Script reoder_incremental_state in fairseq baseline model (#1127)

Summary:
Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1127

Pull Request resolved: https://github.com/pytorch/fairseq/pull/1953

Script the `reorder_incremental_states` in the base FairseqModel
Remove the overwrite scriptable `reorder_incremental_states` in the TransformerModel
Change the decoder_len, since len(Tuple) is supported in Script

Reviewed By: myleott

Differential Revision: D20797390

fbshipit-source-id: ab29874973adc5dbd556c591942a0e071c81fc52
This commit is contained in:
Chen Liu 2020-04-06 20:38:18 -07:00 committed by Facebook GitHub Bot
parent bc93681348
commit d369c88019
5 changed files with 21 additions and 39 deletions

View File

@ -2,9 +2,11 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import Dict, Optional
from fairseq.models import FairseqDecoder
from fairseq.incremental_decoding_utils import with_incremental_state
from torch import Tensor
@with_incremental_state
@ -59,24 +61,25 @@ class FairseqIncrementalDecoder(FairseqDecoder):
"""
raise NotImplementedError
def reorder_incremental_state(self, incremental_state, new_order):
def reorder_incremental_state(
self,
incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
new_order: Tensor,
):
"""Reorder incremental state.
This should be called when the order of the input has changed from the
previous time step. A typical use case is beam search, where the input
order changes between time steps based on the selection of beams.
"""
seen = set()
for module in self.modules():
if (
module != self
and hasattr(module, 'reorder_incremental_state')
and module not in seen
):
seen.add(module)
result = module.reorder_incremental_state(incremental_state, new_order)
if result is not None:
incremental_state = result
seen: Dict[int, Optional[Tensor]] = {}
for _, module in self.named_modules():
if hasattr(module, 'reorder_incremental_state'):
if id(module) not in seen and module is not self:
seen[id(module)] = None
result = module.reorder_incremental_state(incremental_state, new_order)
if result is not None:
incremental_state = result
def set_beam_size(self, beam_size):
"""Sets the beam size in the decoder and all children."""

View File

@ -879,16 +879,6 @@ class TransformerDecoder(FairseqIncrementalDecoder):
self._future_mask = self._future_mask.to(tensor)
return self._future_mask[:dim, :dim]
# Overwirte the method to temporaily soppurt jit scriptable in Transformer
@torch.jit.export
def reorder_incremental_state(
self,
incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
new_order: Tensor,
):
"""Scriptable reorder incremental state in the transformer."""
for layer in self.layers:
layer.reorder_incremental_state(incremental_state, new_order)
def upgrade_state_dict_named(self, state_dict, name):
"""Upgrade a (possibly old) state dict for new versions of fairseq."""

View File

@ -355,18 +355,6 @@ class TransformerDecoderLayer(nn.Module):
def make_generation_fast_(self, need_attn: bool = False, **kwargs):
self.need_attn = need_attn
@torch.jit.export
def reorder_incremental_state(
self,
incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
new_order: Tensor,
):
"""Scriptable reorder incremental state in transformer layers."""
self.self_attn.reorder_incremental_state(incremental_state, new_order)
if self.encoder_attn is not None:
self.encoder_attn.reorder_incremental_state(incremental_state, new_order)
def Linear(in_features, out_features, bias=True):
m = nn.Linear(in_features, out_features, bias)

View File

@ -709,11 +709,7 @@ class EnsembleModel(nn.Module):
decoder_out = model.decoder.forward(tokens, encoder_out=encoder_out)
attn: Optional[Tensor] = None
# __len__ is not supported in Tuple in Script.
decoder_len = 0
for _ in decoder_out:
decoder_len += 1
decoder_len = len(decoder_out)
if decoder_len > 1 and decoder_out[1] is not None:
if isinstance(decoder_out[1], Tensor):
attn = decoder_out[1]

View File

@ -109,6 +109,7 @@ class TestJitSequenceGeneratorBase(unittest.TestCase):
class TestJitSequeneceGenerator(TestJitSequenceGeneratorBase):
@unittest.skipIf(
torch.__version__ < "1.5.0", "Targeting OSS scriptability for the 1.5 release"
)
@ -127,6 +128,10 @@ class TestJitSequeneceGenerator(TestJitSequenceGeneratorBase):
class TestJitEnsemble(TestJitSequenceGeneratorBase):
@unittest.skipIf(
torch.__version__ < "1.5.0", "Targeting OSS scriptability for the 1.5 release"
)
def test_export_ensemble_model(self):
model = self.transformer_model
ensemble_models = EnsembleModel([model])