fairseq/tests/tasks/test_multilingual_denoising.py
Alexander Jipa a6a6327942
switch denoising and multilingual_denoising tasks to OmegaConf (#4447)
Co-authored-by: Alexander Jipa <azzhipa@amazon.com>
2022-06-28 15:44:18 -04:00

99 lines
3.5 KiB
Python

# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import os
import unittest
from tempfile import TemporaryDirectory
from fairseq import options
from fairseq.binarizer import FileBinarizer, VocabularyDatasetBinarizer
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
from fairseq.tasks.multilingual_denoising import MultilingualDenoisingTask
from tests.utils import build_vocab, make_data
class TestMultilingualDenoising(unittest.TestCase):
def test_multilingual_denoising(self):
with TemporaryDirectory() as dirname:
# prep input file
lang_dir = os.path.join(dirname, "en")
os.mkdir(lang_dir)
raw_file = os.path.join(lang_dir, "raw")
data = make_data(out_file=raw_file)
vocab = build_vocab(data)
# binarize
binarizer = VocabularyDatasetBinarizer(vocab, append_eos=False)
split = "train"
bin_file = os.path.join(lang_dir, split)
dataset_impl = "mmap"
FileBinarizer.multiprocess_dataset(
input_file=raw_file,
binarizer=binarizer,
dataset_impl=dataset_impl,
vocab_size=len(vocab),
output_prefix=bin_file,
)
# setup task
train_args = options.parse_args_and_arch(
options.get_training_parser(),
[
"--task",
"multilingual_denoising",
"--arch",
"bart_base",
"--seed",
"42",
"--mask-length",
"word",
"--permute-sentences",
"1",
"--rotate",
"0",
"--replace-length",
"-1",
"--mask",
"0.2",
dirname,
],
)
cfg = convert_namespace_to_omegaconf(train_args)
task = MultilingualDenoisingTask(cfg.task, binarizer.dict)
# load datasets
original_dataset = task._load_dataset_split(bin_file, 1, False)
task.load_dataset(split)
masked_dataset = task.dataset(split)
iterator = task.get_batch_iterator(
dataset=masked_dataset,
max_tokens=65_536,
max_positions=4_096,
).next_epoch_itr(shuffle=False)
mask_index = task.source_dictionary.index("<mask>")
for batch in iterator:
for sample in range(len(batch)):
net_input = batch["net_input"]
masked_src_tokens = net_input["src_tokens"][sample]
masked_src_length = net_input["src_lengths"][sample]
masked_tgt_tokens = batch["target"][sample]
sample_id = batch["id"][sample]
original_tokens = original_dataset[sample_id]
original_tokens = original_tokens.masked_select(
masked_src_tokens[:masked_src_length] == mask_index
)
masked_tokens = masked_tgt_tokens.masked_select(
masked_src_tokens == mask_index
)
assert masked_tokens.equal(original_tokens)
if __name__ == "__main__":
unittest.main()