Enable translation_multi_simple_epoch to have different source and target dictionaries

Summary: In past, we always use shared dictionary for multilingual experiments. This diff renables different dictionaries for source and target languages by changing the assertion criteria and reverts back to use specific languages to return source_dict and target_dict.

Reviewed By: chtran

Differential Revision: D24637682

fbshipit-source-id: a982e4f1e48395cc5bf10dc03b98fbe970062f8d
This commit is contained in:
Yuqing Tang 2020-10-30 18:23:14 -07:00 committed by Facebook GitHub Bot
parent a4356b1da2
commit de859692ff
2 changed files with 75 additions and 11 deletions

View File

@ -96,25 +96,35 @@ class TranslationMultiSimpleEpochTask(LegacyFairseqTask):
# models.build_model(). This allows multitask type of sub-class can
# build models other than the input lang_pairs
self.model_lang_pairs = self.lang_pairs
self.source_langs = [d.split("-")[0] for d in self.lang_pairs]
self.target_langs = [d.split("-")[1] for d in self.lang_pairs]
self.check_dicts(self.dicts, self.source_langs, self.target_langs)
self.sampling_method = SamplingMethod.build_sampler(args, self)
self.data_manager = MultilingualDatasetManager.setup_data_manager(
args, self.lang_pairs, langs, dicts, self.sampling_method
)
@classmethod
def check_dicts(cls, dicts, source_langs, target_langs):
src_dict = dicts[source_langs[0]]
tgt_dict = dicts[target_langs[0]]
for src_lang in source_langs:
assert (
src_dict == dicts[src_lang]
), "Diffrent dictionary are specified for different source languages; "
"TranslationMultiSimpleEpochTask only supports one shared dictionary across all source languages"
for tgt_lang in target_langs:
assert (
tgt_dict == dicts[tgt_lang]
), "Diffrent dictionary are specified for different target languages; "
"TranslationMultiSimpleEpochTask only supports one shared dictionary across all target languages"
@classmethod
def setup_task(cls, args, **kwargs):
langs, dicts, training = MultilingualDatasetManager.prepare(
cls.load_dictionary, args, **kwargs
)
dict0 = None
for _, lang_dict in dicts.items():
if dict0 is None:
dict0 = lang_dict
else:
assert (
dict0 == lang_dict
), "Diffrent dictionary are specified for different languages; "
"TranslationMultiSimpleEpochTask only supports one shared dictionary across all languages"
return cls(args, langs, dicts, training)
def has_sharded_data(self, split):
@ -249,11 +259,11 @@ class TranslationMultiSimpleEpochTask(LegacyFairseqTask):
@property
def source_dictionary(self):
return next(iter(self.dicts.values()))
return self.dicts[self.source_langs[0]]
@property
def target_dictionary(self):
return next(iter(self.dicts.values()))
return self.dicts[self.target_langs[0]]
def create_batch_sampler_func(
self,

View File

@ -425,6 +425,60 @@ class TestTranslation(unittest.TestCase):
+ dec_ltok_flag,
)
def test_translation_multi_simple_epoch_dicts(self):
# test with all combinations of encoder/decoder lang tokens
with contextlib.redirect_stdout(StringIO()):
enc_ltok_flag = ["--encoder-langtok", "src"]
dec_ltok_flag = ["--decoder-langtok"]
with tempfile.TemporaryDirectory(
"test_translation_multi_simple_epoch_dict"
) as data_dir:
create_dummy_data(data_dir)
preprocess_translation_data(
data_dir, extra_flags=[]
)
train_translation_model(
data_dir,
arch="transformer",
task="translation_multi_simple_epoch",
extra_flags=[
"--encoder-layers",
"2",
"--decoder-layers",
"2",
"--encoder-embed-dim",
"8",
"--decoder-embed-dim",
"8",
"--sampling-method",
"temperature",
"--sampling-temperature",
"1.5",
"--virtual-epoch-size",
"1000",
]
+ enc_ltok_flag
+ dec_ltok_flag,
lang_flags=["--lang-pairs", "in-out"],
run_validation=True,
extra_valid_flags=enc_ltok_flag + dec_ltok_flag,
)
generate_main(
data_dir,
extra_flags=[
"--task",
"translation_multi_simple_epoch",
"--lang-pairs",
"in-out",
"--source-lang",
"in",
"--target-lang",
"out",
]
+ enc_ltok_flag
+ dec_ltok_flag,
)
def test_transformer_cross_self_attention(self):
with contextlib.redirect_stdout(StringIO()):
with tempfile.TemporaryDirectory(