Fix indexing in TokenBlockDataset

Summary: Pull Request resolved: https://github.com/pytorch/fairseq/pull/719

Differential Revision: D15258483

Pulled By: myleott

fbshipit-source-id: dd00daa6f1c87264c1196a77dfffc8c876ebde7f
This commit is contained in:
Myle Ott 2019-05-08 06:11:05 -07:00 committed by Facebook Github Bot
parent 0cb45bcb12
commit eddcdf08e1

View File

@ -70,7 +70,7 @@ class TokenBlockDataset(FairseqDataset):
if not torch.is_tensor(sizes):
sizes = torch.tensor(sizes)
cumsum = torch.cumsum(sizes, dim=0)
self.slice_indices[0, 1] = sizes[0]
self.slice_indices[0] = [0, sizes[0]]
self.slice_indices[1:] = cumsum.unfold(0, 2, 1)
else:
raise ValueError('Invalid break_mode: ' + break_mode)