2020-04-17 01:47:34 +03:00
|
|
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
|
|
#
|
|
|
|
# This source code is licensed under the MIT license found in the
|
|
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
|
|
|
|
import argparse
|
|
|
|
import tempfile
|
|
|
|
import unittest
|
|
|
|
|
|
|
|
import torch
|
|
|
|
from fairseq.data.dictionary import Dictionary
|
|
|
|
from fairseq.models.lstm import LSTMModel
|
2020-09-10 03:00:56 +03:00
|
|
|
from fairseq.tasks.fairseq_task import LegacyFairseqTask
|
2020-04-17 01:47:34 +03:00
|
|
|
|
|
|
|
|
|
|
|
DEFAULT_TEST_VOCAB_SIZE = 100
|
|
|
|
|
|
|
|
|
2020-09-10 03:00:56 +03:00
|
|
|
class DummyTask(LegacyFairseqTask):
|
2020-04-17 01:47:34 +03:00
|
|
|
def __init__(self, args):
|
|
|
|
super().__init__(args)
|
|
|
|
self.dictionary = get_dummy_dictionary()
|
|
|
|
if getattr(self.args, "ctc", False):
|
|
|
|
self.dictionary.add_symbol("<ctc_blank>")
|
|
|
|
self.src_dict = self.dictionary
|
|
|
|
self.tgt_dict = self.dictionary
|
|
|
|
|
|
|
|
@property
|
|
|
|
def source_dictionary(self):
|
|
|
|
return self.src_dict
|
|
|
|
|
|
|
|
@property
|
|
|
|
def target_dictionary(self):
|
|
|
|
return self.dictionary
|
|
|
|
|
|
|
|
|
|
|
|
def get_dummy_dictionary(vocab_size=DEFAULT_TEST_VOCAB_SIZE):
|
|
|
|
dummy_dict = Dictionary()
|
|
|
|
# add dummy symbol to satisfy vocab size
|
|
|
|
for id, _ in enumerate(range(vocab_size)):
|
|
|
|
dummy_dict.add_symbol("{}".format(id), 1000)
|
|
|
|
return dummy_dict
|
|
|
|
|
|
|
|
|
|
|
|
def get_dummy_task_and_parser():
|
|
|
|
"""
|
|
|
|
to build a fariseq model, we need some dummy parse and task. This function
|
|
|
|
is used to create dummy task and parser to faciliate model/criterion test
|
|
|
|
|
|
|
|
Note: we use FbSpeechRecognitionTask as the dummy task. You may want
|
|
|
|
to use other task by providing another function
|
|
|
|
"""
|
|
|
|
parser = argparse.ArgumentParser(
|
|
|
|
description="test_dummy_s2s_task", argument_default=argparse.SUPPRESS
|
|
|
|
)
|
|
|
|
DummyTask.add_args(parser)
|
|
|
|
args = parser.parse_args([])
|
|
|
|
task = DummyTask.setup_task(args)
|
|
|
|
return task, parser
|
|
|
|
|
|
|
|
|
|
|
|
class TestJitLSTMModel(unittest.TestCase):
|
|
|
|
def _test_save_and_load(self, scripted_module):
|
|
|
|
with tempfile.NamedTemporaryFile() as f:
|
|
|
|
scripted_module.save(f.name)
|
|
|
|
torch.jit.load(f.name)
|
|
|
|
|
|
|
|
def assertTensorEqual(self, t1, t2):
|
2020-05-14 20:21:26 +03:00
|
|
|
t1 = t1[~torch.isnan(t1)] # can cause size mismatch errors if there are NaNs
|
|
|
|
t2 = t2[~torch.isnan(t2)]
|
2020-04-17 01:47:34 +03:00
|
|
|
self.assertEqual(t1.size(), t2.size(), "size mismatch")
|
|
|
|
self.assertEqual(t1.ne(t2).long().sum(), 0)
|
|
|
|
|
|
|
|
def test_jit_and_export_lstm(self):
|
|
|
|
task, parser = get_dummy_task_and_parser()
|
|
|
|
LSTMModel.add_args(parser)
|
|
|
|
args = parser.parse_args([])
|
|
|
|
args.criterion = ""
|
|
|
|
model = LSTMModel.build_model(args, task)
|
|
|
|
scripted_model = torch.jit.script(model)
|
|
|
|
self._test_save_and_load(scripted_model)
|
|
|
|
|
|
|
|
def test_assert_jit_vs_nonjit_(self):
|
|
|
|
task, parser = get_dummy_task_and_parser()
|
|
|
|
LSTMModel.add_args(parser)
|
|
|
|
args = parser.parse_args([])
|
|
|
|
args.criterion = ""
|
|
|
|
model = LSTMModel.build_model(args, task)
|
|
|
|
model.eval()
|
|
|
|
scripted_model = torch.jit.script(model)
|
|
|
|
scripted_model.eval()
|
|
|
|
idx = len(task.source_dictionary)
|
|
|
|
iter = 100
|
|
|
|
# Inject random input and check output
|
|
|
|
seq_len_tensor = torch.randint(1, 10, (iter, ))
|
|
|
|
num_samples_tensor = torch.randint(1, 10, (iter, ))
|
|
|
|
for i in range(iter):
|
|
|
|
seq_len = seq_len_tensor[i]
|
|
|
|
num_samples = num_samples_tensor[i]
|
|
|
|
src_token = torch.randint(0, idx, (num_samples, seq_len)),
|
|
|
|
src_lengths = torch.randint(1, seq_len+1, (num_samples,))
|
|
|
|
src_lengths, _ = torch.sort(src_lengths, descending=True)
|
|
|
|
# Force the first sample to have seq_len
|
|
|
|
src_lengths[0] = seq_len
|
|
|
|
prev_output_token = torch.randint(0, idx, (num_samples, 1)),
|
|
|
|
result = model(src_token[0], src_lengths, prev_output_token[0], None)
|
|
|
|
scripted_result = scripted_model(src_token[0], src_lengths, prev_output_token[0], None)
|
|
|
|
self.assertTensorEqual(result[0], scripted_result[0])
|
|
|
|
self.assertTensorEqual(result[1], scripted_result[1])
|
|
|
|
|
2020-04-21 20:46:21 +03:00
|
|
|
|
2020-04-17 01:47:34 +03:00
|
|
|
if __name__ == "__main__":
|
|
|
|
unittest.main()
|