xformer integration (#2263)

Summary:
# Before submitting

- [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements)
- [x] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)?
- [ ] Did you make sure to update the docs?
- [x] Did you write any new necessary tests?

## What does this PR do?
This PR is a cleaned up version of https://github.com/fairinternal/fairseq-py/issues/2138. It is based on the `main` branch instead of the `gshard` branch. Removed call to xFormers MultiHeadDispatch, only using xFormers Attention.

## PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

## Did you have fun?
Make sure you had fun coding �

X-link: https://github.com/fairinternal/fairseq-py/pull/2263

Reviewed By: blefaudeux

Differential Revision: D33800377

Pulled By: dianaml0

fbshipit-source-id: 658d52214c782212b12881b30c4d908a763b4cf2
This commit is contained in:
dianaml0 2022-05-04 09:15:36 -07:00 committed by Facebook GitHub Bot
parent 0b54d9fb2e
commit 51478ad3a1
10 changed files with 735 additions and 29 deletions

View File

@ -45,6 +45,17 @@ install_dep_fused_ops: &install_dep_fused_ops
cd Megatron-LM cd Megatron-LM
pip install -e . pip install -e .
install_dep_xformers: &install_dep_xformers
- run:
name: Install xFormers Dependencies
working_directory: ~/
command: |
source activate fairseq
git clone https://github.com/facebookresearch/xformers.git
cd xformers
pip install -r requirements.txt
pip install -e .
install_dep_pt19: &install_dep_pt19 install_dep_pt19: &install_dep_pt19
- run: - run:
@ -123,6 +134,7 @@ jobs:
- <<: *install_dep_pt19 - <<: *install_dep_pt19
- <<: *install_dep_common - <<: *install_dep_common
- <<: *install_dep_fused_ops - <<: *install_dep_fused_ops
- <<: *install_dep_xformers
- save_cache: - save_cache:
paths: paths:
- ~/miniconda/ - ~/miniconda/
@ -144,6 +156,7 @@ jobs:
- <<: *install_dep_pt18 - <<: *install_dep_pt18
- <<: *install_dep_common - <<: *install_dep_common
- <<: *install_dep_fused_ops - <<: *install_dep_fused_ops
- <<: *install_dep_xformers
- save_cache: - save_cache:
paths: paths:
- ~/miniconda/ - ~/miniconda/

View File

@ -41,6 +41,8 @@ jobs:
run: | run: |
python -m pip install iopath transformers pyarrow python -m pip install iopath transformers pyarrow
python -m pip install git+https://github.com/facebookresearch/fairscale.git@main python -m pip install git+https://github.com/facebookresearch/fairscale.git@main
python -m pip install --progress-bar off git+https://github.com/facebookresearch/xformers.git@main
python -m pip install pytest
- name: Lint with flake8 - name: Lint with flake8
run: | run: |

View File

@ -69,6 +69,7 @@ We provide reference implementations of various sequence modeling papers:
</p></details> </p></details>
### What's New: ### What's New:
* May 2022 [Integration with xFormers](https://github.com/facebookresearch/xformers)
* December 2021 [Released Direct speech-to-speech translation code](examples/speech_to_speech/README.md) * December 2021 [Released Direct speech-to-speech translation code](examples/speech_to_speech/README.md)
* October 2021 [Released VideoCLIP and VLM models](examples/MMPT/README.md) * October 2021 [Released VideoCLIP and VLM models](examples/MMPT/README.md)
* October 2021 [Released multilingual finetuned XLSR-53 model](examples/wav2vec/README.md) * October 2021 [Released multilingual finetuned XLSR-53 model](examples/wav2vec/README.md)

View File

@ -0,0 +1,172 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import itertools
import random
import torch
from torch.utils import benchmark
from fairseq.modules.multihead_attention import MultiheadAttention
BATCH = [20, 41, 97]
SEQ = 64
EMB = 48
HEADS = 4
DROP = 0.1
DEVICE = torch.device("cuda")
ATTN_MASK_DTYPE = [torch.uint8, torch.bool, torch.float]
KEY_PADDING_MASK_DTYPE = [torch.uint8, torch.bool]
def _reset_seeds():
torch.manual_seed(0)
random.seed(0)
def _get_mask(to_dtype: torch.dtype, dim0: int, dim1: int):
if to_dtype == torch.float:
mask = torch.randint(0, 2, (dim0, dim1)).to(dtype=torch.bool)
return mask.to(dtype=to_dtype).masked_fill(mask, -float("inf"))
return torch.randint(0, 2, (dim0, dim1)).to(dtype=to_dtype)
def benchmark_multihead_attention(
label="",
attn_dtype=torch.uint8,
key_padding_dtype=torch.uint8,
add_bias_kv=False,
add_zero_attn=False,
static_kv=False,
batch_size=20,
embedding=EMB,
seq_len=SEQ,
num_heads=HEADS,
):
results = []
# device = torch.device("cuda")
xformers_att_config = '{"name": "scaled_dot_product"}'
attn_mask = _get_mask(to_dtype=attn_dtype, dim0=seq_len, dim1=seq_len)
key_padding_mask = _get_mask(
to_dtype=key_padding_dtype, dim0=batch_size, dim1=seq_len
)
q = torch.rand(seq_len, batch_size, embedding, requires_grad=True)
k = torch.rand(seq_len, batch_size, embedding, requires_grad=True)
v = torch.rand(seq_len, batch_size, embedding, requires_grad=True)
_reset_seeds()
original_mha = MultiheadAttention(
embedding,
num_heads,
dropout=0.0,
xformers_att_config=None,
add_bias_kv=add_bias_kv,
add_zero_attn=add_zero_attn,
)
xformers_mha = MultiheadAttention(
embedding,
num_heads,
dropout=0.0,
xformers_att_config=xformers_att_config,
add_bias_kv=add_bias_kv,
add_zero_attn=add_zero_attn,
)
def original_bench_fw(q, k, v, key_padding_mask, attn_mask, static_kv):
original_mha(
query=q,
key=k,
value=v,
key_padding_mask=key_padding_mask,
attn_mask=attn_mask,
static_kv=static_kv,
)
def xformers_bench_fw(q, k, v, key_padding_mask, attn_mask, static_kv):
xformers_mha(
query=q,
key=k,
value=v,
key_padding_mask=key_padding_mask,
attn_mask=attn_mask,
static_kv=static_kv,
)
def original_bench_fw_bw(q, k, v, key_padding_mask, attn_mask, static_kv):
output, _ = original_mha(
query=q,
key=k,
value=v,
key_padding_mask=key_padding_mask,
attn_mask=attn_mask,
static_kv=static_kv,
)
loss = torch.norm(output)
loss.backward()
def xformers_bench_fw_bw(q, k, v, key_padding_mask, attn_mask, static_kv):
output, _ = xformers_mha(
query=q,
key=k,
value=v,
key_padding_mask=key_padding_mask,
attn_mask=attn_mask,
static_kv=static_kv,
)
loss = torch.norm(output)
loss.backward()
fns = [
original_bench_fw,
xformers_bench_fw,
original_bench_fw_bw,
xformers_bench_fw_bw,
]
for fn in fns:
results.append(
benchmark.Timer(
stmt="fn(q, k, v, key_padding_mask, attn_mask, static_kv)",
globals={
"q": q,
"k": k,
"v": v,
"key_padding_mask": key_padding_mask,
"attn_mask": attn_mask,
"static_kv": static_kv,
"fn": fn,
},
label="multihead fw + bw",
sub_label=f"{fn.__name__}",
description=label,
).blocked_autorange(min_run_time=1)
)
compare = benchmark.Compare(results)
compare.print()
def run_benchmarks():
for attn_dtype, key_padding_dtype, add_bias_kv, add_zero_attn in itertools.product(
ATTN_MASK_DTYPE, KEY_PADDING_MASK_DTYPE, [True, False], [True, False]
):
label = f"attn_dtype {attn_dtype}, key_padding_dtype {key_padding_dtype}, \
add_bias_kv {add_bias_kv}, add_zero_attn {add_zero_attn}"
benchmark_multihead_attention(
label=label,
attn_dtype=attn_dtype,
key_padding_dtype=key_padding_dtype,
add_bias_kv=add_bias_kv,
add_zero_attn=add_zero_attn,
)
run_benchmarks()

View File

@ -49,6 +49,13 @@ class EncDecBaseConfig(FairseqDataclass):
default=None, metadata={"help": "which layers to *keep* when pruning"} default=None, metadata={"help": "which layers to *keep* when pruning"}
) )
xformers_att_config: Optional[str] = field(
default=None,
metadata={
"help": "config for xFormers attention, defined in xformers.components.attention.AttentionConfig"
},
)
@dataclass @dataclass
class DecoderConfig(EncDecBaseConfig): class DecoderConfig(EncDecBaseConfig):

View File

@ -7,6 +7,8 @@
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Optional from typing import Optional
from omegaconf import II
from fairseq import options, utils from fairseq import options, utils
from fairseq.dataclass import ChoiceEnum, FairseqDataclass from fairseq.dataclass import ChoiceEnum, FairseqDataclass
from fairseq.models import ( from fairseq.models import (
@ -21,8 +23,6 @@ from fairseq.models.transformer import (
) )
from fairseq.modules import AdaptiveInput, CharacterTokenEmbedder from fairseq.modules import AdaptiveInput, CharacterTokenEmbedder
from fairseq.utils import safe_getattr, safe_hasattr from fairseq.utils import safe_getattr, safe_hasattr
from omegaconf import II
DEFAULT_MAX_TARGET_POSITIONS = 1024 DEFAULT_MAX_TARGET_POSITIONS = 1024
@ -210,6 +210,15 @@ class TransformerLanguageModelConfig(FairseqDataclass):
default=False, default=False,
metadata={"help": "Learn a scale coefficient for each residual connection"}, metadata={"help": "Learn a scale coefficient for each residual connection"},
) )
# xFormers arguments
decoder_xformers_att_config: Optional[str] = field(
default=None,
metadata={
"help": "config for xFormers library attention, defined in xformers.components.attention.AttentionConfig",
},
)
# options from other parts of the config # options from other parts of the config
add_bos_token: bool = II("task.add_bos_token") add_bos_token: bool = II("task.add_bos_token")
tokens_per_sample: int = II("task.tokens_per_sample") tokens_per_sample: int = II("task.tokens_per_sample")

View File

@ -10,6 +10,8 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import Tensor, nn from torch import Tensor, nn
from torch.nn import Parameter from torch.nn import Parameter
from xformers.components.attention import build_attention
from xformers.components.attention.utils import maybe_merge_masks
from fairseq import utils from fairseq import utils
from fairseq.incremental_decoding_utils import with_incremental_state from fairseq.incremental_decoding_utils import with_incremental_state
@ -17,6 +19,40 @@ from fairseq.modules.fairseq_dropout import FairseqDropout
from fairseq.modules.quant_noise import quant_noise from fairseq.modules.quant_noise import quant_noise
# TODO: move this into xformers?
# TODO: uint8 input type should just output a bool
def _mask_for_xformers(mask: Tensor, to_dtype: Optional[torch.dtype] = None):
"""
call to pytorch multihead accepts three mask types:
- ByteTensor where non-zero means to mask
- FloatTensor which is an additive mask
- BoolTensor where True means to mask
xFormers currently accepts boolean and additive maks. For boolean masks
the values have opposite meaning. For a BoolTensor True mean to keep the value.
"""
float_types = [torch.float, torch.float16]
# If an input mask is a float it is an additive mask. Otherwise it is either uint8 or bool.
additive = mask.dtype in float_types
# If to_dype is not specified, keep same dtype as mask.
to_dtype = mask.dtype if to_dtype is None else to_dtype
to_additive = to_dtype in float_types
if additive:
if to_additive:
return mask.to(to_dtype)
mask = mask < 0
if to_additive:
# return additive mask
new_mask = torch.zeros_like(mask, dtype=to_dtype)
new_mask = new_mask.masked_fill_(mask, -float("inf"))
return new_mask
# In xFormers True is value to keep rather than value to mask
mask = ~mask.to(torch.bool)
mask = mask.to(to_dtype)
return mask
@with_incremental_state @with_incremental_state
class MultiheadAttention(nn.Module): class MultiheadAttention(nn.Module):
"""Multi-headed attention. """Multi-headed attention.
@ -38,8 +74,20 @@ class MultiheadAttention(nn.Module):
encoder_decoder_attention=False, encoder_decoder_attention=False,
q_noise=0.0, q_noise=0.0,
qn_block_size=8, qn_block_size=8,
# TODO: pass in config rather than string.
# config defined in xformers.components.attention.AttentionConfig
xformers_att_config: Optional[str] = None,
xformers_blocksparse_layout: Optional[
torch.Tensor
] = None, # This should be part of the config
xformers_blocksparse_blocksize: Optional[
int
] = 16, # This should be part of the config
): ):
super().__init__() super().__init__()
xformers_att_config = utils.eval_str_dict(xformers_att_config)
self.use_xformers = xformers_att_config is not None
self.embed_dim = embed_dim self.embed_dim = embed_dim
self.kdim = kdim if kdim is not None else embed_dim self.kdim = kdim if kdim is not None else embed_dim
self.vdim = vdim if vdim is not None else embed_dim self.vdim = vdim if vdim is not None else embed_dim
@ -87,6 +135,23 @@ class MultiheadAttention(nn.Module):
self.beam_size = 1 self.beam_size = 1
self.reset_parameters() self.reset_parameters()
self.fp16_mask = False
if self.use_xformers:
xformers_att_config["dropout"] = xformers_att_config.get("dropout", dropout)
xformers_att_config["num_heads"] = xformers_att_config.get(
"num_heads", num_heads
)
if xformers_blocksparse_layout is not None:
# Could be part of a single config passed only once
xformers_att_config["block_size"] = xformers_blocksparse_blocksize
xformers_att_config["layout"] = xformers_blocksparse_layout
xformers_att_config["name"] = "blocksparse"
# Mask required to be float16
self.fp16_mask = True
self.attention = build_attention(xformers_att_config)
self.onnx_trace = False self.onnx_trace = False
self.skip_embed_dim_check = False self.skip_embed_dim_check = False
@ -296,6 +361,102 @@ class MultiheadAttention(nn.Module):
) )
return k, v, key_padding_mask, attn_mask return k, v, key_padding_mask, attn_mask
def _xformers_attn_forward(
self,
query,
key: Optional[Tensor],
value: Optional[Tensor],
key_padding_mask: Optional[Tensor] = None,
need_weights: bool = True,
attn_mask: Optional[Tensor] = None,
) -> Tuple[Tensor, Optional[Tensor]]:
tgt_len, bsz, embed_dim = query.size()
if key_padding_mask is not None:
assert key_padding_mask.size(0) == bsz
assert key_padding_mask.size(1) == tgt_len
if self.self_attention:
key = query
value = query
elif self.encoder_decoder_attention:
value = key
q = self.q_proj(query)
k = self.k_proj(key)
v = self.v_proj(value)
if self.bias_k is not None:
assert self.bias_v is not None
k, v, attn_mask, key_padding_mask = self._add_bias(
k, v, attn_mask, key_padding_mask, bsz
)
def fold_heads(x):
return (
x.contiguous()
.view(-1, bsz * self.num_heads, self.head_dim)
.transpose(0, 1)
)
def split_heads(x):
return (
x.contiguous()
.view(-1, bsz, self.num_heads, self.head_dim)
.transpose(0, 1)
.transpose(1, 2)
)
massage = split_heads if self.attention.requires_head_dimension else fold_heads
q = massage(q)
if k is not None:
k = massage(k)
if v is not None:
v = massage(v)
if self.add_zero_attn:
k, v, key_padding_mask, attn_mask = self._append_zero_attn(
k=k, v=v, key_padding_mask=key_padding_mask, attn_mask=attn_mask
)
if attn_mask is not None:
to_dtype = torch.float16 if self.fp16_mask else q.dtype
attn_mask = _mask_for_xformers(attn_mask, to_dtype=to_dtype)
if key_padding_mask is not None:
to_dtype = torch.float16 if self.fp16_mask else torch.bool
key_padding_mask = _mask_for_xformers(key_padding_mask, to_dtype=to_dtype)
if not self.attention.requires_separate_masks:
attn_mask = maybe_merge_masks(
attn_mask,
key_padding_mask,
batch_size=bsz,
src_len=k.size(-2),
tgt_len=q.size(-2),
num_heads=self.num_heads,
)
key_padding_mask = None
y = self.attention(
q, k, v, att_mask=attn_mask, key_padding_mask=key_padding_mask
)
y = (
y.view(bsz, self.num_heads, tgt_len, self.head_dim)
.transpose(1, 2)
.flatten(start_dim=2, end_dim=3)
.transpose(0, 1)
)
assert list(y.size()) == [tgt_len, bsz, embed_dim]
# Dropout not needed because already applied in attention.
# It is applied to the attention weights before matmul with v.
y = self.out_proj(y)
# TODO: support returning attention weights if needed.
return y, None
def forward( def forward(
self, self,
query, query,
@ -359,29 +520,36 @@ class MultiheadAttention(nn.Module):
and not self.skip_embed_dim_check and not self.skip_embed_dim_check
): ):
assert key is not None and value is not None assert key is not None and value is not None
return F.multi_head_attention_forward(
query, if self.use_xformers:
key, return self._xformers_attn_forward(
value, query, key, value, key_padding_mask, need_weights, attn_mask
self.embed_dim, )
self.num_heads,
torch.empty([0]), else:
torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)), return F.multi_head_attention_forward(
self.bias_k, query,
self.bias_v, key,
self.add_zero_attn, value,
self.dropout_module.p, self.embed_dim,
self.out_proj.weight, self.num_heads,
self.out_proj.bias, torch.empty([0]),
self.training or self.dropout_module.apply_during_inference, torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)),
key_padding_mask, self.bias_k,
need_weights, self.bias_v,
attn_mask, self.add_zero_attn,
use_separate_proj_weight=True, self.dropout_module.p,
q_proj_weight=self.q_proj.weight, self.out_proj.weight,
k_proj_weight=self.k_proj.weight, self.out_proj.bias,
v_proj_weight=self.v_proj.weight, self.training or self.dropout_module.apply_during_inference,
) key_padding_mask,
need_weights,
attn_mask,
use_separate_proj_weight=True,
q_proj_weight=self.q_proj.weight,
k_proj_weight=self.k_proj.weight,
v_proj_weight=self.v_proj.weight,
)
if incremental_state is not None: if incremental_state is not None:
saved_state = self._get_input_buffer(incremental_state) saved_state = self._get_input_buffer(incremental_state)

View File

@ -141,6 +141,7 @@ class TransformerEncoderLayerBase(nn.Module):
self_attention=True, self_attention=True,
q_noise=self.quant_noise, q_noise=self.quant_noise,
qn_block_size=self.quant_noise_block_size, qn_block_size=self.quant_noise_block_size,
xformers_att_config=cfg.encoder.xformers_att_config,
) )
def residual_connection(self, x, residual): def residual_connection(self, x, residual):
@ -359,6 +360,7 @@ class TransformerDecoderLayerBase(nn.Module):
self_attention=not cfg.cross_self_attention, self_attention=not cfg.cross_self_attention,
q_noise=self.quant_noise, q_noise=self.quant_noise,
qn_block_size=self.quant_noise_block_size, qn_block_size=self.quant_noise_block_size,
xformers_att_config=cfg.decoder.xformers_att_config,
) )
def build_encoder_attention(self, embed_dim, cfg): def build_encoder_attention(self, embed_dim, cfg):
@ -371,6 +373,7 @@ class TransformerDecoderLayerBase(nn.Module):
encoder_decoder_attention=True, encoder_decoder_attention=True,
q_noise=self.quant_noise, q_noise=self.quant_noise,
qn_block_size=self.quant_noise_block_size, qn_block_size=self.quant_noise_block_size,
xformers_att_config=cfg.encoder.xformers_att_config,
) )
def prepare_for_onnx_export_(self): def prepare_for_onnx_export_(self):

View File

@ -173,7 +173,6 @@ else:
if "clean" in sys.argv[1:]: if "clean" in sys.argv[1:]:
# Source: https://bit.ly/2NLVsgE # Source: https://bit.ly/2NLVsgE
print("deleting Cython files...") print("deleting Cython files...")
import subprocess
subprocess.run( subprocess.run(
["rm -f fairseq/*.so fairseq/**/*.so fairseq/*.pyd fairseq/**/*.pyd"], ["rm -f fairseq/*.so fairseq/**/*.so fairseq/*.pyd fairseq/**/*.pyd"],

View File

@ -3,11 +3,343 @@
# This source code is licensed under the MIT license found in the # This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import random
import unittest import unittest
import pytest
import torch import torch
from fairseq.modules.multihead_attention import MultiheadAttention, _mask_for_xformers
from fairseq.modules.multihead_attention import MultiheadAttention BATCH = [20, 41, 97]
SEQ = [64]
EMB = [48]
HEADS = [4]
DROP = 0.1
DEVICE = ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"]
ATTN_MASK_DTYPE = [None, torch.uint8, torch.bool, torch.float]
KEY_PADDING_MASK_DTYPE = [None, torch.uint8, torch.bool]
# FIXME: some tests fail when decimal=2, fix this and set decimal to 2
def assert_almost_equal(x, y, decimal=1, err_msg=""):
import numpy.testing as npt
if isinstance(x, torch.Tensor):
x = x.cpu().detach().numpy()
if isinstance(y, torch.Tensor):
y = y.cpu().detach().numpy()
npt.assert_array_almost_equal(x, y, err_msg=err_msg, decimal=decimal)
def _reset_seeds():
torch.manual_seed(0)
torch.random.manual_seed(0)
random.seed(0)
torch.cuda.manual_seed_all(0)
def _get_mask(to_dtype: torch.dtype, dim0: int, dim1: int):
if to_dtype == torch.float:
mask = torch.randint(0, 2, (dim0, dim1)).to(dtype=torch.bool)
return mask.to(dtype=to_dtype).masked_fill(mask, -float("inf"))
return torch.randint(0, 2, (dim0, dim1)).to(dtype=to_dtype)
def test_mask_for_xformers():
# Additive Mask
m_float_add = torch.tensor([float("-inf"), 0]).to(torch.float)
m_float_add_flipped = torch.tensor([0, float("-inf")]).to(torch.float)
m_float16_add = torch.tensor([float("-inf"), 0]).to(torch.float16)
m_float16_add_flipped = torch.tensor([0, float("-inf")]).to(torch.float16)
m_uint = torch.tensor([1, 0]).to(torch.uint8)
m_uint_flipped = torch.tensor([0, 1]).to(torch.uint8)
m_bool = torch.tensor([False, True])
assert torch.equal(_mask_for_xformers(m_float_add), m_float_add)
assert torch.equal(_mask_for_xformers(m_float16_add), m_float16_add)
assert torch.equal(_mask_for_xformers(m_uint), m_uint_flipped)
assert torch.equal(_mask_for_xformers(m_bool), ~m_bool)
assert torch.equal(
_mask_for_xformers(m_float_add, to_dtype=torch.float16), m_float16_add
)
assert torch.equal(
_mask_for_xformers(m_float_add, to_dtype=torch.float), m_float_add
)
assert torch.equal(_mask_for_xformers(m_float_add, to_dtype=torch.bool), m_bool)
assert torch.equal(
_mask_for_xformers(m_float_add, to_dtype=torch.uint8), m_uint_flipped
)
assert torch.equal(
_mask_for_xformers(m_float16_add, to_dtype=torch.float16), m_float16_add
)
assert torch.equal(
_mask_for_xformers(m_float16_add, to_dtype=torch.float), m_float_add
)
assert torch.equal(_mask_for_xformers(m_float16_add, to_dtype=torch.bool), m_bool)
assert torch.equal(
_mask_for_xformers(m_float16_add, to_dtype=torch.uint8), m_uint_flipped
)
assert torch.equal(
_mask_for_xformers(m_bool, to_dtype=torch.float16), m_float16_add_flipped
)
assert torch.equal(
_mask_for_xformers(m_bool, to_dtype=torch.float), m_float_add_flipped
)
assert torch.equal(_mask_for_xformers(m_bool, to_dtype=torch.bool), ~m_bool)
assert torch.equal(_mask_for_xformers(m_bool, to_dtype=torch.uint8), m_uint)
assert torch.equal(
_mask_for_xformers(m_uint, to_dtype=torch.float16), m_float16_add
)
assert torch.equal(_mask_for_xformers(m_uint, to_dtype=torch.float), m_float_add)
assert torch.equal(_mask_for_xformers(m_uint, to_dtype=torch.bool), m_bool)
assert torch.equal(_mask_for_xformers(m_uint, to_dtype=torch.uint8), m_uint_flipped)
@pytest.mark.parametrize("device", ["cuda"])
@pytest.mark.parametrize("attn_dtype", ATTN_MASK_DTYPE)
@pytest.mark.parametrize("key_padding_dtype", KEY_PADDING_MASK_DTYPE)
@pytest.mark.parametrize("add_zero_attn", [False])
@pytest.mark.parametrize("batch_size", [20])
@pytest.mark.parametrize("embedding", [64])
@pytest.mark.parametrize("seq_len", [64])
@pytest.mark.parametrize("num_heads", [4])
def test_xformers_blocksparse_parity(
device,
attn_dtype,
key_padding_dtype,
add_zero_attn,
batch_size,
embedding,
seq_len,
num_heads,
):
xformers_att_config = '{"name": "scaled_dot_product"}'
xformers_blocksparse_blocksize = 16
xformers_blocksparse_layout = torch.ones(
seq_len // xformers_blocksparse_blocksize,
seq_len // xformers_blocksparse_blocksize,
)
attn_mask = (
None
if attn_dtype is None
else _get_mask(to_dtype=attn_dtype, dim0=seq_len, dim1=seq_len).to(device)
)
key_padding_mask = (
None
if key_padding_dtype is None
else _get_mask(to_dtype=key_padding_dtype, dim0=batch_size, dim1=seq_len).to(
device
)
)
q = torch.rand(seq_len, batch_size, embedding).to(device).half()
q.requires_grad = True
k = torch.rand(seq_len, batch_size, embedding).to(device).half()
k.requires_grad = True
v = torch.rand(seq_len, batch_size, embedding).to(device).half()
v.requires_grad = True
q_ = q.detach().clone().half()
q_.requires_grad = True
k_ = k.detach().clone().half()
k_.requires_grad = True
v_ = v.detach().clone().half()
v_.requires_grad = True
_reset_seeds()
xf_blocksparse_mha = (
MultiheadAttention(
embedding,
num_heads,
dropout=0.0,
add_zero_attn=add_zero_attn,
xformers_att_config=xformers_att_config,
xformers_blocksparse_layout=xformers_blocksparse_layout,
xformers_blocksparse_blocksize=xformers_blocksparse_blocksize,
)
.to(device)
.half()
)
xf_blocksparse_output, _ = xf_blocksparse_mha(
q,
k,
v,
key_padding_mask=key_padding_mask,
attn_mask=attn_mask,
)
_reset_seeds()
xformers_mha = (
MultiheadAttention(
embedding,
num_heads,
dropout=0.0,
add_zero_attn=add_zero_attn,
xformers_att_config=xformers_att_config,
xformers_blocksparse_layout=None,
)
.to(device)
.half()
)
xformers_output, _ = xformers_mha(
q_,
k_,
v_,
key_padding_mask=key_padding_mask,
attn_mask=attn_mask,
)
# # account for when nan != nan
rand = random.uniform(0, 1)
xformers_output = xformers_output.masked_fill(xformers_output.isnan(), rand)
xf_blocksparse_output = xf_blocksparse_output.masked_fill(
xf_blocksparse_output.isnan(), rand
)
assert_almost_equal(xformers_output, xf_blocksparse_output)
loss_blocksparse = torch.norm(xformers_output)
loss_original = torch.norm(xf_blocksparse_output)
loss_blocksparse.backward()
loss_original.backward()
q.masked_fill(q.isnan(), rand)
q_.masked_fill(q_.isnan(), rand)
k.masked_fill(k.isnan(), rand)
k_.masked_fill(k_.isnan(), rand)
v.masked_fill(v.isnan(), rand)
v_.masked_fill(v_.isnan(), rand)
assert_almost_equal(q.grad, q_.grad)
assert_almost_equal(k.grad, k_.grad)
assert_almost_equal(v.grad, v_.grad)
@pytest.mark.parametrize("device", DEVICE)
@pytest.mark.parametrize("attn_dtype", ATTN_MASK_DTYPE)
@pytest.mark.parametrize("key_padding_dtype", KEY_PADDING_MASK_DTYPE)
@pytest.mark.parametrize("add_bias_kv", [True, False])
@pytest.mark.parametrize("add_zero_attn", [True, False])
# TODO: test with static_kv True
@pytest.mark.parametrize("static_kv", [False])
@pytest.mark.parametrize("batch_size", BATCH)
@pytest.mark.parametrize("embedding", EMB)
@pytest.mark.parametrize("seq_len", SEQ)
@pytest.mark.parametrize("num_heads", HEADS)
def test_xformers_single_forward_parity(
device,
attn_dtype,
key_padding_dtype,
add_bias_kv,
add_zero_attn,
static_kv,
batch_size,
embedding,
seq_len,
num_heads,
):
xformers_att_config = '{"name": "scaled_dot_product"}'
attn_mask = (
None
if attn_dtype is None
else _get_mask(to_dtype=attn_dtype, dim0=seq_len, dim1=seq_len).to(device)
)
key_padding_mask = (
None
if key_padding_dtype is None
else _get_mask(to_dtype=key_padding_dtype, dim0=batch_size, dim1=seq_len).to(
device
)
)
q = torch.rand(seq_len, batch_size, embedding).to(device)
q.requires_grad = True
k = torch.rand(seq_len, batch_size, embedding).to(device)
k.requires_grad = True
v = torch.rand(seq_len, batch_size, embedding).to(device)
v.requires_grad = True
q_ = q.detach().clone()
q_.requires_grad = True
k_ = k.detach().clone()
k_.requires_grad = True
v_ = v.detach().clone()
v_.requires_grad = True
# TODO: dropouts in the two implementations lead to different entries dropped.
_reset_seeds()
xformers_mha = MultiheadAttention(
embedding,
num_heads,
dropout=0.0,
xformers_att_config=xformers_att_config,
add_bias_kv=add_bias_kv,
add_zero_attn=add_zero_attn,
).to(device)
xformers_output, _ = xformers_mha(
q,
k,
v,
key_padding_mask=key_padding_mask,
attn_mask=attn_mask,
static_kv=static_kv,
)
_reset_seeds()
original_mha = MultiheadAttention(
embedding,
num_heads,
dropout=0.0,
xformers_att_config=None,
add_bias_kv=add_bias_kv,
add_zero_attn=add_zero_attn,
).to(device)
original_output, _ = original_mha(
q_,
k_,
v_,
key_padding_mask=key_padding_mask,
attn_mask=attn_mask,
static_kv=static_kv,
)
# account for when nan != nan
if xformers_output.isnan().any() or original_output.isnan().any():
rand = random.uniform(0, 1)
xformers_output = xformers_output.masked_fill(xformers_output.isnan(), rand)
original_output = original_output.masked_fill(original_output.isnan(), rand)
# torch.equal works for cpu, on cuda allclose is needed.
assert torch.allclose(
xformers_output, original_output, atol=1e-06
), f"max diff is {torch.max(torch.abs(xformers_output - original_output))}"
loss_xformers = torch.norm(xformers_output)
loss_original = torch.norm(original_output)
loss_xformers.backward()
loss_original.backward()
# torch.equal works for cpu, on cuda allclose is needed.
assert torch.allclose(
q.grad, q_.grad
), f"max diff is {torch.max(torch.abs(q.grad - q_.grad))}"
assert torch.allclose(
k.grad, k_.grad
), f"max diff is {torch.max(torch.abs(k.grad - k_.grad))}"
assert torch.allclose(
v.grad, v_.grad
), f"max diff is {torch.max(torch.abs(v.grad - v_.grad))}"
def test_mask_padding_parity(): def test_mask_padding_parity():
@ -28,7 +360,7 @@ def test_mask_padding_parity():
# values don't matter for this test. # values don't matter for this test.
mha = MultiheadAttention( mha = MultiheadAttention(
embedding=8, embed_dim=8,
num_heads=2, num_heads=2,
dropout=0.0, dropout=0.0,
add_bias_kv=True, add_bias_kv=True,
@ -50,7 +382,7 @@ def test_mask_padding_parity():
def test_add_bias_parity(): def test_add_bias_parity():
# values don't matter for this test. # values don't matter for this test.
mha = MultiheadAttention( mha = MultiheadAttention(
embedding=8, embed_dim=8,
num_heads=2, num_heads=2,
dropout=0.0, dropout=0.0,
add_bias_kv=True, add_bias_kv=True,