fix formatting (#3350)

Summary:
# Before submitting

- [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements)
- [ ] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/main/CONTRIBUTING.md)?
- [ ] Did you make sure to update the docs?
- [ ] Did you write any new necessary tests?

## What does this PR do?
Fixes # (issue).

## PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

## Did you have fun?
Make sure you had fun coding �

X-link: https://github.com/fairinternal/fairseq-py/pull/3350

Reviewed By: shruti-bh

Differential Revision: D36009526

Pulled By: dianaml0

fbshipit-source-id: 9cdc3d53086b8d40a780bcb64cfe28108091ab98
This commit is contained in:
Diana Liskovich 2022-04-28 14:17:09 -07:00 committed by Facebook GitHub Bot
parent ab98e94046
commit 0b54d9fb2e
2 changed files with 14 additions and 8 deletions

View File

@ -272,7 +272,9 @@ class MultiheadAttention(nn.Module):
assert self.bias_v is not None
k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
key_padding_mask, attn_mask = self._pad_masks(key_padding_mask=key_padding_mask, attn_mask=attn_mask)
key_padding_mask, attn_mask = self._pad_masks(
key_padding_mask=key_padding_mask, attn_mask=attn_mask
)
return k, v, key_padding_mask, attn_mask
def _append_zero_attn(
@ -289,7 +291,9 @@ class MultiheadAttention(nn.Module):
v = torch.cat(
[v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], dim=-2
)
key_padding_mask, attn_mask = self._pad_masks(key_padding_mask=key_padding_mask, attn_mask=attn_mask)
key_padding_mask, attn_mask = self._pad_masks(
key_padding_mask=key_padding_mask, attn_mask=attn_mask
)
return k, v, key_padding_mask, attn_mask
def forward(

View File

@ -6,11 +6,11 @@
import unittest
import torch
from fairseq.modules.multihead_attention import MultiheadAttention
def test_mask_padding_parity():
def old_padding_code(key_padding_mask, attn_mask):
if attn_mask is not None:
attn_mask = torch.cat(
@ -20,9 +20,7 @@ def test_mask_padding_parity():
key_padding_mask = torch.cat(
[
key_padding_mask,
torch.zeros(key_padding_mask.size(0), 1).type_as(
key_padding_mask
),
torch.zeros(key_padding_mask.size(0), 1).type_as(key_padding_mask),
],
dim=1,
)
@ -84,8 +82,12 @@ def test_add_bias_parity():
k = torch.rand((seq_len, bsz, embedding))
v = torch.rand((seq_len, bsz, embedding))
k_orig, v_orig, kp_mask_orig, a_mask_orig = old_bias_code(k, v, key_padding_mask, attn_mask, bsz)
k_new, v_new, kp_mask_new, a_mask_new = mha._add_bias(k, v, key_padding_mask, attn_mask, bsz)
k_orig, v_orig, kp_mask_orig, a_mask_orig = old_bias_code(
k, v, key_padding_mask, attn_mask, bsz
)
k_new, v_new, kp_mask_new, a_mask_new = mha._add_bias(
k, v, key_padding_mask, attn_mask, bsz
)
assert torch.equal(k_orig, k_new)
assert torch.equal(v_orig, v_new)