Script MultiheadAttention (#1002)

Summary:
Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1002

Pull Request resolved: https://github.com/pytorch/translate/pull/681

Pull Request resolved: https://github.com/pytorch/fairseq/pull/1524

Make fairseq MultiheadAttention scriptable. Looking for feedbacks.

1. Add types
2. Move incremental state management logic from util functions to initializers. TorchScript in general doesn't support global dict. As a result modules with multihead attention in it would assign itself fairseq_instance_id in the initializer.
3. There might be opportunities to make assertions and annotations cleaner.

Reviewed By: myleott

Differential Revision: D18772594

fbshipit-source-id: 377aef4bbb7ef51da5b6bac9a87a6f7b03b16fe1
This commit is contained in:
Ning Dong 2020-01-21 18:24:00 -08:00 committed by Facebook Github Bot
parent 2535cab24f
commit 4e48c4ae5d
14 changed files with 177 additions and 80 deletions

View File

@ -0,0 +1,29 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from fairseq import utils
class FairseqIncrementalState(object):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
init_incremental_state(self)
def with_incremental_state(cls):
cls.__bases__ = (FairseqIncrementalState,) + tuple(b for b in cls.__bases__ if b != FairseqIncrementalState)
return cls
# In most cases we should register incremental states using @with_incremental_state decorator
# instead of calling into this explicitly in initializer.
def init_incremental_state(obj):
obj.module_name = obj.__class__.__name__
utils.INCREMENTAL_STATE_INSTANCE_ID[obj.module_name] = (
utils.INCREMENTAL_STATE_INSTANCE_ID.get(obj.module_name, 0) + 1
)
obj._fairseq_instance_id = utils.INCREMENTAL_STATE_INSTANCE_ID[
obj.module_name
]

View File

@ -223,7 +223,7 @@ class IterativeRefinementGenerator(object):
finalized_tokens = decoder_out.output_tokens[terminated]
finalized_scores = decoder_out.output_scores[terminated]
finalized_attn = (
None if decoder_out.attn is None else decoder_out.attn[terminated]
None if (decoder_out.attn is None or decoder_out.attn.size(0) == 0) else decoder_out.attn[terminated]
)
if self.retain_history:
@ -259,8 +259,12 @@ class IterativeRefinementGenerator(object):
prev_decoder_out = decoder_out._replace(
output_tokens=decoder_out.output_tokens[not_terminated],
output_scores=decoder_out.output_scores[not_terminated],
attn=decoder_out.attn[not_terminated] if decoder_out.attn is not None else None,
history=[h[not_terminated] for h in decoder_out.history] if decoder_out.history is not None else None
attn=decoder_out.attn[not_terminated]
if (decoder_out.attn is not None and decoder_out.attn.size(0) > 0)
else None,
history=[h[not_terminated] for h in decoder_out.history]
if decoder_out.history is not None
else None,
)
encoder_out = model.encoder.reorder_encoder_out(encoder_out, not_terminated.nonzero().squeeze())
sent_idxs = sent_idxs[not_terminated]

View File

@ -4,7 +4,6 @@
# LICENSE file in the root directory of this source tree.
import torch.nn as nn
from fairseq import utils
@ -29,7 +28,9 @@ class FairseqDecoder(nn.Module):
- the decoder's output of shape `(batch, tgt_len, vocab)`
- a dictionary with any model-specific outputs
"""
x, extra = self.extract_features(prev_output_tokens, encoder_out=encoder_out, **kwargs)
x, extra = self.extract_features(
prev_output_tokens, encoder_out=encoder_out, **kwargs
)
x = self.output_layer(x)
return x, extra
@ -54,10 +55,10 @@ class FairseqDecoder(nn.Module):
def get_normalized_probs(self, net_output, log_probs, sample):
"""Get normalized probabilities (or log probs) from a net's output."""
if hasattr(self, 'adaptive_softmax') and self.adaptive_softmax is not None:
if hasattr(self, "adaptive_softmax") and self.adaptive_softmax is not None:
if sample is not None:
assert 'target' in sample
target = sample['target']
assert "target" in sample
target = sample["target"]
else:
target = None
out = self.adaptive_softmax.get_log_prob(net_output[0], target=target)

View File

@ -4,8 +4,10 @@
# LICENSE file in the root directory of this source tree.
from fairseq.models import FairseqDecoder
from fairseq.incremental_decoding_utils import with_incremental_state
@with_incremental_state
class FairseqIncrementalDecoder(FairseqDecoder):
"""Base class for incremental decoders.

View File

@ -27,7 +27,7 @@ from fairseq.modules import (
LearnedPositionalEmbedding,
LinearizedConvolution,
)
from fairseq.incremental_decoding_utils import with_incremental_state
logger = logging.getLogger(__name__)
@ -291,6 +291,7 @@ class FConvEncoder(FairseqEncoder):
return self.embed_positions.max_positions()
@with_incremental_state
class FConvDecoder(FairseqDecoder):
"""Convolutional decoder"""
def __init__(

View File

@ -9,6 +9,7 @@ import torch.nn.functional as F
from fairseq import utils
from .unfold import unfold1d
from fairseq.incremental_decoding_utils import with_incremental_state
def DynamicConv(input_size, kernel_size=1, padding_l=None, num_heads=1,
@ -38,6 +39,7 @@ def Linear(in_features, out_features, bias=True):
return m
@with_incremental_state
class DynamicConv1dTBC(nn.Module):
'''Dynamic lightweight convolution taking T x B x C inputs
Args:

View File

@ -9,6 +9,7 @@ import torch.nn.functional as F
from fairseq import utils
from fairseq.modules.unfold import unfold1d
from fairseq.incremental_decoding_utils import with_incremental_state
def LightweightConv(input_size, kernel_size=1, padding_l=None, num_heads=1,
@ -99,6 +100,7 @@ class LightweightConv1d(nn.Module):
return output
@with_incremental_state
class LightweightConv1dTBC(nn.Module):
'''Lightweight Convolution assuming the input is TxBxC
Args:
@ -136,7 +138,6 @@ class LightweightConv1dTBC(nn.Module):
self.bias = None
self.reset_parameters()
self.onnx_trace = False
def reset_parameters(self):

View File

@ -7,10 +7,11 @@ import torch
import torch.nn.functional as F
from fairseq import utils
from .conv_tbc import ConvTBC
from fairseq.incremental_decoding_utils import with_incremental_state
@with_incremental_state
class LinearizedConvolution(ConvTBC):
"""An optimized version of nn.Conv1d.

View File

@ -4,14 +4,17 @@
# LICENSE file in the root directory of this source tree.
import math
from typing import Dict, Optional, Tuple
import torch
import torch.nn.functional as F
from fairseq import utils
from torch import nn
from torch import Tensor, nn
from torch.nn import Parameter
from fairseq.incremental_decoding_utils import with_incremental_state
@with_incremental_state
class MultiheadAttention(nn.Module):
"""Multi-headed attention.
@ -102,16 +105,16 @@ class MultiheadAttention(nn.Module):
def forward(
self,
query,
key,
value,
key_padding_mask=None,
incremental_state=None,
need_weights=True,
static_kv=False,
attn_mask=None,
before_softmax=False,
need_head_weights=False,
):
key: Optional[Tensor],
value: Optional[Tensor],
key_padding_mask: Optional[Tensor] = None,
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
need_weights: bool = True,
static_kv: bool = False,
attn_mask: Optional[Tensor] = None,
before_softmax: bool = False,
need_head_weights: bool = False,
) -> Tuple[Tensor, Optional[Tensor]]:
"""Input shape: Time x Batch x Channel
Args:
@ -142,6 +145,7 @@ class MultiheadAttention(nn.Module):
and incremental_state is None
and not static_kv
):
assert key is not None and value is not None
return F.multi_head_attention_forward(
query,
key,
@ -168,7 +172,7 @@ class MultiheadAttention(nn.Module):
if incremental_state is not None:
saved_state = self._get_input_buffer(incremental_state)
if "prev_key" in saved_state:
if saved_state is not None and "prev_key" in saved_state:
# previous time steps are cached - no need to recompute
# key and value if they are static
if static_kv:
@ -192,6 +196,7 @@ class MultiheadAttention(nn.Module):
v = self.v_proj(key)
else:
assert key is not None and value is not None
q = self.q_proj(query)
k = self.k_proj(key)
v = self.v_proj(value)
@ -235,24 +240,30 @@ class MultiheadAttention(nn.Module):
if saved_state is not None:
# saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
if "prev_key" in saved_state:
prev_key = saved_state["prev_key"].view(
bsz * self.num_heads, -1, self.head_dim
)
_prev_key = saved_state["prev_key"]
assert _prev_key is not None
prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim)
if static_kv:
k = prev_key
else:
k = torch.cat((prev_key, k), dim=1)
assert k is not None
k = torch.cat([prev_key, k], dim=1)
if "prev_value" in saved_state:
prev_value = saved_state["prev_value"].view(
bsz * self.num_heads, -1, self.head_dim
)
_prev_value = saved_state["prev_value"]
assert _prev_value is not None
prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim)
if static_kv:
v = prev_value
else:
v = torch.cat((prev_value, v), dim=1)
key_padding_mask = self._append_prev_key_padding_mask(
assert v is not None
v = torch.cat([prev_value, v], dim=1)
prev_key_padding_mask: Optional[Tensor] = None
if "prev_key_padding_mask" in saved_state:
prev_key_padding_mask = saved_state["prev_key_padding_mask"]
assert k is not None and v is not None
key_padding_mask = MultiheadAttention._append_prev_key_padding_mask(
key_padding_mask=key_padding_mask,
prev_key_padding_mask=saved_state.get("prev_key_padding_mask", None),
prev_key_padding_mask=prev_key_padding_mask,
batch_size=bsz,
src_len=k.size(1),
static_kv=static_kv,
@ -261,14 +272,15 @@ class MultiheadAttention(nn.Module):
saved_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim)
saved_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim)
saved_state["prev_key_padding_mask"] = key_padding_mask
# In this branch incremental_state is never None
assert incremental_state is not None
self._set_input_buffer(incremental_state, saved_state)
assert k is not None
src_len = k.size(1)
# This is part of a workaround to get around fork/join parallelism
# not supporting Optional types.
if key_padding_mask is not None and key_padding_mask.shape == torch.Size([]):
if key_padding_mask is not None and key_padding_mask.dim() == 0:
key_padding_mask = None
if key_padding_mask is not None:
@ -276,6 +288,7 @@ class MultiheadAttention(nn.Module):
assert key_padding_mask.size(1) == src_len
if self.add_zero_attn:
assert v is not None
src_len += 1
k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
@ -295,7 +308,7 @@ class MultiheadAttention(nn.Module):
)
attn_weights = torch.bmm(q, k.transpose(1, 2))
attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
attn_weights = MultiheadAttention.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
@ -309,7 +322,7 @@ class MultiheadAttention(nn.Module):
# don't attend to padding symbols
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights.masked_fill(
key_padding_mask.unsqueeze(1).unsqueeze(2), float("-inf")
key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), float("-inf")
)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
@ -325,7 +338,7 @@ class MultiheadAttention(nn.Module):
p=self.dropout,
training=self.training,
)
assert v is not None
attn = torch.bmm(attn_probs, v)
assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
if self.onnx_trace and attn.size(1) == 1:
@ -335,7 +348,7 @@ class MultiheadAttention(nn.Module):
else:
attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
attn = self.out_proj(attn)
attn_weights: Optional[Tensor] = None
if need_weights:
attn_weights = attn_weights_float.view(
bsz, self.num_heads, tgt_len, src_len
@ -343,40 +356,49 @@ class MultiheadAttention(nn.Module):
if not need_head_weights:
# average attention weights over heads
attn_weights = attn_weights.mean(dim=0)
else:
attn_weights = None
return attn, attn_weights
@staticmethod
def _append_prev_key_padding_mask(
key_padding_mask, prev_key_padding_mask, batch_size, src_len, static_kv
):
key_padding_mask: Optional[Tensor],
prev_key_padding_mask: Optional[Tensor],
batch_size: int,
src_len: int,
static_kv: bool,
) -> Optional[Tensor]:
# saved key padding masks have shape (bsz, seq_len)
if prev_key_padding_mask is not None and static_kv:
key_padding_mask = prev_key_padding_mask
new_key_padding_mask = prev_key_padding_mask
elif prev_key_padding_mask is not None and key_padding_mask is not None:
key_padding_mask = torch.cat(
(prev_key_padding_mask, key_padding_mask), dim=1
new_key_padding_mask = torch.cat(
[prev_key_padding_mask.float(), key_padding_mask.float()], dim=1
)
# During incremental decoding, as the padding token enters and
# leaves the frame, there will be a time when prev or current
# is None
elif prev_key_padding_mask is not None:
filler = torch.zeros(
batch_size, src_len - prev_key_padding_mask.size(1)
).bool()
filler = torch.zeros(batch_size, src_len - prev_key_padding_mask.size(1))
if prev_key_padding_mask.is_cuda:
filler = filler.cuda()
key_padding_mask = torch.cat((prev_key_padding_mask, filler), dim=1)
new_key_padding_mask = torch.cat(
[prev_key_padding_mask.float(), filler.float()], dim=1
)
elif key_padding_mask is not None:
filler = torch.zeros(batch_size, src_len - key_padding_mask.size(1)).bool()
filler = torch.zeros(batch_size, src_len - key_padding_mask.size(1))
if key_padding_mask.is_cuda:
filler = filler.cuda()
key_padding_mask = torch.cat((filler, key_padding_mask), dim=1)
return key_padding_mask
new_key_padding_mask = torch.cat(
[filler.float(), key_padding_mask.float()], dim=1
)
else:
new_key_padding_mask = prev_key_padding_mask
return new_key_padding_mask
def reorder_incremental_state(self, incremental_state, new_order):
def reorder_incremental_state(
self, incremental_state: Dict[str, Dict[str, Optional[Tensor]]], new_order
):
"""Reorder buffered internal state (for incremental generation)."""
input_buffer = self._get_input_buffer(incremental_state)
if input_buffer is not None:
@ -385,13 +407,28 @@ class MultiheadAttention(nn.Module):
input_buffer[k] = input_buffer[k].index_select(0, new_order)
self._set_input_buffer(incremental_state, input_buffer)
def _get_input_buffer(self, incremental_state):
return utils.get_incremental_state(self, incremental_state, "attn_state") or {}
def _get_input_buffer(
self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]
) -> Dict[str, Optional[Tensor]]:
empty_dict_annotated: Dict[str, Optional[Tensor]] = {}
if incremental_state is None:
return empty_dict_annotated
full_key = utils._get_full_incremental_state_key(self, "attn_state")
if full_key not in incremental_state:
return empty_dict_annotated
return incremental_state[full_key]
def _set_input_buffer(self, incremental_state, buffer):
utils.set_incremental_state(self, incremental_state, "attn_state", buffer)
def _set_input_buffer(
self,
incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
buffer: Dict[str, Optional[Tensor]],
):
full_key = utils._get_full_incremental_state_key(
self, "attn_state"
)
incremental_state[full_key] = buffer
def apply_sparse_mask(self, attn_weights, tgt_len, src_len, bsz):
def apply_sparse_mask(attn_weights, tgt_len: int, src_len: int, bsz: int):
return attn_weights
def upgrade_state_dict_named(self, state_dict, name):

View File

@ -225,10 +225,11 @@ class TransformerDecoderLayer(nn.Module):
if len(prev_self_attn_state) >= 3:
saved_state["prev_key_padding_mask"] = prev_self_attn_state[2]
self.self_attn._set_input_buffer(incremental_state, saved_state)
input_buffer = self.self_attn._get_input_buffer(incremental_state)
if self.cross_self_attention and not (
incremental_state is not None
and "prev_key" in self.self_attn._get_input_buffer(incremental_state)
and input_buffer is not None
and "prev_key" in input_buffer
):
if self_attn_mask is not None:
self_attn_mask = torch.cat(

View File

@ -274,6 +274,7 @@ class SequenceGenerator(object):
lprobs, avg_attn_scores = model.forward_decoder(
tokens[:, :step + 1], encoder_outs, temperature=self.temperature,
)
lprobs[lprobs != lprobs] = -math.inf
lprobs[:, self.pad] = -math.inf # never select pad
lprobs[:, self.unk] -= self.unk_penalty # apply unk penalty

View File

@ -13,11 +13,13 @@ import sys
import warnings
from collections import defaultdict
from itertools import accumulate
from typing import Callable, List
from typing import Callable, Dict, List, Optional
import torch
import torch.nn.functional as F
from fairseq.modules import gelu, gelu_accurate
from fairseq.modules.multihead_attention import MultiheadAttention
from torch import Tensor
logger = logging.getLogger(__name__)
@ -59,24 +61,22 @@ def move_to_cuda(sample):
return apply_to_sample(_move_to_cuda, sample)
INCREMENTAL_STATE_INSTANCE_ID = defaultdict(lambda: 0)
INCREMENTAL_STATE_INSTANCE_ID = {}
def _get_full_incremental_state_key(module_instance, key):
module_name = module_instance.__class__.__name__
# assign a unique ID to each module instance, so that incremental state is
# not shared across module instances
if not hasattr(module_instance, "_fairseq_instance_id"):
INCREMENTAL_STATE_INSTANCE_ID[module_name] += 1
module_instance._fairseq_instance_id = INCREMENTAL_STATE_INSTANCE_ID[
module_name
]
return "{}.{}.{}".format(module_name, module_instance._fairseq_instance_id, key)
def _get_full_incremental_state_key(
module_instance: MultiheadAttention, key: str
) -> str:
return "{}.{}.{}".format(
module_instance.module_name, module_instance._fairseq_instance_id, key
)
def get_incremental_state(module, incremental_state, key):
def get_incremental_state(
module: MultiheadAttention,
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]],
key: str,
) -> Optional[Dict[str, Optional[Tensor]]]:
"""Helper for getting incremental state for an nn.Module."""
full_key = _get_full_incremental_state_key(module, key)
if incremental_state is None or full_key not in incremental_state:
@ -84,7 +84,12 @@ def get_incremental_state(module, incremental_state, key):
return incremental_state[full_key]
def set_incremental_state(module, incremental_state, key, value):
def set_incremental_state(
module: MultiheadAttention,
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]],
key: str,
value: Dict[str, Optional[Tensor]],
):
"""Helper for setting incremental state for an nn.Module."""
if incremental_state is not None:
full_key = _get_full_incremental_state_key(module, key)
@ -323,7 +328,7 @@ def import_user_module(args):
sys.path.pop(0)
def softmax(x, dim, onnx_trace=False):
def softmax(x, dim: int, onnx_trace: bool = False):
if onnx_trace:
return F.softmax(x.float(), dim=dim)
else:

12
tests/test_export.py Normal file
View File

@ -0,0 +1,12 @@
#!/usr/bin/env python3
import unittest
import torch
from fairseq.modules import multihead_attention
class TestExportModels(unittest.TestCase):
def test_export_multihead_attention(self):
module = multihead_attention.MultiheadAttention(embed_dim=8, num_heads=2)
torch.jit.script(module)

View File

@ -74,7 +74,7 @@ class TestReproducibility(unittest.TestCase):
self._test_reproducibility('test_reproducibility_fp16', [
'--fp16',
'--fp16-init-scale', '4096',
], delta=0.01)
], delta=0.011)
@unittest.skipIf(not torch.cuda.is_available(), 'test requires a GPU')
def test_reproducibility_memory_efficient_fp16(self):