mirror of
https://github.com/facebookresearch/fairseq.git
synced 2024-11-13 07:41:39 +03:00
make Multihead_attention scriptable (#4773)
Co-authored-by: moslehpour <moslehpour@meta.com>
This commit is contained in:
parent
a3bd672317
commit
c20ba1fbe1
@ -20,9 +20,9 @@ except ImportError:
|
||||
_xformers_available = False
|
||||
|
||||
from fairseq import utils
|
||||
from fairseq.incremental_decoding_utils import with_incremental_state
|
||||
from fairseq.modules.fairseq_dropout import FairseqDropout
|
||||
from fairseq.modules.quant_noise import quant_noise
|
||||
from fairseq.models.fairseq_incremental_decoder import FairseqIncrementalDecoder
|
||||
|
||||
|
||||
# TODO: move this into xformers?
|
||||
@ -60,8 +60,7 @@ def _mask_for_xformers(mask: Tensor, to_dtype: Optional[torch.dtype] = None):
|
||||
return mask
|
||||
|
||||
|
||||
@with_incremental_state
|
||||
class MultiheadAttention(nn.Module):
|
||||
class MultiheadAttention(FairseqIncrementalDecoder):
|
||||
"""Multi-headed attention.
|
||||
|
||||
See "Attention Is All You Need" for more details.
|
||||
@ -79,6 +78,7 @@ class MultiheadAttention(nn.Module):
|
||||
add_zero_attn=False,
|
||||
self_attention=False,
|
||||
encoder_decoder_attention=False,
|
||||
dictionary=None,
|
||||
q_noise=0.0,
|
||||
qn_block_size=8,
|
||||
# TODO: pass in config rather than string.
|
||||
@ -91,7 +91,7 @@ class MultiheadAttention(nn.Module):
|
||||
int
|
||||
] = 16, # This should be part of the config
|
||||
):
|
||||
super().__init__()
|
||||
super().__init__(dictionary)
|
||||
|
||||
xformers_att_config = utils.eval_str_dict(xformers_att_config)
|
||||
self.use_xformers = xformers_att_config is not None
|
||||
@ -160,6 +160,7 @@ class MultiheadAttention(nn.Module):
|
||||
|
||||
self.onnx_trace = False
|
||||
self.skip_embed_dim_check = False
|
||||
self.init_incremental_state()
|
||||
|
||||
def prepare_for_onnx_export_(self):
|
||||
self.onnx_trace = True
|
||||
@ -467,7 +468,7 @@ class MultiheadAttention(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query,
|
||||
query: Tensor,
|
||||
key: Optional[Tensor],
|
||||
value: Optional[Tensor],
|
||||
key_padding_mask: Optional[Tensor] = None,
|
||||
@ -739,6 +740,7 @@ class MultiheadAttention(nn.Module):
|
||||
attn_probs = self.dropout_module(attn_weights)
|
||||
|
||||
assert v is not None
|
||||
attn: Optional[Tensor] = None
|
||||
if self.encoder_decoder_attention and bsz != kv_bsz:
|
||||
attn = torch.einsum(
|
||||
"bxhts,bhsd->bxhtd",
|
||||
@ -827,7 +829,7 @@ class MultiheadAttention(nn.Module):
|
||||
@torch.jit.export
|
||||
def reorder_incremental_state(
|
||||
self,
|
||||
incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
|
||||
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]],
|
||||
new_order: Tensor,
|
||||
):
|
||||
"""Reorder buffered internal state (for incremental generation)."""
|
||||
@ -868,7 +870,7 @@ class MultiheadAttention(nn.Module):
|
||||
|
||||
def _set_input_buffer(
|
||||
self,
|
||||
incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
|
||||
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]],
|
||||
buffer: Dict[str, Optional[Tensor]],
|
||||
):
|
||||
return self.set_incremental_state(incremental_state, "attn_state", buffer)
|
||||
|
Loading…
Reference in New Issue
Block a user