mirror of
https://github.com/facebookresearch/fairseq.git
synced 2024-10-05 13:17:39 +03:00
5f2515e676
Summary: # Before submitting - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [ ] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/main/CONTRIBUTING.md)? - [ ] Did you make sure to update the docs? - [ ] Did you write any new necessary tests? ## What does this PR do? Fixes # (issue). ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/3065 Reviewed By: Mortimerp9 Differential Revision: D34144674 Pulled By: dianaml0 fbshipit-source-id: 842b0d29c9c85d4b56b640f2823fcb4e3f912f98
55 lines
1.8 KiB
Python
55 lines
1.8 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
#
|
|
# This source code is licensed under the MIT license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import unittest
|
|
|
|
import torch
|
|
|
|
from fairseq.data import MonolingualDataset
|
|
from fairseq.tasks.language_modeling import LanguageModelingConfig, LanguageModelingTask
|
|
from tests import utils as test_utils
|
|
|
|
|
|
class TestLMContextWindow(unittest.TestCase):
|
|
def test_eval_dataloader(self):
|
|
dictionary = test_utils.dummy_dictionary(10)
|
|
assert len(dictionary) == 14 # 4 extra special symbols
|
|
assert dictionary.pad() == 1
|
|
|
|
dataset = test_utils.TestDataset(
|
|
[
|
|
torch.tensor([4, 5, 6, 7], dtype=torch.long),
|
|
torch.tensor([8, 9, 10, 11], dtype=torch.long),
|
|
torch.tensor([12, 13], dtype=torch.long),
|
|
]
|
|
)
|
|
dataset = MonolingualDataset(dataset, sizes=[4, 4, 2], src_vocab=dictionary)
|
|
|
|
config = LanguageModelingConfig(tokens_per_sample=4)
|
|
task = LanguageModelingTask(config, dictionary)
|
|
|
|
eval_dataloader = task.eval_lm_dataloader(
|
|
dataset=dataset,
|
|
batch_size=1,
|
|
context_window=2,
|
|
num_workers=0,
|
|
)
|
|
|
|
batch = next(eval_dataloader)
|
|
assert batch["net_input"]["src_tokens"][0].tolist() == [4, 5, 6, 7, 1, 1]
|
|
assert batch["target"][0].tolist() == [4, 5, 6, 7, 1, 1]
|
|
|
|
batch = next(eval_dataloader)
|
|
assert batch["net_input"]["src_tokens"][0].tolist() == [6, 7, 8, 9, 10, 11]
|
|
assert batch["target"][0].tolist() == [1, 1, 8, 9, 10, 11]
|
|
|
|
batch = next(eval_dataloader)
|
|
assert batch["net_input"]["src_tokens"][0].tolist() == [10, 11, 12, 13]
|
|
assert batch["target"][0].tolist() == [1, 1, 12, 13]
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|