fix broken build and docs (#3362)

Summary:
# Before submitting

- [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements)
- [ ] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/main/CONTRIBUTING.md)?
- [ ] Did you make sure to update the docs?
- [ ] Did you write any new necessary tests?

## What does this PR do?
- [x] formatting fix
- [x] optional import of xFormers
- [x] enabled doc building as part of CI
- [x] remove mask arguments for attentions that do not support them
- [x] remove masks for blocksparse tests, no longer supported
- [ ] use pytest instead of deprecated `setup.py test`
- [ ] CircleCI xFormers tests

Will submit without the last two done to unblock people using the repo

## PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

## Did you have fun?
Make sure you had fun coding �

X-link: https://github.com/fairinternal/fairseq-py/pull/3362

Reviewed By: blefaudeux

Differential Revision: D36169572

Pulled By: dianaml0

fbshipit-source-id: 3b20ae5f377144a0854e016771af703f0d0d694b
This commit is contained in:
dianaml0 2022-05-05 15:18:53 -07:00 committed by Facebook GitHub Bot
parent 51478ad3a1
commit e71c4d04d7
4 changed files with 24 additions and 32 deletions

View File

@ -56,7 +56,6 @@ install_dep_xformers: &install_dep_xformers
pip install -r requirements.txt
pip install -e .
install_dep_pt19: &install_dep_pt19
- run:
name: Install Pytorch Dependencies
@ -120,6 +119,7 @@ create_conda_env: &create_conda_env
# -------------------------------------------------------------------------------------
jobs:
gpu_tests_pt19:
<<: *gpu
@ -134,7 +134,6 @@ jobs:
- <<: *install_dep_pt19
- <<: *install_dep_common
- <<: *install_dep_fused_ops
- <<: *install_dep_xformers
- save_cache:
paths:
- ~/miniconda/
@ -156,7 +155,6 @@ jobs:
- <<: *install_dep_pt18
- <<: *install_dep_common
- <<: *install_dep_fused_ops
- <<: *install_dep_xformers
- save_cache:
paths:
- ~/miniconda/

View File

@ -54,7 +54,7 @@ jobs:
- name: Run tests
run: |
python setup.py test
python setup.py test
- name: Lint with black
run: |

View File

@ -10,8 +10,14 @@ import torch
import torch.nn.functional as F
from torch import Tensor, nn
from torch.nn import Parameter
from xformers.components.attention import build_attention
from xformers.components.attention.utils import maybe_merge_masks
try:
from xformers.components.attention import build_attention
from xformers.components.attention.utils import maybe_merge_masks
_xformers_available = True
except ImportError:
_xformers_available = False
from fairseq import utils
from fairseq.incremental_decoding_utils import with_incremental_state
@ -53,6 +59,7 @@ def _mask_for_xformers(mask: Tensor, to_dtype: Optional[torch.dtype] = None):
mask = mask.to(to_dtype)
return mask
@with_incremental_state
class MultiheadAttention(nn.Module):
"""Multi-headed attention.
@ -88,6 +95,8 @@ class MultiheadAttention(nn.Module):
xformers_att_config = utils.eval_str_dict(xformers_att_config)
self.use_xformers = xformers_att_config is not None
if self.use_xformers and not _xformers_available:
raise ImportError("\n\n Please install xFormers.")
self.embed_dim = embed_dim
self.kdim = kdim if kdim is not None else embed_dim
self.vdim = vdim if vdim is not None else embed_dim
@ -420,9 +429,12 @@ class MultiheadAttention(nn.Module):
k=k, v=v, key_padding_mask=key_padding_mask, attn_mask=attn_mask
)
if attn_mask is not None:
kwargs = {}
if attn_mask is not None and self.attention.supports_attention_mask:
to_dtype = torch.float16 if self.fp16_mask else q.dtype
attn_mask = _mask_for_xformers(attn_mask, to_dtype=to_dtype)
kwargs["att_mask"] = attn_mask
if key_padding_mask is not None:
to_dtype = torch.float16 if self.fp16_mask else torch.bool
@ -437,10 +449,11 @@ class MultiheadAttention(nn.Module):
num_heads=self.num_heads,
)
key_padding_mask = None
kwargs["att_mask"] = attn_mask
if self.attention.supports_key_padding_mask:
kwargs["key_padding_mask"] = key_padding_mask
y = self.attention(
q, k, v, att_mask=attn_mask, key_padding_mask=key_padding_mask
)
y = self.attention(q, k, v, **kwargs)
y = (
y.view(bsz, self.num_heads, tgt_len, self.head_dim)

View File

@ -8,6 +8,7 @@ import unittest
import pytest
import torch
from fairseq.modules.multihead_attention import MultiheadAttention, _mask_for_xformers
BATCH = [20, 41, 97]
@ -99,9 +100,8 @@ def test_mask_for_xformers():
assert torch.equal(_mask_for_xformers(m_uint, to_dtype=torch.uint8), m_uint_flipped)
@pytest.mark.skipif(not torch.cuda.is_available(), reason="blocksparse requires gpu")
@pytest.mark.parametrize("device", ["cuda"])
@pytest.mark.parametrize("attn_dtype", ATTN_MASK_DTYPE)
@pytest.mark.parametrize("key_padding_dtype", KEY_PADDING_MASK_DTYPE)
@pytest.mark.parametrize("add_zero_attn", [False])
@pytest.mark.parametrize("batch_size", [20])
@pytest.mark.parametrize("embedding", [64])
@ -109,8 +109,6 @@ def test_mask_for_xformers():
@pytest.mark.parametrize("num_heads", [4])
def test_xformers_blocksparse_parity(
device,
attn_dtype,
key_padding_dtype,
add_zero_attn,
batch_size,
embedding,
@ -123,20 +121,7 @@ def test_xformers_blocksparse_parity(
xformers_blocksparse_layout = torch.ones(
seq_len // xformers_blocksparse_blocksize,
seq_len // xformers_blocksparse_blocksize,
)
attn_mask = (
None
if attn_dtype is None
else _get_mask(to_dtype=attn_dtype, dim0=seq_len, dim1=seq_len).to(device)
)
key_padding_mask = (
None
if key_padding_dtype is None
else _get_mask(to_dtype=key_padding_dtype, dim0=batch_size, dim1=seq_len).to(
device
)
dtype=torch.int32,
)
q = torch.rand(seq_len, batch_size, embedding).to(device).half()
@ -172,8 +157,6 @@ def test_xformers_blocksparse_parity(
q,
k,
v,
key_padding_mask=key_padding_mask,
attn_mask=attn_mask,
)
_reset_seeds()
@ -194,8 +177,6 @@ def test_xformers_blocksparse_parity(
q_,
k_,
v_,
key_padding_mask=key_padding_mask,
attn_mask=attn_mask,
)
# # account for when nan != nan