diff --git a/fairseq/modules/multihead_attention.py b/fairseq/modules/multihead_attention.py index 0ff05d16d..cb3ae95d5 100644 --- a/fairseq/modules/multihead_attention.py +++ b/fairseq/modules/multihead_attention.py @@ -3,6 +3,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import math import torch from torch import nn from torch.nn import Parameter @@ -38,12 +39,9 @@ class MultiheadAttention(nn.Module): assert not self.self_attention or self.qkv_same_dim, 'Self-attention requires query, key and ' \ 'value to be of the same size' - if self.qkv_same_dim: - self.in_proj_weight = Parameter(torch.Tensor(3 * embed_dim, embed_dim)) - else: - self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim)) - self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.vdim)) - self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim)) + self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim)) + self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.vdim)) + self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim)) if bias: self.in_proj_bias = Parameter(torch.Tensor(3 * embed_dim)) @@ -70,12 +68,19 @@ class MultiheadAttention(nn.Module): else: self.enable_torch_version = False + @property + def in_proj_weight(self): + # TODO: Remove this backward compatibility code (in_proj_weight) + return torch.cat((self.q_proj_weight, self.k_proj_weight, self.v_proj_weight)) + def prepare_for_onnx_export_(self): self.onnx_trace = True def reset_parameters(self): if self.qkv_same_dim: - nn.init.xavier_uniform_(self.in_proj_weight) + nn.init.xavier_uniform_(self.k_proj_weight, gain=1/math.sqrt(2)) + nn.init.xavier_uniform_(self.v_proj_weight, gain=1/math.sqrt(2)) + nn.init.xavier_uniform_(self.q_proj_weight, gain=1/math.sqrt(2)) else: nn.init.xavier_uniform_(self.k_proj_weight) nn.init.xavier_uniform_(self.v_proj_weight) @@ -126,27 +131,17 @@ class MultiheadAttention(nn.Module): assert list(query.size()) == [tgt_len, bsz, embed_dim] if self.enable_torch_version and not self.onnx_trace and incremental_state is None and not static_kv: - if self.qkv_same_dim: - return F.multi_head_attention_forward(query, key, value, - self.embed_dim, self.num_heads, - self.in_proj_weight, - self.in_proj_bias, self.bias_k, self.bias_v, - self.add_zero_attn, self.dropout, - self.out_proj.weight, self.out_proj.bias, - self.training, key_padding_mask, need_weights, - attn_mask) - else: - return F.multi_head_attention_forward(query, key, value, - self.embed_dim, self.num_heads, - torch.empty([0]), - self.in_proj_bias, self.bias_k, self.bias_v, - self.add_zero_attn, self.dropout, - self.out_proj.weight, self.out_proj.bias, - self.training, key_padding_mask, need_weights, - attn_mask, use_separate_proj_weight=True, - q_proj_weight=self.q_proj_weight, - k_proj_weight=self.k_proj_weight, - v_proj_weight=self.v_proj_weight) + return F.multi_head_attention_forward(query, key, value, + self.embed_dim, self.num_heads, + torch.empty([0]), + self.in_proj_bias, self.bias_k, self.bias_v, + self.add_zero_attn, self.dropout, + self.out_proj.weight, self.out_proj.bias, + self.training, key_padding_mask, need_weights, + attn_mask, use_separate_proj_weight=True, + q_proj_weight=self.q_proj_weight, + k_proj_weight=self.k_proj_weight, + v_proj_weight=self.v_proj_weight) if incremental_state is not None: saved_state = self._get_input_buffer(incremental_state) @@ -160,8 +155,9 @@ class MultiheadAttention(nn.Module): saved_state = None if self.self_attention: - # self-attention - q, k, v = self.in_proj_qkv(query) + q = self.in_proj_q(query) + k = self.in_proj_k(query) + v = self.in_proj_v(query) elif self.encoder_decoder_attention: # encoder-decoder attention q = self.in_proj_q(query) @@ -288,45 +284,25 @@ class MultiheadAttention(nn.Module): return attn, attn_weights - def in_proj_qkv(self, query): - return self._in_proj(query).chunk(3, dim=-1) - def in_proj_q(self, query): - if self.qkv_same_dim: - return self._in_proj(query, end=self.embed_dim) - else: - bias = self.in_proj_bias - if bias is not None: - bias = bias[:self.embed_dim] - return F.linear(query, self.q_proj_weight, bias) + bias = self.in_proj_bias + if bias is not None: + bias = bias[:self.embed_dim] + return F.linear(query, self.q_proj_weight, bias) def in_proj_k(self, key): - if self.qkv_same_dim: - return self._in_proj(key, start=self.embed_dim, end=2 * self.embed_dim) - else: - weight = self.k_proj_weight - bias = self.in_proj_bias - if bias is not None: - bias = bias[self.embed_dim:2 * self.embed_dim] - return F.linear(key, weight, bias) + weight = self.k_proj_weight + bias = self.in_proj_bias + if bias is not None: + bias = bias[self.embed_dim:2 * self.embed_dim] + return F.linear(key, weight, bias) def in_proj_v(self, value): - if self.qkv_same_dim: - return self._in_proj(value, start=2 * self.embed_dim) - else: - weight = self.v_proj_weight - bias = self.in_proj_bias - if bias is not None: - bias = bias[2 * self.embed_dim:] - return F.linear(value, weight, bias) - - def _in_proj(self, input, start=0, end=None): - weight = self.in_proj_weight + weight = self.v_proj_weight bias = self.in_proj_bias - weight = weight[start:end, :] if bias is not None: - bias = bias[start:end] - return F.linear(input, weight, bias) + bias = bias[2 * self.embed_dim:] + return F.linear(value, weight, bias) def reorder_incremental_state(self, incremental_state, new_order): """Reorder buffered internal state (for incremental generation).""" @@ -354,3 +330,27 @@ class MultiheadAttention(nn.Module): def apply_sparse_mask(self, attn_weights, tgt_len, src_len, bsz): return attn_weights + + def upgrade_state_dict_named(self, state_dict, name): + # TODO: Remove this backward compatibility code (in_proj_weight) + # here, we convert in_proj_weight to individual q,k,v weights + prefix = name + '.' if name != '' else '' + items_to_add = {} + keys_to_remove = [] + for k in state_dict.keys(): + if k.endswith(prefix + 'in_proj_weight'): + # in_proj_weight used to be q + k + v with same dimensions + dim = int(state_dict[k].shape[0] / 3) + items_to_add[prefix + 'q_proj_weight'] = state_dict[k][:dim] + items_to_add[prefix + 'k_proj_weight'] = state_dict[k][dim:2*dim] + items_to_add[prefix + 'v_proj_weight'] = state_dict[k][2*dim:] + + keys_to_remove.append(k) + + for k in keys_to_remove: + del state_dict[k] + + for key, value in items_to_add.items(): + state_dict[key] = value + + return state_dict