Make TransformerDecoupled model scriptable (#1125)

Summary:
Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1125

Pull Request resolved: https://github.com/pytorch/translate/pull/695

Pull Request resolved: https://github.com/pytorch/fairseq/pull/1927

- Switches the model to the scripted sequence generator recently implemented in fairseq. Involved making the input/ouput format of this model to conform to that in Fairseq TransformerEncoder/Decoder
- Modify the `EncoderOut` format for fairseq transformer and added optional fields needed for copy ptr decoder
- Switches to using WordEmbedding directly instead of the non scriptable EmbeddingList for src/trg embedding layer
- Small assorted syntactic changes to make it jit scriptable
- Adds a torchscriptify method for this model. Preliminarily latency seems similar to the unexported model. Also verified that the outputs match
- Currently the Roberta decoupled model is not scriptable because the base TransformerSentenceEncoder it is based on is not scriptable. We can look at adding that later

Reviewed By: einolghozati

Differential Revision: D20687247

fbshipit-source-id: 8232972bba2f1b2df4100f3c1776b6bad08a71db
This commit is contained in:
Anchit Gupta 2020-04-01 17:51:32 -07:00 committed by Facebook GitHub Bot
parent 0e608fdba6
commit f6f092f489
6 changed files with 52 additions and 19 deletions

View File

@ -14,6 +14,8 @@ EncoderOut = NamedTuple(
("encoder_padding_mask", Tensor), # B x T
("encoder_embedding", Tensor), # B x T x C
("encoder_states", Optional[List[Tensor]]), # List[T x B x C]
("src_tokens", Optional[Tensor]), # B x T
("src_lengths", Optional[Tensor]), # B x 1
],
)

View File

@ -485,6 +485,8 @@ class TransformerEncoder(FairseqEncoder):
encoder_padding_mask=encoder_padding_mask, # B x T
encoder_embedding=encoder_embedding, # B x T x C
encoder_states=encoder_states, # List[T x B x C]
src_tokens=None,
src_lengths=None,
)
@torch.jit.export
@ -516,6 +518,13 @@ class TransformerEncoder(FairseqEncoder):
if encoder_out.encoder_embedding is None
else encoder_out.encoder_embedding.index_select(0, new_order)
)
src_tokens = encoder_out.src_tokens
if src_tokens is not None:
src_tokens = src_tokens.index_select(0, new_order)
src_lengths = encoder_out.src_lengths
if src_lengths is not None:
src_lengths = src_lengths.index_select(0, new_order)
encoder_states = encoder_out.encoder_states
if encoder_states is not None:
@ -527,6 +536,8 @@ class TransformerEncoder(FairseqEncoder):
encoder_padding_mask=new_encoder_out["encoder_padding_mask"], # B x T
encoder_embedding=new_encoder_out["encoder_embedding"], # B x T x C
encoder_states=encoder_states, # List[T x B x C]
src_tokens=src_tokens, # B x T
src_lengths=src_lengths, # B x 1
)
def max_positions(self):

View File

@ -22,7 +22,4 @@ def gelu_accurate(x):
def gelu(x: torch.Tensor) -> torch.Tensor:
if hasattr(torch.nn.functional, "gelu"):
return torch.nn.functional.gelu(x.float()).type_as(x)
else:
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
return torch.nn.functional.gelu(x.float()).type_as(x)

View File

@ -3,9 +3,13 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch.nn as nn
from typing import Dict, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from fairseq import utils
from torch import Tensor
class LearnedPositionalEmbedding(nn.Embedding):
@ -16,12 +20,7 @@ class LearnedPositionalEmbedding(nn.Embedding):
position ids are passed to the forward function.
"""
def __init__(
self,
num_embeddings: int,
embedding_dim: int,
padding_idx: int,
):
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int):
super().__init__(num_embeddings, embedding_dim, padding_idx)
self.onnx_trace = False
if self.padding_idx is not None:
@ -29,19 +28,34 @@ class LearnedPositionalEmbedding(nn.Embedding):
else:
self.max_positions = self.num_embeddings
def forward(self, input, incremental_state=None, positions=None):
def forward(
self,
input: Tensor,
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
positions: Optional[Tensor] = None,
):
"""Input is expected to be of size [bsz x seqlen]."""
assert (
(positions is None) or (self.padding_idx is None)
assert (positions is None) or (
self.padding_idx is None
), "If positions is pre-computed then padding_idx should not be set."
if positions is None:
if incremental_state is not None:
# positions is the same for every token when decoding a single step
# Without the int() cast, it doesn't work in some cases when exporting to ONNX
positions = input.data.new(1, 1).fill_(int(self.padding_idx + input.size(1)))
positions = torch.zeros(
(1, 1), device=input.device, dtype=input.dtype
).fill_(int(self.padding_idx + input.size(1)))
else:
positions = utils.make_positions(
input, self.padding_idx, onnx_trace=self.onnx_trace,
input, self.padding_idx, onnx_trace=self.onnx_trace
)
return super().forward(positions)
return F.embedding(
positions,
self.weight,
self.padding_idx,
self.max_norm,
self.norm_type,
self.scale_grad_by_freq,
self.sparse,
)

View File

@ -2,6 +2,7 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import Optional
import torch
import torch.nn as nn
@ -60,8 +61,8 @@ class TransformerSentenceEncoderLayer(nn.Module):
def forward(
self,
x: torch.Tensor,
self_attn_mask: torch.Tensor = None,
self_attn_padding_mask: torch.Tensor = None,
self_attn_mask: Optional[torch.Tensor] = None,
self_attn_padding_mask: Optional[torch.Tensor] = None,
):
"""
LayerNorm is applied either before or after the self-attention/ffn

View File

@ -176,6 +176,8 @@ class TestEncoder(FairseqEncoder):
encoder_padding_mask=None,
encoder_embedding=None,
encoder_states=None,
src_tokens=None,
src_lengths=None,
)
def reorder_encoder_out(self, encoder_out, new_order):
@ -184,6 +186,8 @@ class TestEncoder(FairseqEncoder):
encoder_padding_mask=None,
encoder_embedding=None,
encoder_states=None,
src_tokens=None,
src_lengths=None,
)
@ -264,6 +268,8 @@ class TestReshapingEncoder(FairseqEncoder):
encoder_padding_mask=None,
encoder_embedding=None,
encoder_states=None,
src_tokens=None,
src_lengths=None,
)
def reorder_encoder_out(self, encoder_out, new_order):
@ -272,6 +278,8 @@ class TestReshapingEncoder(FairseqEncoder):
encoder_padding_mask=None,
encoder_embedding=None,
encoder_states=None,
src_tokens=None,
src_lengths=None,
)