TPU support for Translation (#2245)

Summary: Pull Request resolved: https://github.com/pytorch/fairseq/pull/2245

Reviewed By: ngoyal2707

Differential Revision: D22070745

Pulled By: myleott

fbshipit-source-id: e43a96a585366b10d997a12522e8cd6496294ad2
This commit is contained in:
Myle Ott 2020-06-24 09:54:46 -07:00 committed by Facebook GitHub Bot
parent a12c5c5de8
commit da94e58c70
12 changed files with 314 additions and 53 deletions

View File

@ -282,8 +282,6 @@ following contents::
tgt_sizes=torch.ones(len(labels)), # targets have length 1
tgt_dict=self.label_vocab,
left_pad_source=False,
max_source_positions=self.args.max_positions,
max_target_positions=1,
# Since our target is a single class label, there's no need for
# teacher forcing. If we set this to ``True`` then our Model's
# ``forward()`` method would receive an additional argument called

View File

@ -12,6 +12,7 @@ from .base_wrapper_dataset import BaseWrapperDataset
from .append_token_dataset import AppendTokenDataset
from .audio.raw_audio_dataset import FileAudioDataset
from .backtranslation_dataset import BacktranslationDataset
from .bucket_pad_length_dataset import BucketPadLengthDataset
from .colorize_dataset import ColorizeDataset
from .concat_dataset import ConcatDataset
from .concat_sentences_dataset import ConcatSentencesDataset
@ -57,6 +58,7 @@ __all__ = [
'AppendTokenDataset',
'BacktranslationDataset',
'BaseWrapperDataset',
'BucketPadLengthDataset',
'ColorizeDataset',
'ConcatDataset',
'ConcatSentencesDataset',

View File

@ -0,0 +1,77 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import numpy as np
import torch.nn.functional as F
from fairseq.data import BaseWrapperDataset
class BucketPadLengthDataset(BaseWrapperDataset):
"""
Bucket and pad item lengths to the nearest bucket size. This can be used to
reduce the number of unique batch shapes, which is important on TPUs since
each new batch shape requires a recompilation.
Args:
dataset (FairseqDatset): dataset to bucket
sizes (List[int]): all item sizes
num_buckets (int): number of buckets to create
pad_idx (int): padding symbol
left_pad (bool): if True, pad on the left; otherwise right pad
"""
def __init__(
self,
dataset,
sizes,
num_buckets,
pad_idx,
left_pad,
):
super().__init__(dataset)
self.pad_idx = pad_idx
self.left_pad = left_pad
assert num_buckets > 0
self.buckets = np.unique(
np.percentile(
sizes,
np.linspace(0, 100, num_buckets + 1),
interpolation='lower',
)[1:]
)
def get_bucketed_sizes(orig_sizes, buckets):
sizes = np.copy(orig_sizes)
assert np.min(sizes) >= 0
start_val = -1
for end_val in buckets:
mask = (sizes > start_val) & (sizes <= end_val)
sizes[mask] = end_val
start_val = end_val
return sizes
self._bucketed_sizes = get_bucketed_sizes(sizes, self.buckets)
def __getitem__(self, index):
item = self.dataset[index]
bucket_size = self._bucketed_sizes[index]
num_pad = bucket_size - item.size(-1)
return F.pad(
item,
(num_pad if self.left_pad else 0, 0 if self.left_pad else num_pad),
value=self.pad_idx,
)
@property
def sizes(self):
return self._bucketed_sizes
def num_tokens(self, index):
return self._bucketed_sizes[index]
def size(self, index):
return self._bucketed_sizes[index]

View File

@ -199,7 +199,7 @@ def filter_by_size(indices, dataset, max_positions, raise_exception=False):
def batch_by_size(
indices, num_tokens_fn, max_tokens=None, max_sentences=None,
required_batch_size_multiple=1,
required_batch_size_multiple=1, fixed_shapes=None,
):
"""
Yield mini-batches of indices bucketed by size. Batches may contain
@ -214,10 +214,15 @@ def batch_by_size(
max_sentences (int, optional): max number of sentences in each
batch (default: None).
required_batch_size_multiple (int, optional): require batch size to
be a multiple of N (default: 1).
be less than N or a multiple of N (default: 1).
fixed_shapes (List[Tuple[int, int]], optional): if given, batches will
only be created with the given shapes. *max_sentences* and
*required_batch_size_multiple* will be ignored (default: None).
"""
try:
from fairseq.data.data_utils_fast import batch_by_size_fast
from fairseq.data.data_utils_fast import (
batch_by_size_fast, batch_fixed_shapes_fast,
)
except ImportError:
raise ImportError(
'Please build Cython components with: `pip install --editable .` '
@ -228,10 +233,21 @@ def batch_by_size(
max_sentences = max_sentences if max_sentences is not None else -1
bsz_mult = required_batch_size_multiple
if isinstance(indices, types.GeneratorType):
if not isinstance(indices, np.ndarray):
indices = np.fromiter(indices, dtype=np.int64, count=-1)
return batch_by_size_fast(indices, num_tokens_fn, max_tokens, max_sentences, bsz_mult)
if fixed_shapes is None:
return batch_by_size_fast(
indices, num_tokens_fn, max_tokens, max_sentences, bsz_mult,
)
else:
fixed_shapes = np.array(fixed_shapes, dtype=np.int64)
sort_order = np.lexsort([
fixed_shapes[:, 1].argsort(), # length
fixed_shapes[:, 0].argsort(), # bsz
])
fixed_shapes_sorted = fixed_shapes[sort_order]
return batch_fixed_shapes_fast(indices, num_tokens_fn, fixed_shapes_sorted)
def process_bpe_symbol(sentence: str, bpe_symbol: str):

View File

@ -13,10 +13,10 @@ DTYPE = np.int64
ctypedef np.int64_t DTYPE_t
cdef _is_batch_full(list batch, long num_tokens, long max_tokens, long max_sentences):
if len(batch) == 0:
cdef _is_batch_full(long num_sentences, long num_tokens, long max_tokens, long max_sentences):
if num_sentences == 0:
return 0
if max_sentences > 0 and len(batch) == max_sentences:
if max_sentences > 0 and num_sentences == max_sentences:
return 1
if max_tokens > 0 and num_tokens > max_tokens:
return 1
@ -53,7 +53,7 @@ cpdef list batch_by_size_fast(
)
num_tokens = (len(batch) + 1) * sample_len
if _is_batch_full(batch, num_tokens, max_tokens, max_sentences):
if _is_batch_full(len(batch), num_tokens, max_tokens, max_sentences):
mod_len = max(
bsz_mult * (len(batch) // bsz_mult),
len(batch) % bsz_mult,
@ -66,3 +66,57 @@ cpdef list batch_by_size_fast(
if len(batch) > 0:
batches.append(batch)
return batches
cdef _find_valid_shape(
DTYPE_t[:, :] shapes_view,
long num_sentences,
long num_tokens,
):
"""Return index of first valid shape of -1 if none is found."""
for i in range(shapes_view.shape[0]):
if num_sentences <= shapes_view[i][0] and num_tokens <= shapes_view[i][1]:
return i
return -1
@cython.cdivision(True)
cpdef list batch_fixed_shapes_fast(
np.ndarray[DTYPE_t, ndim=1] indices,
num_tokens_fn,
np.ndarray[DTYPE_t, ndim=2] fixed_shapes_sorted,
):
cdef long sample_len = 0
cdef list sample_lens = []
cdef list batch = []
cdef list batches = []
cdef long mod_len
cdef long i
cdef long idx
cdef long num_tokens
cdef DTYPE_t[:] indices_view = indices
cdef DTYPE_t[:, :] shapes_view = fixed_shapes_sorted
for i in range(len(indices_view)):
idx = indices_view[i]
num_tokens = num_tokens_fn(idx)
sample_lens.append(num_tokens)
sample_len = max(sample_len, num_tokens)
shape_idx = _find_valid_shape(shapes_view, len(batch) + 1, sample_len)
if shape_idx == -1:
batches.append(batch)
batch = []
sample_lens = []
sample_len = 0
shapes_view = fixed_shapes_sorted
elif shape_idx > 0:
# small optimization for the next call to _find_valid_shape
shapes_view = shapes_view[shape_idx:]
batch.append(idx)
if len(batch) > 0:
batches.append(batch)
return batches

View File

@ -62,6 +62,66 @@ class FairseqDataset(torch.utils.data.Dataset, EpochListening):
"""Prefetch the data required for this epoch."""
raise NotImplementedError
def get_batch_shapes(self):
"""
Return a list of valid batch shapes, for example::
[(8, 512), (16, 256), (32, 128)]
The first dimension of each tuple is the batch size and can be ``None``
to automatically infer the max batch size based on ``--max-tokens``.
The second dimension of each tuple is the max supported length as given
by :func:`fairseq.data.FairseqDataset.num_tokens`.
This will be used by :func:`fairseq.data.FairseqDataset.batch_by_size`
to restrict batch shapes. This is useful on TPUs to avoid too many
dynamic shapes (and recompilations).
"""
return None
def batch_by_size(
self,
indices,
max_tokens=None,
max_sentences=None,
required_batch_size_multiple=1,
):
"""
Given an ordered set of indices, return batches according to
*max_tokens*, *max_sentences* and *required_batch_size_multiple*.
"""
from fairseq.data import data_utils
fixed_shapes = self.get_batch_shapes()
if fixed_shapes is not None:
def adjust_bsz(bsz, num_tokens):
if bsz is None:
assert max_tokens is not None, 'Must specify --max-tokens'
bsz = max_tokens // num_tokens
if max_sentences is not None:
bsz = min(bsz, max_sentences)
elif (
bsz >= required_batch_size_multiple
and bsz % required_batch_size_multiple != 0
):
bsz -= (bsz % required_batch_size_multiple)
return bsz
fixed_shapes = np.array([
[adjust_bsz(bsz, num_tokens), num_tokens]
for (bsz, num_tokens) in fixed_shapes
])
return data_utils.batch_by_size(
indices,
num_tokens_fn=self.num_tokens,
max_tokens=max_tokens,
max_sentences=max_sentences,
required_batch_size_multiple=required_batch_size_multiple,
fixed_shapes=fixed_shapes,
)
class FairseqIterableDataset(torch.utils.data.IterableDataset, EpochListening):
"""For datasets that need to be read sequentially, usually because the data

View File

@ -8,14 +8,18 @@ import logging
import numpy as np
import torch
from . import data_utils, FairseqDataset
from fairseq.data import data_utils, FairseqDataset
logger = logging.getLogger(__name__)
def collate(
samples, pad_idx, eos_idx, left_pad_source=True, left_pad_target=False,
samples,
pad_idx,
eos_idx,
left_pad_source=True,
left_pad_target=False,
input_feeding=True,
):
if len(samples) == 0:
@ -52,7 +56,9 @@ def collate(
id = torch.LongTensor([s['id'] for s in samples])
src_tokens = merge('source', left_pad=left_pad_source)
# sort by descending source length
src_lengths = torch.LongTensor([s['source'].numel() for s in samples])
src_lengths = torch.LongTensor([
s['source'].ne(pad_idx).long().sum() for s in samples
])
src_lengths, sort_order = src_lengths.sort(descending=True)
id = id.index_select(0, sort_order)
src_tokens = src_tokens.index_select(0, sort_order)
@ -62,8 +68,10 @@ def collate(
if samples[0].get('target', None) is not None:
target = merge('target', left_pad=left_pad_target)
target = target.index_select(0, sort_order)
tgt_lengths = torch.LongTensor([s['target'].numel() for s in samples]).index_select(0, sort_order)
ntokens = sum(len(s['target']) for s in samples)
tgt_lengths = torch.LongTensor([
s['target'].ne(pad_idx).long().sum() for s in samples
]).index_select(0, sort_order)
ntokens = tgt_lengths.sum().item()
if input_feeding:
# we create a shifted version of targets for feeding the
@ -75,7 +83,7 @@ def collate(
)
prev_output_tokens = prev_output_tokens.index_select(0, sort_order)
else:
ntokens = sum(len(s['source']) for s in samples)
ntokens = src_lengths.sum().item()
batch = {
'id': id,
@ -133,10 +141,6 @@ class LanguagePairDataset(FairseqDataset):
(default: True).
left_pad_target (bool, optional): pad target tensors on the left side
(default: False).
max_source_positions (int, optional): max number of tokens in the
source sentence (default: 1024).
max_target_positions (int, optional): max number of tokens in the
target sentence (default: 1024).
shuffle (bool, optional): shuffle dataset elements before batching
(default: True).
input_feeding (bool, optional): create a shifted version of the targets
@ -149,17 +153,19 @@ class LanguagePairDataset(FairseqDataset):
containing alignments.
append_bos (bool, optional): if set, appends bos to the beginning of
source/target sentence.
num_buckets (int, optional): if set to a value greater than 0, then
batches will be bucketed into the given number of batch shapes.
"""
def __init__(
self, src, src_sizes, src_dict,
tgt=None, tgt_sizes=None, tgt_dict=None,
left_pad_source=True, left_pad_target=False,
max_source_positions=1024, max_target_positions=1024,
shuffle=True, input_feeding=True,
remove_eos_from_source=False, append_eos_to_target=False,
align_dataset=None,
append_bos=False, eos=None
append_bos=False, eos=None,
num_buckets=0,
):
if tgt_dict is not None:
assert src_dict.pad() == tgt_dict.pad()
@ -175,8 +181,6 @@ class LanguagePairDataset(FairseqDataset):
self.tgt_dict = tgt_dict
self.left_pad_source = left_pad_source
self.left_pad_target = left_pad_target
self.max_source_positions = max_source_positions
self.max_target_positions = max_target_positions
self.shuffle = shuffle
self.input_feeding = input_feeding
self.remove_eos_from_source = remove_eos_from_source
@ -187,6 +191,42 @@ class LanguagePairDataset(FairseqDataset):
self.append_bos = append_bos
self.eos = (eos if eos is not None else src_dict.eos())
if num_buckets > 0:
from fairseq.data import BucketPadLengthDataset
self.src = BucketPadLengthDataset(
self.src,
sizes=self.src_sizes,
num_buckets=num_buckets,
pad_idx=self.src_dict.pad(),
left_pad=self.left_pad_source,
)
self.src_sizes = self.src.sizes
logger.info('bucketing source lengths: {}'.format(list(self.src.buckets)))
if self.tgt is not None:
self.tgt = BucketPadLengthDataset(
self.tgt,
sizes=self.tgt_sizes,
num_buckets=num_buckets,
pad_idx=self.tgt_dict.pad(),
left_pad=self.left_pad_target,
)
self.tgt_sizes = self.tgt.sizes
logger.info('bucketing target lengths: {}'.format(list(self.tgt.buckets)))
# determine bucket sizes using self.num_tokens, which will return
# the padded lengths (thanks to BucketPadLengthDataset)
num_tokens = np.vectorize(self.num_tokens, otypes=[np.long])
self.bucketed_num_tokens = num_tokens(np.arange(len(self.src)))
self.buckets = [
(None, num_tokens)
for num_tokens in np.unique(self.bucketed_num_tokens)
]
else:
self.buckets = None
def get_batch_shapes(self):
return self.buckets
def __getitem__(self, index):
tgt_item = self.tgt[index] if self.tgt is not None else None
src_item = self.src[index]
@ -255,8 +295,11 @@ class LanguagePairDataset(FairseqDataset):
on the left if *left_pad_target* is ``True``.
"""
return collate(
samples, pad_idx=self.src_dict.pad(), eos_idx=self.eos,
left_pad_source=self.left_pad_source, left_pad_target=self.left_pad_target,
samples,
pad_idx=self.src_dict.pad(),
eos_idx=self.eos,
left_pad_source=self.left_pad_source,
left_pad_target=self.left_pad_target,
input_feeding=self.input_feeding,
)
@ -277,9 +320,19 @@ class LanguagePairDataset(FairseqDataset):
indices = np.random.permutation(len(self))
else:
indices = np.arange(len(self))
if self.tgt_sizes is not None:
indices = indices[np.argsort(self.tgt_sizes[indices], kind='mergesort')]
return indices[np.argsort(self.src_sizes[indices], kind='mergesort')]
if self.buckets is None:
# sort by target length, then source length
if self.tgt_sizes is not None:
indices = indices[
np.argsort(self.tgt_sizes[indices], kind='mergesort')
]
return indices[np.argsort(self.src_sizes[indices], kind='mergesort')]
else:
# sort by bucketed_num_tokens, which is:
# max(padded_src_len, padded_tgt_len)
return indices[
np.argsort(self.bucketed_num_tokens[indices], kind='mergesort')
]
@property
def supports_prefetch(self):

View File

@ -89,30 +89,28 @@ class TransformerEncoderLayer(nn.Module):
Args:
x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
encoder_padding_mask (ByteTensor): binary ByteTensor of shape
`(batch, src_len)` where padding elements are indicated by ``1``.
attn_mask (ByteTensor): binary tensor of shape (T_tgt, T_src), where
T_tgt is the length of query, while T_src is the length of key,
though here both query and key is x here,
attn_mask[t_tgt, t_src] = 1 means when calculating embedding
for t_tgt, t_src is excluded (or masked out), =0 means it is
included in attention
`(batch, seq_len)` where padding elements are indicated by ``1``.
attn_mask (ByteTensor): binary tensor of shape `(tgt_len, src_len)`,
where `tgt_len` is the length of output and `src_len` is the
length of input, though here both are equal to `seq_len`.
`attn_mask[tgt_i, src_j] = 1` means that when calculating the
embedding for `tgt_i`, we exclude (mask out) `src_j`. This is
useful for strided self-attention.
Returns:
encoded output of shape `(seq_len, batch, embed_dim)`
"""
residual = x
if self.normalize_before:
x = self.self_attn_layer_norm(x)
if attn_mask is not None:
attn_mask = attn_mask.masked_fill(attn_mask.to(torch.bool), -1e8)
# anything in original attn_mask = 1, becomes -1e8
# anything in original attn_mask = 0, becomes 0
# Note that we cannot use -inf here, because at some edge cases,
# the attention weight (before softmax) for some padded element in query
# will become -inf, which results in NaN in model parameters
# TODO: to formally solve this problem, we need to change fairseq's
# MultiheadAttention. We will do this later on.
if attn_mask is not None:
attn_mask = attn_mask.masked_fill(attn_mask.to(torch.bool), -1e8)
residual = x
if self.normalize_before:
x = self.self_attn_layer_norm(x)
x, _ = self.self_attn(
query=x,
key=x,

View File

@ -327,7 +327,8 @@ def add_dataset_args(parser, train=False, gen=False):
group.add_argument('--max-sentences', '--batch-size', type=int, metavar='N',
help='maximum number of sentences in a batch')
group.add_argument('--required-batch-size-multiple', default=8, type=int, metavar='N',
help='batch size will be a multiplier of this value')
help='batch size will either be less than this value, '
'or a multiple of this value')
parser.add_argument('--dataset-impl', metavar='FORMAT',
choices=get_available_dataset_impl(),
help='output dataset implementation')

View File

@ -173,9 +173,8 @@ class FairseqTask(object):
)
# create mini-batches with given size constraints
batch_sampler = data_utils.batch_by_size(
batch_sampler = dataset.batch_by_size(
indices,
dataset.num_tokens,
max_tokens=max_tokens,
max_sentences=max_sentences,
required_batch_size_multiple=required_batch_size_multiple,

View File

@ -264,8 +264,6 @@ class SemisupervisedTranslationTask(MultilingualTranslationTask):
tgt_dataset, tgt_dataset.sizes, self.dicts[tgt],
left_pad_source=self.args.left_pad_source,
left_pad_target=self.args.left_pad_target,
max_source_positions=self.args.max_source_positions,
max_target_positions=self.args.max_target_positions,
),
self.dicts[src].eos(),
src,

View File

@ -39,7 +39,8 @@ def load_langpair_dataset(
combine, dataset_impl, upsample_primary,
left_pad_source, left_pad_target, max_source_positions,
max_target_positions, prepend_bos=False, load_alignments=False,
truncate_source=False, append_source_id=False
truncate_source=False, append_source_id=False,
num_buckets=0,
):
def split_exists(split, src, tgt, lang, data_path):
@ -124,9 +125,8 @@ def load_langpair_dataset(
tgt_dataset, tgt_dataset_sizes, tgt_dict,
left_pad_source=left_pad_source,
left_pad_target=left_pad_target,
max_source_positions=max_source_positions,
max_target_positions=max_target_positions,
align_dataset=align_dataset, eos=eos
align_dataset=align_dataset, eos=eos,
num_buckets=num_buckets,
)
@ -176,6 +176,10 @@ class TranslationTask(FairseqTask):
help='amount to upsample primary dataset')
parser.add_argument('--truncate-source', action='store_true', default=False,
help='truncate source to max-source-positions')
parser.add_argument('--num-batch-buckets', default=0, type=int, metavar='N',
help='if >0, then bucket source and target lengths into N '
'buckets and pad accordingly; this is useful on TPUs '
'to minimize the number of compilations')
# options for reporting BLEU during validation
parser.add_argument('--eval-bleu', action='store_true',
@ -255,6 +259,7 @@ class TranslationTask(FairseqTask):
max_target_positions=self.args.max_target_positions,
load_alignments=self.args.load_alignments,
truncate_source=self.args.truncate_source,
num_buckets=self.args.num_batch_buckets,
)
def build_dataset_for_inference(self, src_tokens, src_lengths):