Pass all net_inputs in SequenceGenerator (#2090)

Summary:
# Before submitting

- [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements)
- [x] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)?
- [ ] Did you make sure to update the docs?
- [x] Did you write any new necessary tests?

## What does this PR do?
Fixes https://github.com/pytorch/fairseq/issues/2022.

## PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

## Did you have fun?
Make sure you had fun coding �
Pull Request resolved: https://github.com/pytorch/fairseq/pull/2090

Reviewed By: cndn

Differential Revision: D21385984

Pulled By: myleott

fbshipit-source-id: 1428e02e625b8625df71a83c05dcf933c3f899df
This commit is contained in:
Marco Gaido 2020-05-10 06:11:24 -07:00 committed by Facebook GitHub Bot
parent be86e7ebef
commit 11345a7608
4 changed files with 90 additions and 14 deletions

View File

@ -3,8 +3,9 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn as nn
from typing import List, NamedTuple, Optional
from typing import Dict, List, NamedTuple, Optional
from torch import Tensor
EncoderOut = NamedTuple(
@ -37,6 +38,29 @@ class FairseqEncoder(nn.Module):
"""
raise NotImplementedError
def forward_torchscript(self, net_input: Dict[str, Tensor]):
"""A TorchScript-compatible version of forward.
Encoders which use additional arguments may want to override
this method for TorchScript compatibility.
"""
if torch.jit.is_scripting():
return self.forward(
src_tokens=net_input["src_tokens"],
src_lengths=net_input["src_lengths"],
)
else:
return self.forward_non_torchscript(net_input)
@torch.jit.unused
def forward_non_torchscript(self, net_input: Dict[str, Tensor]):
encoder_input = {
k: v
for k, v in net_input.items()
if k != "prev_output_tokens"
}
return self.forward(**encoder_input)
def reorder_encoder_out(self, encoder_out, new_order):
"""
Reorder encoder output according to `new_order`.

View File

@ -112,6 +112,7 @@ class SequenceGenerator(nn.Module):
self.model.reset_incremental_state()
return self._generate(sample, prefix_tokens, bos_token)
# TODO(myleott): unused, deprecate after pytorch-translate migration
def generate_batched_itr(self, data_itr, beam_size=None, cuda=False, timer=None):
"""Iterate over a batched dataset and yield individual translations.
Args:
@ -165,13 +166,8 @@ class SequenceGenerator(nn.Module):
prefix_tokens: Optional[Tensor] = None,
bos_token: Optional[int] = None,
):
encoder_input: Dict[str, Tensor] = {}
for k, v in sample["net_input"].items():
if k != "prev_output_tokens":
encoder_input[k] = v
src_tokens = encoder_input["src_tokens"]
net_input = sample["net_input"]
src_tokens = net_input["src_tokens"]
# length of the source text being the character length except EndOfSentence and pad
src_lengths = (
(src_tokens.ne(self.eos) & src_tokens.ne(self.pad)).long().sum(dim=1)
@ -194,10 +190,7 @@ class SequenceGenerator(nn.Module):
self.min_len <= max_len
), "min_len cannot be larger than max_len, please adjust these!"
# compute the encoder output for each beam
encoder_outs = self.model.forward_encoder(
src_tokens=encoder_input["src_tokens"],
src_lengths=encoder_input["src_lengths"],
)
encoder_outs = self.model.forward_encoder(net_input)
# placeholder of indices for bsz * beam_size to hold tokens and accumulative scores
new_order = torch.arange(bsz).view(-1, 1).repeat(1, beam_size).view(-1)
@ -707,11 +700,11 @@ class EnsembleModel(nn.Module):
return min([m.max_decoder_positions() for m in self.models])
@torch.jit.export
def forward_encoder(self, src_tokens, src_lengths):
def forward_encoder(self, net_input: Dict[str, Tensor]):
if not self.has_encoder():
return None
return [
model.encoder(src_tokens=src_tokens, src_lengths=src_lengths)
model.encoder.forward_torchscript(net_input)
for model in self.models
]

View File

@ -302,6 +302,19 @@ class TestSequeneceGenerator(TestSequenceGeneratorBase):
for beam in [0, 1]:
assert hypos[sent][beam]['attention'] is not None
def test_generation_with_additional_input(self):
args = self.model.encoder.args
task = test_utils.TestTranslationTask.setup_task(args, self.tgt_dict, self.tgt_dict)
add_input_model = test_utils.TestAdditionalInputModel.build_model(args, task)
generator = SequenceGenerator([add_input_model], self.tgt_dict, beam_size=2)
sample = self.sample.copy()
sample['net_input']['fancy_other_input'] = sample['net_input']['src_tokens']
hypos = generator.forward(self.sample)
eos, w1, w2 = self.tgt_dict.eos(), self.w1, self.w2
# sentence 1, beam 1
self.assertHypoTokens(hypos[0][0], [w1, eos])
self.assertHypoScore(hypos[0][0], [0.9, 1.0])
class TestDiverseBeamSearch(TestSequenceGeneratorBase):

View File

@ -292,3 +292,49 @@ class TestReshapingModel(FairseqEncoderDecoderModel):
encoder = TestReshapingEncoder(args, task.source_dictionary)
decoder = TestIncrementalDecoder(args, task.target_dictionary)
return cls(encoder, decoder)
class TestAdditionalInputEncoder(FairseqEncoder):
def __init__(self, args, dictionary):
super().__init__(dictionary)
self.args = args
def forward(self, src_tokens, src_lengths=None, **kwargs):
assert 'fancy_other_input' in kwargs
assert kwargs['fancy_other_input'] is not None
return EncoderOut(
encoder_out=src_tokens,
encoder_padding_mask=None,
encoder_embedding=None,
encoder_states=None,
src_tokens=None,
src_lengths=None,
)
def reorder_encoder_out(self, encoder_out, new_order):
return EncoderOut(
encoder_out=encoder_out.encoder_out.index_select(0, new_order),
encoder_padding_mask=None,
encoder_embedding=None,
encoder_states=None,
src_tokens=None,
src_lengths=None,
)
class TestAdditionalInputModel(FairseqEncoderDecoderModel):
def __init__(self, encoder, decoder):
super().__init__(encoder, decoder)
@classmethod
def build_model(cls, args, task):
encoder = TestAdditionalInputEncoder(args, task.source_dictionary)
decoder = TestIncrementalDecoder(args, task.target_dictionary)
return cls(encoder, decoder)
def forward(self, src_tokens, src_lengths, prev_output_tokens, **kwargs):
encoder_out = self.encoder(
src_tokens, src_lengths=src_lengths, **kwargs)
decoder_out = self.decoder(
prev_output_tokens, encoder_out=encoder_out, **kwargs)
return decoder_out