Rewrite the unit test of sequence generator

Summary:
1. Overwrite the base class function `get_normalized_probs` in scriptable TransformerModel
2. Change the unit test setup to match the Transformer decoder output format
3. Initialze the buffer in the simple sequence generator [WIP]
   1. It is the initial step to script the sequence generator from simple scriptable version.
4. Refactor the unit test of simple sequence generator.
5. Change the input format of simple sequence generator and unit test.

Reviewed By: myleott

Differential Revision: D20017859

fbshipit-source-id: a3e93b57c22e49840e460469fa2b1c530346886d
This commit is contained in:
Chen Liu 2020-02-26 11:06:29 -08:00 committed by Facebook Github Bot
parent 07eed27d9f
commit fdfdbec9e2
3 changed files with 29 additions and 3 deletions

View File

@ -50,6 +50,19 @@ class BaseFairseqModel(nn.Module):
sample: Optional[Dict[str, Tensor]] = None,
):
"""Get normalized probabilities (or log probs) from a net's output."""
return self.get_normalized_probs_scriptable(net_output, log_probs, sample)
# TorchScript doesn't support super() method so that the scriptable Subclass
# can't access the base class model in Torchscript.
# Current workaround is to add a helper function with different name and
# call the helper function from scriptable Subclass.
def get_normalized_probs_scriptable(
self,
net_output: Tuple[Tensor, Dict[str, List[Optional[Tensor]]]],
log_probs: bool,
sample: Optional[Dict[str, Tensor]] = None,
):
"""Scriptable helper function for get_normalized_probs in ~BaseFairseqModel"""
if hasattr(self, "decoder"):
return self.decoder.get_normalized_probs(net_output, log_probs, sample)
elif torch.is_tensor(net_output):

View File

@ -4,13 +4,12 @@
# LICENSE file in the root directory of this source tree.
import math
from typing import Any, Dict, List, NamedTuple, Optional
from typing import Any, Dict, List, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from fairseq import options, utils
from fairseq.models.fairseq_encoder import EncoderOut
from fairseq.models import (
FairseqEncoder,
FairseqEncoderDecoderModel,
@ -18,6 +17,7 @@ from fairseq.models import (
register_model,
register_model_architecture,
)
from fairseq.models.fairseq_encoder import EncoderOut
from fairseq.modules import (
AdaptiveSoftmax,
LayerNorm,
@ -272,6 +272,19 @@ class TransformerModel(FairseqEncoderDecoderModel):
)
return decoder_out
# Since get_normalized_probs is in the Fairseq Model which is not scriptable,
# I rewrite the get_normalized_probs from Base Class to call the
# helper function in the Base Class.
@torch.jit.export
def get_normalized_probs(
self,
net_output: Tuple[Tensor, Dict[str, List[Optional[Tensor]]]],
log_probs: bool,
sample: Optional[Dict[str, Tensor]] = None,
):
"""Get normalized probabilities (or log probs) from a net's output."""
return self.get_normalized_probs_scriptable(net_output, log_probs, sample)
@register_model("transformer_align")
class TransformerAlignModel(TransformerModel):

View File

@ -220,7 +220,7 @@ class TestIncrementalDecoder(FairseqIncrementalDecoder):
attn = torch.rand(bbsz, tgt_len, src_len)
dev = prev_output_tokens.device
return probs.to(dev), attn.to(dev)
return probs.to(dev), {"attn": [attn.to(dev)]}
def get_normalized_probs(self, net_output, log_probs, _):
# the decoder returns probabilities directly