make Multihead_attention scriptable (#4773)

Co-authored-by: moslehpour <moslehpour@meta.com>
This commit is contained in:
Mohsen 2022-10-10 18:47:43 -07:00 committed by GitHub
parent a3bd672317
commit c20ba1fbe1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -20,9 +20,9 @@ except ImportError:
_xformers_available = False
from fairseq import utils
from fairseq.incremental_decoding_utils import with_incremental_state
from fairseq.modules.fairseq_dropout import FairseqDropout
from fairseq.modules.quant_noise import quant_noise
from fairseq.models.fairseq_incremental_decoder import FairseqIncrementalDecoder
# TODO: move this into xformers?
@ -60,8 +60,7 @@ def _mask_for_xformers(mask: Tensor, to_dtype: Optional[torch.dtype] = None):
return mask
@with_incremental_state
class MultiheadAttention(nn.Module):
class MultiheadAttention(FairseqIncrementalDecoder):
"""Multi-headed attention.
See "Attention Is All You Need" for more details.
@ -79,6 +78,7 @@ class MultiheadAttention(nn.Module):
add_zero_attn=False,
self_attention=False,
encoder_decoder_attention=False,
dictionary=None,
q_noise=0.0,
qn_block_size=8,
# TODO: pass in config rather than string.
@ -91,7 +91,7 @@ class MultiheadAttention(nn.Module):
int
] = 16, # This should be part of the config
):
super().__init__()
super().__init__(dictionary)
xformers_att_config = utils.eval_str_dict(xformers_att_config)
self.use_xformers = xformers_att_config is not None
@ -160,6 +160,7 @@ class MultiheadAttention(nn.Module):
self.onnx_trace = False
self.skip_embed_dim_check = False
self.init_incremental_state()
def prepare_for_onnx_export_(self):
self.onnx_trace = True
@ -467,7 +468,7 @@ class MultiheadAttention(nn.Module):
def forward(
self,
query,
query: Tensor,
key: Optional[Tensor],
value: Optional[Tensor],
key_padding_mask: Optional[Tensor] = None,
@ -739,6 +740,7 @@ class MultiheadAttention(nn.Module):
attn_probs = self.dropout_module(attn_weights)
assert v is not None
attn: Optional[Tensor] = None
if self.encoder_decoder_attention and bsz != kv_bsz:
attn = torch.einsum(
"bxhts,bhsd->bxhtd",
@ -827,7 +829,7 @@ class MultiheadAttention(nn.Module):
@torch.jit.export
def reorder_incremental_state(
self,
incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]],
new_order: Tensor,
):
"""Reorder buffered internal state (for incremental generation)."""
@ -868,7 +870,7 @@ class MultiheadAttention(nn.Module):
def _set_input_buffer(
self,
incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]],
buffer: Dict[str, Optional[Tensor]],
):
return self.set_incremental_state(incremental_state, "attn_state", buffer)