OSS tracing compliant transformer to unbreak master (#1299)

Summary:
Pull Request resolved: https://github.com/pytorch/fairseq/pull/1299

 LevT calls into tracing compliant transformer we didn't plan to OSS earlier. This is a workaround to unbreak the master. Will revisit and simplify the code later.

Reviewed By: pipibjc

Differential Revision: D18110339

fbshipit-source-id: 3bb51c56c2c20f45db1d5786d030b374b412eab1
This commit is contained in:
Ning Dong 2019-10-24 09:54:33 -07:00 committed by Facebook Github Bot
parent d0358bb38e
commit 5b086a0c17
2 changed files with 626 additions and 1 deletions

View File

@ -16,7 +16,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from fairseq.models import register_model, register_model_architecture
from fairseq.models.fb_tracing_transformer import (
from fairseq.models.tracing_compliant_transformer import (
TracingTransformerDecoder,
TracingTransformerEncoder,
TracingTransformerModel,

View File

@ -0,0 +1,625 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from fairseq import options, utils
from fairseq.models import (
FairseqEncoder,
FairseqIncrementalDecoder,
FairseqEncoderDecoderModel,
register_model,
register_model_architecture,
)
from fairseq.models.transformer import Embedding, Linear, base_architecture
from fairseq.modules import (
AdaptiveSoftmax,
LayerNorm,
PositionalEmbedding,
SinusoidalPositionalEmbedding,
TransformerDecoderLayer,
TransformerEncoderLayer,
)
DEFAULT_MAX_SOURCE_POSITIONS = 1024
DEFAULT_MAX_TARGET_POSITIONS = 1024
@register_model('tracing_transformer')
class TracingTransformerModel(FairseqEncoderDecoderModel):
"""
Transformer model from `"Attention Is All You Need" (Vaswani, et al, 2017)
<https://arxiv.org/abs/1706.03762>`_.
Args:
encoder (TransformerEncoder): the encoder
decoder (TransformerDecoder): the decoder
The Transformer model provides the following named architectures and
command-line arguments:
.. argparse::
:ref: fairseq.models.transformer_parser
:prog:
"""
@classmethod
def hub_models(cls):
# fmt: off
return {
'transformer.wmt14.en-fr': 'https://dl.fbaipublicfiles.com/fairseq/models/wmt14.en-fr.joined-dict.transformer.tar.bz2',
'transformer.wmt16.en-de': 'https://dl.fbaipublicfiles.com/fairseq/models/wmt16.en-de.joined-dict.transformer.tar.bz2',
'transformer.wmt18.en-de': 'https://dl.fbaipublicfiles.com/fairseq/models/wmt18.en-de.ensemble.tar.gz',
'transformer.wmt19.en-de': 'https://dl.fbaipublicfiles.com/fairseq/models/wmt19.en-de.joined-dict.ensemble.tar.gz',
'transformer.wmt19.en-ru': 'https://dl.fbaipublicfiles.com/fairseq/models/wmt19.en-ru.ensemble.tar.gz',
'transformer.wmt19.de-en': 'https://dl.fbaipublicfiles.com/fairseq/models/wmt19.de-en.joined-dict.ensemble.tar.gz',
'transformer.wmt19.ru-en': 'https://dl.fbaipublicfiles.com/fairseq/models/wmt19.ru-en.ensemble.tar.gz',
'transformer.wmt19.en-de.single_model': 'https://dl.fbaipublicfiles.com/fairseq/models/wmt19.en-de.joined-dict.single_model.tar.gz',
'transformer.wmt19.en-ru.single_model': 'https://dl.fbaipublicfiles.com/fairseq/models/wmt19.en-ru.single_model.tar.gz',
'transformer.wmt19.de-en.single_model': 'https://dl.fbaipublicfiles.com/fairseq/models/wmt19.de-en.joined-dict.single_model.tar.gz',
'transformer.wmt19.ru-en.single_model': 'https://dl.fbaipublicfiles.com/fairseq/models/wmt19.ru-en.single_model.tar.gz',
}
# fmt: on
def __init__(self, encoder, decoder):
super().__init__(encoder, decoder)
self.supports_align_args = True
@staticmethod
def add_args(parser):
"""Add model-specific arguments to the parser."""
# fmt: off
parser.add_argument('--activation-fn',
choices=utils.get_available_activation_fns(),
help='activation function to use')
parser.add_argument('--dropout', type=float, metavar='D',
help='dropout probability')
parser.add_argument('--attention-dropout', type=float, metavar='D',
help='dropout probability for attention weights')
parser.add_argument('--activation-dropout', '--relu-dropout', type=float, metavar='D',
help='dropout probability after activation in FFN.')
parser.add_argument('--encoder-embed-path', type=str, metavar='STR',
help='path to pre-trained encoder embedding')
parser.add_argument('--encoder-embed-dim', type=int, metavar='N',
help='encoder embedding dimension')
parser.add_argument('--encoder-ffn-embed-dim', type=int, metavar='N',
help='encoder embedding dimension for FFN')
parser.add_argument('--encoder-layers', type=int, metavar='N',
help='num encoder layers')
parser.add_argument('--encoder-attention-heads', type=int, metavar='N',
help='num encoder attention heads')
parser.add_argument('--encoder-normalize-before', action='store_true',
help='apply layernorm before each encoder block')
parser.add_argument('--encoder-learned-pos', action='store_true',
help='use learned positional embeddings in the encoder')
parser.add_argument('--decoder-embed-path', type=str, metavar='STR',
help='path to pre-trained decoder embedding')
parser.add_argument('--decoder-embed-dim', type=int, metavar='N',
help='decoder embedding dimension')
parser.add_argument('--decoder-ffn-embed-dim', type=int, metavar='N',
help='decoder embedding dimension for FFN')
parser.add_argument('--decoder-layers', type=int, metavar='N',
help='num decoder layers')
parser.add_argument('--decoder-attention-heads', type=int, metavar='N',
help='num decoder attention heads')
parser.add_argument('--decoder-learned-pos', action='store_true',
help='use learned positional embeddings in the decoder')
parser.add_argument('--decoder-normalize-before', action='store_true',
help='apply layernorm before each decoder block')
parser.add_argument('--share-decoder-input-output-embed', action='store_true',
help='share decoder input and output embeddings')
parser.add_argument('--share-all-embeddings', action='store_true',
help='share encoder, decoder and output embeddings'
' (requires shared dictionary and embed dim)')
parser.add_argument('--no-token-positional-embeddings', default=False, action='store_true',
help='if set, disables positional embeddings (outside self attention)')
parser.add_argument('--adaptive-softmax-cutoff', metavar='EXPR',
help='comma separated list of adaptive softmax cutoff points. '
'Must be used with adaptive_loss criterion'),
parser.add_argument('--adaptive-softmax-dropout', type=float, metavar='D',
help='sets adaptive softmax dropout for the tail projections')
# args for "Cross+Self-Attention for Transformer Models" (Peitz et al., 2019)
parser.add_argument('--no-cross-attention', default=False, action='store_true',
help='do not perform cross-attention')
parser.add_argument('--cross-self-attention', default=False, action='store_true',
help='perform cross+self-attention')
parser.add_argument('--layer-wise-attention', default=False, action='store_true',
help='perform layer-wise attention (cross-attention or cross+self-attention)')
# fmt: on
@classmethod
def build_model(cls, args, task):
"""Build a new model instance."""
# make sure all arguments are present in older models
base_architecture(args)
if not hasattr(args, 'max_source_positions'):
args.max_source_positions = DEFAULT_MAX_SOURCE_POSITIONS
if not hasattr(args, 'max_target_positions'):
args.max_target_positions = DEFAULT_MAX_TARGET_POSITIONS
src_dict, tgt_dict = task.source_dictionary, task.target_dictionary
def build_embedding(dictionary, embed_dim, path=None):
num_embeddings = len(dictionary)
padding_idx = dictionary.pad()
emb = Embedding(num_embeddings, embed_dim, padding_idx)
# if provided, load from preloaded dictionaries
if path:
embed_dict = utils.parse_embedding(path)
utils.load_embedding(embed_dict, dictionary, emb)
return emb
if args.share_all_embeddings:
if src_dict != tgt_dict:
raise ValueError('--share-all-embeddings requires a joined dictionary')
if args.encoder_embed_dim != args.decoder_embed_dim:
raise ValueError(
'--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim')
if args.decoder_embed_path and (
args.decoder_embed_path != args.encoder_embed_path):
raise ValueError('--share-all-embeddings not compatible with --decoder-embed-path')
encoder_embed_tokens = build_embedding(
src_dict, args.encoder_embed_dim, args.encoder_embed_path
)
decoder_embed_tokens = encoder_embed_tokens
args.share_decoder_input_output_embed = True
else:
encoder_embed_tokens = build_embedding(
src_dict, args.encoder_embed_dim, args.encoder_embed_path
)
decoder_embed_tokens = build_embedding(
tgt_dict, args.decoder_embed_dim, args.decoder_embed_path
)
encoder = cls.build_encoder(args, src_dict, encoder_embed_tokens)
decoder = cls.build_decoder(args, tgt_dict, decoder_embed_tokens)
return cls(encoder, decoder)
@classmethod
def build_encoder(cls, args, src_dict, embed_tokens):
return TracingTransformerEncoder(args, src_dict, embed_tokens)
@classmethod
def build_decoder(cls, args, tgt_dict, embed_tokens):
return TracingTransformerDecoder(
args,
tgt_dict,
embed_tokens,
no_encoder_attn=getattr(args, 'no_cross_attention', False),
)
class TracingTransformerEncoder(FairseqEncoder):
"""
Transformer encoder consisting of *args.encoder_layers* layers. Each layer
is a :class:`TransformerEncoderLayer`.
Args:
args (argparse.Namespace): parsed command-line arguments
dictionary (~fairseq.data.Dictionary): encoding dictionary
embed_tokens (torch.nn.Embedding): input embedding
"""
def __init__(self, args, dictionary, embed_tokens):
super().__init__(dictionary)
self.register_buffer('version', torch.Tensor([3]))
self.dropout = args.dropout
embed_dim = embed_tokens.embedding_dim
self.padding_idx = embed_tokens.padding_idx
self.max_source_positions = args.max_source_positions
self.embed_tokens = embed_tokens
self.embed_scale = math.sqrt(embed_dim)
self.embed_positions = PositionalEmbedding(
args.max_source_positions, embed_dim, self.padding_idx,
learned=args.encoder_learned_pos,
) if not args.no_token_positional_embeddings else None
self.layer_wise_attention = getattr(args, 'layer_wise_attention', False)
self.layers = nn.ModuleList([])
self.layers.extend([
TransformerEncoderLayer(args)
for i in range(args.encoder_layers)
])
if args.encoder_normalize_before:
self.layer_norm = LayerNorm(embed_dim)
else:
self.layer_norm = None
def forward_embedding(self, src_tokens):
# embed tokens and positions
embed = self.embed_scale * self.embed_tokens(src_tokens)
if self.embed_positions is not None:
x = embed + self.embed_positions(src_tokens)
x = F.dropout(x, p=self.dropout, training=self.training)
return x, embed
def forward(self, src_tokens, src_lengths, cls_input=None, return_all_hiddens=False):
"""
Args:
src_tokens (LongTensor): tokens in the source language of shape
`(batch, src_len)`
src_lengths (torch.LongTensor): lengths of each source sentence of
shape `(batch)`
return_all_hiddens (bool, optional): also return all of the
intermediate hidden states (default: False).
Returns:
dict:
- **encoder_out** (Tensor): the last encoder layer's output of
shape `(src_len, batch, embed_dim)`
- **encoder_padding_mask** (ByteTensor): the positions of
padding elements of shape `(batch, src_len)`
- **encoder_states** (List[Tensor]): all intermediate
hidden states of shape `(src_len, batch, embed_dim)`.
Only populated if *return_all_hiddens* is True.
"""
if self.layer_wise_attention:
return_all_hiddens = True
x, encoder_embedding = self.forward_embedding(src_tokens)
# B x T x C -> T x B x C
x = x.transpose(0, 1)
# compute padding mask
encoder_padding_mask = src_tokens.eq(self.padding_idx)
encoder_states = [] if return_all_hiddens else None
# encoder layers
for layer in self.layers:
x = layer(x, encoder_padding_mask)
if return_all_hiddens:
encoder_states.append(x)
if self.layer_norm:
x = self.layer_norm(x)
if return_all_hiddens:
encoder_states[-1] = x
if encoder_states is not None:
return x, encoder_padding_mask, encoder_embedding, encoder_states
else:
return x, encoder_padding_mask, encoder_embedding
def reorder_encoder_out(self, encoder_out, new_order):
"""
Reorder encoder output according to *new_order*.
Args:
encoder_out: output from the ``forward()`` method
new_order (LongTensor): desired order
Returns:
*encoder_out* rearranged according to *new_order*
"""
# 0: encoder_out
# 1: encoder_padding_mask
# 2: encoder_states
if encoder_out[0] is not None:
encoder_out[0] = \
encoder_out[0].index_select(1, new_order)
if encoder_out[1] is not None:
encoder_out[1] = \
encoder_out[1].index_select(0, new_order)
if len(encoder_out) == 3 and encoder_out[2] is not None:
for idx, state in enumerate(encoder_out[2]):
encoder_out[2][idx] = state.index_select(1, new_order)
return encoder_out
def max_positions(self):
"""Maximum input length supported by the encoder."""
if self.embed_positions is None:
return self.max_source_positions
return min(self.max_source_positions, self.embed_positions.max_positions())
def buffered_future_mask(self, tensor):
dim = tensor.size(0)
if not hasattr(self, '_future_mask') or self._future_mask is None or self._future_mask.device != tensor.device:
self._future_mask = torch.triu(utils.fill_with_neg_inf(tensor.new(dim, dim)), 1)
if self._future_mask.size(0) < dim:
self._future_mask = torch.triu(utils.fill_with_neg_inf(self._future_mask.resize_(dim, dim)), 1)
return self._future_mask[:dim, :dim]
def upgrade_state_dict_named(self, state_dict, name):
"""Upgrade a (possibly old) state dict for new versions of fairseq."""
if isinstance(self.embed_positions, SinusoidalPositionalEmbedding):
weights_key = '{}.embed_positions.weights'.format(name)
if weights_key in state_dict:
del state_dict[weights_key]
state_dict['{}.embed_positions._float_tensor'.format(name)] = torch.FloatTensor(1)
for i in range(len(self.layers)):
# update layer norms
self.layers[i].upgrade_state_dict_named(state_dict, "{}.layers.{}".format(name, i))
version_key = '{}.version'.format(name)
if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) < 2:
# earlier checkpoints did not normalize after the stack of layers
self.layer_norm = None
self.normalize = False
state_dict[version_key] = torch.Tensor([1])
return state_dict
class TracingTransformerDecoder(FairseqIncrementalDecoder):
"""
Transformer decoder consisting of *args.decoder_layers* layers. Each layer
is a :class:`TransformerDecoderLayer`.
Args:
args (argparse.Namespace): parsed command-line arguments
dictionary (~fairseq.data.Dictionary): decoding dictionary
embed_tokens (torch.nn.Embedding): output embedding
no_encoder_attn (bool, optional): whether to attend to encoder outputs
(default: False).
"""
def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False):
super().__init__(dictionary)
self.register_buffer('version', torch.Tensor([3]))
self.dropout = args.dropout
self.share_input_output_embed = args.share_decoder_input_output_embed
input_embed_dim = embed_tokens.embedding_dim
embed_dim = args.decoder_embed_dim
self.output_embed_dim = args.decoder_output_dim
self.padding_idx = embed_tokens.padding_idx
self.max_target_positions = args.max_target_positions
self.embed_tokens = embed_tokens
self.embed_scale = math.sqrt(embed_dim) # todo: try with input_embed_dim
self.project_in_dim = Linear(input_embed_dim, embed_dim, bias=False) if embed_dim != input_embed_dim else None
self.embed_positions = PositionalEmbedding(
args.max_target_positions, embed_dim, self.padding_idx,
learned=args.decoder_learned_pos,
) if not args.no_token_positional_embeddings else None
self.cross_self_attention = getattr(args, 'cross_self_attention', False)
self.layer_wise_attention = getattr(args, 'layer_wise_attention', False)
self.layers = nn.ModuleList([])
self.layers.extend([
TransformerDecoderLayer(args, no_encoder_attn)
for _ in range(args.decoder_layers)
])
self.adaptive_softmax = None
self.project_out_dim = Linear(embed_dim, self.output_embed_dim, bias=False) \
if embed_dim != self.output_embed_dim and not args.tie_adaptive_weights else None
if args.adaptive_softmax_cutoff is not None:
self.adaptive_softmax = AdaptiveSoftmax(
len(dictionary),
self.output_embed_dim,
options.eval_str_list(args.adaptive_softmax_cutoff, type=int),
dropout=args.adaptive_softmax_dropout,
adaptive_inputs=embed_tokens if args.tie_adaptive_weights else None,
factor=args.adaptive_softmax_factor,
tie_proj=args.tie_adaptive_proj,
)
elif not self.share_input_output_embed:
self.embed_out = nn.Parameter(torch.Tensor(len(dictionary), self.output_embed_dim))
nn.init.normal_(self.embed_out, mean=0, std=self.output_embed_dim ** -0.5)
if args.decoder_normalize_before and not getattr(args, 'no_decoder_final_norm', False):
self.layer_norm = LayerNorm(embed_dim)
else:
self.layer_norm = None
def forward(
self,
prev_output_tokens,
encoder_out=None,
incremental_state=None,
features_only=False,
**extra_args,
):
"""
Args:
prev_output_tokens (LongTensor): previous decoder outputs of shape
`(batch, tgt_len)`, for teacher forcing
encoder_out (Tensor, optional): output from the encoder, used for
encoder-side attention
incremental_state (dict): dictionary used for storing state during
:ref:`Incremental decoding`
features_only (bool, optional): only return features without
applying output layer (default: False).
Returns:
tuple:
- the decoder's output of shape `(batch, tgt_len, vocab)`
- a dictionary with any model-specific outputs
"""
x, extra = self.extract_features(
prev_output_tokens, encoder_out, incremental_state, **extra_args,
)
if not features_only:
x = self.output_layer(x)
return x, extra
def extract_features(
self,
prev_output_tokens,
encoder_out=None,
incremental_state=None,
full_context_alignment=False,
alignment_layer=None,
alignment_heads=None,
**unused,
):
"""
Similar to *forward* but only return features.
Includes several features from "Jointly Learning to Align and
Translate with Transformer Models" (Garg et al., EMNLP 2019).
Args:
full_context_alignment (bool, optional): don't apply
auto-regressive mask to self-attention (default: False).
alignment_layer (int, optional): return mean alignment over
heads at this layer (default: last layer).
alignment_heads (int, optional): only average alignment over
this many heads (default: all heads).
Returns:
tuple:
- the decoder's features of shape `(batch, tgt_len, embed_dim)`
- a dictionary with any model-specific outputs
"""
if alignment_layer is None:
alignment_layer = len(self.layers) - 1
# embed positions
positions = self.embed_positions(
prev_output_tokens,
incremental_state=incremental_state,
) if self.embed_positions is not None else None
if incremental_state is not None:
prev_output_tokens = prev_output_tokens[:, -1:]
if positions is not None:
positions = positions[:, -1:]
# embed tokens and positions
x = self.embed_scale * self.embed_tokens(prev_output_tokens)
if self.project_in_dim is not None:
x = self.project_in_dim(x)
if positions is not None:
x += positions
x = F.dropout(x, p=self.dropout, training=self.training)
# B x T x C -> T x B x C
x = x.transpose(0, 1)
self_attn_padding_mask = prev_output_tokens.eq(self.padding_idx)
if not self_attn_padding_mask.any() and not self.cross_self_attention:
self_attn_padding_mask = None
# decoder layers
attn = None
inner_states = [x]
for idx, layer in enumerate(self.layers):
encoder_state = None
if encoder_out is not None:
if self.layer_wise_attention:
encoder_state = encoder_out[3][idx]
else:
encoder_state = encoder_out[0]
if incremental_state is None and not full_context_alignment:
self_attn_mask = self.buffered_future_mask(x)
else:
self_attn_mask = None
x, layer_attn = layer(
x,
encoder_state
if encoder_state is not None else None,
encoder_out[1]
if encoder_out is not None else None,
incremental_state,
self_attn_mask=self_attn_mask,
self_attn_padding_mask=self_attn_padding_mask,
need_attn=(idx == alignment_layer),
need_head_weights=(idx == alignment_layer),
)
inner_states.append(x)
if layer_attn is not None and idx == alignment_layer:
attn = layer_attn.float()
if attn is not None:
if alignment_heads is not None:
attn = attn[:alignment_heads]
# average probabilities over heads
attn = attn.mean(dim=0)
if self.layer_norm:
x = self.layer_norm(x)
# T x B x C -> B x T x C
x = x.transpose(0, 1)
if self.project_out_dim is not None:
x = self.project_out_dim(x)
return x, {'attn': attn, 'inner_states': inner_states}
def output_layer(self, features, **kwargs):
"""Project features to the vocabulary size."""
if self.adaptive_softmax is None:
# project back to size of vocabulary
if self.share_input_output_embed:
return F.linear(features, self.embed_tokens.weight)
else:
return F.linear(features, self.embed_out)
else:
return features
def max_positions(self):
"""Maximum output length supported by the decoder."""
if self.embed_positions is None:
return self.max_target_positions
return min(self.max_target_positions, self.embed_positions.max_positions())
def buffered_future_mask(self, tensor):
dim = tensor.size(0)
if (
not hasattr(self, '_future_mask')
or self._future_mask is None
or self._future_mask.device != tensor.device
or self._future_mask.size(0) < dim
):
self._future_mask = torch.triu(utils.fill_with_neg_inf(tensor.new(dim, dim)), 1)
return self._future_mask[:dim, :dim]
def upgrade_state_dict_named(self, state_dict, name):
"""Upgrade a (possibly old) state dict for new versions of fairseq."""
if isinstance(self.embed_positions, SinusoidalPositionalEmbedding):
weights_key = '{}.embed_positions.weights'.format(name)
if weights_key in state_dict:
del state_dict[weights_key]
state_dict['{}.embed_positions._float_tensor'.format(name)] = torch.FloatTensor(1)
for i in range(len(self.layers)):
# update layer norms
layer_norm_map = {
'0': 'self_attn_layer_norm',
'1': 'encoder_attn_layer_norm',
'2': 'final_layer_norm'
}
for old, new in layer_norm_map.items():
for m in ('weight', 'bias'):
k = '{}.layers.{}.layer_norms.{}.{}'.format(name, i, old, m)
if k in state_dict:
state_dict['{}.layers.{}.{}.{}'.format(name, i, new, m)] = state_dict[k]
del state_dict[k]
version_key = '{}.version'.format(name)
if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) <= 2:
# earlier checkpoints did not normalize after the stack of layers
self.layer_norm = None
self.normalize = False
state_dict[version_key] = torch.Tensor([1])
return state_dict