Move reorder_encoder_out to FairseqEncoder and fix non-incremental decoding

This commit is contained in:
Myle Ott 2018-06-21 14:34:29 -04:00
parent e9967cd334
commit 6ec5022e57
9 changed files with 58 additions and 45 deletions

View File

@ -26,6 +26,12 @@ class CompositeEncoder(FairseqEncoder):
encoder_out[key] = self.encoders[key](src_tokens, src_lengths)
return encoder_out
def reorder_encoder_out(self, encoder_out, new_order):
"""Reorder encoder output according to new_order."""
for key in self.encoders:
encoder_out[key] = self.encoders[key].reorder_encoder_out(encoder_out[key], new_order)
return encoder_out
def max_positions(self):
return min([self.encoders[key].max_positions() for key in self.encoders])

View File

@ -18,6 +18,10 @@ class FairseqEncoder(nn.Module):
def forward(self, src_tokens, src_lengths):
raise NotImplementedError
def reorder_encoder_out(self, encoder_out, new_order):
"""Reorder encoder output according to new_order."""
raise NotImplementedError
def max_positions(self):
"""Maximum input length supported by the encoder."""
raise NotImplementedError

View File

@ -32,9 +32,6 @@ class FairseqIncrementalDecoder(FairseqDecoder):
)
self.apply(apply_reorder_incremental_state)
def reorder_encoder_out(self, encoder_out, new_order):
return encoder_out
def set_beam_size(self, beam_size):
"""Sets the beam size in the decoder and all children."""
if getattr(self, '_beam_size', -1) != beam_size:

View File

@ -268,6 +268,17 @@ class FConvEncoder(FairseqEncoder):
'encoder_padding_mask': encoder_padding_mask, # B x T
}
def reorder_encoder_out(self, encoder_out_dict, new_order):
if encoder_out_dict['encoder_out'] is not None:
encoder_out_dict['encoder_out'] = (
encoder_out_dict['encoder_out'][0].index_select(0, new_order),
encoder_out_dict['encoder_out'][1].index_select(0, new_order),
)
if encoder_out_dict['encoder_padding_mask'] is not None:
encoder_out_dict['encoder_padding_mask'] = \
encoder_out_dict['encoder_padding_mask'].index_select(0, new_order)
return encoder_out_dict
def max_positions(self):
"""Maximum input length supported by the encoder."""
return self.embed_positions.max_positions()
@ -496,12 +507,6 @@ class FConvDecoder(FairseqIncrementalDecoder):
encoder_out = tuple(eo.index_select(0, new_order) for eo in encoder_out)
utils.set_incremental_state(self, incremental_state, 'encoder_out', encoder_out)
def reorder_encoder_out(self, encoder_out_dict, new_order):
if encoder_out_dict['encoder_padding_mask'] is not None:
encoder_out_dict['encoder_padding_mask'] = \
encoder_out_dict['encoder_padding_mask'].index_select(0, new_order)
return encoder_out_dict
def max_positions(self):
"""Maximum output length supported by the decoder."""
return self.embed_positions.max_positions() if self.embed_positions is not None else float('inf')

View File

@ -226,6 +226,19 @@ class FConvEncoder(FairseqEncoder):
'encoder_out': (x, y),
}
def reorder_encoder_out(self, encoder_out_dict, new_order):
encoder_out_dict['encoder_out'] = tuple(
eo.index_select(0, new_order) for eo in encoder_out_dict['encoder_out']
)
if 'pretrained' in encoder_out_dict:
encoder_out_dict['pretrained']['encoder_out'] = tuple(
eo.index_select(0, new_order)
for eo in encoder_out_dict['pretrained']['encoder_out']
)
return encoder_out_dict
def max_positions(self):
"""Maximum input length supported by the encoder."""
return self.embed_positions.max_positions()
@ -409,30 +422,12 @@ class FConvDecoder(FairseqDecoder):
else:
return x, avg_attn_scores
def reorder_incremental_state(self, incremental_state, new_order):
"""Reorder buffered internal state (for incremental generation)."""
super().reorder_incremental_state(incremental_state, new_order)
def reorder_encoder_out(self, encoder_out_dict, new_order):
encoder_out_dict['encoder']['encoder_out'] = tuple(
eo.index_select(0, new_order) for eo in encoder_out_dict['encoder']['encoder_out']
)
if 'pretrained' in encoder_out_dict:
encoder_out_dict['pretrained']['encoder']['encoder_out'] = tuple(
eo.index_select(0, new_order)
for eo in encoder_out_dict['pretrained']['encoder']['encoder_out']
)
return encoder_out_dict
def max_positions(self):
"""Maximum output length supported by the decoder."""
return self.embed_positions.max_positions()
def _split_encoder_out(self, encoder_out):
"""Split and transpose encoder outputs.
"""
"""Split and transpose encoder outputs."""
# transpose only once to speed up attention layers
encoder_a, encoder_b = encoder_out
encoder_a = encoder_a.transpose(0, 1).contiguous()

View File

@ -197,6 +197,16 @@ class LSTMEncoder(FairseqEncoder):
'encoder_padding_mask': encoder_padding_mask if encoder_padding_mask.any() else None
}
def reorder_encoder_out(self, encoder_out_dict, new_order):
encoder_out_dict['encoder_out'] = tuple(
eo.index_select(1, new_order)
for eo in encoder_out_dict['encoder_out']
)
if encoder_out_dict['encoder_padding_mask'] is not None:
encoder_out_dict['encoder_padding_mask'] = \
encoder_out_dict['encoder_padding_mask'].index_select(1, new_order)
return encoder_out_dict
def max_positions(self):
"""Maximum input length supported by the encoder."""
return int(1e5) # an arbitrary large number
@ -366,16 +376,6 @@ class LSTMDecoder(FairseqIncrementalDecoder):
new_state = tuple(map(reorder_state, cached_state))
utils.set_incremental_state(self, incremental_state, 'cached_state', new_state)
def reorder_encoder_out(self, encoder_out_dict, new_order):
encoder_out_dict['encoder_out'] = tuple(
eo.index_select(1, new_order)
for eo in encoder_out_dict['encoder_out']
)
if encoder_out_dict['encoder_padding_mask'] is not None:
encoder_out_dict['encoder_padding_mask'] = \
encoder_out_dict['encoder_padding_mask'].index_select(1, new_order)
return encoder_out_dict
def max_positions(self):
"""Maximum output length supported by the decoder."""
return int(1e5) # an arbitrary large number

View File

@ -164,6 +164,15 @@ class TransformerEncoder(FairseqEncoder):
'encoder_padding_mask': encoder_padding_mask, # B x T
}
def reorder_encoder_out(self, encoder_out_dict, new_order):
if encoder_out_dict['encoder_out'] is not None:
encoder_out_dict['encoder_out'] = \
encoder_out_dict['encoder_out'].index_select(1, new_order)
if encoder_out_dict['encoder_padding_mask'] is not None:
encoder_out_dict['encoder_padding_mask'] = \
encoder_out_dict['encoder_padding_mask'].index_select(0, new_order)
return encoder_out_dict
def max_positions(self):
"""Maximum input length supported by the encoder."""
return self.embed_positions.max_positions()
@ -245,12 +254,6 @@ class TransformerDecoder(FairseqIncrementalDecoder):
return x, attn
def reorder_encoder_out(self, encoder_out_dict, new_order):
if encoder_out_dict['encoder_padding_mask'] is not None:
encoder_out_dict['encoder_padding_mask'] = \
encoder_out_dict['encoder_padding_mask'].index_select(0, new_order)
return encoder_out_dict
def max_positions(self):
"""Maximum output length supported by the decoder."""
return self.embed_positions.max_positions()

View File

@ -268,7 +268,7 @@ class SequenceGenerator(object):
for i, model in enumerate(self.models):
if isinstance(model.decoder, FairseqIncrementalDecoder):
model.decoder.reorder_incremental_state(incremental_states[model], reorder_state)
encoder_outs[i] = model.decoder.reorder_encoder_out(encoder_outs[i], reorder_state)
encoder_outs[i] = model.encoder.reorder_encoder_out(encoder_outs[i], reorder_state)
probs, avg_attn_scores = self._decode(
tokens[:, :step + 1], encoder_outs, incremental_states)

View File

@ -108,6 +108,9 @@ class TestEncoder(FairseqEncoder):
def forward(self, src_tokens, src_lengths):
return src_tokens
def reorder_encoder_out(self, encoder_out, new_order):
return encoder_out.index_select(0, new_order)
class TestIncrementalDecoder(FairseqIncrementalDecoder):
def __init__(self, args, dictionary):