Improve torchscript compatibility of transfomer and transformer pg (#3247)

Summary:
# Before submitting

- [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements)
- [x] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)?
- [x] Did you make sure to update the docs?
- [x] Did you write any new necessary tests?

## What does this PR do?

Fixes https://github.com/pytorch/fairseq/issues/3246
Fixes https://github.com/pytorch/fairseq/issues/3248

## PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

## Did you have fun?
Make sure you had fun coding �

Pull Request resolved: https://github.com/pytorch/fairseq/pull/3247

Reviewed By: myleott

Differential Revision: D26513267

Pulled By: lematt1991

fbshipit-source-id: 958de0b3a58a0dd2a56bd6c6d7fb2644a89f6746
This commit is contained in:
Miguel Del-Agua 2021-02-22 14:21:36 -08:00 committed by Facebook GitHub Bot
parent 38258a79a4
commit 808b751597
4 changed files with 133 additions and 20 deletions

View File

@ -4,13 +4,12 @@
# LICENSE file in the root directory of this source tree.
import logging
from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, List, Tuple
import torch
import torch.nn as nn
from fairseq import metrics, utils
from fairseq.models import register_model, register_model_architecture
from fairseq.models.fairseq_encoder import EncoderOut
from fairseq.models.transformer import (
DEFAULT_MAX_SOURCE_POSITIONS,
DEFAULT_MAX_TARGET_POSITIONS,
@ -155,7 +154,13 @@ class TransformerPointerGeneratorEncoder(TransformerEncoder):
to the decoder.
"""
def forward(self, src_tokens, src_lengths, **kwargs):
def forward(
self,
src_tokens,
src_lengths: Optional[Tensor] = None,
return_all_hiddens: bool = False,
token_embeddings: Optional[Tensor] = None
):
"""
Runs the `forward()` method of the parent Transformer class. Then adds
the source tokens into the encoder output tuple.
@ -169,6 +174,10 @@ class TransformerPointerGeneratorEncoder(TransformerEncoder):
shape `(batch, src_len)`
src_lengths (torch.LongTensor): lengths of each source sentence of
shape `(batch)`
return_all_hiddens (bool, optional): also return all of the
intermediate hidden states (default: False).
token_embeddings (torch.Tensor, optional): precomputed embeddings
default `None` will recompute embeddings
Returns:
namedtuple:
@ -184,7 +193,15 @@ class TransformerPointerGeneratorEncoder(TransformerEncoder):
- **src_tokens** (Tensor): input token ids of shape
`(batch, src_len)`
"""
encoder_out = super().forward(src_tokens, src_lengths, **kwargs)
encoder_out = self.forward_scriptable(src_tokens,
src_lengths,
return_all_hiddens,
token_embeddings)
# The Pytorch Mobile lite interpreter does not supports returning NamedTuple in
# `forward` so we use a dictionary instead.
# TorchScript does not support mixed values so the values are all lists.
# The empty list is equivalent to None.
return {
"encoder_out": encoder_out["encoder_out"], # T x B x C
"encoder_padding_mask": encoder_out["encoder_padding_mask"], # B x T
@ -236,7 +253,7 @@ class TransformerPointerGeneratorDecoder(TransformerDecoder):
def forward(
self,
prev_output_tokens,
encoder_out: Optional[EncoderOut] = None,
encoder_out: Optional[Dict[str, List[Tensor]]] = None,
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
features_only: bool = False,
alignment_layer: Optional[int] = 0,
@ -248,8 +265,8 @@ class TransformerPointerGeneratorDecoder(TransformerDecoder):
Args:
prev_output_tokens (LongTensor): previous decoder outputs of shape
`(batch, tgt_len)`, for teacher forcing
encoder_out (EncoderOut, optional): output from the encoder, used
for encoder-side attention
encoder_out (optional): output from the encoder, used for
encoder-side attention
incremental_state (dict, optional): dictionary used for storing
state during :ref:`Incremental decoding`
features_only (bool, optional): only return features without
@ -284,10 +301,21 @@ class TransformerPointerGeneratorDecoder(TransformerDecoder):
predictors = torch.cat((prev_output_embed, x), 2)
p_gens = self.project_p_gens(predictors)
p_gens = torch.sigmoid(p_gens)
x = self.output_layer(x, extra["attn"][0], encoder_out["src_tokens"][0], p_gens)
# Torchscript complains if encoder_out or attn are None because
# `output_layer()` signature expects tensors instead
attn: Optional[Tensor] = extra["attn"][0]
assert encoder_out is not None
assert attn is not None
x = self.output_layer(x, attn, encoder_out["src_tokens"][0], p_gens)
return x, extra
def output_layer(self, features, attn, src_tokens, p_gens, **kwargs):
def output_layer(
self,
features: Tensor,
attn: Tensor,
src_tokens: Tensor,
p_gens: Tensor
) -> Tensor:
"""
Project features to the vocabulary size and mix with the attention
distributions.
@ -296,7 +324,10 @@ class TransformerPointerGeneratorDecoder(TransformerDecoder):
p_gens = self.force_p_gen
# project back to size of vocabulary
logits = super().output_layer(features, **kwargs)
if self.adaptive_softmax is None:
logits = self.output_projection(features)
else:
logits = features
batch_size = logits.shape[0]
output_length = logits.shape[1]
@ -306,7 +337,7 @@ class TransformerPointerGeneratorDecoder(TransformerDecoder):
# The final output distribution will be a mixture of the normal output
# distribution (softmax of logits) and attention weights.
gen_dists = super().get_normalized_probs(
gen_dists = self.get_normalized_probs_scriptable(
(logits, None), log_probs=False, sample=None
)
gen_dists = torch.mul(gen_dists, p_gens)
@ -330,7 +361,12 @@ class TransformerPointerGeneratorDecoder(TransformerDecoder):
# Final distributions, [batch_size, output_length, num_types].
return gen_dists + attn_dists
def get_normalized_probs(self, net_output, log_probs, sample):
def get_normalized_probs(
self,
net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]],
log_probs: bool,
sample: Optional[Dict[str, Tensor]] = None,
):
"""
Get normalized probabilities (or log probs) from a net's output.
Pointer-generator network output is already normalized.
@ -375,8 +411,19 @@ class Embedding(nn.Embedding):
"""
__constants__ = ["unk_idx"]
def __init__(self, num_embeddings, embedding_dim, padding_idx, unk_idx):
super().__init__(num_embeddings, embedding_dim, padding_idx=padding_idx)
# Torchscript: Inheriting from Embedding class produces an error when exporting to Torchscript
# -> RuntimeError: Unable to cast Python instance to C++ type (compile in debug mode for details
# It's happening because max_norm attribute from nn.Embedding is None by default and it cannot be
# cast to a C++ type
def __init__(
self,
num_embeddings: int,
embedding_dim: int,
padding_idx: Optional[int],
unk_idx: int,
max_norm: Optional[float] = float("inf"),
):
super().__init__(num_embeddings, embedding_dim, padding_idx=padding_idx, max_norm=max_norm)
self.unk_idx = unk_idx
nn.init.normal_(self.weight, mean=0, std=embedding_dim ** -0.5)
nn.init.constant_(self.weight[padding_idx], 0)
@ -385,7 +432,10 @@ class Embedding(nn.Embedding):
input = torch.where(
input >= self.num_embeddings, torch.ones_like(input) * self.unk_idx, input
)
return super().forward(input)
return nn.functional.embedding(
input, self.weight, self.padding_idx, self.max_norm,
self.norm_type, self.scale_grad_by_freq, self.sparse
)
@register_model_architecture(

View File

@ -64,6 +64,19 @@ class FairseqDecoder(nn.Module):
sample: Optional[Dict[str, Tensor]] = None,
):
"""Get normalized probabilities (or log probs) from a net's output."""
return self.get_normalized_probs_scriptable(net_output, log_probs, sample)
# TorchScript doesn't support super() method so that the scriptable Subclass
# can't access the base class model in Torchscript.
# Current workaround is to add a helper function with different name and
# call the helper function from scriptable Subclass.
def get_normalized_probs_scriptable(
self,
net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]],
log_probs: bool,
sample: Optional[Dict[str, Tensor]] = None,
):
"""Get normalized probabilities (or log probs) from a net's output."""
if hasattr(self, "adaptive_softmax") and self.adaptive_softmax is not None:
if sample is not None:

View File

@ -422,6 +422,45 @@ class TransformerEncoder(FairseqEncoder):
token_embeddings (torch.Tensor, optional): precomputed embeddings
default `None` will recompute embeddings
Returns:
dict:
- **encoder_out** (Tensor): the last encoder layer's output of
shape `(src_len, batch, embed_dim)`
- **encoder_padding_mask** (ByteTensor): the positions of
padding elements of shape `(batch, src_len)`
- **encoder_embedding** (Tensor): the (scaled) embedding lookup
of shape `(batch, src_len, embed_dim)`
- **encoder_states** (List[Tensor]): all intermediate
hidden states of shape `(src_len, batch, embed_dim)`.
Only populated if *return_all_hiddens* is True.
"""
return self.forward_scriptable(src_tokens,
src_lengths,
return_all_hiddens,
token_embeddings)
# TorchScript doesn't support super() method so that the scriptable Subclass
# can't access the base class model in Torchscript.
# Current workaround is to add a helper function with different name and
# call the helper function from scriptable Subclass.
def forward_scriptable(
self,
src_tokens,
src_lengths: Optional[torch.Tensor] = None,
return_all_hiddens: bool = False,
token_embeddings: Optional[torch.Tensor] = None,
):
"""
Args:
src_tokens (LongTensor): tokens in the source language of shape
`(batch, src_len)`
src_lengths (torch.LongTensor): lengths of each source sentence of
shape `(batch)`
return_all_hiddens (bool, optional): also return all of the
intermediate hidden states (default: False).
token_embeddings (torch.Tensor, optional): precomputed embeddings
default `None` will recompute embeddings
Returns:
dict:
- **encoder_out** (Tensor): the last encoder layer's output of
@ -787,13 +826,11 @@ class TransformerDecoder(FairseqIncrementalDecoder):
alignment_layer = self.num_layers - 1
# embed positions
positions = (
self.embed_positions(
positions = None
if self.embed_positions is not None:
positions = self.embed_positions(
prev_output_tokens, incremental_state=incremental_state
)
if self.embed_positions is not None
else None
)
if incremental_state is not None:
prev_output_tokens = prev_output_tokens[:, -1:]

View File

@ -103,6 +103,19 @@ class TestExportModels(unittest.TestCase):
scripted = torch.jit.script(model)
_test_save_and_load(scripted)
@unittest.skipIf(
torch.__version__ < "1.6.0", "Targeting OSS scriptability for the 1.6 release"
)
def test_export_transformer_no_token_pos_emb(self):
task, parser = get_dummy_task_and_parser()
TransformerModel.add_args(parser)
args = parser.parse_args([])
args.no_token_positional_embeddings = True
model = TransformerModel.build_model(args, task)
scripted = torch.jit.script(model)
_test_save_and_load(scripted)
if __name__ == "__main__":
unittest.main()