Simplify and generalize utils.make_positions

Summary: Pull Request resolved: https://github.com/pytorch/fairseq/pull/625

Differential Revision: D14822123

Pulled By: myleott

fbshipit-source-id: 8a263d30020588577ee02fb8c6959ff918705103
This commit is contained in:
Myle Ott 2019-04-15 07:29:14 -07:00 committed by Facebook Github Bot
parent a47630e127
commit e12e1d254c
9 changed files with 49 additions and 87 deletions

View File

@ -179,17 +179,14 @@ class FConvEncoder(FairseqEncoder):
connections are added between layers when ``residual=1`` (which is
the default behavior).
dropout (float, optional): dropout to be applied before each conv layer
left_pad (bool, optional): whether the input is left-padded
(default: True).
"""
def __init__(
self, dictionary, embed_dim=512, embed_dict=None, max_positions=1024,
convolutions=((512, 3),) * 20, dropout=0.1, left_pad=True,
self, dictionary, embed_dim=512, embed_dict=None, max_positions=1024,
convolutions=((512, 3),) * 20, dropout=0.1,
):
super().__init__(dictionary)
self.dropout = dropout
self.left_pad = left_pad
self.num_attention_layers = None
num_embeddings = len(dictionary)
@ -202,7 +199,6 @@ class FConvEncoder(FairseqEncoder):
max_positions,
embed_dim,
self.padding_idx,
left_pad=self.left_pad,
)
convolutions = extend_conv_spec(convolutions)
@ -387,16 +383,14 @@ class FConvDecoder(FairseqIncrementalDecoder):
"""Convolutional decoder"""
def __init__(
self, dictionary, embed_dim=512, embed_dict=None, out_embed_dim=256,
max_positions=1024, convolutions=((512, 3),) * 20, attention=True,
dropout=0.1, share_embed=False, positional_embeddings=True,
adaptive_softmax_cutoff=None, adaptive_softmax_dropout=0,
left_pad=False,
self, dictionary, embed_dim=512, embed_dict=None, out_embed_dim=256,
max_positions=1024, convolutions=((512, 3),) * 20, attention=True,
dropout=0.1, share_embed=False, positional_embeddings=True,
adaptive_softmax_cutoff=None, adaptive_softmax_dropout=0,
):
super().__init__(dictionary)
self.register_buffer('version', torch.Tensor([2]))
self.dropout = dropout
self.left_pad = left_pad
self.need_attn = True
convolutions = extend_conv_spec(convolutions)
@ -418,7 +412,6 @@ class FConvDecoder(FairseqIncrementalDecoder):
max_positions,
embed_dim,
padding_idx,
left_pad=self.left_pad,
) if positional_embeddings else None
self.fc1 = Linear(embed_dim, in_channels, dropout=dropout)
@ -616,8 +609,8 @@ def Embedding(num_embeddings, embedding_dim, padding_idx):
return m
def PositionalEmbedding(num_embeddings, embedding_dim, padding_idx, left_pad):
m = LearnedPositionalEmbedding(num_embeddings, embedding_dim, padding_idx, left_pad)
def PositionalEmbedding(num_embeddings, embedding_dim, padding_idx):
m = LearnedPositionalEmbedding(num_embeddings, embedding_dim, padding_idx)
nn.init.normal_(m.weight, 0, 0.1)
nn.init.constant_(m.weight[padding_idx], 0)
return m

View File

@ -140,12 +140,11 @@ class FConvEncoder(FairseqEncoder):
def __init__(
self, dictionary, embed_dim=512, max_positions=1024,
convolutions=((512, 3),) * 20, dropout=0.1, attention=False,
attention_nheads=1, left_pad=True,
attention_nheads=1,
):
super().__init__(dictionary)
self.dropout = dropout
self.num_attention_layers = None
self.left_pad = left_pad
num_embeddings = len(dictionary)
self.padding_idx = dictionary.pad()
@ -154,7 +153,6 @@ class FConvEncoder(FairseqEncoder):
max_positions,
embed_dim,
self.padding_idx,
left_pad=self.left_pad,
)
def expand_bool_array(val):
@ -269,14 +267,13 @@ class FConvDecoder(FairseqDecoder):
convolutions=((512, 3),) * 8, attention=True, dropout=0.1,
selfattention=False, attention_nheads=1, selfattention_nheads=1,
project_input=False, gated_attention=False, downsample=False,
pretrained=False, trained_decoder=None, left_pad=False,
pretrained=False, trained_decoder=None,
):
super().__init__(dictionary)
self.register_buffer('version', torch.Tensor([2]))
self.pretrained = pretrained
self.pretrained_decoder = trained_decoder
self.dropout = dropout
self.left_pad = left_pad
self.need_attn = True
in_channels = convolutions[0][0]
@ -301,7 +298,6 @@ class FConvDecoder(FairseqDecoder):
max_positions,
embed_dim,
padding_idx,
left_pad=self.left_pad,
)
self.fc1 = Linear(embed_dim, in_channels, dropout=dropout)
@ -487,8 +483,8 @@ def Embedding(num_embeddings, embedding_dim, padding_idx):
return m
def PositionalEmbedding(num_embeddings, embedding_dim, padding_idx, left_pad):
m = LearnedPositionalEmbedding(num_embeddings, embedding_dim, padding_idx, left_pad)
def PositionalEmbedding(num_embeddings, embedding_dim, padding_idx):
m = LearnedPositionalEmbedding(num_embeddings, embedding_dim, padding_idx)
m.weight.data.normal_(0, 0.1)
return m

View File

@ -291,11 +291,9 @@ class LightConvEncoder(FairseqEncoder):
args (argparse.Namespace): parsed command-line arguments
dictionary (~fairseq.data.Dictionary): encoding dictionary
embed_tokens (torch.nn.Embedding): input embedding
left_pad (bool, optional): whether the input is left-padded. Default:
``True``
"""
def __init__(self, args, dictionary, embed_tokens, left_pad=True):
def __init__(self, args, dictionary, embed_tokens):
super().__init__(dictionary)
self.dropout = args.dropout
@ -307,7 +305,6 @@ class LightConvEncoder(FairseqEncoder):
self.embed_scale = math.sqrt(embed_dim)
self.embed_positions = PositionalEmbedding(
args.max_source_positions, embed_dim, self.padding_idx,
left_pad=left_pad,
learned=args.encoder_learned_pos,
) if not args.no_token_positional_embeddings else None
@ -399,11 +396,9 @@ class LightConvDecoder(FairseqIncrementalDecoder):
embed_tokens (torch.nn.Embedding): output embedding
no_encoder_attn (bool, optional): whether to attend to encoder outputs.
Default: ``False``
left_pad (bool, optional): whether the input is left-padded. Default:
``False``
"""
def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False, left_pad=False, final_norm=True):
def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False, final_norm=True):
super().__init__(dictionary)
self.dropout = args.dropout
self.share_input_output_embed = args.share_decoder_input_output_embed
@ -422,7 +417,6 @@ class LightConvDecoder(FairseqIncrementalDecoder):
self.embed_positions = PositionalEmbedding(
args.max_target_positions, embed_dim, padding_idx,
left_pad=left_pad,
learned=args.decoder_learned_pos,
) if not args.no_token_positional_embeddings else None
@ -778,13 +772,17 @@ def Linear(in_features, out_features, bias=True):
return m
def PositionalEmbedding(num_embeddings, embedding_dim, padding_idx, left_pad, learned=False):
def PositionalEmbedding(num_embeddings, embedding_dim, padding_idx, learned=False):
if learned:
m = LearnedPositionalEmbedding(num_embeddings + padding_idx + 1, embedding_dim, padding_idx, left_pad)
m = LearnedPositionalEmbedding(
num_embeddings + padding_idx + 1, embedding_dim, padding_idx,
)
nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5)
nn.init.constant_(m.weight[padding_idx], 0)
else:
m = SinusoidalPositionalEmbedding(embedding_dim, padding_idx, left_pad, num_embeddings + padding_idx + 1)
m = SinusoidalPositionalEmbedding(
embedding_dim, padding_idx, init_size=num_embeddings + padding_idx + 1,
)
return m

View File

@ -263,11 +263,9 @@ class TransformerEncoder(FairseqEncoder):
args (argparse.Namespace): parsed command-line arguments
dictionary (~fairseq.data.Dictionary): encoding dictionary
embed_tokens (torch.nn.Embedding): input embedding
left_pad (bool, optional): whether the input is left-padded
(default: True).
"""
def __init__(self, args, dictionary, embed_tokens, left_pad=True):
def __init__(self, args, dictionary, embed_tokens):
super().__init__(dictionary)
self.dropout = args.dropout
@ -279,7 +277,6 @@ class TransformerEncoder(FairseqEncoder):
self.embed_scale = math.sqrt(embed_dim)
self.embed_positions = PositionalEmbedding(
args.max_source_positions, embed_dim, self.padding_idx,
left_pad=left_pad,
learned=args.encoder_learned_pos,
) if not args.no_token_positional_embeddings else None
@ -390,13 +387,11 @@ class TransformerDecoder(FairseqIncrementalDecoder):
embed_tokens (torch.nn.Embedding): output embedding
no_encoder_attn (bool, optional): whether to attend to encoder outputs
(default: False).
left_pad (bool, optional): whether the input is left-padded
(default: False).
final_norm (bool, optional): apply layer norm to the output of the
final decoder layer (default: True).
"""
def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False, left_pad=False, final_norm=True):
def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False, final_norm=True):
super().__init__(dictionary)
self.dropout = args.dropout
self.share_input_output_embed = args.share_decoder_input_output_embed
@ -415,7 +410,6 @@ class TransformerDecoder(FairseqIncrementalDecoder):
self.embed_positions = PositionalEmbedding(
args.max_target_positions, embed_dim, padding_idx,
left_pad=left_pad,
learned=args.decoder_learned_pos,
) if not args.no_token_positional_embeddings else None
@ -796,13 +790,17 @@ def Linear(in_features, out_features, bias=True):
return m
def PositionalEmbedding(num_embeddings, embedding_dim, padding_idx, left_pad, learned=False):
def PositionalEmbedding(num_embeddings, embedding_dim, padding_idx, learned=False):
if learned:
m = LearnedPositionalEmbedding(num_embeddings + padding_idx + 1, embedding_dim, padding_idx, left_pad)
m = LearnedPositionalEmbedding(
num_embeddings + padding_idx + 1, embedding_dim, padding_idx,
)
nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5)
nn.init.constant_(m.weight[padding_idx], 0)
else:
m = SinusoidalPositionalEmbedding(embedding_dim, padding_idx, left_pad, num_embeddings + padding_idx + 1)
m = SinusoidalPositionalEmbedding(
embedding_dim, padding_idx, init_size=num_embeddings + padding_idx + 1,
)
return m

View File

@ -13,13 +13,11 @@ from fairseq import utils
class LearnedPositionalEmbedding(nn.Embedding):
"""This module learns positional embeddings up to a fixed maximum size.
Padding symbols are ignored, but it is necessary to specify whether padding
is added on the left side (left_pad=True) or right side (left_pad=False).
Padding symbols are ignored.
"""
def __init__(self, num_embeddings, embedding_dim, padding_idx, left_pad):
def __init__(self, num_embeddings, embedding_dim, padding_idx):
super().__init__(num_embeddings, embedding_dim, padding_idx)
self.left_pad = left_pad
self.onnx_trace = False
def forward(self, input, incremental_state=None):
@ -28,7 +26,9 @@ class LearnedPositionalEmbedding(nn.Embedding):
# positions is the same for every token when decoding a single step
positions = input.data.new(1, 1).fill_(self.padding_idx + input.size(1))
else:
positions = utils.make_positions(input.data, self.padding_idx, self.left_pad, self.onnx_trace)
positions = utils.make_positions(
input.data, self.padding_idx, onnx_trace=self.onnx_trace,
)
return super().forward(positions)
def max_positions(self):

View File

@ -17,15 +17,13 @@ from fairseq import utils
class SinusoidalPositionalEmbedding(nn.Module):
"""This module produces sinusoidal positional embeddings of any length.
Padding symbols are ignored, but it is necessary to specify whether padding
is added on the left side (left_pad=True) or right side (left_pad=False).
Padding symbols are ignored.
"""
def __init__(self, embedding_dim, padding_idx, left_pad, init_size=1024):
def __init__(self, embedding_dim, padding_idx, init_size=1024):
super().__init__()
self.embedding_dim = embedding_dim
self.padding_idx = padding_idx
self.left_pad = left_pad
self.weights = SinusoidalPositionalEmbedding.get_embedding(
init_size,
embedding_dim,
@ -76,7 +74,7 @@ class SinusoidalPositionalEmbedding(nn.Module):
return self.weights[self.padding_idx + pos, :].unsqueeze(1).repeat(bsz, 1, 1)
return self.weights[self.padding_idx + pos, :].expand(bsz, 1, -1)
positions = utils.make_positions(input, self.padding_idx, self.left_pad, self.onnx_trace)
positions = utils.make_positions(input, self.padding_idx, onnx_trace=self.onnx_trace)
if self.onnx_trace:
flat_embeddings = self.weights.detach().index_select(0, positions.view(-1))
embedding_shape = torch.cat((bsz.view(1), seq_len.view(1), torch.LongTensor([-1])))

View File

@ -40,13 +40,13 @@ def init_bert_params(module):
def PositionalEmbedding(
num_embeddings: int,
embedding_dim: int,
padding_idx: int,
left_pad: bool
num_embeddings: int,
embedding_dim: int,
padding_idx: int,
)-> nn.Embedding:
m = LearnedPositionalEmbedding(
num_embeddings + padding_idx + 1, embedding_dim, padding_idx, left_pad)
num_embeddings + padding_idx + 1, embedding_dim, padding_idx,
)
nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5)
nn.init.constant_(m.weight[padding_idx], 0)
return m
@ -121,7 +121,6 @@ class TransformerSentenceEncoder(nn.Module):
self.max_seq_len,
self.embedding_dim,
self.padding_idx,
left_pad=False,
)
if self.use_position_embeddings
else None

View File

@ -312,33 +312,13 @@ def post_process_prediction(hypo_tokens, src_str, alignment, align_dict, tgt_dic
return hypo_tokens, hypo_str, alignment
def make_positions(tensor, padding_idx, left_pad, onnx_trace=False):
def make_positions(tensor, padding_idx, onnx_trace=False):
"""Replace non-padding symbols with their position numbers.
Position numbers begin at padding_idx+1.
Padding symbols are ignored, but it is necessary to specify whether padding
is added on the left side (left_pad=True) or right side (left_pad=False).
Position numbers begin at padding_idx+1. Padding symbols are ignored.
"""
if onnx_trace:
range_buf = torch._dim_arange(like=tensor, dim=1) + padding_idx + 1
mask = tensor.ne(padding_idx)
positions = range_buf.expand_as(tensor)
if left_pad:
positions = positions - mask.size(1) + mask.long().sum(dim=1).unsqueeze(1)
return positions * mask.long() + padding_idx * (1 - mask.long())
max_pos = padding_idx + 1 + tensor.size(1)
if not hasattr(make_positions, 'range_buf'):
make_positions.range_buf = tensor.new()
make_positions.range_buf = make_positions.range_buf.type_as(tensor)
if make_positions.range_buf.numel() < max_pos:
torch.arange(padding_idx + 1, max_pos, out=make_positions.range_buf)
mask = tensor.ne(padding_idx)
positions = make_positions.range_buf[:tensor.size(1)].expand_as(tensor)
if left_pad:
positions = positions - mask.size(1) + mask.long().sum(dim=1).unsqueeze(1)
return tensor.clone().masked_scatter_(mask, positions[mask])
mask = tensor.ne(padding_idx).long()
return torch.cumsum(mask, dim=1) * mask + padding_idx
def strip_pad(tensor, pad):

View File

@ -69,11 +69,11 @@ class TestUtils(unittest.TestCase):
self.assertAlmostEqual(
left_pad_output,
utils.make_positions(left_pad_input, pad, left_pad=True),
utils.make_positions(left_pad_input, pad),
)
self.assertAlmostEqual(
right_pad_output,
utils.make_positions(right_pad_input, pad, left_pad=False),
utils.make_positions(right_pad_input, pad),
)
def assertAlmostEqual(self, t1, t2):