fairseq/tests/tasks/test_masked_lm.py

79 lines
2.9 KiB
Python
Raw Permalink Normal View History

add masked_lm test (#4344) Summary: # Before submitting - [X] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [X] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/main/CONTRIBUTING.md)? - [X] Did you make sure to update the docs? - [X] Did you write any new necessary tests? ## What does this PR do? Fixes https://github.com/pytorch/fairseq/issues/4300 ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Big time! Note: I had to update `black` because of [this known issue](https://github.com/psf/black/issues/2964): ``` black....................................................................Failed - hook id: black - exit code: 1 Traceback (most recent call last): File "/Users/azzhipa/.cache/pre-commit/repoxt83whf2/py_env-python3.8/bin/black", line 8, in <module> sys.exit(patched_main()) File "/Users/azzhipa/.cache/pre-commit/repoxt83whf2/py_env-python3.8/lib/python3.8/site-packages/black/__init__.py", line 1423, in patched_main patch_click() File "/Users/azzhipa/.cache/pre-commit/repoxt83whf2/py_env-python3.8/lib/python3.8/site-packages/black/__init__.py", line 1409, in patch_click from click import _unicodefun ImportError: cannot import name '_unicodefun' from 'click' (/Users/azzhipa/.cache/pre-commit/repoxt83whf2/py_env-python3.8/lib/python3.8/site-packages/click/__init__.py) ``` Pull Request resolved: https://github.com/pytorch/fairseq/pull/4344 Reviewed By: zhengwy888 Differential Revision: D35691648 Pulled By: dianaml0 fbshipit-source-id: 4bdf408bc9d9cca76c9c08e138cf85b1d00d14d4
2022-04-19 00:47:00 +03:00
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import os
import unittest
from tempfile import TemporaryDirectory
from fairseq.binarizer import FileBinarizer, VocabularyDatasetBinarizer
from fairseq.tasks.masked_lm import MaskedLMConfig, MaskedLMTask
from tests.utils import build_vocab, make_data
class TestMaskedLM(unittest.TestCase):
def test_masks_tokens(self):
with TemporaryDirectory() as dirname:
# prep input file
raw_file = os.path.join(dirname, "raw")
data = make_data(out_file=raw_file)
vocab = build_vocab(data)
# binarize
binarizer = VocabularyDatasetBinarizer(vocab, append_eos=False)
split = "train"
bin_file = os.path.join(dirname, split)
FileBinarizer.multiprocess_dataset(
input_file=raw_file,
binarizer=binarizer,
dataset_impl="mmap",
vocab_size=len(vocab),
output_prefix=bin_file,
)
# setup task
cfg = MaskedLMConfig(
data=dirname,
seed=42,
mask_prob=0.5, # increasing the odds of masking
random_token_prob=0, # avoiding random tokens for exact match
leave_unmasked_prob=0, # always masking for exact match
)
task = MaskedLMTask(cfg, binarizer.dict)
original_dataset = task._load_dataset_split(bin_file, 1, False)
# load datasets
task.load_dataset(split)
masked_dataset = task.dataset(split)
mask_index = task.source_dictionary.index("<mask>")
iterator = task.get_batch_iterator(
dataset=masked_dataset,
max_tokens=65_536,
max_positions=4_096,
).next_epoch_itr(shuffle=False)
for batch in iterator:
for sample in range(len(batch)):
net_input = batch["net_input"]
masked_src_tokens = net_input["src_tokens"][sample]
masked_src_length = net_input["src_lengths"][sample]
masked_tgt_tokens = batch["target"][sample]
sample_id = batch["id"][sample]
original_tokens = original_dataset[sample_id]
original_tokens = original_tokens.masked_select(
masked_src_tokens[:masked_src_length] == mask_index
)
masked_tokens = masked_tgt_tokens.masked_select(
masked_tgt_tokens != task.source_dictionary.pad()
)
assert masked_tokens.equal(original_tokens)
if __name__ == "__main__":
unittest.main()