Preprocess Split (#2738)

Summary:
This is the equivalent to PR https://github.com/fairinternal/fairseq-py/issues/2697 but on top of main instead of gshard (cherry-picked and merged the squash):

* reorganize preprocess.py code a bit
* use Binarizers objects in the multiprocess code
* clean up the make_binary
* multiprocess logic
* learn to count
* format and doc string
* add basic test for vocab binarizer
* generalize to one line
* move multiprocess in binarizer

Testing:
```
python -m fairseq_cli.preprocess --only-source --trainpref ~/fixathon/small_vocab_test/train.in --destdir ~/fixathon/small_vocab_test/data-bin.cherry --workers 20
python -m fairseq_cli.preprocess --only-source --trainpref ~/fixathon/small_vocab_test/train.in --destdir ~/fixathon/small_vocab_test/data-bin.main --workers 20
```

```
 md5sum ~/fixathon/small_vocab_test/data-bin.cherry/train.bin == md5sum ~/fixathon/small_vocab_test/data-bin.main/train.bin
```

```
diff ~/fixathon/small_vocab_test/data-bin.main/dict.txt ~/fixathon/small_vocab_test/data-bin.cherry/dict.tx
```

Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/2738

Reviewed By: sshleifer, dianaml0

Differential Revision: D32830875

Pulled By: Mortimerp9

fbshipit-source-id: e7463d5cdd96a877691bf39666daa319ebb3dcb8
This commit is contained in:
Pierre Andrews 2022-01-11 11:55:43 -08:00 committed by Facebook GitHub Bot
parent b3fa5100c6
commit 279796224f
5 changed files with 803 additions and 390 deletions

View File

@ -3,78 +3,379 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
import os
import typing as tp
from abc import ABC, abstractmethod
from collections import Counter
from typing import Dict
from dataclasses import dataclass
from multiprocessing import Pool
import torch
from fairseq.file_chunker_utils import Chunker
from fairseq.data import Dictionary, indexed_dataset
from fairseq.file_chunker_utils import Chunker, find_offsets
from fairseq.file_io import PathManager
from fairseq.tokenizer import tokenize_line
logger = logging.getLogger("binarizer")
@dataclass
class BinarizeSummary:
"""
Keep track of what's going on in the binarizer
"""
num_seq: int = 0
replaced: tp.Optional[Counter] = None
num_tok: int = 0
@property
def num_replaced(self) -> int:
if self.replaced is None:
return 0
return sum(self.replaced.values())
@property
def replaced_percent(self) -> float:
return 100 * self.num_replaced / self.num_tok
def __str__(self) -> str:
base = f"{self.num_seq} sents, {self.num_tok} tokens"
if self.replaced is None:
return base
return f"{base}, {self.replaced_percent:.3}% replaced"
def merge(self, other: "BinarizeSummary"):
replaced = None
if self.replaced is not None:
replaced = self.replaced
if other.replaced is not None:
if replaced is None:
replaced = other.replaced
else:
replaced += other.replaced
self.replaced = replaced
self.num_seq += other.num_seq
self.num_tok += other.num_tok
class Binarizer(ABC):
"""
a binarizer describes how to take a string and build a tensor out of it
"""
@abstractmethod
def binarize_line(
self,
line: str,
summary: BinarizeSummary,
) -> torch.IntTensor:
...
def _worker_prefix(output_prefix: str, worker_id: int):
return f"{output_prefix}.pt{worker_id}"
class FileBinarizer:
"""
An file binarizer can take a file, tokenize it, and binarize each line to a tensor
"""
@classmethod
def multiprocess_dataset(
cls,
input_file: str,
dataset_impl: str,
binarizer: Binarizer,
output_prefix: str,
vocab_size=None,
num_workers=1,
) -> BinarizeSummary:
final_summary = BinarizeSummary()
offsets = find_offsets(input_file, num_workers)
# find_offsets returns a list of position [pos1, pos2, pos3, pos4] but we would want pairs:
# [(pos1, pos2), (pos2, pos3), (pos3, pos4)] to process the chunks with start/end info
# we zip the list with itself shifted by one to get all the pairs.
(first_chunk, *more_chunks) = zip(offsets, offsets[1:])
pool = None
if num_workers > 1:
pool = Pool(processes=num_workers - 1)
worker_results = [
pool.apply_async(
cls._binarize_chunk_and_finalize,
args=(
binarizer,
input_file,
start_offset,
end_offset,
_worker_prefix(
output_prefix,
worker_id,
),
dataset_impl,
),
kwds={
"vocab_size": vocab_size,
}
if vocab_size is not None
else {},
)
for worker_id, (start_offset, end_offset) in enumerate(
more_chunks, start=1
)
]
pool.close()
pool.join()
for r in worker_results:
summ = r.get()
final_summary.merge(summ)
# do not close the bin file as we need to merge the worker results in
final_ds, summ = cls._binarize_file_chunk(
binarizer,
input_file,
offset_start=first_chunk[0],
offset_end=first_chunk[1],
output_prefix=output_prefix,
dataset_impl=dataset_impl,
vocab_size=vocab_size if vocab_size is not None else None,
)
final_summary.merge(summ)
if num_workers > 1:
for worker_id in range(1, num_workers):
# merge the worker outputs
worker_output_prefix = _worker_prefix(
output_prefix,
worker_id,
)
final_ds.merge_file_(worker_output_prefix)
try:
os.remove(indexed_dataset.data_file_path(worker_output_prefix))
os.remove(indexed_dataset.index_file_path(worker_output_prefix))
except Exception as e:
logger.error(
f"couldn't remove {worker_output_prefix}.*", exc_info=e
)
# now we can close the file
idx_file = indexed_dataset.index_file_path(output_prefix)
final_ds.finalize(idx_file)
return final_summary
class Binarizer:
@staticmethod
def binarize(
filename,
dict,
consumer,
tokenize=tokenize_line,
append_eos=True,
reverse_order=False,
offset=0,
end=-1,
already_numberized=False,
) -> Dict[str, int]:
nseq, ntok = 0, 0
replaced = Counter()
def _binarize_file_chunk(
binarizer: Binarizer,
filename: str,
offset_start: int,
offset_end: int,
output_prefix: str,
dataset_impl: str,
vocab_size=None,
) -> tp.Tuple[tp.Any, BinarizeSummary]: # (dataset builder, BinarizeSummary)
"""
creates a dataset builder and append binarized items to it. This function does not
finalize the builder, this is useful if you want to do other things with your bin file
like appending/merging other files
"""
bin_file = indexed_dataset.data_file_path(output_prefix)
ds = indexed_dataset.make_builder(
bin_file,
impl=dataset_impl,
vocab_size=vocab_size,
)
summary = BinarizeSummary()
with Chunker(
PathManager.get_local_path(filename), offset_start, offset_end
) as line_iterator:
for line in line_iterator:
ds.add_item(binarizer.binarize_line(line, summary))
return ds, summary
@classmethod
def _binarize_chunk_and_finalize(
cls,
binarizer: Binarizer,
filename: str,
offset_start: int,
offset_end: int,
output_prefix: str,
dataset_impl: str,
vocab_size=None,
):
"""
same as above, but also finalizes the builder
"""
ds, summ = cls._binarize_file_chunk(
binarizer,
filename,
offset_start,
offset_end,
output_prefix,
dataset_impl,
vocab_size=vocab_size,
)
idx_file = indexed_dataset.index_file_path(output_prefix)
ds.finalize(idx_file)
return summ
class VocabularyDatasetBinarizer(Binarizer):
"""
Takes a Dictionary/Vocabulary, assign ids to each
token using the dictionary encode_line function.
"""
def __init__(
self,
dict: Dictionary,
tokenize: tp.Callable[[str], tp.List[str]] = tokenize_line,
append_eos: bool = True,
reverse_order: bool = False,
already_numberized: bool = False,
) -> None:
self.dict = dict
self.tokenize = tokenize
self.append_eos = append_eos
self.reverse_order = reverse_order
self.already_numberized = already_numberized
super().__init__()
def binarize_line(
self,
line: str,
summary: BinarizeSummary,
):
if summary.replaced is None:
summary.replaced = Counter()
def replaced_consumer(word, idx):
if idx == dict.unk_index and word != dict.unk_word:
replaced.update([word])
if idx == self.dict.unk_index and word != self.dict.unk_word:
summary.replaced.update([word])
with Chunker(
PathManager.get_local_path(filename), offset, end
) as line_iterator:
for line in line_iterator:
if already_numberized:
id_strings = line.strip().split()
id_list = [int(id_string) for id_string in id_strings]
if reverse_order:
id_list.reverse()
if append_eos:
id_list.append(dict.eos())
ids = torch.IntTensor(id_list)
else:
ids = dict.encode_line(
line=line,
line_tokenizer=tokenize,
add_if_not_exist=False,
consumer=replaced_consumer,
append_eos=append_eos,
reverse_order=reverse_order,
)
nseq += 1
ntok += len(ids)
consumer(ids)
return {
"nseq": nseq,
"nunk": sum(replaced.values()),
"ntok": ntok,
"replaced": replaced,
}
if self.already_numberized:
id_strings = line.strip().split()
id_list = [int(id_string) for id_string in id_strings]
if self.reverse_order:
id_list.reverse()
if self.append_eos:
id_list.append(self.dict.eos())
ids = torch.IntTensor(id_list)
else:
ids = self.dict.encode_line(
line=line,
line_tokenizer=self.tokenize,
add_if_not_exist=False,
consumer=replaced_consumer,
append_eos=self.append_eos,
reverse_order=self.reverse_order,
)
summary.num_seq += 1
summary.num_tok += len(ids)
return ids
class AlignmentDatasetBinarizer(Binarizer):
"""
binarize by parsing a set of alignments and packing
them in a tensor (see utils.parse_alignment)
"""
def __init__(
self,
alignment_parser: tp.Callable[[str], torch.IntTensor],
) -> None:
super().__init__()
self.alignment_parser = alignment_parser
def binarize_line(
self,
line: str,
summary: BinarizeSummary,
):
ids = self.alignment_parser(line)
summary.num_seq += 1
summary.num_tok += len(ids)
return ids
class LegacyBinarizer:
@classmethod
def binarize(
cls,
filename: str,
dico: Dictionary,
consumer: tp.Callable[[torch.IntTensor], None],
tokenize: tp.Callable[[str], tp.List[str]] = tokenize_line,
append_eos: bool = True,
reverse_order: bool = False,
offset: int = 0,
end: int = -1,
already_numberized: bool = False,
) -> tp.Dict[str, int]:
binarizer = VocabularyDatasetBinarizer(
dict=dico,
tokenize=tokenize,
append_eos=append_eos,
reverse_order=reverse_order,
already_numberized=already_numberized,
)
return cls._consume_file(
filename,
binarizer,
consumer,
offset_start=offset,
offset_end=end,
)
@classmethod
def binarize_alignments(
cls,
filename: str,
alignment_parser: tp.Callable[[str], torch.IntTensor],
consumer: tp.Callable[[torch.IntTensor], None],
offset: int = 0,
end: int = -1,
) -> tp.Dict[str, int]:
binarizer = AlignmentDatasetBinarizer(alignment_parser)
return cls._consume_file(
filename,
binarizer,
consumer,
offset_start=offset,
offset_end=end,
)
@staticmethod
def binarize_alignments(
filename, alignment_parser, consumer, offset=0, end=-1
) -> Dict[str, int]:
nseq = 0
def _consume_file(
filename: str,
binarizer: Binarizer,
consumer: tp.Callable[[torch.IntTensor], None],
offset_start: int,
offset_end: int,
) -> tp.Dict[str, int]:
summary = BinarizeSummary()
with Chunker(
PathManager.get_local_path(filename), offset, end
PathManager.get_local_path(filename), offset_start, offset_end
) as line_iterator:
for line in line_iterator:
ids = alignment_parser(line)
nseq += 1
consumer(ids)
return {"nseq": nseq}
consumer(binarizer.binarize_line(line, summary))
return {
"nseq": summary.num_seq,
"nunk": summary.num_replaced,
"ntok": summary.num_tok,
"replaced": summary.replaced,
}

View File

@ -11,14 +11,17 @@ import logging
import os
import shutil
import sys
from collections import Counter
import typing as tp
from argparse import Namespace
from itertools import zip_longest
from multiprocessing import Pool
from fairseq import options, tasks, utils
from fairseq.binarizer import Binarizer
from fairseq.data import indexed_dataset
from fairseq.file_chunker_utils import find_offsets
from fairseq.binarizer import (
AlignmentDatasetBinarizer,
FileBinarizer,
VocabularyDatasetBinarizer,
)
from fairseq.data import Dictionary
logging.basicConfig(
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
@ -28,8 +31,251 @@ logging.basicConfig(
)
logger = logging.getLogger("fairseq_cli.preprocess")
#####################################################################
# file name tools
#####################################################################
def _train_path(lang, trainpref):
return "{}{}".format(trainpref, ("." + lang) if lang else "")
def _file_name(prefix, lang):
fname = prefix
if lang is not None:
fname += ".{lang}".format(lang=lang)
return fname
def _dest_path(prefix, lang, destdir):
return os.path.join(destdir, _file_name(prefix, lang))
def _dict_path(lang, destdir):
return _dest_path("dict", lang, destdir) + ".txt"
def dataset_dest_prefix(args, output_prefix, lang):
base = os.path.join(args.destdir, output_prefix)
if lang is not None:
lang_part = f".{args.source_lang}-{args.target_lang}.{lang}"
elif args.only_source:
lang_part = ""
else:
lang_part = f".{args.source_lang}-{args.target_lang}"
return "{}{}".format(base, lang_part)
def dataset_dest_file(args, output_prefix, lang, extension):
return "{}.{}".format(dataset_dest_prefix(args, output_prefix, lang), extension)
#####################################################################
# dictionary tools
#####################################################################
def _build_dictionary(
filenames,
task,
args,
src=False,
tgt=False,
):
assert src ^ tgt
return task.build_dictionary(
filenames,
workers=args.workers,
threshold=args.thresholdsrc if src else args.thresholdtgt,
nwords=args.nwordssrc if src else args.nwordstgt,
padding_factor=args.padding_factor,
)
#####################################################################
# bin file creation logic
#####################################################################
def _make_binary_dataset(
vocab: Dictionary,
input_prefix: str,
output_prefix: str,
lang: tp.Optional[str],
num_workers: int,
args: Namespace,
):
logger.info("[{}] Dictionary: {} types".format(lang, len(vocab)))
binarizer = VocabularyDatasetBinarizer(
vocab,
append_eos=True,
)
input_file = "{}{}".format(input_prefix, ("." + lang) if lang is not None else "")
full_output_prefix = dataset_dest_prefix(args, output_prefix, lang)
final_summary = FileBinarizer.multiprocess_dataset(
input_file,
args.dataset_impl,
binarizer,
full_output_prefix,
vocab_size=len(vocab),
num_workers=num_workers,
)
logger.info(f"[{lang}] {input_file}: {final_summary} (by {vocab.unk_word})")
def _make_binary_alignment_dataset(
input_prefix: str, output_prefix: str, num_workers: int, args: Namespace
):
binarizer = AlignmentDatasetBinarizer(utils.parse_alignment)
input_file = input_prefix
full_output_prefix = dataset_dest_prefix(args, output_prefix, lang=None)
final_summary = FileBinarizer.multiprocess_dataset(
input_file,
args.dataset_impl,
binarizer,
full_output_prefix,
vocab_size=None,
num_workers=num_workers,
)
logger.info(
"[alignments] {}: parsed {} alignments".format(
input_file, final_summary.num_seq
)
)
#####################################################################
# routing logic
#####################################################################
def _make_dataset(
vocab: Dictionary,
input_prefix: str,
output_prefix: str,
lang: tp.Optional[str],
args: Namespace,
num_workers: int,
):
if args.dataset_impl == "raw":
# Copy original text file to destination folder
output_text_file = _dest_path(
output_prefix + ".{}-{}".format(args.source_lang, args.target_lang),
lang,
args.destdir,
)
shutil.copyfile(_file_name(input_prefix, lang), output_text_file)
else:
_make_binary_dataset(
vocab, input_prefix, output_prefix, lang, num_workers, args
)
def _make_all(lang, vocab, args):
if args.trainpref:
_make_dataset(
vocab, args.trainpref, "train", lang, args=args, num_workers=args.workers
)
if args.validpref:
for k, validpref in enumerate(args.validpref.split(",")):
outprefix = "valid{}".format(k) if k > 0 else "valid"
_make_dataset(
vocab, validpref, outprefix, lang, args=args, num_workers=args.workers
)
if args.testpref:
for k, testpref in enumerate(args.testpref.split(",")):
outprefix = "test{}".format(k) if k > 0 else "test"
_make_dataset(
vocab, testpref, outprefix, lang, args=args, num_workers=args.workers
)
def _make_all_alignments(args):
if args.trainpref and os.path.exists(args.trainpref + "." + args.align_suffix):
_make_binary_alignment_dataset(
args.trainpref + "." + args.align_suffix,
"train.align",
num_workers=args.workers,
args=args,
)
if args.validpref and os.path.exists(args.validpref + "." + args.align_suffix):
_make_binary_alignment_dataset(
args.validpref + "." + args.align_suffix,
"valid.align",
num_workers=args.workers,
args=args,
)
if args.testpref and os.path.exists(args.testpref + "." + args.align_suffix):
_make_binary_alignment_dataset(
args.testpref + "." + args.align_suffix,
"test.align",
num_workers=args.workers,
args=args,
)
#####################################################################
# align
#####################################################################
def _align_files(args, src_dict, tgt_dict):
assert args.trainpref, "--trainpref must be set if --alignfile is specified"
src_file_name = _train_path(args.source_lang, args.trainpref)
tgt_file_name = _train_path(args.target_lang, args.trainpref)
freq_map = {}
with open(args.alignfile, "r", encoding="utf-8") as align_file:
with open(src_file_name, "r", encoding="utf-8") as src_file:
with open(tgt_file_name, "r", encoding="utf-8") as tgt_file:
for a, s, t in zip_longest(align_file, src_file, tgt_file):
si = src_dict.encode_line(s, add_if_not_exist=False)
ti = tgt_dict.encode_line(t, add_if_not_exist=False)
ai = list(map(lambda x: tuple(x.split("-")), a.split()))
for sai, tai in ai:
srcidx = si[int(sai)]
tgtidx = ti[int(tai)]
if srcidx != src_dict.unk() and tgtidx != tgt_dict.unk():
assert srcidx != src_dict.pad()
assert srcidx != src_dict.eos()
assert tgtidx != tgt_dict.pad()
assert tgtidx != tgt_dict.eos()
if srcidx not in freq_map:
freq_map[srcidx] = {}
if tgtidx not in freq_map[srcidx]:
freq_map[srcidx][tgtidx] = 1
else:
freq_map[srcidx][tgtidx] += 1
align_dict = {}
for srcidx in freq_map.keys():
align_dict[srcidx] = max(freq_map[srcidx], key=freq_map[srcidx].get)
with open(
os.path.join(
args.destdir,
"alignment.{}-{}.txt".format(args.source_lang, args.target_lang),
),
"w",
encoding="utf-8",
) as f:
for k, v in align_dict.items():
print("{} {}".format(src_dict[k], tgt_dict[v]), file=f)
#####################################################################
# MAIN
#####################################################################
def main(args):
# setup some basic things
utils.import_user_module(args)
os.makedirs(args.destdir, exist_ok=True)
@ -45,39 +291,21 @@ def main(args):
args.dataset_impl != "huffman"
), "preprocessing.py doesn't support Huffman yet, use HuffmanCodeBuilder directly."
task = tasks.get_task(args.task)
def train_path(lang):
return "{}{}".format(args.trainpref, ("." + lang) if lang else "")
def file_name(prefix, lang):
fname = prefix
if lang is not None:
fname += ".{lang}".format(lang=lang)
return fname
def dest_path(prefix, lang):
return os.path.join(args.destdir, file_name(prefix, lang))
def dict_path(lang):
return dest_path("dict", lang) + ".txt"
def build_dictionary(filenames, src=False, tgt=False):
assert src ^ tgt
return task.build_dictionary(
filenames,
workers=args.workers,
threshold=args.thresholdsrc if src else args.thresholdtgt,
nwords=args.nwordssrc if src else args.nwordstgt,
padding_factor=args.padding_factor,
)
# build dictionaries
target = not args.only_source
if not args.srcdict and os.path.exists(dict_path(args.source_lang)):
raise FileExistsError(dict_path(args.source_lang))
if target and not args.tgtdict and os.path.exists(dict_path(args.target_lang)):
raise FileExistsError(dict_path(args.target_lang))
if not args.srcdict and os.path.exists(_dict_path(args.source_lang, args.destdir)):
raise FileExistsError(_dict_path(args.source_lang, args.destdir))
if (
target
and not args.tgtdict
and os.path.exists(_dict_path(args.target_lang, args.destdir))
):
raise FileExistsError(_dict_path(args.target_lang, args.destdir))
task = tasks.get_task(args.task)
if args.joined_dictionary:
assert (
@ -92,8 +320,13 @@ def main(args):
assert (
args.trainpref
), "--trainpref must be set if --srcdict is not specified"
src_dict = build_dictionary(
{train_path(lang) for lang in [args.source_lang, args.target_lang]},
src_dict = _build_dictionary(
{
_train_path(lang, args.trainpref)
for lang in [args.source_lang, args.target_lang]
},
task=task,
args=args,
src=True,
)
tgt_dict = src_dict
@ -104,7 +337,12 @@ def main(args):
assert (
args.trainpref
), "--trainpref must be set if --srcdict is not specified"
src_dict = build_dictionary([train_path(args.source_lang)], src=True)
src_dict = _build_dictionary(
[_train_path(args.source_lang, args.trainpref)],
task=task,
args=args,
src=True,
)
if target:
if args.tgtdict:
@ -113,292 +351,36 @@ def main(args):
assert (
args.trainpref
), "--trainpref must be set if --tgtdict is not specified"
tgt_dict = build_dictionary([train_path(args.target_lang)], tgt=True)
tgt_dict = _build_dictionary(
[_train_path(args.target_lang, args.trainpref)],
task=task,
args=args,
tgt=True,
)
else:
tgt_dict = None
src_dict.save(dict_path(args.source_lang))
# save dictionaries
src_dict.save(_dict_path(args.source_lang, args.destdir))
if target and tgt_dict is not None:
tgt_dict.save(dict_path(args.target_lang))
tgt_dict.save(_dict_path(args.target_lang, args.destdir))
if args.dict_only:
return
def make_binary_dataset(vocab, input_prefix, output_prefix, lang, num_workers):
logger.info("[{}] Dictionary: {} types".format(lang, len(vocab)))
n_seq_tok = [0, 0]
replaced = Counter()
def merge_result(worker_result):
replaced.update(worker_result["replaced"])
n_seq_tok[0] += worker_result["nseq"]
n_seq_tok[1] += worker_result["ntok"]
input_file = "{}{}".format(
input_prefix, ("." + lang) if lang is not None else ""
)
offsets = find_offsets(input_file, num_workers)
(first_chunk, *more_chunks) = zip(offsets, offsets[1:])
pool = None
if num_workers > 1:
pool = Pool(processes=num_workers - 1)
for worker_id, (start_offset, end_offset) in enumerate(
more_chunks, start=1
):
prefix = "{}{}".format(output_prefix, worker_id)
pool.apply_async(
binarize,
(
args,
input_file,
vocab,
prefix,
lang,
start_offset,
end_offset,
),
callback=merge_result,
)
pool.close()
ds = indexed_dataset.make_builder(
dataset_dest_file(args, output_prefix, lang, "bin"),
impl=args.dataset_impl,
vocab_size=len(vocab),
)
merge_result(
Binarizer.binarize(
input_file,
vocab,
lambda t: ds.add_item(t),
offset=first_chunk[0],
end=first_chunk[1],
)
)
if num_workers > 1:
pool.join()
for worker_id in range(1, num_workers):
prefix = "{}{}".format(output_prefix, worker_id)
temp_file_path = dataset_dest_prefix(args, prefix, lang)
ds.merge_file_(temp_file_path)
os.remove(indexed_dataset.data_file_path(temp_file_path))
os.remove(indexed_dataset.index_file_path(temp_file_path))
ds.finalize(dataset_dest_file(args, output_prefix, lang, "idx"))
logger.info(
"[{}] {}: {} sents, {} tokens, {:.3}% replaced by {}".format(
lang,
input_file,
n_seq_tok[0],
n_seq_tok[1],
100 * sum(replaced.values()) / n_seq_tok[1],
vocab.unk_word,
)
)
def make_binary_alignment_dataset(input_prefix, output_prefix, num_workers):
nseq = [0]
def merge_result(worker_result):
nseq[0] += worker_result["nseq"]
input_file = input_prefix
offsets = find_offsets(input_file, num_workers)
(first_chunk, *more_chunks) = zip(offsets, offsets[1:])
pool = None
if num_workers > 1:
pool = Pool(processes=num_workers - 1)
for worker_id, (start_offset, end_offset) in enumerate(
more_chunks, start=1
):
prefix = "{}{}".format(output_prefix, worker_id)
pool.apply_async(
binarize_alignments,
(
args,
input_file,
utils.parse_alignment,
prefix,
start_offset,
end_offset,
),
callback=merge_result,
)
pool.close()
ds = indexed_dataset.make_builder(
dataset_dest_file(args, output_prefix, None, "bin"), impl=args.dataset_impl
)
merge_result(
Binarizer.binarize_alignments(
input_file,
utils.parse_alignment,
lambda t: ds.add_item(t),
offset=first_chunk[0],
end=first_chunk[1],
)
)
if num_workers > 1:
pool.join()
for worker_id in range(1, num_workers):
prefix = "{}{}".format(output_prefix, worker_id)
temp_file_path = dataset_dest_prefix(args, prefix, None)
ds.merge_file_(temp_file_path)
os.remove(indexed_dataset.data_file_path(temp_file_path))
os.remove(indexed_dataset.index_file_path(temp_file_path))
ds.finalize(dataset_dest_file(args, output_prefix, None, "idx"))
logger.info("[alignments] {}: parsed {} alignments".format(input_file, nseq[0]))
def make_dataset(vocab, input_prefix, output_prefix, lang, num_workers=1):
if args.dataset_impl == "raw":
# Copy original text file to destination folder
output_text_file = dest_path(
output_prefix + ".{}-{}".format(args.source_lang, args.target_lang),
lang,
)
shutil.copyfile(file_name(input_prefix, lang), output_text_file)
else:
make_binary_dataset(vocab, input_prefix, output_prefix, lang, num_workers)
def make_all(lang, vocab):
if args.trainpref:
make_dataset(vocab, args.trainpref, "train", lang, num_workers=args.workers)
if args.validpref:
for k, validpref in enumerate(args.validpref.split(",")):
outprefix = "valid{}".format(k) if k > 0 else "valid"
make_dataset(
vocab, validpref, outprefix, lang, num_workers=args.workers
)
if args.testpref:
for k, testpref in enumerate(args.testpref.split(",")):
outprefix = "test{}".format(k) if k > 0 else "test"
make_dataset(vocab, testpref, outprefix, lang, num_workers=args.workers)
def make_all_alignments():
if args.trainpref and os.path.exists(args.trainpref + "." + args.align_suffix):
make_binary_alignment_dataset(
args.trainpref + "." + args.align_suffix,
"train.align",
num_workers=args.workers,
)
if args.validpref and os.path.exists(args.validpref + "." + args.align_suffix):
make_binary_alignment_dataset(
args.validpref + "." + args.align_suffix,
"valid.align",
num_workers=args.workers,
)
if args.testpref and os.path.exists(args.testpref + "." + args.align_suffix):
make_binary_alignment_dataset(
args.testpref + "." + args.align_suffix,
"test.align",
num_workers=args.workers,
)
make_all(args.source_lang, src_dict)
_make_all(args.source_lang, src_dict, args)
if target:
make_all(args.target_lang, tgt_dict)
_make_all(args.target_lang, tgt_dict, args)
# align the datasets if needed
if args.align_suffix:
make_all_alignments()
_make_all_alignments(args)
logger.info("Wrote preprocessed data to {}".format(args.destdir))
if args.alignfile:
assert args.trainpref, "--trainpref must be set if --alignfile is specified"
src_file_name = train_path(args.source_lang)
tgt_file_name = train_path(args.target_lang)
freq_map = {}
with open(args.alignfile, "r", encoding="utf-8") as align_file:
with open(src_file_name, "r", encoding="utf-8") as src_file:
with open(tgt_file_name, "r", encoding="utf-8") as tgt_file:
for a, s, t in zip_longest(align_file, src_file, tgt_file):
si = src_dict.encode_line(s, add_if_not_exist=False)
ti = tgt_dict.encode_line(t, add_if_not_exist=False)
ai = list(map(lambda x: tuple(x.split("-")), a.split()))
for sai, tai in ai:
srcidx = si[int(sai)]
tgtidx = ti[int(tai)]
if srcidx != src_dict.unk() and tgtidx != tgt_dict.unk():
assert srcidx != src_dict.pad()
assert srcidx != src_dict.eos()
assert tgtidx != tgt_dict.pad()
assert tgtidx != tgt_dict.eos()
if srcidx not in freq_map:
freq_map[srcidx] = {}
if tgtidx not in freq_map[srcidx]:
freq_map[srcidx][tgtidx] = 1
else:
freq_map[srcidx][tgtidx] += 1
align_dict = {}
for srcidx in freq_map.keys():
align_dict[srcidx] = max(freq_map[srcidx], key=freq_map[srcidx].get)
with open(
os.path.join(
args.destdir,
"alignment.{}-{}.txt".format(args.source_lang, args.target_lang),
),
"w",
encoding="utf-8",
) as f:
for k, v in align_dict.items():
print("{} {}".format(src_dict[k], tgt_dict[v]), file=f)
def binarize(args, filename, vocab, output_prefix, lang, offset, end, append_eos=True):
ds = indexed_dataset.make_builder(
dataset_dest_file(args, output_prefix, lang, "bin"),
impl=args.dataset_impl,
vocab_size=len(vocab),
)
def consumer(tensor):
ds.add_item(tensor)
res = Binarizer.binarize(
filename, vocab, consumer, append_eos=append_eos, offset=offset, end=end
)
ds.finalize(dataset_dest_file(args, output_prefix, lang, "idx"))
return res
def binarize_alignments(args, filename, parse_alignment, output_prefix, offset, end):
ds = indexed_dataset.make_builder(
dataset_dest_file(args, output_prefix, None, "bin"),
impl=args.dataset_impl,
vocab_size=None,
)
def consumer(tensor):
ds.add_item(tensor)
res = Binarizer.binarize_alignments(
filename, parse_alignment, consumer, offset=offset, end=end
)
ds.finalize(dataset_dest_file(args, output_prefix, None, "idx"))
return res
def dataset_dest_prefix(args, output_prefix, lang):
base = "{}/{}".format(args.destdir, output_prefix)
if lang is not None:
lang_part = ".{}-{}.{}".format(args.source_lang, args.target_lang, lang)
elif args.only_source:
lang_part = ""
else:
lang_part = ".{}-{}".format(args.source_lang, args.target_lang)
return "{}{}".format(base, lang_part)
def dataset_dest_file(args, output_prefix, lang, extension):
base = dataset_dest_prefix(args, output_prefix, lang)
return "{}.{}".format(base, extension)
_align_files(args, src_dict=src_dict, tgt_dict=tgt_dict)
def cli_main():

122
tests/test_binarizer.py Normal file
View File

@ -0,0 +1,122 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import os
import typing as tp
import unittest
from tempfile import TemporaryDirectory
from fairseq.binarizer import BinarizeSummary, FileBinarizer, VocabularyDatasetBinarizer
from fairseq.data import Dictionary, indexed_dataset
from tests.utils import make_data, sizes
def build_vocab(data: tp.List[tp.List[str]]) -> Dictionary:
d = Dictionary()
for s in data:
for token in s:
d.add_symbol(token)
d.finalize()
return d
class TestBinarizer(unittest.TestCase):
def compare_ds_data(self, summary, data, prefix, impl, vocab):
self.assertEqual(summary.num_seq, len(data))
self.assertEqual(summary.num_tok, sum([len(s) for s in data]))
dataset = indexed_dataset.make_dataset(prefix, impl)
self.assertEqual(len(dataset), len(data))
decoded = [vocab.string(dataset[i]).split() for i in range(0, len(dataset))]
self.assertEqual(decoded, data)
data_sizes = [i.item() for i in dataset.sizes]
self.assertEqual(data_sizes, sizes(data))
def test_can_binarize_line(self):
data = make_data(length=1)
vocab = build_vocab(data)
binarizer = VocabularyDatasetBinarizer(
vocab,
)
sentence = data[0]
summary = BinarizeSummary()
tensor = binarizer.binarize_line(
" ".join(sentence),
summary,
)
self.assertEqual(len(tensor), len(sentence) + 1)
self.assertEqual(summary.num_tok, len(sentence) + 1)
self.assertEqual(summary.num_seq, 1)
def test_can_binarize_file_chunk(self):
# test without multiprocess logic
with TemporaryDirectory() as dirname:
raw_file = os.path.join(dirname, "raw1")
prefix = os.path.join(dirname, "test1")
impl = "mmap"
data = make_data(out_file=raw_file)
vocab = build_vocab(data)
binarizer = VocabularyDatasetBinarizer(
vocab,
append_eos=False,
)
summary = FileBinarizer._binarize_chunk_and_finalize(
binarizer,
raw_file,
offset_start=0,
offset_end=-1,
output_prefix=prefix,
dataset_impl=impl,
vocab_size=len(vocab),
)
self.compare_ds_data(summary, data, prefix, impl, vocab)
def test_can_multiprocess(self):
with TemporaryDirectory() as dirname:
raw_file = os.path.join(dirname, "raw1")
prefix = os.path.join(dirname, "test1")
impl = "mmap"
data = make_data(out_file=raw_file)
vocab = build_vocab(data)
binarizer = VocabularyDatasetBinarizer(
vocab,
append_eos=False,
)
# with one worker
summary = FileBinarizer.multiprocess_dataset(
raw_file,
impl,
binarizer,
output_prefix=prefix,
vocab_size=len(vocab),
num_workers=1,
)
self.compare_ds_data(summary, data, prefix, impl, vocab)
# with multiple worker
prefix_multi = os.path.join(dirname, "test2")
summary = FileBinarizer.multiprocess_dataset(
raw_file,
impl,
binarizer,
output_prefix=prefix_multi,
vocab_size=len(vocab),
num_workers=3,
)
self.compare_ds_data(summary, data, prefix_multi, impl, vocab)

View File

@ -4,8 +4,6 @@
# LICENSE file in the root directory of this source tree.
import os
import random
import string
import typing as tp
import unittest
from collections import Counter
@ -18,23 +16,7 @@ from fairseq.data.huffman import (
HuffmanMMapIndexedDataset,
HuffmanMMapIndexedDatasetBuilder,
)
POPULATION = string.ascii_letters + string.digits
def make_sentence() -> tp.List[str]:
length = random.randint(10, 50)
return random.choices(
population=POPULATION, k=length, weights=range(1, len(POPULATION) + 1)
)
def make_data(length=1000) -> tp.List[tp.List[str]]:
return (
[make_sentence() for _ in range(0, length)]
# add all the symbols at least once
+ [list(string.ascii_letters), list(string.digits)]
)
from tests.utils import POPULATION, make_data, sizes
def make_counts(data: tp.List[tp.List[str]]) -> Counter:
@ -112,10 +94,6 @@ def build_dataset(prefix, data, coder):
builder.add_item(sentence)
def sizes(data):
return [len(sentence) for sentence in data]
class TestHuffmanDataset(unittest.TestCase):
def test_huffman_can_encode_decode(self):
data = make_data()

View File

@ -8,7 +8,9 @@ import json
import os
import random
import shutil
import string
import sys
import typing as tp
from io import StringIO
import torch
@ -756,3 +758,31 @@ def train_language_model(
+ (extra_valid_flags or []),
)
validate.main(validate_args)
def sizes(data):
return [len(sentence) for sentence in data]
POPULATION = string.ascii_letters + string.digits
def make_sentence() -> tp.List[str]:
length = random.randint(10, 50)
return random.choices(
population=POPULATION, k=length, weights=range(1, len(POPULATION) + 1)
)
def make_data(length=1000, out_file=None) -> tp.List[tp.List[str]]:
data = (
[make_sentence() for _ in range(0, length)]
# add all the symbols at least once
+ [list(string.ascii_letters), list(string.digits)]
)
if out_file is not None:
with open(out_file, "w", encoding="utf-8") as out:
for s in data:
print(" ".join(s), file=out)
return data