mirror of
https://github.com/facebookresearch/fairseq.git
synced 2024-09-22 06:39:29 +03:00
Refactor PaddingCollater
This commit is contained in:
parent
4593ebfaf9
commit
2ad5888562
123
fairseq/data.py
123
fairseq/data.py
@ -70,8 +70,8 @@ def load(path, src, dst):
|
||||
dataset.splits[prefix] = LanguagePairDataset(
|
||||
IndexedInMemoryDataset(src_path),
|
||||
IndexedInMemoryDataset(fmt_path('{}.{}.{}', prefix, langcode, dst)),
|
||||
padding_value=dataset.src_dict.pad(),
|
||||
eos=dataset.src_dict.eos(),
|
||||
pad_idx=dataset.src_dict.pad(),
|
||||
eos_idx=dataset.src_dict.eos(),
|
||||
)
|
||||
|
||||
return dataset
|
||||
@ -85,6 +85,10 @@ class LanguageDatasets(object):
|
||||
self.dst_dict = dst_dict
|
||||
self.splits = {}
|
||||
|
||||
assert self.src_dict.pad() == self.dst_dict.pad()
|
||||
assert self.src_dict.eos() == self.dst_dict.eos()
|
||||
assert self.src_dict.unk() == self.dst_dict.unk()
|
||||
|
||||
def dataloader(self, split, batch_size=1, num_workers=0,
|
||||
max_tokens=None, seed=None, epoch=1,
|
||||
sample_without_replacement=0, max_positions=1024):
|
||||
@ -105,8 +109,9 @@ class LanguageDatasets(object):
|
||||
return torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
num_workers=num_workers,
|
||||
collate_fn=PaddingCollater(self.src_dict.pad()),
|
||||
batch_sampler=batch_sampler)
|
||||
collate_fn=dataset.collater,
|
||||
batch_sampler=batch_sampler,
|
||||
)
|
||||
|
||||
|
||||
def skip_group_enumerator(it, ngpus, offset=0):
|
||||
@ -124,67 +129,83 @@ def skip_group_enumerator(it, ngpus, offset=0):
|
||||
yield (idx, res)
|
||||
|
||||
|
||||
class PaddingCollater(object):
|
||||
def __init__(self, padding_value=1):
|
||||
self.padding_value = padding_value
|
||||
|
||||
def __call__(self, samples):
|
||||
def merge(key, pad_begin):
|
||||
return self.merge_with_pad([s[key] for s in samples], pad_begin)
|
||||
|
||||
ntokens = sum(len(s['target']) for s in samples)
|
||||
|
||||
return {
|
||||
'id': torch.LongTensor([s['id'].item() for s in samples]),
|
||||
'input_tokens': merge('input_tokens', pad_begin=True),
|
||||
'input_positions': merge('input_positions', pad_begin=True),
|
||||
'target': merge('target', pad_begin=True),
|
||||
'src_tokens': merge('src_tokens', pad_begin=False),
|
||||
'src_positions': merge('src_positions', pad_begin=False),
|
||||
'ntokens': ntokens,
|
||||
}
|
||||
|
||||
def merge_with_pad(self, values, pad_begin):
|
||||
size = max(v.size(0) for v in values)
|
||||
res = values[0].new(len(values), size).fill_(self.padding_value)
|
||||
for i, v in enumerate(values):
|
||||
if pad_begin:
|
||||
res[i][size-len(v):].copy_(v)
|
||||
else:
|
||||
res[i][:len(v)].copy_(v)
|
||||
return res
|
||||
|
||||
|
||||
class LanguagePairDataset(object):
|
||||
def __init__(self, src, dst, padding_value=1, eos=2):
|
||||
def __init__(self, src, dst, pad_idx, eos_idx):
|
||||
self.src = src
|
||||
self.dst = dst
|
||||
self.padding_value = padding_value
|
||||
self.eos = eos
|
||||
self.pad_idx = pad_idx
|
||||
self.eos_idx = eos_idx
|
||||
|
||||
def __getitem__(self, i):
|
||||
src = self.src[i].long() - 1
|
||||
# subtract 1 for 0-based indexing
|
||||
source = self.src[i].long() - 1
|
||||
target = self.dst[i].long() - 1
|
||||
input = target.new(target.size())
|
||||
input[0] = self.eos
|
||||
input[1:].copy_(target[:-1])
|
||||
|
||||
return {
|
||||
'id': i,
|
||||
'input_tokens': input,
|
||||
'input_positions': self.make_positions(input),
|
||||
'source': source,
|
||||
'target': target,
|
||||
'src_tokens': src,
|
||||
'src_positions': self.make_positions(src),
|
||||
}
|
||||
|
||||
def make_positions(self, x):
|
||||
start = self.padding_value + 1
|
||||
return torch.arange(start, start + len(x)).type_as(x)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.src)
|
||||
|
||||
def collater(self, samples):
|
||||
return LanguagePairDataset.collate(samples, self.pad_idx, self.eos_idx)
|
||||
|
||||
@staticmethod
|
||||
def collate(samples, pad_idx, eos_idx):
|
||||
|
||||
def merge(key, left_pad, move_eos_to_beginning=False):
|
||||
return LanguagePairDataset.collate_tokens(
|
||||
[s[key] for s in samples], pad_idx, eos_idx, left_pad, move_eos_to_beginning)
|
||||
|
||||
def merge_positions(key, left_pad):
|
||||
return LanguagePairDataset.collate_positions([s[key] for s in samples], pad_idx, left_pad)
|
||||
|
||||
ntokens = sum(len(s['target']) for s in samples)
|
||||
return {
|
||||
'id': torch.LongTensor([s['id'].item() for s in samples]),
|
||||
'input_tokens': merge('target', left_pad=True, move_eos_to_beginning=True),
|
||||
'input_positions': merge_positions('target', left_pad=True),
|
||||
'target': merge('target', left_pad=True),
|
||||
'src_tokens': merge('source', left_pad=False),
|
||||
'src_positions': merge_positions('source', left_pad=False),
|
||||
'ntokens': ntokens,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def collate_tokens(values, pad_idx, eos_idx, left_pad, move_eos_to_beginning):
|
||||
size = max(v.size(0) for v in values)
|
||||
res = values[0].new(len(values), size).fill_(pad_idx)
|
||||
|
||||
def copy_tensor(src, dst):
|
||||
assert dst.numel() == src.numel()
|
||||
if move_eos_to_beginning:
|
||||
assert src[-1] == eos_idx
|
||||
dst[0] = eos_idx
|
||||
dst[1:] = src[:-1]
|
||||
else:
|
||||
dst.copy_(src)
|
||||
|
||||
for i, v in enumerate(values):
|
||||
if left_pad:
|
||||
copy_tensor(v, res[i][size-len(v):])
|
||||
else:
|
||||
copy_tensor(v, res[i][:len(v)])
|
||||
return res
|
||||
|
||||
@staticmethod
|
||||
def collate_positions(values, pad_idx, left_pad):
|
||||
start = pad_idx + 1
|
||||
size = max(v.size(0) for v in values)
|
||||
res = values[0].new(len(values), size).fill_(pad_idx)
|
||||
for i, v in enumerate(values):
|
||||
if left_pad:
|
||||
torch.arange(start, start + len(v), out=res[i][size-len(v):])
|
||||
else:
|
||||
torch.arange(start, start + len(v), out=res[i][:len(v)])
|
||||
return res
|
||||
|
||||
|
||||
def batches_by_size(src, batch_size=None, max_tokens=None, dst=None, max_positions=1024):
|
||||
"""Returns batches of indices sorted by size. Sequences of different lengths
|
||||
|
Loading…
Reference in New Issue
Block a user