Refactor PaddingCollater

This commit is contained in:
Myle Ott 2017-09-27 19:33:32 -07:00
parent 4593ebfaf9
commit 2ad5888562

View File

@ -70,8 +70,8 @@ def load(path, src, dst):
dataset.splits[prefix] = LanguagePairDataset(
IndexedInMemoryDataset(src_path),
IndexedInMemoryDataset(fmt_path('{}.{}.{}', prefix, langcode, dst)),
padding_value=dataset.src_dict.pad(),
eos=dataset.src_dict.eos(),
pad_idx=dataset.src_dict.pad(),
eos_idx=dataset.src_dict.eos(),
)
return dataset
@ -85,6 +85,10 @@ class LanguageDatasets(object):
self.dst_dict = dst_dict
self.splits = {}
assert self.src_dict.pad() == self.dst_dict.pad()
assert self.src_dict.eos() == self.dst_dict.eos()
assert self.src_dict.unk() == self.dst_dict.unk()
def dataloader(self, split, batch_size=1, num_workers=0,
max_tokens=None, seed=None, epoch=1,
sample_without_replacement=0, max_positions=1024):
@ -105,8 +109,9 @@ class LanguageDatasets(object):
return torch.utils.data.DataLoader(
dataset,
num_workers=num_workers,
collate_fn=PaddingCollater(self.src_dict.pad()),
batch_sampler=batch_sampler)
collate_fn=dataset.collater,
batch_sampler=batch_sampler,
)
def skip_group_enumerator(it, ngpus, offset=0):
@ -124,67 +129,83 @@ def skip_group_enumerator(it, ngpus, offset=0):
yield (idx, res)
class PaddingCollater(object):
def __init__(self, padding_value=1):
self.padding_value = padding_value
def __call__(self, samples):
def merge(key, pad_begin):
return self.merge_with_pad([s[key] for s in samples], pad_begin)
ntokens = sum(len(s['target']) for s in samples)
return {
'id': torch.LongTensor([s['id'].item() for s in samples]),
'input_tokens': merge('input_tokens', pad_begin=True),
'input_positions': merge('input_positions', pad_begin=True),
'target': merge('target', pad_begin=True),
'src_tokens': merge('src_tokens', pad_begin=False),
'src_positions': merge('src_positions', pad_begin=False),
'ntokens': ntokens,
}
def merge_with_pad(self, values, pad_begin):
size = max(v.size(0) for v in values)
res = values[0].new(len(values), size).fill_(self.padding_value)
for i, v in enumerate(values):
if pad_begin:
res[i][size-len(v):].copy_(v)
else:
res[i][:len(v)].copy_(v)
return res
class LanguagePairDataset(object):
def __init__(self, src, dst, padding_value=1, eos=2):
def __init__(self, src, dst, pad_idx, eos_idx):
self.src = src
self.dst = dst
self.padding_value = padding_value
self.eos = eos
self.pad_idx = pad_idx
self.eos_idx = eos_idx
def __getitem__(self, i):
src = self.src[i].long() - 1
# subtract 1 for 0-based indexing
source = self.src[i].long() - 1
target = self.dst[i].long() - 1
input = target.new(target.size())
input[0] = self.eos
input[1:].copy_(target[:-1])
return {
'id': i,
'input_tokens': input,
'input_positions': self.make_positions(input),
'source': source,
'target': target,
'src_tokens': src,
'src_positions': self.make_positions(src),
}
def make_positions(self, x):
start = self.padding_value + 1
return torch.arange(start, start + len(x)).type_as(x)
def __len__(self):
return len(self.src)
def collater(self, samples):
return LanguagePairDataset.collate(samples, self.pad_idx, self.eos_idx)
@staticmethod
def collate(samples, pad_idx, eos_idx):
def merge(key, left_pad, move_eos_to_beginning=False):
return LanguagePairDataset.collate_tokens(
[s[key] for s in samples], pad_idx, eos_idx, left_pad, move_eos_to_beginning)
def merge_positions(key, left_pad):
return LanguagePairDataset.collate_positions([s[key] for s in samples], pad_idx, left_pad)
ntokens = sum(len(s['target']) for s in samples)
return {
'id': torch.LongTensor([s['id'].item() for s in samples]),
'input_tokens': merge('target', left_pad=True, move_eos_to_beginning=True),
'input_positions': merge_positions('target', left_pad=True),
'target': merge('target', left_pad=True),
'src_tokens': merge('source', left_pad=False),
'src_positions': merge_positions('source', left_pad=False),
'ntokens': ntokens,
}
@staticmethod
def collate_tokens(values, pad_idx, eos_idx, left_pad, move_eos_to_beginning):
size = max(v.size(0) for v in values)
res = values[0].new(len(values), size).fill_(pad_idx)
def copy_tensor(src, dst):
assert dst.numel() == src.numel()
if move_eos_to_beginning:
assert src[-1] == eos_idx
dst[0] = eos_idx
dst[1:] = src[:-1]
else:
dst.copy_(src)
for i, v in enumerate(values):
if left_pad:
copy_tensor(v, res[i][size-len(v):])
else:
copy_tensor(v, res[i][:len(v)])
return res
@staticmethod
def collate_positions(values, pad_idx, left_pad):
start = pad_idx + 1
size = max(v.size(0) for v in values)
res = values[0].new(len(values), size).fill_(pad_idx)
for i, v in enumerate(values):
if left_pad:
torch.arange(start, start + len(v), out=res[i][size-len(v):])
else:
torch.arange(start, start + len(v), out=res[i][:len(v)])
return res
def batches_by_size(src, batch_size=None, max_tokens=None, dst=None, max_positions=1024):
"""Returns batches of indices sorted by size. Sequences of different lengths