From 51478ad3a19feed51d4bc4df5416870b7cee5347 Mon Sep 17 00:00:00 2001 From: dianaml0 <82468439+dianaml0@users.noreply.github.com> Date: Wed, 4 May 2022 09:15:36 -0700 Subject: [PATCH] xformer integration (#2263) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: # Before submitting - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [x] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [ ] Did you make sure to update the docs? - [x] Did you write any new necessary tests? ## What does this PR do? This PR is a cleaned up version of https://github.com/fairinternal/fairseq-py/issues/2138. It is based on the `main` branch instead of the `gshard` branch. Removed call to xFormers MultiHeadDispatch, only using xFormers Attention. ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � X-link: https://github.com/fairinternal/fairseq-py/pull/2263 Reviewed By: blefaudeux Differential Revision: D33800377 Pulled By: dianaml0 fbshipit-source-id: 658d52214c782212b12881b30c4d908a763b4cf2 --- .circleci/config.yml | 13 + .github/workflows/build.yml | 2 + README.md | 1 + .../benchmark_multihead_attention.py | 172 +++++++++ .../models/transformer/transformer_config.py | 7 + fairseq/models/transformer_lm.py | 13 +- fairseq/modules/multihead_attention.py | 214 +++++++++-- fairseq/modules/transformer_layer.py | 3 + setup.py | 1 - tests/test_multihead_attention.py | 338 +++++++++++++++++- 10 files changed, 735 insertions(+), 29 deletions(-) create mode 100644 fairseq/benchmark/benchmark_multihead_attention.py diff --git a/.circleci/config.yml b/.circleci/config.yml index de40a6e9c..c29b534ff 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -45,6 +45,17 @@ install_dep_fused_ops: &install_dep_fused_ops cd Megatron-LM pip install -e . +install_dep_xformers: &install_dep_xformers + - run: + name: Install xFormers Dependencies + working_directory: ~/ + command: | + source activate fairseq + git clone https://github.com/facebookresearch/xformers.git + cd xformers + pip install -r requirements.txt + pip install -e . + install_dep_pt19: &install_dep_pt19 - run: @@ -123,6 +134,7 @@ jobs: - <<: *install_dep_pt19 - <<: *install_dep_common - <<: *install_dep_fused_ops + - <<: *install_dep_xformers - save_cache: paths: - ~/miniconda/ @@ -144,6 +156,7 @@ jobs: - <<: *install_dep_pt18 - <<: *install_dep_common - <<: *install_dep_fused_ops + - <<: *install_dep_xformers - save_cache: paths: - ~/miniconda/ diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 9abd0690f..16b42974b 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -41,6 +41,8 @@ jobs: run: | python -m pip install iopath transformers pyarrow python -m pip install git+https://github.com/facebookresearch/fairscale.git@main + python -m pip install --progress-bar off git+https://github.com/facebookresearch/xformers.git@main + python -m pip install pytest - name: Lint with flake8 run: | diff --git a/README.md b/README.md index b4a848ecf..a354e1b9e 100644 --- a/README.md +++ b/README.md @@ -69,6 +69,7 @@ We provide reference implementations of various sequence modeling papers:

### What's New: +* May 2022 [Integration with xFormers](https://github.com/facebookresearch/xformers) * December 2021 [Released Direct speech-to-speech translation code](examples/speech_to_speech/README.md) * October 2021 [Released VideoCLIP and VLM models](examples/MMPT/README.md) * October 2021 [Released multilingual finetuned XLSR-53 model](examples/wav2vec/README.md) diff --git a/fairseq/benchmark/benchmark_multihead_attention.py b/fairseq/benchmark/benchmark_multihead_attention.py new file mode 100644 index 000000000..a44847f25 --- /dev/null +++ b/fairseq/benchmark/benchmark_multihead_attention.py @@ -0,0 +1,172 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import itertools +import random + +import torch +from torch.utils import benchmark + +from fairseq.modules.multihead_attention import MultiheadAttention + +BATCH = [20, 41, 97] +SEQ = 64 +EMB = 48 +HEADS = 4 +DROP = 0.1 +DEVICE = torch.device("cuda") +ATTN_MASK_DTYPE = [torch.uint8, torch.bool, torch.float] +KEY_PADDING_MASK_DTYPE = [torch.uint8, torch.bool] + + +def _reset_seeds(): + torch.manual_seed(0) + random.seed(0) + + +def _get_mask(to_dtype: torch.dtype, dim0: int, dim1: int): + if to_dtype == torch.float: + mask = torch.randint(0, 2, (dim0, dim1)).to(dtype=torch.bool) + return mask.to(dtype=to_dtype).masked_fill(mask, -float("inf")) + return torch.randint(0, 2, (dim0, dim1)).to(dtype=to_dtype) + + +def benchmark_multihead_attention( + label="", + attn_dtype=torch.uint8, + key_padding_dtype=torch.uint8, + add_bias_kv=False, + add_zero_attn=False, + static_kv=False, + batch_size=20, + embedding=EMB, + seq_len=SEQ, + num_heads=HEADS, +): + + results = [] + # device = torch.device("cuda") + + xformers_att_config = '{"name": "scaled_dot_product"}' + + attn_mask = _get_mask(to_dtype=attn_dtype, dim0=seq_len, dim1=seq_len) + key_padding_mask = _get_mask( + to_dtype=key_padding_dtype, dim0=batch_size, dim1=seq_len + ) + + q = torch.rand(seq_len, batch_size, embedding, requires_grad=True) + k = torch.rand(seq_len, batch_size, embedding, requires_grad=True) + v = torch.rand(seq_len, batch_size, embedding, requires_grad=True) + + _reset_seeds() + + original_mha = MultiheadAttention( + embedding, + num_heads, + dropout=0.0, + xformers_att_config=None, + add_bias_kv=add_bias_kv, + add_zero_attn=add_zero_attn, + ) + + xformers_mha = MultiheadAttention( + embedding, + num_heads, + dropout=0.0, + xformers_att_config=xformers_att_config, + add_bias_kv=add_bias_kv, + add_zero_attn=add_zero_attn, + ) + + def original_bench_fw(q, k, v, key_padding_mask, attn_mask, static_kv): + original_mha( + query=q, + key=k, + value=v, + key_padding_mask=key_padding_mask, + attn_mask=attn_mask, + static_kv=static_kv, + ) + + def xformers_bench_fw(q, k, v, key_padding_mask, attn_mask, static_kv): + xformers_mha( + query=q, + key=k, + value=v, + key_padding_mask=key_padding_mask, + attn_mask=attn_mask, + static_kv=static_kv, + ) + + def original_bench_fw_bw(q, k, v, key_padding_mask, attn_mask, static_kv): + output, _ = original_mha( + query=q, + key=k, + value=v, + key_padding_mask=key_padding_mask, + attn_mask=attn_mask, + static_kv=static_kv, + ) + loss = torch.norm(output) + loss.backward() + + def xformers_bench_fw_bw(q, k, v, key_padding_mask, attn_mask, static_kv): + output, _ = xformers_mha( + query=q, + key=k, + value=v, + key_padding_mask=key_padding_mask, + attn_mask=attn_mask, + static_kv=static_kv, + ) + loss = torch.norm(output) + loss.backward() + + fns = [ + original_bench_fw, + xformers_bench_fw, + original_bench_fw_bw, + xformers_bench_fw_bw, + ] + + for fn in fns: + results.append( + benchmark.Timer( + stmt="fn(q, k, v, key_padding_mask, attn_mask, static_kv)", + globals={ + "q": q, + "k": k, + "v": v, + "key_padding_mask": key_padding_mask, + "attn_mask": attn_mask, + "static_kv": static_kv, + "fn": fn, + }, + label="multihead fw + bw", + sub_label=f"{fn.__name__}", + description=label, + ).blocked_autorange(min_run_time=1) + ) + + compare = benchmark.Compare(results) + compare.print() + + +def run_benchmarks(): + for attn_dtype, key_padding_dtype, add_bias_kv, add_zero_attn in itertools.product( + ATTN_MASK_DTYPE, KEY_PADDING_MASK_DTYPE, [True, False], [True, False] + ): + label = f"attn_dtype {attn_dtype}, key_padding_dtype {key_padding_dtype}, \ + add_bias_kv {add_bias_kv}, add_zero_attn {add_zero_attn}" + benchmark_multihead_attention( + label=label, + attn_dtype=attn_dtype, + key_padding_dtype=key_padding_dtype, + add_bias_kv=add_bias_kv, + add_zero_attn=add_zero_attn, + ) + + +run_benchmarks() diff --git a/fairseq/models/transformer/transformer_config.py b/fairseq/models/transformer/transformer_config.py index 4ebd292b0..119b030b0 100644 --- a/fairseq/models/transformer/transformer_config.py +++ b/fairseq/models/transformer/transformer_config.py @@ -49,6 +49,13 @@ class EncDecBaseConfig(FairseqDataclass): default=None, metadata={"help": "which layers to *keep* when pruning"} ) + xformers_att_config: Optional[str] = field( + default=None, + metadata={ + "help": "config for xFormers attention, defined in xformers.components.attention.AttentionConfig" + }, + ) + @dataclass class DecoderConfig(EncDecBaseConfig): diff --git a/fairseq/models/transformer_lm.py b/fairseq/models/transformer_lm.py index 5a23888b1..1e3aa72d3 100644 --- a/fairseq/models/transformer_lm.py +++ b/fairseq/models/transformer_lm.py @@ -7,6 +7,8 @@ from dataclasses import dataclass, field from typing import Optional +from omegaconf import II + from fairseq import options, utils from fairseq.dataclass import ChoiceEnum, FairseqDataclass from fairseq.models import ( @@ -21,8 +23,6 @@ from fairseq.models.transformer import ( ) from fairseq.modules import AdaptiveInput, CharacterTokenEmbedder from fairseq.utils import safe_getattr, safe_hasattr -from omegaconf import II - DEFAULT_MAX_TARGET_POSITIONS = 1024 @@ -210,6 +210,15 @@ class TransformerLanguageModelConfig(FairseqDataclass): default=False, metadata={"help": "Learn a scale coefficient for each residual connection"}, ) + + # xFormers arguments + decoder_xformers_att_config: Optional[str] = field( + default=None, + metadata={ + "help": "config for xFormers library attention, defined in xformers.components.attention.AttentionConfig", + }, + ) + # options from other parts of the config add_bos_token: bool = II("task.add_bos_token") tokens_per_sample: int = II("task.tokens_per_sample") diff --git a/fairseq/modules/multihead_attention.py b/fairseq/modules/multihead_attention.py index fd33ba446..8d08331e4 100644 --- a/fairseq/modules/multihead_attention.py +++ b/fairseq/modules/multihead_attention.py @@ -10,6 +10,8 @@ import torch import torch.nn.functional as F from torch import Tensor, nn from torch.nn import Parameter +from xformers.components.attention import build_attention +from xformers.components.attention.utils import maybe_merge_masks from fairseq import utils from fairseq.incremental_decoding_utils import with_incremental_state @@ -17,6 +19,40 @@ from fairseq.modules.fairseq_dropout import FairseqDropout from fairseq.modules.quant_noise import quant_noise +# TODO: move this into xformers? +# TODO: uint8 input type should just output a bool +def _mask_for_xformers(mask: Tensor, to_dtype: Optional[torch.dtype] = None): + """ + call to pytorch multihead accepts three mask types: + - ByteTensor where non-zero means to mask + - FloatTensor which is an additive mask + - BoolTensor where True means to mask + xFormers currently accepts boolean and additive maks. For boolean masks + the values have opposite meaning. For a BoolTensor True mean to keep the value. + """ + float_types = [torch.float, torch.float16] + # If an input mask is a float it is an additive mask. Otherwise it is either uint8 or bool. + additive = mask.dtype in float_types + # If to_dype is not specified, keep same dtype as mask. + to_dtype = mask.dtype if to_dtype is None else to_dtype + to_additive = to_dtype in float_types + + if additive: + if to_additive: + return mask.to(to_dtype) + mask = mask < 0 + + if to_additive: + # return additive mask + new_mask = torch.zeros_like(mask, dtype=to_dtype) + new_mask = new_mask.masked_fill_(mask, -float("inf")) + return new_mask + + # In xFormers True is value to keep rather than value to mask + mask = ~mask.to(torch.bool) + mask = mask.to(to_dtype) + return mask + @with_incremental_state class MultiheadAttention(nn.Module): """Multi-headed attention. @@ -38,8 +74,20 @@ class MultiheadAttention(nn.Module): encoder_decoder_attention=False, q_noise=0.0, qn_block_size=8, + # TODO: pass in config rather than string. + # config defined in xformers.components.attention.AttentionConfig + xformers_att_config: Optional[str] = None, + xformers_blocksparse_layout: Optional[ + torch.Tensor + ] = None, # This should be part of the config + xformers_blocksparse_blocksize: Optional[ + int + ] = 16, # This should be part of the config ): super().__init__() + + xformers_att_config = utils.eval_str_dict(xformers_att_config) + self.use_xformers = xformers_att_config is not None self.embed_dim = embed_dim self.kdim = kdim if kdim is not None else embed_dim self.vdim = vdim if vdim is not None else embed_dim @@ -87,6 +135,23 @@ class MultiheadAttention(nn.Module): self.beam_size = 1 self.reset_parameters() + self.fp16_mask = False + if self.use_xformers: + xformers_att_config["dropout"] = xformers_att_config.get("dropout", dropout) + xformers_att_config["num_heads"] = xformers_att_config.get( + "num_heads", num_heads + ) + + if xformers_blocksparse_layout is not None: + # Could be part of a single config passed only once + xformers_att_config["block_size"] = xformers_blocksparse_blocksize + xformers_att_config["layout"] = xformers_blocksparse_layout + xformers_att_config["name"] = "blocksparse" + # Mask required to be float16 + self.fp16_mask = True + + self.attention = build_attention(xformers_att_config) + self.onnx_trace = False self.skip_embed_dim_check = False @@ -296,6 +361,102 @@ class MultiheadAttention(nn.Module): ) return k, v, key_padding_mask, attn_mask + def _xformers_attn_forward( + self, + query, + key: Optional[Tensor], + value: Optional[Tensor], + key_padding_mask: Optional[Tensor] = None, + need_weights: bool = True, + attn_mask: Optional[Tensor] = None, + ) -> Tuple[Tensor, Optional[Tensor]]: + + tgt_len, bsz, embed_dim = query.size() + + if key_padding_mask is not None: + assert key_padding_mask.size(0) == bsz + assert key_padding_mask.size(1) == tgt_len + + if self.self_attention: + key = query + value = query + elif self.encoder_decoder_attention: + value = key + + q = self.q_proj(query) + k = self.k_proj(key) + v = self.v_proj(value) + + if self.bias_k is not None: + assert self.bias_v is not None + k, v, attn_mask, key_padding_mask = self._add_bias( + k, v, attn_mask, key_padding_mask, bsz + ) + + def fold_heads(x): + return ( + x.contiguous() + .view(-1, bsz * self.num_heads, self.head_dim) + .transpose(0, 1) + ) + + def split_heads(x): + return ( + x.contiguous() + .view(-1, bsz, self.num_heads, self.head_dim) + .transpose(0, 1) + .transpose(1, 2) + ) + + massage = split_heads if self.attention.requires_head_dimension else fold_heads + q = massage(q) + if k is not None: + k = massage(k) + if v is not None: + v = massage(v) + + if self.add_zero_attn: + k, v, key_padding_mask, attn_mask = self._append_zero_attn( + k=k, v=v, key_padding_mask=key_padding_mask, attn_mask=attn_mask + ) + + if attn_mask is not None: + to_dtype = torch.float16 if self.fp16_mask else q.dtype + attn_mask = _mask_for_xformers(attn_mask, to_dtype=to_dtype) + + if key_padding_mask is not None: + to_dtype = torch.float16 if self.fp16_mask else torch.bool + key_padding_mask = _mask_for_xformers(key_padding_mask, to_dtype=to_dtype) + if not self.attention.requires_separate_masks: + attn_mask = maybe_merge_masks( + attn_mask, + key_padding_mask, + batch_size=bsz, + src_len=k.size(-2), + tgt_len=q.size(-2), + num_heads=self.num_heads, + ) + key_padding_mask = None + + y = self.attention( + q, k, v, att_mask=attn_mask, key_padding_mask=key_padding_mask + ) + + y = ( + y.view(bsz, self.num_heads, tgt_len, self.head_dim) + .transpose(1, 2) + .flatten(start_dim=2, end_dim=3) + .transpose(0, 1) + ) + assert list(y.size()) == [tgt_len, bsz, embed_dim] + + # Dropout not needed because already applied in attention. + # It is applied to the attention weights before matmul with v. + y = self.out_proj(y) + + # TODO: support returning attention weights if needed. + return y, None + def forward( self, query, @@ -359,29 +520,36 @@ class MultiheadAttention(nn.Module): and not self.skip_embed_dim_check ): assert key is not None and value is not None - return F.multi_head_attention_forward( - query, - key, - value, - self.embed_dim, - self.num_heads, - torch.empty([0]), - torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)), - self.bias_k, - self.bias_v, - self.add_zero_attn, - self.dropout_module.p, - self.out_proj.weight, - self.out_proj.bias, - self.training or self.dropout_module.apply_during_inference, - key_padding_mask, - need_weights, - attn_mask, - use_separate_proj_weight=True, - q_proj_weight=self.q_proj.weight, - k_proj_weight=self.k_proj.weight, - v_proj_weight=self.v_proj.weight, - ) + + if self.use_xformers: + return self._xformers_attn_forward( + query, key, value, key_padding_mask, need_weights, attn_mask + ) + + else: + return F.multi_head_attention_forward( + query, + key, + value, + self.embed_dim, + self.num_heads, + torch.empty([0]), + torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)), + self.bias_k, + self.bias_v, + self.add_zero_attn, + self.dropout_module.p, + self.out_proj.weight, + self.out_proj.bias, + self.training or self.dropout_module.apply_during_inference, + key_padding_mask, + need_weights, + attn_mask, + use_separate_proj_weight=True, + q_proj_weight=self.q_proj.weight, + k_proj_weight=self.k_proj.weight, + v_proj_weight=self.v_proj.weight, + ) if incremental_state is not None: saved_state = self._get_input_buffer(incremental_state) diff --git a/fairseq/modules/transformer_layer.py b/fairseq/modules/transformer_layer.py index 50025e045..2e687b94d 100644 --- a/fairseq/modules/transformer_layer.py +++ b/fairseq/modules/transformer_layer.py @@ -141,6 +141,7 @@ class TransformerEncoderLayerBase(nn.Module): self_attention=True, q_noise=self.quant_noise, qn_block_size=self.quant_noise_block_size, + xformers_att_config=cfg.encoder.xformers_att_config, ) def residual_connection(self, x, residual): @@ -359,6 +360,7 @@ class TransformerDecoderLayerBase(nn.Module): self_attention=not cfg.cross_self_attention, q_noise=self.quant_noise, qn_block_size=self.quant_noise_block_size, + xformers_att_config=cfg.decoder.xformers_att_config, ) def build_encoder_attention(self, embed_dim, cfg): @@ -371,6 +373,7 @@ class TransformerDecoderLayerBase(nn.Module): encoder_decoder_attention=True, q_noise=self.quant_noise, qn_block_size=self.quant_noise_block_size, + xformers_att_config=cfg.encoder.xformers_att_config, ) def prepare_for_onnx_export_(self): diff --git a/setup.py b/setup.py index c5591915f..e2e44570c 100644 --- a/setup.py +++ b/setup.py @@ -173,7 +173,6 @@ else: if "clean" in sys.argv[1:]: # Source: https://bit.ly/2NLVsgE print("deleting Cython files...") - import subprocess subprocess.run( ["rm -f fairseq/*.so fairseq/**/*.so fairseq/*.pyd fairseq/**/*.pyd"], diff --git a/tests/test_multihead_attention.py b/tests/test_multihead_attention.py index ebed9c903..dd075ea88 100644 --- a/tests/test_multihead_attention.py +++ b/tests/test_multihead_attention.py @@ -3,11 +3,343 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import random import unittest +import pytest import torch +from fairseq.modules.multihead_attention import MultiheadAttention, _mask_for_xformers -from fairseq.modules.multihead_attention import MultiheadAttention +BATCH = [20, 41, 97] +SEQ = [64] +EMB = [48] +HEADS = [4] +DROP = 0.1 +DEVICE = ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"] +ATTN_MASK_DTYPE = [None, torch.uint8, torch.bool, torch.float] +KEY_PADDING_MASK_DTYPE = [None, torch.uint8, torch.bool] + + +# FIXME: some tests fail when decimal=2, fix this and set decimal to 2 +def assert_almost_equal(x, y, decimal=1, err_msg=""): + import numpy.testing as npt + + if isinstance(x, torch.Tensor): + x = x.cpu().detach().numpy() + if isinstance(y, torch.Tensor): + y = y.cpu().detach().numpy() + npt.assert_array_almost_equal(x, y, err_msg=err_msg, decimal=decimal) + + +def _reset_seeds(): + torch.manual_seed(0) + torch.random.manual_seed(0) + random.seed(0) + torch.cuda.manual_seed_all(0) + + +def _get_mask(to_dtype: torch.dtype, dim0: int, dim1: int): + if to_dtype == torch.float: + mask = torch.randint(0, 2, (dim0, dim1)).to(dtype=torch.bool) + return mask.to(dtype=to_dtype).masked_fill(mask, -float("inf")) + return torch.randint(0, 2, (dim0, dim1)).to(dtype=to_dtype) + + +def test_mask_for_xformers(): + # Additive Mask + m_float_add = torch.tensor([float("-inf"), 0]).to(torch.float) + m_float_add_flipped = torch.tensor([0, float("-inf")]).to(torch.float) + m_float16_add = torch.tensor([float("-inf"), 0]).to(torch.float16) + m_float16_add_flipped = torch.tensor([0, float("-inf")]).to(torch.float16) + m_uint = torch.tensor([1, 0]).to(torch.uint8) + m_uint_flipped = torch.tensor([0, 1]).to(torch.uint8) + m_bool = torch.tensor([False, True]) + + assert torch.equal(_mask_for_xformers(m_float_add), m_float_add) + assert torch.equal(_mask_for_xformers(m_float16_add), m_float16_add) + assert torch.equal(_mask_for_xformers(m_uint), m_uint_flipped) + assert torch.equal(_mask_for_xformers(m_bool), ~m_bool) + + assert torch.equal( + _mask_for_xformers(m_float_add, to_dtype=torch.float16), m_float16_add + ) + assert torch.equal( + _mask_for_xformers(m_float_add, to_dtype=torch.float), m_float_add + ) + assert torch.equal(_mask_for_xformers(m_float_add, to_dtype=torch.bool), m_bool) + assert torch.equal( + _mask_for_xformers(m_float_add, to_dtype=torch.uint8), m_uint_flipped + ) + + assert torch.equal( + _mask_for_xformers(m_float16_add, to_dtype=torch.float16), m_float16_add + ) + assert torch.equal( + _mask_for_xformers(m_float16_add, to_dtype=torch.float), m_float_add + ) + assert torch.equal(_mask_for_xformers(m_float16_add, to_dtype=torch.bool), m_bool) + assert torch.equal( + _mask_for_xformers(m_float16_add, to_dtype=torch.uint8), m_uint_flipped + ) + + assert torch.equal( + _mask_for_xformers(m_bool, to_dtype=torch.float16), m_float16_add_flipped + ) + assert torch.equal( + _mask_for_xformers(m_bool, to_dtype=torch.float), m_float_add_flipped + ) + assert torch.equal(_mask_for_xformers(m_bool, to_dtype=torch.bool), ~m_bool) + assert torch.equal(_mask_for_xformers(m_bool, to_dtype=torch.uint8), m_uint) + + assert torch.equal( + _mask_for_xformers(m_uint, to_dtype=torch.float16), m_float16_add + ) + assert torch.equal(_mask_for_xformers(m_uint, to_dtype=torch.float), m_float_add) + assert torch.equal(_mask_for_xformers(m_uint, to_dtype=torch.bool), m_bool) + assert torch.equal(_mask_for_xformers(m_uint, to_dtype=torch.uint8), m_uint_flipped) + + +@pytest.mark.parametrize("device", ["cuda"]) +@pytest.mark.parametrize("attn_dtype", ATTN_MASK_DTYPE) +@pytest.mark.parametrize("key_padding_dtype", KEY_PADDING_MASK_DTYPE) +@pytest.mark.parametrize("add_zero_attn", [False]) +@pytest.mark.parametrize("batch_size", [20]) +@pytest.mark.parametrize("embedding", [64]) +@pytest.mark.parametrize("seq_len", [64]) +@pytest.mark.parametrize("num_heads", [4]) +def test_xformers_blocksparse_parity( + device, + attn_dtype, + key_padding_dtype, + add_zero_attn, + batch_size, + embedding, + seq_len, + num_heads, +): + + xformers_att_config = '{"name": "scaled_dot_product"}' + xformers_blocksparse_blocksize = 16 + xformers_blocksparse_layout = torch.ones( + seq_len // xformers_blocksparse_blocksize, + seq_len // xformers_blocksparse_blocksize, + ) + + attn_mask = ( + None + if attn_dtype is None + else _get_mask(to_dtype=attn_dtype, dim0=seq_len, dim1=seq_len).to(device) + ) + + key_padding_mask = ( + None + if key_padding_dtype is None + else _get_mask(to_dtype=key_padding_dtype, dim0=batch_size, dim1=seq_len).to( + device + ) + ) + + q = torch.rand(seq_len, batch_size, embedding).to(device).half() + q.requires_grad = True + k = torch.rand(seq_len, batch_size, embedding).to(device).half() + k.requires_grad = True + v = torch.rand(seq_len, batch_size, embedding).to(device).half() + v.requires_grad = True + + q_ = q.detach().clone().half() + q_.requires_grad = True + k_ = k.detach().clone().half() + k_.requires_grad = True + v_ = v.detach().clone().half() + v_.requires_grad = True + + _reset_seeds() + xf_blocksparse_mha = ( + MultiheadAttention( + embedding, + num_heads, + dropout=0.0, + add_zero_attn=add_zero_attn, + xformers_att_config=xformers_att_config, + xformers_blocksparse_layout=xformers_blocksparse_layout, + xformers_blocksparse_blocksize=xformers_blocksparse_blocksize, + ) + .to(device) + .half() + ) + + xf_blocksparse_output, _ = xf_blocksparse_mha( + q, + k, + v, + key_padding_mask=key_padding_mask, + attn_mask=attn_mask, + ) + + _reset_seeds() + xformers_mha = ( + MultiheadAttention( + embedding, + num_heads, + dropout=0.0, + add_zero_attn=add_zero_attn, + xformers_att_config=xformers_att_config, + xformers_blocksparse_layout=None, + ) + .to(device) + .half() + ) + + xformers_output, _ = xformers_mha( + q_, + k_, + v_, + key_padding_mask=key_padding_mask, + attn_mask=attn_mask, + ) + + # # account for when nan != nan + rand = random.uniform(0, 1) + xformers_output = xformers_output.masked_fill(xformers_output.isnan(), rand) + xf_blocksparse_output = xf_blocksparse_output.masked_fill( + xf_blocksparse_output.isnan(), rand + ) + + assert_almost_equal(xformers_output, xf_blocksparse_output) + + loss_blocksparse = torch.norm(xformers_output) + loss_original = torch.norm(xf_blocksparse_output) + loss_blocksparse.backward() + loss_original.backward() + + q.masked_fill(q.isnan(), rand) + q_.masked_fill(q_.isnan(), rand) + k.masked_fill(k.isnan(), rand) + k_.masked_fill(k_.isnan(), rand) + v.masked_fill(v.isnan(), rand) + v_.masked_fill(v_.isnan(), rand) + + assert_almost_equal(q.grad, q_.grad) + assert_almost_equal(k.grad, k_.grad) + assert_almost_equal(v.grad, v_.grad) + + +@pytest.mark.parametrize("device", DEVICE) +@pytest.mark.parametrize("attn_dtype", ATTN_MASK_DTYPE) +@pytest.mark.parametrize("key_padding_dtype", KEY_PADDING_MASK_DTYPE) +@pytest.mark.parametrize("add_bias_kv", [True, False]) +@pytest.mark.parametrize("add_zero_attn", [True, False]) +# TODO: test with static_kv True +@pytest.mark.parametrize("static_kv", [False]) +@pytest.mark.parametrize("batch_size", BATCH) +@pytest.mark.parametrize("embedding", EMB) +@pytest.mark.parametrize("seq_len", SEQ) +@pytest.mark.parametrize("num_heads", HEADS) +def test_xformers_single_forward_parity( + device, + attn_dtype, + key_padding_dtype, + add_bias_kv, + add_zero_attn, + static_kv, + batch_size, + embedding, + seq_len, + num_heads, +): + + xformers_att_config = '{"name": "scaled_dot_product"}' + + attn_mask = ( + None + if attn_dtype is None + else _get_mask(to_dtype=attn_dtype, dim0=seq_len, dim1=seq_len).to(device) + ) + key_padding_mask = ( + None + if key_padding_dtype is None + else _get_mask(to_dtype=key_padding_dtype, dim0=batch_size, dim1=seq_len).to( + device + ) + ) + + q = torch.rand(seq_len, batch_size, embedding).to(device) + q.requires_grad = True + k = torch.rand(seq_len, batch_size, embedding).to(device) + k.requires_grad = True + v = torch.rand(seq_len, batch_size, embedding).to(device) + v.requires_grad = True + + q_ = q.detach().clone() + q_.requires_grad = True + k_ = k.detach().clone() + k_.requires_grad = True + v_ = v.detach().clone() + v_.requires_grad = True + + # TODO: dropouts in the two implementations lead to different entries dropped. + _reset_seeds() + xformers_mha = MultiheadAttention( + embedding, + num_heads, + dropout=0.0, + xformers_att_config=xformers_att_config, + add_bias_kv=add_bias_kv, + add_zero_attn=add_zero_attn, + ).to(device) + xformers_output, _ = xformers_mha( + q, + k, + v, + key_padding_mask=key_padding_mask, + attn_mask=attn_mask, + static_kv=static_kv, + ) + + _reset_seeds() + original_mha = MultiheadAttention( + embedding, + num_heads, + dropout=0.0, + xformers_att_config=None, + add_bias_kv=add_bias_kv, + add_zero_attn=add_zero_attn, + ).to(device) + original_output, _ = original_mha( + q_, + k_, + v_, + key_padding_mask=key_padding_mask, + attn_mask=attn_mask, + static_kv=static_kv, + ) + + # account for when nan != nan + if xformers_output.isnan().any() or original_output.isnan().any(): + rand = random.uniform(0, 1) + xformers_output = xformers_output.masked_fill(xformers_output.isnan(), rand) + original_output = original_output.masked_fill(original_output.isnan(), rand) + + # torch.equal works for cpu, on cuda allclose is needed. + assert torch.allclose( + xformers_output, original_output, atol=1e-06 + ), f"max diff is {torch.max(torch.abs(xformers_output - original_output))}" + + loss_xformers = torch.norm(xformers_output) + loss_original = torch.norm(original_output) + loss_xformers.backward() + loss_original.backward() + + # torch.equal works for cpu, on cuda allclose is needed. + assert torch.allclose( + q.grad, q_.grad + ), f"max diff is {torch.max(torch.abs(q.grad - q_.grad))}" + assert torch.allclose( + k.grad, k_.grad + ), f"max diff is {torch.max(torch.abs(k.grad - k_.grad))}" + assert torch.allclose( + v.grad, v_.grad + ), f"max diff is {torch.max(torch.abs(v.grad - v_.grad))}" def test_mask_padding_parity(): @@ -28,7 +360,7 @@ def test_mask_padding_parity(): # values don't matter for this test. mha = MultiheadAttention( - embedding=8, + embed_dim=8, num_heads=2, dropout=0.0, add_bias_kv=True, @@ -50,7 +382,7 @@ def test_mask_padding_parity(): def test_add_bias_parity(): # values don't matter for this test. mha = MultiheadAttention( - embedding=8, + embed_dim=8, num_heads=2, dropout=0.0, add_bias_kv=True,