mirror of
https://github.com/facebookresearch/fairseq.git
synced 2024-11-11 03:31:32 +03:00
279796224f
Summary: This is the equivalent to PR https://github.com/fairinternal/fairseq-py/issues/2697 but on top of main instead of gshard (cherry-picked and merged the squash): * reorganize preprocess.py code a bit * use Binarizers objects in the multiprocess code * clean up the make_binary * multiprocess logic * learn to count * format and doc string * add basic test for vocab binarizer * generalize to one line * move multiprocess in binarizer Testing: ``` python -m fairseq_cli.preprocess --only-source --trainpref ~/fixathon/small_vocab_test/train.in --destdir ~/fixathon/small_vocab_test/data-bin.cherry --workers 20 python -m fairseq_cli.preprocess --only-source --trainpref ~/fixathon/small_vocab_test/train.in --destdir ~/fixathon/small_vocab_test/data-bin.main --workers 20 ``` ``` md5sum ~/fixathon/small_vocab_test/data-bin.cherry/train.bin == md5sum ~/fixathon/small_vocab_test/data-bin.main/train.bin ``` ``` diff ~/fixathon/small_vocab_test/data-bin.main/dict.txt ~/fixathon/small_vocab_test/data-bin.cherry/dict.tx ``` Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/2738 Reviewed By: sshleifer, dianaml0 Differential Revision: D32830875 Pulled By: Mortimerp9 fbshipit-source-id: e7463d5cdd96a877691bf39666daa319ebb3dcb8
123 lines
3.7 KiB
Python
123 lines
3.7 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
#
|
|
# This source code is licensed under the MIT license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
|
|
import os
|
|
import typing as tp
|
|
import unittest
|
|
from tempfile import TemporaryDirectory
|
|
|
|
from fairseq.binarizer import BinarizeSummary, FileBinarizer, VocabularyDatasetBinarizer
|
|
from fairseq.data import Dictionary, indexed_dataset
|
|
from tests.utils import make_data, sizes
|
|
|
|
|
|
def build_vocab(data: tp.List[tp.List[str]]) -> Dictionary:
|
|
d = Dictionary()
|
|
for s in data:
|
|
for token in s:
|
|
d.add_symbol(token)
|
|
d.finalize()
|
|
return d
|
|
|
|
|
|
class TestBinarizer(unittest.TestCase):
|
|
def compare_ds_data(self, summary, data, prefix, impl, vocab):
|
|
self.assertEqual(summary.num_seq, len(data))
|
|
self.assertEqual(summary.num_tok, sum([len(s) for s in data]))
|
|
|
|
dataset = indexed_dataset.make_dataset(prefix, impl)
|
|
|
|
self.assertEqual(len(dataset), len(data))
|
|
decoded = [vocab.string(dataset[i]).split() for i in range(0, len(dataset))]
|
|
|
|
self.assertEqual(decoded, data)
|
|
data_sizes = [i.item() for i in dataset.sizes]
|
|
self.assertEqual(data_sizes, sizes(data))
|
|
|
|
def test_can_binarize_line(self):
|
|
data = make_data(length=1)
|
|
vocab = build_vocab(data)
|
|
|
|
binarizer = VocabularyDatasetBinarizer(
|
|
vocab,
|
|
)
|
|
|
|
sentence = data[0]
|
|
summary = BinarizeSummary()
|
|
|
|
tensor = binarizer.binarize_line(
|
|
" ".join(sentence),
|
|
summary,
|
|
)
|
|
|
|
self.assertEqual(len(tensor), len(sentence) + 1)
|
|
|
|
self.assertEqual(summary.num_tok, len(sentence) + 1)
|
|
self.assertEqual(summary.num_seq, 1)
|
|
|
|
def test_can_binarize_file_chunk(self):
|
|
# test without multiprocess logic
|
|
with TemporaryDirectory() as dirname:
|
|
raw_file = os.path.join(dirname, "raw1")
|
|
prefix = os.path.join(dirname, "test1")
|
|
impl = "mmap"
|
|
|
|
data = make_data(out_file=raw_file)
|
|
vocab = build_vocab(data)
|
|
|
|
binarizer = VocabularyDatasetBinarizer(
|
|
vocab,
|
|
append_eos=False,
|
|
)
|
|
|
|
summary = FileBinarizer._binarize_chunk_and_finalize(
|
|
binarizer,
|
|
raw_file,
|
|
offset_start=0,
|
|
offset_end=-1,
|
|
output_prefix=prefix,
|
|
dataset_impl=impl,
|
|
vocab_size=len(vocab),
|
|
)
|
|
|
|
self.compare_ds_data(summary, data, prefix, impl, vocab)
|
|
|
|
def test_can_multiprocess(self):
|
|
with TemporaryDirectory() as dirname:
|
|
raw_file = os.path.join(dirname, "raw1")
|
|
prefix = os.path.join(dirname, "test1")
|
|
impl = "mmap"
|
|
data = make_data(out_file=raw_file)
|
|
vocab = build_vocab(data)
|
|
binarizer = VocabularyDatasetBinarizer(
|
|
vocab,
|
|
append_eos=False,
|
|
)
|
|
# with one worker
|
|
summary = FileBinarizer.multiprocess_dataset(
|
|
raw_file,
|
|
impl,
|
|
binarizer,
|
|
output_prefix=prefix,
|
|
vocab_size=len(vocab),
|
|
num_workers=1,
|
|
)
|
|
|
|
self.compare_ds_data(summary, data, prefix, impl, vocab)
|
|
|
|
# with multiple worker
|
|
prefix_multi = os.path.join(dirname, "test2")
|
|
summary = FileBinarizer.multiprocess_dataset(
|
|
raw_file,
|
|
impl,
|
|
binarizer,
|
|
output_prefix=prefix_multi,
|
|
vocab_size=len(vocab),
|
|
num_workers=3,
|
|
)
|
|
|
|
self.compare_ds_data(summary, data, prefix_multi, impl, vocab)
|