Test max_positions

This commit is contained in:
Myle Ott 2018-09-02 20:05:38 -07:00
parent dfd77717b9
commit d473620e39

View File

@ -58,6 +58,27 @@ class TestTranslation(unittest.TestCase):
train_translation_model(data_dir, 'fconv_iwslt_de_en', ['--update-freq', '3'])
generate_main(data_dir)
def test_max_positions(self):
with contextlib.redirect_stdout(StringIO()):
with tempfile.TemporaryDirectory('test_max_positions') as data_dir:
create_dummy_data(data_dir)
preprocess_translation_data(data_dir)
with self.assertRaises(Exception) as context:
train_translation_model(
data_dir, 'fconv_iwslt_de_en', ['--max-target-positions', '5'],
)
self.assertTrue(
'skip this example with --skip-invalid-size-inputs-valid-test' \
in str(context.exception)
)
train_translation_model(
data_dir, 'fconv_iwslt_de_en',
['--max-target-positions', '5', '--skip-invalid-size-inputs-valid-test'],
)
with self.assertRaises(Exception) as context:
generate_main(data_dir)
generate_main(data_dir, ['--skip-invalid-size-inputs-valid-test'])
def test_generation(self):
with contextlib.redirect_stdout(StringIO()):
with tempfile.TemporaryDirectory('test_sampling') as data_dir: