mirror of
https://github.com/facebookresearch/fairseq.git
synced 2024-10-05 13:17:39 +03:00
Preprocess Split (#2738)
Summary: This is the equivalent to PR https://github.com/fairinternal/fairseq-py/issues/2697 but on top of main instead of gshard (cherry-picked and merged the squash): * reorganize preprocess.py code a bit * use Binarizers objects in the multiprocess code * clean up the make_binary * multiprocess logic * learn to count * format and doc string * add basic test for vocab binarizer * generalize to one line * move multiprocess in binarizer Testing: ``` python -m fairseq_cli.preprocess --only-source --trainpref ~/fixathon/small_vocab_test/train.in --destdir ~/fixathon/small_vocab_test/data-bin.cherry --workers 20 python -m fairseq_cli.preprocess --only-source --trainpref ~/fixathon/small_vocab_test/train.in --destdir ~/fixathon/small_vocab_test/data-bin.main --workers 20 ``` ``` md5sum ~/fixathon/small_vocab_test/data-bin.cherry/train.bin == md5sum ~/fixathon/small_vocab_test/data-bin.main/train.bin ``` ``` diff ~/fixathon/small_vocab_test/data-bin.main/dict.txt ~/fixathon/small_vocab_test/data-bin.cherry/dict.tx ``` Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/2738 Reviewed By: sshleifer, dianaml0 Differential Revision: D32830875 Pulled By: Mortimerp9 fbshipit-source-id: e7463d5cdd96a877691bf39666daa319ebb3dcb8
This commit is contained in:
parent
b3fa5100c6
commit
279796224f
@ -3,78 +3,379 @@
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import os
|
||||
import typing as tp
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import Counter
|
||||
from typing import Dict
|
||||
from dataclasses import dataclass
|
||||
from multiprocessing import Pool
|
||||
|
||||
import torch
|
||||
|
||||
from fairseq.file_chunker_utils import Chunker
|
||||
from fairseq.data import Dictionary, indexed_dataset
|
||||
from fairseq.file_chunker_utils import Chunker, find_offsets
|
||||
from fairseq.file_io import PathManager
|
||||
from fairseq.tokenizer import tokenize_line
|
||||
|
||||
logger = logging.getLogger("binarizer")
|
||||
|
||||
|
||||
@dataclass
|
||||
class BinarizeSummary:
|
||||
"""
|
||||
Keep track of what's going on in the binarizer
|
||||
"""
|
||||
|
||||
num_seq: int = 0
|
||||
replaced: tp.Optional[Counter] = None
|
||||
num_tok: int = 0
|
||||
|
||||
@property
|
||||
def num_replaced(self) -> int:
|
||||
if self.replaced is None:
|
||||
return 0
|
||||
return sum(self.replaced.values())
|
||||
|
||||
@property
|
||||
def replaced_percent(self) -> float:
|
||||
return 100 * self.num_replaced / self.num_tok
|
||||
|
||||
def __str__(self) -> str:
|
||||
base = f"{self.num_seq} sents, {self.num_tok} tokens"
|
||||
if self.replaced is None:
|
||||
return base
|
||||
|
||||
return f"{base}, {self.replaced_percent:.3}% replaced"
|
||||
|
||||
def merge(self, other: "BinarizeSummary"):
|
||||
replaced = None
|
||||
if self.replaced is not None:
|
||||
replaced = self.replaced
|
||||
if other.replaced is not None:
|
||||
if replaced is None:
|
||||
replaced = other.replaced
|
||||
else:
|
||||
replaced += other.replaced
|
||||
self.replaced = replaced
|
||||
self.num_seq += other.num_seq
|
||||
self.num_tok += other.num_tok
|
||||
|
||||
|
||||
class Binarizer(ABC):
|
||||
"""
|
||||
a binarizer describes how to take a string and build a tensor out of it
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def binarize_line(
|
||||
self,
|
||||
line: str,
|
||||
summary: BinarizeSummary,
|
||||
) -> torch.IntTensor:
|
||||
...
|
||||
|
||||
|
||||
def _worker_prefix(output_prefix: str, worker_id: int):
|
||||
return f"{output_prefix}.pt{worker_id}"
|
||||
|
||||
|
||||
class FileBinarizer:
|
||||
"""
|
||||
An file binarizer can take a file, tokenize it, and binarize each line to a tensor
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def multiprocess_dataset(
|
||||
cls,
|
||||
input_file: str,
|
||||
dataset_impl: str,
|
||||
binarizer: Binarizer,
|
||||
output_prefix: str,
|
||||
vocab_size=None,
|
||||
num_workers=1,
|
||||
) -> BinarizeSummary:
|
||||
final_summary = BinarizeSummary()
|
||||
|
||||
offsets = find_offsets(input_file, num_workers)
|
||||
# find_offsets returns a list of position [pos1, pos2, pos3, pos4] but we would want pairs:
|
||||
# [(pos1, pos2), (pos2, pos3), (pos3, pos4)] to process the chunks with start/end info
|
||||
# we zip the list with itself shifted by one to get all the pairs.
|
||||
(first_chunk, *more_chunks) = zip(offsets, offsets[1:])
|
||||
pool = None
|
||||
if num_workers > 1:
|
||||
pool = Pool(processes=num_workers - 1)
|
||||
worker_results = [
|
||||
pool.apply_async(
|
||||
cls._binarize_chunk_and_finalize,
|
||||
args=(
|
||||
binarizer,
|
||||
input_file,
|
||||
start_offset,
|
||||
end_offset,
|
||||
_worker_prefix(
|
||||
output_prefix,
|
||||
worker_id,
|
||||
),
|
||||
dataset_impl,
|
||||
),
|
||||
kwds={
|
||||
"vocab_size": vocab_size,
|
||||
}
|
||||
if vocab_size is not None
|
||||
else {},
|
||||
)
|
||||
for worker_id, (start_offset, end_offset) in enumerate(
|
||||
more_chunks, start=1
|
||||
)
|
||||
]
|
||||
|
||||
pool.close()
|
||||
pool.join()
|
||||
for r in worker_results:
|
||||
summ = r.get()
|
||||
final_summary.merge(summ)
|
||||
|
||||
# do not close the bin file as we need to merge the worker results in
|
||||
final_ds, summ = cls._binarize_file_chunk(
|
||||
binarizer,
|
||||
input_file,
|
||||
offset_start=first_chunk[0],
|
||||
offset_end=first_chunk[1],
|
||||
output_prefix=output_prefix,
|
||||
dataset_impl=dataset_impl,
|
||||
vocab_size=vocab_size if vocab_size is not None else None,
|
||||
)
|
||||
final_summary.merge(summ)
|
||||
|
||||
if num_workers > 1:
|
||||
for worker_id in range(1, num_workers):
|
||||
# merge the worker outputs
|
||||
worker_output_prefix = _worker_prefix(
|
||||
output_prefix,
|
||||
worker_id,
|
||||
)
|
||||
final_ds.merge_file_(worker_output_prefix)
|
||||
try:
|
||||
os.remove(indexed_dataset.data_file_path(worker_output_prefix))
|
||||
os.remove(indexed_dataset.index_file_path(worker_output_prefix))
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"couldn't remove {worker_output_prefix}.*", exc_info=e
|
||||
)
|
||||
|
||||
# now we can close the file
|
||||
idx_file = indexed_dataset.index_file_path(output_prefix)
|
||||
final_ds.finalize(idx_file)
|
||||
return final_summary
|
||||
|
||||
class Binarizer:
|
||||
@staticmethod
|
||||
def binarize(
|
||||
filename,
|
||||
dict,
|
||||
consumer,
|
||||
tokenize=tokenize_line,
|
||||
append_eos=True,
|
||||
reverse_order=False,
|
||||
offset=0,
|
||||
end=-1,
|
||||
already_numberized=False,
|
||||
) -> Dict[str, int]:
|
||||
nseq, ntok = 0, 0
|
||||
replaced = Counter()
|
||||
def _binarize_file_chunk(
|
||||
binarizer: Binarizer,
|
||||
filename: str,
|
||||
offset_start: int,
|
||||
offset_end: int,
|
||||
output_prefix: str,
|
||||
dataset_impl: str,
|
||||
vocab_size=None,
|
||||
) -> tp.Tuple[tp.Any, BinarizeSummary]: # (dataset builder, BinarizeSummary)
|
||||
"""
|
||||
creates a dataset builder and append binarized items to it. This function does not
|
||||
finalize the builder, this is useful if you want to do other things with your bin file
|
||||
like appending/merging other files
|
||||
"""
|
||||
bin_file = indexed_dataset.data_file_path(output_prefix)
|
||||
ds = indexed_dataset.make_builder(
|
||||
bin_file,
|
||||
impl=dataset_impl,
|
||||
vocab_size=vocab_size,
|
||||
)
|
||||
summary = BinarizeSummary()
|
||||
|
||||
with Chunker(
|
||||
PathManager.get_local_path(filename), offset_start, offset_end
|
||||
) as line_iterator:
|
||||
for line in line_iterator:
|
||||
ds.add_item(binarizer.binarize_line(line, summary))
|
||||
|
||||
return ds, summary
|
||||
|
||||
@classmethod
|
||||
def _binarize_chunk_and_finalize(
|
||||
cls,
|
||||
binarizer: Binarizer,
|
||||
filename: str,
|
||||
offset_start: int,
|
||||
offset_end: int,
|
||||
output_prefix: str,
|
||||
dataset_impl: str,
|
||||
vocab_size=None,
|
||||
):
|
||||
"""
|
||||
same as above, but also finalizes the builder
|
||||
"""
|
||||
ds, summ = cls._binarize_file_chunk(
|
||||
binarizer,
|
||||
filename,
|
||||
offset_start,
|
||||
offset_end,
|
||||
output_prefix,
|
||||
dataset_impl,
|
||||
vocab_size=vocab_size,
|
||||
)
|
||||
|
||||
idx_file = indexed_dataset.index_file_path(output_prefix)
|
||||
ds.finalize(idx_file)
|
||||
|
||||
return summ
|
||||
|
||||
|
||||
class VocabularyDatasetBinarizer(Binarizer):
|
||||
"""
|
||||
Takes a Dictionary/Vocabulary, assign ids to each
|
||||
token using the dictionary encode_line function.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dict: Dictionary,
|
||||
tokenize: tp.Callable[[str], tp.List[str]] = tokenize_line,
|
||||
append_eos: bool = True,
|
||||
reverse_order: bool = False,
|
||||
already_numberized: bool = False,
|
||||
) -> None:
|
||||
self.dict = dict
|
||||
self.tokenize = tokenize
|
||||
self.append_eos = append_eos
|
||||
self.reverse_order = reverse_order
|
||||
self.already_numberized = already_numberized
|
||||
super().__init__()
|
||||
|
||||
def binarize_line(
|
||||
self,
|
||||
line: str,
|
||||
summary: BinarizeSummary,
|
||||
):
|
||||
if summary.replaced is None:
|
||||
summary.replaced = Counter()
|
||||
|
||||
def replaced_consumer(word, idx):
|
||||
if idx == dict.unk_index and word != dict.unk_word:
|
||||
replaced.update([word])
|
||||
if idx == self.dict.unk_index and word != self.dict.unk_word:
|
||||
summary.replaced.update([word])
|
||||
|
||||
with Chunker(
|
||||
PathManager.get_local_path(filename), offset, end
|
||||
) as line_iterator:
|
||||
for line in line_iterator:
|
||||
if already_numberized:
|
||||
id_strings = line.strip().split()
|
||||
id_list = [int(id_string) for id_string in id_strings]
|
||||
if reverse_order:
|
||||
id_list.reverse()
|
||||
if append_eos:
|
||||
id_list.append(dict.eos())
|
||||
ids = torch.IntTensor(id_list)
|
||||
else:
|
||||
ids = dict.encode_line(
|
||||
line=line,
|
||||
line_tokenizer=tokenize,
|
||||
add_if_not_exist=False,
|
||||
consumer=replaced_consumer,
|
||||
append_eos=append_eos,
|
||||
reverse_order=reverse_order,
|
||||
)
|
||||
nseq += 1
|
||||
ntok += len(ids)
|
||||
consumer(ids)
|
||||
return {
|
||||
"nseq": nseq,
|
||||
"nunk": sum(replaced.values()),
|
||||
"ntok": ntok,
|
||||
"replaced": replaced,
|
||||
}
|
||||
if self.already_numberized:
|
||||
id_strings = line.strip().split()
|
||||
id_list = [int(id_string) for id_string in id_strings]
|
||||
if self.reverse_order:
|
||||
id_list.reverse()
|
||||
if self.append_eos:
|
||||
id_list.append(self.dict.eos())
|
||||
ids = torch.IntTensor(id_list)
|
||||
else:
|
||||
ids = self.dict.encode_line(
|
||||
line=line,
|
||||
line_tokenizer=self.tokenize,
|
||||
add_if_not_exist=False,
|
||||
consumer=replaced_consumer,
|
||||
append_eos=self.append_eos,
|
||||
reverse_order=self.reverse_order,
|
||||
)
|
||||
|
||||
summary.num_seq += 1
|
||||
summary.num_tok += len(ids)
|
||||
return ids
|
||||
|
||||
|
||||
class AlignmentDatasetBinarizer(Binarizer):
|
||||
"""
|
||||
binarize by parsing a set of alignments and packing
|
||||
them in a tensor (see utils.parse_alignment)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
alignment_parser: tp.Callable[[str], torch.IntTensor],
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.alignment_parser = alignment_parser
|
||||
|
||||
def binarize_line(
|
||||
self,
|
||||
line: str,
|
||||
summary: BinarizeSummary,
|
||||
):
|
||||
ids = self.alignment_parser(line)
|
||||
summary.num_seq += 1
|
||||
summary.num_tok += len(ids)
|
||||
return ids
|
||||
|
||||
|
||||
class LegacyBinarizer:
|
||||
@classmethod
|
||||
def binarize(
|
||||
cls,
|
||||
filename: str,
|
||||
dico: Dictionary,
|
||||
consumer: tp.Callable[[torch.IntTensor], None],
|
||||
tokenize: tp.Callable[[str], tp.List[str]] = tokenize_line,
|
||||
append_eos: bool = True,
|
||||
reverse_order: bool = False,
|
||||
offset: int = 0,
|
||||
end: int = -1,
|
||||
already_numberized: bool = False,
|
||||
) -> tp.Dict[str, int]:
|
||||
binarizer = VocabularyDatasetBinarizer(
|
||||
dict=dico,
|
||||
tokenize=tokenize,
|
||||
append_eos=append_eos,
|
||||
reverse_order=reverse_order,
|
||||
already_numberized=already_numberized,
|
||||
)
|
||||
return cls._consume_file(
|
||||
filename,
|
||||
binarizer,
|
||||
consumer,
|
||||
offset_start=offset,
|
||||
offset_end=end,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def binarize_alignments(
|
||||
cls,
|
||||
filename: str,
|
||||
alignment_parser: tp.Callable[[str], torch.IntTensor],
|
||||
consumer: tp.Callable[[torch.IntTensor], None],
|
||||
offset: int = 0,
|
||||
end: int = -1,
|
||||
) -> tp.Dict[str, int]:
|
||||
binarizer = AlignmentDatasetBinarizer(alignment_parser)
|
||||
return cls._consume_file(
|
||||
filename,
|
||||
binarizer,
|
||||
consumer,
|
||||
offset_start=offset,
|
||||
offset_end=end,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def binarize_alignments(
|
||||
filename, alignment_parser, consumer, offset=0, end=-1
|
||||
) -> Dict[str, int]:
|
||||
nseq = 0
|
||||
def _consume_file(
|
||||
filename: str,
|
||||
binarizer: Binarizer,
|
||||
consumer: tp.Callable[[torch.IntTensor], None],
|
||||
offset_start: int,
|
||||
offset_end: int,
|
||||
) -> tp.Dict[str, int]:
|
||||
summary = BinarizeSummary()
|
||||
|
||||
with Chunker(
|
||||
PathManager.get_local_path(filename), offset, end
|
||||
PathManager.get_local_path(filename), offset_start, offset_end
|
||||
) as line_iterator:
|
||||
for line in line_iterator:
|
||||
ids = alignment_parser(line)
|
||||
nseq += 1
|
||||
consumer(ids)
|
||||
return {"nseq": nseq}
|
||||
consumer(binarizer.binarize_line(line, summary))
|
||||
|
||||
return {
|
||||
"nseq": summary.num_seq,
|
||||
"nunk": summary.num_replaced,
|
||||
"ntok": summary.num_tok,
|
||||
"replaced": summary.replaced,
|
||||
}
|
||||
|
@ -11,14 +11,17 @@ import logging
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
from collections import Counter
|
||||
import typing as tp
|
||||
from argparse import Namespace
|
||||
from itertools import zip_longest
|
||||
from multiprocessing import Pool
|
||||
|
||||
from fairseq import options, tasks, utils
|
||||
from fairseq.binarizer import Binarizer
|
||||
from fairseq.data import indexed_dataset
|
||||
from fairseq.file_chunker_utils import find_offsets
|
||||
from fairseq.binarizer import (
|
||||
AlignmentDatasetBinarizer,
|
||||
FileBinarizer,
|
||||
VocabularyDatasetBinarizer,
|
||||
)
|
||||
from fairseq.data import Dictionary
|
||||
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
|
||||
@ -28,8 +31,251 @@ logging.basicConfig(
|
||||
)
|
||||
logger = logging.getLogger("fairseq_cli.preprocess")
|
||||
|
||||
#####################################################################
|
||||
# file name tools
|
||||
#####################################################################
|
||||
|
||||
|
||||
def _train_path(lang, trainpref):
|
||||
return "{}{}".format(trainpref, ("." + lang) if lang else "")
|
||||
|
||||
|
||||
def _file_name(prefix, lang):
|
||||
fname = prefix
|
||||
if lang is not None:
|
||||
fname += ".{lang}".format(lang=lang)
|
||||
return fname
|
||||
|
||||
|
||||
def _dest_path(prefix, lang, destdir):
|
||||
return os.path.join(destdir, _file_name(prefix, lang))
|
||||
|
||||
|
||||
def _dict_path(lang, destdir):
|
||||
return _dest_path("dict", lang, destdir) + ".txt"
|
||||
|
||||
|
||||
def dataset_dest_prefix(args, output_prefix, lang):
|
||||
base = os.path.join(args.destdir, output_prefix)
|
||||
if lang is not None:
|
||||
lang_part = f".{args.source_lang}-{args.target_lang}.{lang}"
|
||||
elif args.only_source:
|
||||
lang_part = ""
|
||||
else:
|
||||
lang_part = f".{args.source_lang}-{args.target_lang}"
|
||||
|
||||
return "{}{}".format(base, lang_part)
|
||||
|
||||
|
||||
def dataset_dest_file(args, output_prefix, lang, extension):
|
||||
return "{}.{}".format(dataset_dest_prefix(args, output_prefix, lang), extension)
|
||||
|
||||
|
||||
#####################################################################
|
||||
# dictionary tools
|
||||
#####################################################################
|
||||
|
||||
|
||||
def _build_dictionary(
|
||||
filenames,
|
||||
task,
|
||||
args,
|
||||
src=False,
|
||||
tgt=False,
|
||||
):
|
||||
assert src ^ tgt
|
||||
return task.build_dictionary(
|
||||
filenames,
|
||||
workers=args.workers,
|
||||
threshold=args.thresholdsrc if src else args.thresholdtgt,
|
||||
nwords=args.nwordssrc if src else args.nwordstgt,
|
||||
padding_factor=args.padding_factor,
|
||||
)
|
||||
|
||||
|
||||
#####################################################################
|
||||
# bin file creation logic
|
||||
#####################################################################
|
||||
|
||||
|
||||
def _make_binary_dataset(
|
||||
vocab: Dictionary,
|
||||
input_prefix: str,
|
||||
output_prefix: str,
|
||||
lang: tp.Optional[str],
|
||||
num_workers: int,
|
||||
args: Namespace,
|
||||
):
|
||||
logger.info("[{}] Dictionary: {} types".format(lang, len(vocab)))
|
||||
|
||||
binarizer = VocabularyDatasetBinarizer(
|
||||
vocab,
|
||||
append_eos=True,
|
||||
)
|
||||
|
||||
input_file = "{}{}".format(input_prefix, ("." + lang) if lang is not None else "")
|
||||
full_output_prefix = dataset_dest_prefix(args, output_prefix, lang)
|
||||
|
||||
final_summary = FileBinarizer.multiprocess_dataset(
|
||||
input_file,
|
||||
args.dataset_impl,
|
||||
binarizer,
|
||||
full_output_prefix,
|
||||
vocab_size=len(vocab),
|
||||
num_workers=num_workers,
|
||||
)
|
||||
|
||||
logger.info(f"[{lang}] {input_file}: {final_summary} (by {vocab.unk_word})")
|
||||
|
||||
|
||||
def _make_binary_alignment_dataset(
|
||||
input_prefix: str, output_prefix: str, num_workers: int, args: Namespace
|
||||
):
|
||||
|
||||
binarizer = AlignmentDatasetBinarizer(utils.parse_alignment)
|
||||
|
||||
input_file = input_prefix
|
||||
full_output_prefix = dataset_dest_prefix(args, output_prefix, lang=None)
|
||||
|
||||
final_summary = FileBinarizer.multiprocess_dataset(
|
||||
input_file,
|
||||
args.dataset_impl,
|
||||
binarizer,
|
||||
full_output_prefix,
|
||||
vocab_size=None,
|
||||
num_workers=num_workers,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"[alignments] {}: parsed {} alignments".format(
|
||||
input_file, final_summary.num_seq
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
#####################################################################
|
||||
# routing logic
|
||||
#####################################################################
|
||||
|
||||
|
||||
def _make_dataset(
|
||||
vocab: Dictionary,
|
||||
input_prefix: str,
|
||||
output_prefix: str,
|
||||
lang: tp.Optional[str],
|
||||
args: Namespace,
|
||||
num_workers: int,
|
||||
):
|
||||
if args.dataset_impl == "raw":
|
||||
# Copy original text file to destination folder
|
||||
output_text_file = _dest_path(
|
||||
output_prefix + ".{}-{}".format(args.source_lang, args.target_lang),
|
||||
lang,
|
||||
args.destdir,
|
||||
)
|
||||
shutil.copyfile(_file_name(input_prefix, lang), output_text_file)
|
||||
else:
|
||||
_make_binary_dataset(
|
||||
vocab, input_prefix, output_prefix, lang, num_workers, args
|
||||
)
|
||||
|
||||
|
||||
def _make_all(lang, vocab, args):
|
||||
if args.trainpref:
|
||||
_make_dataset(
|
||||
vocab, args.trainpref, "train", lang, args=args, num_workers=args.workers
|
||||
)
|
||||
if args.validpref:
|
||||
for k, validpref in enumerate(args.validpref.split(",")):
|
||||
outprefix = "valid{}".format(k) if k > 0 else "valid"
|
||||
_make_dataset(
|
||||
vocab, validpref, outprefix, lang, args=args, num_workers=args.workers
|
||||
)
|
||||
if args.testpref:
|
||||
for k, testpref in enumerate(args.testpref.split(",")):
|
||||
outprefix = "test{}".format(k) if k > 0 else "test"
|
||||
_make_dataset(
|
||||
vocab, testpref, outprefix, lang, args=args, num_workers=args.workers
|
||||
)
|
||||
|
||||
|
||||
def _make_all_alignments(args):
|
||||
if args.trainpref and os.path.exists(args.trainpref + "." + args.align_suffix):
|
||||
_make_binary_alignment_dataset(
|
||||
args.trainpref + "." + args.align_suffix,
|
||||
"train.align",
|
||||
num_workers=args.workers,
|
||||
args=args,
|
||||
)
|
||||
if args.validpref and os.path.exists(args.validpref + "." + args.align_suffix):
|
||||
_make_binary_alignment_dataset(
|
||||
args.validpref + "." + args.align_suffix,
|
||||
"valid.align",
|
||||
num_workers=args.workers,
|
||||
args=args,
|
||||
)
|
||||
if args.testpref and os.path.exists(args.testpref + "." + args.align_suffix):
|
||||
_make_binary_alignment_dataset(
|
||||
args.testpref + "." + args.align_suffix,
|
||||
"test.align",
|
||||
num_workers=args.workers,
|
||||
args=args,
|
||||
)
|
||||
|
||||
|
||||
#####################################################################
|
||||
# align
|
||||
#####################################################################
|
||||
|
||||
|
||||
def _align_files(args, src_dict, tgt_dict):
|
||||
assert args.trainpref, "--trainpref must be set if --alignfile is specified"
|
||||
src_file_name = _train_path(args.source_lang, args.trainpref)
|
||||
tgt_file_name = _train_path(args.target_lang, args.trainpref)
|
||||
freq_map = {}
|
||||
with open(args.alignfile, "r", encoding="utf-8") as align_file:
|
||||
with open(src_file_name, "r", encoding="utf-8") as src_file:
|
||||
with open(tgt_file_name, "r", encoding="utf-8") as tgt_file:
|
||||
for a, s, t in zip_longest(align_file, src_file, tgt_file):
|
||||
si = src_dict.encode_line(s, add_if_not_exist=False)
|
||||
ti = tgt_dict.encode_line(t, add_if_not_exist=False)
|
||||
ai = list(map(lambda x: tuple(x.split("-")), a.split()))
|
||||
for sai, tai in ai:
|
||||
srcidx = si[int(sai)]
|
||||
tgtidx = ti[int(tai)]
|
||||
if srcidx != src_dict.unk() and tgtidx != tgt_dict.unk():
|
||||
assert srcidx != src_dict.pad()
|
||||
assert srcidx != src_dict.eos()
|
||||
assert tgtidx != tgt_dict.pad()
|
||||
assert tgtidx != tgt_dict.eos()
|
||||
if srcidx not in freq_map:
|
||||
freq_map[srcidx] = {}
|
||||
if tgtidx not in freq_map[srcidx]:
|
||||
freq_map[srcidx][tgtidx] = 1
|
||||
else:
|
||||
freq_map[srcidx][tgtidx] += 1
|
||||
align_dict = {}
|
||||
for srcidx in freq_map.keys():
|
||||
align_dict[srcidx] = max(freq_map[srcidx], key=freq_map[srcidx].get)
|
||||
with open(
|
||||
os.path.join(
|
||||
args.destdir,
|
||||
"alignment.{}-{}.txt".format(args.source_lang, args.target_lang),
|
||||
),
|
||||
"w",
|
||||
encoding="utf-8",
|
||||
) as f:
|
||||
for k, v in align_dict.items():
|
||||
print("{} {}".format(src_dict[k], tgt_dict[v]), file=f)
|
||||
|
||||
|
||||
#####################################################################
|
||||
# MAIN
|
||||
#####################################################################
|
||||
|
||||
|
||||
def main(args):
|
||||
# setup some basic things
|
||||
utils.import_user_module(args)
|
||||
|
||||
os.makedirs(args.destdir, exist_ok=True)
|
||||
@ -45,39 +291,21 @@ def main(args):
|
||||
args.dataset_impl != "huffman"
|
||||
), "preprocessing.py doesn't support Huffman yet, use HuffmanCodeBuilder directly."
|
||||
|
||||
task = tasks.get_task(args.task)
|
||||
|
||||
def train_path(lang):
|
||||
return "{}{}".format(args.trainpref, ("." + lang) if lang else "")
|
||||
|
||||
def file_name(prefix, lang):
|
||||
fname = prefix
|
||||
if lang is not None:
|
||||
fname += ".{lang}".format(lang=lang)
|
||||
return fname
|
||||
|
||||
def dest_path(prefix, lang):
|
||||
return os.path.join(args.destdir, file_name(prefix, lang))
|
||||
|
||||
def dict_path(lang):
|
||||
return dest_path("dict", lang) + ".txt"
|
||||
|
||||
def build_dictionary(filenames, src=False, tgt=False):
|
||||
assert src ^ tgt
|
||||
return task.build_dictionary(
|
||||
filenames,
|
||||
workers=args.workers,
|
||||
threshold=args.thresholdsrc if src else args.thresholdtgt,
|
||||
nwords=args.nwordssrc if src else args.nwordstgt,
|
||||
padding_factor=args.padding_factor,
|
||||
)
|
||||
# build dictionaries
|
||||
|
||||
target = not args.only_source
|
||||
|
||||
if not args.srcdict and os.path.exists(dict_path(args.source_lang)):
|
||||
raise FileExistsError(dict_path(args.source_lang))
|
||||
if target and not args.tgtdict and os.path.exists(dict_path(args.target_lang)):
|
||||
raise FileExistsError(dict_path(args.target_lang))
|
||||
if not args.srcdict and os.path.exists(_dict_path(args.source_lang, args.destdir)):
|
||||
raise FileExistsError(_dict_path(args.source_lang, args.destdir))
|
||||
|
||||
if (
|
||||
target
|
||||
and not args.tgtdict
|
||||
and os.path.exists(_dict_path(args.target_lang, args.destdir))
|
||||
):
|
||||
raise FileExistsError(_dict_path(args.target_lang, args.destdir))
|
||||
|
||||
task = tasks.get_task(args.task)
|
||||
|
||||
if args.joined_dictionary:
|
||||
assert (
|
||||
@ -92,8 +320,13 @@ def main(args):
|
||||
assert (
|
||||
args.trainpref
|
||||
), "--trainpref must be set if --srcdict is not specified"
|
||||
src_dict = build_dictionary(
|
||||
{train_path(lang) for lang in [args.source_lang, args.target_lang]},
|
||||
src_dict = _build_dictionary(
|
||||
{
|
||||
_train_path(lang, args.trainpref)
|
||||
for lang in [args.source_lang, args.target_lang]
|
||||
},
|
||||
task=task,
|
||||
args=args,
|
||||
src=True,
|
||||
)
|
||||
tgt_dict = src_dict
|
||||
@ -104,7 +337,12 @@ def main(args):
|
||||
assert (
|
||||
args.trainpref
|
||||
), "--trainpref must be set if --srcdict is not specified"
|
||||
src_dict = build_dictionary([train_path(args.source_lang)], src=True)
|
||||
src_dict = _build_dictionary(
|
||||
[_train_path(args.source_lang, args.trainpref)],
|
||||
task=task,
|
||||
args=args,
|
||||
src=True,
|
||||
)
|
||||
|
||||
if target:
|
||||
if args.tgtdict:
|
||||
@ -113,292 +351,36 @@ def main(args):
|
||||
assert (
|
||||
args.trainpref
|
||||
), "--trainpref must be set if --tgtdict is not specified"
|
||||
tgt_dict = build_dictionary([train_path(args.target_lang)], tgt=True)
|
||||
tgt_dict = _build_dictionary(
|
||||
[_train_path(args.target_lang, args.trainpref)],
|
||||
task=task,
|
||||
args=args,
|
||||
tgt=True,
|
||||
)
|
||||
else:
|
||||
tgt_dict = None
|
||||
|
||||
src_dict.save(dict_path(args.source_lang))
|
||||
# save dictionaries
|
||||
|
||||
src_dict.save(_dict_path(args.source_lang, args.destdir))
|
||||
if target and tgt_dict is not None:
|
||||
tgt_dict.save(dict_path(args.target_lang))
|
||||
tgt_dict.save(_dict_path(args.target_lang, args.destdir))
|
||||
|
||||
if args.dict_only:
|
||||
return
|
||||
|
||||
def make_binary_dataset(vocab, input_prefix, output_prefix, lang, num_workers):
|
||||
logger.info("[{}] Dictionary: {} types".format(lang, len(vocab)))
|
||||
n_seq_tok = [0, 0]
|
||||
replaced = Counter()
|
||||
|
||||
def merge_result(worker_result):
|
||||
replaced.update(worker_result["replaced"])
|
||||
n_seq_tok[0] += worker_result["nseq"]
|
||||
n_seq_tok[1] += worker_result["ntok"]
|
||||
|
||||
input_file = "{}{}".format(
|
||||
input_prefix, ("." + lang) if lang is not None else ""
|
||||
)
|
||||
offsets = find_offsets(input_file, num_workers)
|
||||
(first_chunk, *more_chunks) = zip(offsets, offsets[1:])
|
||||
pool = None
|
||||
if num_workers > 1:
|
||||
pool = Pool(processes=num_workers - 1)
|
||||
for worker_id, (start_offset, end_offset) in enumerate(
|
||||
more_chunks, start=1
|
||||
):
|
||||
prefix = "{}{}".format(output_prefix, worker_id)
|
||||
pool.apply_async(
|
||||
binarize,
|
||||
(
|
||||
args,
|
||||
input_file,
|
||||
vocab,
|
||||
prefix,
|
||||
lang,
|
||||
start_offset,
|
||||
end_offset,
|
||||
),
|
||||
callback=merge_result,
|
||||
)
|
||||
pool.close()
|
||||
|
||||
ds = indexed_dataset.make_builder(
|
||||
dataset_dest_file(args, output_prefix, lang, "bin"),
|
||||
impl=args.dataset_impl,
|
||||
vocab_size=len(vocab),
|
||||
)
|
||||
merge_result(
|
||||
Binarizer.binarize(
|
||||
input_file,
|
||||
vocab,
|
||||
lambda t: ds.add_item(t),
|
||||
offset=first_chunk[0],
|
||||
end=first_chunk[1],
|
||||
)
|
||||
)
|
||||
if num_workers > 1:
|
||||
pool.join()
|
||||
for worker_id in range(1, num_workers):
|
||||
prefix = "{}{}".format(output_prefix, worker_id)
|
||||
temp_file_path = dataset_dest_prefix(args, prefix, lang)
|
||||
ds.merge_file_(temp_file_path)
|
||||
os.remove(indexed_dataset.data_file_path(temp_file_path))
|
||||
os.remove(indexed_dataset.index_file_path(temp_file_path))
|
||||
|
||||
ds.finalize(dataset_dest_file(args, output_prefix, lang, "idx"))
|
||||
|
||||
logger.info(
|
||||
"[{}] {}: {} sents, {} tokens, {:.3}% replaced by {}".format(
|
||||
lang,
|
||||
input_file,
|
||||
n_seq_tok[0],
|
||||
n_seq_tok[1],
|
||||
100 * sum(replaced.values()) / n_seq_tok[1],
|
||||
vocab.unk_word,
|
||||
)
|
||||
)
|
||||
|
||||
def make_binary_alignment_dataset(input_prefix, output_prefix, num_workers):
|
||||
nseq = [0]
|
||||
|
||||
def merge_result(worker_result):
|
||||
nseq[0] += worker_result["nseq"]
|
||||
|
||||
input_file = input_prefix
|
||||
offsets = find_offsets(input_file, num_workers)
|
||||
(first_chunk, *more_chunks) = zip(offsets, offsets[1:])
|
||||
pool = None
|
||||
if num_workers > 1:
|
||||
pool = Pool(processes=num_workers - 1)
|
||||
for worker_id, (start_offset, end_offset) in enumerate(
|
||||
more_chunks, start=1
|
||||
):
|
||||
prefix = "{}{}".format(output_prefix, worker_id)
|
||||
pool.apply_async(
|
||||
binarize_alignments,
|
||||
(
|
||||
args,
|
||||
input_file,
|
||||
utils.parse_alignment,
|
||||
prefix,
|
||||
start_offset,
|
||||
end_offset,
|
||||
),
|
||||
callback=merge_result,
|
||||
)
|
||||
pool.close()
|
||||
|
||||
ds = indexed_dataset.make_builder(
|
||||
dataset_dest_file(args, output_prefix, None, "bin"), impl=args.dataset_impl
|
||||
)
|
||||
|
||||
merge_result(
|
||||
Binarizer.binarize_alignments(
|
||||
input_file,
|
||||
utils.parse_alignment,
|
||||
lambda t: ds.add_item(t),
|
||||
offset=first_chunk[0],
|
||||
end=first_chunk[1],
|
||||
)
|
||||
)
|
||||
if num_workers > 1:
|
||||
pool.join()
|
||||
for worker_id in range(1, num_workers):
|
||||
prefix = "{}{}".format(output_prefix, worker_id)
|
||||
temp_file_path = dataset_dest_prefix(args, prefix, None)
|
||||
ds.merge_file_(temp_file_path)
|
||||
os.remove(indexed_dataset.data_file_path(temp_file_path))
|
||||
os.remove(indexed_dataset.index_file_path(temp_file_path))
|
||||
|
||||
ds.finalize(dataset_dest_file(args, output_prefix, None, "idx"))
|
||||
|
||||
logger.info("[alignments] {}: parsed {} alignments".format(input_file, nseq[0]))
|
||||
|
||||
def make_dataset(vocab, input_prefix, output_prefix, lang, num_workers=1):
|
||||
if args.dataset_impl == "raw":
|
||||
# Copy original text file to destination folder
|
||||
output_text_file = dest_path(
|
||||
output_prefix + ".{}-{}".format(args.source_lang, args.target_lang),
|
||||
lang,
|
||||
)
|
||||
shutil.copyfile(file_name(input_prefix, lang), output_text_file)
|
||||
else:
|
||||
make_binary_dataset(vocab, input_prefix, output_prefix, lang, num_workers)
|
||||
|
||||
def make_all(lang, vocab):
|
||||
if args.trainpref:
|
||||
make_dataset(vocab, args.trainpref, "train", lang, num_workers=args.workers)
|
||||
if args.validpref:
|
||||
for k, validpref in enumerate(args.validpref.split(",")):
|
||||
outprefix = "valid{}".format(k) if k > 0 else "valid"
|
||||
make_dataset(
|
||||
vocab, validpref, outprefix, lang, num_workers=args.workers
|
||||
)
|
||||
if args.testpref:
|
||||
for k, testpref in enumerate(args.testpref.split(",")):
|
||||
outprefix = "test{}".format(k) if k > 0 else "test"
|
||||
make_dataset(vocab, testpref, outprefix, lang, num_workers=args.workers)
|
||||
|
||||
def make_all_alignments():
|
||||
if args.trainpref and os.path.exists(args.trainpref + "." + args.align_suffix):
|
||||
make_binary_alignment_dataset(
|
||||
args.trainpref + "." + args.align_suffix,
|
||||
"train.align",
|
||||
num_workers=args.workers,
|
||||
)
|
||||
if args.validpref and os.path.exists(args.validpref + "." + args.align_suffix):
|
||||
make_binary_alignment_dataset(
|
||||
args.validpref + "." + args.align_suffix,
|
||||
"valid.align",
|
||||
num_workers=args.workers,
|
||||
)
|
||||
if args.testpref and os.path.exists(args.testpref + "." + args.align_suffix):
|
||||
make_binary_alignment_dataset(
|
||||
args.testpref + "." + args.align_suffix,
|
||||
"test.align",
|
||||
num_workers=args.workers,
|
||||
)
|
||||
|
||||
make_all(args.source_lang, src_dict)
|
||||
_make_all(args.source_lang, src_dict, args)
|
||||
if target:
|
||||
make_all(args.target_lang, tgt_dict)
|
||||
_make_all(args.target_lang, tgt_dict, args)
|
||||
|
||||
# align the datasets if needed
|
||||
if args.align_suffix:
|
||||
make_all_alignments()
|
||||
_make_all_alignments(args)
|
||||
|
||||
logger.info("Wrote preprocessed data to {}".format(args.destdir))
|
||||
|
||||
if args.alignfile:
|
||||
assert args.trainpref, "--trainpref must be set if --alignfile is specified"
|
||||
src_file_name = train_path(args.source_lang)
|
||||
tgt_file_name = train_path(args.target_lang)
|
||||
freq_map = {}
|
||||
with open(args.alignfile, "r", encoding="utf-8") as align_file:
|
||||
with open(src_file_name, "r", encoding="utf-8") as src_file:
|
||||
with open(tgt_file_name, "r", encoding="utf-8") as tgt_file:
|
||||
for a, s, t in zip_longest(align_file, src_file, tgt_file):
|
||||
si = src_dict.encode_line(s, add_if_not_exist=False)
|
||||
ti = tgt_dict.encode_line(t, add_if_not_exist=False)
|
||||
ai = list(map(lambda x: tuple(x.split("-")), a.split()))
|
||||
for sai, tai in ai:
|
||||
srcidx = si[int(sai)]
|
||||
tgtidx = ti[int(tai)]
|
||||
if srcidx != src_dict.unk() and tgtidx != tgt_dict.unk():
|
||||
assert srcidx != src_dict.pad()
|
||||
assert srcidx != src_dict.eos()
|
||||
assert tgtidx != tgt_dict.pad()
|
||||
assert tgtidx != tgt_dict.eos()
|
||||
|
||||
if srcidx not in freq_map:
|
||||
freq_map[srcidx] = {}
|
||||
if tgtidx not in freq_map[srcidx]:
|
||||
freq_map[srcidx][tgtidx] = 1
|
||||
else:
|
||||
freq_map[srcidx][tgtidx] += 1
|
||||
|
||||
align_dict = {}
|
||||
for srcidx in freq_map.keys():
|
||||
align_dict[srcidx] = max(freq_map[srcidx], key=freq_map[srcidx].get)
|
||||
|
||||
with open(
|
||||
os.path.join(
|
||||
args.destdir,
|
||||
"alignment.{}-{}.txt".format(args.source_lang, args.target_lang),
|
||||
),
|
||||
"w",
|
||||
encoding="utf-8",
|
||||
) as f:
|
||||
for k, v in align_dict.items():
|
||||
print("{} {}".format(src_dict[k], tgt_dict[v]), file=f)
|
||||
|
||||
|
||||
def binarize(args, filename, vocab, output_prefix, lang, offset, end, append_eos=True):
|
||||
ds = indexed_dataset.make_builder(
|
||||
dataset_dest_file(args, output_prefix, lang, "bin"),
|
||||
impl=args.dataset_impl,
|
||||
vocab_size=len(vocab),
|
||||
)
|
||||
|
||||
def consumer(tensor):
|
||||
ds.add_item(tensor)
|
||||
|
||||
res = Binarizer.binarize(
|
||||
filename, vocab, consumer, append_eos=append_eos, offset=offset, end=end
|
||||
)
|
||||
ds.finalize(dataset_dest_file(args, output_prefix, lang, "idx"))
|
||||
return res
|
||||
|
||||
|
||||
def binarize_alignments(args, filename, parse_alignment, output_prefix, offset, end):
|
||||
ds = indexed_dataset.make_builder(
|
||||
dataset_dest_file(args, output_prefix, None, "bin"),
|
||||
impl=args.dataset_impl,
|
||||
vocab_size=None,
|
||||
)
|
||||
|
||||
def consumer(tensor):
|
||||
ds.add_item(tensor)
|
||||
|
||||
res = Binarizer.binarize_alignments(
|
||||
filename, parse_alignment, consumer, offset=offset, end=end
|
||||
)
|
||||
ds.finalize(dataset_dest_file(args, output_prefix, None, "idx"))
|
||||
return res
|
||||
|
||||
|
||||
def dataset_dest_prefix(args, output_prefix, lang):
|
||||
base = "{}/{}".format(args.destdir, output_prefix)
|
||||
if lang is not None:
|
||||
lang_part = ".{}-{}.{}".format(args.source_lang, args.target_lang, lang)
|
||||
elif args.only_source:
|
||||
lang_part = ""
|
||||
else:
|
||||
lang_part = ".{}-{}".format(args.source_lang, args.target_lang)
|
||||
|
||||
return "{}{}".format(base, lang_part)
|
||||
|
||||
|
||||
def dataset_dest_file(args, output_prefix, lang, extension):
|
||||
base = dataset_dest_prefix(args, output_prefix, lang)
|
||||
return "{}.{}".format(base, extension)
|
||||
_align_files(args, src_dict=src_dict, tgt_dict=tgt_dict)
|
||||
|
||||
|
||||
def cli_main():
|
||||
|
122
tests/test_binarizer.py
Normal file
122
tests/test_binarizer.py
Normal file
@ -0,0 +1,122 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
import os
|
||||
import typing as tp
|
||||
import unittest
|
||||
from tempfile import TemporaryDirectory
|
||||
|
||||
from fairseq.binarizer import BinarizeSummary, FileBinarizer, VocabularyDatasetBinarizer
|
||||
from fairseq.data import Dictionary, indexed_dataset
|
||||
from tests.utils import make_data, sizes
|
||||
|
||||
|
||||
def build_vocab(data: tp.List[tp.List[str]]) -> Dictionary:
|
||||
d = Dictionary()
|
||||
for s in data:
|
||||
for token in s:
|
||||
d.add_symbol(token)
|
||||
d.finalize()
|
||||
return d
|
||||
|
||||
|
||||
class TestBinarizer(unittest.TestCase):
|
||||
def compare_ds_data(self, summary, data, prefix, impl, vocab):
|
||||
self.assertEqual(summary.num_seq, len(data))
|
||||
self.assertEqual(summary.num_tok, sum([len(s) for s in data]))
|
||||
|
||||
dataset = indexed_dataset.make_dataset(prefix, impl)
|
||||
|
||||
self.assertEqual(len(dataset), len(data))
|
||||
decoded = [vocab.string(dataset[i]).split() for i in range(0, len(dataset))]
|
||||
|
||||
self.assertEqual(decoded, data)
|
||||
data_sizes = [i.item() for i in dataset.sizes]
|
||||
self.assertEqual(data_sizes, sizes(data))
|
||||
|
||||
def test_can_binarize_line(self):
|
||||
data = make_data(length=1)
|
||||
vocab = build_vocab(data)
|
||||
|
||||
binarizer = VocabularyDatasetBinarizer(
|
||||
vocab,
|
||||
)
|
||||
|
||||
sentence = data[0]
|
||||
summary = BinarizeSummary()
|
||||
|
||||
tensor = binarizer.binarize_line(
|
||||
" ".join(sentence),
|
||||
summary,
|
||||
)
|
||||
|
||||
self.assertEqual(len(tensor), len(sentence) + 1)
|
||||
|
||||
self.assertEqual(summary.num_tok, len(sentence) + 1)
|
||||
self.assertEqual(summary.num_seq, 1)
|
||||
|
||||
def test_can_binarize_file_chunk(self):
|
||||
# test without multiprocess logic
|
||||
with TemporaryDirectory() as dirname:
|
||||
raw_file = os.path.join(dirname, "raw1")
|
||||
prefix = os.path.join(dirname, "test1")
|
||||
impl = "mmap"
|
||||
|
||||
data = make_data(out_file=raw_file)
|
||||
vocab = build_vocab(data)
|
||||
|
||||
binarizer = VocabularyDatasetBinarizer(
|
||||
vocab,
|
||||
append_eos=False,
|
||||
)
|
||||
|
||||
summary = FileBinarizer._binarize_chunk_and_finalize(
|
||||
binarizer,
|
||||
raw_file,
|
||||
offset_start=0,
|
||||
offset_end=-1,
|
||||
output_prefix=prefix,
|
||||
dataset_impl=impl,
|
||||
vocab_size=len(vocab),
|
||||
)
|
||||
|
||||
self.compare_ds_data(summary, data, prefix, impl, vocab)
|
||||
|
||||
def test_can_multiprocess(self):
|
||||
with TemporaryDirectory() as dirname:
|
||||
raw_file = os.path.join(dirname, "raw1")
|
||||
prefix = os.path.join(dirname, "test1")
|
||||
impl = "mmap"
|
||||
data = make_data(out_file=raw_file)
|
||||
vocab = build_vocab(data)
|
||||
binarizer = VocabularyDatasetBinarizer(
|
||||
vocab,
|
||||
append_eos=False,
|
||||
)
|
||||
# with one worker
|
||||
summary = FileBinarizer.multiprocess_dataset(
|
||||
raw_file,
|
||||
impl,
|
||||
binarizer,
|
||||
output_prefix=prefix,
|
||||
vocab_size=len(vocab),
|
||||
num_workers=1,
|
||||
)
|
||||
|
||||
self.compare_ds_data(summary, data, prefix, impl, vocab)
|
||||
|
||||
# with multiple worker
|
||||
prefix_multi = os.path.join(dirname, "test2")
|
||||
summary = FileBinarizer.multiprocess_dataset(
|
||||
raw_file,
|
||||
impl,
|
||||
binarizer,
|
||||
output_prefix=prefix_multi,
|
||||
vocab_size=len(vocab),
|
||||
num_workers=3,
|
||||
)
|
||||
|
||||
self.compare_ds_data(summary, data, prefix_multi, impl, vocab)
|
@ -4,8 +4,6 @@
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import os
|
||||
import random
|
||||
import string
|
||||
import typing as tp
|
||||
import unittest
|
||||
from collections import Counter
|
||||
@ -18,23 +16,7 @@ from fairseq.data.huffman import (
|
||||
HuffmanMMapIndexedDataset,
|
||||
HuffmanMMapIndexedDatasetBuilder,
|
||||
)
|
||||
|
||||
POPULATION = string.ascii_letters + string.digits
|
||||
|
||||
|
||||
def make_sentence() -> tp.List[str]:
|
||||
length = random.randint(10, 50)
|
||||
return random.choices(
|
||||
population=POPULATION, k=length, weights=range(1, len(POPULATION) + 1)
|
||||
)
|
||||
|
||||
|
||||
def make_data(length=1000) -> tp.List[tp.List[str]]:
|
||||
return (
|
||||
[make_sentence() for _ in range(0, length)]
|
||||
# add all the symbols at least once
|
||||
+ [list(string.ascii_letters), list(string.digits)]
|
||||
)
|
||||
from tests.utils import POPULATION, make_data, sizes
|
||||
|
||||
|
||||
def make_counts(data: tp.List[tp.List[str]]) -> Counter:
|
||||
@ -112,10 +94,6 @@ def build_dataset(prefix, data, coder):
|
||||
builder.add_item(sentence)
|
||||
|
||||
|
||||
def sizes(data):
|
||||
return [len(sentence) for sentence in data]
|
||||
|
||||
|
||||
class TestHuffmanDataset(unittest.TestCase):
|
||||
def test_huffman_can_encode_decode(self):
|
||||
data = make_data()
|
||||
|
@ -8,7 +8,9 @@ import json
|
||||
import os
|
||||
import random
|
||||
import shutil
|
||||
import string
|
||||
import sys
|
||||
import typing as tp
|
||||
from io import StringIO
|
||||
|
||||
import torch
|
||||
@ -756,3 +758,31 @@ def train_language_model(
|
||||
+ (extra_valid_flags or []),
|
||||
)
|
||||
validate.main(validate_args)
|
||||
|
||||
|
||||
def sizes(data):
|
||||
return [len(sentence) for sentence in data]
|
||||
|
||||
|
||||
POPULATION = string.ascii_letters + string.digits
|
||||
|
||||
|
||||
def make_sentence() -> tp.List[str]:
|
||||
length = random.randint(10, 50)
|
||||
return random.choices(
|
||||
population=POPULATION, k=length, weights=range(1, len(POPULATION) + 1)
|
||||
)
|
||||
|
||||
|
||||
def make_data(length=1000, out_file=None) -> tp.List[tp.List[str]]:
|
||||
data = (
|
||||
[make_sentence() for _ in range(0, length)]
|
||||
# add all the symbols at least once
|
||||
+ [list(string.ascii_letters), list(string.digits)]
|
||||
)
|
||||
if out_file is not None:
|
||||
with open(out_file, "w", encoding="utf-8") as out:
|
||||
for s in data:
|
||||
print(" ".join(s), file=out)
|
||||
|
||||
return data
|
||||
|
Loading…
Reference in New Issue
Block a user