fairseq/tests/test_backtranslation_dataset.py
Myle Ott 864b89d044 Online backtranslation module
Co-authored-by: liezl200 <lie@fb.com>
2018-09-25 17:36:43 -04:00

71 lines
2.5 KiB
Python

# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import argparse
import unittest
import tests.utils as test_utils
import torch
from fairseq.data.backtranslation_dataset import BacktranslationDataset
class TestBacktranslationDataset(unittest.TestCase):
def setUp(self):
self.tgt_dict, self.w1, self.w2, self.src_tokens, self.src_lengths, self.model = (
test_utils.sequence_generator_setup()
)
backtranslation_args = argparse.Namespace()
"""
Same as defaults from fairseq/options.py
"""
backtranslation_args.backtranslation_unkpen = 0
backtranslation_args.backtranslation_sampling = False
backtranslation_args.backtranslation_max_len_a = 0
backtranslation_args.backtranslation_max_len_b = 200
backtranslation_args.backtranslation_beam = 2
self.backtranslation_args = backtranslation_args
dummy_src_samples = self.src_tokens
self.tgt_dataset = test_utils.TestDataset(data=dummy_src_samples)
def test_backtranslation_dataset(self):
backtranslation_dataset = BacktranslationDataset(
args=self.backtranslation_args,
tgt_dataset=self.tgt_dataset,
tgt_dict=self.tgt_dict,
backtranslation_model=self.model,
)
dataloader = torch.utils.data.DataLoader(
backtranslation_dataset,
batch_size=2,
collate_fn=backtranslation_dataset.collater,
)
backtranslation_batch_result = next(iter(dataloader))
eos, pad, w1, w2 = self.tgt_dict.eos(), self.tgt_dict.pad(), self.w1, self.w2
# Note that we sort by src_lengths and add left padding, so actually
# ids will look like: [1, 0]
expected_src = torch.LongTensor([[w1, w2, w1, eos], [pad, pad, w1, eos]])
expected_tgt = torch.LongTensor([[w1, w2, eos], [w1, w2, eos]])
generated_src = backtranslation_batch_result["net_input"]["src_tokens"]
tgt_tokens = backtranslation_batch_result["target"]
self.assertTensorEqual(expected_src, generated_src)
self.assertTensorEqual(expected_tgt, tgt_tokens)
def assertTensorEqual(self, t1, t2):
self.assertEqual(t1.size(), t2.size(), "size mismatch")
self.assertEqual(t1.ne(t2).long().sum(), 0)
if __name__ == "__main__":
unittest.main()