xformer integration (#2263)

Summary:
# Before submitting

- [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements)
- [x] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)?
- [ ] Did you make sure to update the docs?
- [x] Did you write any new necessary tests?

## What does this PR do?
This PR is a cleaned up version of https://github.com/fairinternal/fairseq-py/issues/2138. It is based on the `main` branch instead of the `gshard` branch. Removed call to xFormers MultiHeadDispatch, only using xFormers Attention.

## PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

## Did you have fun?
Make sure you had fun coding �

X-link: https://github.com/fairinternal/fairseq-py/pull/2263

Reviewed By: blefaudeux

Differential Revision: D33800377

Pulled By: dianaml0

fbshipit-source-id: 658d52214c782212b12881b30c4d908a763b4cf2
This commit is contained in:
dianaml0 2022-05-04 09:15:36 -07:00 committed by Facebook GitHub Bot
parent 0b54d9fb2e
commit 51478ad3a1
10 changed files with 735 additions and 29 deletions

View File

@ -45,6 +45,17 @@ install_dep_fused_ops: &install_dep_fused_ops
cd Megatron-LM
pip install -e .
install_dep_xformers: &install_dep_xformers
- run:
name: Install xFormers Dependencies
working_directory: ~/
command: |
source activate fairseq
git clone https://github.com/facebookresearch/xformers.git
cd xformers
pip install -r requirements.txt
pip install -e .
install_dep_pt19: &install_dep_pt19
- run:
@ -123,6 +134,7 @@ jobs:
- <<: *install_dep_pt19
- <<: *install_dep_common
- <<: *install_dep_fused_ops
- <<: *install_dep_xformers
- save_cache:
paths:
- ~/miniconda/
@ -144,6 +156,7 @@ jobs:
- <<: *install_dep_pt18
- <<: *install_dep_common
- <<: *install_dep_fused_ops
- <<: *install_dep_xformers
- save_cache:
paths:
- ~/miniconda/

View File

@ -41,6 +41,8 @@ jobs:
run: |
python -m pip install iopath transformers pyarrow
python -m pip install git+https://github.com/facebookresearch/fairscale.git@main
python -m pip install --progress-bar off git+https://github.com/facebookresearch/xformers.git@main
python -m pip install pytest
- name: Lint with flake8
run: |

View File

@ -69,6 +69,7 @@ We provide reference implementations of various sequence modeling papers:
</p></details>
### What's New:
* May 2022 [Integration with xFormers](https://github.com/facebookresearch/xformers)
* December 2021 [Released Direct speech-to-speech translation code](examples/speech_to_speech/README.md)
* October 2021 [Released VideoCLIP and VLM models](examples/MMPT/README.md)
* October 2021 [Released multilingual finetuned XLSR-53 model](examples/wav2vec/README.md)

View File

@ -0,0 +1,172 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import itertools
import random
import torch
from torch.utils import benchmark
from fairseq.modules.multihead_attention import MultiheadAttention
BATCH = [20, 41, 97]
SEQ = 64
EMB = 48
HEADS = 4
DROP = 0.1
DEVICE = torch.device("cuda")
ATTN_MASK_DTYPE = [torch.uint8, torch.bool, torch.float]
KEY_PADDING_MASK_DTYPE = [torch.uint8, torch.bool]
def _reset_seeds():
torch.manual_seed(0)
random.seed(0)
def _get_mask(to_dtype: torch.dtype, dim0: int, dim1: int):
if to_dtype == torch.float:
mask = torch.randint(0, 2, (dim0, dim1)).to(dtype=torch.bool)
return mask.to(dtype=to_dtype).masked_fill(mask, -float("inf"))
return torch.randint(0, 2, (dim0, dim1)).to(dtype=to_dtype)
def benchmark_multihead_attention(
label="",
attn_dtype=torch.uint8,
key_padding_dtype=torch.uint8,
add_bias_kv=False,
add_zero_attn=False,
static_kv=False,
batch_size=20,
embedding=EMB,
seq_len=SEQ,
num_heads=HEADS,
):
results = []
# device = torch.device("cuda")
xformers_att_config = '{"name": "scaled_dot_product"}'
attn_mask = _get_mask(to_dtype=attn_dtype, dim0=seq_len, dim1=seq_len)
key_padding_mask = _get_mask(
to_dtype=key_padding_dtype, dim0=batch_size, dim1=seq_len
)
q = torch.rand(seq_len, batch_size, embedding, requires_grad=True)
k = torch.rand(seq_len, batch_size, embedding, requires_grad=True)
v = torch.rand(seq_len, batch_size, embedding, requires_grad=True)
_reset_seeds()
original_mha = MultiheadAttention(
embedding,
num_heads,
dropout=0.0,
xformers_att_config=None,
add_bias_kv=add_bias_kv,
add_zero_attn=add_zero_attn,
)
xformers_mha = MultiheadAttention(
embedding,
num_heads,
dropout=0.0,
xformers_att_config=xformers_att_config,
add_bias_kv=add_bias_kv,
add_zero_attn=add_zero_attn,
)
def original_bench_fw(q, k, v, key_padding_mask, attn_mask, static_kv):
original_mha(
query=q,
key=k,
value=v,
key_padding_mask=key_padding_mask,
attn_mask=attn_mask,
static_kv=static_kv,
)
def xformers_bench_fw(q, k, v, key_padding_mask, attn_mask, static_kv):
xformers_mha(
query=q,
key=k,
value=v,
key_padding_mask=key_padding_mask,
attn_mask=attn_mask,
static_kv=static_kv,
)
def original_bench_fw_bw(q, k, v, key_padding_mask, attn_mask, static_kv):
output, _ = original_mha(
query=q,
key=k,
value=v,
key_padding_mask=key_padding_mask,
attn_mask=attn_mask,
static_kv=static_kv,
)
loss = torch.norm(output)
loss.backward()
def xformers_bench_fw_bw(q, k, v, key_padding_mask, attn_mask, static_kv):
output, _ = xformers_mha(
query=q,
key=k,
value=v,
key_padding_mask=key_padding_mask,
attn_mask=attn_mask,
static_kv=static_kv,
)
loss = torch.norm(output)
loss.backward()
fns = [
original_bench_fw,
xformers_bench_fw,
original_bench_fw_bw,
xformers_bench_fw_bw,
]
for fn in fns:
results.append(
benchmark.Timer(
stmt="fn(q, k, v, key_padding_mask, attn_mask, static_kv)",
globals={
"q": q,
"k": k,
"v": v,
"key_padding_mask": key_padding_mask,
"attn_mask": attn_mask,
"static_kv": static_kv,
"fn": fn,
},
label="multihead fw + bw",
sub_label=f"{fn.__name__}",
description=label,
).blocked_autorange(min_run_time=1)
)
compare = benchmark.Compare(results)
compare.print()
def run_benchmarks():
for attn_dtype, key_padding_dtype, add_bias_kv, add_zero_attn in itertools.product(
ATTN_MASK_DTYPE, KEY_PADDING_MASK_DTYPE, [True, False], [True, False]
):
label = f"attn_dtype {attn_dtype}, key_padding_dtype {key_padding_dtype}, \
add_bias_kv {add_bias_kv}, add_zero_attn {add_zero_attn}"
benchmark_multihead_attention(
label=label,
attn_dtype=attn_dtype,
key_padding_dtype=key_padding_dtype,
add_bias_kv=add_bias_kv,
add_zero_attn=add_zero_attn,
)
run_benchmarks()

View File

@ -49,6 +49,13 @@ class EncDecBaseConfig(FairseqDataclass):
default=None, metadata={"help": "which layers to *keep* when pruning"}
)
xformers_att_config: Optional[str] = field(
default=None,
metadata={
"help": "config for xFormers attention, defined in xformers.components.attention.AttentionConfig"
},
)
@dataclass
class DecoderConfig(EncDecBaseConfig):

View File

@ -7,6 +7,8 @@
from dataclasses import dataclass, field
from typing import Optional
from omegaconf import II
from fairseq import options, utils
from fairseq.dataclass import ChoiceEnum, FairseqDataclass
from fairseq.models import (
@ -21,8 +23,6 @@ from fairseq.models.transformer import (
)
from fairseq.modules import AdaptiveInput, CharacterTokenEmbedder
from fairseq.utils import safe_getattr, safe_hasattr
from omegaconf import II
DEFAULT_MAX_TARGET_POSITIONS = 1024
@ -210,6 +210,15 @@ class TransformerLanguageModelConfig(FairseqDataclass):
default=False,
metadata={"help": "Learn a scale coefficient for each residual connection"},
)
# xFormers arguments
decoder_xformers_att_config: Optional[str] = field(
default=None,
metadata={
"help": "config for xFormers library attention, defined in xformers.components.attention.AttentionConfig",
},
)
# options from other parts of the config
add_bos_token: bool = II("task.add_bos_token")
tokens_per_sample: int = II("task.tokens_per_sample")

View File

@ -10,6 +10,8 @@ import torch
import torch.nn.functional as F
from torch import Tensor, nn
from torch.nn import Parameter
from xformers.components.attention import build_attention
from xformers.components.attention.utils import maybe_merge_masks
from fairseq import utils
from fairseq.incremental_decoding_utils import with_incremental_state
@ -17,6 +19,40 @@ from fairseq.modules.fairseq_dropout import FairseqDropout
from fairseq.modules.quant_noise import quant_noise
# TODO: move this into xformers?
# TODO: uint8 input type should just output a bool
def _mask_for_xformers(mask: Tensor, to_dtype: Optional[torch.dtype] = None):
"""
call to pytorch multihead accepts three mask types:
- ByteTensor where non-zero means to mask
- FloatTensor which is an additive mask
- BoolTensor where True means to mask
xFormers currently accepts boolean and additive maks. For boolean masks
the values have opposite meaning. For a BoolTensor True mean to keep the value.
"""
float_types = [torch.float, torch.float16]
# If an input mask is a float it is an additive mask. Otherwise it is either uint8 or bool.
additive = mask.dtype in float_types
# If to_dype is not specified, keep same dtype as mask.
to_dtype = mask.dtype if to_dtype is None else to_dtype
to_additive = to_dtype in float_types
if additive:
if to_additive:
return mask.to(to_dtype)
mask = mask < 0
if to_additive:
# return additive mask
new_mask = torch.zeros_like(mask, dtype=to_dtype)
new_mask = new_mask.masked_fill_(mask, -float("inf"))
return new_mask
# In xFormers True is value to keep rather than value to mask
mask = ~mask.to(torch.bool)
mask = mask.to(to_dtype)
return mask
@with_incremental_state
class MultiheadAttention(nn.Module):
"""Multi-headed attention.
@ -38,8 +74,20 @@ class MultiheadAttention(nn.Module):
encoder_decoder_attention=False,
q_noise=0.0,
qn_block_size=8,
# TODO: pass in config rather than string.
# config defined in xformers.components.attention.AttentionConfig
xformers_att_config: Optional[str] = None,
xformers_blocksparse_layout: Optional[
torch.Tensor
] = None, # This should be part of the config
xformers_blocksparse_blocksize: Optional[
int
] = 16, # This should be part of the config
):
super().__init__()
xformers_att_config = utils.eval_str_dict(xformers_att_config)
self.use_xformers = xformers_att_config is not None
self.embed_dim = embed_dim
self.kdim = kdim if kdim is not None else embed_dim
self.vdim = vdim if vdim is not None else embed_dim
@ -87,6 +135,23 @@ class MultiheadAttention(nn.Module):
self.beam_size = 1
self.reset_parameters()
self.fp16_mask = False
if self.use_xformers:
xformers_att_config["dropout"] = xformers_att_config.get("dropout", dropout)
xformers_att_config["num_heads"] = xformers_att_config.get(
"num_heads", num_heads
)
if xformers_blocksparse_layout is not None:
# Could be part of a single config passed only once
xformers_att_config["block_size"] = xformers_blocksparse_blocksize
xformers_att_config["layout"] = xformers_blocksparse_layout
xformers_att_config["name"] = "blocksparse"
# Mask required to be float16
self.fp16_mask = True
self.attention = build_attention(xformers_att_config)
self.onnx_trace = False
self.skip_embed_dim_check = False
@ -296,6 +361,102 @@ class MultiheadAttention(nn.Module):
)
return k, v, key_padding_mask, attn_mask
def _xformers_attn_forward(
self,
query,
key: Optional[Tensor],
value: Optional[Tensor],
key_padding_mask: Optional[Tensor] = None,
need_weights: bool = True,
attn_mask: Optional[Tensor] = None,
) -> Tuple[Tensor, Optional[Tensor]]:
tgt_len, bsz, embed_dim = query.size()
if key_padding_mask is not None:
assert key_padding_mask.size(0) == bsz
assert key_padding_mask.size(1) == tgt_len
if self.self_attention:
key = query
value = query
elif self.encoder_decoder_attention:
value = key
q = self.q_proj(query)
k = self.k_proj(key)
v = self.v_proj(value)
if self.bias_k is not None:
assert self.bias_v is not None
k, v, attn_mask, key_padding_mask = self._add_bias(
k, v, attn_mask, key_padding_mask, bsz
)
def fold_heads(x):
return (
x.contiguous()
.view(-1, bsz * self.num_heads, self.head_dim)
.transpose(0, 1)
)
def split_heads(x):
return (
x.contiguous()
.view(-1, bsz, self.num_heads, self.head_dim)
.transpose(0, 1)
.transpose(1, 2)
)
massage = split_heads if self.attention.requires_head_dimension else fold_heads
q = massage(q)
if k is not None:
k = massage(k)
if v is not None:
v = massage(v)
if self.add_zero_attn:
k, v, key_padding_mask, attn_mask = self._append_zero_attn(
k=k, v=v, key_padding_mask=key_padding_mask, attn_mask=attn_mask
)
if attn_mask is not None:
to_dtype = torch.float16 if self.fp16_mask else q.dtype
attn_mask = _mask_for_xformers(attn_mask, to_dtype=to_dtype)
if key_padding_mask is not None:
to_dtype = torch.float16 if self.fp16_mask else torch.bool
key_padding_mask = _mask_for_xformers(key_padding_mask, to_dtype=to_dtype)
if not self.attention.requires_separate_masks:
attn_mask = maybe_merge_masks(
attn_mask,
key_padding_mask,
batch_size=bsz,
src_len=k.size(-2),
tgt_len=q.size(-2),
num_heads=self.num_heads,
)
key_padding_mask = None
y = self.attention(
q, k, v, att_mask=attn_mask, key_padding_mask=key_padding_mask
)
y = (
y.view(bsz, self.num_heads, tgt_len, self.head_dim)
.transpose(1, 2)
.flatten(start_dim=2, end_dim=3)
.transpose(0, 1)
)
assert list(y.size()) == [tgt_len, bsz, embed_dim]
# Dropout not needed because already applied in attention.
# It is applied to the attention weights before matmul with v.
y = self.out_proj(y)
# TODO: support returning attention weights if needed.
return y, None
def forward(
self,
query,
@ -359,29 +520,36 @@ class MultiheadAttention(nn.Module):
and not self.skip_embed_dim_check
):
assert key is not None and value is not None
return F.multi_head_attention_forward(
query,
key,
value,
self.embed_dim,
self.num_heads,
torch.empty([0]),
torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)),
self.bias_k,
self.bias_v,
self.add_zero_attn,
self.dropout_module.p,
self.out_proj.weight,
self.out_proj.bias,
self.training or self.dropout_module.apply_during_inference,
key_padding_mask,
need_weights,
attn_mask,
use_separate_proj_weight=True,
q_proj_weight=self.q_proj.weight,
k_proj_weight=self.k_proj.weight,
v_proj_weight=self.v_proj.weight,
)
if self.use_xformers:
return self._xformers_attn_forward(
query, key, value, key_padding_mask, need_weights, attn_mask
)
else:
return F.multi_head_attention_forward(
query,
key,
value,
self.embed_dim,
self.num_heads,
torch.empty([0]),
torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)),
self.bias_k,
self.bias_v,
self.add_zero_attn,
self.dropout_module.p,
self.out_proj.weight,
self.out_proj.bias,
self.training or self.dropout_module.apply_during_inference,
key_padding_mask,
need_weights,
attn_mask,
use_separate_proj_weight=True,
q_proj_weight=self.q_proj.weight,
k_proj_weight=self.k_proj.weight,
v_proj_weight=self.v_proj.weight,
)
if incremental_state is not None:
saved_state = self._get_input_buffer(incremental_state)

View File

@ -141,6 +141,7 @@ class TransformerEncoderLayerBase(nn.Module):
self_attention=True,
q_noise=self.quant_noise,
qn_block_size=self.quant_noise_block_size,
xformers_att_config=cfg.encoder.xformers_att_config,
)
def residual_connection(self, x, residual):
@ -359,6 +360,7 @@ class TransformerDecoderLayerBase(nn.Module):
self_attention=not cfg.cross_self_attention,
q_noise=self.quant_noise,
qn_block_size=self.quant_noise_block_size,
xformers_att_config=cfg.decoder.xformers_att_config,
)
def build_encoder_attention(self, embed_dim, cfg):
@ -371,6 +373,7 @@ class TransformerDecoderLayerBase(nn.Module):
encoder_decoder_attention=True,
q_noise=self.quant_noise,
qn_block_size=self.quant_noise_block_size,
xformers_att_config=cfg.encoder.xformers_att_config,
)
def prepare_for_onnx_export_(self):

View File

@ -173,7 +173,6 @@ else:
if "clean" in sys.argv[1:]:
# Source: https://bit.ly/2NLVsgE
print("deleting Cython files...")
import subprocess
subprocess.run(
["rm -f fairseq/*.so fairseq/**/*.so fairseq/*.pyd fairseq/**/*.pyd"],

View File

@ -3,11 +3,343 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import random
import unittest
import pytest
import torch
from fairseq.modules.multihead_attention import MultiheadAttention, _mask_for_xformers
from fairseq.modules.multihead_attention import MultiheadAttention
BATCH = [20, 41, 97]
SEQ = [64]
EMB = [48]
HEADS = [4]
DROP = 0.1
DEVICE = ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"]
ATTN_MASK_DTYPE = [None, torch.uint8, torch.bool, torch.float]
KEY_PADDING_MASK_DTYPE = [None, torch.uint8, torch.bool]
# FIXME: some tests fail when decimal=2, fix this and set decimal to 2
def assert_almost_equal(x, y, decimal=1, err_msg=""):
import numpy.testing as npt
if isinstance(x, torch.Tensor):
x = x.cpu().detach().numpy()
if isinstance(y, torch.Tensor):
y = y.cpu().detach().numpy()
npt.assert_array_almost_equal(x, y, err_msg=err_msg, decimal=decimal)
def _reset_seeds():
torch.manual_seed(0)
torch.random.manual_seed(0)
random.seed(0)
torch.cuda.manual_seed_all(0)
def _get_mask(to_dtype: torch.dtype, dim0: int, dim1: int):
if to_dtype == torch.float:
mask = torch.randint(0, 2, (dim0, dim1)).to(dtype=torch.bool)
return mask.to(dtype=to_dtype).masked_fill(mask, -float("inf"))
return torch.randint(0, 2, (dim0, dim1)).to(dtype=to_dtype)
def test_mask_for_xformers():
# Additive Mask
m_float_add = torch.tensor([float("-inf"), 0]).to(torch.float)
m_float_add_flipped = torch.tensor([0, float("-inf")]).to(torch.float)
m_float16_add = torch.tensor([float("-inf"), 0]).to(torch.float16)
m_float16_add_flipped = torch.tensor([0, float("-inf")]).to(torch.float16)
m_uint = torch.tensor([1, 0]).to(torch.uint8)
m_uint_flipped = torch.tensor([0, 1]).to(torch.uint8)
m_bool = torch.tensor([False, True])
assert torch.equal(_mask_for_xformers(m_float_add), m_float_add)
assert torch.equal(_mask_for_xformers(m_float16_add), m_float16_add)
assert torch.equal(_mask_for_xformers(m_uint), m_uint_flipped)
assert torch.equal(_mask_for_xformers(m_bool), ~m_bool)
assert torch.equal(
_mask_for_xformers(m_float_add, to_dtype=torch.float16), m_float16_add
)
assert torch.equal(
_mask_for_xformers(m_float_add, to_dtype=torch.float), m_float_add
)
assert torch.equal(_mask_for_xformers(m_float_add, to_dtype=torch.bool), m_bool)
assert torch.equal(
_mask_for_xformers(m_float_add, to_dtype=torch.uint8), m_uint_flipped
)
assert torch.equal(
_mask_for_xformers(m_float16_add, to_dtype=torch.float16), m_float16_add
)
assert torch.equal(
_mask_for_xformers(m_float16_add, to_dtype=torch.float), m_float_add
)
assert torch.equal(_mask_for_xformers(m_float16_add, to_dtype=torch.bool), m_bool)
assert torch.equal(
_mask_for_xformers(m_float16_add, to_dtype=torch.uint8), m_uint_flipped
)
assert torch.equal(
_mask_for_xformers(m_bool, to_dtype=torch.float16), m_float16_add_flipped
)
assert torch.equal(
_mask_for_xformers(m_bool, to_dtype=torch.float), m_float_add_flipped
)
assert torch.equal(_mask_for_xformers(m_bool, to_dtype=torch.bool), ~m_bool)
assert torch.equal(_mask_for_xformers(m_bool, to_dtype=torch.uint8), m_uint)
assert torch.equal(
_mask_for_xformers(m_uint, to_dtype=torch.float16), m_float16_add
)
assert torch.equal(_mask_for_xformers(m_uint, to_dtype=torch.float), m_float_add)
assert torch.equal(_mask_for_xformers(m_uint, to_dtype=torch.bool), m_bool)
assert torch.equal(_mask_for_xformers(m_uint, to_dtype=torch.uint8), m_uint_flipped)
@pytest.mark.parametrize("device", ["cuda"])
@pytest.mark.parametrize("attn_dtype", ATTN_MASK_DTYPE)
@pytest.mark.parametrize("key_padding_dtype", KEY_PADDING_MASK_DTYPE)
@pytest.mark.parametrize("add_zero_attn", [False])
@pytest.mark.parametrize("batch_size", [20])
@pytest.mark.parametrize("embedding", [64])
@pytest.mark.parametrize("seq_len", [64])
@pytest.mark.parametrize("num_heads", [4])
def test_xformers_blocksparse_parity(
device,
attn_dtype,
key_padding_dtype,
add_zero_attn,
batch_size,
embedding,
seq_len,
num_heads,
):
xformers_att_config = '{"name": "scaled_dot_product"}'
xformers_blocksparse_blocksize = 16
xformers_blocksparse_layout = torch.ones(
seq_len // xformers_blocksparse_blocksize,
seq_len // xformers_blocksparse_blocksize,
)
attn_mask = (
None
if attn_dtype is None
else _get_mask(to_dtype=attn_dtype, dim0=seq_len, dim1=seq_len).to(device)
)
key_padding_mask = (
None
if key_padding_dtype is None
else _get_mask(to_dtype=key_padding_dtype, dim0=batch_size, dim1=seq_len).to(
device
)
)
q = torch.rand(seq_len, batch_size, embedding).to(device).half()
q.requires_grad = True
k = torch.rand(seq_len, batch_size, embedding).to(device).half()
k.requires_grad = True
v = torch.rand(seq_len, batch_size, embedding).to(device).half()
v.requires_grad = True
q_ = q.detach().clone().half()
q_.requires_grad = True
k_ = k.detach().clone().half()
k_.requires_grad = True
v_ = v.detach().clone().half()
v_.requires_grad = True
_reset_seeds()
xf_blocksparse_mha = (
MultiheadAttention(
embedding,
num_heads,
dropout=0.0,
add_zero_attn=add_zero_attn,
xformers_att_config=xformers_att_config,
xformers_blocksparse_layout=xformers_blocksparse_layout,
xformers_blocksparse_blocksize=xformers_blocksparse_blocksize,
)
.to(device)
.half()
)
xf_blocksparse_output, _ = xf_blocksparse_mha(
q,
k,
v,
key_padding_mask=key_padding_mask,
attn_mask=attn_mask,
)
_reset_seeds()
xformers_mha = (
MultiheadAttention(
embedding,
num_heads,
dropout=0.0,
add_zero_attn=add_zero_attn,
xformers_att_config=xformers_att_config,
xformers_blocksparse_layout=None,
)
.to(device)
.half()
)
xformers_output, _ = xformers_mha(
q_,
k_,
v_,
key_padding_mask=key_padding_mask,
attn_mask=attn_mask,
)
# # account for when nan != nan
rand = random.uniform(0, 1)
xformers_output = xformers_output.masked_fill(xformers_output.isnan(), rand)
xf_blocksparse_output = xf_blocksparse_output.masked_fill(
xf_blocksparse_output.isnan(), rand
)
assert_almost_equal(xformers_output, xf_blocksparse_output)
loss_blocksparse = torch.norm(xformers_output)
loss_original = torch.norm(xf_blocksparse_output)
loss_blocksparse.backward()
loss_original.backward()
q.masked_fill(q.isnan(), rand)
q_.masked_fill(q_.isnan(), rand)
k.masked_fill(k.isnan(), rand)
k_.masked_fill(k_.isnan(), rand)
v.masked_fill(v.isnan(), rand)
v_.masked_fill(v_.isnan(), rand)
assert_almost_equal(q.grad, q_.grad)
assert_almost_equal(k.grad, k_.grad)
assert_almost_equal(v.grad, v_.grad)
@pytest.mark.parametrize("device", DEVICE)
@pytest.mark.parametrize("attn_dtype", ATTN_MASK_DTYPE)
@pytest.mark.parametrize("key_padding_dtype", KEY_PADDING_MASK_DTYPE)
@pytest.mark.parametrize("add_bias_kv", [True, False])
@pytest.mark.parametrize("add_zero_attn", [True, False])
# TODO: test with static_kv True
@pytest.mark.parametrize("static_kv", [False])
@pytest.mark.parametrize("batch_size", BATCH)
@pytest.mark.parametrize("embedding", EMB)
@pytest.mark.parametrize("seq_len", SEQ)
@pytest.mark.parametrize("num_heads", HEADS)
def test_xformers_single_forward_parity(
device,
attn_dtype,
key_padding_dtype,
add_bias_kv,
add_zero_attn,
static_kv,
batch_size,
embedding,
seq_len,
num_heads,
):
xformers_att_config = '{"name": "scaled_dot_product"}'
attn_mask = (
None
if attn_dtype is None
else _get_mask(to_dtype=attn_dtype, dim0=seq_len, dim1=seq_len).to(device)
)
key_padding_mask = (
None
if key_padding_dtype is None
else _get_mask(to_dtype=key_padding_dtype, dim0=batch_size, dim1=seq_len).to(
device
)
)
q = torch.rand(seq_len, batch_size, embedding).to(device)
q.requires_grad = True
k = torch.rand(seq_len, batch_size, embedding).to(device)
k.requires_grad = True
v = torch.rand(seq_len, batch_size, embedding).to(device)
v.requires_grad = True
q_ = q.detach().clone()
q_.requires_grad = True
k_ = k.detach().clone()
k_.requires_grad = True
v_ = v.detach().clone()
v_.requires_grad = True
# TODO: dropouts in the two implementations lead to different entries dropped.
_reset_seeds()
xformers_mha = MultiheadAttention(
embedding,
num_heads,
dropout=0.0,
xformers_att_config=xformers_att_config,
add_bias_kv=add_bias_kv,
add_zero_attn=add_zero_attn,
).to(device)
xformers_output, _ = xformers_mha(
q,
k,
v,
key_padding_mask=key_padding_mask,
attn_mask=attn_mask,
static_kv=static_kv,
)
_reset_seeds()
original_mha = MultiheadAttention(
embedding,
num_heads,
dropout=0.0,
xformers_att_config=None,
add_bias_kv=add_bias_kv,
add_zero_attn=add_zero_attn,
).to(device)
original_output, _ = original_mha(
q_,
k_,
v_,
key_padding_mask=key_padding_mask,
attn_mask=attn_mask,
static_kv=static_kv,
)
# account for when nan != nan
if xformers_output.isnan().any() or original_output.isnan().any():
rand = random.uniform(0, 1)
xformers_output = xformers_output.masked_fill(xformers_output.isnan(), rand)
original_output = original_output.masked_fill(original_output.isnan(), rand)
# torch.equal works for cpu, on cuda allclose is needed.
assert torch.allclose(
xformers_output, original_output, atol=1e-06
), f"max diff is {torch.max(torch.abs(xformers_output - original_output))}"
loss_xformers = torch.norm(xformers_output)
loss_original = torch.norm(original_output)
loss_xformers.backward()
loss_original.backward()
# torch.equal works for cpu, on cuda allclose is needed.
assert torch.allclose(
q.grad, q_.grad
), f"max diff is {torch.max(torch.abs(q.grad - q_.grad))}"
assert torch.allclose(
k.grad, k_.grad
), f"max diff is {torch.max(torch.abs(k.grad - k_.grad))}"
assert torch.allclose(
v.grad, v_.grad
), f"max diff is {torch.max(torch.abs(v.grad - v_.grad))}"
def test_mask_padding_parity():
@ -28,7 +360,7 @@ def test_mask_padding_parity():
# values don't matter for this test.
mha = MultiheadAttention(
embedding=8,
embed_dim=8,
num_heads=2,
dropout=0.0,
add_bias_kv=True,
@ -50,7 +382,7 @@ def test_mask_padding_parity():
def test_add_bias_parity():
# values don't matter for this test.
mha = MultiheadAttention(
embedding=8,
embed_dim=8,
num_heads=2,
dropout=0.0,
add_bias_kv=True,