Simplify fairseq multihead attention (#888)

Summary:
Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/888

We want to simplify multihead attention and get rid of the dynamic in_proj_weight logic. Sending the diff early for feedback, will have further changes as I try to fix breaking tests

Reviewed By: edunov

Differential Revision: D17912661

fbshipit-source-id: 0e6319fc694d8ec5187d1c2fefe5839d9d522186
This commit is contained in:
Halil Akin 2019-10-25 09:02:19 -07:00 committed by Facebook Github Bot
parent 5b086a0c17
commit fdf4c3e900

View File

@ -3,6 +3,7 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
import torch
from torch import nn
from torch.nn import Parameter
@ -38,12 +39,9 @@ class MultiheadAttention(nn.Module):
assert not self.self_attention or self.qkv_same_dim, 'Self-attention requires query, key and ' \
'value to be of the same size'
if self.qkv_same_dim:
self.in_proj_weight = Parameter(torch.Tensor(3 * embed_dim, embed_dim))
else:
self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim))
self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.vdim))
self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim))
self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.vdim))
self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
if bias:
self.in_proj_bias = Parameter(torch.Tensor(3 * embed_dim))
@ -70,12 +68,19 @@ class MultiheadAttention(nn.Module):
else:
self.enable_torch_version = False
@property
def in_proj_weight(self):
# TODO: Remove this backward compatibility code (in_proj_weight)
return torch.cat((self.q_proj_weight, self.k_proj_weight, self.v_proj_weight))
def prepare_for_onnx_export_(self):
self.onnx_trace = True
def reset_parameters(self):
if self.qkv_same_dim:
nn.init.xavier_uniform_(self.in_proj_weight)
nn.init.xavier_uniform_(self.k_proj_weight, gain=1/math.sqrt(2))
nn.init.xavier_uniform_(self.v_proj_weight, gain=1/math.sqrt(2))
nn.init.xavier_uniform_(self.q_proj_weight, gain=1/math.sqrt(2))
else:
nn.init.xavier_uniform_(self.k_proj_weight)
nn.init.xavier_uniform_(self.v_proj_weight)
@ -126,27 +131,17 @@ class MultiheadAttention(nn.Module):
assert list(query.size()) == [tgt_len, bsz, embed_dim]
if self.enable_torch_version and not self.onnx_trace and incremental_state is None and not static_kv:
if self.qkv_same_dim:
return F.multi_head_attention_forward(query, key, value,
self.embed_dim, self.num_heads,
self.in_proj_weight,
self.in_proj_bias, self.bias_k, self.bias_v,
self.add_zero_attn, self.dropout,
self.out_proj.weight, self.out_proj.bias,
self.training, key_padding_mask, need_weights,
attn_mask)
else:
return F.multi_head_attention_forward(query, key, value,
self.embed_dim, self.num_heads,
torch.empty([0]),
self.in_proj_bias, self.bias_k, self.bias_v,
self.add_zero_attn, self.dropout,
self.out_proj.weight, self.out_proj.bias,
self.training, key_padding_mask, need_weights,
attn_mask, use_separate_proj_weight=True,
q_proj_weight=self.q_proj_weight,
k_proj_weight=self.k_proj_weight,
v_proj_weight=self.v_proj_weight)
return F.multi_head_attention_forward(query, key, value,
self.embed_dim, self.num_heads,
torch.empty([0]),
self.in_proj_bias, self.bias_k, self.bias_v,
self.add_zero_attn, self.dropout,
self.out_proj.weight, self.out_proj.bias,
self.training, key_padding_mask, need_weights,
attn_mask, use_separate_proj_weight=True,
q_proj_weight=self.q_proj_weight,
k_proj_weight=self.k_proj_weight,
v_proj_weight=self.v_proj_weight)
if incremental_state is not None:
saved_state = self._get_input_buffer(incremental_state)
@ -160,8 +155,9 @@ class MultiheadAttention(nn.Module):
saved_state = None
if self.self_attention:
# self-attention
q, k, v = self.in_proj_qkv(query)
q = self.in_proj_q(query)
k = self.in_proj_k(query)
v = self.in_proj_v(query)
elif self.encoder_decoder_attention:
# encoder-decoder attention
q = self.in_proj_q(query)
@ -288,45 +284,25 @@ class MultiheadAttention(nn.Module):
return attn, attn_weights
def in_proj_qkv(self, query):
return self._in_proj(query).chunk(3, dim=-1)
def in_proj_q(self, query):
if self.qkv_same_dim:
return self._in_proj(query, end=self.embed_dim)
else:
bias = self.in_proj_bias
if bias is not None:
bias = bias[:self.embed_dim]
return F.linear(query, self.q_proj_weight, bias)
bias = self.in_proj_bias
if bias is not None:
bias = bias[:self.embed_dim]
return F.linear(query, self.q_proj_weight, bias)
def in_proj_k(self, key):
if self.qkv_same_dim:
return self._in_proj(key, start=self.embed_dim, end=2 * self.embed_dim)
else:
weight = self.k_proj_weight
bias = self.in_proj_bias
if bias is not None:
bias = bias[self.embed_dim:2 * self.embed_dim]
return F.linear(key, weight, bias)
weight = self.k_proj_weight
bias = self.in_proj_bias
if bias is not None:
bias = bias[self.embed_dim:2 * self.embed_dim]
return F.linear(key, weight, bias)
def in_proj_v(self, value):
if self.qkv_same_dim:
return self._in_proj(value, start=2 * self.embed_dim)
else:
weight = self.v_proj_weight
bias = self.in_proj_bias
if bias is not None:
bias = bias[2 * self.embed_dim:]
return F.linear(value, weight, bias)
def _in_proj(self, input, start=0, end=None):
weight = self.in_proj_weight
weight = self.v_proj_weight
bias = self.in_proj_bias
weight = weight[start:end, :]
if bias is not None:
bias = bias[start:end]
return F.linear(input, weight, bias)
bias = bias[2 * self.embed_dim:]
return F.linear(value, weight, bias)
def reorder_incremental_state(self, incremental_state, new_order):
"""Reorder buffered internal state (for incremental generation)."""
@ -354,3 +330,27 @@ class MultiheadAttention(nn.Module):
def apply_sparse_mask(self, attn_weights, tgt_len, src_len, bsz):
return attn_weights
def upgrade_state_dict_named(self, state_dict, name):
# TODO: Remove this backward compatibility code (in_proj_weight)
# here, we convert in_proj_weight to individual q,k,v weights
prefix = name + '.' if name != '' else ''
items_to_add = {}
keys_to_remove = []
for k in state_dict.keys():
if k.endswith(prefix + 'in_proj_weight'):
# in_proj_weight used to be q + k + v with same dimensions
dim = int(state_dict[k].shape[0] / 3)
items_to_add[prefix + 'q_proj_weight'] = state_dict[k][:dim]
items_to_add[prefix + 'k_proj_weight'] = state_dict[k][dim:2*dim]
items_to_add[prefix + 'v_proj_weight'] = state_dict[k][2*dim:]
keys_to_remove.append(k)
for k in keys_to_remove:
del state_dict[k]
for key, value in items_to_add.items():
state_dict[key] = value
return state_dict