Fix full-context alignment with transformer_align model (#2675)

Summary:
Fixes https://github.com/pytorch/fairseq/issues/2673.

# Before submitting

- [x] Was this discussed/approved via a Github issue? (no need for typos, doc improvements)
- [x] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)?
- [ ] Did you make sure to update the docs?
- [ ] Did you write any new necessary tests?

## What does this PR do?
Fixes https://github.com/pytorch/fairseq/issues/2673 (issue).

## PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

## Did you have fun?
Make sure you had fun coding �

Pull Request resolved: https://github.com/pytorch/fairseq/pull/2675

Reviewed By: ngoyal2707

Differential Revision: D24001793

Pulled By: myleott

fbshipit-source-id: 6b4e9270e5f5a31ba1b65ae2ae717019108af913
This commit is contained in:
Seppo Enarvi 2020-10-01 12:35:54 -07:00 committed by Facebook GitHub Bot
parent 7d2a3e10a9
commit c049749c7a
4 changed files with 29 additions and 3 deletions

View File

@ -641,6 +641,7 @@ class TransformerDecoder(FairseqIncrementalDecoder):
encoder_out: Optional[EncoderOut] = None,
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
features_only: bool = False,
full_context_alignment: bool = False,
alignment_layer: Optional[int] = None,
alignment_heads: Optional[int] = None,
src_lengths: Optional[Any] = None,
@ -656,6 +657,8 @@ class TransformerDecoder(FairseqIncrementalDecoder):
:ref:`Incremental decoding`
features_only (bool, optional): only return features without
applying output layer (default: False).
full_context_alignment (bool, optional): don't apply
auto-regressive mask to self-attention (default: False).
Returns:
tuple:
@ -666,6 +669,7 @@ class TransformerDecoder(FairseqIncrementalDecoder):
prev_output_tokens,
encoder_out=encoder_out,
incremental_state=incremental_state,
full_context_alignment=full_context_alignment,
alignment_layer=alignment_layer,
alignment_heads=alignment_heads,
)

View File

@ -32,7 +32,7 @@ class TransformerAlignModel(TransformerModel):
help='Number of cross attention heads per layer to supervised with alignments')
parser.add_argument('--alignment-layer', type=int, metavar='D',
help='Layer number which has to be supervised. 0 corresponding to the bottommost layer.')
parser.add_argument('--full-context-alignment', type=bool, metavar='D',
parser.add_argument('--full-context-alignment', action='store_true',
help='Whether or not alignment is supervised conditioned on the full target context.')
# fmt: on

View File

@ -952,7 +952,7 @@ class EnsembleModelWithAlignment(EnsembleModel):
avg_attn = None
for model in self.models:
decoder_out = model(src_tokens, src_lengths, prev_output_tokens)
attn = decoder_out[1]["attn"]
attn = decoder_out[1]["attn"][0]
if avg_attn is None:
avg_attn = attn
else:

View File

@ -446,7 +446,29 @@ class TestTranslation(unittest.TestCase):
'--decoder-embed-dim', '8',
'--load-alignments',
'--alignment-layer', '1',
'--criterion', 'label_smoothed_cross_entropy_with_alignment'
'--criterion', 'label_smoothed_cross_entropy_with_alignment',
],
run_validation=True,
)
generate_main(data_dir)
def test_alignment_full_context(self):
with contextlib.redirect_stdout(StringIO()):
with tempfile.TemporaryDirectory('test_alignment') as data_dir:
create_dummy_data(data_dir, alignment=True)
preprocess_translation_data(data_dir, ['--align-suffix', 'align'])
train_translation_model(
data_dir,
'transformer_align',
[
'--encoder-layers', '2',
'--decoder-layers', '2',
'--encoder-embed-dim', '8',
'--decoder-embed-dim', '8',
'--load-alignments',
'--alignment-layer', '1',
'--criterion', 'label_smoothed_cross_entropy_with_alignment',
'--full-context-alignment',
],
run_validation=True,
)