fairseq/tests/test_transformer.py
Guillaume Wenzek 436166a00c fix MultiHeadAttention assert (#1798)
Summary:
# Before submitting

- [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements)
- [ ] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)?
- [ ] Did you make sure to update the docs?
- [ ] Did you write any new necessary tests?

## What does this PR do?
Fixes https://github.com/fairinternal/fairseq-py/issues/1538.

## PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

## Did you have fun?
Make sure you had fun coding �

Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1798

Reviewed By: myleott

Differential Revision: D27710902

Pulled By: gwenzek

fbshipit-source-id: 2efdf645bb30e4cf6653c48371bfca8df6f94eaf
2021-04-14 04:59:59 -07:00

66 lines
1.9 KiB
Python

import argparse
import unittest
from typing import Any, Dict, Sequence
import torch
from fairseq.models import transformer
from tests.test_roberta import FakeTask
def mk_sample(tok: Sequence[int] = None, batch_size: int = 2) -> Dict[str, Any]:
if not tok:
tok = [10, 11, 12, 13, 14, 15, 2]
batch = torch.stack([torch.tensor(tok, dtype=torch.long)] * batch_size)
sample = {
"net_input": {
"src_tokens": batch,
"prev_output_tokens": batch,
"src_lengths": torch.tensor(
[len(tok)] * batch_size, dtype=torch.long, device=batch.device
),
},
"target": batch[:, 1:],
}
return sample
def mk_transformer(**extra_args: Any):
overrides = {
# Use characteristics dimensions
"encoder_embed_dim": 12,
"encoder_ffn_embed_dim": 14,
"decoder_embed_dim": 12,
"decoder_ffn_embed_dim": 14,
# Disable dropout so we have comparable tests.
"dropout": 0,
"attention_dropout": 0,
"activation_dropout": 0,
"encoder_layerdrop": 0,
}
overrides.update(extra_args)
# Overrides the defaults from the parser
args = argparse.Namespace(**overrides)
transformer.tiny_architecture(args)
torch.manual_seed(0)
task = FakeTask(args)
return transformer.TransformerModel.build_model(args, task)
class TransformerTestCase(unittest.TestCase):
def test_forward_backward(self):
model = mk_transformer(encoder_embed_dim=12, decoder_embed_dim=12)
sample = mk_sample()
o, _ = model.forward(**sample["net_input"])
loss = o.sum()
loss.backward()
def test_different_encoder_decoder_embed_dim(self):
model = mk_transformer(encoder_embed_dim=12, decoder_embed_dim=16)
sample = mk_sample()
o, _ = model.forward(**sample["net_input"])
loss = o.sum()
loss.backward()