Implementation of the WeCNLP abstract "Cross+Self-Attention for Transformer Models" (#1097)

Summary:
This PR implements a new attention module which combines cross-attention (encoder-decoder attention) and the decoder self-attention. This work was accepted as an abstract at WeCNLP 2019 (https://www.wecnlp.ai/wecnlp-2019).

Cross+Self-Attention reduces the amount of parameter and increases the inference speed without any degradation in translation quality.
More details can be found in the attached [abstract](https://github.com/pytorch/fairseq/files/3561282/paper.pdf)
Pull Request resolved: https://github.com/pytorch/fairseq/pull/1097

Differential Revision: D17653168

Pulled By: myleott

fbshipit-source-id: deb834c2c78a229d7418ffbfea20ba3ce252991c
This commit is contained in:
Stephan Peitz 2019-09-29 05:08:24 -07:00 committed by Facebook Github Bot
parent ea1a410d59
commit 4ac2c5f2cc
4 changed files with 120 additions and 12 deletions

View File

@ -122,6 +122,13 @@ class TransformerModel(FairseqEncoderDecoderModel):
'Must be used with adaptive_loss criterion'),
parser.add_argument('--adaptive-softmax-dropout', type=float, metavar='D',
help='sets adaptive softmax dropout for the tail projections')
# args for "Cross+Self-Attention for Transformer Models" (Peitz et al., 2019)
parser.add_argument('--no-cross-attention', default=False, action='store_true',
help='do not perform cross-attention')
parser.add_argument('--cross-self-attention', default=False, action='store_true',
help='perform cross+self-attention')
parser.add_argument('--layer-wise-attention', default=False, action='store_true',
help='perform layer-wise attention (cross-attention or cross+self-attention)')
# fmt: on
@classmethod
@ -180,7 +187,12 @@ class TransformerModel(FairseqEncoderDecoderModel):
@classmethod
def build_decoder(cls, args, tgt_dict, embed_tokens):
return TransformerDecoder(args, tgt_dict, embed_tokens)
return TransformerDecoder(
args,
tgt_dict,
embed_tokens,
no_encoder_attn=getattr(args, 'no_cross_attention', False),
)
class TransformerEncoder(FairseqEncoder):
@ -211,6 +223,8 @@ class TransformerEncoder(FairseqEncoder):
learned=args.encoder_learned_pos,
) if not args.no_token_positional_embeddings else None
self.layer_wise_attention = getattr(args, 'layer_wise_attention', False)
self.layers = nn.ModuleList([])
self.layers.extend([
TransformerEncoderLayer(args)
@ -230,13 +244,15 @@ class TransformerEncoder(FairseqEncoder):
x = F.dropout(x, p=self.dropout, training=self.training)
return x, embed
def forward(self, src_tokens, src_lengths, cls_input=None):
def forward(self, src_tokens, src_lengths, cls_input=None, return_all_hiddens=False):
"""
Args:
src_tokens (LongTensor): tokens in the source language of shape
`(batch, src_len)`
src_lengths (torch.LongTensor): lengths of each source sentence of
shape `(batch)`
return_all_hiddens (bool, optional): also return all of the
intermediate hidden states (default: False).
Returns:
dict:
@ -244,7 +260,13 @@ class TransformerEncoder(FairseqEncoder):
shape `(src_len, batch, embed_dim)`
- **encoder_padding_mask** (ByteTensor): the positions of
padding elements of shape `(batch, src_len)`
- **encoder_states** (List[Tensor]): all intermediate
hidden states of shape `(src_len, batch, embed_dim)`.
Only populated if *return_all_hiddens* is True.
"""
if self.layer_wise_attention:
return_all_hiddens = True
x, encoder_embedding = self.forward_embedding(src_tokens)
# B x T x C -> T x B x C
@ -255,17 +277,24 @@ class TransformerEncoder(FairseqEncoder):
if not encoder_padding_mask.any():
encoder_padding_mask = None
encoder_states = [] if return_all_hiddens else None
# encoder layers
for layer in self.layers:
x = layer(x, encoder_padding_mask)
if return_all_hiddens:
encoder_states.append(x)
if self.layer_norm:
x = self.layer_norm(x)
if return_all_hiddens:
encoder_states[-1] = x
return {
'encoder_out': x, # T x B x C
'encoder_padding_mask': encoder_padding_mask, # B x T
'encoder_embedding': encoder_embedding, # B x T x C
'encoder_states': encoder_states, # List[T x B x C]
}
def reorder_encoder_out(self, encoder_out, new_order):
@ -285,6 +314,9 @@ class TransformerEncoder(FairseqEncoder):
if encoder_out['encoder_padding_mask'] is not None:
encoder_out['encoder_padding_mask'] = \
encoder_out['encoder_padding_mask'].index_select(0, new_order)
if encoder_out.get('encoder_states', None) is not None:
for idx, state in enumerate(encoder_out['encoder_states']):
encoder_out['encoder_states'][idx] = state.index_select(1, new_order)
return encoder_out
def max_positions(self):
@ -293,6 +325,14 @@ class TransformerEncoder(FairseqEncoder):
return self.max_source_positions
return min(self.max_source_positions, self.embed_positions.max_positions())
def buffered_future_mask(self, tensor):
dim = tensor.size(0)
if not hasattr(self, '_future_mask') or self._future_mask is None or self._future_mask.device != tensor.device:
self._future_mask = torch.triu(utils.fill_with_neg_inf(tensor.new(dim, dim)), 1)
if self._future_mask.size(0) < dim:
self._future_mask = torch.triu(utils.fill_with_neg_inf(self._future_mask.resize_(dim, dim)), 1)
return self._future_mask[:dim, :dim]
def upgrade_state_dict_named(self, state_dict, name):
"""Upgrade a (possibly old) state dict for new versions of fairseq."""
if isinstance(self.embed_positions, SinusoidalPositionalEmbedding):
@ -350,6 +390,9 @@ class TransformerDecoder(FairseqIncrementalDecoder):
learned=args.decoder_learned_pos,
) if not args.no_token_positional_embeddings else None
self.cross_self_attention = getattr(args, 'cross_self_attention', False)
self.layer_wise_attention = getattr(args, 'layer_wise_attention', False)
self.layers = nn.ModuleList([])
self.layers.extend([
TransformerDecoderLayer(args, no_encoder_attn)
@ -435,14 +478,26 @@ class TransformerDecoder(FairseqIncrementalDecoder):
inner_states = [x]
self_attn_padding_mask = prev_output_tokens.eq(self.padding_idx)
if not self_attn_padding_mask.any() and not self.cross_self_attention:
self_attn_padding_mask = None
# decoder layers
for layer in self.layers:
for idx, layer in enumerate(self.layers):
encoder_state = None
if encoder_out is not None:
if self.layer_wise_attention:
encoder_state = encoder_out['encoder_states'][idx]
else:
encoder_state = encoder_out['encoder_out']
x, attn = layer(
x,
encoder_out['encoder_out'] if encoder_out is not None else None,
encoder_state,
encoder_out['encoder_padding_mask'] if encoder_out is not None else None,
incremental_state,
self_attn_mask=self.buffered_future_mask(x) if incremental_state is None else None,
self_attn_padding_mask=self_attn_padding_mask,
)
inner_states.append(x)
@ -553,6 +608,9 @@ def base_architecture(args):
args.share_all_embeddings = getattr(args, 'share_all_embeddings', False)
args.no_token_positional_embeddings = getattr(args, 'no_token_positional_embeddings', False)
args.adaptive_input = getattr(args, 'adaptive_input', False)
args.no_cross_attention = getattr(args, 'no_cross_attention', False)
args.cross_self_attention = getattr(args, 'cross_self_attention', False)
args.layer_wise_attention = getattr(args, 'layer_wise_attention', False)
args.decoder_output_dim = getattr(args, 'decoder_output_dim', args.decoder_embed_dim)
args.decoder_input_dim = getattr(args, 'decoder_input_dim', args.decoder_embed_dim)

View File

@ -186,8 +186,15 @@ class MultiheadAttention(nn.Module):
v = prev_value
else:
v = torch.cat((prev_value, v), dim=1)
if 'prev_key_padding_mask' in saved_state and saved_state['prev_key_padding_mask'] is not None:
prev_key_padding_mask = saved_state['prev_key_padding_mask']
if static_kv:
key_padding_mask = prev_key_padding_mask
else:
key_padding_mask = torch.cat((prev_key_padding_mask, key_padding_mask), dim=1)
saved_state['prev_key'] = k.view(bsz, self.num_heads, -1, self.head_dim)
saved_state['prev_value'] = v.view(bsz, self.num_heads, -1, self.head_dim)
saved_state['prev_key_padding_mask'] = key_padding_mask
self._set_input_buffer(incremental_state, saved_state)
@ -311,7 +318,8 @@ class MultiheadAttention(nn.Module):
input_buffer = self._get_input_buffer(incremental_state)
if input_buffer is not None:
for k in input_buffer.keys():
input_buffer[k] = input_buffer[k].index_select(0, new_order)
if input_buffer[k] is not None:
input_buffer[k] = input_buffer[k].index_select(0, new_order)
self._set_input_buffer(incremental_state, input_buffer)
def _get_input_buffer(self, incremental_state):

View File

@ -3,6 +3,7 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn as nn
import torch.nn.functional as F
from fairseq import utils
@ -134,13 +135,14 @@ class TransformerDecoderLayer(nn.Module):
def __init__(self, args, no_encoder_attn=False, add_bias_kv=False, add_zero_attn=False):
super().__init__()
self.embed_dim = args.decoder_embed_dim
self.cross_self_attention = getattr(args, 'cross_self_attention', False)
self.self_attn = MultiheadAttention(
embed_dim=self.embed_dim,
num_heads=args.decoder_attention_heads,
dropout=args.attention_dropout,
add_bias_kv=add_bias_kv,
add_zero_attn=add_zero_attn,
self_attention=True
self_attention=not self.cross_self_attention,
)
self.dropout = args.dropout
self.activation_fn = utils.get_activation_fn(
@ -208,13 +210,27 @@ class TransformerDecoderLayer(nn.Module):
if prev_self_attn_state is not None:
if incremental_state is None:
incremental_state = {}
prev_key, prev_value = prev_self_attn_state
prev_key, prev_value = prev_self_attn_state[:2]
saved_state = {"prev_key": prev_key, "prev_value": prev_value}
if len(prev_self_attn_state) >= 3:
saved_state["prev_key_padding_mask"] = prev_self_attn_state[2]
self.self_attn._set_input_buffer(incremental_state, saved_state)
if self.cross_self_attention and not (incremental_state is not None and "prev_key" in self.self_attn._get_input_buffer(incremental_state)):
if self_attn_mask is not None:
self_attn_mask = torch.cat((x.new(x.size(0), encoder_out.size(0)).zero_(), self_attn_mask), dim=1)
if self_attn_padding_mask is not None:
if encoder_padding_mask is None:
encoder_padding_mask = self_attn_padding_mask.new(encoder_out.size(1), encoder_out.size(0)).zero_()
self_attn_padding_mask = torch.cat((encoder_padding_mask, self_attn_padding_mask), dim=1)
y = torch.cat((encoder_out, x), dim=0)
else:
y = x
x, attn = self.self_attn(
query=x,
key=x,
value=x,
key=y,
value=y,
key_padding_mask=self_attn_padding_mask,
incremental_state=incremental_state,
need_weights=False,
@ -230,9 +246,12 @@ class TransformerDecoderLayer(nn.Module):
if prev_attn_state is not None:
if incremental_state is None:
incremental_state = {}
prev_key, prev_value = prev_attn_state
prev_key, prev_value = prev_attn_state[:2]
saved_state = {"prev_key": prev_key, "prev_value": prev_value}
if len(prev_attn_state) >= 3:
saved_state["prev_key_padding_mask"] = prev_attn_state[2]
self.encoder_attn._set_input_buffer(incremental_state, saved_state)
x, attn = self.encoder_attn(
query=x,
key=encoder_out,
@ -256,7 +275,10 @@ class TransformerDecoderLayer(nn.Module):
x = self.maybe_layer_norm(self.final_layer_norm, x, after=True)
if self.onnx_trace and incremental_state is not None:
saved_state = self.self_attn._get_input_buffer(incremental_state)
self_attn_state = saved_state["prev_key"], saved_state["prev_value"]
if self_attn_padding_mask is not None:
self_attn_state = saved_state["prev_key"], saved_state["prev_value"], saved_state["prev_key_padding_mask"]
else:
self_attn_state = saved_state["prev_key"], saved_state["prev_value"]
return x, attn, self_attn_state
return x, attn

View File

@ -154,6 +154,23 @@ class TestTranslation(unittest.TestCase):
], run_validation=True)
generate_main(data_dir)
def test_transformer_cross_self_attention(self):
with contextlib.redirect_stdout(StringIO()):
with tempfile.TemporaryDirectory('test_transformer_cross_self_attention') as data_dir:
create_dummy_data(data_dir)
preprocess_translation_data(data_dir)
train_translation_model(data_dir, 'transformer_iwslt_de_en', [
'--encoder-layers', '2',
'--decoder-layers', '2',
'--encoder-embed-dim', '8',
'--decoder-embed-dim', '8',
'--decoder-embed-dim', '8',
'--no-cross-attention',
'--cross-self-attention',
'--layer-wise-attention',
], run_validation=True)
generate_main(data_dir, extra_flags=[])
def test_lightconv(self):
with contextlib.redirect_stdout(StringIO()):
with tempfile.TemporaryDirectory('test_lightconv') as data_dir:
@ -543,6 +560,10 @@ def train_translation_model(data_dir, arch, extra_flags=None, task='translation'
def generate_main(data_dir, extra_flags=None):
if extra_flags is None:
extra_flags = [
'--print-alignment',
]
generate_parser = options.get_generation_parser()
generate_args = options.parse_args_and_arch(
generate_parser,
@ -554,7 +575,6 @@ def generate_main(data_dir, extra_flags=None):
'--max-len-b', '5',
'--gen-subset', 'valid',
'--no-progress-bar',
'--print-alignment',
] + (extra_flags or []),
)