More updates for PyTorch (#114)

This commit is contained in:
Myle Ott 2018-03-01 14:04:08 -05:00 committed by GitHub
parent 3bde773d66
commit 6e4d370af9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 5 additions and 7 deletions

View File

@ -96,9 +96,9 @@ class LSTMEncoder(FairseqEncoder):
def forward(self, src_tokens, src_lengths):
if LanguagePairDataset.LEFT_PAD_SOURCE:
# convert left-padding to right-padding
src_tokens.data = utils.convert_padding_direction(
src_tokens.data,
src_lengths.data,
src_tokens = utils.convert_padding_direction(
src_tokens,
src_lengths,
self.padding_idx,
left_to_right=True,
)

View File

@ -289,8 +289,6 @@ def convert_padding_direction(
right_to_left=False,
left_to_right=False,
):
assert not isinstance(src_tokens, Variable)
assert not isinstance(src_lengths, Variable)
assert right_to_left ^ left_to_right
pad_mask = src_tokens.eq(padding_idx)
if pad_mask.max() == 0:

View File

@ -1,4 +1,4 @@
cffi
numpy
torch>=0.4.0
torch
tqdm

View File

@ -61,7 +61,7 @@ class TestUtils(unittest.TestCase):
def assertAlmostEqual(self, t1, t2):
self.assertEqual(t1.size(), t2.size(), "size mismatch")
self.assertLess((t1 - t2).abs().max(), 1e-4)
self.assertLess(utils.item((t1 - t2).abs().max()), 1e-4)
if __name__ == '__main__':