Option to remove EOS at source in backtranslation dataset

Summary:
If we want our parallel data to have EOS at the end of source, we keep the EOS at the end of the generated source dialect backtranslation.
If we don't want our parallel data to have EOS at the end of source, we **remove** the EOS at the end of the generated source dialect backtranslation.

Note: we always want EOS at the end of our target / reference in parallel data so our model can learn to generate a sentence at any arbitrary length. So we make sure that the original target has an EOS before returning a batch of {generated src, original target}. If our original targets in tgt dataset doesn't have an EOS, we append EOS to each tgt sample before collating.
We only do this for the purpose of collating a {generated src, original tgt} batch AFTER generating the backtranslations. We don't enforce any EOS before passing tgt to the tgt->src model for generating the backtranslation. The users of this dataset is expected to format tgt dataset examples in the correct format that the tgt->src model expects.

Reviewed By: jmp84

Differential Revision: D10157725

fbshipit-source-id: eb6a15f13c651f7c435b8db28103c9a8189845fb
This commit is contained in:
Liezl Puzon 2018-10-03 18:18:03 -07:00 committed by Facebook Github Bot
parent fc677c945e
commit b9e29a4711
2 changed files with 43 additions and 4 deletions

View File

@ -5,6 +5,7 @@
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import torch
from fairseq import sequence_generator
@ -19,6 +20,7 @@ class BacktranslationDataset(FairseqDataset):
backtranslation_model,
max_len_a,
max_len_b,
remove_eos_at_src=False,
generator_class=sequence_generator.SequenceGenerator,
**kwargs,
):
@ -33,11 +35,16 @@ class BacktranslationDataset(FairseqDataset):
We use language_pair_dataset here to encapsulate the tgt_dataset
so we can re-use the LanguagePairDataset collater to format the
batches in the structure that SequenceGenerator expects.
Note: tgt_dataset samples should not have EOS at end if
the tgt-src model expects an input without EOS. This dataset
does not enforce this, you should enforce that in preprocessing.
tgt_dict: tgt dictionary (typically a joint src/tgt BPE dictionary)
backtranslation_model: tgt-src model to use in the SequenceGenerator
to generate backtranslations from tgt batches
max_len_a, max_len_b: args passed into generate() function of
max_len_a, max_len_b: args passed into generate() function of
the backtranslation SequenceGenerator
remove_eos_at_src: whether we should remove EOS from the source
dialect text generated by the backtranslation model.
generator_class: which SequenceGenerator class to use for
backtranslation. Output of generate() should be the same format
as fairseq's SequenceGenerator
@ -55,6 +62,8 @@ class BacktranslationDataset(FairseqDataset):
self.max_len_a = max_len_a
self.max_len_b = max_len_b
self.remove_eos_at_src = remove_eos_at_src
self.backtranslation_generator = generator_class(
models=[backtranslation_model],
tgt_dict=tgt_dict,
@ -93,11 +102,32 @@ class BacktranslationDataset(FairseqDataset):
# {id: id, source: generated backtranslation, target: original tgt}
generated_samples = []
for input_sample, hypos in zip(samples, backtranslation_hypos):
eos = self.tgt_dataset.src_dict.eos()
# Append EOS to the tgt sentence if it does not have an EOS
# This is the case if the samples in monolingual tgt_dataset don't
# have an EOS appended to the end of each sentence.
original_tgt = input_sample["source"]
if original_tgt[-1] != eos:
original_tgt = torch.cat([original_tgt, torch.LongTensor(eos)])
# The generated source dialect backtranslation will have an EOS.
# If we want our parallel data source to not have an EOS, we will
# have to remove it.
generated_source = hypos[0]["tokens"] # first hypo is best hypo
if self.remove_eos_at_src:
assert generated_source[-1] == eos, (
f"Expected generated backtranslation to have eos (id: "
f"{eos}) at end, but instead found token id "
f"{generated_source[-1]} at end."
)
generated_source = generated_source[:-1]
generated_samples.append(
{
"id": input_sample["id"],
"source": hypos[0]["tokens"], # first hypo is best hypo
"target": input_sample["source"],
"source": generated_source,
"target": original_tgt,
}
)

View File

@ -23,7 +23,7 @@ class TestBacktranslationDataset(unittest.TestCase):
self.tgt_dataset = test_utils.TestDataset(data=dummy_src_samples)
def test_backtranslation_dataset(self):
def _backtranslation_dataset_helper(self, remove_eos_at_src):
"""
SequenceGenerator kwargs are same as defaults from fairseq/options.py
"""
@ -36,6 +36,7 @@ class TestBacktranslationDataset(unittest.TestCase):
beam_size=2,
unk_penalty=0,
sampling=False,
remove_eos_at_src=remove_eos_at_src,
generator_class=sequence_generator.SequenceGenerator,
)
dataloader = torch.utils.data.DataLoader(
@ -50,6 +51,8 @@ class TestBacktranslationDataset(unittest.TestCase):
# Note that we sort by src_lengths and add left padding, so actually
# ids will look like: [1, 0]
expected_src = torch.LongTensor([[w1, w2, w1, eos], [pad, pad, w1, eos]])
if remove_eos_at_src:
expected_src = expected_src[:, :-1]
expected_tgt = torch.LongTensor([[w1, w2, eos], [w1, w2, eos]])
generated_src = backtranslation_batch_result["net_input"]["src_tokens"]
tgt_tokens = backtranslation_batch_result["target"]
@ -57,6 +60,12 @@ class TestBacktranslationDataset(unittest.TestCase):
self.assertTensorEqual(expected_src, generated_src)
self.assertTensorEqual(expected_tgt, tgt_tokens)
def test_backtranslation_dataset_no_eos_at_src(self):
self._backtranslation_dataset_helper(remove_eos_at_src=True)
def test_backtranslation_dataset_with_eos_at_src(self):
self._backtranslation_dataset_helper(remove_eos_at_src=False)
def assertTensorEqual(self, t1, t2):
self.assertEqual(t1.size(), t2.size(), "size mismatch")
self.assertEqual(t1.ne(t2).long().sum(), 0)