TokenBlockDataset np type promotion issue (#1658)

Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1658

Reviewed By: jxmsML

Differential Revision: D26701840

Pulled By: sshleifer

fbshipit-source-id: 90d631c3cd775ab847366fe7a05136c29d90cd63
This commit is contained in:
Sam Shleifer 2021-02-26 20:59:22 -08:00 committed by Facebook GitHub Bot
parent f569c024ae
commit 4f881a760e
4 changed files with 41 additions and 10 deletions

View File

@ -18,9 +18,9 @@ from . import FairseqDataset
from typing import Union
def best_fitting_uint_dtype(
def best_fitting_int_dtype(
max_int_to_represent,
) -> Union[np.uint16, np.uint32, np.uint64]:
) -> Union[np.uint16, np.uint32, np.int64]:
if max_int_to_represent is None:
return np.uint32 # Safe guess
@ -29,7 +29,9 @@ def best_fitting_uint_dtype(
elif max_int_to_represent < 4294967295:
return np.uint32
else:
return np.uint64
return np.int64
# we avoid np.uint64 because it doesn't save space and its type promotion behaves unexpectedly
# https://github.com/numpy/numpy/issues/5745
def get_available_dataset_impl():
@ -57,7 +59,7 @@ def infer_dataset_impl(path):
def make_builder(out_file, impl, vocab_size=None):
if impl == "mmap":
return MMapIndexedDatasetBuilder(
out_file, dtype=best_fitting_uint_dtype(vocab_size)
out_file, dtype=best_fitting_int_dtype(vocab_size)
)
elif impl == "fasta":
raise NotImplementedError

View File

@ -6,7 +6,8 @@
import numpy as np
import torch
from fairseq.data import FairseqDataset, plasma_utils
from fairseq.data.indexed_dataset import best_fitting_uint_dtype
from fairseq.data.indexed_dataset import best_fitting_int_dtype
class TokenBlockDataset(FairseqDataset):
"""Break a Dataset of tokens into blocks.
@ -95,15 +96,18 @@ class TokenBlockDataset(FairseqDataset):
)
else:
block_to_dataset_index = _get_block_to_dataset_index_fast(
sizes,
slice_indices,
sizes, slice_indices,
)
size_dtype = np.uint16 if block_size < 65535 else np.uint32
slice_indices_dtype = best_fitting_uint_dtype(slice_indices[-1].max())
slice_indices_dtype = best_fitting_int_dtype(slice_indices[-1].max())
self._slice_indices = plasma_utils.PlasmaArray(slice_indices.astype(slice_indices_dtype))
self._slice_indices = plasma_utils.PlasmaArray(
slice_indices.astype(slice_indices_dtype)
)
self._sizes = plasma_utils.PlasmaArray(self._sizes.astype(size_dtype))
self._block_to_dataset_index = plasma_utils.PlasmaArray(block_to_dataset_index.astype(slice_indices_dtype))
self._block_to_dataset_index = plasma_utils.PlasmaArray(
block_to_dataset_index.astype(slice_indices_dtype)
)
@property
def slice_indices(self):

View File

@ -394,6 +394,18 @@ def transformer_lm_gpt2_small(args):
base_lm_architecture(args)
@register_model_architecture("transformer_lm", "transformer_lm_gpt2_tiny")
def transformer_lm_gpt2_tiny(args):
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 64)
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 64)
args.decoder_layers = getattr(args, "decoder_layers", 2)
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 1)
args.dropout = getattr(args, "dropout", 0.1)
args.attention_dropout = getattr(args, "attention_dropout", 0.1)
args.activation_fn = getattr(args, "activation_fn", "gelu")
base_lm_architecture(args)
@register_model_architecture("transformer_lm", "transformer_lm_gpt2_medium")
def transformer_lm_gpt2_medium(args):
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 1280)

View File

@ -74,6 +74,19 @@ class TestTokenBlockDataset(unittest.TestCase):
self.assertEqual(ds[1].tolist(), [5, 1, 1])
self.assertEqual(ds[2].tolist(), [6, 1])
def test_4billion_tokens(self):
"""Regression test for numpy type promotion issue https://github.com/numpy/numpy/issues/5745"""
data = [torch.tensor(list(range(10000)), dtype=torch.long)] * 430000
ds = self._build_dataset(
data, block_size=6, pad=0, eos=1, break_mode="complete"
)
ds[-1] # __getitem__ works
start, end = ds.slice_indices[-1]
assert end > 4294967295 # data must be sufficiently large to overflow uint32
assert not isinstance(
end + 1, float
) # this would also raise, since np.uint64(1) + 1 => 2.0
if __name__ == "__main__":
unittest.main()