diff --git a/docs/tutorial_classifying_names.rst b/docs/tutorial_classifying_names.rst index b420d850b..e2b5a6716 100644 --- a/docs/tutorial_classifying_names.rst +++ b/docs/tutorial_classifying_names.rst @@ -282,8 +282,6 @@ following contents:: tgt_sizes=torch.ones(len(labels)), # targets have length 1 tgt_dict=self.label_vocab, left_pad_source=False, - max_source_positions=self.args.max_positions, - max_target_positions=1, # Since our target is a single class label, there's no need for # teacher forcing. If we set this to ``True`` then our Model's # ``forward()`` method would receive an additional argument called diff --git a/fairseq/data/__init__.py b/fairseq/data/__init__.py index 30c6e88d8..9bdb7a74a 100644 --- a/fairseq/data/__init__.py +++ b/fairseq/data/__init__.py @@ -12,6 +12,7 @@ from .base_wrapper_dataset import BaseWrapperDataset from .append_token_dataset import AppendTokenDataset from .audio.raw_audio_dataset import FileAudioDataset from .backtranslation_dataset import BacktranslationDataset +from .bucket_pad_length_dataset import BucketPadLengthDataset from .colorize_dataset import ColorizeDataset from .concat_dataset import ConcatDataset from .concat_sentences_dataset import ConcatSentencesDataset @@ -57,6 +58,7 @@ __all__ = [ 'AppendTokenDataset', 'BacktranslationDataset', 'BaseWrapperDataset', + 'BucketPadLengthDataset', 'ColorizeDataset', 'ConcatDataset', 'ConcatSentencesDataset', diff --git a/fairseq/data/bucket_pad_length_dataset.py b/fairseq/data/bucket_pad_length_dataset.py new file mode 100644 index 000000000..6f53d0118 --- /dev/null +++ b/fairseq/data/bucket_pad_length_dataset.py @@ -0,0 +1,77 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch.nn.functional as F + +from fairseq.data import BaseWrapperDataset + + +class BucketPadLengthDataset(BaseWrapperDataset): + """ + Bucket and pad item lengths to the nearest bucket size. This can be used to + reduce the number of unique batch shapes, which is important on TPUs since + each new batch shape requires a recompilation. + + Args: + dataset (FairseqDatset): dataset to bucket + sizes (List[int]): all item sizes + num_buckets (int): number of buckets to create + pad_idx (int): padding symbol + left_pad (bool): if True, pad on the left; otherwise right pad + """ + + def __init__( + self, + dataset, + sizes, + num_buckets, + pad_idx, + left_pad, + ): + super().__init__(dataset) + self.pad_idx = pad_idx + self.left_pad = left_pad + + assert num_buckets > 0 + self.buckets = np.unique( + np.percentile( + sizes, + np.linspace(0, 100, num_buckets + 1), + interpolation='lower', + )[1:] + ) + + def get_bucketed_sizes(orig_sizes, buckets): + sizes = np.copy(orig_sizes) + assert np.min(sizes) >= 0 + start_val = -1 + for end_val in buckets: + mask = (sizes > start_val) & (sizes <= end_val) + sizes[mask] = end_val + start_val = end_val + return sizes + + self._bucketed_sizes = get_bucketed_sizes(sizes, self.buckets) + + def __getitem__(self, index): + item = self.dataset[index] + bucket_size = self._bucketed_sizes[index] + num_pad = bucket_size - item.size(-1) + return F.pad( + item, + (num_pad if self.left_pad else 0, 0 if self.left_pad else num_pad), + value=self.pad_idx, + ) + + @property + def sizes(self): + return self._bucketed_sizes + + def num_tokens(self, index): + return self._bucketed_sizes[index] + + def size(self, index): + return self._bucketed_sizes[index] diff --git a/fairseq/data/data_utils.py b/fairseq/data/data_utils.py index ab82ea459..3b8f1afd2 100644 --- a/fairseq/data/data_utils.py +++ b/fairseq/data/data_utils.py @@ -199,7 +199,7 @@ def filter_by_size(indices, dataset, max_positions, raise_exception=False): def batch_by_size( indices, num_tokens_fn, max_tokens=None, max_sentences=None, - required_batch_size_multiple=1, + required_batch_size_multiple=1, fixed_shapes=None, ): """ Yield mini-batches of indices bucketed by size. Batches may contain @@ -214,10 +214,15 @@ def batch_by_size( max_sentences (int, optional): max number of sentences in each batch (default: None). required_batch_size_multiple (int, optional): require batch size to - be a multiple of N (default: 1). + be less than N or a multiple of N (default: 1). + fixed_shapes (List[Tuple[int, int]], optional): if given, batches will + only be created with the given shapes. *max_sentences* and + *required_batch_size_multiple* will be ignored (default: None). """ try: - from fairseq.data.data_utils_fast import batch_by_size_fast + from fairseq.data.data_utils_fast import ( + batch_by_size_fast, batch_fixed_shapes_fast, + ) except ImportError: raise ImportError( 'Please build Cython components with: `pip install --editable .` ' @@ -228,10 +233,21 @@ def batch_by_size( max_sentences = max_sentences if max_sentences is not None else -1 bsz_mult = required_batch_size_multiple - if isinstance(indices, types.GeneratorType): + if not isinstance(indices, np.ndarray): indices = np.fromiter(indices, dtype=np.int64, count=-1) - return batch_by_size_fast(indices, num_tokens_fn, max_tokens, max_sentences, bsz_mult) + if fixed_shapes is None: + return batch_by_size_fast( + indices, num_tokens_fn, max_tokens, max_sentences, bsz_mult, + ) + else: + fixed_shapes = np.array(fixed_shapes, dtype=np.int64) + sort_order = np.lexsort([ + fixed_shapes[:, 1].argsort(), # length + fixed_shapes[:, 0].argsort(), # bsz + ]) + fixed_shapes_sorted = fixed_shapes[sort_order] + return batch_fixed_shapes_fast(indices, num_tokens_fn, fixed_shapes_sorted) def process_bpe_symbol(sentence: str, bpe_symbol: str): diff --git a/fairseq/data/data_utils_fast.pyx b/fairseq/data/data_utils_fast.pyx index 6fa8acc09..c1f97bf5b 100644 --- a/fairseq/data/data_utils_fast.pyx +++ b/fairseq/data/data_utils_fast.pyx @@ -13,10 +13,10 @@ DTYPE = np.int64 ctypedef np.int64_t DTYPE_t -cdef _is_batch_full(list batch, long num_tokens, long max_tokens, long max_sentences): - if len(batch) == 0: +cdef _is_batch_full(long num_sentences, long num_tokens, long max_tokens, long max_sentences): + if num_sentences == 0: return 0 - if max_sentences > 0 and len(batch) == max_sentences: + if max_sentences > 0 and num_sentences == max_sentences: return 1 if max_tokens > 0 and num_tokens > max_tokens: return 1 @@ -53,7 +53,7 @@ cpdef list batch_by_size_fast( ) num_tokens = (len(batch) + 1) * sample_len - if _is_batch_full(batch, num_tokens, max_tokens, max_sentences): + if _is_batch_full(len(batch), num_tokens, max_tokens, max_sentences): mod_len = max( bsz_mult * (len(batch) // bsz_mult), len(batch) % bsz_mult, @@ -66,3 +66,57 @@ cpdef list batch_by_size_fast( if len(batch) > 0: batches.append(batch) return batches + + +cdef _find_valid_shape( + DTYPE_t[:, :] shapes_view, + long num_sentences, + long num_tokens, +): + """Return index of first valid shape of -1 if none is found.""" + for i in range(shapes_view.shape[0]): + if num_sentences <= shapes_view[i][0] and num_tokens <= shapes_view[i][1]: + return i + return -1 + + +@cython.cdivision(True) +cpdef list batch_fixed_shapes_fast( + np.ndarray[DTYPE_t, ndim=1] indices, + num_tokens_fn, + np.ndarray[DTYPE_t, ndim=2] fixed_shapes_sorted, +): + cdef long sample_len = 0 + cdef list sample_lens = [] + cdef list batch = [] + cdef list batches = [] + cdef long mod_len + cdef long i + cdef long idx + cdef long num_tokens + cdef DTYPE_t[:] indices_view = indices + cdef DTYPE_t[:, :] shapes_view = fixed_shapes_sorted + + for i in range(len(indices_view)): + idx = indices_view[i] + num_tokens = num_tokens_fn(idx) + sample_lens.append(num_tokens) + sample_len = max(sample_len, num_tokens) + + shape_idx = _find_valid_shape(shapes_view, len(batch) + 1, sample_len) + if shape_idx == -1: + batches.append(batch) + batch = [] + sample_lens = [] + sample_len = 0 + shapes_view = fixed_shapes_sorted + elif shape_idx > 0: + # small optimization for the next call to _find_valid_shape + shapes_view = shapes_view[shape_idx:] + + batch.append(idx) + + if len(batch) > 0: + batches.append(batch) + + return batches diff --git a/fairseq/data/fairseq_dataset.py b/fairseq/data/fairseq_dataset.py index fe5681be5..b03c90ed4 100644 --- a/fairseq/data/fairseq_dataset.py +++ b/fairseq/data/fairseq_dataset.py @@ -62,6 +62,66 @@ class FairseqDataset(torch.utils.data.Dataset, EpochListening): """Prefetch the data required for this epoch.""" raise NotImplementedError + def get_batch_shapes(self): + """ + Return a list of valid batch shapes, for example:: + + [(8, 512), (16, 256), (32, 128)] + + The first dimension of each tuple is the batch size and can be ``None`` + to automatically infer the max batch size based on ``--max-tokens``. + The second dimension of each tuple is the max supported length as given + by :func:`fairseq.data.FairseqDataset.num_tokens`. + + This will be used by :func:`fairseq.data.FairseqDataset.batch_by_size` + to restrict batch shapes. This is useful on TPUs to avoid too many + dynamic shapes (and recompilations). + """ + return None + + def batch_by_size( + self, + indices, + max_tokens=None, + max_sentences=None, + required_batch_size_multiple=1, + ): + """ + Given an ordered set of indices, return batches according to + *max_tokens*, *max_sentences* and *required_batch_size_multiple*. + """ + from fairseq.data import data_utils + + fixed_shapes = self.get_batch_shapes() + if fixed_shapes is not None: + + def adjust_bsz(bsz, num_tokens): + if bsz is None: + assert max_tokens is not None, 'Must specify --max-tokens' + bsz = max_tokens // num_tokens + if max_sentences is not None: + bsz = min(bsz, max_sentences) + elif ( + bsz >= required_batch_size_multiple + and bsz % required_batch_size_multiple != 0 + ): + bsz -= (bsz % required_batch_size_multiple) + return bsz + + fixed_shapes = np.array([ + [adjust_bsz(bsz, num_tokens), num_tokens] + for (bsz, num_tokens) in fixed_shapes + ]) + + return data_utils.batch_by_size( + indices, + num_tokens_fn=self.num_tokens, + max_tokens=max_tokens, + max_sentences=max_sentences, + required_batch_size_multiple=required_batch_size_multiple, + fixed_shapes=fixed_shapes, + ) + class FairseqIterableDataset(torch.utils.data.IterableDataset, EpochListening): """For datasets that need to be read sequentially, usually because the data diff --git a/fairseq/data/language_pair_dataset.py b/fairseq/data/language_pair_dataset.py index d18a92d78..63c95911f 100644 --- a/fairseq/data/language_pair_dataset.py +++ b/fairseq/data/language_pair_dataset.py @@ -8,14 +8,18 @@ import logging import numpy as np import torch -from . import data_utils, FairseqDataset +from fairseq.data import data_utils, FairseqDataset logger = logging.getLogger(__name__) def collate( - samples, pad_idx, eos_idx, left_pad_source=True, left_pad_target=False, + samples, + pad_idx, + eos_idx, + left_pad_source=True, + left_pad_target=False, input_feeding=True, ): if len(samples) == 0: @@ -52,7 +56,9 @@ def collate( id = torch.LongTensor([s['id'] for s in samples]) src_tokens = merge('source', left_pad=left_pad_source) # sort by descending source length - src_lengths = torch.LongTensor([s['source'].numel() for s in samples]) + src_lengths = torch.LongTensor([ + s['source'].ne(pad_idx).long().sum() for s in samples + ]) src_lengths, sort_order = src_lengths.sort(descending=True) id = id.index_select(0, sort_order) src_tokens = src_tokens.index_select(0, sort_order) @@ -62,8 +68,10 @@ def collate( if samples[0].get('target', None) is not None: target = merge('target', left_pad=left_pad_target) target = target.index_select(0, sort_order) - tgt_lengths = torch.LongTensor([s['target'].numel() for s in samples]).index_select(0, sort_order) - ntokens = sum(len(s['target']) for s in samples) + tgt_lengths = torch.LongTensor([ + s['target'].ne(pad_idx).long().sum() for s in samples + ]).index_select(0, sort_order) + ntokens = tgt_lengths.sum().item() if input_feeding: # we create a shifted version of targets for feeding the @@ -75,7 +83,7 @@ def collate( ) prev_output_tokens = prev_output_tokens.index_select(0, sort_order) else: - ntokens = sum(len(s['source']) for s in samples) + ntokens = src_lengths.sum().item() batch = { 'id': id, @@ -133,10 +141,6 @@ class LanguagePairDataset(FairseqDataset): (default: True). left_pad_target (bool, optional): pad target tensors on the left side (default: False). - max_source_positions (int, optional): max number of tokens in the - source sentence (default: 1024). - max_target_positions (int, optional): max number of tokens in the - target sentence (default: 1024). shuffle (bool, optional): shuffle dataset elements before batching (default: True). input_feeding (bool, optional): create a shifted version of the targets @@ -149,17 +153,19 @@ class LanguagePairDataset(FairseqDataset): containing alignments. append_bos (bool, optional): if set, appends bos to the beginning of source/target sentence. + num_buckets (int, optional): if set to a value greater than 0, then + batches will be bucketed into the given number of batch shapes. """ def __init__( self, src, src_sizes, src_dict, tgt=None, tgt_sizes=None, tgt_dict=None, left_pad_source=True, left_pad_target=False, - max_source_positions=1024, max_target_positions=1024, shuffle=True, input_feeding=True, remove_eos_from_source=False, append_eos_to_target=False, align_dataset=None, - append_bos=False, eos=None + append_bos=False, eos=None, + num_buckets=0, ): if tgt_dict is not None: assert src_dict.pad() == tgt_dict.pad() @@ -175,8 +181,6 @@ class LanguagePairDataset(FairseqDataset): self.tgt_dict = tgt_dict self.left_pad_source = left_pad_source self.left_pad_target = left_pad_target - self.max_source_positions = max_source_positions - self.max_target_positions = max_target_positions self.shuffle = shuffle self.input_feeding = input_feeding self.remove_eos_from_source = remove_eos_from_source @@ -187,6 +191,42 @@ class LanguagePairDataset(FairseqDataset): self.append_bos = append_bos self.eos = (eos if eos is not None else src_dict.eos()) + if num_buckets > 0: + from fairseq.data import BucketPadLengthDataset + self.src = BucketPadLengthDataset( + self.src, + sizes=self.src_sizes, + num_buckets=num_buckets, + pad_idx=self.src_dict.pad(), + left_pad=self.left_pad_source, + ) + self.src_sizes = self.src.sizes + logger.info('bucketing source lengths: {}'.format(list(self.src.buckets))) + if self.tgt is not None: + self.tgt = BucketPadLengthDataset( + self.tgt, + sizes=self.tgt_sizes, + num_buckets=num_buckets, + pad_idx=self.tgt_dict.pad(), + left_pad=self.left_pad_target, + ) + self.tgt_sizes = self.tgt.sizes + logger.info('bucketing target lengths: {}'.format(list(self.tgt.buckets))) + + # determine bucket sizes using self.num_tokens, which will return + # the padded lengths (thanks to BucketPadLengthDataset) + num_tokens = np.vectorize(self.num_tokens, otypes=[np.long]) + self.bucketed_num_tokens = num_tokens(np.arange(len(self.src))) + self.buckets = [ + (None, num_tokens) + for num_tokens in np.unique(self.bucketed_num_tokens) + ] + else: + self.buckets = None + + def get_batch_shapes(self): + return self.buckets + def __getitem__(self, index): tgt_item = self.tgt[index] if self.tgt is not None else None src_item = self.src[index] @@ -255,8 +295,11 @@ class LanguagePairDataset(FairseqDataset): on the left if *left_pad_target* is ``True``. """ return collate( - samples, pad_idx=self.src_dict.pad(), eos_idx=self.eos, - left_pad_source=self.left_pad_source, left_pad_target=self.left_pad_target, + samples, + pad_idx=self.src_dict.pad(), + eos_idx=self.eos, + left_pad_source=self.left_pad_source, + left_pad_target=self.left_pad_target, input_feeding=self.input_feeding, ) @@ -277,9 +320,19 @@ class LanguagePairDataset(FairseqDataset): indices = np.random.permutation(len(self)) else: indices = np.arange(len(self)) - if self.tgt_sizes is not None: - indices = indices[np.argsort(self.tgt_sizes[indices], kind='mergesort')] - return indices[np.argsort(self.src_sizes[indices], kind='mergesort')] + if self.buckets is None: + # sort by target length, then source length + if self.tgt_sizes is not None: + indices = indices[ + np.argsort(self.tgt_sizes[indices], kind='mergesort') + ] + return indices[np.argsort(self.src_sizes[indices], kind='mergesort')] + else: + # sort by bucketed_num_tokens, which is: + # max(padded_src_len, padded_tgt_len) + return indices[ + np.argsort(self.bucketed_num_tokens[indices], kind='mergesort') + ] @property def supports_prefetch(self): diff --git a/fairseq/modules/transformer_layer.py b/fairseq/modules/transformer_layer.py index 8fb08b3aa..854e2437c 100644 --- a/fairseq/modules/transformer_layer.py +++ b/fairseq/modules/transformer_layer.py @@ -89,30 +89,28 @@ class TransformerEncoderLayer(nn.Module): Args: x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)` encoder_padding_mask (ByteTensor): binary ByteTensor of shape - `(batch, src_len)` where padding elements are indicated by ``1``. - attn_mask (ByteTensor): binary tensor of shape (T_tgt, T_src), where - T_tgt is the length of query, while T_src is the length of key, - though here both query and key is x here, - attn_mask[t_tgt, t_src] = 1 means when calculating embedding - for t_tgt, t_src is excluded (or masked out), =0 means it is - included in attention + `(batch, seq_len)` where padding elements are indicated by ``1``. + attn_mask (ByteTensor): binary tensor of shape `(tgt_len, src_len)`, + where `tgt_len` is the length of output and `src_len` is the + length of input, though here both are equal to `seq_len`. + `attn_mask[tgt_i, src_j] = 1` means that when calculating the + embedding for `tgt_i`, we exclude (mask out) `src_j`. This is + useful for strided self-attention. Returns: encoded output of shape `(seq_len, batch, embed_dim)` """ - residual = x - if self.normalize_before: - x = self.self_attn_layer_norm(x) - if attn_mask is not None: - attn_mask = attn_mask.masked_fill(attn_mask.to(torch.bool), -1e8) # anything in original attn_mask = 1, becomes -1e8 # anything in original attn_mask = 0, becomes 0 # Note that we cannot use -inf here, because at some edge cases, # the attention weight (before softmax) for some padded element in query # will become -inf, which results in NaN in model parameters - # TODO: to formally solve this problem, we need to change fairseq's - # MultiheadAttention. We will do this later on. + if attn_mask is not None: + attn_mask = attn_mask.masked_fill(attn_mask.to(torch.bool), -1e8) + residual = x + if self.normalize_before: + x = self.self_attn_layer_norm(x) x, _ = self.self_attn( query=x, key=x, diff --git a/fairseq/options.py b/fairseq/options.py index 9ccd58941..1972d4b85 100644 --- a/fairseq/options.py +++ b/fairseq/options.py @@ -327,7 +327,8 @@ def add_dataset_args(parser, train=False, gen=False): group.add_argument('--max-sentences', '--batch-size', type=int, metavar='N', help='maximum number of sentences in a batch') group.add_argument('--required-batch-size-multiple', default=8, type=int, metavar='N', - help='batch size will be a multiplier of this value') + help='batch size will either be less than this value, ' + 'or a multiple of this value') parser.add_argument('--dataset-impl', metavar='FORMAT', choices=get_available_dataset_impl(), help='output dataset implementation') diff --git a/fairseq/tasks/fairseq_task.py b/fairseq/tasks/fairseq_task.py index bd9d75abd..f58ccea8c 100644 --- a/fairseq/tasks/fairseq_task.py +++ b/fairseq/tasks/fairseq_task.py @@ -173,9 +173,8 @@ class FairseqTask(object): ) # create mini-batches with given size constraints - batch_sampler = data_utils.batch_by_size( + batch_sampler = dataset.batch_by_size( indices, - dataset.num_tokens, max_tokens=max_tokens, max_sentences=max_sentences, required_batch_size_multiple=required_batch_size_multiple, diff --git a/fairseq/tasks/semisupervised_translation.py b/fairseq/tasks/semisupervised_translation.py index bf770bfe1..3f919be6f 100644 --- a/fairseq/tasks/semisupervised_translation.py +++ b/fairseq/tasks/semisupervised_translation.py @@ -264,8 +264,6 @@ class SemisupervisedTranslationTask(MultilingualTranslationTask): tgt_dataset, tgt_dataset.sizes, self.dicts[tgt], left_pad_source=self.args.left_pad_source, left_pad_target=self.args.left_pad_target, - max_source_positions=self.args.max_source_positions, - max_target_positions=self.args.max_target_positions, ), self.dicts[src].eos(), src, diff --git a/fairseq/tasks/translation.py b/fairseq/tasks/translation.py index 6e6ea5596..c3237aa96 100644 --- a/fairseq/tasks/translation.py +++ b/fairseq/tasks/translation.py @@ -39,7 +39,8 @@ def load_langpair_dataset( combine, dataset_impl, upsample_primary, left_pad_source, left_pad_target, max_source_positions, max_target_positions, prepend_bos=False, load_alignments=False, - truncate_source=False, append_source_id=False + truncate_source=False, append_source_id=False, + num_buckets=0, ): def split_exists(split, src, tgt, lang, data_path): @@ -124,9 +125,8 @@ def load_langpair_dataset( tgt_dataset, tgt_dataset_sizes, tgt_dict, left_pad_source=left_pad_source, left_pad_target=left_pad_target, - max_source_positions=max_source_positions, - max_target_positions=max_target_positions, - align_dataset=align_dataset, eos=eos + align_dataset=align_dataset, eos=eos, + num_buckets=num_buckets, ) @@ -176,6 +176,10 @@ class TranslationTask(FairseqTask): help='amount to upsample primary dataset') parser.add_argument('--truncate-source', action='store_true', default=False, help='truncate source to max-source-positions') + parser.add_argument('--num-batch-buckets', default=0, type=int, metavar='N', + help='if >0, then bucket source and target lengths into N ' + 'buckets and pad accordingly; this is useful on TPUs ' + 'to minimize the number of compilations') # options for reporting BLEU during validation parser.add_argument('--eval-bleu', action='store_true', @@ -255,6 +259,7 @@ class TranslationTask(FairseqTask): max_target_positions=self.args.max_target_positions, load_alignments=self.args.load_alignments, truncate_source=self.args.truncate_source, + num_buckets=self.args.num_batch_buckets, ) def build_dataset_for_inference(self, src_tokens, src_lengths):