Revert D20797390: Script reoder_incremental_state in fairseq baseline model

Differential Revision:
D20797390

Original commit changeset: ab29874973ad

fbshipit-source-id: efd2d720c96ee90d1e8dc36178e04f0bf5510278
This commit is contained in:
Aapo Kyrola 2020-04-07 00:57:53 -07:00 committed by Facebook GitHub Bot
parent d369c88019
commit 8a528888e4
5 changed files with 39 additions and 21 deletions

View File

@ -2,11 +2,9 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import Dict, Optional
from fairseq.models import FairseqDecoder
from fairseq.incremental_decoding_utils import with_incremental_state
from torch import Tensor
@with_incremental_state
@ -61,25 +59,24 @@ class FairseqIncrementalDecoder(FairseqDecoder):
"""
raise NotImplementedError
def reorder_incremental_state(
self,
incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
new_order: Tensor,
):
def reorder_incremental_state(self, incremental_state, new_order):
"""Reorder incremental state.
This should be called when the order of the input has changed from the
previous time step. A typical use case is beam search, where the input
order changes between time steps based on the selection of beams.
"""
seen: Dict[int, Optional[Tensor]] = {}
for _, module in self.named_modules():
if hasattr(module, 'reorder_incremental_state'):
if id(module) not in seen and module is not self:
seen[id(module)] = None
result = module.reorder_incremental_state(incremental_state, new_order)
if result is not None:
incremental_state = result
seen = set()
for module in self.modules():
if (
module != self
and hasattr(module, 'reorder_incremental_state')
and module not in seen
):
seen.add(module)
result = module.reorder_incremental_state(incremental_state, new_order)
if result is not None:
incremental_state = result
def set_beam_size(self, beam_size):
"""Sets the beam size in the decoder and all children."""

View File

@ -879,6 +879,16 @@ class TransformerDecoder(FairseqIncrementalDecoder):
self._future_mask = self._future_mask.to(tensor)
return self._future_mask[:dim, :dim]
# Overwirte the method to temporaily soppurt jit scriptable in Transformer
@torch.jit.export
def reorder_incremental_state(
self,
incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
new_order: Tensor,
):
"""Scriptable reorder incremental state in the transformer."""
for layer in self.layers:
layer.reorder_incremental_state(incremental_state, new_order)
def upgrade_state_dict_named(self, state_dict, name):
"""Upgrade a (possibly old) state dict for new versions of fairseq."""

View File

@ -355,6 +355,18 @@ class TransformerDecoderLayer(nn.Module):
def make_generation_fast_(self, need_attn: bool = False, **kwargs):
self.need_attn = need_attn
@torch.jit.export
def reorder_incremental_state(
self,
incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
new_order: Tensor,
):
"""Scriptable reorder incremental state in transformer layers."""
self.self_attn.reorder_incremental_state(incremental_state, new_order)
if self.encoder_attn is not None:
self.encoder_attn.reorder_incremental_state(incremental_state, new_order)
def Linear(in_features, out_features, bias=True):
m = nn.Linear(in_features, out_features, bias)

View File

@ -709,7 +709,11 @@ class EnsembleModel(nn.Module):
decoder_out = model.decoder.forward(tokens, encoder_out=encoder_out)
attn: Optional[Tensor] = None
decoder_len = len(decoder_out)
# __len__ is not supported in Tuple in Script.
decoder_len = 0
for _ in decoder_out:
decoder_len += 1
if decoder_len > 1 and decoder_out[1] is not None:
if isinstance(decoder_out[1], Tensor):
attn = decoder_out[1]

View File

@ -109,7 +109,6 @@ class TestJitSequenceGeneratorBase(unittest.TestCase):
class TestJitSequeneceGenerator(TestJitSequenceGeneratorBase):
@unittest.skipIf(
torch.__version__ < "1.5.0", "Targeting OSS scriptability for the 1.5 release"
)
@ -128,10 +127,6 @@ class TestJitSequeneceGenerator(TestJitSequenceGeneratorBase):
class TestJitEnsemble(TestJitSequenceGeneratorBase):
@unittest.skipIf(
torch.__version__ < "1.5.0", "Targeting OSS scriptability for the 1.5 release"
)
def test_export_ensemble_model(self):
model = self.transformer_model
ensemble_models = EnsembleModel([model])