mirror of
https://github.com/facebookresearch/fairseq.git
synced 2024-10-26 17:32:57 +03:00
Indexed Huffman Coded dataset (#2029)
Summary: ## What does this PR do? Currently, binarized dataset are stored as a bin representation of int tensors. At best, each int is coded as uint16 on disk. When coding a fixed size vocabulary dataset where we know the frequency of each symbol and where some symbols are more common than other, we can do better. This happens in particular when binarizing a dataset split in subword units as the most common "tokenizers" like bpe and spm will choose subwords with high frequencies over subwords with low frequencies. In practice, if we know the frequency of all symbols (or a good estimate), we can use entropy encoding methods to compress the data. The idea is to assign a compressed representation where frequent symbols will have shorter representations than unfrequent symbols. In this PR, we build a Huffman code from a frequency table and use this code to encode a dataset. The PR provides the huffman coder implementation (using the single queue approach as we usually start with a sorted set of symbols) as well as a memory map implementation of a dataset that stores the data compressed with a huffman code and can return indexed tensors from it. Over a whole dataset, depending on how many symbols we sample to evaluate the frequency, we can save between 25% and 30% of storage space. ## Follow Ups currently the binarizer/preprocess script make too many assumptions about the dataset writers so the huffman dataset writer cannot be used straight out of the box with it. I will make follow ups PRs to provide easy to use scripts to build such datasets. But it's as simple as doing: ``` code_builder = HuffmanCodeBuilder() with open(sample_file, 'r', encoding="utf-8") as input: for line in input: code_builder.add(*line.strip().split(" ")) coder = code_builder.build_code() with HuffmanMMapIndexedDatasetBuilder('/tmp/testing_huffman', coder) as builder: with open(dataset_file, 'r', encoding="utf-8") as input: for line in input: builder.add_item(line.strip().split(' ')) ``` a lot of the `HuffmanMMapIndexedDataset` code comes from the normal `MMapIndexedDataset` and we could probably extract commonalities in a base class the `HuffmanCoder` is also really a special kind of `Dictionary` and again, a common base class could be abstracted out of them. Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/2029 Reviewed By: dianaml0 Differential Revision: D29557468 Pulled By: Mortimerp9 fbshipit-source-id: a01b6d98f38f937934cadebb3786133e257adefe
This commit is contained in:
parent
5277ec47bd
commit
68a81202a3
@ -267,7 +267,7 @@ class Dictionary:
|
||||
self.add_symbol(word, n=count, overwrite=overwrite)
|
||||
except ValueError:
|
||||
raise ValueError(
|
||||
"Incorrect dictionary format, expected '<token> <cnt> [flags]'"
|
||||
f"Incorrect dictionary format, expected '<token> <cnt> [flags]': \"{line}\""
|
||||
)
|
||||
|
||||
def _save(self, f, kv_iterator):
|
||||
|
21
fairseq/data/huffman/__init__.py
Normal file
21
fairseq/data/huffman/__init__.py
Normal file
@ -0,0 +1,21 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from .huffman_coder import HuffmanCodeBuilder, HuffmanCoder
|
||||
from .huffman_mmap_indexed_dataset import (
|
||||
HuffmanMMapIndex,
|
||||
HuffmanMMapIndexedDataset,
|
||||
HuffmanMMapIndexedDatasetBuilder,
|
||||
vocab_file_path,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"HuffmanCoder",
|
||||
"HuffmanCodeBuilder",
|
||||
"HuffmanMMapIndexedDatasetBuilder",
|
||||
"HuffmanMMapIndexedDataset",
|
||||
"HuffmanMMapIndex",
|
||||
"vocab_file_path",
|
||||
]
|
265
fairseq/data/huffman/huffman_coder.py
Normal file
265
fairseq/data/huffman/huffman_coder.py
Normal file
@ -0,0 +1,265 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import re
|
||||
import typing as tp
|
||||
from collections import Counter, deque
|
||||
from dataclasses import dataclass
|
||||
|
||||
from bitarray import bitarray, util
|
||||
from fairseq.data import Dictionary
|
||||
|
||||
# basically we have to write to addressable bytes for the memory mapped
|
||||
# dataset loader. Sentences that get encoded to a length that is not a
|
||||
# multiple of BLOCKSIZE (a byte) will be padded to fit. (see _pad in the coder)
|
||||
BLOCKSIZE = 8
|
||||
|
||||
|
||||
class HuffmanCoder:
|
||||
def __init__(
|
||||
self, root: "HuffmanNode", bos="<s>", pad="<pad>", eos="</s>", unk="<unk>"
|
||||
):
|
||||
self.root = root
|
||||
self.table = root.code_table()
|
||||
self.bos_word, self.unk_word, self.pad_word, self.eos_word = bos, unk, pad, eos
|
||||
|
||||
def _pad(self, a: bitarray) -> bitarray:
|
||||
"""
|
||||
bitpadding, 1 then 0.
|
||||
|
||||
If the array is already a multiple of blocksize, we add a full block.
|
||||
"""
|
||||
pad_len = BLOCKSIZE - (len(a) % BLOCKSIZE) - 1
|
||||
padding = bitarray("1" + "0" * pad_len)
|
||||
return a + padding
|
||||
|
||||
def _unpad(self, a: bitarray) -> bitarray:
|
||||
"""
|
||||
remove the bitpadding.
|
||||
|
||||
There will be a set of 0s preceded by a 1 at the end of the bitarray, we remove that
|
||||
"""
|
||||
# count the 0 padding at the end until we find the first 1
|
||||
# we want to remove the one too
|
||||
remove_cnt = util.rindex(a, 1)
|
||||
return a[:remove_cnt]
|
||||
|
||||
def encode(self, iter: tp.List[str]) -> bytes:
|
||||
"""
|
||||
encode a list of tokens a return bytes. We use bitpadding to make sure the encoded bits fit in bytes.
|
||||
"""
|
||||
a = bitarray()
|
||||
for token in iter:
|
||||
code = self.get_code(token)
|
||||
if code is None:
|
||||
if self.unk_word is None:
|
||||
raise Exception(f"unknown token {token} cannot be encoded.")
|
||||
else:
|
||||
token = self.unk_word
|
||||
a = a + self.get_code(token)
|
||||
return self._pad(a).tobytes()
|
||||
|
||||
def decode(self, bits: bytes) -> tp.Iterator["HuffmanNode"]:
|
||||
"""
|
||||
take bitpadded bytes and decode it to a set of leaves. You can then use each node to find the symbol/id
|
||||
"""
|
||||
a = bitarray()
|
||||
a.frombytes(bits)
|
||||
return self.root.decode(self._unpad(a))
|
||||
|
||||
def get_code(self, symbol: str) -> tp.Optional[bitarray]:
|
||||
node = self.get_node(symbol)
|
||||
return None if node is None else node.code
|
||||
|
||||
def get_node(self, symbol: str) -> "HuffmanNode":
|
||||
return self.table.get(symbol)
|
||||
|
||||
@classmethod
|
||||
def from_file(
|
||||
cls,
|
||||
filename: str,
|
||||
bos="<s>",
|
||||
pad="<pad>",
|
||||
eos="</s>",
|
||||
unk="<unk>",
|
||||
) -> "HuffmanCoder":
|
||||
builder = HuffmanCodeBuilder.from_file(filename)
|
||||
return builder.build_code(bos=bos, pad=pad, eos=eos, unk=unk)
|
||||
|
||||
def to_file(self, filename, sep="\t"):
|
||||
nodes = list(self.table.values())
|
||||
nodes.sort(key=lambda n: n.id)
|
||||
with open(filename, "w", encoding="utf-8") as output:
|
||||
for n in nodes:
|
||||
output.write(f"{n.symbol}{sep}{n.count}\n")
|
||||
|
||||
def __iter__(self):
|
||||
for n in self.table.values():
|
||||
yield n
|
||||
|
||||
def merge(self, other_coder: "HuffmanCoder") -> "HuffmanCoder":
|
||||
builder = HuffmanCodeBuilder()
|
||||
for n in self:
|
||||
builder.increment(n.symbol, n.count)
|
||||
for n in other_coder:
|
||||
builder.increment(n.symbol, n.count)
|
||||
return builder.build_code()
|
||||
|
||||
def __eq__(self, other: "HuffmanCoder") -> bool:
|
||||
return self.table == other.table
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.table)
|
||||
|
||||
def __contains__(self, sym: str) -> bool:
|
||||
return sym in self.table
|
||||
|
||||
def to_dictionary(self) -> Dictionary:
|
||||
dictionary = Dictionary(bos=self.bos, unk=self.unk, pad=self.pad, eos=self.eos)
|
||||
for n in self:
|
||||
dictionary.add_symbol(n.symbol, n=n.count)
|
||||
dictionary.finalize()
|
||||
return dictionary
|
||||
|
||||
|
||||
@dataclass
|
||||
class HuffmanNode:
|
||||
"""
|
||||
a node in a Huffman tree
|
||||
"""
|
||||
|
||||
id: int
|
||||
count: int
|
||||
symbol: tp.Optional[str] = None
|
||||
left: tp.Optional["HuffmanNode"] = None
|
||||
right: tp.Optional["HuffmanNode"] = None
|
||||
code: tp.Optional[bitarray] = None
|
||||
|
||||
def is_leaf(self) -> bool:
|
||||
return self.left is None and self.right is None
|
||||
|
||||
def code_table(self, prefix: tp.Optional[bitarray] = None) -> tp.Dict[str, "HuffmanNode"]:
|
||||
defaulted_prefix = prefix if prefix is not None else bitarray()
|
||||
if self.is_leaf():
|
||||
self.code = (
|
||||
defaulted_prefix if len(defaulted_prefix) > 0 else bitarray("0")
|
||||
) # leaf could be the root if there is only one symbol
|
||||
return {self.symbol: self}
|
||||
|
||||
codes_right = self.right.code_table(defaulted_prefix + bitarray([0]))
|
||||
codes_left = self.left.code_table(defaulted_prefix + bitarray([1]))
|
||||
return {**codes_left, **codes_right}
|
||||
|
||||
def decode(self, bits: bitarray) -> tp.Iterator["HuffmanNode"]:
|
||||
current_node = self
|
||||
for bit in bits:
|
||||
if bit == 0: # go right
|
||||
current_node = current_node.right
|
||||
else: # go left
|
||||
current_node = current_node.left
|
||||
if current_node is None:
|
||||
# we shouldn't be on a leaf here
|
||||
raise Exception("fell off a leaf")
|
||||
if current_node.is_leaf():
|
||||
yield current_node
|
||||
current_node = self
|
||||
if current_node != self:
|
||||
raise Exception("couldn't decode all the bits")
|
||||
|
||||
|
||||
class HuffmanCodeBuilder:
|
||||
"""
|
||||
build a dictionary with occurence count and then build the Huffman code for it.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.symbols = Counter()
|
||||
|
||||
def add_symbols(self, *syms) -> None:
|
||||
self.symbols.update(syms)
|
||||
|
||||
def increment(self, symbol: str, cnt: int) -> None:
|
||||
self.symbols[symbol] += cnt
|
||||
|
||||
@classmethod
|
||||
def from_file(cls, filename):
|
||||
c = cls()
|
||||
with open(filename, "r", encoding="utf-8") as input:
|
||||
for line in input:
|
||||
split = re.split(r"[\s]+", line)
|
||||
c.increment(split[0], int(split[1]))
|
||||
return c
|
||||
|
||||
def to_file(self, filename, sep="\t"):
|
||||
with open(filename, "w", encoding="utf-8") as output:
|
||||
for (tok, cnt) in self.symbols.most_common():
|
||||
output.write(f"{tok}{sep}{cnt}\n")
|
||||
|
||||
def _smallest(self, q1: deque, q2: deque) -> HuffmanNode:
|
||||
if len(q1) == 0:
|
||||
return q2.pop()
|
||||
|
||||
if len(q2) == 0:
|
||||
return q1.pop()
|
||||
|
||||
if q1[-1].count < q2[-1].count:
|
||||
return q1.pop()
|
||||
|
||||
return q2.pop()
|
||||
|
||||
def __add__(self, c: "HuffmanCodeBuilder") -> "HuffmanCodeBuilder":
|
||||
new_c = self.symbols + c.symbols
|
||||
new_b = HuffmanCodeBuilder()
|
||||
new_b.symbols = new_c
|
||||
return new_b
|
||||
|
||||
def build_code(
|
||||
self,
|
||||
bos="<s>",
|
||||
pad="<pad>",
|
||||
eos="</s>",
|
||||
unk="<unk>",
|
||||
) -> HuffmanCoder:
|
||||
assert len(self.symbols) > 0, "cannot build code from empty list of symbols"
|
||||
|
||||
if self.symbols[bos] == 0:
|
||||
self.add_symbols(bos)
|
||||
if self.symbols[pad] == 0:
|
||||
self.add_symbols(pad)
|
||||
if self.symbols[eos] == 0:
|
||||
self.add_symbols(eos)
|
||||
if self.symbols[unk] == 0:
|
||||
self.add_symbols(unk)
|
||||
|
||||
node_id = 0
|
||||
leaves_queue = deque(
|
||||
[
|
||||
HuffmanNode(symbol=symbol, count=count, id=idx)
|
||||
for idx, (symbol, count) in enumerate(self.symbols.most_common())
|
||||
]
|
||||
) # left are the most common, right are the least common
|
||||
|
||||
if len(leaves_queue) == 1:
|
||||
root = leaves_queue.pop()
|
||||
root.id = 0
|
||||
return HuffmanCoder(root)
|
||||
|
||||
nodes_queue = deque()
|
||||
|
||||
while len(leaves_queue) > 0 or len(nodes_queue) != 1:
|
||||
# get the lowest two nodes at the head of each queue
|
||||
node1 = self._smallest(leaves_queue, nodes_queue)
|
||||
node2 = self._smallest(leaves_queue, nodes_queue)
|
||||
|
||||
# add new node
|
||||
nodes_queue.appendleft(
|
||||
HuffmanNode(
|
||||
count=node1.count + node2.count, left=node1, right=node2, id=node_id
|
||||
)
|
||||
)
|
||||
node_id += 1
|
||||
|
||||
# we are left with the root
|
||||
return HuffmanCoder(nodes_queue.pop(), bos=bos, pad=pad, eos=eos, unk=unk)
|
287
fairseq/data/huffman/huffman_mmap_indexed_dataset.py
Normal file
287
fairseq/data/huffman/huffman_mmap_indexed_dataset.py
Normal file
@ -0,0 +1,287 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import mmap
|
||||
import os
|
||||
import shutil
|
||||
import struct
|
||||
import typing as tp
|
||||
from functools import lru_cache
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from fairseq.data import indexed_dataset
|
||||
from fairseq.data.huffman import HuffmanCoder
|
||||
from fairseq.file_io import PathManager
|
||||
|
||||
|
||||
class HuffmanMMapIndex:
|
||||
"""
|
||||
keep an index of the offsets in the huffman binary file.
|
||||
First a header, then the list of sizes (num tokens) for each instance and finally
|
||||
the addresses of each instance.
|
||||
"""
|
||||
|
||||
_HDR_MAGIC = b"HUFFIDX\x00\x00"
|
||||
_VERSION = 1
|
||||
|
||||
@classmethod
|
||||
def writer(cls, path: str, data_len: int):
|
||||
class _Writer:
|
||||
def __enter__(self):
|
||||
self._file = open(path, "wb")
|
||||
|
||||
# write header (magic + version)
|
||||
self._file.write(cls._HDR_MAGIC)
|
||||
self._file.write(struct.pack("<Q", cls._VERSION))
|
||||
self._file.write(struct.pack("<Q", data_len))
|
||||
|
||||
return self
|
||||
|
||||
def write(self, sizes, pointers):
|
||||
# add number of items in the index to the header
|
||||
self._file.write(struct.pack("<Q", len(sizes)))
|
||||
|
||||
# write sizes
|
||||
sizes = np.array(sizes, dtype=np.int32)
|
||||
self._file.write(sizes.tobytes(order="C"))
|
||||
del sizes
|
||||
|
||||
# write address pointers
|
||||
pointers = np.array(pointers, dtype=np.int64)
|
||||
self._file.write(pointers.tobytes(order="C"))
|
||||
del pointers
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self._file.close()
|
||||
|
||||
return _Writer()
|
||||
|
||||
def __init__(self, path):
|
||||
with open(path, "rb") as stream:
|
||||
# read headers
|
||||
magic_test = stream.read(9)
|
||||
assert self._HDR_MAGIC == magic_test, (
|
||||
"Index file doesn't match expected format. "
|
||||
"Make sure that --dataset-impl is configured properly."
|
||||
)
|
||||
(version,) = struct.unpack("<Q", stream.read(8))
|
||||
assert (
|
||||
self._VERSION == version
|
||||
), "Unexpected file version f{version} != code version f{self._VERSION}"
|
||||
|
||||
# read length of data file
|
||||
(self._data_len,) = struct.unpack("<Q", stream.read(8))
|
||||
# read number of items in data file/index
|
||||
(self._len,) = struct.unpack("<Q", stream.read(8))
|
||||
offset = stream.tell()
|
||||
|
||||
indexed_dataset._warmup_mmap_file(path)
|
||||
|
||||
self._bin_buffer_mmap = np.memmap(path, mode="r", order="C")
|
||||
self._bin_buffer = memoryview(self._bin_buffer_mmap)
|
||||
self._sizes = np.frombuffer(
|
||||
self._bin_buffer, dtype=np.int32, count=self._len, offset=offset
|
||||
)
|
||||
self._pointers = np.frombuffer(
|
||||
self._bin_buffer,
|
||||
dtype=np.int64,
|
||||
count=self._len,
|
||||
offset=offset + self._sizes.nbytes,
|
||||
)
|
||||
|
||||
def __del__(self):
|
||||
self._bin_buffer_mmap._mmap.close()
|
||||
del self._bin_buffer_mmap
|
||||
|
||||
def __iter__(self):
|
||||
for i in range(self._len):
|
||||
yield self[i]
|
||||
|
||||
@property
|
||||
def data_len(self):
|
||||
return self._data_len
|
||||
|
||||
@property
|
||||
def sizes(self):
|
||||
return self._sizes
|
||||
|
||||
@lru_cache(maxsize=8)
|
||||
def __getitem__(self, i):
|
||||
return self._pointers[i], self._sizes[i]
|
||||
|
||||
def __len__(self):
|
||||
return self._len
|
||||
|
||||
|
||||
def vocab_file_path(prefix_path):
|
||||
return prefix_path + ".vocab"
|
||||
|
||||
|
||||
class HuffmanMMapIndexedDataset(torch.utils.data.Dataset):
|
||||
"""
|
||||
an indexed dataset that use mmap and memoryview to access data from disk
|
||||
that was compressed with a HuffmanCoder.
|
||||
"""
|
||||
|
||||
def __init__(self, prefix_path):
|
||||
super().__init__()
|
||||
|
||||
self._prefix_path = None
|
||||
self._index = None
|
||||
self._bin_buffer = None
|
||||
self._coder = None
|
||||
self._file = None
|
||||
|
||||
self._bin_buffer_mmap = None
|
||||
|
||||
self._do_init(prefix_path)
|
||||
|
||||
def __getstate__(self):
|
||||
return self._prefix_path
|
||||
|
||||
def __setstate__(self, state):
|
||||
self._do_init(state)
|
||||
|
||||
def _do_init(self, prefix_path):
|
||||
self._prefix_path = prefix_path
|
||||
self._index = HuffmanMMapIndex(
|
||||
indexed_dataset.index_file_path(self._prefix_path)
|
||||
)
|
||||
self._coder = HuffmanCoder.from_file(vocab_file_path(self._prefix_path))
|
||||
|
||||
indexed_dataset._warmup_mmap_file(
|
||||
indexed_dataset.data_file_path(self._prefix_path)
|
||||
)
|
||||
self._file = os.open(
|
||||
indexed_dataset.data_file_path(self._prefix_path), os.O_RDONLY
|
||||
)
|
||||
self._bin_buffer_mmap = mmap.mmap(
|
||||
self._file,
|
||||
self._index.data_len,
|
||||
access=mmap.ACCESS_READ,
|
||||
)
|
||||
self._bin_buffer = memoryview(self._bin_buffer_mmap)
|
||||
|
||||
def __del__(self):
|
||||
del self._bin_buffer
|
||||
if self._file:
|
||||
os.close(self._file)
|
||||
del self._index
|
||||
|
||||
def __len__(self):
|
||||
return len(self._index)
|
||||
|
||||
def _decode(self, i):
|
||||
ptr, _ = self._index[i]
|
||||
if i == 0:
|
||||
raw_bytes = self._bin_buffer[:ptr]
|
||||
else:
|
||||
(prev_ptr, _) = self._index[i - 1]
|
||||
raw_bytes = self._bin_buffer[prev_ptr:ptr]
|
||||
|
||||
return self._coder.decode(raw_bytes.tobytes())
|
||||
|
||||
@lru_cache(maxsize=8)
|
||||
def __getitem__(self, i):
|
||||
nodes = self._decode(i)
|
||||
return torch.tensor([n.id for n in nodes], dtype=torch.int64)
|
||||
|
||||
def __iter__(self):
|
||||
for idx in range(len(self)):
|
||||
yield self[idx]
|
||||
|
||||
def get_symbols(self, i):
|
||||
nodes = self._decode(i)
|
||||
for n in nodes:
|
||||
yield n.symbol
|
||||
|
||||
@property
|
||||
def sizes(self):
|
||||
return self._index.sizes
|
||||
|
||||
@property
|
||||
def supports_prefetch(self):
|
||||
return False
|
||||
|
||||
@property
|
||||
def coder(self):
|
||||
return self._coder
|
||||
|
||||
@staticmethod
|
||||
def exists(prefix_path):
|
||||
return (
|
||||
PathManager.exists(indexed_dataset.index_file_path(prefix_path))
|
||||
and PathManager.exists(indexed_dataset.data_file_path(prefix_path))
|
||||
and PathManager.exists(vocab_file_path(prefix_path))
|
||||
)
|
||||
|
||||
|
||||
class HuffmanMMapIndexedDatasetBuilder:
|
||||
"""
|
||||
Helper to build a memory mapped datasets with a huffman encoder.
|
||||
You can either open/close this manually or use it as a ContextManager.
|
||||
Provide your own coder, it will then be stored alongside the dataset.
|
||||
The builder will first write the vocab file, then open the binary file so you can stream
|
||||
into it, finally the index will be written when the builder is closed (your index should fit in memory).
|
||||
"""
|
||||
|
||||
def __init__(self, path_prefix: str, coder: HuffmanCoder) -> None:
|
||||
self._path_prefix = path_prefix
|
||||
self._coder = coder
|
||||
self._sizes = []
|
||||
self._ptrs = []
|
||||
self._data_len = 0
|
||||
|
||||
def open(self):
|
||||
self._coder.to_file(vocab_file_path(self._path_prefix))
|
||||
self._data_file = open(indexed_dataset.data_file_path(self._path_prefix), "wb")
|
||||
|
||||
def __enter__(self) -> "HuffmanMMapIndexedDatasetBuilder":
|
||||
self.open()
|
||||
return self
|
||||
|
||||
def add_item(self, tokens: tp.List[str]) -> None:
|
||||
"""
|
||||
add a list of tokens to the dataset, they will compressed with the
|
||||
provided coder before being written to file.
|
||||
"""
|
||||
encoded = self._coder.encode(tokens)
|
||||
code_len = len(encoded)
|
||||
last_ptr = 0
|
||||
if len(self._ptrs) > 0:
|
||||
last_ptr = self._ptrs[-1]
|
||||
self._sizes.append(len(tokens))
|
||||
self._ptrs.append(last_ptr + code_len)
|
||||
self._data_len += code_len
|
||||
self._data_file.write(encoded)
|
||||
|
||||
def append(self, other_dataset_path_prefix: str) -> None:
|
||||
"""
|
||||
append an existing dataset.
|
||||
Beware, if it wasn't built with the same coder, you are in trouble.
|
||||
"""
|
||||
other_index = HuffmanMMapIndex(
|
||||
indexed_dataset.index_file_path(other_dataset_path_prefix)
|
||||
)
|
||||
for (ptr, size) in other_index:
|
||||
self._ptrs.append(ptr + self._data_len)
|
||||
self._sizes.append(size)
|
||||
|
||||
# Concatenate data
|
||||
with open(indexed_dataset.data_file_path(other_dataset_path_prefix), "rb") as f:
|
||||
shutil.copyfileobj(f, self._data_file)
|
||||
|
||||
self._data_len += other_index.data_len
|
||||
|
||||
def close(self):
|
||||
self._data_file.close()
|
||||
with HuffmanMMapIndex.writer(
|
||||
indexed_dataset.index_file_path(self._path_prefix), self._data_len
|
||||
) as index:
|
||||
index.write(self._sizes, self._ptrs)
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
|
||||
self.close()
|
@ -12,6 +12,7 @@ import torch
|
||||
from fairseq.dataclass.constants import DATASET_IMPL_CHOICES
|
||||
from fairseq.data.fasta_dataset import FastaDataset
|
||||
from fairseq.file_io import PathManager
|
||||
from fairseq.data.huffman import HuffmanMMapIndexedDataset, HuffmanMMapIndex
|
||||
|
||||
from . import FairseqDataset
|
||||
|
||||
@ -48,6 +49,8 @@ def infer_dataset_impl(path):
|
||||
return "cached"
|
||||
elif magic == MMapIndexedDataset.Index._HDR_MAGIC[:8]:
|
||||
return "mmap"
|
||||
elif magic == HuffmanMMapIndex._HDR_MAGIC[:8]:
|
||||
return "huffman"
|
||||
else:
|
||||
return None
|
||||
elif FastaDataset.exists(path):
|
||||
@ -63,6 +66,8 @@ def make_builder(out_file, impl, vocab_size=None):
|
||||
)
|
||||
elif impl == "fasta":
|
||||
raise NotImplementedError
|
||||
elif impl == "huffman":
|
||||
raise ValueError("Use HuffmanCodeBuilder directly as it has a different interface.")
|
||||
else:
|
||||
return IndexedDatasetBuilder(out_file)
|
||||
|
||||
@ -81,6 +86,8 @@ def make_dataset(path, impl, fix_lua_indexing=False, dictionary=None):
|
||||
from fairseq.data.fasta_dataset import EncodedFastaDataset
|
||||
|
||||
return EncodedFastaDataset(path, dictionary)
|
||||
elif impl == "huffman" and HuffmanMMapIndexedDataset.exists(path):
|
||||
return HuffmanMMapIndexedDataset(path)
|
||||
return None
|
||||
|
||||
|
||||
@ -89,6 +96,8 @@ def dataset_exists(path, impl):
|
||||
return IndexedRawTextDataset.exists(path)
|
||||
elif impl == "mmap":
|
||||
return MMapIndexedDataset.exists(path)
|
||||
elif impl == "huffman":
|
||||
return HuffmanMMapIndexedDataset.exists(path)
|
||||
else:
|
||||
return IndexedDataset.exists(path)
|
||||
|
||||
|
@ -44,7 +44,7 @@ DDP_BACKEND_CHOICES = ChoiceEnum([
|
||||
"slow_mo",
|
||||
])
|
||||
DDP_COMM_HOOK_CHOICES = ChoiceEnum(["none", "fp16"])
|
||||
DATASET_IMPL_CHOICES = ChoiceEnum(["raw", "lazy", "cached", "mmap", "fasta"])
|
||||
DATASET_IMPL_CHOICES = ChoiceEnum(["raw", "lazy", "cached", "mmap", "fasta", "huffman"])
|
||||
GENERATION_CONSTRAINTS_CHOICES = ChoiceEnum(["ordered", "unordered"])
|
||||
GENERATION_DECODING_FORMAT_CHOICES = ChoiceEnum(
|
||||
["unigram", "ensemble", "vote", "dp", "bs"]
|
||||
|
@ -41,6 +41,8 @@ def main(args):
|
||||
)
|
||||
logger.info(args)
|
||||
|
||||
assert args.dataset_impl != "huffman", "preprocessing.py doesn't support Huffman yet, use HuffmanCodeBuilder directly."
|
||||
|
||||
task = tasks.get_task(args.task)
|
||||
|
||||
def train_path(lang):
|
||||
|
1
setup.py
1
setup.py
@ -209,6 +209,7 @@ def do_setup(package_data):
|
||||
"sacrebleu>=1.4.12",
|
||||
"torch",
|
||||
"tqdm",
|
||||
"bitarray",
|
||||
],
|
||||
dependency_links=dependency_links,
|
||||
packages=find_packages(
|
||||
|
201
tests/test_huffman.py
Normal file
201
tests/test_huffman.py
Normal file
@ -0,0 +1,201 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import os
|
||||
import random
|
||||
import string
|
||||
import typing as tp
|
||||
import unittest
|
||||
from collections import Counter
|
||||
from tempfile import NamedTemporaryFile, TemporaryDirectory
|
||||
|
||||
from fairseq.data import Dictionary, indexed_dataset
|
||||
from fairseq.data.huffman import (
|
||||
HuffmanCodeBuilder,
|
||||
HuffmanCoder,
|
||||
HuffmanMMapIndexedDataset,
|
||||
HuffmanMMapIndexedDatasetBuilder,
|
||||
)
|
||||
|
||||
POPULATION = string.ascii_letters + string.digits
|
||||
|
||||
|
||||
def make_sentence() -> tp.List[str]:
|
||||
length = random.randint(10, 50)
|
||||
return random.choices(
|
||||
population=POPULATION, k=length, weights=range(1, len(POPULATION) + 1)
|
||||
)
|
||||
|
||||
|
||||
def make_data(length=1000) -> tp.List[tp.List[str]]:
|
||||
return (
|
||||
[make_sentence() for _ in range(0, length)]
|
||||
# add all the symbols at least once
|
||||
+ [list(string.ascii_letters), list(string.digits)]
|
||||
)
|
||||
|
||||
|
||||
def make_counts(data: tp.List[tp.List[str]]) -> Counter:
|
||||
return Counter([symbol for sentence in data for symbol in sentence])
|
||||
|
||||
|
||||
def make_code_builder(data: tp.List[tp.List[str]]) -> HuffmanCodeBuilder:
|
||||
builder = HuffmanCodeBuilder()
|
||||
for sentence in data:
|
||||
builder.add_symbols(*sentence)
|
||||
return builder
|
||||
|
||||
|
||||
class TestCodeBuilder(unittest.TestCase):
|
||||
def test_code_builder_can_count(self):
|
||||
data = make_data()
|
||||
counts = make_counts(data)
|
||||
builder = make_code_builder(data)
|
||||
|
||||
self.assertEqual(builder.symbols, counts)
|
||||
|
||||
def test_code_builder_can_add(self):
|
||||
data = make_data()
|
||||
counts = make_counts(data)
|
||||
builder = make_code_builder(data)
|
||||
|
||||
new_builder = builder + builder
|
||||
|
||||
self.assertEqual(new_builder.symbols, counts + counts)
|
||||
|
||||
def test_code_builder_can_io(self):
|
||||
data = make_data()
|
||||
builder = make_code_builder(data)
|
||||
|
||||
with NamedTemporaryFile() as tmp_fp:
|
||||
builder.to_file(tmp_fp.name)
|
||||
other_builder = HuffmanCodeBuilder.from_file(tmp_fp.name)
|
||||
|
||||
self.assertEqual(builder.symbols, other_builder.symbols)
|
||||
|
||||
|
||||
class TestCoder(unittest.TestCase):
|
||||
def test_coder_can_io(self):
|
||||
data = make_data()
|
||||
builder = make_code_builder(data)
|
||||
coder = builder.build_code()
|
||||
|
||||
with NamedTemporaryFile() as tmp_fp:
|
||||
coder.to_file(tmp_fp.name)
|
||||
other_coder = HuffmanCoder.from_file(tmp_fp.name)
|
||||
|
||||
self.assertEqual(coder, other_coder)
|
||||
|
||||
def test_coder_can_encode_decode(self):
|
||||
data = make_data()
|
||||
builder = make_code_builder(data)
|
||||
coder = builder.build_code()
|
||||
|
||||
encoded = [coder.encode(sentence) for sentence in data]
|
||||
decoded = [[n.symbol for n in coder.decode(enc)] for enc in encoded]
|
||||
|
||||
self.assertEqual(decoded, data)
|
||||
|
||||
unseen_data = make_data()
|
||||
unseen_encoded = [coder.encode(sentence) for sentence in unseen_data]
|
||||
unseen_decoded = [
|
||||
[n.symbol for n in coder.decode(enc)] for enc in unseen_encoded
|
||||
]
|
||||
self.assertEqual(unseen_decoded, unseen_data)
|
||||
|
||||
|
||||
def build_dataset(prefix, data, coder):
|
||||
with HuffmanMMapIndexedDatasetBuilder(prefix, coder) as builder:
|
||||
for sentence in data:
|
||||
builder.add_item(sentence)
|
||||
|
||||
|
||||
def sizes(data):
|
||||
return [len(sentence) for sentence in data]
|
||||
|
||||
|
||||
class TestHuffmanDataset(unittest.TestCase):
|
||||
def test_huffman_can_encode_decode(self):
|
||||
data = make_data()
|
||||
builder = make_code_builder(data)
|
||||
coder = builder.build_code()
|
||||
|
||||
with TemporaryDirectory() as dirname:
|
||||
prefix = os.path.join(dirname, "test1")
|
||||
build_dataset(prefix, data, coder)
|
||||
dataset = HuffmanMMapIndexedDataset(prefix)
|
||||
|
||||
self.assertEqual(len(dataset), len(data))
|
||||
decoded = [list(dataset.get_symbols(i)) for i in range(0, len(dataset))]
|
||||
|
||||
self.assertEqual(decoded, data)
|
||||
data_sizes = [i.item() for i in dataset.sizes]
|
||||
self.assertEqual(data_sizes, sizes(data))
|
||||
|
||||
def test_huffman_compresses(self):
|
||||
data = make_data()
|
||||
builder = make_code_builder(data)
|
||||
coder = builder.build_code()
|
||||
|
||||
with TemporaryDirectory() as dirname:
|
||||
prefix = os.path.join(dirname, "huffman")
|
||||
build_dataset(prefix, data, coder)
|
||||
|
||||
prefix_mmap = os.path.join(dirname, "mmap")
|
||||
mmap_builder = indexed_dataset.make_builder(
|
||||
indexed_dataset.data_file_path(prefix_mmap),
|
||||
"mmap",
|
||||
vocab_size=len(POPULATION),
|
||||
)
|
||||
dictionary = Dictionary()
|
||||
for c in POPULATION:
|
||||
dictionary.add_symbol(c)
|
||||
dictionary.finalize()
|
||||
for sentence in data:
|
||||
mmap_builder.add_item(dictionary.encode_line(" ".join(sentence)))
|
||||
mmap_builder.finalize(indexed_dataset.index_file_path(prefix_mmap))
|
||||
|
||||
huff_size = os.stat(indexed_dataset.data_file_path(prefix)).st_size
|
||||
mmap_size = os.stat(indexed_dataset.data_file_path(prefix_mmap)).st_size
|
||||
self.assertLess(huff_size, mmap_size)
|
||||
|
||||
def test_huffman_can_append(self):
|
||||
data1 = make_data()
|
||||
builder = make_code_builder(data1)
|
||||
coder = builder.build_code()
|
||||
|
||||
with TemporaryDirectory() as dirname:
|
||||
prefix1 = os.path.join(dirname, "test1")
|
||||
build_dataset(prefix1, data1, coder)
|
||||
|
||||
data2 = make_data()
|
||||
prefix2 = os.path.join(dirname, "test2")
|
||||
build_dataset(prefix2, data2, coder)
|
||||
|
||||
prefix3 = os.path.join(dirname, "test3")
|
||||
|
||||
with HuffmanMMapIndexedDatasetBuilder(prefix3, coder) as builder:
|
||||
builder.append(prefix1)
|
||||
builder.append(prefix2)
|
||||
|
||||
dataset = HuffmanMMapIndexedDataset(prefix3)
|
||||
|
||||
self.assertEqual(len(dataset), len(data1) + len(data2))
|
||||
|
||||
decoded1 = [list(dataset.get_symbols(i)) for i in range(0, len(data1))]
|
||||
self.assertEqual(decoded1, data1)
|
||||
|
||||
decoded2 = [
|
||||
list(dataset.get_symbols(i)) for i in range(len(data1), len(dataset))
|
||||
]
|
||||
self.assertEqual(decoded2, data2)
|
||||
|
||||
data_sizes = [i.item() for i in dataset.sizes]
|
||||
self.assertEqual(data_sizes[: len(data1)], sizes(data1))
|
||||
self.assertEqual(data_sizes[len(data1) : len(dataset)], sizes(data2))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in New Issue
Block a user