Fix bidirectional lstm

This commit is contained in:
Myle Ott 2018-06-12 14:12:50 -06:00
parent 55dc4842b2
commit bfcc6ec739
2 changed files with 30 additions and 20 deletions

View File

@ -134,7 +134,7 @@ class LSTMEncoder(FairseqEncoder):
input_size=embed_dim,
hidden_size=hidden_size,
num_layers=num_layers,
dropout=self.dropout_out,
dropout=self.dropout_out if num_layers > 1 else 0.,
bidirectional=bidirectional,
)
self.left_pad = left_pad
@ -172,29 +172,23 @@ class LSTMEncoder(FairseqEncoder):
state_size = self.num_layers, bsz, self.hidden_size
h0 = x.data.new(*state_size).zero_()
c0 = x.data.new(*state_size).zero_()
packed_outs, (final_hiddens, final_cells) = self.lstm(
packed_x,
(h0, c0),
)
packed_outs, (final_hiddens, final_cells) = self.lstm(packed_x, (h0, c0))
# unpack outputs and apply dropout
x, _ = nn.utils.rnn.pad_packed_sequence(
packed_outs, padding_value=self.padding_value)
x, _ = nn.utils.rnn.pad_packed_sequence(packed_outs, padding_value=self.padding_value)
x = F.dropout(x, p=self.dropout_out, training=self.training)
assert list(x.size()) == [seqlen, bsz, self.output_units]
if self.bidirectional:
bi_final_hiddens, bi_final_cells = [], []
for i in range(self.num_layers):
bi_final_hiddens.append(
torch.cat(
(final_hiddens[2 * i], final_hiddens[2 * i + 1]),
dim=0).view(bsz, self.output_units))
bi_final_cells.append(
torch.cat(
(final_cells[2 * i], final_cells[2 * i + 1]),
dim=0).view(bsz, self.output_units))
return x, bi_final_hiddens, bi_final_cells
def combine_bidir(outs):
return torch.cat([
torch.cat([outs[2 * i], outs[2 * i + 1]], dim=0).view(1, bsz, self.output_units)
for i in range(self.num_layers)
], dim=0)
final_hiddens = combine_bidir(final_hiddens)
final_cells = combine_bidir(final_cells)
encoder_padding_mask = src_tokens.eq(self.padding_idx).t()
@ -262,7 +256,7 @@ class LSTMDecoder(FairseqIncrementalDecoder):
self.encoder_output_units = encoder_output_units
assert encoder_output_units == hidden_size, \
'{} {}'.format(encoder_output_units, hidden_size)
'encoder_output_units ({}) != hidden_size ({})'.format(encoder_output_units, hidden_size)
# TODO another Linear layer if not equal
self.layers = nn.ModuleList([

View File

@ -55,7 +55,23 @@ class TestTranslation(unittest.TestCase):
with tempfile.TemporaryDirectory('test_lstm') as data_dir:
create_dummy_data(data_dir)
preprocess_translation_data(data_dir)
train_translation_model(data_dir, 'lstm_wiseman_iwslt_de_en')
train_translation_model(data_dir, 'lstm_wiseman_iwslt_de_en', [
'--encoder-layers', '2',
'--decoder-layers', '2',
])
generate_main(data_dir)
def test_lstm_bidirectional(self):
with contextlib.redirect_stdout(StringIO()):
with tempfile.TemporaryDirectory('test_lstm_bidirectional') as data_dir:
create_dummy_data(data_dir)
preprocess_translation_data(data_dir)
train_translation_model(data_dir, 'lstm', [
'--encoder-layers', '2',
'--encoder-bidirectional',
'--encoder-hidden-size', '256',
'--decoder-layers', '2',
])
generate_main(data_dir)
def test_transformer(self):