Deprecate the SequenceGenerator with the Scripted vision (#1120)

Summary:
Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1120

Pull Request resolved: https://github.com/pytorch/fairseq/pull/1940

Deprecate the SequenceGenerator in Fairseq with the Scripted vision.

Pass all integration unit tests

- Copy ScriptSequenceGenerator to SequenceGenerator:
  - Modified the forward_decoder to fix bug when using adaptive_softmax in `get_prob_normalize` (marked with the inline comment)
   - Add support for other EnsembleModels as input arg (marked with the inline comment)
 - Add `FBEnsembleModelWithFork` to support folk/join in ensemblemodel
   - Add `test_fb_ensemble_model` to test folk/join feature
   - Still have bugs in folk/join feature when running in the Fairseq interface (like generation and interactive). Need further investigation P128130029. cc cndn, jhcross
- Modified SequenceGenerator initialization the interface
- Clear up the codes: delete unused functions `get_normalized_probs` and `_decode`

Reviewed By: myleott

Differential Revision: D20685075

fbshipit-source-id: 046b76874465a70d8118a97ad670311c6ce1d1c8
This commit is contained in:
Chen Liu 2020-04-06 17:45:48 -07:00 committed by Facebook GitHub Bot
parent e9014fb424
commit bc93681348
8 changed files with 734 additions and 416 deletions

View File

@ -57,7 +57,7 @@ class FairseqDecoder(nn.Module):
def get_normalized_probs(
self,
net_output: Tuple[Tensor, Dict[str, List[Optional[Tensor]]]],
net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]],
log_probs: bool,
sample: Optional[Dict[str, Tensor]],
):

View File

@ -45,7 +45,7 @@ class BaseFairseqModel(nn.Module):
def get_normalized_probs(
self,
net_output: Tuple[Tensor, Dict[str, List[Optional[Tensor]]]],
net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]],
log_probs: bool,
sample: Optional[Dict[str, Tensor]] = None,
):
@ -58,7 +58,7 @@ class BaseFairseqModel(nn.Module):
# call the helper function from scriptable Subclass.
def get_normalized_probs_scriptable(
self,
net_output: Tuple[Tensor, Dict[str, List[Optional[Tensor]]]],
net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]],
log_probs: bool,
sample: Optional[Dict[str, Tensor]] = None,
):

View File

@ -280,7 +280,7 @@ class TransformerModel(FairseqEncoderDecoderModel):
@torch.jit.export
def get_normalized_probs(
self,
net_output: Tuple[Tensor, Dict[str, List[Optional[Tensor]]]],
net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]],
log_probs: bool,
sample: Optional[Dict[str, Tensor]] = None,
):

File diff suppressed because it is too large Load Diff

View File

@ -295,6 +295,7 @@ class FairseqTask(object):
seq_gen_cls = SequenceGenerator
return seq_gen_cls(
models,
self.target_dictionary,
beam_size=getattr(args, "beam", 5),
max_len_a=getattr(args, "max_len_a", 0),

View File

@ -89,6 +89,7 @@ class TranslationFromPretrainedBARTTask(TranslationTask):
else:
from fairseq.sequence_generator import SequenceGenerator
return SequenceGenerator(
models,
self.target_dictionary,
beam_size=getattr(args, 'beam', 5),
max_len_a=getattr(args, 'max_len_a', 0),

View File

@ -42,6 +42,7 @@ class TestBacktranslationDataset(unittest.TestCase):
)
generator = SequenceGenerator(
[self.model],
tgt_dict=self.tgt_dict,
max_len_a=0,
max_len_b=200,

View File

@ -4,29 +4,173 @@
# LICENSE file in the root directory of this source tree.
import argparse
import tempfile
import unittest
import torch
from fairseq import search
from fairseq.sequence_generator import SequenceGenerator
import tests.utils as test_utils
import torch
from fairseq import search
from fairseq.data.dictionary import Dictionary
from fairseq.models.transformer import TransformerModel
from fairseq.sequence_generator import SequenceGenerator, EnsembleModel
from fairseq.tasks.fairseq_task import FairseqTask
DEFAULT_TEST_VOCAB_SIZE = 100
class DummyTask(FairseqTask):
def __init__(self, args):
super().__init__(args)
self.dictionary = get_dummy_dictionary()
if getattr(self.args, "ctc", False):
self.dictionary.add_symbol("<ctc_blank>")
self.src_dict = self.dictionary
self.tgt_dict = self.dictionary
@property
def source_dictionary(self):
return self.src_dict
@property
def target_dictionary(self):
return self.dictionary
def get_dummy_dictionary(vocab_size=DEFAULT_TEST_VOCAB_SIZE):
dummy_dict = Dictionary()
# add dummy symbol to satisfy vocab size
for id, _ in enumerate(range(vocab_size)):
dummy_dict.add_symbol("{}".format(id), 1000)
return dummy_dict
def get_dummy_task_and_parser():
"""
to build a fariseq model, we need some dummy parse and task. This function
is used to create dummy task and parser to faciliate model/criterion test
Note: we use FbSpeechRecognitionTask as the dummy task. You may want
to use other task by providing another function
"""
parser = argparse.ArgumentParser(
description="test_dummy_s2s_task", argument_default=argparse.SUPPRESS
)
DummyTask.add_args(parser)
args = parser.parse_args([])
task = DummyTask.setup_task(args)
return task, parser
class TestJitSequenceGeneratorBase(unittest.TestCase):
def setUp(self):
self.task, self.parser = get_dummy_task_and_parser()
eos = self.task.tgt_dict.eos()
src_tokens = torch.randint(3, 50, (2, 10)).long()
src_tokens = torch.cat((src_tokens, torch.LongTensor([[eos], [eos]])), -1)
src_lengths = torch.LongTensor([2, 10])
self.sample = {
"net_input": {"src_tokens": src_tokens, "src_lengths": src_lengths}
}
TransformerModel.add_args(self.parser)
args = self.parser.parse_args([])
args.encoder_layers = 2
args.decoder_layers = 1
self.transformer_model = TransformerModel.build_model(args, self.task)
def assertOutputEqual(self, hypo, pos_probs):
pos_scores = torch.FloatTensor(pos_probs).log()
self.assertTensorSizeEqual(hypo["positional_scores"], pos_scores)
self.assertTensorSizeEqual(pos_scores.numel(), hypo["tokens"].numel())
def assertTensorSizeEqual(self, t1, t2):
self.assertEqual(t1.size(), t2.size(), "size mismatch")
def assertAlmostEqual(self, t1, t2):
self.assertEqual(t1.size(), t2.size(), "size mismatch")
self.assertLess((t1 - t2).abs().max(), 1e-4)
def assertTensorEqual(self, t1, t2):
self.assertEqual(t1.size(), t2.size(), "size mismatch")
self.assertEqual(t1.ne(t2).long().sum(), 0)
def assertHypoEqual(self, h1, h2):
"Check two hypos are equal"
self.assertTensorEqual(h1["tokens"], h2["tokens"])
self.assertAlmostEqual(h1["positional_scores"], h2["positional_scores"])
self.assertLess(abs(h1["score"] - h2["score"]), 1e-6)
self.assertAlmostEqual(h1["attention"], h2["attention"])
def _test_save_and_load(self, scripted_module):
with tempfile.NamedTemporaryFile() as f:
scripted_module.save(f.name)
torch.jit.load(f.name)
class TestJitSequeneceGenerator(TestJitSequenceGeneratorBase):
@unittest.skipIf(
torch.__version__ < "1.5.0", "Targeting OSS scriptability for the 1.5 release"
)
def test_export_transformer(self):
model = self.transformer_model
torch.jit.script(model)
@unittest.skipIf(
torch.__version__ < "1.5.0", "Targeting OSS scriptability for the 1.5 release"
)
def test_ensemble_sequence_generator(self):
model = self.transformer_model
generator = SequenceGenerator([model], self.task.tgt_dict, beam_size=2)
scripted_model = torch.jit.script(generator)
self._test_save_and_load(scripted_model)
class TestJitEnsemble(TestJitSequenceGeneratorBase):
def test_export_ensemble_model(self):
model = self.transformer_model
ensemble_models = EnsembleModel([model])
torch.jit.script(ensemble_models)
class TestExportSearch(unittest.TestCase):
def setUp(self):
task, _ = get_dummy_task_and_parser()
self.tgt_dict = task.tgt_dict
self.min_top1_prob = 0.4
def test_export_diverse_bs(self):
search_strategy = search.DiverseBeamSearch(
self.tgt_dict, num_groups=2, diversity_strength=0.0
)
torch.jit.script(search_strategy)
def test_export_sampling(self):
low_sampling_topp = self.min_top1_prob / 2.0
search_strategy = search.Sampling(
self.tgt_dict, sampling_topp=low_sampling_topp
)
torch.jit.script(search_strategy)
def test_export_diverse_siblings_search(self):
search_strategy = search.DiverseSiblingsSearch(
self.tgt_dict, diversity_rate=0.5
)
torch.jit.script(search_strategy)
class TestSequenceGeneratorBase(unittest.TestCase):
def assertHypoTokens(self, hypo, tokens):
self.assertTensorEqual(hypo['tokens'], torch.LongTensor(tokens))
self.assertTensorEqual(hypo["tokens"], torch.LongTensor(tokens))
def assertHypoScore(self, hypo, pos_probs, normalized=True, lenpen=1.):
def assertHypoScore(self, hypo, pos_probs, normalized=True, lenpen=1.0):
pos_scores = torch.FloatTensor(pos_probs).log()
self.assertAlmostEqual(hypo['positional_scores'], pos_scores)
self.assertEqual(pos_scores.numel(), hypo['tokens'].numel())
self.assertAlmostEqual(hypo["positional_scores"], pos_scores)
self.assertEqual(pos_scores.numel(), hypo["tokens"].numel())
score = pos_scores.sum()
if normalized:
score /= pos_scores.numel()**lenpen
self.assertLess(abs(score - hypo['score']), 1e-6)
score /= pos_scores.numel() ** lenpen
self.assertLess(abs(score - hypo["score"]), 1e-6)
def assertAlmostEqual(self, t1, t2):
self.assertEqual(t1.size(), t2.size(), "size mismatch")
@ -37,21 +181,18 @@ class TestSequenceGeneratorBase(unittest.TestCase):
self.assertEqual(t1.ne(t2).long().sum(), 0)
class TestSequenceGenerator(TestSequenceGeneratorBase):
class TestSequeneceGenerator(TestSequenceGeneratorBase):
def setUp(self):
self.tgt_dict, self.w1, self.w2, src_tokens, src_lengths, self.model = (
test_utils.sequence_generator_setup()
)
self.sample = {
'net_input': {
'src_tokens': src_tokens, 'src_lengths': src_lengths,
},
"net_input": {"src_tokens": src_tokens, "src_lengths": src_lengths}
}
def test_with_normalization(self):
generator = SequenceGenerator(self.tgt_dict, beam_size=2)
hypos = generator.generate([self.model], self.sample)
generator = SequenceGenerator([self.model], self.tgt_dict, beam_size=2)
hypos = generator.forward(self.sample)
eos, w1, w2 = self.tgt_dict.eos(), self.w1, self.w2
# sentence 1, beam 1
self.assertHypoTokens(hypos[0][0], [w1, eos])
@ -69,8 +210,10 @@ class TestSequenceGenerator(TestSequenceGeneratorBase):
def test_without_normalization(self):
# Sentence 1: unchanged from the normalized case
# Sentence 2: beams swap order
generator = SequenceGenerator(self.tgt_dict, beam_size=2, normalize_scores=False)
hypos = generator.generate([self.model], self.sample)
generator = SequenceGenerator(
[self.model], self.tgt_dict, beam_size=2, normalize_scores=False
)
hypos = generator.forward(self.sample)
eos, w1, w2 = self.tgt_dict.eos(), self.w1, self.w2
# sentence 1, beam 1
self.assertHypoTokens(hypos[0][0], [w1, eos])
@ -87,8 +230,10 @@ class TestSequenceGenerator(TestSequenceGeneratorBase):
def test_with_lenpen_favoring_short_hypos(self):
lenpen = 0.6
generator = SequenceGenerator(self.tgt_dict, beam_size=2, len_penalty=lenpen)
hypos = generator.generate([self.model], self.sample)
generator = SequenceGenerator(
[self.model], self.tgt_dict, beam_size=2, len_penalty=lenpen
)
hypos = generator.forward(self.sample)
eos, w1, w2 = self.tgt_dict.eos(), self.w1, self.w2
# sentence 1, beam 1
self.assertHypoTokens(hypos[0][0], [w1, eos])
@ -105,8 +250,10 @@ class TestSequenceGenerator(TestSequenceGeneratorBase):
def test_with_lenpen_favoring_long_hypos(self):
lenpen = 5.0
generator = SequenceGenerator(self.tgt_dict, beam_size=2, len_penalty=lenpen)
hypos = generator.generate([self.model], self.sample)
generator = SequenceGenerator(
[self.model], self.tgt_dict, beam_size=2, len_penalty=lenpen
)
hypos = generator.forward(self.sample)
eos, w1, w2 = self.tgt_dict.eos(), self.w1, self.w2
# sentence 1, beam 1
self.assertHypoTokens(hypos[0][0], [w2, w1, w2, eos])
@ -122,8 +269,8 @@ class TestSequenceGenerator(TestSequenceGeneratorBase):
self.assertHypoScore(hypos[1][1], [0.7, 0.4, 0.6], lenpen=lenpen)
def test_maxlen(self):
generator = SequenceGenerator(self.tgt_dict, beam_size=2, max_len_b=2)
hypos = generator.generate([self.model], self.sample)
generator = SequenceGenerator([self.model], self.tgt_dict, beam_size=2, max_len_b=2)
hypos = generator.forward(self.sample)
eos, w1, w2 = self.tgt_dict.eos(), self.w1, self.w2
# sentence 1, beam 1
self.assertHypoTokens(hypos[0][0], [w1, eos])
@ -139,11 +286,11 @@ class TestSequenceGenerator(TestSequenceGeneratorBase):
self.assertHypoScore(hypos[1][1], [0.3, 0.9, 0.01])
def test_encoder_with_different_output_len(self):
generator = SequenceGenerator(self.tgt_dict, beam_size=2, max_len_b=2)
args = self.model.encoder.args
task = test_utils.TestTranslationTask.setup_task(args, self.tgt_dict, self.tgt_dict)
reshaping_model = test_utils.TestReshapingModel.build_model(args, task)
hypos = generator.generate([reshaping_model], self.sample)
generator = SequenceGenerator([reshaping_model], self.tgt_dict, beam_size=2, max_len_b=2)
hypos = generator.forward(self.sample)
for sent in [0, 1]:
for beam in [0, 1]:
assert hypos[sent][beam]['attention'] is not None
@ -210,10 +357,10 @@ class TestDiverseBeamSearch(TestSequenceGeneratorBase):
def test_diverse_beam_search(self):
search_strategy = search.DiverseBeamSearch(self.tgt_dict, num_groups=2, diversity_strength=0.)
generator = SequenceGenerator(
self.tgt_dict, beam_size=2, search_strategy=search_strategy,
[self.model], self.tgt_dict, beam_size=2, search_strategy=search_strategy,
)
sample = {'net_input': {'src_tokens': self.src_tokens, 'src_lengths': self.src_lengths}}
hypos = generator.generate([self.model], sample)
hypos = generator.forward(sample)
eos, w1, w2 = self.eos, self.w1, self.w2
# sentence 1, beam 1
self.assertHypoTokens(hypos[0][0], [w1, w1, eos])
@ -247,7 +394,7 @@ class TestDiverseSiblingsSearch(TestDiverseBeamSearch):
self.tgt_dict, diversity_rate=0.5
)
generator = SequenceGenerator(
self.tgt_dict, beam_size=2, search_strategy=search_strategy
[self.model], self.tgt_dict, beam_size=2, search_strategy=search_strategy
)
sample = {
"net_input": {
@ -255,7 +402,7 @@ class TestDiverseSiblingsSearch(TestDiverseBeamSearch):
"src_lengths": self.src_lengths,
}
}
hypos = generator.generate([self.model], sample)
hypos = generator.forward(sample)
eos, w1, w2 = self.eos, self.w1, self.w2
# sentence 1, beam 1
self.assertHypoTokens(hypos[0][0], [w1, w1, eos])
@ -338,14 +485,14 @@ class TestTopPSamplingSearch(TestSequenceGeneratorBase):
low_sampling_topp = self.min_top1_prob/2.0
search_strategy = search.Sampling(self.tgt_dict, sampling_topp=low_sampling_topp)
generator = SequenceGenerator(
self.tgt_dict, beam_size=2, search_strategy=search_strategy)
[self.model], self.tgt_dict, beam_size=2, search_strategy=search_strategy)
sample = {
'net_input': {
'src_tokens': self.src_tokens,
'src_lengths': self.src_lengths
}
}
hypos = generator.generate([self.model], sample)
hypos = generator.forward(sample)
eos, w1 = self.eos, self.w1
# sentence 1, beam 1
self.assertHypoTokens(hypos[0][0], [w1, w1, eos])
@ -366,14 +513,14 @@ class TestTopPSamplingSearch(TestSequenceGeneratorBase):
high_sampling_topp = (self.min_top1_prob+self.min_top2_prob)/2.0
search_strategy = search.Sampling(self.tgt_dict, sampling_topp=high_sampling_topp)
generator = SequenceGenerator(
self.tgt_dict, beam_size=2, search_strategy=search_strategy)
[self.model], self.tgt_dict, beam_size=2, search_strategy=search_strategy)
sample = {
'net_input': {
'src_tokens': self.src_tokens,
'src_lengths': self.src_lengths
}
}
hypos = generator.generate([self.model], sample)
hypos = generator.forward(sample)
eos, w1, w2 = self.eos, self.w1, self.w2
# sentence 1, beam 1
self.assertTrue(self.hypoTokens(hypos[0][0], [w1, w1, eos]) or
@ -420,5 +567,5 @@ class TestTopPSamplingSearch(TestSequenceGeneratorBase):
return t1.size() == t2.size() and t1.ne(t2).long().sum() == 0
if __name__ == '__main__':
if __name__ == "__main__":
unittest.main()