mirror of
https://github.com/facebookresearch/fairseq.git
synced 2024-10-27 01:41:27 +03:00
53bf2b1293
Summary: ## What does this PR do? there are a few places where we do file chunking for multiprocessing a single file. However, the code is partly in Binarizer and partly just duplicated here and there. This PR extracts the file chunking/reading logic. The multiprocessing logic could probably be extracted too, but I haven't found a good abstraction yet. # Testing Added testing for this reading logic + maybe fixed a bug where the last part of a file might get dropped (even if it's unclear with the current stopping logic) Tested by running the preprocessing script as follow: ``` python -m fairseq_cli.preprocess --source-lang de --target-lang en --trainpref ...train.spm.clean.de_en --srcdict ...fairseq.dict --tgtdict .../fairseq.dict --destdir ... --workers 60 ``` Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1955 Reviewed By: myleott Differential Revision: D29065473 Pulled By: Mortimerp9 fbshipit-source-id: c60843de8cfd45a63b3dbb8290f57ef3df3bf983
146 lines
4.4 KiB
Python
146 lines
4.4 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
#
|
|
# This source code is licensed under the MIT license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import io
|
|
import os
|
|
import string
|
|
import tempfile
|
|
import unittest
|
|
|
|
import torch
|
|
from fairseq import tokenizer
|
|
from fairseq.data import Dictionary
|
|
|
|
|
|
class TestDictionary(unittest.TestCase):
|
|
def test_finalize(self):
|
|
txt = [
|
|
"A B C D",
|
|
"B C D",
|
|
"C D",
|
|
"D",
|
|
]
|
|
ref_ids1 = list(
|
|
map(
|
|
torch.IntTensor,
|
|
[
|
|
[4, 5, 6, 7, 2],
|
|
[5, 6, 7, 2],
|
|
[6, 7, 2],
|
|
[7, 2],
|
|
],
|
|
)
|
|
)
|
|
ref_ids2 = list(
|
|
map(
|
|
torch.IntTensor,
|
|
[
|
|
[7, 6, 5, 4, 2],
|
|
[6, 5, 4, 2],
|
|
[5, 4, 2],
|
|
[4, 2],
|
|
],
|
|
)
|
|
)
|
|
|
|
# build dictionary
|
|
d = Dictionary()
|
|
for line in txt:
|
|
d.encode_line(line, add_if_not_exist=True)
|
|
|
|
def get_ids(dictionary):
|
|
ids = []
|
|
for line in txt:
|
|
ids.append(dictionary.encode_line(line, add_if_not_exist=False))
|
|
return ids
|
|
|
|
def assertMatch(ids, ref_ids):
|
|
for toks, ref_toks in zip(ids, ref_ids):
|
|
self.assertEqual(toks.size(), ref_toks.size())
|
|
self.assertEqual(0, (toks != ref_toks).sum().item())
|
|
|
|
ids = get_ids(d)
|
|
assertMatch(ids, ref_ids1)
|
|
|
|
# check finalized dictionary
|
|
d.finalize()
|
|
finalized_ids = get_ids(d)
|
|
assertMatch(finalized_ids, ref_ids2)
|
|
|
|
# write to disk and reload
|
|
with tempfile.NamedTemporaryFile(mode="w") as tmp_dict:
|
|
d.save(tmp_dict.name)
|
|
d = Dictionary.load(tmp_dict.name)
|
|
reload_ids = get_ids(d)
|
|
assertMatch(reload_ids, ref_ids2)
|
|
assertMatch(finalized_ids, reload_ids)
|
|
|
|
def test_overwrite(self):
|
|
# for example, Camembert overwrites <unk>, <s> and </s>
|
|
dict_file = io.StringIO(
|
|
"<unk> 999 #fairseq:overwrite\n"
|
|
"<s> 999 #fairseq:overwrite\n"
|
|
"</s> 999 #fairseq:overwrite\n"
|
|
", 999\n"
|
|
"▁de 999\n"
|
|
)
|
|
d = Dictionary()
|
|
d.add_from_file(dict_file)
|
|
self.assertEqual(d.index("<pad>"), 1)
|
|
self.assertEqual(d.index("foo"), 3)
|
|
self.assertEqual(d.index("<unk>"), 4)
|
|
self.assertEqual(d.index("<s>"), 5)
|
|
self.assertEqual(d.index("</s>"), 6)
|
|
self.assertEqual(d.index(","), 7)
|
|
self.assertEqual(d.index("▁de"), 8)
|
|
|
|
def test_no_overwrite(self):
|
|
# for example, Camembert overwrites <unk>, <s> and </s>
|
|
dict_file = io.StringIO(
|
|
"<unk> 999\n" "<s> 999\n" "</s> 999\n" ", 999\n" "▁de 999\n"
|
|
)
|
|
d = Dictionary()
|
|
with self.assertRaisesRegex(RuntimeError, "Duplicate"):
|
|
d.add_from_file(dict_file)
|
|
|
|
def test_space(self):
|
|
# for example, character models treat space as a symbol
|
|
dict_file = io.StringIO(" 999\n" "a 999\n" "b 999\n")
|
|
d = Dictionary()
|
|
d.add_from_file(dict_file)
|
|
self.assertEqual(d.index(" "), 4)
|
|
self.assertEqual(d.index("a"), 5)
|
|
self.assertEqual(d.index("b"), 6)
|
|
|
|
def test_add_file_to_dict(self):
|
|
counts = {}
|
|
num_lines = 100
|
|
per_line = 10
|
|
with tempfile.TemporaryDirectory("test_sampling") as data_dir:
|
|
filename = os.path.join(data_dir, "dummy.txt")
|
|
with open(filename, "w", encoding="utf-8") as data:
|
|
for c in string.ascii_letters:
|
|
line = f"{c} " * per_line
|
|
for _ in range(num_lines):
|
|
data.write(f"{line}\n")
|
|
counts[c] = per_line * num_lines
|
|
per_line += 5
|
|
|
|
dict = Dictionary()
|
|
Dictionary.add_file_to_dictionary(
|
|
filename, dict, tokenizer.tokenize_line, 10
|
|
)
|
|
dict.finalize(threshold=0, nwords=-1, padding_factor=8)
|
|
|
|
for c in string.ascii_letters:
|
|
count = dict.get_count(dict.index(c))
|
|
self.assertEqual(
|
|
counts[c], count, f"{c} count is {count} but should be {counts[c]}"
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|