mirror of
https://github.com/facebookresearch/fairseq.git
synced 2024-10-26 17:32:57 +03:00
436166a00c
Summary: # Before submitting - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [ ] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [ ] Did you make sure to update the docs? - [ ] Did you write any new necessary tests? ## What does this PR do? Fixes https://github.com/fairinternal/fairseq-py/issues/1538. ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1798 Reviewed By: myleott Differential Revision: D27710902 Pulled By: gwenzek fbshipit-source-id: 2efdf645bb30e4cf6653c48371bfca8df6f94eaf
66 lines
1.9 KiB
Python
66 lines
1.9 KiB
Python
import argparse
|
|
import unittest
|
|
from typing import Any, Dict, Sequence
|
|
|
|
import torch
|
|
from fairseq.models import transformer
|
|
|
|
from tests.test_roberta import FakeTask
|
|
|
|
|
|
def mk_sample(tok: Sequence[int] = None, batch_size: int = 2) -> Dict[str, Any]:
|
|
if not tok:
|
|
tok = [10, 11, 12, 13, 14, 15, 2]
|
|
|
|
batch = torch.stack([torch.tensor(tok, dtype=torch.long)] * batch_size)
|
|
sample = {
|
|
"net_input": {
|
|
"src_tokens": batch,
|
|
"prev_output_tokens": batch,
|
|
"src_lengths": torch.tensor(
|
|
[len(tok)] * batch_size, dtype=torch.long, device=batch.device
|
|
),
|
|
},
|
|
"target": batch[:, 1:],
|
|
}
|
|
return sample
|
|
|
|
|
|
def mk_transformer(**extra_args: Any):
|
|
overrides = {
|
|
# Use characteristics dimensions
|
|
"encoder_embed_dim": 12,
|
|
"encoder_ffn_embed_dim": 14,
|
|
"decoder_embed_dim": 12,
|
|
"decoder_ffn_embed_dim": 14,
|
|
# Disable dropout so we have comparable tests.
|
|
"dropout": 0,
|
|
"attention_dropout": 0,
|
|
"activation_dropout": 0,
|
|
"encoder_layerdrop": 0,
|
|
}
|
|
overrides.update(extra_args)
|
|
# Overrides the defaults from the parser
|
|
args = argparse.Namespace(**overrides)
|
|
transformer.tiny_architecture(args)
|
|
|
|
torch.manual_seed(0)
|
|
task = FakeTask(args)
|
|
return transformer.TransformerModel.build_model(args, task)
|
|
|
|
|
|
class TransformerTestCase(unittest.TestCase):
|
|
def test_forward_backward(self):
|
|
model = mk_transformer(encoder_embed_dim=12, decoder_embed_dim=12)
|
|
sample = mk_sample()
|
|
o, _ = model.forward(**sample["net_input"])
|
|
loss = o.sum()
|
|
loss.backward()
|
|
|
|
def test_different_encoder_decoder_embed_dim(self):
|
|
model = mk_transformer(encoder_embed_dim=12, decoder_embed_dim=16)
|
|
sample = mk_sample()
|
|
o, _ = model.forward(**sample["net_input"])
|
|
loss = o.sum()
|
|
loss.backward()
|