Refactor fairseq/test_noising with a word shuffle helper function (#340)

Summary:
Pull Request resolved: https://github.com/pytorch/fairseq/pull/340

This allows us to do a lot less copy paste when adding new word shuffle function tests

Reviewed By: xianxl

Differential Revision: D12810304

fbshipit-source-id: a56b5df093d17be2b73837897c526978cab92b70
This commit is contained in:
Liezl Puzon 2018-11-01 17:11:02 -07:00 committed by Facebook Github Bot
parent 0b05467dd8
commit b1521f962e

View File

@ -6,6 +6,7 @@
# can be found in the PATENTS file in the same directory.
import unittest
from typing import Dict, List
import tests.utils as test_utils
import torch
@ -116,82 +117,145 @@ class TestDataNoising(unittest.TestCase):
)
self.assert_eos_at_end(x=x_noised, x_len=l_noised, eos=vocab.eos())
def assert_no_shuffle_with_0_distance(self, x, x_noised, x_len, l_noised):
"""
Applies word shuffle with 0 max_shuffle_distance and asserts that no
shuffling happened
"""
for i in range(len(x_len)):
for j in range(x_len[i]):
self.assertEqual(x[j][i], x_noised[j][i])
self.assertEqual(x_len[0], l_noised[0])
def generate_unchanged_shuffle_map(self, length):
return {i: i for i in range(length)}
def assert_word_shuffle_with_distance_3(self, x, x_noised, x_len, l_noised):
def assert_word_shuffle_matches_expected(
self,
x,
x_len,
max_shuffle_distance: int,
vocab: Dictionary,
expected_shufle_maps: List[Dict[int, int]],
expect_eos_at_end: bool,
):
"""
Applies word shuffle with max_shuffle_distance = 3 and asserts that the
shuffling result is as expected. If test data changes, update this func
"""
# Expect the second example has the last three tokens shuffled
# 6, 7, 8, 9 => 6, 8, 9, 7, where (8, 9) is a word
for i in range(x_len[0]):
self.assertEqual(x[i][0], x_noised[i][0])
shuffle_map = {0: 0, 1: 3, 2: 1, 3: 2}
for k, v in shuffle_map.items():
self.assertEqual(x[k][1], x_noised[v][1])
self.assertEqual(x_len[0], l_noised[0])
self.assertEqual(x_len[1], l_noised[1])
This verifies that with a given x, x_len, max_shuffle_distance, and
vocab, we get the expected shuffle result.
def assert_nonbpe_shuffle_with_distance_3(self, x, x_noised, x_len, l_noised):
Args:
x: Tensor of shape (T x B) = (sequence_length, batch_size)
x_len: Tensor of length B = batch_size
max_shuffle_distance: arg to pass to noising
expected_shuffle_maps: List[mapping] where mapping is a
Dict[old_index, new_index], mapping x's elements from their
old positions in x to their new positions in x.
expect_eos_at_end: if True, check the output to make sure there is
an EOS at the end.
"""
Applies word shuffle with max_shuffle_distance = 3 and asserts that the
shuffling result is as expected. If test data changes, update this func
"""
# Expect the first example has the last two tokens shuffled
# Expect the secon example has the second and third tokens shuffled
shuffle_map = {0: 0, 1: 1, 2: 3, 3: 2}
for k, v in shuffle_map.items():
self.assertEqual(x[k][0], x_noised[v][0])
shuffle_map = {0: 0, 1: 2, 2: 1, 3: 3, 4: 4}
for k, v in shuffle_map.items():
self.assertEqual(x[k][1], x_noised[v][1])
self.assertEqual(x_len[0], l_noised[0])
self.assertEqual(x_len[1], l_noised[1])
with data_utils.numpy_seed(1234):
word_shuffle = noising.WordShuffle(vocab)
x_noised, l_noised = word_shuffle.noising(
x, x_len, max_shuffle_distance=max_shuffle_distance
)
# For every example, we have a different expected shuffle map. We check
# that each example is shuffled as expected according to each
# corresponding shuffle map.
for i in range(len(expected_shufle_maps)):
shuffle_map = expected_shufle_maps[i]
for k, v in shuffle_map.items():
self.assertEqual(x[k][i], x_noised[v][i])
# Shuffling should not affect the length of each example
for pre_shuffle_length, post_shuffle_length in zip(x_len, l_noised):
self.assertEqual(pre_shuffle_length, post_shuffle_length)
if expect_eos_at_end:
self.assert_eos_at_end(x=x_noised, x_len=l_noised, eos=vocab.eos())
def test_word_shuffle_with_eos(self):
vocab, x, x_len = self._get_test_data(append_eos=True)
with data_utils.numpy_seed(1234):
word_shuffle = noising.WordShuffle(vocab)
# Assert word shuffle with max shuffle distance 0 causes input to be
# unchanged
self.assert_word_shuffle_matches_expected(
x=x,
x_len=x_len,
max_shuffle_distance=0,
vocab=vocab,
expected_shufle_maps=[
self.generate_unchanged_shuffle_map(example_len)
for example_len in x_len
],
expect_eos_at_end=True,
)
x_noised, l_noised = word_shuffle.noising(x, x_len, 0)
self.assert_no_shuffle_with_0_distance(
x=x, x_noised=x_noised, x_len=x_len, l_noised=l_noised
)
self.assert_eos_at_end(x=x_noised, x_len=l_noised, eos=vocab.eos())
x_noised, l_noised = word_shuffle.noising(x, x_len, 3)
self.assert_word_shuffle_with_distance_3(
x=x, x_noised=x_noised, x_len=x_len, l_noised=l_noised
)
self.assert_eos_at_end(x=x_noised, x_len=l_noised, eos=vocab.eos())
# Assert word shuffle with max shuffle distance 3 matches our expected
# shuffle order
self.assert_word_shuffle_matches_expected(
x=x,
x_len=x_len,
vocab=vocab,
max_shuffle_distance=3,
expected_shufle_maps=[
self.generate_unchanged_shuffle_map(x_len[0]),
{0: 0, 1: 3, 2: 1, 3: 2},
],
expect_eos_at_end=True,
)
def test_word_shuffle_with_eos_nonbpe(self):
vocab, x, x_len = self._get_test_data(append_eos=True, bpe=False)
with data_utils.numpy_seed(1234):
word_shuffle = noising.WordShuffle(vocab, bpe_cont_marker=None)
# Assert word shuffle with max shuffle distance 0 causes input to be
# unchanged
self.assert_word_shuffle_matches_expected(
x=x,
x_len=x_len,
max_shuffle_distance=0,
vocab=vocab,
expected_shufle_maps=[
self.generate_unchanged_shuffle_map(example_len)
for example_len in x_len
],
expect_eos_at_end=True,
)
x_noised, l_noised = word_shuffle.noising(x, x_len, 0)
self.assert_no_shuffle_with_0_distance(
x=x, x_noised=x_noised, x_len=x_len, l_noised=l_noised
)
self.assert_eos_at_end(x=x_noised, x_len=l_noised, eos=vocab.eos())
# Assert word shuffle with max shuffle distance 3 matches our expected
# shuffle order
self.assert_word_shuffle_matches_expected(
x=x,
x_len=x_len,
vocab=vocab,
max_shuffle_distance=3,
expected_shufle_maps=[
{0: 0, 1: 1, 2: 3, 3: 2},
{0: 0, 1: 2, 2: 1, 3: 3, 4: 4},
],
expect_eos_at_end=True,
)
x_noised, l_noised = word_shuffle.noising(x, x_len, 3)
self.assert_nonbpe_shuffle_with_distance_3(
x=x, x_noised=x_noised, x_len=x_len, l_noised=l_noised
)
self.assert_eos_at_end(x=x_noised, x_len=l_noised, eos=vocab.eos())
def test_word_shuffle_without_eos(self):
"""Same result as word shuffle with eos except no EOS at end"""
vocab, x, x_len = self._get_test_data(append_eos=False)
# Assert word shuffle with max shuffle distance 0 causes input to be
# unchanged
self.assert_word_shuffle_matches_expected(
x=x,
x_len=x_len,
max_shuffle_distance=0,
vocab=vocab,
expected_shufle_maps=[
self.generate_unchanged_shuffle_map(example_len)
for example_len in x_len
],
expect_eos_at_end=False,
)
# Assert word shuffle with max shuffle distance 3 matches our expected
# shuffle order
self.assert_word_shuffle_matches_expected(
x=x,
x_len=x_len,
vocab=vocab,
max_shuffle_distance=3,
expected_shufle_maps=[
self.generate_unchanged_shuffle_map(x_len[0]),
{0: 0, 1: 3, 2: 1, 3: 2},
],
expect_eos_at_end=False,
)
def assert_no_eos_at_end(self, x, x_len, eos):
"""Asserts that the last token of each sentence in x is not EOS """
@ -228,25 +292,6 @@ class TestDataNoising(unittest.TestCase):
)
self.assert_no_eos_at_end(x=x_noised, x_len=l_noised, eos=vocab.eos())
def test_word_shuffle_without_eos(self):
"""Same result as word shuffle with eos except no EOS at end"""
vocab, x, x_len = self._get_test_data(append_eos=False)
with data_utils.numpy_seed(1234):
word_shuffle = noising.WordShuffle(vocab)
x_noised, l_noised = word_shuffle.noising(x, x_len, 0)
self.assert_no_shuffle_with_0_distance(
x=x, x_noised=x_noised, x_len=x_len, l_noised=l_noised
)
self.assert_no_eos_at_end(x=x_noised, x_len=l_noised, eos=vocab.eos())
x_noised, l_noised = word_shuffle.noising(x, x_len, 3)
self.assert_word_shuffle_with_distance_3(
x=x, x_noised=x_noised, x_len=x_len, l_noised=l_noised
)
self.assert_no_eos_at_end(x=x_noised, x_len=l_noised, eos=vocab.eos())
def _get_noising_dataset_batch(
self, src_tokens_no_pad, src_dict, use_append_eos_dataset=False
):