Fix tests + style nits + Python 3.5 compat

Summary: Pull Request resolved: https://github.com/pytorch/fairseq/pull/336

Differential Revision: D12876709

Pulled By: myleott

fbshipit-source-id: a31536e2eb93f752600b9940c28e9b9fcefc8b86
This commit is contained in:
Myle Ott 2018-11-01 01:23:43 -07:00 committed by Facebook Github Bot
parent f3a0939eed
commit 5bbd148e6e
5 changed files with 33 additions and 27 deletions

View File

@ -7,10 +7,10 @@
from .dictionary import Dictionary, TruncatedDictionary
from .fairseq_dataset import FairseqDataset
from .concat_dataset import ConcatDataset
from .indexed_dataset import IndexedDataset, IndexedCachedDataset, IndexedInMemoryDataset, IndexedRawTextDataset
from .append_eos_dataset import AppendEosDataset
from .backtranslation_dataset import BacktranslationDataset
from .concat_dataset import ConcatDataset
from .indexed_dataset import IndexedDataset, IndexedCachedDataset, IndexedInMemoryDataset, IndexedRawTextDataset
from .language_pair_dataset import LanguagePairDataset
from .monolingual_dataset import MonolingualDataset
from .round_robin_zip_datasets import RoundRobinZipDatasets
@ -25,6 +25,7 @@ from .iterators import (
__all__ = [
'AppendEosDataset',
'BacktranslationDataset',
'ConcatDataset',
'CountingIterator',
'Dictionary',
@ -40,5 +41,4 @@ __all__ = [
'RoundRobinZipDatasets',
'ShardedIterator',
'TokenBlockDataset',
'BacktranslationDataset',
]

View File

@ -17,7 +17,6 @@ class AppendEosDataset(torch.utils.data.Dataset):
def __getitem__(self, index):
item = torch.cat([self.dataset[index], torch.LongTensor([self.eos])])
print(item)
return item
def __len__(self):

View File

@ -23,6 +23,7 @@ class BacktranslationDataset(FairseqDataset):
max_len_b,
remove_eos_at_src=False,
generator_class=sequence_generator.SequenceGenerator,
cuda=True,
**kwargs,
):
"""
@ -51,6 +52,7 @@ class BacktranslationDataset(FairseqDataset):
generator_class: which SequenceGenerator class to use for
backtranslation. Output of generate() should be the same format
as fairseq's SequenceGenerator
cuda: use GPU for generation
kwargs: generation args to init the backtranslation
SequenceGenerator
"""
@ -73,6 +75,10 @@ class BacktranslationDataset(FairseqDataset):
**kwargs,
)
self.cuda = cuda if torch.cuda.is_available() else False
if self.cuda:
self.backtranslation_generator.cuda()
def __getitem__(self, index):
"""
Returns a single sample. Multiple samples are fed to the collater to
@ -105,32 +111,32 @@ class BacktranslationDataset(FairseqDataset):
# {id: id, source: generated backtranslation, target: original tgt}
generated_samples = []
for input_sample, hypos in zip(samples, backtranslation_hypos):
eos = self.tgt_dataset.src_dict.eos()
original_tgt = input_sample["source"].cpu()
generated_source = hypos[0]["tokens"].cpu() # first hypo is best hypo
# Append EOS to the tgt sentence if it does not have an EOS
# This is the case if the samples in monolingual tgt_dataset don't
# have an EOS appended to the end of each sentence.
original_tgt = input_sample["source"]
eos = self.tgt_dataset.src_dict.eos()
if original_tgt[-1] != eos:
original_tgt = torch.cat([original_tgt, torch.LongTensor([eos])])
# The generated source dialect backtranslation will have an EOS.
# If we want our parallel data source to not have an EOS, we will
# have to remove it.
generated_source = hypos[0]["tokens"] # first hypo is best hypo
if self.remove_eos_at_src:
assert generated_source[-1] == eos, (
f"Expected generated backtranslation to have eos (id: "
f"{eos}) at end, but instead found token id "
f"{generated_source[-1]} at end."
)
"Expected generated backtranslation to have eos (id: "
"{eos}) at end, but instead found token id "
"{generated_source[-1]} at end."
).format(eos=eos, generated_source=generated_source)
generated_source = generated_source[:-1]
generated_samples.append(
{
"id": input_sample["id"],
"source": generated_source.cpu(),
"target": original_tgt.cpu(),
"source": generated_source,
"target": original_tgt,
}
)
@ -162,11 +168,7 @@ class BacktranslationDataset(FairseqDataset):
sample. Note in this case, sample["target"] is None, and
sample["net_input"]["src_tokens"] is really in tgt language.
"""
if torch.cuda.is_available():
s = utils.move_to_cuda(sample)
else:
s = sample
self.backtranslation_generator.cuda()
s = utils.move_to_cuda(sample) if self.cuda else sample
input = s["net_input"]
srclen = input["src_tokens"].size(1)
hypos = self.backtranslation_generator.generate(

View File

@ -66,13 +66,15 @@ class TestDataNoising(unittest.TestCase):
return vocab, x, torch.LongTensor(src_len)
def assert_eos_at_end(self, x, x_len, eos):
""" Asserts last token of every sentence in x is EOS """
"""Asserts last token of every sentence in x is EOS """
for i in range(len(x_len)):
self.assertEqual(
x[x_len[i]-1][i],
eos,
f"Expected eos (token id {eos}) at the end of sentence {i} but "
f"got {x[i][-1]} instead"
(
"Expected eos (token id {eos}) at the end of sentence {i} but "
"got {other} instead"
).format(i=i, eos=eos, other=x[i][-1])
)
def assert_word_dropout_correct(self, x, x_noised, x_len, l_noised):
@ -192,16 +194,18 @@ class TestDataNoising(unittest.TestCase):
self.assert_eos_at_end(x=x_noised, x_len=l_noised, eos=vocab.eos())
def assert_no_eos_at_end(self, x, x_len, eos):
""" Asserts that the last token of each sentence in x is not EOS """
"""Asserts that the last token of each sentence in x is not EOS """
for i in range(len(x_len)):
self.assertNotEqual(
x[x_len[i]-1][i],
eos,
f"Expected no eos (token id {eos}) at the end of sentence {i}."
"Expected no eos (token id {eos}) at the end of sentence {i}.".format(
eos=eos, i=i,
)
)
def test_word_dropout_without_eos(self):
""" Same result as word dropout with eos except no EOS at end"""
"""Same result as word dropout with eos except no EOS at end"""
vocab, x, x_len = self._get_test_data(append_eos=False)
with data_utils.numpy_seed(1234):
@ -213,7 +217,7 @@ class TestDataNoising(unittest.TestCase):
self.assert_no_eos_at_end(x=x_noised, x_len=l_noised, eos=vocab.eos())
def test_word_blank_without_eos(self):
""" Same result as word blank with eos except no EOS at end"""
"""Same result as word blank with eos except no EOS at end"""
vocab, x, x_len = self._get_test_data(append_eos=False)
with data_utils.numpy_seed(1234):
@ -225,7 +229,7 @@ class TestDataNoising(unittest.TestCase):
self.assert_no_eos_at_end(x=x_noised, x_len=l_noised, eos=vocab.eos())
def test_word_shuffle_without_eos(self):
""" Same result as word shuffle with eos except no EOS at end """
"""Same result as word shuffle with eos except no EOS at end"""
vocab, x, x_len = self._get_test_data(append_eos=False)
with data_utils.numpy_seed(1234):

View File

@ -221,7 +221,8 @@ class TestIncrementalDecoder(FairseqIncrementalDecoder):
# random attention
attn = torch.rand(bbsz, tgt_len, src_len)
return probs, attn
dev = prev_output_tokens.device
return probs.to(dev), attn.to(dev)
def get_normalized_probs(self, net_output, log_probs, _):
# the decoder returns probabilities directly