mirror of
https://github.com/facebookresearch/fairseq.git
synced 2024-10-26 17:32:57 +03:00
c2e8904b60
Summary: # Before submitting - [x] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [x] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.m)? - [x] Did you make sure to update the docs? - [x] Did you write any new necessary tests? too many of them actually ^^ ## What does this PR do? This is a rewrite of https://github.com/fairinternal/fairseq-py/issues/1538 following the discussion there, and taking into account the proposed https://github.com/fairinternal/fairseq-py/issues/1560 from Myle. it brings online backtranslation to fairseq. It adds a RobertaEncDec to fairseq. RobertaEncDec can be built from a pretrained Roberta model allowing to do transfer learning. This is crucial for backtranslation. ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1614 Reviewed By: myleott Differential Revision: D27157296 Pulled By: gwenzek fbshipit-source-id: 43020bc27743419bd4b138716165bf5764117c21
207 lines
7.5 KiB
Python
207 lines
7.5 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
#
|
|
# This source code is licensed under the MIT license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import tempfile
|
|
import unittest
|
|
from pathlib import Path
|
|
from typing import Any, Dict, Sequence
|
|
|
|
import fairseq.data.indexed_dataset as indexed_dataset
|
|
import fairseq.options
|
|
import fairseq.tasks.online_backtranslation as obt
|
|
import torch
|
|
from tests import utils
|
|
|
|
|
|
def mk_sample(tokens: Sequence[int], batch_size: int = 2) -> Dict[str, Any]:
|
|
batch = torch.stack([torch.tensor(tokens, dtype=torch.long)] * batch_size)
|
|
sample = {
|
|
"net_input": {
|
|
"src_tokens": batch,
|
|
"prev_output_tokens": batch,
|
|
"src_lengths": torch.tensor([len(tokens)] * batch_size, dtype=torch.long),
|
|
},
|
|
"target": batch[:, 1:],
|
|
}
|
|
return sample
|
|
|
|
|
|
def mk_dataset(num_samples: int, max_len: int, output: Path):
|
|
output.parent.mkdir(exist_ok=True)
|
|
idx = indexed_dataset.IndexedDatasetBuilder(str(output))
|
|
data = torch.randint(5, 100, (num_samples, max_len))
|
|
lengths = torch.randint(3, max_len, (num_samples,))
|
|
for d, l in zip(data, lengths):
|
|
d[0] = 0
|
|
idx.add_item(d[:l])
|
|
idx.finalize(output.with_suffix(".idx"))
|
|
assert output.exists()
|
|
assert output.with_suffix(".idx").exists()
|
|
|
|
|
|
class OnlineBacktranslationTest(unittest.TestCase):
|
|
|
|
tmp_dir = Path(tempfile.mkdtemp(suffix="OnlineBacktranslationTest"))
|
|
|
|
@classmethod
|
|
def obt_task(
|
|
cls, languages: Sequence[str], data: Path = None, language_mapping: str = None
|
|
):
|
|
dict_path = cls.tmp_dir / "dict.txt"
|
|
if not dict_path.exists():
|
|
dictionary = utils.dummy_dictionary(100)
|
|
dictionary.save(str(dict_path))
|
|
|
|
if data is not None:
|
|
(data / "dict.txt").write_text(dict_path.read_text())
|
|
else:
|
|
data = cls.tmp_dir
|
|
assert len(languages) >= 2
|
|
|
|
kwargs = {
|
|
"arch": "transformer",
|
|
# --max-sentences=1 for better predictability of batches
|
|
"max_sentences": 1,
|
|
# Use characteristics dimensions
|
|
"encoder_layers": 3,
|
|
"encoder_embed_dim": 12,
|
|
"encoder_ffn_embed_dim": 14,
|
|
"encoder_attention_heads": 4,
|
|
"decoder_layers": 3,
|
|
"decoder_embed_dim": 12,
|
|
"decoder_output_dim": 12,
|
|
"decoder_ffn_embed_dim": 14,
|
|
"decoder_attention_heads": 4,
|
|
# Disable dropout so we have comparable tests.
|
|
"dropout": 0,
|
|
"attention_dropout": 0,
|
|
"activation_dropout": 0,
|
|
"encoder_layerdrop": 0,
|
|
}
|
|
|
|
args = fairseq.options.get_args(
|
|
data,
|
|
task="online_backtranslation",
|
|
mono_langs=",".join(languages),
|
|
valid_lang_pairs=f"{languages[0]}-{languages[1]}",
|
|
tokens_per_sample=256,
|
|
language_mapping=language_mapping,
|
|
**kwargs,
|
|
)
|
|
task = obt.OnlineBackTranslationTask.setup_task(args)
|
|
# we need to build the model to have the correct dictionary
|
|
model = task.build_model(task.args)
|
|
return task, model
|
|
|
|
def tmp_path(self, test_case: str) -> Path:
|
|
return Path(tempfile.mkdtemp(test_case, dir=self.tmp_dir))
|
|
|
|
def test_lang_tokens(self):
|
|
task, model = self.obt_task(["en", "ro", "zh"])
|
|
assert obt._lang_token("en") in task.dictionary
|
|
assert obt._lang_token("ro") in task.dictionary
|
|
assert obt._lang_token("zh") in task.dictionary
|
|
|
|
en_bos = obt._lang_token_index(task.common_dict, "en")
|
|
assert "en" == task.common_dict[en_bos].strip("_")
|
|
zh_bos = obt._lang_token_index(task.common_dict, "zh")
|
|
assert "zh" == task.common_dict[zh_bos].strip("_")
|
|
zh_sample = mk_sample([zh_bos, 16, 14, 12, 10])
|
|
|
|
# we expect to receive the bos token for translation
|
|
assert task.get_bos_token_from_sample(zh_sample) == en_bos
|
|
|
|
def test_backtranslate_sample(self):
|
|
task, model = self.obt_task(["en", "ro", "zh"])
|
|
|
|
en_bos = obt._lang_token_index(task.common_dict, "en")
|
|
zh_bos = obt._lang_token_index(task.common_dict, "zh")
|
|
sample = mk_sample([zh_bos, 16, 14, 12, 10])
|
|
|
|
task.backtranslate_sample(sample, "zh", "en")
|
|
target_zh = list(sample["target"][0])
|
|
assert target_zh == [16, 14, 12, 10] # original zh sentence
|
|
generated_en = sample["net_input"]["src_tokens"][0]
|
|
assert generated_en[0] == en_bos
|
|
|
|
def test_train_dataset(self):
|
|
data = self.tmp_path("test_train_dataset")
|
|
mk_dataset(20, 10, data / "en" / "train.bin")
|
|
mk_dataset(10, 10, data / "zh" / "train.bin")
|
|
task, model = self.obt_task(["en", "zh"], data)
|
|
task.load_dataset("train")
|
|
|
|
en_bos = obt._lang_token_index(task.common_dict, "en")
|
|
zh_bos = obt._lang_token_index(task.common_dict, "zh")
|
|
|
|
train = task.datasets["train"]
|
|
train.ordered_indices()
|
|
train.prefetch([0, 19])
|
|
sample_0 = train[0]
|
|
sample_19 = train[19]
|
|
self.assertEqual(
|
|
set(sample_0.keys()), {"en-BT", "en-DENOISE", "zh-BT", "zh-DENOISE"}
|
|
)
|
|
for sample in (sample_0, sample_19):
|
|
self.assertEqual(sample["en-BT"]["source"][0], en_bos)
|
|
# bt target isn't ready to look at.
|
|
self.assertEqual(sample["en-DENOISE"]["source"][0], en_bos)
|
|
# TODO What could we check on the target side ?
|
|
|
|
for i in range(10):
|
|
# Zh dataset is shorter, and is wrapped around En dataset.
|
|
train.prefetch([i, i + 10])
|
|
self.assertEqual(
|
|
list(train[i]["zh-DENOISE"]["source"]),
|
|
list(train[i + 10]["zh-DENOISE"]["source"]),
|
|
)
|
|
self.assertEqual(train[i]["zh-DENOISE"]["source"][0].item(), zh_bos)
|
|
|
|
# Sorted by increasing len
|
|
self.assertLess(
|
|
len(sample_0["en-BT"]["source"]), len(sample_19["en-BT"]["source"])
|
|
)
|
|
|
|
def test_valid_dataset(self):
|
|
data = self.tmp_path("test_valid_dataset")
|
|
mk_dataset(10, 21, data / "valid.en-zh.en.bin")
|
|
mk_dataset(10, 21, data / "valid.en-zh.zh.bin")
|
|
|
|
task, model = self.obt_task(["en", "zh"], data)
|
|
valid = task.load_dataset("valid")
|
|
en_bos = obt._lang_token_index(task.common_dict, "en")
|
|
|
|
assert valid is not None
|
|
valid.prefetch(range(10))
|
|
sample_0 = valid[0]
|
|
sample_9 = valid[9]
|
|
self.assertEqual(sample_0["id"], 0)
|
|
self.assertEqual(sample_9["id"], 9)
|
|
self.assertEqual(sample_0["source"][0], en_bos)
|
|
self.assertEqual(sample_9["source"][0], en_bos)
|
|
# TODO: could we test the target side ?
|
|
|
|
def assertFnMatch(self, fn, values):
|
|
for x, y in values.items():
|
|
fn_x = fn(x)
|
|
self.assertEqual(fn_x, y, f"Fn has wrong value: fn({x}) = {fn_x} != {y}")
|
|
|
|
def test_piecewise_linear_fn(self):
|
|
self.assertFnMatch(
|
|
obt.PiecewiseLinearFn.from_string("1.0"), {0: 1, 100: 1, 500: 1, 1000: 1}
|
|
)
|
|
self.assertFnMatch(
|
|
obt.PiecewiseLinearFn.from_string("0:1,1000:0"),
|
|
{0: 1, 500: 0.5, 1000: 0, 2000: 0},
|
|
)
|
|
self.assertFnMatch(
|
|
obt.PiecewiseLinearFn.from_string("0:0,1000:1"),
|
|
{0: 0, 500: 0.5, 1000: 1, 2000: 1},
|
|
)
|
|
self.assertFnMatch(
|
|
obt.PiecewiseLinearFn.from_string("0:0,1000:1,2000:0"),
|
|
{0: 0, 500: 0.5, 1000: 1, 1500: 0.5, 2000: 0, 3000: 0},
|
|
)
|