Add new Datasets

Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/757

Differential Revision: D16418305

Pulled By: myleott

fbshipit-source-id: 25f293a2792509f7a75c688e4bf8cff02e6bba2e
This commit is contained in:
Myle Ott 2019-07-22 08:55:12 -07:00 committed by Facebook Github Bot
parent 51ba35217d
commit 654affc03d
11 changed files with 569 additions and 0 deletions

View File

@ -9,15 +9,26 @@ from .dictionary import Dictionary, TruncatedDictionary
from .fairseq_dataset import FairseqDataset
from .base_wrapper_dataset import BaseWrapperDataset
from .audio.raw_audio_dataset import RawAudioDataset
from .backtranslation_dataset import BacktranslationDataset
from .concat_dataset import ConcatDataset
from .id_dataset import IdDataset
from .indexed_dataset import IndexedCachedDataset, IndexedDataset, IndexedRawTextDataset, MMapIndexedDataset
from .language_pair_dataset import LanguagePairDataset
from .lm_context_window_dataset import LMContextWindowDataset
from .lru_cache_dataset import LRUCacheDataset
from .mask_tokens_dataset import MaskTokensDataset
from .monolingual_dataset import MonolingualDataset
from .nested_dictionary_dataset import NestedDictionaryDataset
from .noising import NoisingDataset
from .numel_dataset import NumelDataset
from .num_samples_dataset import NumSamplesDataset
from .pad_dataset import LeftPadDataset, PadDataset, RightPadDataset
from .prepend_token_dataset import PrependTokenDataset
from .round_robin_zip_datasets import RoundRobinZipDatasets
from .sort_dataset import SortDataset
from .token_block_dataset import TokenBlockDataset
from .transform_eos_dataset import TransformEosDataset
from .transform_eos_lang_pair_dataset import TransformEosLangPairDataset
@ -31,23 +42,35 @@ from .iterators import (
__all__ = [
'BacktranslationDataset',
'BaseWrapperDataset',
'ConcatDataset',
'CountingIterator',
'Dictionary',
'EpochBatchIterator',
'FairseqDataset',
'GroupedIterator',
'IdDataset',
'IndexedCachedDataset',
'IndexedDataset',
'IndexedRawTextDataset',
'LanguagePairDataset',
'LeftPadDataset',
'LMContextWindowDataset',
'LRUCacheDataset',
'MaskTokensDataset',
'MMapIndexedDataset',
'MonolingualDataset',
'NestedDictionaryDataset',
'NoisingDataset',
'NumelDataset',
'NumSamplesDataset',
'PadDataset',
'PrependTokenDataset',
'RawAudioDataset',
'RightPadDataset',
'RoundRobinZipDatasets',
'ShardedIterator',
'SortDataset',
'TokenBlockDataset',
'TransformEosDataset',
'TransformEosLangPairDataset',

View File

@ -0,0 +1,53 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from torch.utils.data._utils.collate import default_collate
from . import FairseqDataset
class BaseWrapperDataset(FairseqDataset):
def __init__(self, dataset):
super().__init__()
self.dataset = dataset
def __getitem__(self, index):
return self.dataset[index]
def __len__(self):
return len(self.dataset)
def collater(self, samples):
if hasattr(self.dataset, 'collater'):
return self.dataset.collater(samples)
else:
return default_collate(samples)
@property
def sizes(self):
return self.dataset.sizes
def num_tokens(self, index):
return self.dataset.num_tokens(index)
def size(self, index):
return self.dataset.size(index)
def ordered_indices(self):
return self.dataset.ordered_indices()
@property
def supports_prefetch(self):
return getattr(self.dataset, 'supports_prefetch', False)
def prefetch(self, indices):
self.dataset.prefetch(indices)
def set_epoch(self, epoch):
super().set_epoch(epoch)
self.dataset.set_epoch(epoch)

View File

@ -0,0 +1,22 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import torch
from . import FairseqDataset
class IdDataset(FairseqDataset):
def __getitem__(self, index):
return index
def __len__(self):
return 0
def collater(self, samples):
return torch.tensor(samples)

View File

@ -0,0 +1,24 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from functools import lru_cache
from . import BaseWrapperDataset
class LRUCacheDataset(BaseWrapperDataset):
def __init__(self, dataset, token=None):
super().__init__(dataset)
@lru_cache(maxsize=8)
def __getitem__(self, index):
return self.dataset[index]
@lru_cache(maxsize=8)
def collater(self, samples):
return self.dataset.collater(samples)

View File

@ -0,0 +1,174 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from functools import lru_cache
import numpy as np
import torch
from fairseq.data import data_utils, Dictionary
from . import BaseWrapperDataset, LRUCacheDataset
class MaskTokensDataset(BaseWrapperDataset):
"""
A wrapper Dataset for masked language modeling.
Input items are masked according to the specified masking probability.
Args:
dataset: Dataset to wrap.
sizes: Sentence lengths
vocab: Dictionary with the vocabulary and special tokens.
pad_idx: Id of pad token in vocab
mask_idx: Id of mask token in vocab
return_masked_tokens: controls whether to return the non-masked tokens
(the default) or to return a tensor with the original masked token
IDs (and *pad_idx* elsewhere). The latter is useful as targets for
masked LM training.
seed: Seed for random number generator for reproducibility.
mask_prob: probability of replacing a token with *mask_idx*.
leave_unmasked_prob: probability that a masked token is unmasked.
random_token_prob: probability of replacing a masked token with a
random token from the vocabulary.
freq_weighted_replacement: sample random replacement words based on
word frequencies in the vocab.
mask_whole_words: only mask whole words. This should be a byte mask
over vocab indices, indicating whether it is the beginning of a
word. We will extend any mask to encompass the whole word.
bpe: BPE to use for whole-word masking.
"""
@classmethod
def apply_mask(cls, dataset: torch.utils.data.Dataset, *args, **kwargs):
"""Return the source and target datasets for masked LM training."""
dataset = LRUCacheDataset(dataset)
return (
LRUCacheDataset(cls(dataset, *args, **kwargs, return_masked_tokens=False)),
LRUCacheDataset(cls(dataset, *args, **kwargs, return_masked_tokens=True)),
)
def __init__(
self,
dataset: torch.utils.data.Dataset,
vocab: Dictionary,
pad_idx: int,
mask_idx: int,
return_masked_tokens: bool = False,
seed: int = 1,
mask_prob: float = 0.15,
leave_unmasked_prob: float = 0.1,
random_token_prob: float = 0.1,
freq_weighted_replacement: bool = False,
mask_whole_words: torch.Tensor = None,
):
assert 0.0 < mask_prob < 1.0
assert 0.0 <= random_token_prob <= 1.0
assert 0.0 <= leave_unmasked_prob <= 1.0
assert random_token_prob + leave_unmasked_prob <= 1.0
self.dataset = dataset
self.vocab = vocab
self.pad_idx = pad_idx
self.mask_idx = mask_idx
self.return_masked_tokens = return_masked_tokens
self.seed = seed
self.mask_prob = mask_prob
self.leave_unmasked_prob = leave_unmasked_prob
self.random_token_prob = random_token_prob
self.mask_whole_words = mask_whole_words
if random_token_prob > 0.0:
if freq_weighted_replacement:
weights = np.array(self.vocab.count)
else:
weights = np.ones(len(self.vocab))
weights[:self.vocab.nspecial] = 0
self.weights = weights / weights.sum()
self.epoch = 0
def set_epoch(self, epoch, **unused):
self.epoch = epoch
@lru_cache(maxsize=8)
def __getitem__(self, index: int):
with data_utils.numpy_seed(self.seed, self.epoch, index):
item = self.dataset[index]
sz = len(item)
assert self.mask_idx not in item, \
'Dataset contains mask_idx (={}), this is not expected!'.format(
self.mask_idx,
)
if self.mask_whole_words is not None:
word_begins_mask = self.mask_whole_words.gather(0, item)
word_begins_idx = word_begins_mask.nonzero().view(-1)
sz = len(word_begins_idx)
words = np.split(word_begins_mask, word_begins_idx)[1:]
assert len(words) == sz
word_lens = list(map(len, words))
# decide elements to mask
mask = np.full(sz, False)
num_mask = int(
# add a random number for probabilistic rounding
self.mask_prob * sz + np.random.rand()
)
mask[np.random.choice(sz, num_mask, replace=False)] = True
if self.return_masked_tokens:
# exit early if we're just returning the masked tokens
# (i.e., the targets for masked LM training)
if self.mask_whole_words is not None:
mask = np.repeat(mask, word_lens)
new_item = np.full(len(mask), self.pad_idx)
new_item[mask] = item[torch.from_numpy(mask)]
return torch.from_numpy(new_item)
# decide unmasking and random replacement
rand_or_unmask_prob = self.random_token_prob + self.leave_unmasked_prob
if rand_or_unmask_prob > 0.0:
rand_or_unmask = mask & (np.random.rand(sz) < rand_or_unmask_prob)
if self.random_token_prob == 0.0:
unmask = rand_or_unmask
rand_mask = None
elif self.leave_unmasked_prob == 0.0:
unmask = None
rand_mask = rand_or_unmask
else:
unmask_prob = self.leave_unmasked_prob / rand_or_unmask_prob
decision = np.random.rand(sz) < unmask_prob
unmask = rand_or_unmask & decision
rand_mask = rand_or_unmask & (~decision)
else:
unmask = rand_mask = None
if unmask is not None:
mask = mask ^ unmask
if self.mask_whole_words is not None:
mask = np.repeat(mask, word_lens)
new_item = np.copy(item)
new_item[mask] = self.mask_idx
if rand_mask is not None:
num_rand = rand_mask.sum()
if num_rand > 0:
if self.mask_whole_words is not None:
rand_mask = np.repeat(rand_mask, word_lens)
num_rand = rand_mask.sum()
new_item[rand_mask] = np.random.choice(
len(self.vocab),
num_rand,
p=self.weights,
)
return torch.from_numpy(new_item)

View File

@ -0,0 +1,118 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from collections import OrderedDict
import torch
from torch.utils.data.dataloader import default_collate
from . import FairseqDataset
def _flatten(dico, prefix=None):
"""Flatten a nested dictionary."""
new_dico = OrderedDict()
if isinstance(dico, dict):
prefix = prefix + '.' if prefix is not None else ''
for k, v in dico.items():
if v is None:
continue
new_dico.update(_flatten(v, prefix + k))
elif isinstance(dico, list):
for i, v in enumerate(dico):
new_dico.update(_flatten(v, prefix + '.[' + str(i) + ']'))
else:
new_dico = OrderedDict({prefix: dico})
return new_dico
def _unflatten(dico):
"""Unflatten a flattened dictionary into a nested dictionary."""
new_dico = OrderedDict()
for full_k, v in dico.items():
full_k = full_k.split('.')
node = new_dico
for k in full_k[:-1]:
if k.startswith('[') and k.endswith(']'):
k = int(k[1:-1])
if k not in node:
node[k] = OrderedDict()
node = node[k]
node[full_k[-1]] = v
return new_dico
class NestedDictionaryDataset(FairseqDataset):
def __init__(self, defn, sizes=None):
super().__init__()
self.defn = _flatten(defn)
self.sizes = [sizes] if not isinstance(sizes, (list, tuple)) else sizes
first = None
for v in self.defn.values():
if not isinstance(v, (FairseqDataset, torch.utils.data.Dataset, )):
raise ValueError('Expected Dataset but found: {}'.format(v.__class__))
first = first or v
if len(v) > 0:
assert len(v) == len(first), 'dataset lengths must match'
self._len = len(first)
def __getitem__(self, index):
return OrderedDict((k, ds[index]) for k, ds in self.defn.items())
def __len__(self):
return self._len
def collater(self, samples):
"""Merge a list of samples to form a mini-batch.
Args:
samples (List[dict]): samples to collate
Returns:
dict: a mini-batch suitable for forwarding with a Model
"""
if len(samples) == 0:
return {}
sample = OrderedDict()
for k, ds in self.defn.items():
try:
sample[k] = ds.collater([s[k] for s in samples])
except NotImplementedError:
sample[k] = default_collate([s[k] for s in samples])
return _unflatten(sample)
def num_tokens(self, index):
"""Return the number of tokens in a sample. This value is used to
enforce ``--max-tokens`` during batching."""
return max(s[index] for s in self.sizes)
def size(self, index):
"""Return an example's size as a float or tuple. This value is used when
filtering a dataset with ``--max-positions``."""
if len(self.sizes) == 1:
return self.sizes[0][index]
else:
return (s[index] for s in self.sizes)
@property
def supports_prefetch(self):
"""Whether this dataset supports prefetching."""
return any(ds.supports_prefetch for ds in self.defn.values())
def prefetch(self, indices):
"""Prefetch the data required for this epoch."""
for ds in self.defn.values():
if getattr(ds, 'supports_prefetch', False):
ds.prefetch(indices)
def set_epoch(self, epoch):
super().set_epoch(epoch)
for ds in self.defn.values():
ds.set_epoch(epoch)

View File

@ -0,0 +1,20 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from . import FairseqDataset
class NumSamplesDataset(FairseqDataset):
def __getitem__(self, index):
return 1
def __len__(self):
return 0
def collater(self, samples):
return sum(samples)

View File

@ -0,0 +1,34 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import numpy as np
import torch
from . import BaseWrapperDataset
class NumelDataset(BaseWrapperDataset):
def __init__(self, dataset, reduce=False):
super().__init__(dataset)
self.reduce = reduce
def __getitem__(self, index):
item = self.dataset[index]
if torch.is_tensor(item):
return torch.numel(item)
else:
return np.size(item)
def __len__(self):
return len(self.dataset)
def collater(self, samples):
if self.reduce:
return sum(samples)
else:
return torch.tensor(samples)

View File

@ -0,0 +1,33 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from fairseq.data import data_utils
from . import BaseWrapperDataset
class PadDataset(BaseWrapperDataset):
def __init__(self, dataset, pad_idx, left_pad):
super().__init__(dataset)
self.pad_idx = pad_idx
self.left_pad = left_pad
def collater(self, samples):
return data_utils.collate_tokens(samples, self.pad_idx, left_pad=self.left_pad)
class LeftPadDataset(PadDataset):
def __init__(self, dataset, pad_idx):
super().__init__(dataset, pad_idx, left_pad=True)
class RightPadDataset(PadDataset):
def __init__(self, dataset, pad_idx):
super().__init__(dataset, pad_idx, left_pad=False)

View File

@ -0,0 +1,44 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import numpy as np
import torch
from . import BaseWrapperDataset
class PrependTokenDataset(BaseWrapperDataset):
def __init__(self, dataset, token=None):
super().__init__(dataset)
self.token = token
if token is not None:
self._sizes = np.array(dataset.sizes) + 1
else:
self._sizes = dataset.sizes
def __getitem__(self, idx):
item = self.dataset[idx]
if self.token is not None:
item = torch.cat([item.new([self.token]), item])
return item
@property
def sizes(self):
return self._sizes
def num_tokens(self, index):
n = self.dataset.num_tokens(index)
if self.token is not None:
n += 1
return n
def size(self, index):
n = self.dataset.size(index)
if self.token is not None:
n += 1
return n

View File

@ -0,0 +1,24 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import numpy as np
from . import BaseWrapperDataset
class SortDataset(BaseWrapperDataset):
def __init__(self, dataset, sort_order):
super().__init__(dataset)
if not isinstance(sort_order, (list, tuple)):
sort_order = [sort_order]
self.sort_order = sort_order
assert all(len(so) == len(dataset) for so in sort_order)
def ordered_indices(self):
return np.lexsort(self.sort_order)