mirror of
https://github.com/facebookresearch/fairseq.git
synced 2024-09-21 14:17:25 +03:00
a48f235636
Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1357 Reviewed By: alexeib Differential Revision: D24377772 fbshipit-source-id: 51581af041d42d62166b33a35a1a4228b1a76f0c
124 lines
4.0 KiB
Python
124 lines
4.0 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
#
|
|
# This source code is licensed under the MIT license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import unittest
|
|
|
|
import tests.utils as test_utils
|
|
import torch
|
|
from fairseq.data import (
|
|
BacktranslationDataset,
|
|
LanguagePairDataset,
|
|
TransformEosDataset,
|
|
)
|
|
from fairseq.sequence_generator import SequenceGenerator
|
|
|
|
|
|
class TestBacktranslationDataset(unittest.TestCase):
|
|
def setUp(self):
|
|
(
|
|
self.tgt_dict,
|
|
self.w1,
|
|
self.w2,
|
|
self.src_tokens,
|
|
self.src_lengths,
|
|
self.model,
|
|
) = test_utils.sequence_generator_setup()
|
|
|
|
dummy_src_samples = self.src_tokens
|
|
|
|
self.tgt_dataset = test_utils.TestDataset(data=dummy_src_samples)
|
|
self.cuda = torch.cuda.is_available()
|
|
|
|
def _backtranslation_dataset_helper(
|
|
self,
|
|
remove_eos_from_input_src,
|
|
remove_eos_from_output_src,
|
|
):
|
|
tgt_dataset = LanguagePairDataset(
|
|
src=self.tgt_dataset,
|
|
src_sizes=self.tgt_dataset.sizes,
|
|
src_dict=self.tgt_dict,
|
|
tgt=None,
|
|
tgt_sizes=None,
|
|
tgt_dict=None,
|
|
)
|
|
|
|
generator = SequenceGenerator(
|
|
[self.model],
|
|
tgt_dict=self.tgt_dict,
|
|
max_len_a=0,
|
|
max_len_b=200,
|
|
beam_size=2,
|
|
unk_penalty=0,
|
|
)
|
|
|
|
backtranslation_dataset = BacktranslationDataset(
|
|
tgt_dataset=TransformEosDataset(
|
|
dataset=tgt_dataset,
|
|
eos=self.tgt_dict.eos(),
|
|
# remove eos from the input src
|
|
remove_eos_from_src=remove_eos_from_input_src,
|
|
),
|
|
src_dict=self.tgt_dict,
|
|
backtranslation_fn=(
|
|
lambda sample: generator.generate([self.model], sample)
|
|
),
|
|
output_collater=TransformEosDataset(
|
|
dataset=tgt_dataset,
|
|
eos=self.tgt_dict.eos(),
|
|
# if we remove eos from the input src, then we need to add it
|
|
# back to the output tgt
|
|
append_eos_to_tgt=remove_eos_from_input_src,
|
|
remove_eos_from_src=remove_eos_from_output_src,
|
|
).collater,
|
|
cuda=self.cuda,
|
|
)
|
|
dataloader = torch.utils.data.DataLoader(
|
|
backtranslation_dataset,
|
|
batch_size=2,
|
|
collate_fn=backtranslation_dataset.collater,
|
|
)
|
|
backtranslation_batch_result = next(iter(dataloader))
|
|
|
|
eos, pad, w1, w2 = self.tgt_dict.eos(), self.tgt_dict.pad(), self.w1, self.w2
|
|
|
|
# Note that we sort by src_lengths and add left padding, so actually
|
|
# ids will look like: [1, 0]
|
|
expected_src = torch.LongTensor([[w1, w2, w1, eos], [pad, pad, w1, eos]])
|
|
if remove_eos_from_output_src:
|
|
expected_src = expected_src[:, :-1]
|
|
expected_tgt = torch.LongTensor([[w1, w2, eos], [w1, w2, eos]])
|
|
generated_src = backtranslation_batch_result["net_input"]["src_tokens"]
|
|
tgt_tokens = backtranslation_batch_result["target"]
|
|
|
|
self.assertTensorEqual(expected_src, generated_src)
|
|
self.assertTensorEqual(expected_tgt, tgt_tokens)
|
|
|
|
def test_backtranslation_dataset_no_eos_in_output_src(self):
|
|
self._backtranslation_dataset_helper(
|
|
remove_eos_from_input_src=False,
|
|
remove_eos_from_output_src=True,
|
|
)
|
|
|
|
def test_backtranslation_dataset_with_eos_in_output_src(self):
|
|
self._backtranslation_dataset_helper(
|
|
remove_eos_from_input_src=False,
|
|
remove_eos_from_output_src=False,
|
|
)
|
|
|
|
def test_backtranslation_dataset_no_eos_in_input_src(self):
|
|
self._backtranslation_dataset_helper(
|
|
remove_eos_from_input_src=True,
|
|
remove_eos_from_output_src=False,
|
|
)
|
|
|
|
def assertTensorEqual(self, t1, t2):
|
|
self.assertEqual(t1.size(), t2.size(), "size mismatch")
|
|
self.assertEqual(t1.ne(t2).long().sum(), 0)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|