diff --git a/tests/gpu/__init__.py b/tests/gpu/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/gpu/test_binaries_gpu.py b/tests/gpu/test_binaries_gpu.py new file mode 100644 index 000000000..5ccb84c55 --- /dev/null +++ b/tests/gpu/test_binaries_gpu.py @@ -0,0 +1,281 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import contextlib +import logging +import os +import tempfile +import unittest +from io import StringIO + +import torch +from fairseq import options +from fairseq_cli import train +from tests.utils import ( + create_dummy_data, + generate_main, + preprocess_lm_data, + preprocess_translation_data, + train_translation_model, +) + + +class TestTranslationGPU(unittest.TestCase): + def setUp(self): + logging.disable(logging.CRITICAL) + + def tearDown(self): + logging.disable(logging.NOTSET) + + @unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU") + def test_fp16(self): + with contextlib.redirect_stdout(StringIO()): + with tempfile.TemporaryDirectory("test_fp16") as data_dir: + create_dummy_data(data_dir) + preprocess_translation_data(data_dir) + train_translation_model(data_dir, "fconv_iwslt_de_en", ["--fp16"]) + generate_main(data_dir) + + @unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU") + def test_memory_efficient_fp16(self): + with contextlib.redirect_stdout(StringIO()): + with tempfile.TemporaryDirectory("test_memory_efficient_fp16") as data_dir: + create_dummy_data(data_dir) + preprocess_translation_data(data_dir) + train_translation_model( + data_dir, "fconv_iwslt_de_en", ["--memory-efficient-fp16"] + ) + generate_main(data_dir) + + @unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU") + def test_levenshtein_transformer(self): + with contextlib.redirect_stdout(StringIO()): + with tempfile.TemporaryDirectory( + "test_levenshtein_transformer" + ) as data_dir: + create_dummy_data(data_dir) + preprocess_translation_data(data_dir, ["--joined-dictionary"]) + train_translation_model( + data_dir, + "levenshtein_transformer", + [ + "--apply-bert-init", + "--early-exit", + "6,6,6", + "--criterion", + "nat_loss", + ], + task="translation_lev", + ) + generate_main( + data_dir, + [ + "--task", + "translation_lev", + "--iter-decode-max-iter", + "9", + "--iter-decode-eos-penalty", + "0", + "--print-step", + ], + ) + + +def _quantize_language_model(data_dir, arch, extra_flags=None, run_validation=False): + train_parser = options.get_training_parser() + train_args = options.parse_args_and_arch( + train_parser, + [ + "--task", + "language_modeling", + data_dir, + "--arch", + arch, + "--optimizer", + "adam", + "--lr", + "0.0001", + "--criterion", + "adaptive_loss", + "--adaptive-softmax-cutoff", + "5,10,15", + "--max-tokens", + "500", + "--tokens-per-sample", + "500", + "--save-dir", + data_dir, + "--max-epoch", + "1", + "--no-progress-bar", + "--distributed-world-size", + "1", + "--ddp-backend", + "no_c10d", + "--num-workers", + 0, + ] + + (extra_flags or []), + ) + train.main(train_args) + + # try scalar quantization + scalar_quant_train_parser = options.get_training_parser() + scalar_quant_train_args = options.parse_args_and_arch( + scalar_quant_train_parser, + [ + "--task", + "language_modeling", + data_dir, + "--arch", + arch, + "--optimizer", + "adam", + "--lr", + "0.0001", + "--criterion", + "adaptive_loss", + "--adaptive-softmax-cutoff", + "5,10,15", + "--max-tokens", + "500", + "--tokens-per-sample", + "500", + "--save-dir", + data_dir, + "--max-update", + "3", + "--no-progress-bar", + "--distributed-world-size", + "1", + "--ddp-backend", + "no_c10d", + "--num-workers", + 0, + "--quant-noise-scalar", + "0.5", + ] + + (extra_flags or []), + ) + train.main(scalar_quant_train_args) + + # try iterative PQ quantization + quantize_parser = options.get_training_parser() + quantize_args = options.parse_args_and_arch( + quantize_parser, + [ + "--task", + "language_modeling", + data_dir, + "--arch", + arch, + "--optimizer", + "adam", + "--lr", + "0.0001", + "--criterion", + "adaptive_loss", + "--adaptive-softmax-cutoff", + "5,10,15", + "--max-tokens", + "50", + "--tokens-per-sample", + "50", + "--max-update", + "6", + "--no-progress-bar", + "--distributed-world-size", + "1", + "--ddp-backend", + "no_c10d", + "--num-workers", + 0, + "--restore-file", + os.path.join(data_dir, "checkpoint_last.pt"), + "--reset-optimizer", + "--quantization-config-path", + os.path.join( + os.path.dirname(__file__), "transformer_quantization_config.yaml" + ), + ] + + (extra_flags or []), + ) + train.main(quantize_args) + + +class TestQuantization(unittest.TestCase): + def setUp(self): + logging.disable(logging.CRITICAL) + + def tearDown(self): + logging.disable(logging.NOTSET) + + @unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU") + def test_quantization(self): + with contextlib.redirect_stdout(StringIO()): + with tempfile.TemporaryDirectory("test_quantization") as data_dir: + create_dummy_data(data_dir) + preprocess_lm_data(data_dir) + # tests both scalar and iterative PQ quantization + _quantize_language_model(data_dir, "transformer_lm") + + +class TestOptimizersGPU(unittest.TestCase): + def setUp(self): + logging.disable(logging.CRITICAL) + + def tearDown(self): + logging.disable(logging.NOTSET) + + @unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU") + def test_flat_grads(self): + with contextlib.redirect_stdout(StringIO()): + with tempfile.TemporaryDirectory("test_flat_grads") as data_dir: + # Use just a bit of data and tiny model to keep this test runtime reasonable + create_dummy_data(data_dir, num_examples=10, maxlen=5) + preprocess_translation_data(data_dir) + with self.assertRaises(RuntimeError): + # adafactor isn't compatible with flat grads, which + # are used by default with --fp16 + train_translation_model( + data_dir, + "lstm", + [ + "--required-batch-size-multiple", + "1", + "--encoder-layers", + "1", + "--encoder-hidden-size", + "32", + "--decoder-layers", + "1", + "--optimizer", + "adafactor", + "--fp16", + ], + ) + # but it should pass once we set --fp16-no-flatten-grads + train_translation_model( + data_dir, + "lstm", + [ + "--required-batch-size-multiple", + "1", + "--encoder-layers", + "1", + "--encoder-hidden-size", + "32", + "--decoder-layers", + "1", + "--optimizer", + "adafactor", + "--fp16", + "--fp16-no-flatten-grads", + ], + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/transformer_quantization_config.yaml b/tests/gpu/transformer_quantization_config.yaml similarity index 100% rename from tests/transformer_quantization_config.yaml rename to tests/gpu/transformer_quantization_config.yaml diff --git a/tests/test_binaries.py b/tests/test_binaries.py index e1f037bcb..8e8732b64 100644 --- a/tests/test_binaries.py +++ b/tests/test_binaries.py @@ -8,19 +8,22 @@ from io import StringIO import logging import os import random -import sys import tempfile import unittest import torch from fairseq import options -from fairseq_cli import preprocess from fairseq_cli import train -from fairseq_cli import generate -from fairseq_cli import interactive from fairseq_cli import eval_lm from fairseq_cli import validate +from tests.utils import ( + create_dummy_data, + preprocess_lm_data, + preprocess_translation_data, + train_translation_model, + generate_main, +) class TestTranslation(unittest.TestCase): @@ -47,24 +50,6 @@ class TestTranslation(unittest.TestCase): train_translation_model(data_dir, 'fconv_iwslt_de_en', ['--dataset-impl', 'raw']) generate_main(data_dir, ['--dataset-impl', 'raw']) - @unittest.skipIf(not torch.cuda.is_available(), 'test requires a GPU') - def test_fp16(self): - with contextlib.redirect_stdout(StringIO()): - with tempfile.TemporaryDirectory('test_fp16') as data_dir: - create_dummy_data(data_dir) - preprocess_translation_data(data_dir) - train_translation_model(data_dir, 'fconv_iwslt_de_en', ['--fp16']) - generate_main(data_dir) - - @unittest.skipIf(not torch.cuda.is_available(), 'test requires a GPU') - def test_memory_efficient_fp16(self): - with contextlib.redirect_stdout(StringIO()): - with tempfile.TemporaryDirectory('test_memory_efficient_fp16') as data_dir: - create_dummy_data(data_dir) - preprocess_translation_data(data_dir) - train_translation_model(data_dir, 'fconv_iwslt_de_en', ['--memory-efficient-fp16']) - generate_main(data_dir) - def test_update_freq(self): with contextlib.redirect_stdout(StringIO()): with tempfile.TemporaryDirectory('test_update_freq') as data_dir: @@ -184,19 +169,28 @@ class TestTranslation(unittest.TestCase): ], run_validation=True) generate_main(data_dir) - @unittest.skipIf(not torch.cuda.is_available(), 'test requires a GPU') + @unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU") def test_transformer_fp16(self): with contextlib.redirect_stdout(StringIO()): - with tempfile.TemporaryDirectory('test_transformer') as data_dir: + with tempfile.TemporaryDirectory("test_transformer") as data_dir: create_dummy_data(data_dir) preprocess_translation_data(data_dir) - train_translation_model(data_dir, 'transformer_iwslt_de_en', [ - '--encoder-layers', '2', - '--decoder-layers', '2', - '--encoder-embed-dim', '8', - '--decoder-embed-dim', '8', - '--fp16', - ], run_validation=True) + train_translation_model( + data_dir, + "transformer_iwslt_de_en", + [ + "--encoder-layers", + "2", + "--decoder-layers", + "2", + "--encoder-embed-dim", + "8", + "--decoder-embed-dim", + "8", + "--fp16", + ], + run_validation=True, + ) generate_main(data_dir) def test_multilingual_transformer(self): @@ -296,23 +290,6 @@ class TestTranslation(unittest.TestCase): '--print-step', ]) - @unittest.skipIf(not torch.cuda.is_available(), 'test requires a GPU') - def test_levenshtein_transformer(self): - with contextlib.redirect_stdout(StringIO()): - with tempfile.TemporaryDirectory('test_levenshtein_transformer') as data_dir: - create_dummy_data(data_dir) - preprocess_translation_data(data_dir, ['--joined-dictionary']) - train_translation_model(data_dir, 'levenshtein_transformer', [ - '--apply-bert-init', '--early-exit', '6,6,6', - '--criterion', 'nat_loss' - ], task='translation_lev') - generate_main(data_dir, [ - '--task', 'translation_lev', - '--iter-decode-max-iter', '9', - '--iter-decode-eos-penalty', '0', - '--print-step', - ]) - def test_nonautoregressive_transformer(self): with contextlib.redirect_stdout(StringIO()): with tempfile.TemporaryDirectory('test_nonautoregressive_transformer') as data_dir: @@ -714,23 +691,6 @@ def train_legacy_masked_language_model(data_dir, arch, extra_args=()): train.main(train_args) -class TestQuantization(unittest.TestCase): - def setUp(self): - logging.disable(logging.CRITICAL) - - def tearDown(self): - logging.disable(logging.NOTSET) - - @unittest.skipIf(not torch.cuda.is_available(), 'test requires a GPU') - def test_quantization(self): - with contextlib.redirect_stdout(StringIO()): - with tempfile.TemporaryDirectory('test_quantization') as data_dir: - create_dummy_data(data_dir) - preprocess_lm_data(data_dir) - # tests both scalar and iterative PQ quantization - quantize_language_model(data_dir, 'transformer_lm') - - class TestOptimizers(unittest.TestCase): def setUp(self): @@ -759,74 +719,6 @@ class TestOptimizers(unittest.TestCase): ]) generate_main(data_dir) - @unittest.skipIf(not torch.cuda.is_available(), 'test requires a GPU') - def test_flat_grads(self): - with contextlib.redirect_stdout(StringIO()): - with tempfile.TemporaryDirectory('test_flat_grads') as data_dir: - # Use just a bit of data and tiny model to keep this test runtime reasonable - create_dummy_data(data_dir, num_examples=10, maxlen=5) - preprocess_translation_data(data_dir) - with self.assertRaises(RuntimeError): - # adafactor isn't compatible with flat grads, which - # are used by default with --fp16 - train_translation_model(data_dir, 'lstm', [ - '--required-batch-size-multiple', '1', - '--encoder-layers', '1', - '--encoder-hidden-size', '32', - '--decoder-layers', '1', - '--optimizer', 'adafactor', - '--fp16', - ]) - # but it should pass once we set --fp16-no-flatten-grads - train_translation_model(data_dir, 'lstm', [ - '--required-batch-size-multiple', '1', - '--encoder-layers', '1', - '--encoder-hidden-size', '32', - '--decoder-layers', '1', - '--optimizer', 'adafactor', - '--fp16', - '--fp16-no-flatten-grads', - ]) - - -def create_dummy_data(data_dir, num_examples=100, maxlen=20, alignment=False): - def _create_dummy_data(filename): - data = torch.rand(num_examples * maxlen) - data = 97 + torch.floor(26 * data).int() - with open(os.path.join(data_dir, filename), 'w') as h: - offset = 0 - for _ in range(num_examples): - ex_len = random.randint(1, maxlen) - ex_str = ' '.join(map(chr, data[offset:offset+ex_len])) - print(ex_str, file=h) - offset += ex_len - - def _create_dummy_alignment_data(filename_src, filename_tgt, filename): - with open(os.path.join(data_dir, filename_src), 'r') as src_f, \ - open(os.path.join(data_dir, filename_tgt), 'r') as tgt_f, \ - open(os.path.join(data_dir, filename), 'w') as h: - for src, tgt in zip(src_f, tgt_f): - src_len = len(src.split()) - tgt_len = len(tgt.split()) - avg_len = (src_len + tgt_len) // 2 - num_alignments = random.randint(avg_len // 2, 2 * avg_len) - src_indices = torch.floor(torch.rand(num_alignments) * src_len).int() - tgt_indices = torch.floor(torch.rand(num_alignments) * tgt_len).int() - ex_str = ' '.join(["{}-{}".format(src, tgt) for src, tgt in zip(src_indices, tgt_indices)]) - print(ex_str, file=h) - - _create_dummy_data('train.in') - _create_dummy_data('train.out') - _create_dummy_data('valid.in') - _create_dummy_data('valid.out') - _create_dummy_data('test.in') - _create_dummy_data('test.out') - - if alignment: - _create_dummy_alignment_data('train.in', 'train.out', 'train.align') - _create_dummy_alignment_data('valid.in', 'valid.out', 'valid.align') - _create_dummy_alignment_data('test.in', 'test.out', 'test.align') - def create_dummy_roberta_head_data(data_dir, num_examples=100, maxlen=10, num_classes=2, regression=False): input_dir = 'input0' @@ -862,109 +754,6 @@ def create_dummy_roberta_head_data(data_dir, num_examples=100, maxlen=10, num_cl _create_dummy_data('test') -def preprocess_translation_data(data_dir, extra_flags=None): - preprocess_parser = options.get_preprocessing_parser() - preprocess_args = preprocess_parser.parse_args( - [ - '--source-lang', 'in', - '--target-lang', 'out', - '--trainpref', os.path.join(data_dir, 'train'), - '--validpref', os.path.join(data_dir, 'valid'), - '--testpref', os.path.join(data_dir, 'test'), - '--thresholdtgt', '0', - '--thresholdsrc', '0', - '--destdir', data_dir, - ] + (extra_flags or []), - ) - preprocess.main(preprocess_args) - - -def train_translation_model(data_dir, arch, extra_flags=None, task='translation', run_validation=False, - lang_flags=None, extra_valid_flags=None): - if lang_flags is None: - lang_flags = [ - '--source-lang', 'in', - '--target-lang', 'out', - ] - train_parser = options.get_training_parser() - train_args = options.parse_args_and_arch( - train_parser, - [ - '--task', task, - data_dir, - '--save-dir', data_dir, - '--arch', arch, - '--lr', '0.05', - '--max-tokens', '500', - '--max-epoch', '1', - '--no-progress-bar', - '--distributed-world-size', '1', - '--num-workers', 0, - ] + lang_flags + (extra_flags or []), - ) - train.main(train_args) - - if run_validation: - # test validation - validate_parser = options.get_validation_parser() - validate_args = options.parse_args_and_arch( - validate_parser, - [ - '--task', task, - data_dir, - '--path', os.path.join(data_dir, 'checkpoint_last.pt'), - '--valid-subset', 'valid', - '--max-tokens', '500', - '--no-progress-bar', - ] + lang_flags + (extra_valid_flags or []) - ) - validate.main(validate_args) - - -def generate_main(data_dir, extra_flags=None): - if extra_flags is None: - extra_flags = [ - '--print-alignment', - ] - generate_parser = options.get_generation_parser() - generate_args = options.parse_args_and_arch( - generate_parser, - [ - data_dir, - '--path', os.path.join(data_dir, 'checkpoint_last.pt'), - '--beam', '3', - '--batch-size', '64', - '--max-len-b', '5', - '--gen-subset', 'valid', - '--no-progress-bar', - ] + (extra_flags or []), - ) - - # evaluate model in batch mode - generate.main(generate_args) - - # evaluate model interactively - generate_args.buffer_size = 0 - generate_args.input = '-' - generate_args.max_sentences = None - orig_stdin = sys.stdin - sys.stdin = StringIO('h e l l o\n') - interactive.main(generate_args) - sys.stdin = orig_stdin - - -def preprocess_lm_data(data_dir): - preprocess_parser = options.get_preprocessing_parser() - preprocess_args = preprocess_parser.parse_args([ - '--only-source', - '--trainpref', os.path.join(data_dir, 'train.out'), - '--validpref', os.path.join(data_dir, 'valid.out'), - '--testpref', os.path.join(data_dir, 'test.out'), - '--destdir', data_dir, - ]) - preprocess.main(preprocess_args) - - def train_masked_lm(data_dir, arch, extra_flags=None): train_parser = options.get_training_parser() train_args = options.parse_args_and_arch( @@ -1130,80 +919,5 @@ def train_masked_language_model(data_dir, arch, extra_args=()): train.main(train_args) -def quantize_language_model(data_dir, arch, extra_flags=None, run_validation=False): - train_parser = options.get_training_parser() - train_args = options.parse_args_and_arch( - train_parser, - [ - '--task', 'language_modeling', - data_dir, - '--arch', arch, - '--optimizer', 'adam', - '--lr', '0.0001', - '--criterion', 'adaptive_loss', - '--adaptive-softmax-cutoff', '5,10,15', - '--max-tokens', '500', - '--tokens-per-sample', '500', - '--save-dir', data_dir, - '--max-epoch', '1', - '--no-progress-bar', - '--distributed-world-size', '1', - '--ddp-backend', 'no_c10d', - '--num-workers', 0, - ] + (extra_flags or []), - ) - train.main(train_args) - - # try scalar quantization - scalar_quant_train_parser = options.get_training_parser() - scalar_quant_train_args = options.parse_args_and_arch( - scalar_quant_train_parser, - [ - '--task', 'language_modeling', - data_dir, - '--arch', arch, - '--optimizer', 'adam', - '--lr', '0.0001', - '--criterion', 'adaptive_loss', - '--adaptive-softmax-cutoff', '5,10,15', - '--max-tokens', '500', - '--tokens-per-sample', '500', - '--save-dir', data_dir, - '--max-update', '3', - '--no-progress-bar', - '--distributed-world-size', '1', - '--ddp-backend', 'no_c10d', - '--num-workers', 0, - '--quant-noise-scalar', '0.5', - ] + (extra_flags or []), - ) - train.main(scalar_quant_train_args) - - # try iterative PQ quantization - quantize_parser = options.get_training_parser() - quantize_args = options.parse_args_and_arch( - quantize_parser, - [ - '--task', 'language_modeling', - data_dir, - '--arch', arch, - '--optimizer', 'adam', - '--lr', '0.0001', - '--criterion', 'adaptive_loss', - '--adaptive-softmax-cutoff', '5,10,15', - '--max-tokens', '50', - '--tokens-per-sample', '50', - '--max-update', '6', - '--no-progress-bar', - '--distributed-world-size', '1', - '--ddp-backend', 'no_c10d', - '--num-workers', 0, - '--restore-file', os.path.join(data_dir, 'checkpoint_last.pt'), - '--reset-optimizer', - '--quantization-config-path', os.path.join(os.path.dirname(__file__), 'transformer_quantization_config.yaml'), - ] + (extra_flags or []), - ) - train.main(quantize_args) - if __name__ == '__main__': unittest.main() diff --git a/tests/utils.py b/tests/utils.py index f908e5e74..e207575d6 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -4,10 +4,14 @@ # LICENSE file in the root directory of this source tree. import argparse +import os +import random +import sys import torch import torch.nn.functional as F -from fairseq import utils +from io import StringIO +from fairseq import options, utils from fairseq.data import Dictionary from fairseq.data.language_pair_dataset import collate from fairseq.models import ( @@ -17,6 +21,13 @@ from fairseq.models import ( ) from fairseq.models.fairseq_encoder import EncoderOut from fairseq.tasks import FairseqTask +from fairseq_cli import ( + generate, + interactive, + preprocess, + train, + validate, +) def dummy_dictionary(vocab_size, prefix='token_'): @@ -116,6 +127,148 @@ def sequence_generator_setup(): return tgt_dict, w1, w2, src_tokens, src_lengths, model +def create_dummy_data(data_dir, num_examples=100, maxlen=20, alignment=False): + def _create_dummy_data(filename): + data = torch.rand(num_examples * maxlen) + data = 97 + torch.floor(26 * data).int() + with open(os.path.join(data_dir, filename), 'w') as h: + offset = 0 + for _ in range(num_examples): + ex_len = random.randint(1, maxlen) + ex_str = ' '.join(map(chr, data[offset:offset+ex_len])) + print(ex_str, file=h) + offset += ex_len + + def _create_dummy_alignment_data(filename_src, filename_tgt, filename): + with open(os.path.join(data_dir, filename_src), 'r') as src_f, \ + open(os.path.join(data_dir, filename_tgt), 'r') as tgt_f, \ + open(os.path.join(data_dir, filename), 'w') as h: + for src, tgt in zip(src_f, tgt_f): + src_len = len(src.split()) + tgt_len = len(tgt.split()) + avg_len = (src_len + tgt_len) // 2 + num_alignments = random.randint(avg_len // 2, 2 * avg_len) + src_indices = torch.floor(torch.rand(num_alignments) * src_len).int() + tgt_indices = torch.floor(torch.rand(num_alignments) * tgt_len).int() + ex_str = ' '.join(["{}-{}".format(src, tgt) for src, tgt in zip(src_indices, tgt_indices)]) + print(ex_str, file=h) + + _create_dummy_data('train.in') + _create_dummy_data('train.out') + _create_dummy_data('valid.in') + _create_dummy_data('valid.out') + _create_dummy_data('test.in') + _create_dummy_data('test.out') + + if alignment: + _create_dummy_alignment_data('train.in', 'train.out', 'train.align') + _create_dummy_alignment_data('valid.in', 'valid.out', 'valid.align') + _create_dummy_alignment_data('test.in', 'test.out', 'test.align') + + +def preprocess_lm_data(data_dir): + preprocess_parser = options.get_preprocessing_parser() + preprocess_args = preprocess_parser.parse_args([ + '--only-source', + '--trainpref', os.path.join(data_dir, 'train.out'), + '--validpref', os.path.join(data_dir, 'valid.out'), + '--testpref', os.path.join(data_dir, 'test.out'), + '--destdir', data_dir, + ]) + preprocess.main(preprocess_args) + + +def preprocess_translation_data(data_dir, extra_flags=None): + preprocess_parser = options.get_preprocessing_parser() + preprocess_args = preprocess_parser.parse_args( + [ + '--source-lang', 'in', + '--target-lang', 'out', + '--trainpref', os.path.join(data_dir, 'train'), + '--validpref', os.path.join(data_dir, 'valid'), + '--testpref', os.path.join(data_dir, 'test'), + '--thresholdtgt', '0', + '--thresholdsrc', '0', + '--destdir', data_dir, + ] + (extra_flags or []), + ) + preprocess.main(preprocess_args) + + +def train_translation_model(data_dir, arch, extra_flags=None, task='translation', run_validation=False, + lang_flags=None, extra_valid_flags=None): + if lang_flags is None: + lang_flags = [ + '--source-lang', 'in', + '--target-lang', 'out', + ] + train_parser = options.get_training_parser() + train_args = options.parse_args_and_arch( + train_parser, + [ + '--task', task, + data_dir, + '--save-dir', data_dir, + '--arch', arch, + '--lr', '0.05', + '--max-tokens', '500', + '--max-epoch', '1', + '--no-progress-bar', + '--distributed-world-size', '1', + '--num-workers', 0, + ] + lang_flags + (extra_flags or []), + ) + train.main(train_args) + + if run_validation: + # test validation + validate_parser = options.get_validation_parser() + validate_args = options.parse_args_and_arch( + validate_parser, + [ + '--task', task, + data_dir, + '--path', os.path.join(data_dir, 'checkpoint_last.pt'), + '--valid-subset', 'valid', + '--max-tokens', '500', + '--no-progress-bar', + ] + lang_flags + (extra_valid_flags or []) + ) + validate.main(validate_args) + + +def generate_main(data_dir, extra_flags=None): + if extra_flags is None: + extra_flags = [ + '--print-alignment', + ] + generate_parser = options.get_generation_parser() + generate_args = options.parse_args_and_arch( + generate_parser, + [ + data_dir, + '--path', os.path.join(data_dir, 'checkpoint_last.pt'), + '--beam', '3', + '--batch-size', '64', + '--max-len-b', '5', + '--gen-subset', 'valid', + '--no-progress-bar', + ] + (extra_flags or []), + ) + + # evaluate model in batch mode + generate.main(generate_args) + + # evaluate model interactively + generate_args.buffer_size = 0 + generate_args.input = '-' + generate_args.max_sentences = None + orig_stdin = sys.stdin + sys.stdin = StringIO('h e l l o\n') + interactive.main(generate_args) + sys.stdin = orig_stdin + + class TestDataset(torch.utils.data.Dataset): def __init__(self, data):