fairseq/tests/tasks/test_masked_lm.py
Alexander Jipa 355ffbe4e2 add masked_lm test (#4344)
Summary:
# Before submitting

- [X] Was this discussed/approved via a Github issue? (no need for typos, doc improvements)
- [X] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/main/CONTRIBUTING.md)?
- [X] Did you make sure to update the docs?
- [X] Did you write any new necessary tests?

## What does this PR do?
Fixes https://github.com/pytorch/fairseq/issues/4300

## PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

## Did you have fun?
Big time!

Note:
I had to update `black` because of [this known issue](https://github.com/psf/black/issues/2964):
```
black....................................................................Failed
- hook id: black
- exit code: 1
Traceback (most recent call last):
  File "/Users/azzhipa/.cache/pre-commit/repoxt83whf2/py_env-python3.8/bin/black", line 8, in <module>
    sys.exit(patched_main())
  File "/Users/azzhipa/.cache/pre-commit/repoxt83whf2/py_env-python3.8/lib/python3.8/site-packages/black/__init__.py", line 1423, in patched_main
    patch_click()
  File "/Users/azzhipa/.cache/pre-commit/repoxt83whf2/py_env-python3.8/lib/python3.8/site-packages/black/__init__.py", line 1409, in patch_click
    from click import _unicodefun
ImportError: cannot import name '_unicodefun' from 'click' (/Users/azzhipa/.cache/pre-commit/repoxt83whf2/py_env-python3.8/lib/python3.8/site-packages/click/__init__.py)
```

Pull Request resolved: https://github.com/pytorch/fairseq/pull/4344

Reviewed By: zhengwy888

Differential Revision: D35691648

Pulled By: dianaml0

fbshipit-source-id: 4bdf408bc9d9cca76c9c08e138cf85b1d00d14d4
2022-04-18 14:47:00 -07:00

79 lines
2.9 KiB
Python

# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import os
import unittest
from tempfile import TemporaryDirectory
from fairseq.binarizer import FileBinarizer, VocabularyDatasetBinarizer
from fairseq.tasks.masked_lm import MaskedLMConfig, MaskedLMTask
from tests.utils import build_vocab, make_data
class TestMaskedLM(unittest.TestCase):
def test_masks_tokens(self):
with TemporaryDirectory() as dirname:
# prep input file
raw_file = os.path.join(dirname, "raw")
data = make_data(out_file=raw_file)
vocab = build_vocab(data)
# binarize
binarizer = VocabularyDatasetBinarizer(vocab, append_eos=False)
split = "train"
bin_file = os.path.join(dirname, split)
FileBinarizer.multiprocess_dataset(
input_file=raw_file,
binarizer=binarizer,
dataset_impl="mmap",
vocab_size=len(vocab),
output_prefix=bin_file,
)
# setup task
cfg = MaskedLMConfig(
data=dirname,
seed=42,
mask_prob=0.5, # increasing the odds of masking
random_token_prob=0, # avoiding random tokens for exact match
leave_unmasked_prob=0, # always masking for exact match
)
task = MaskedLMTask(cfg, binarizer.dict)
original_dataset = task._load_dataset_split(bin_file, 1, False)
# load datasets
task.load_dataset(split)
masked_dataset = task.dataset(split)
mask_index = task.source_dictionary.index("<mask>")
iterator = task.get_batch_iterator(
dataset=masked_dataset,
max_tokens=65_536,
max_positions=4_096,
).next_epoch_itr(shuffle=False)
for batch in iterator:
for sample in range(len(batch)):
net_input = batch["net_input"]
masked_src_tokens = net_input["src_tokens"][sample]
masked_src_length = net_input["src_lengths"][sample]
masked_tgt_tokens = batch["target"][sample]
sample_id = batch["id"][sample]
original_tokens = original_dataset[sample_id]
original_tokens = original_tokens.masked_select(
masked_src_tokens[:masked_src_length] == mask_index
)
masked_tokens = masked_tgt_tokens.masked_select(
masked_tgt_tokens != task.source_dictionary.pad()
)
assert masked_tokens.equal(original_tokens)
if __name__ == "__main__":
unittest.main()