Indexed Huffman Coded dataset (#2029)

Summary:
## What does this PR do?

Currently, binarized dataset are stored as a bin representation of int tensors. At best, each int is coded as uint16 on disk.

When coding a fixed size vocabulary dataset where we know the frequency of each symbol and where some symbols are more common than other, we can do better. This happens in particular when binarizing a dataset split in subword units as the most common "tokenizers" like bpe and spm will choose subwords with high frequencies over subwords with low frequencies.

In practice, if we know the frequency of all symbols (or a good estimate), we can use entropy encoding methods to compress the data. The idea is to assign a compressed representation where frequent symbols will have shorter representations than unfrequent symbols.

In this PR, we build a Huffman code from a frequency table and use this code to encode a dataset. The PR provides the huffman coder implementation (using the single queue approach as we usually start with a sorted set of symbols) as well as a memory map implementation of a dataset that stores the data compressed with a huffman code and can return indexed tensors from it.

Over a whole dataset, depending on how many symbols we sample to evaluate the frequency, we can save between 25% and 30% of storage space.

## Follow Ups

currently the binarizer/preprocess script make too many assumptions about the dataset writers so the huffman dataset writer cannot be used straight out of the box with it. I will make follow ups PRs to provide easy to use scripts to build such datasets. But it's as simple as doing:
```
code_builder = HuffmanCodeBuilder()
with open(sample_file, 'r', encoding="utf-8") as input:
    for line in input:
        code_builder.add(*line.strip().split(" "))

coder = code_builder.build_code()

with HuffmanMMapIndexedDatasetBuilder('/tmp/testing_huffman', coder) as builder:
    with open(dataset_file, 'r', encoding="utf-8") as input:
        for line in input:
            builder.add_item(line.strip().split(' '))
```

a lot of the `HuffmanMMapIndexedDataset` code comes from the normal `MMapIndexedDataset` and we could probably extract commonalities in a base class

the `HuffmanCoder` is also really a special kind of `Dictionary` and again, a common base class could be abstracted out of them.

Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/2029

Reviewed By: dianaml0

Differential Revision: D29557468

Pulled By: Mortimerp9

fbshipit-source-id: a01b6d98f38f937934cadebb3786133e257adefe
This commit is contained in:
Pierre Andrews 2021-08-31 01:11:34 -07:00 committed by Facebook GitHub Bot
parent 5277ec47bd
commit 68a81202a3
9 changed files with 788 additions and 2 deletions

View File

@ -267,7 +267,7 @@ class Dictionary:
self.add_symbol(word, n=count, overwrite=overwrite)
except ValueError:
raise ValueError(
"Incorrect dictionary format, expected '<token> <cnt> [flags]'"
f"Incorrect dictionary format, expected '<token> <cnt> [flags]': \"{line}\""
)
def _save(self, f, kv_iterator):

View File

@ -0,0 +1,21 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from .huffman_coder import HuffmanCodeBuilder, HuffmanCoder
from .huffman_mmap_indexed_dataset import (
HuffmanMMapIndex,
HuffmanMMapIndexedDataset,
HuffmanMMapIndexedDatasetBuilder,
vocab_file_path,
)
__all__ = [
"HuffmanCoder",
"HuffmanCodeBuilder",
"HuffmanMMapIndexedDatasetBuilder",
"HuffmanMMapIndexedDataset",
"HuffmanMMapIndex",
"vocab_file_path",
]

View File

@ -0,0 +1,265 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import re
import typing as tp
from collections import Counter, deque
from dataclasses import dataclass
from bitarray import bitarray, util
from fairseq.data import Dictionary
# basically we have to write to addressable bytes for the memory mapped
# dataset loader. Sentences that get encoded to a length that is not a
# multiple of BLOCKSIZE (a byte) will be padded to fit. (see _pad in the coder)
BLOCKSIZE = 8
class HuffmanCoder:
def __init__(
self, root: "HuffmanNode", bos="<s>", pad="<pad>", eos="</s>", unk="<unk>"
):
self.root = root
self.table = root.code_table()
self.bos_word, self.unk_word, self.pad_word, self.eos_word = bos, unk, pad, eos
def _pad(self, a: bitarray) -> bitarray:
"""
bitpadding, 1 then 0.
If the array is already a multiple of blocksize, we add a full block.
"""
pad_len = BLOCKSIZE - (len(a) % BLOCKSIZE) - 1
padding = bitarray("1" + "0" * pad_len)
return a + padding
def _unpad(self, a: bitarray) -> bitarray:
"""
remove the bitpadding.
There will be a set of 0s preceded by a 1 at the end of the bitarray, we remove that
"""
# count the 0 padding at the end until we find the first 1
# we want to remove the one too
remove_cnt = util.rindex(a, 1)
return a[:remove_cnt]
def encode(self, iter: tp.List[str]) -> bytes:
"""
encode a list of tokens a return bytes. We use bitpadding to make sure the encoded bits fit in bytes.
"""
a = bitarray()
for token in iter:
code = self.get_code(token)
if code is None:
if self.unk_word is None:
raise Exception(f"unknown token {token} cannot be encoded.")
else:
token = self.unk_word
a = a + self.get_code(token)
return self._pad(a).tobytes()
def decode(self, bits: bytes) -> tp.Iterator["HuffmanNode"]:
"""
take bitpadded bytes and decode it to a set of leaves. You can then use each node to find the symbol/id
"""
a = bitarray()
a.frombytes(bits)
return self.root.decode(self._unpad(a))
def get_code(self, symbol: str) -> tp.Optional[bitarray]:
node = self.get_node(symbol)
return None if node is None else node.code
def get_node(self, symbol: str) -> "HuffmanNode":
return self.table.get(symbol)
@classmethod
def from_file(
cls,
filename: str,
bos="<s>",
pad="<pad>",
eos="</s>",
unk="<unk>",
) -> "HuffmanCoder":
builder = HuffmanCodeBuilder.from_file(filename)
return builder.build_code(bos=bos, pad=pad, eos=eos, unk=unk)
def to_file(self, filename, sep="\t"):
nodes = list(self.table.values())
nodes.sort(key=lambda n: n.id)
with open(filename, "w", encoding="utf-8") as output:
for n in nodes:
output.write(f"{n.symbol}{sep}{n.count}\n")
def __iter__(self):
for n in self.table.values():
yield n
def merge(self, other_coder: "HuffmanCoder") -> "HuffmanCoder":
builder = HuffmanCodeBuilder()
for n in self:
builder.increment(n.symbol, n.count)
for n in other_coder:
builder.increment(n.symbol, n.count)
return builder.build_code()
def __eq__(self, other: "HuffmanCoder") -> bool:
return self.table == other.table
def __len__(self) -> int:
return len(self.table)
def __contains__(self, sym: str) -> bool:
return sym in self.table
def to_dictionary(self) -> Dictionary:
dictionary = Dictionary(bos=self.bos, unk=self.unk, pad=self.pad, eos=self.eos)
for n in self:
dictionary.add_symbol(n.symbol, n=n.count)
dictionary.finalize()
return dictionary
@dataclass
class HuffmanNode:
"""
a node in a Huffman tree
"""
id: int
count: int
symbol: tp.Optional[str] = None
left: tp.Optional["HuffmanNode"] = None
right: tp.Optional["HuffmanNode"] = None
code: tp.Optional[bitarray] = None
def is_leaf(self) -> bool:
return self.left is None and self.right is None
def code_table(self, prefix: tp.Optional[bitarray] = None) -> tp.Dict[str, "HuffmanNode"]:
defaulted_prefix = prefix if prefix is not None else bitarray()
if self.is_leaf():
self.code = (
defaulted_prefix if len(defaulted_prefix) > 0 else bitarray("0")
) # leaf could be the root if there is only one symbol
return {self.symbol: self}
codes_right = self.right.code_table(defaulted_prefix + bitarray([0]))
codes_left = self.left.code_table(defaulted_prefix + bitarray([1]))
return {**codes_left, **codes_right}
def decode(self, bits: bitarray) -> tp.Iterator["HuffmanNode"]:
current_node = self
for bit in bits:
if bit == 0: # go right
current_node = current_node.right
else: # go left
current_node = current_node.left
if current_node is None:
# we shouldn't be on a leaf here
raise Exception("fell off a leaf")
if current_node.is_leaf():
yield current_node
current_node = self
if current_node != self:
raise Exception("couldn't decode all the bits")
class HuffmanCodeBuilder:
"""
build a dictionary with occurence count and then build the Huffman code for it.
"""
def __init__(self):
self.symbols = Counter()
def add_symbols(self, *syms) -> None:
self.symbols.update(syms)
def increment(self, symbol: str, cnt: int) -> None:
self.symbols[symbol] += cnt
@classmethod
def from_file(cls, filename):
c = cls()
with open(filename, "r", encoding="utf-8") as input:
for line in input:
split = re.split(r"[\s]+", line)
c.increment(split[0], int(split[1]))
return c
def to_file(self, filename, sep="\t"):
with open(filename, "w", encoding="utf-8") as output:
for (tok, cnt) in self.symbols.most_common():
output.write(f"{tok}{sep}{cnt}\n")
def _smallest(self, q1: deque, q2: deque) -> HuffmanNode:
if len(q1) == 0:
return q2.pop()
if len(q2) == 0:
return q1.pop()
if q1[-1].count < q2[-1].count:
return q1.pop()
return q2.pop()
def __add__(self, c: "HuffmanCodeBuilder") -> "HuffmanCodeBuilder":
new_c = self.symbols + c.symbols
new_b = HuffmanCodeBuilder()
new_b.symbols = new_c
return new_b
def build_code(
self,
bos="<s>",
pad="<pad>",
eos="</s>",
unk="<unk>",
) -> HuffmanCoder:
assert len(self.symbols) > 0, "cannot build code from empty list of symbols"
if self.symbols[bos] == 0:
self.add_symbols(bos)
if self.symbols[pad] == 0:
self.add_symbols(pad)
if self.symbols[eos] == 0:
self.add_symbols(eos)
if self.symbols[unk] == 0:
self.add_symbols(unk)
node_id = 0
leaves_queue = deque(
[
HuffmanNode(symbol=symbol, count=count, id=idx)
for idx, (symbol, count) in enumerate(self.symbols.most_common())
]
) # left are the most common, right are the least common
if len(leaves_queue) == 1:
root = leaves_queue.pop()
root.id = 0
return HuffmanCoder(root)
nodes_queue = deque()
while len(leaves_queue) > 0 or len(nodes_queue) != 1:
# get the lowest two nodes at the head of each queue
node1 = self._smallest(leaves_queue, nodes_queue)
node2 = self._smallest(leaves_queue, nodes_queue)
# add new node
nodes_queue.appendleft(
HuffmanNode(
count=node1.count + node2.count, left=node1, right=node2, id=node_id
)
)
node_id += 1
# we are left with the root
return HuffmanCoder(nodes_queue.pop(), bos=bos, pad=pad, eos=eos, unk=unk)

View File

@ -0,0 +1,287 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import mmap
import os
import shutil
import struct
import typing as tp
from functools import lru_cache
import numpy as np
import torch
from fairseq.data import indexed_dataset
from fairseq.data.huffman import HuffmanCoder
from fairseq.file_io import PathManager
class HuffmanMMapIndex:
"""
keep an index of the offsets in the huffman binary file.
First a header, then the list of sizes (num tokens) for each instance and finally
the addresses of each instance.
"""
_HDR_MAGIC = b"HUFFIDX\x00\x00"
_VERSION = 1
@classmethod
def writer(cls, path: str, data_len: int):
class _Writer:
def __enter__(self):
self._file = open(path, "wb")
# write header (magic + version)
self._file.write(cls._HDR_MAGIC)
self._file.write(struct.pack("<Q", cls._VERSION))
self._file.write(struct.pack("<Q", data_len))
return self
def write(self, sizes, pointers):
# add number of items in the index to the header
self._file.write(struct.pack("<Q", len(sizes)))
# write sizes
sizes = np.array(sizes, dtype=np.int32)
self._file.write(sizes.tobytes(order="C"))
del sizes
# write address pointers
pointers = np.array(pointers, dtype=np.int64)
self._file.write(pointers.tobytes(order="C"))
del pointers
def __exit__(self, exc_type, exc_val, exc_tb):
self._file.close()
return _Writer()
def __init__(self, path):
with open(path, "rb") as stream:
# read headers
magic_test = stream.read(9)
assert self._HDR_MAGIC == magic_test, (
"Index file doesn't match expected format. "
"Make sure that --dataset-impl is configured properly."
)
(version,) = struct.unpack("<Q", stream.read(8))
assert (
self._VERSION == version
), "Unexpected file version f{version} != code version f{self._VERSION}"
# read length of data file
(self._data_len,) = struct.unpack("<Q", stream.read(8))
# read number of items in data file/index
(self._len,) = struct.unpack("<Q", stream.read(8))
offset = stream.tell()
indexed_dataset._warmup_mmap_file(path)
self._bin_buffer_mmap = np.memmap(path, mode="r", order="C")
self._bin_buffer = memoryview(self._bin_buffer_mmap)
self._sizes = np.frombuffer(
self._bin_buffer, dtype=np.int32, count=self._len, offset=offset
)
self._pointers = np.frombuffer(
self._bin_buffer,
dtype=np.int64,
count=self._len,
offset=offset + self._sizes.nbytes,
)
def __del__(self):
self._bin_buffer_mmap._mmap.close()
del self._bin_buffer_mmap
def __iter__(self):
for i in range(self._len):
yield self[i]
@property
def data_len(self):
return self._data_len
@property
def sizes(self):
return self._sizes
@lru_cache(maxsize=8)
def __getitem__(self, i):
return self._pointers[i], self._sizes[i]
def __len__(self):
return self._len
def vocab_file_path(prefix_path):
return prefix_path + ".vocab"
class HuffmanMMapIndexedDataset(torch.utils.data.Dataset):
"""
an indexed dataset that use mmap and memoryview to access data from disk
that was compressed with a HuffmanCoder.
"""
def __init__(self, prefix_path):
super().__init__()
self._prefix_path = None
self._index = None
self._bin_buffer = None
self._coder = None
self._file = None
self._bin_buffer_mmap = None
self._do_init(prefix_path)
def __getstate__(self):
return self._prefix_path
def __setstate__(self, state):
self._do_init(state)
def _do_init(self, prefix_path):
self._prefix_path = prefix_path
self._index = HuffmanMMapIndex(
indexed_dataset.index_file_path(self._prefix_path)
)
self._coder = HuffmanCoder.from_file(vocab_file_path(self._prefix_path))
indexed_dataset._warmup_mmap_file(
indexed_dataset.data_file_path(self._prefix_path)
)
self._file = os.open(
indexed_dataset.data_file_path(self._prefix_path), os.O_RDONLY
)
self._bin_buffer_mmap = mmap.mmap(
self._file,
self._index.data_len,
access=mmap.ACCESS_READ,
)
self._bin_buffer = memoryview(self._bin_buffer_mmap)
def __del__(self):
del self._bin_buffer
if self._file:
os.close(self._file)
del self._index
def __len__(self):
return len(self._index)
def _decode(self, i):
ptr, _ = self._index[i]
if i == 0:
raw_bytes = self._bin_buffer[:ptr]
else:
(prev_ptr, _) = self._index[i - 1]
raw_bytes = self._bin_buffer[prev_ptr:ptr]
return self._coder.decode(raw_bytes.tobytes())
@lru_cache(maxsize=8)
def __getitem__(self, i):
nodes = self._decode(i)
return torch.tensor([n.id for n in nodes], dtype=torch.int64)
def __iter__(self):
for idx in range(len(self)):
yield self[idx]
def get_symbols(self, i):
nodes = self._decode(i)
for n in nodes:
yield n.symbol
@property
def sizes(self):
return self._index.sizes
@property
def supports_prefetch(self):
return False
@property
def coder(self):
return self._coder
@staticmethod
def exists(prefix_path):
return (
PathManager.exists(indexed_dataset.index_file_path(prefix_path))
and PathManager.exists(indexed_dataset.data_file_path(prefix_path))
and PathManager.exists(vocab_file_path(prefix_path))
)
class HuffmanMMapIndexedDatasetBuilder:
"""
Helper to build a memory mapped datasets with a huffman encoder.
You can either open/close this manually or use it as a ContextManager.
Provide your own coder, it will then be stored alongside the dataset.
The builder will first write the vocab file, then open the binary file so you can stream
into it, finally the index will be written when the builder is closed (your index should fit in memory).
"""
def __init__(self, path_prefix: str, coder: HuffmanCoder) -> None:
self._path_prefix = path_prefix
self._coder = coder
self._sizes = []
self._ptrs = []
self._data_len = 0
def open(self):
self._coder.to_file(vocab_file_path(self._path_prefix))
self._data_file = open(indexed_dataset.data_file_path(self._path_prefix), "wb")
def __enter__(self) -> "HuffmanMMapIndexedDatasetBuilder":
self.open()
return self
def add_item(self, tokens: tp.List[str]) -> None:
"""
add a list of tokens to the dataset, they will compressed with the
provided coder before being written to file.
"""
encoded = self._coder.encode(tokens)
code_len = len(encoded)
last_ptr = 0
if len(self._ptrs) > 0:
last_ptr = self._ptrs[-1]
self._sizes.append(len(tokens))
self._ptrs.append(last_ptr + code_len)
self._data_len += code_len
self._data_file.write(encoded)
def append(self, other_dataset_path_prefix: str) -> None:
"""
append an existing dataset.
Beware, if it wasn't built with the same coder, you are in trouble.
"""
other_index = HuffmanMMapIndex(
indexed_dataset.index_file_path(other_dataset_path_prefix)
)
for (ptr, size) in other_index:
self._ptrs.append(ptr + self._data_len)
self._sizes.append(size)
# Concatenate data
with open(indexed_dataset.data_file_path(other_dataset_path_prefix), "rb") as f:
shutil.copyfileobj(f, self._data_file)
self._data_len += other_index.data_len
def close(self):
self._data_file.close()
with HuffmanMMapIndex.writer(
indexed_dataset.index_file_path(self._path_prefix), self._data_len
) as index:
index.write(self._sizes, self._ptrs)
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
self.close()

View File

@ -12,6 +12,7 @@ import torch
from fairseq.dataclass.constants import DATASET_IMPL_CHOICES
from fairseq.data.fasta_dataset import FastaDataset
from fairseq.file_io import PathManager
from fairseq.data.huffman import HuffmanMMapIndexedDataset, HuffmanMMapIndex
from . import FairseqDataset
@ -48,6 +49,8 @@ def infer_dataset_impl(path):
return "cached"
elif magic == MMapIndexedDataset.Index._HDR_MAGIC[:8]:
return "mmap"
elif magic == HuffmanMMapIndex._HDR_MAGIC[:8]:
return "huffman"
else:
return None
elif FastaDataset.exists(path):
@ -63,6 +66,8 @@ def make_builder(out_file, impl, vocab_size=None):
)
elif impl == "fasta":
raise NotImplementedError
elif impl == "huffman":
raise ValueError("Use HuffmanCodeBuilder directly as it has a different interface.")
else:
return IndexedDatasetBuilder(out_file)
@ -81,6 +86,8 @@ def make_dataset(path, impl, fix_lua_indexing=False, dictionary=None):
from fairseq.data.fasta_dataset import EncodedFastaDataset
return EncodedFastaDataset(path, dictionary)
elif impl == "huffman" and HuffmanMMapIndexedDataset.exists(path):
return HuffmanMMapIndexedDataset(path)
return None
@ -89,6 +96,8 @@ def dataset_exists(path, impl):
return IndexedRawTextDataset.exists(path)
elif impl == "mmap":
return MMapIndexedDataset.exists(path)
elif impl == "huffman":
return HuffmanMMapIndexedDataset.exists(path)
else:
return IndexedDataset.exists(path)

View File

@ -44,7 +44,7 @@ DDP_BACKEND_CHOICES = ChoiceEnum([
"slow_mo",
])
DDP_COMM_HOOK_CHOICES = ChoiceEnum(["none", "fp16"])
DATASET_IMPL_CHOICES = ChoiceEnum(["raw", "lazy", "cached", "mmap", "fasta"])
DATASET_IMPL_CHOICES = ChoiceEnum(["raw", "lazy", "cached", "mmap", "fasta", "huffman"])
GENERATION_CONSTRAINTS_CHOICES = ChoiceEnum(["ordered", "unordered"])
GENERATION_DECODING_FORMAT_CHOICES = ChoiceEnum(
["unigram", "ensemble", "vote", "dp", "bs"]

View File

@ -41,6 +41,8 @@ def main(args):
)
logger.info(args)
assert args.dataset_impl != "huffman", "preprocessing.py doesn't support Huffman yet, use HuffmanCodeBuilder directly."
task = tasks.get_task(args.task)
def train_path(lang):

View File

@ -209,6 +209,7 @@ def do_setup(package_data):
"sacrebleu>=1.4.12",
"torch",
"tqdm",
"bitarray",
],
dependency_links=dependency_links,
packages=find_packages(

201
tests/test_huffman.py Normal file
View File

@ -0,0 +1,201 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import os
import random
import string
import typing as tp
import unittest
from collections import Counter
from tempfile import NamedTemporaryFile, TemporaryDirectory
from fairseq.data import Dictionary, indexed_dataset
from fairseq.data.huffman import (
HuffmanCodeBuilder,
HuffmanCoder,
HuffmanMMapIndexedDataset,
HuffmanMMapIndexedDatasetBuilder,
)
POPULATION = string.ascii_letters + string.digits
def make_sentence() -> tp.List[str]:
length = random.randint(10, 50)
return random.choices(
population=POPULATION, k=length, weights=range(1, len(POPULATION) + 1)
)
def make_data(length=1000) -> tp.List[tp.List[str]]:
return (
[make_sentence() for _ in range(0, length)]
# add all the symbols at least once
+ [list(string.ascii_letters), list(string.digits)]
)
def make_counts(data: tp.List[tp.List[str]]) -> Counter:
return Counter([symbol for sentence in data for symbol in sentence])
def make_code_builder(data: tp.List[tp.List[str]]) -> HuffmanCodeBuilder:
builder = HuffmanCodeBuilder()
for sentence in data:
builder.add_symbols(*sentence)
return builder
class TestCodeBuilder(unittest.TestCase):
def test_code_builder_can_count(self):
data = make_data()
counts = make_counts(data)
builder = make_code_builder(data)
self.assertEqual(builder.symbols, counts)
def test_code_builder_can_add(self):
data = make_data()
counts = make_counts(data)
builder = make_code_builder(data)
new_builder = builder + builder
self.assertEqual(new_builder.symbols, counts + counts)
def test_code_builder_can_io(self):
data = make_data()
builder = make_code_builder(data)
with NamedTemporaryFile() as tmp_fp:
builder.to_file(tmp_fp.name)
other_builder = HuffmanCodeBuilder.from_file(tmp_fp.name)
self.assertEqual(builder.symbols, other_builder.symbols)
class TestCoder(unittest.TestCase):
def test_coder_can_io(self):
data = make_data()
builder = make_code_builder(data)
coder = builder.build_code()
with NamedTemporaryFile() as tmp_fp:
coder.to_file(tmp_fp.name)
other_coder = HuffmanCoder.from_file(tmp_fp.name)
self.assertEqual(coder, other_coder)
def test_coder_can_encode_decode(self):
data = make_data()
builder = make_code_builder(data)
coder = builder.build_code()
encoded = [coder.encode(sentence) for sentence in data]
decoded = [[n.symbol for n in coder.decode(enc)] for enc in encoded]
self.assertEqual(decoded, data)
unseen_data = make_data()
unseen_encoded = [coder.encode(sentence) for sentence in unseen_data]
unseen_decoded = [
[n.symbol for n in coder.decode(enc)] for enc in unseen_encoded
]
self.assertEqual(unseen_decoded, unseen_data)
def build_dataset(prefix, data, coder):
with HuffmanMMapIndexedDatasetBuilder(prefix, coder) as builder:
for sentence in data:
builder.add_item(sentence)
def sizes(data):
return [len(sentence) for sentence in data]
class TestHuffmanDataset(unittest.TestCase):
def test_huffman_can_encode_decode(self):
data = make_data()
builder = make_code_builder(data)
coder = builder.build_code()
with TemporaryDirectory() as dirname:
prefix = os.path.join(dirname, "test1")
build_dataset(prefix, data, coder)
dataset = HuffmanMMapIndexedDataset(prefix)
self.assertEqual(len(dataset), len(data))
decoded = [list(dataset.get_symbols(i)) for i in range(0, len(dataset))]
self.assertEqual(decoded, data)
data_sizes = [i.item() for i in dataset.sizes]
self.assertEqual(data_sizes, sizes(data))
def test_huffman_compresses(self):
data = make_data()
builder = make_code_builder(data)
coder = builder.build_code()
with TemporaryDirectory() as dirname:
prefix = os.path.join(dirname, "huffman")
build_dataset(prefix, data, coder)
prefix_mmap = os.path.join(dirname, "mmap")
mmap_builder = indexed_dataset.make_builder(
indexed_dataset.data_file_path(prefix_mmap),
"mmap",
vocab_size=len(POPULATION),
)
dictionary = Dictionary()
for c in POPULATION:
dictionary.add_symbol(c)
dictionary.finalize()
for sentence in data:
mmap_builder.add_item(dictionary.encode_line(" ".join(sentence)))
mmap_builder.finalize(indexed_dataset.index_file_path(prefix_mmap))
huff_size = os.stat(indexed_dataset.data_file_path(prefix)).st_size
mmap_size = os.stat(indexed_dataset.data_file_path(prefix_mmap)).st_size
self.assertLess(huff_size, mmap_size)
def test_huffman_can_append(self):
data1 = make_data()
builder = make_code_builder(data1)
coder = builder.build_code()
with TemporaryDirectory() as dirname:
prefix1 = os.path.join(dirname, "test1")
build_dataset(prefix1, data1, coder)
data2 = make_data()
prefix2 = os.path.join(dirname, "test2")
build_dataset(prefix2, data2, coder)
prefix3 = os.path.join(dirname, "test3")
with HuffmanMMapIndexedDatasetBuilder(prefix3, coder) as builder:
builder.append(prefix1)
builder.append(prefix2)
dataset = HuffmanMMapIndexedDataset(prefix3)
self.assertEqual(len(dataset), len(data1) + len(data2))
decoded1 = [list(dataset.get_symbols(i)) for i in range(0, len(data1))]
self.assertEqual(decoded1, data1)
decoded2 = [
list(dataset.get_symbols(i)) for i in range(len(data1), len(dataset))
]
self.assertEqual(decoded2, data2)
data_sizes = [i.item() for i in dataset.sizes]
self.assertEqual(data_sizes[: len(data1)], sizes(data1))
self.assertEqual(data_sizes[len(data1) : len(dataset)], sizes(data2))
if __name__ == "__main__":
unittest.main()