mirror of
https://github.com/facebookresearch/fairseq.git
synced 2024-10-26 17:32:57 +03:00
Merge TracingCompliantTransformer and regular Transformer, fix NAT tests
Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/899 Differential Revision: D18373060 Pulled By: myleott fbshipit-source-id: bb5510ec15799a0a10a7c0669e76d8200e1ba479
This commit is contained in:
parent
2a9b4ec237
commit
27568a7ebe
@ -48,7 +48,7 @@ class LabelSmoothedDualImitationCriterion(FairseqCriterion):
|
||||
if masks is not None:
|
||||
outputs, targets = outputs[masks], targets[masks]
|
||||
|
||||
if not masks.any():
|
||||
if masks is not None and not masks.any():
|
||||
nll_loss = torch.tensor(0)
|
||||
loss = nll_loss
|
||||
else:
|
||||
|
@ -3,11 +3,20 @@
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from collections import namedtuple
|
||||
|
||||
import torch
|
||||
|
||||
from fairseq import utils
|
||||
from fairseq.models.levenshtein_transformer import LevenshteinTransformerModel
|
||||
from fairseq.models.model_utils import script_skip_tensor_list, skip_tensors as _skip
|
||||
from fairseq.models.nonautoregressive_ensembles import EnsembleLevT
|
||||
|
||||
|
||||
DecoderOut = namedtuple('IterativeRefinementDecoderOut', [
|
||||
'output_tokens',
|
||||
'output_scores',
|
||||
'attn',
|
||||
'step',
|
||||
'max_step',
|
||||
])
|
||||
|
||||
|
||||
class IterativeRefinementGenerator(object):
|
||||
@ -88,6 +97,8 @@ class IterativeRefinementGenerator(object):
|
||||
|
||||
@torch.no_grad()
|
||||
def generate(self, models, sample, prefix_tokens=None):
|
||||
from fairseq.models.levenshtein_transformer import LevenshteinTransformerModel
|
||||
from fairseq.models.nonautoregressive_ensembles import EnsembleLevT
|
||||
|
||||
if len(models) == 1:
|
||||
# Keep this for other NAT models for which we have yet to implement ensemble wrappers. Later delete this.
|
||||
@ -110,7 +121,7 @@ class IterativeRefinementGenerator(object):
|
||||
|
||||
# initialize buffers (very model specific, with length prediction or not)
|
||||
prev_decoder_out = model.initialize_output_tokens(encoder_out, src_tokens)
|
||||
prev_output_tokens = prev_decoder_out[0].clone()
|
||||
prev_output_tokens = prev_decoder_out.output_tokens.clone()
|
||||
|
||||
finalized = [[] for _ in range(bsz)]
|
||||
|
||||
@ -150,8 +161,10 @@ class IterativeRefinementGenerator(object):
|
||||
"max_ratio": self.max_ratio,
|
||||
"decoding_format": self.decoding_format,
|
||||
}
|
||||
prev_decoder_out[3] = step
|
||||
prev_decoder_out[4] = self.max_iter + 1
|
||||
prev_decoder_out = prev_decoder_out._replace(
|
||||
step=step,
|
||||
max_step=self.max_iter + 1,
|
||||
)
|
||||
|
||||
decoder_out = model.forward_decoder(
|
||||
prev_decoder_out, encoder_out, **decoder_options
|
||||
@ -160,24 +173,26 @@ class IterativeRefinementGenerator(object):
|
||||
if self.adaptive:
|
||||
# terminate if there is a loop
|
||||
terminated, out_tokens, out_scores, out_attn = is_a_loop(
|
||||
prev_output_tokens, decoder_out[0], decoder_out[1], decoder_out[2]
|
||||
prev_output_tokens, decoder_out.output_tokens, decoder_out.output_scores, decoder_out.attn
|
||||
)
|
||||
decoder_out = decoder_out._replace(
|
||||
output_tokens=out_tokens,
|
||||
output_scores=out_scores,
|
||||
attn=out_attn,
|
||||
)
|
||||
decoder_out[0] = out_tokens
|
||||
decoder_out[1] = out_scores
|
||||
decoder_out[2] = out_attn
|
||||
|
||||
else:
|
||||
terminated = decoder_out[0].new_zeros(decoder_out[0].size(0)).bool()
|
||||
terminated = decoder_out.output_tokens.new_zeros(decoder_out.output_tokens.size(0)).bool()
|
||||
|
||||
if step == self.max_iter: # reach last iteration, terminate
|
||||
terminated.fill_(1)
|
||||
|
||||
# collect finalized sentences
|
||||
finalized_idxs = sent_idxs[terminated]
|
||||
finalized_tokens = decoder_out[0][terminated]
|
||||
finalized_scores = decoder_out[1][terminated]
|
||||
finalized_tokens = decoder_out.output_tokens[terminated]
|
||||
finalized_scores = decoder_out.output_scores[terminated]
|
||||
finalized_attn = (
|
||||
None if decoder_out[2] is None else decoder_out[2][terminated]
|
||||
None if decoder_out.attn is None else decoder_out.attn[terminated]
|
||||
)
|
||||
|
||||
for i in range(finalized_idxs.size(0)):
|
||||
@ -194,10 +209,15 @@ class IterativeRefinementGenerator(object):
|
||||
break
|
||||
|
||||
# for next step
|
||||
prev_decoder_out = _skip(decoder_out, ~terminated)
|
||||
encoder_out = script_skip_tensor_list(encoder_out, ~terminated)
|
||||
sent_idxs = _skip(sent_idxs, ~terminated)
|
||||
not_terminated = ~terminated
|
||||
prev_decoder_out = decoder_out._replace(
|
||||
output_tokens=decoder_out.output_tokens[not_terminated],
|
||||
output_scores=decoder_out.output_scores[not_terminated],
|
||||
attn=decoder_out.attn[not_terminated] if decoder_out.attn is not None else None,
|
||||
)
|
||||
encoder_out = model.encoder.reorder_encoder_out(encoder_out, not_terminated.nonzero().squeeze())
|
||||
sent_idxs = sent_idxs[not_terminated]
|
||||
|
||||
prev_output_tokens = prev_decoder_out[0].clone()
|
||||
prev_output_tokens = prev_decoder_out.output_tokens.clone()
|
||||
|
||||
return finalized
|
||||
|
@ -10,9 +10,9 @@ Ghazvininejad, Marjan, et al.
|
||||
arXiv preprint arXiv:1904.09324 (2019).
|
||||
"""
|
||||
|
||||
from fairseq.utils import new_arange
|
||||
from fairseq.models import register_model, register_model_architecture
|
||||
from fairseq.models.nonautoregressive_transformer import NATransformerModel
|
||||
from fairseq.utils import new_arange
|
||||
|
||||
|
||||
def _skeptical_unmasking(output_scores, output_masks, p):
|
||||
@ -55,11 +55,11 @@ class CMLMNATransformerModel(NATransformerModel):
|
||||
|
||||
def forward_decoder(self, decoder_out, encoder_out, decoding_format=None, **kwargs):
|
||||
|
||||
step = decoder_out["step"]
|
||||
max_step = decoder_out["max_step"]
|
||||
step = decoder_out.step
|
||||
max_step = decoder_out.max_step
|
||||
|
||||
output_tokens = decoder_out["output_tokens"]
|
||||
output_scores = decoder_out["output_scores"]
|
||||
output_tokens = decoder_out.output_tokens
|
||||
output_scores = decoder_out.output_scores
|
||||
|
||||
# execute the decoder
|
||||
output_masks = output_tokens.eq(self.unk)
|
||||
@ -78,7 +78,11 @@ class CMLMNATransformerModel(NATransformerModel):
|
||||
output_tokens.masked_fill_(skeptical_mask, self.unk)
|
||||
output_scores.masked_fill_(skeptical_mask, 0.0)
|
||||
|
||||
return {"output_tokens": output_tokens, "output_scores": output_scores}
|
||||
return decoder_out._replace(
|
||||
output_tokens=output_tokens,
|
||||
output_scores=output_scores,
|
||||
attn=None,
|
||||
)
|
||||
|
||||
|
||||
@register_model_architecture("cmlm_transformer", "cmlm_transformer")
|
||||
|
@ -6,7 +6,7 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from fairseq.utils import new_arange
|
||||
|
||||
from fairseq.models import register_model, register_model_architecture
|
||||
from fairseq.models.levenshtein_transformer import (
|
||||
LevenshteinTransformerDecoder,
|
||||
@ -14,6 +14,7 @@ from fairseq.models.levenshtein_transformer import (
|
||||
)
|
||||
from fairseq.models.transformer import Linear, TransformerModel
|
||||
from fairseq.modules.transformer_sentence_encoder import init_bert_params
|
||||
from fairseq.utils import new_arange
|
||||
|
||||
|
||||
class NegativeDistanceScore(object):
|
||||
@ -116,8 +117,8 @@ def _apply_ins_words(in_tokens, in_scores, word_ins_pred, word_ins_scores, paddi
|
||||
|
||||
@register_model("insertion_transformer")
|
||||
class InsertionTransformerModel(LevenshteinTransformerModel):
|
||||
def __init__(self, encoder, decoder):
|
||||
super().__init__(encoder, decoder)
|
||||
def __init__(self, args, encoder, decoder):
|
||||
super().__init__(args, encoder, decoder)
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
@ -169,8 +170,8 @@ class InsertionTransformerModel(LevenshteinTransformerModel):
|
||||
self, decoder_out, encoder_out, eos_penalty=0.0, max_ratio=None, **kwargs
|
||||
):
|
||||
|
||||
output_tokens = decoder_out["output_tokens"]
|
||||
output_scores = decoder_out["output_scores"]
|
||||
output_tokens = decoder_out.output_tokens
|
||||
output_scores = decoder_out.output_scores
|
||||
# TODO: decoding for InsertionTransformer
|
||||
word_ins_out = self.decoder.forward_word_ins(
|
||||
output_tokens, encoder_out=encoder_out
|
||||
@ -187,7 +188,11 @@ class InsertionTransformerModel(LevenshteinTransformerModel):
|
||||
cut_off = output_tokens.ne(self.pad).sum(1).max()
|
||||
output_tokens = output_tokens[:, :cut_off]
|
||||
output_scores = output_scores[:, :cut_off]
|
||||
return {"output_tokens": output_tokens, "output_scores": output_scores, "attn": None}
|
||||
return decoder_out._replace(
|
||||
output_tokens=output_tokens,
|
||||
output_scores=output_scores,
|
||||
attn=None,
|
||||
)
|
||||
|
||||
|
||||
class InsertionTransformerDecoder(LevenshteinTransformerDecoder):
|
||||
@ -206,7 +211,7 @@ class InsertionTransformerDecoder(LevenshteinTransformerDecoder):
|
||||
self.label_tau = getattr(args, "label_tau", None)
|
||||
|
||||
def forward_word_ins(self, prev_output_tokens, encoder_out=None):
|
||||
features, _ = self.extract_features(prev_output_tokens, encoder_out=encoder_out)
|
||||
features = self.extract_features(prev_output_tokens, encoder_out=encoder_out)[0]
|
||||
features = self.pool_out(
|
||||
torch.cat([features[:, :-1, :], features[:, 1:, :]], 2)
|
||||
)
|
||||
|
@ -4,6 +4,7 @@
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
|
||||
from fairseq.models import register_model, register_model_architecture
|
||||
from fairseq.models.nonautoregressive_transformer import NATransformerModel
|
||||
|
||||
|
@ -1,52 +1,109 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) 2017-present, Facebook, Inc.
|
||||
# All rights reserved.
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the license found in the LICENSE file in
|
||||
# the root directory of this source tree. An additional grant of patent rights
|
||||
# can be found in the PATENTS file in the same directory.
|
||||
|
||||
|
||||
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||
|
||||
from typing import Optional
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from fairseq.iterative_refinement_generator import DecoderOut
|
||||
from fairseq.models import register_model, register_model_architecture
|
||||
from fairseq.models.tracing_compliant_transformer import (
|
||||
TracingTransformerDecoder,
|
||||
TracingTransformerEncoder,
|
||||
TracingTransformerModel,
|
||||
TransformerDecoderLayer,
|
||||
from fairseq.models.transformer import (
|
||||
Embedding,
|
||||
TransformerDecoder,
|
||||
TransformerEncoder,
|
||||
TransformerModel,
|
||||
TransformerDecoderLayer
|
||||
)
|
||||
from fairseq.models.model_utils import (
|
||||
fill_tensors as _fill,
|
||||
script_skip_tensor,
|
||||
script_skip_tensor_list,
|
||||
)
|
||||
from fairseq.models.transformer import Embedding
|
||||
from fairseq.modules.transformer_sentence_encoder import init_bert_params
|
||||
from torch import Tensor
|
||||
from fairseq.utils import new_arange
|
||||
|
||||
# -------------- Helper Functions --------------------------------------------------- #
|
||||
def _skip(x, mask):
|
||||
"""
|
||||
Getting sliced (dim=0) tensor by mask. Supporting tensor and list/dict of tensors.
|
||||
"""
|
||||
if isinstance(x, int):
|
||||
return x
|
||||
|
||||
if x is None:
|
||||
return None
|
||||
|
||||
if isinstance(x, torch.Tensor):
|
||||
if x.size(0) == mask.size(0):
|
||||
return x[mask]
|
||||
elif x.size(1) == mask.size(0):
|
||||
return x[:, mask]
|
||||
|
||||
if isinstance(x, list):
|
||||
return [_skip(x_i, mask) for x_i in x]
|
||||
|
||||
if isinstance(x, dict):
|
||||
return {k: _skip(v, mask) for k, v in x.items()}
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def _get_ins_targets(in_tokens, out_tokens, padding_idx, unk_idx):
|
||||
def _skip_encoder_out(encoder, encoder_out, mask):
|
||||
if not mask.any():
|
||||
return encoder_out
|
||||
else:
|
||||
return encoder.reorder_encoder_out(encoder_out, mask.nonzero().squeeze())
|
||||
|
||||
|
||||
def _fill(x, mask, y, padding_idx):
|
||||
"""
|
||||
Filling tensor x with y at masked positions (dim=0).
|
||||
"""
|
||||
if x is None:
|
||||
return y
|
||||
assert x.dim() == y.dim() and mask.size(0) == x.size(0)
|
||||
assert x.dim() == 2 or (x.dim() == 3 and x.size(2) == y.size(2))
|
||||
n_selected = mask.sum()
|
||||
assert n_selected == y.size(0)
|
||||
|
||||
if n_selected == x.size(0):
|
||||
return y
|
||||
|
||||
if x.size(1) < y.size(1):
|
||||
dims = [x.size(0), y.size(1) - x.size(1)]
|
||||
if x.dim() == 3:
|
||||
dims.append(x.size(2))
|
||||
x = torch.cat([x, x.new_zeros(*dims).fill_(padding_idx)], 1)
|
||||
x[mask] = y
|
||||
elif x.size(1) > y.size(1):
|
||||
x[mask] = padding_idx
|
||||
if x.dim() == 2:
|
||||
x[mask, :y.size(1)] = y
|
||||
else:
|
||||
x[mask, :y.size(1), :] = y
|
||||
else:
|
||||
x[mask] = y
|
||||
return x
|
||||
|
||||
|
||||
def load_libnat():
|
||||
try:
|
||||
from fairseq import libnat
|
||||
except ImportError as e:
|
||||
import sys
|
||||
|
||||
sys.stderr.write("ERROR: missing libnat. run `pip install --editable .`\n")
|
||||
raise e
|
||||
return libnat
|
||||
|
||||
|
||||
def _get_ins_targets(in_tokens, out_tokens, padding_idx, unk_idx):
|
||||
libnat = load_libnat()
|
||||
|
||||
in_seq_len, out_seq_len = in_tokens.size(1), out_tokens.size(1)
|
||||
|
||||
in_tokens_list = [
|
||||
[t for t in s if t != padding_idx] for i, s in enumerate(in_tokens.tolist())
|
||||
]
|
||||
out_tokens_list = [
|
||||
[t for t in s if t != padding_idx] for i, s in enumerate(out_tokens.tolist())
|
||||
[t for t in s if t != padding_idx]
|
||||
for i, s in enumerate(out_tokens.tolist())
|
||||
]
|
||||
|
||||
full_labels = libnat.suggested_ed2_path(
|
||||
@ -71,27 +128,26 @@ def _get_ins_targets(in_tokens, out_tokens, padding_idx, unk_idx):
|
||||
]
|
||||
|
||||
# transform to tensor
|
||||
masked_tgt_masks = torch.tensor(masked_tgt_masks, device=out_tokens.device).bool()
|
||||
masked_tgt_masks = torch.tensor(
|
||||
masked_tgt_masks, device=out_tokens.device
|
||||
).bool()
|
||||
mask_ins_targets = torch.tensor(mask_ins_targets, device=in_tokens.device)
|
||||
masked_tgt_tokens = out_tokens.masked_fill(masked_tgt_masks, unk_idx)
|
||||
return masked_tgt_masks, masked_tgt_tokens, mask_ins_targets
|
||||
|
||||
|
||||
def _get_del_targets(in_tokens, out_tokens, padding_idx):
|
||||
try:
|
||||
from fairseq import libnat
|
||||
except ImportError as e:
|
||||
import sys
|
||||
libnat = load_libnat()
|
||||
|
||||
sys.stderr.write("ERROR: missing libnat. run `pip install --editable .`\n")
|
||||
raise e
|
||||
out_seq_len = out_tokens.size(1)
|
||||
|
||||
with torch.cuda.device_of(in_tokens):
|
||||
in_tokens_list = [
|
||||
[t for t in s if t != padding_idx] for i, s in enumerate(in_tokens.tolist())
|
||||
]
|
||||
out_tokens_list = [
|
||||
[t for t in s if t != padding_idx] for i, s in enumerate(out_tokens.tolist())
|
||||
[t for t in s if t != padding_idx]
|
||||
for i, s in enumerate(out_tokens.tolist())
|
||||
]
|
||||
|
||||
full_labels = libnat.suggested_ed2_path(
|
||||
@ -104,25 +160,22 @@ def _get_del_targets(in_tokens, out_tokens, padding_idx):
|
||||
]
|
||||
|
||||
# transform to tensor
|
||||
word_del_targets = torch.tensor(word_del_targets)
|
||||
word_del_targets = torch.tensor(word_del_targets, device=out_tokens.device)
|
||||
return word_del_targets
|
||||
|
||||
|
||||
def _get_del_ins_targets(in_tokens, out_tokens, padding_idx):
|
||||
try:
|
||||
from fairseq import libnat
|
||||
except ImportError as e:
|
||||
import sys
|
||||
libnat = load_libnat()
|
||||
|
||||
sys.stderr.write("ERROR: missing libnat. run `pip install --editable .`\n")
|
||||
raise e
|
||||
in_seq_len, out_seq_len = in_tokens.size(1), out_tokens.size(1)
|
||||
|
||||
with torch.cuda.device_of(in_tokens):
|
||||
in_tokens_list = [
|
||||
[t for t in s if t != padding_idx] for i, s in enumerate(in_tokens.tolist())
|
||||
]
|
||||
out_tokens_list = [
|
||||
[t for t in s if t != padding_idx] for i, s in enumerate(out_tokens.tolist())
|
||||
[t for t in s if t != padding_idx]
|
||||
for i, s in enumerate(out_tokens.tolist())
|
||||
]
|
||||
|
||||
full_labels = libnat.suggested_ed2_path(
|
||||
@ -144,15 +197,101 @@ def _get_del_ins_targets(in_tokens, out_tokens, padding_idx):
|
||||
]
|
||||
|
||||
# transform to tensor
|
||||
mask_ins_targets = torch.tensor(mask_ins_targets)
|
||||
word_del_targets = torch.tensor(word_del_targets)
|
||||
mask_ins_targets = torch.tensor(mask_ins_targets, device=in_tokens.device)
|
||||
word_del_targets = torch.tensor(word_del_targets, device=out_tokens.device)
|
||||
return word_del_targets, mask_ins_targets
|
||||
|
||||
|
||||
def _apply_ins_masks(
|
||||
in_tokens, in_scores, mask_ins_pred, padding_idx, unk_idx, eos_idx
|
||||
):
|
||||
|
||||
in_masks = in_tokens.ne(padding_idx)
|
||||
in_lengths = in_masks.sum(1)
|
||||
|
||||
# HACK: hacky way to shift all the paddings to eos first.
|
||||
in_tokens.masked_fill_(~in_masks, eos_idx)
|
||||
mask_ins_pred.masked_fill_(~in_masks[:, 1:], 0)
|
||||
|
||||
out_lengths = in_lengths + mask_ins_pred.sum(1)
|
||||
out_max_len = out_lengths.max()
|
||||
out_masks = (
|
||||
new_arange(out_lengths, out_max_len)[None, :]
|
||||
< out_lengths[:, None]
|
||||
)
|
||||
|
||||
reordering = (mask_ins_pred + in_masks[:, 1:].long()).cumsum(1)
|
||||
out_tokens = (
|
||||
in_tokens.new_zeros(in_tokens.size(0), out_max_len)
|
||||
.fill_(padding_idx)
|
||||
.masked_fill_(out_masks, unk_idx)
|
||||
)
|
||||
out_tokens[:, 0] = in_tokens[:, 0]
|
||||
out_tokens.scatter_(1, reordering, in_tokens[:, 1:])
|
||||
|
||||
out_scores = None
|
||||
if in_scores is not None:
|
||||
in_scores.masked_fill_(~in_masks, 0)
|
||||
out_scores = in_scores.new_zeros(*out_tokens.size())
|
||||
out_scores[:, 0] = in_scores[:, 0]
|
||||
out_scores.scatter_(1, reordering, in_scores[:, 1:])
|
||||
|
||||
return out_tokens, out_scores
|
||||
|
||||
|
||||
def _apply_ins_words(
|
||||
in_tokens, in_scores, word_ins_pred, word_ins_scores, unk_idx
|
||||
):
|
||||
word_ins_masks = in_tokens.eq(unk_idx)
|
||||
out_tokens = in_tokens.masked_scatter(word_ins_masks, word_ins_pred[word_ins_masks])
|
||||
|
||||
if in_scores is not None:
|
||||
out_scores = in_scores.masked_scatter(
|
||||
word_ins_masks, word_ins_scores[word_ins_masks]
|
||||
)
|
||||
else:
|
||||
out_scores = None
|
||||
|
||||
return out_tokens, out_scores
|
||||
|
||||
|
||||
def _apply_del_words(
|
||||
in_tokens, in_scores, in_attn, word_del_pred, padding_idx, bos_idx, eos_idx
|
||||
):
|
||||
# apply deletion to a tensor
|
||||
in_masks = in_tokens.ne(padding_idx)
|
||||
bos_eos_masks = in_tokens.eq(bos_idx) | in_tokens.eq(eos_idx)
|
||||
|
||||
max_len = in_tokens.size(1)
|
||||
word_del_pred.masked_fill_(~in_masks, 1)
|
||||
word_del_pred.masked_fill_(bos_eos_masks, 0)
|
||||
|
||||
reordering = (
|
||||
new_arange(in_tokens)
|
||||
.masked_fill_(word_del_pred, max_len)
|
||||
.sort(1)[1]
|
||||
)
|
||||
|
||||
out_tokens = in_tokens.masked_fill(word_del_pred, padding_idx).gather(1, reordering)
|
||||
|
||||
out_scores = None
|
||||
if in_scores is not None:
|
||||
out_scores = in_scores.masked_fill(word_del_pred, 0).gather(1, reordering)
|
||||
|
||||
out_attn = None
|
||||
if in_attn is not None:
|
||||
_mask = word_del_pred[:, :, None].expand_as(in_attn)
|
||||
_reordering = reordering[:, :, None].expand_as(in_attn)
|
||||
out_attn = in_attn.masked_fill(_mask, 0.).gather(1, _reordering)
|
||||
|
||||
return out_tokens, out_scores, out_attn
|
||||
|
||||
# ------------------------------------------------------------------------------------- #
|
||||
|
||||
@register_model("levenshtein_transformer")
|
||||
class LevenshteinTransformerModel(TracingTransformerModel):
|
||||
def __init__(self, encoder, decoder):
|
||||
super().__init__(encoder, decoder)
|
||||
class LevenshteinTransformerModel(TransformerModel):
|
||||
def __init__(self, args, encoder, decoder):
|
||||
super().__init__(args, encoder, decoder)
|
||||
self.tgt_dict = decoder.dictionary
|
||||
self.bos = decoder.dictionary.bos()
|
||||
self.eos = decoder.dictionary.eos()
|
||||
@ -161,7 +300,7 @@ class LevenshteinTransformerModel(TracingTransformerModel):
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
TracingTransformerModel.add_args(parser)
|
||||
TransformerModel.add_args(parser)
|
||||
parser.add_argument(
|
||||
"--apply-bert-init",
|
||||
action="store_true",
|
||||
@ -171,31 +310,27 @@ class LevenshteinTransformerModel(TracingTransformerModel):
|
||||
"--early-exit",
|
||||
default="6,6,6",
|
||||
type=str,
|
||||
help="number of decoder layers for del_word, ins_mask, ins_word",
|
||||
help="number of decoder layers before word_del, mask_ins, word_ins",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-share-discriminator",
|
||||
action="store_true",
|
||||
help="addtional decoder-layers to learn deletion",
|
||||
help="separate parameters for discriminator",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-share-maskpredictor",
|
||||
action="store_true",
|
||||
help="addtional decoder-layers to learn predicting masks",
|
||||
help="separate parameters for mask-predictor",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--share-discriminator-maskpredictor",
|
||||
action="store_true",
|
||||
help="share the parameters for both mask-predictor and discriminator",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sampling-for-deletion",
|
||||
action="store_true",
|
||||
help="instead of argmax, use sampling to predict the tokens",
|
||||
)
|
||||
# Added for compatibility
|
||||
parser.add_argument(
|
||||
"--decoder-out-embed-dim",
|
||||
default=None,
|
||||
type=int,
|
||||
metavar="N",
|
||||
help="decoder output embedding dimension (bottleneck layer before"
|
||||
"output layer if specified.)",
|
||||
action='store_true',
|
||||
help='instead of argmax, use sampling to predict the tokens'
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -207,7 +342,7 @@ class LevenshteinTransformerModel(TracingTransformerModel):
|
||||
|
||||
@classmethod
|
||||
def build_encoder(cls, args, src_dict, embed_tokens):
|
||||
encoder = TracingTransformerEncoder(args, src_dict, embed_tokens)
|
||||
encoder = TransformerEncoder(args, src_dict, embed_tokens)
|
||||
if getattr(args, "apply_bert_init", False):
|
||||
encoder.apply(init_bert_params)
|
||||
return encoder
|
||||
@ -238,8 +373,8 @@ class LevenshteinTransformerModel(TracingTransformerModel):
|
||||
# make online prediction
|
||||
if self.decoder.sampling_for_deletion:
|
||||
word_predictions = torch.multinomial(
|
||||
F.softmax(word_ins_out, -1).view(-1, word_ins_out.size(-1)), 1
|
||||
).view(word_ins_out.size(0), -1)
|
||||
F.softmax(word_ins_out, -1).view(-1, word_ins_out.size(-1)), 1).view(
|
||||
word_ins_out.size(0), -1)
|
||||
else:
|
||||
word_predictions = F.log_softmax(word_ins_out, dim=-1).max(2)[1]
|
||||
|
||||
@ -249,7 +384,10 @@ class LevenshteinTransformerModel(TracingTransformerModel):
|
||||
|
||||
# generate training labels for deletion
|
||||
word_del_targets = _get_del_targets(word_predictions, tgt_tokens, self.pad)
|
||||
word_del_out, _ = self.decoder.forward_word_del(word_predictions, encoder_out)
|
||||
word_del_out, _ = self.decoder.forward_word_del(
|
||||
word_predictions, encoder_out)
|
||||
word_del_masks = word_predictions.ne(self.pad)
|
||||
|
||||
return {
|
||||
"mask_ins_out": mask_ins_out,
|
||||
"mask_ins_tgt": mask_ins_targets,
|
||||
@ -259,7 +397,7 @@ class LevenshteinTransformerModel(TracingTransformerModel):
|
||||
"word_ins_mask": masked_tgt_masks,
|
||||
"word_del_out": word_del_out,
|
||||
"word_del_tgt": word_del_targets,
|
||||
"word_del_mask": word_predictions.ne(self.pad),
|
||||
"word_del_mask": word_del_masks,
|
||||
}
|
||||
|
||||
def forward_encoder(self, encoder_inputs):
|
||||
@ -269,248 +407,123 @@ class LevenshteinTransformerModel(TracingTransformerModel):
|
||||
self, decoder_out, encoder_out, eos_penalty=0.0, max_ratio=None, **kwargs
|
||||
):
|
||||
|
||||
output_tokens = decoder_out[0]
|
||||
output_scores = decoder_out[1]
|
||||
attn = decoder_out[2]
|
||||
|
||||
if max_ratio is not None and encoder_out[1] is not None:
|
||||
max_lengths = ((~encoder_out[1]).sum(1) * max_ratio).clamp(min=10)
|
||||
output_tokens = decoder_out.output_tokens
|
||||
output_scores = decoder_out.output_scores
|
||||
attn = decoder_out.attn
|
||||
|
||||
bsz = output_tokens.size(0)
|
||||
if max_ratio is None:
|
||||
max_lens = torch.zeros_like(output_tokens).fill_(255)
|
||||
else:
|
||||
max_lengths = torch.zeros(output_tokens.size(0)).fill_(255)
|
||||
if encoder_out.encoder_padding_mask is None:
|
||||
max_src_len = encoder_out.encoder_out.size(1)
|
||||
src_lens = encoder_out.encoder_out.new(bsz).fill_(max_src_len)
|
||||
else:
|
||||
src_lens = (~encoder_out.encoder_padding_mask).sum(1)
|
||||
max_lens = (src_lens * max_ratio).clamp(min=10).long()
|
||||
|
||||
@torch.jit.script
|
||||
def del_word(
|
||||
output_tokens,
|
||||
output_scores,
|
||||
attn: Tensor,
|
||||
word_del_attn: Optional[Tensor],
|
||||
word_del_out,
|
||||
can_del_word,
|
||||
pad_idx: int,
|
||||
bos_idx: int,
|
||||
eos_idx: int,
|
||||
):
|
||||
# delete words
|
||||
# do not delete tokens if it is <s> </s>
|
||||
if can_del_word.sum() != 0: # we cannot delete, skip
|
||||
word_del_score = F.log_softmax(word_del_out, 2)
|
||||
word_del_pred = torch.jit.Attribute(word_del_score.max(-1)[1], bool)
|
||||
in_tokens = output_tokens[can_del_word]
|
||||
in_scores = output_scores[can_del_word]
|
||||
# apply deletion to a tensor
|
||||
in_masks = in_tokens.ne(pad_idx)
|
||||
bos_eos_masks = in_tokens.eq(bos_idx) | in_tokens.eq(eos_idx)
|
||||
|
||||
max_len = in_tokens.size(1)
|
||||
word_del_pred.masked_fill_(~in_masks, 1)
|
||||
word_del_pred.masked_fill_(bos_eos_masks, 0)
|
||||
|
||||
reordering = (
|
||||
torch.arange(max_len)[None, :]
|
||||
.expand_as(in_tokens)
|
||||
.contiguous()
|
||||
.masked_fill(word_del_pred, max_len)
|
||||
.sort(1)[1]
|
||||
)
|
||||
|
||||
_tokens = in_tokens.masked_fill(word_del_pred, pad_idx).gather(
|
||||
1, reordering
|
||||
)
|
||||
|
||||
_scores = in_scores.masked_fill(word_del_pred, 0).gather(1, reordering)
|
||||
if word_del_attn is not None:
|
||||
_mask = word_del_pred[:, :, None].expand_as(word_del_attn)
|
||||
_reordering = reordering[:, :, None].expand_as(word_del_attn)
|
||||
_attn = word_del_attn.masked_fill(_mask, 0.0).gather(1, _reordering)
|
||||
attn = _fill(attn, can_del_word, _attn, 0)
|
||||
|
||||
output_tokens = _fill(output_tokens, can_del_word, _tokens, pad_idx)
|
||||
output_scores = _fill(output_scores, can_del_word, _scores, 0)
|
||||
return output_tokens, output_scores, attn
|
||||
|
||||
@torch.jit.script
|
||||
def ins_placeholders(
|
||||
output_tokens,
|
||||
output_scores,
|
||||
mask_ins_out,
|
||||
can_ins_mask,
|
||||
pad_idx: int,
|
||||
unk_idx: int,
|
||||
eos_idx: int,
|
||||
max_ratio: float,
|
||||
max_lengths,
|
||||
):
|
||||
# insert placeholders
|
||||
if can_ins_mask.sum() != 0:
|
||||
mask_ins_score = F.log_softmax(mask_ins_out, 2)
|
||||
if eos_penalty > 0.0:
|
||||
mask_ins_score[:, :, 0] -= eos_penalty
|
||||
mask_ins_pred = mask_ins_score.max(-1)[1]
|
||||
if max_ratio is not None and encoder_out[1] is not None:
|
||||
mask_ins_pred = torch.min(
|
||||
mask_ins_pred, max_lengths[can_ins_mask, None].expand_as(mask_ins_pred)
|
||||
)
|
||||
in_tokens = output_tokens[can_ins_mask]
|
||||
in_scores = output_scores[can_ins_mask]
|
||||
in_masks = in_tokens.ne(pad_idx)
|
||||
in_lengths = in_masks.sum(1)
|
||||
|
||||
# HACK: hacky way to shift all the paddings to eos first.
|
||||
in_tokens.masked_fill_(~in_masks, eos_idx)
|
||||
mask_ins_pred.masked_fill_(~in_masks[:, 1:], 0)
|
||||
|
||||
out_lengths = in_lengths + mask_ins_pred.sum(1)
|
||||
out_max_len = out_lengths.max()
|
||||
out_masks = (
|
||||
torch.arange(out_max_len)[None, :].long() < out_lengths[:, None]
|
||||
)
|
||||
|
||||
reordering = (mask_ins_pred + in_masks[:, 1:].long()).cumsum(1)
|
||||
out_tokens = (
|
||||
torch.zeros(in_tokens.size()[0], out_max_len)
|
||||
.fill_(pad_idx)
|
||||
.masked_fill_(out_masks, unk_idx)
|
||||
)
|
||||
out_tokens = torch.cat([in_tokens[:, :1], out_tokens[:, 1:]], 1)
|
||||
out_tokens.scatter_(1, reordering, in_tokens[:, 1:].float())
|
||||
|
||||
if in_scores is not None:
|
||||
in_scores.masked_fill_(~in_masks, 0)
|
||||
out_scores = torch.zeros_like(out_tokens).to(in_scores)
|
||||
out_tokens = torch.cat([in_tokens[:, :1], out_tokens[:, 1:]], 1)
|
||||
out_scores.scatter_(1, reordering, in_scores[:, 1:])
|
||||
else:
|
||||
out_scores = None
|
||||
output_tokens = _fill(output_tokens, can_ins_mask, out_tokens, pad_idx)
|
||||
output_scores = _fill(output_scores, can_ins_mask, out_scores, 0)
|
||||
return output_tokens, output_scores
|
||||
|
||||
@torch.jit.script
|
||||
def ins_words(
|
||||
output_tokens,
|
||||
output_scores,
|
||||
attn: Tensor,
|
||||
word_ins_attn,
|
||||
word_ins_out,
|
||||
can_ins_word,
|
||||
pad_idx: int,
|
||||
unk_idx: int,
|
||||
):
|
||||
# insert words
|
||||
if can_ins_word.sum() != 0:
|
||||
word_ins_scores = F.log_softmax(word_ins_out, 2)
|
||||
word_ins_pred = word_ins_scores.max(-1)[1]
|
||||
in_tokens = output_tokens[can_ins_word]
|
||||
in_scores = output_scores[can_ins_word]
|
||||
word_ins_masks = in_tokens.eq(unk_idx)
|
||||
out_tokens = in_tokens.masked_scatter(
|
||||
word_ins_masks, word_ins_pred[word_ins_masks].float()
|
||||
)
|
||||
|
||||
if in_scores is not None:
|
||||
out_scores = in_scores.masked_scatter(
|
||||
word_ins_masks, word_ins_scores[word_ins_masks]
|
||||
)
|
||||
else:
|
||||
out_scores = None
|
||||
output_tokens = _fill(output_tokens, can_ins_word, out_tokens, pad_idx)
|
||||
output_scores = _fill(output_scores, can_ins_word, out_scores, 0)
|
||||
attn = _fill(attn, can_ins_word, word_ins_attn, 0)
|
||||
return output_tokens, output_scores, attn
|
||||
|
||||
can_del_word = output_tokens.ne(self.pad).sum(1) > 2
|
||||
if can_del_word.sum() != 0: # we cannot delete, skip
|
||||
word_del_out, word_del_attn = self.decoder.forward_word_del(
|
||||
script_skip_tensor(output_tokens, can_del_word),
|
||||
script_skip_tensor_list(list(encoder_out), can_del_word),
|
||||
_skip(output_tokens, can_del_word),
|
||||
_skip_encoder_out(self.encoder, encoder_out, can_del_word)
|
||||
)
|
||||
word_del_score = F.log_softmax(word_del_out, 2)
|
||||
word_del_pred = word_del_score.max(-1)[1].bool()
|
||||
|
||||
output_tokens, output_scores, attn = del_word(
|
||||
output_tokens,
|
||||
output_scores,
|
||||
attn,
|
||||
_tokens, _scores, _attn = _apply_del_words(
|
||||
output_tokens[can_del_word],
|
||||
output_scores[can_del_word],
|
||||
word_del_attn,
|
||||
word_del_out,
|
||||
can_del_word,
|
||||
word_del_pred,
|
||||
self.pad,
|
||||
self.bos,
|
||||
self.eos,
|
||||
)
|
||||
output_tokens = _fill(output_tokens, can_del_word, _tokens, self.pad)
|
||||
output_scores = _fill(output_scores, can_del_word, _scores, 0)
|
||||
attn = _fill(attn, can_del_word, _attn, 0.)
|
||||
|
||||
can_ins_mask = output_tokens.ne(self.pad).sum(1) < max_lengths
|
||||
# insert placeholders
|
||||
can_ins_mask = output_tokens.ne(self.pad).sum(1) < max_lens
|
||||
if can_ins_mask.sum() != 0:
|
||||
mask_ins_out, _ = self.decoder.forward_mask_ins(
|
||||
script_skip_tensor(output_tokens, can_ins_mask),
|
||||
script_skip_tensor_list(encoder_out, can_ins_mask),
|
||||
_skip(output_tokens, can_ins_mask),
|
||||
_skip_encoder_out(self.encoder, encoder_out, can_ins_mask)
|
||||
)
|
||||
output_tokens, output_scores = ins_placeholders(
|
||||
output_tokens,
|
||||
output_scores,
|
||||
mask_ins_out,
|
||||
can_ins_mask,
|
||||
mask_ins_score = F.log_softmax(mask_ins_out, 2)
|
||||
if eos_penalty > 0.0:
|
||||
mask_ins_score[:, :, 0] = mask_ins_score[:, :, 0] - eos_penalty
|
||||
mask_ins_pred = mask_ins_score.max(-1)[1]
|
||||
mask_ins_pred = torch.min(
|
||||
mask_ins_pred, max_lens[can_ins_mask, None].expand_as(mask_ins_pred)
|
||||
)
|
||||
|
||||
_tokens, _scores = _apply_ins_masks(
|
||||
output_tokens[can_ins_mask],
|
||||
output_scores[can_ins_mask],
|
||||
mask_ins_pred,
|
||||
self.pad,
|
||||
self.unk,
|
||||
self.eos,
|
||||
max_ratio,
|
||||
max_lengths,
|
||||
)
|
||||
output_tokens = _fill(output_tokens, can_ins_mask, _tokens, self.pad)
|
||||
output_scores = _fill(output_scores, can_ins_mask, _scores, 0)
|
||||
|
||||
# insert words
|
||||
can_ins_word = output_tokens.eq(self.unk).sum(1) > 0
|
||||
if can_ins_word.sum() != 0:
|
||||
word_ins_out, word_ins_attn = self.decoder.forward_word_ins(
|
||||
script_skip_tensor(output_tokens, can_ins_word),
|
||||
script_skip_tensor_list(encoder_out, can_ins_word),
|
||||
_skip(output_tokens, can_ins_word),
|
||||
_skip_encoder_out(self.encoder, encoder_out, can_ins_word)
|
||||
)
|
||||
word_ins_score, word_ins_pred = F.log_softmax(word_ins_out, 2).max(-1)
|
||||
word_ins_pred = word_ins_score.max(-1)[1]
|
||||
|
||||
|
||||
output_tokens, output_scores, attn = ins_words(
|
||||
output_tokens,
|
||||
output_scores,
|
||||
attn,
|
||||
word_ins_attn,
|
||||
word_ins_out,
|
||||
can_ins_word,
|
||||
self.pad,
|
||||
_tokens, _scores = _apply_ins_words(
|
||||
output_tokens[can_ins_word],
|
||||
output_scores[can_ins_word],
|
||||
word_ins_pred,
|
||||
word_ins_score,
|
||||
self.unk,
|
||||
)
|
||||
|
||||
output_tokens = _fill(output_tokens, can_ins_word, _tokens, self.pad)
|
||||
output_scores = _fill(output_scores, can_ins_word, _scores, 0)
|
||||
attn = _fill(attn, can_ins_word, word_ins_attn, 0.)
|
||||
|
||||
# delete some unnecessary paddings
|
||||
cut_off = output_tokens.ne(self.pad).sum(1).max()
|
||||
output_tokens = output_tokens[:, :cut_off]
|
||||
output_scores = output_scores[:, :cut_off]
|
||||
attn = None if attn is None else attn[:, :cut_off, :]
|
||||
|
||||
@torch.jit.script
|
||||
def slice_wrap(x, l):
|
||||
return x[:, :l]
|
||||
|
||||
@torch.jit.script
|
||||
def slice_wrap_attn(x, l):
|
||||
return x if x.size()[0] == 0 else x[:, :l, :]
|
||||
|
||||
output_tokens = slice_wrap(output_tokens, cut_off)
|
||||
output_scores = slice_wrap(output_scores, cut_off)
|
||||
attn = slice_wrap(attn, cut_off)
|
||||
return [output_tokens, output_scores, attn, 0, 0]
|
||||
return decoder_out._replace(
|
||||
output_tokens=output_tokens,
|
||||
output_scores=output_scores,
|
||||
attn=attn,
|
||||
)
|
||||
|
||||
def initialize_output_tokens(self, encoder_out, src_tokens):
|
||||
initial_output_tokens = torch.cat(
|
||||
[
|
||||
torch.zeros(src_tokens.size(0), 1).fill_(self.bos),
|
||||
torch.zeros(src_tokens.size(0), 1).fill_(self.eos),
|
||||
],
|
||||
1,
|
||||
initial_output_tokens = src_tokens.new_zeros(src_tokens.size(0), 2)
|
||||
initial_output_tokens[:, 0] = self.bos
|
||||
initial_output_tokens[:, 1] = self.eos
|
||||
|
||||
initial_output_scores = initial_output_tokens.new_zeros(
|
||||
*initial_output_tokens.size()
|
||||
).type_as(encoder_out.encoder_out)
|
||||
return DecoderOut(
|
||||
output_tokens=initial_output_tokens,
|
||||
output_scores=initial_output_scores,
|
||||
attn=None,
|
||||
step=0,
|
||||
max_step=0,
|
||||
)
|
||||
|
||||
initial_output_scores = torch.zeros_like(initial_output_tokens).to(
|
||||
encoder_out[0]
|
||||
)
|
||||
|
||||
initial_attn = torch.empty([0])
|
||||
if getattr(self.decoder.layers[-1], "need_attn", True):
|
||||
initial_attn = torch.zeros([src_tokens.size(0), 2, src_tokens.size(1)]).to(
|
||||
initial_output_tokens
|
||||
)
|
||||
|
||||
return [initial_output_tokens, initial_output_scores, initial_attn, 0, 0]
|
||||
|
||||
|
||||
class LevenshteinTransformerDecoder(TracingTransformerDecoder):
|
||||
class LevenshteinTransformerDecoder(TransformerDecoder):
|
||||
def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False):
|
||||
super().__init__(
|
||||
args, dictionary, embed_tokens, no_encoder_attn=no_encoder_attn
|
||||
@ -524,38 +537,32 @@ class LevenshteinTransformerDecoder(TracingTransformerDecoder):
|
||||
self.embed_word_del = Embedding(2, self.output_embed_dim, None)
|
||||
|
||||
# del_word, ins_mask, ins_word
|
||||
self.early_exit = [int(i) for i in args.early_exit.split(",")]
|
||||
self.early_exit = [int(i) for i in args.early_exit.split(',')]
|
||||
assert len(self.early_exit) == 3
|
||||
|
||||
# copy layers for mask-predict/deletion
|
||||
self.layers_msk = None
|
||||
if getattr(args, "no_share_maskpredictor", False):
|
||||
self.layers_msk = nn.ModuleList(
|
||||
[
|
||||
self.layers_msk = nn.ModuleList([
|
||||
TransformerDecoderLayer(args, no_encoder_attn)
|
||||
for _ in range(self.early_exit[1])
|
||||
]
|
||||
)
|
||||
])
|
||||
self.layers_del = None
|
||||
if getattr(args, "no_share_discriminator", False):
|
||||
self.layers_del = nn.ModuleList(
|
||||
[
|
||||
self.layers_del = nn.ModuleList([
|
||||
TransformerDecoderLayer(args, no_encoder_attn)
|
||||
for _ in range(self.early_exit[0])
|
||||
]
|
||||
)
|
||||
])
|
||||
|
||||
if getattr(args, "share_discriminator_maskpredictor", False):
|
||||
assert getattr(args, "no_share_discriminator", False), "must set saperate discriminator"
|
||||
self.layers_msk = self.layers_del
|
||||
|
||||
def extract_features(
|
||||
self,
|
||||
prev_output_tokens,
|
||||
encoder_out=None,
|
||||
early_exit=None,
|
||||
layers=None,
|
||||
**unused
|
||||
self, prev_output_tokens, encoder_out=None, early_exit=None, layers=None, **unused
|
||||
):
|
||||
"""
|
||||
Similar to *forward* but only return features.
|
||||
|
||||
Inputs:
|
||||
prev_output_tokens: Tensor(B, T)
|
||||
encoder_out: a dictionary of hidden states and masks
|
||||
@ -574,7 +581,7 @@ class LevenshteinTransformerDecoder(TracingTransformerDecoder):
|
||||
)
|
||||
|
||||
# embed tokens and positions
|
||||
x = self.embed_scale * self.embed_tokens(prev_output_tokens.long())
|
||||
x = self.embed_scale * self.embed_tokens(prev_output_tokens)
|
||||
if self.project_in_dim is not None:
|
||||
x = self.project_in_dim(x)
|
||||
|
||||
@ -594,8 +601,8 @@ class LevenshteinTransformerDecoder(TracingTransformerDecoder):
|
||||
for _, layer in enumerate(layers[: early_exit]):
|
||||
x, attn = layer(
|
||||
x,
|
||||
encoder_out[0] if encoder_out is not None else None,
|
||||
encoder_out[1] if encoder_out is not None else None,
|
||||
encoder_out.encoder_out if encoder_out is not None else None,
|
||||
encoder_out.encoder_padding_mask if encoder_out is not None else None,
|
||||
self_attn_mask=None,
|
||||
self_attn_padding_mask=decoder_padding_mask,
|
||||
)
|
||||
@ -610,38 +617,26 @@ class LevenshteinTransformerDecoder(TracingTransformerDecoder):
|
||||
if self.project_out_dim is not None:
|
||||
x = self.project_out_dim(x)
|
||||
|
||||
return x, attn, inner_states
|
||||
return x, {"attn": attn, "inner_states": inner_states}
|
||||
|
||||
def forward_mask_ins(self, prev_output_tokens, encoder_out=None, **unused):
|
||||
features, attn, _ = self.extract_features(
|
||||
prev_output_tokens,
|
||||
encoder_out=encoder_out,
|
||||
early_exit=self.early_exit[1],
|
||||
layers=self.layers_msk,
|
||||
**unused
|
||||
features, extra = self.extract_features(
|
||||
prev_output_tokens, encoder_out=encoder_out, early_exit=self.early_exit[1], layers=self.layers_msk, **unused
|
||||
)
|
||||
features_cat = torch.cat([features[:, :-1, :], features[:, 1:, :]], 2)
|
||||
return F.linear(features_cat, self.embed_mask_ins.weight), attn
|
||||
return F.linear(features_cat, self.embed_mask_ins.weight), extra['attn']
|
||||
|
||||
def forward_word_ins(self, prev_output_tokens, encoder_out=None, **unused):
|
||||
features, attn, _ = self.extract_features(
|
||||
prev_output_tokens,
|
||||
encoder_out=encoder_out,
|
||||
early_exit=self.early_exit[2],
|
||||
layers=self.layers,
|
||||
**unused
|
||||
features, extra = self.extract_features(
|
||||
prev_output_tokens, encoder_out=encoder_out, early_exit=self.early_exit[2], layers=self.layers, **unused
|
||||
)
|
||||
return self.output_layer(features), attn
|
||||
return self.output_layer(features), extra['attn']
|
||||
|
||||
def forward_word_del(self, prev_output_tokens, encoder_out=None, **unused):
|
||||
features, attn, _ = self.extract_features(
|
||||
prev_output_tokens,
|
||||
encoder_out=encoder_out,
|
||||
early_exit=self.early_exit[0],
|
||||
layers=self.layers_del,
|
||||
**unused
|
||||
features, extra = self.extract_features(
|
||||
prev_output_tokens, encoder_out=encoder_out, early_exit=self.early_exit[0], layers=self.layers_del, **unused
|
||||
)
|
||||
return F.linear(features, self.embed_word_del.weight), attn
|
||||
return F.linear(features, self.embed_word_del.weight), extra['attn']
|
||||
|
||||
|
||||
@register_model_architecture("levenshtein_transformer", "levenshtein_transformer")
|
||||
@ -671,7 +666,7 @@ def base_architecture(args):
|
||||
args.share_decoder_input_output_embed = getattr(
|
||||
args, "share_decoder_input_output_embed", False
|
||||
)
|
||||
args.share_all_embeddings = getattr(args, "share_all_embeddings", True)
|
||||
args.share_all_embeddings = getattr(args, "share_all_embeddings", False)
|
||||
args.no_token_positional_embeddings = getattr(
|
||||
args, "no_token_positional_embeddings", False
|
||||
)
|
||||
@ -686,6 +681,8 @@ def base_architecture(args):
|
||||
args.early_exit = getattr(args, "early_exit", "6,6,6")
|
||||
args.no_share_discriminator = getattr(args, "no_share_discriminator", False)
|
||||
args.no_share_maskpredictor = getattr(args, "no_share_maskpredictor", False)
|
||||
args.share_discriminator_maskpredictor = getattr(args, "share_discriminator_maskpredictor", False)
|
||||
args.no_share_last_layer = getattr(args, "no_share_last_layer", False)
|
||||
|
||||
|
||||
@register_model_architecture(
|
||||
|
@ -3,7 +3,7 @@
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from typing import Dict, List
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
@ -33,39 +33,6 @@ def script_skip_tensor(x: Tensor, mask):
|
||||
return res
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def script_skip_tensor_dict(x: Dict[str, Tensor], mask):
|
||||
outputs = {}
|
||||
for s, t in x.items():
|
||||
outputs[s] = t[mask] if t.size(0) == mask.size(0) else t[:, mask]
|
||||
return outputs
|
||||
|
||||
|
||||
def skip_tensors(x, mask):
|
||||
"""
|
||||
Getting sliced (dim=0) tensor by mask. Supporting tensor and list/dict of tensors.
|
||||
"""
|
||||
if isinstance(x, int):
|
||||
return x
|
||||
|
||||
if x is None:
|
||||
return None
|
||||
|
||||
if isinstance(x, torch.Tensor):
|
||||
if x.size(0) == mask.size(0):
|
||||
return x[mask]
|
||||
elif x.size(1) == mask.size(0):
|
||||
return x[:, mask]
|
||||
|
||||
if isinstance(x, list):
|
||||
return [skip_tensors(x_i, mask) for x_i in x]
|
||||
|
||||
if isinstance(x, dict):
|
||||
return {k: skip_tensors(v, mask) for k, v in x.items()}
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def expand_2d_or_3d_tensor(x, trg_dim: int, padding_idx: int):
|
||||
"""
|
||||
@ -88,12 +55,17 @@ def expand_2d_or_3d_tensor(x, trg_dim: int, padding_idx: int):
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def fill_tensors(x, mask, y, padding_idx: int):
|
||||
def coalesce(x: Optional[Tensor], y: Tensor) -> Tensor:
|
||||
return x if x is not None else y
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def fill_tensors(x: Optional[Tensor], mask, y: Optional[Tensor], padding_idx: int) -> Optional[Tensor]:
|
||||
"""
|
||||
Filling tensor x with y at masked positions (dim=0).
|
||||
"""
|
||||
if x is None or x.size()[0] == 0:
|
||||
return torch.empty([0])
|
||||
if x is None or x.size()[0] == 0 or y is None:
|
||||
return x
|
||||
assert x.dim() == y.dim() and mask.size(0) == x.size(0)
|
||||
assert x.dim() == 2 or (x.dim() == 3 and x.size(2) == y.size(2))
|
||||
|
||||
@ -116,88 +88,3 @@ def fill_tensors(x, mask, y, padding_idx: int):
|
||||
else:
|
||||
x[mask] = y
|
||||
return x
|
||||
|
||||
|
||||
def _apply_ins_masks(
|
||||
in_tokens, in_scores, mask_ins_pred, padding_idx, unk_idx, eos_idx
|
||||
):
|
||||
|
||||
in_masks = in_tokens.ne(padding_idx)
|
||||
in_lengths = in_masks.sum(1)
|
||||
|
||||
# HACK: hacky way to shift all the paddings to eos first.
|
||||
in_tokens.masked_fill_(~in_masks, eos_idx)
|
||||
mask_ins_pred.masked_fill_(~in_masks[:, 1:], 0)
|
||||
|
||||
out_lengths = in_lengths + mask_ins_pred.sum(1)
|
||||
out_max_len = out_lengths.max()
|
||||
out_masks = (
|
||||
torch.arange(out_max_len, device=out_lengths.device)[None, :]
|
||||
< out_lengths[:, None]
|
||||
)
|
||||
|
||||
reordering = (mask_ins_pred + in_masks[:, 1:].long()).cumsum(1)
|
||||
out_tokens = (
|
||||
in_tokens.new_zeros(in_tokens.size(0), out_max_len)
|
||||
.fill_(padding_idx)
|
||||
.masked_fill_(out_masks, unk_idx)
|
||||
)
|
||||
out_tokens[:, 0] = in_tokens[:, 0]
|
||||
out_tokens.scatter_(1, reordering, in_tokens[:, 1:])
|
||||
|
||||
out_scores = None
|
||||
if in_scores is not None:
|
||||
in_scores.masked_fill_(~in_masks, 0)
|
||||
out_scores = in_scores.new_zeros(*out_tokens.size())
|
||||
out_scores[:, 0] = in_scores[:, 0]
|
||||
out_scores.scatter_(1, reordering, in_scores[:, 1:])
|
||||
|
||||
return out_tokens, out_scores
|
||||
|
||||
|
||||
def _apply_ins_words(in_tokens, in_scores, word_ins_pred, word_ins_scores, unk_idx):
|
||||
word_ins_masks = in_tokens.eq(unk_idx)
|
||||
out_tokens = in_tokens.masked_scatter(word_ins_masks, word_ins_pred[word_ins_masks])
|
||||
|
||||
if in_scores is not None:
|
||||
out_scores = in_scores.masked_scatter(
|
||||
word_ins_masks, word_ins_scores[word_ins_masks]
|
||||
)
|
||||
else:
|
||||
out_scores = None
|
||||
|
||||
return out_tokens, out_scores
|
||||
|
||||
|
||||
def _apply_del_words(
|
||||
in_tokens, in_scores, in_attn, word_del_pred, padding_idx, bos_idx, eos_idx
|
||||
):
|
||||
# apply deletion to a tensor
|
||||
in_masks = in_tokens.ne(padding_idx)
|
||||
bos_eos_masks = in_tokens.eq(bos_idx) | in_tokens.eq(eos_idx)
|
||||
|
||||
max_len = in_tokens.size(1)
|
||||
word_del_pred.masked_fill_(~in_masks, 1)
|
||||
word_del_pred.masked_fill_(bos_eos_masks, 0)
|
||||
|
||||
reordering = (
|
||||
torch.arange(max_len, device=in_tokens.device)[None, :]
|
||||
.expand_as(in_tokens)
|
||||
.contiguous()
|
||||
.masked_fill_(word_del_pred, max_len)
|
||||
.sort(1)[1]
|
||||
)
|
||||
|
||||
out_tokens = in_tokens.masked_fill(word_del_pred, padding_idx).gather(1, reordering)
|
||||
|
||||
out_scores = None
|
||||
if in_scores is not None:
|
||||
out_scores = in_scores.masked_fill(word_del_pred, 0).gather(1, reordering)
|
||||
|
||||
out_attn = None
|
||||
if in_attn is not None:
|
||||
_mask = word_del_pred[:, :, None].expand_as(in_attn)
|
||||
_reordering = reordering[:, :, None].expand_as(in_attn)
|
||||
out_attn = in_attn.masked_fill(_mask, 0.).gather(1, _reordering)
|
||||
|
||||
return out_tokens, out_scores, out_attn
|
||||
|
@ -3,11 +3,18 @@
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import math
|
||||
from fairseq.models.model_utils import fill_tensors as _fill, skip_tensors as _skip
|
||||
from fairseq.models.model_utils import _apply_del_words, _apply_ins_masks, _apply_ins_words
|
||||
|
||||
from fairseq.models.levenshtein_transformer import (
|
||||
_skip,
|
||||
_apply_ins_masks,
|
||||
_apply_ins_words,
|
||||
_apply_del_words,
|
||||
)
|
||||
from fairseq.models.model_utils import fill_tensors as _fill
|
||||
|
||||
|
||||
class BasicEnsembleModel(torch.nn.Module):
|
||||
|
@ -5,7 +5,9 @@
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from fairseq import utils
|
||||
from fairseq.iterative_refinement_generator import DecoderOut
|
||||
from fairseq.models import register_model, register_model_architecture
|
||||
from fairseq.models.transformer import (
|
||||
Embedding,
|
||||
@ -45,8 +47,8 @@ def _uniform_assignment(src_lens, trg_lens):
|
||||
|
||||
@register_model("nonautoregressive_transformer")
|
||||
class NATransformerModel(TransformerModel):
|
||||
def __init__(self, encoder, decoder):
|
||||
super().__init__(encoder, decoder)
|
||||
def __init__(self, args, encoder, decoder):
|
||||
super().__init__(args, encoder, decoder)
|
||||
self.tgt_dict = decoder.dictionary
|
||||
self.bos = decoder.dictionary.bos()
|
||||
self.eos = decoder.dictionary.eos()
|
||||
@ -112,9 +114,9 @@ class NATransformerModel(TransformerModel):
|
||||
return self.encoder(*encoder_inputs)
|
||||
|
||||
def forward_decoder(self, decoder_out, encoder_out, decoding_format=None, **kwargs):
|
||||
step = decoder_out["step"]
|
||||
output_tokens = decoder_out["output_tokens"]
|
||||
output_scores = decoder_out["output_scores"]
|
||||
step = decoder_out.step
|
||||
output_tokens = decoder_out.output_tokens
|
||||
output_scores = decoder_out.output_scores
|
||||
|
||||
# execute the decoder
|
||||
output_masks = output_tokens.ne(self.pad)
|
||||
@ -127,12 +129,16 @@ class NATransformerModel(TransformerModel):
|
||||
output_tokens.masked_scatter_(output_masks, _tokens[output_masks])
|
||||
output_scores.masked_scatter_(output_masks, _scores[output_masks])
|
||||
|
||||
return {"output_tokens": output_tokens, "output_scores": output_scores, "attn": None}
|
||||
return decoder_out._replace(
|
||||
output_tokens=output_tokens,
|
||||
output_scores=output_scores,
|
||||
attn=None,
|
||||
)
|
||||
|
||||
def initialize_output_tokens(self, encoder_out, src_tokens):
|
||||
# length prediction
|
||||
_, length_tgt = self.decoder.forward_length_prediction(encoder_out)
|
||||
max_length = length_tgt.max()
|
||||
max_length = length_tgt.clamp_(min=2).max()
|
||||
idx_length = utils.new_arange(src_tokens, max_length)
|
||||
|
||||
initial_output_tokens = src_tokens.new_zeros(
|
||||
@ -146,13 +152,15 @@ class NATransformerModel(TransformerModel):
|
||||
|
||||
initial_output_scores = initial_output_tokens.new_zeros(
|
||||
*initial_output_tokens.size()
|
||||
).type_as(encoder_out["encoder_out"])
|
||||
).type_as(encoder_out.encoder_out)
|
||||
|
||||
return {
|
||||
"output_tokens": initial_output_tokens,
|
||||
"output_scores": initial_output_scores,
|
||||
"attn": None
|
||||
}
|
||||
return DecoderOut(
|
||||
output_tokens=initial_output_tokens,
|
||||
output_scores=initial_output_scores,
|
||||
attn=None,
|
||||
step=0,
|
||||
max_step=0,
|
||||
)
|
||||
|
||||
|
||||
class NATransformerDecoder(TransformerDecoder):
|
||||
@ -220,8 +228,8 @@ class NATransformerDecoder(TransformerDecoder):
|
||||
"""
|
||||
# embedding
|
||||
if embedding_copy:
|
||||
src_embd = encoder_out["encoder_embedding"]
|
||||
src_mask = encoder_out["encoder_padding_mask"]
|
||||
src_embd = encoder_out.encoder_embedding
|
||||
src_mask = encoder_out.encoder_padding_mask
|
||||
src_mask = (
|
||||
~src_mask
|
||||
if src_mask is not None
|
||||
@ -253,10 +261,8 @@ class NATransformerDecoder(TransformerDecoder):
|
||||
|
||||
x, attn = layer(
|
||||
x,
|
||||
encoder_out["encoder_out"] if encoder_out is not None else None,
|
||||
encoder_out["encoder_padding_mask"]
|
||||
if encoder_out is not None
|
||||
else None,
|
||||
encoder_out.encoder_out if encoder_out is not None else None,
|
||||
encoder_out.encoder_padding_mask if encoder_out is not None else None,
|
||||
self_attn_mask=None,
|
||||
self_attn_padding_mask=decoder_padding_mask,
|
||||
)
|
||||
@ -311,8 +317,8 @@ class NATransformerDecoder(TransformerDecoder):
|
||||
return copied_embedding
|
||||
|
||||
def forward_length_prediction(self, encoder_out, tgt_tokens=None):
|
||||
enc_feats = encoder_out["encoder_out"] # T x B x C
|
||||
src_masks = encoder_out["encoder_padding_mask"] # B x T or None
|
||||
enc_feats = encoder_out.encoder_out # T x B x C
|
||||
src_masks = encoder_out.encoder_padding_mask # B x T or None
|
||||
|
||||
if self.pred_length_offset:
|
||||
if src_masks is None:
|
||||
|
@ -348,7 +348,7 @@ class RobertaEncoder(FairseqDecoder):
|
||||
- a dictionary of additional data, where 'inner_states'
|
||||
is a list of hidden states.
|
||||
"""
|
||||
x, extra = self.extract_features(src_tokens, return_all_hiddens)
|
||||
x, extra = self.extract_features(src_tokens, return_all_hiddens=return_all_hiddens)
|
||||
if not features_only:
|
||||
x = self.output_layer(x, masked_tokens=masked_tokens)
|
||||
return x, extra
|
||||
|
@ -1,625 +0,0 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from fairseq import options, utils
|
||||
from fairseq.models import (
|
||||
FairseqEncoder,
|
||||
FairseqIncrementalDecoder,
|
||||
FairseqEncoderDecoderModel,
|
||||
register_model,
|
||||
register_model_architecture,
|
||||
)
|
||||
from fairseq.models.transformer import Embedding, Linear, base_architecture
|
||||
from fairseq.modules import (
|
||||
AdaptiveSoftmax,
|
||||
LayerNorm,
|
||||
PositionalEmbedding,
|
||||
SinusoidalPositionalEmbedding,
|
||||
TransformerDecoderLayer,
|
||||
TransformerEncoderLayer,
|
||||
)
|
||||
|
||||
DEFAULT_MAX_SOURCE_POSITIONS = 1024
|
||||
DEFAULT_MAX_TARGET_POSITIONS = 1024
|
||||
|
||||
|
||||
@register_model('tracing_transformer')
|
||||
class TracingTransformerModel(FairseqEncoderDecoderModel):
|
||||
"""
|
||||
Transformer model from `"Attention Is All You Need" (Vaswani, et al, 2017)
|
||||
<https://arxiv.org/abs/1706.03762>`_.
|
||||
|
||||
Args:
|
||||
encoder (TransformerEncoder): the encoder
|
||||
decoder (TransformerDecoder): the decoder
|
||||
|
||||
The Transformer model provides the following named architectures and
|
||||
command-line arguments:
|
||||
|
||||
.. argparse::
|
||||
:ref: fairseq.models.transformer_parser
|
||||
:prog:
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def hub_models(cls):
|
||||
# fmt: off
|
||||
return {
|
||||
'transformer.wmt14.en-fr': 'https://dl.fbaipublicfiles.com/fairseq/models/wmt14.en-fr.joined-dict.transformer.tar.bz2',
|
||||
'transformer.wmt16.en-de': 'https://dl.fbaipublicfiles.com/fairseq/models/wmt16.en-de.joined-dict.transformer.tar.bz2',
|
||||
'transformer.wmt18.en-de': 'https://dl.fbaipublicfiles.com/fairseq/models/wmt18.en-de.ensemble.tar.gz',
|
||||
'transformer.wmt19.en-de': 'https://dl.fbaipublicfiles.com/fairseq/models/wmt19.en-de.joined-dict.ensemble.tar.gz',
|
||||
'transformer.wmt19.en-ru': 'https://dl.fbaipublicfiles.com/fairseq/models/wmt19.en-ru.ensemble.tar.gz',
|
||||
'transformer.wmt19.de-en': 'https://dl.fbaipublicfiles.com/fairseq/models/wmt19.de-en.joined-dict.ensemble.tar.gz',
|
||||
'transformer.wmt19.ru-en': 'https://dl.fbaipublicfiles.com/fairseq/models/wmt19.ru-en.ensemble.tar.gz',
|
||||
'transformer.wmt19.en-de.single_model': 'https://dl.fbaipublicfiles.com/fairseq/models/wmt19.en-de.joined-dict.single_model.tar.gz',
|
||||
'transformer.wmt19.en-ru.single_model': 'https://dl.fbaipublicfiles.com/fairseq/models/wmt19.en-ru.single_model.tar.gz',
|
||||
'transformer.wmt19.de-en.single_model': 'https://dl.fbaipublicfiles.com/fairseq/models/wmt19.de-en.joined-dict.single_model.tar.gz',
|
||||
'transformer.wmt19.ru-en.single_model': 'https://dl.fbaipublicfiles.com/fairseq/models/wmt19.ru-en.single_model.tar.gz',
|
||||
}
|
||||
# fmt: on
|
||||
|
||||
def __init__(self, encoder, decoder):
|
||||
super().__init__(encoder, decoder)
|
||||
self.supports_align_args = True
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
"""Add model-specific arguments to the parser."""
|
||||
# fmt: off
|
||||
parser.add_argument('--activation-fn',
|
||||
choices=utils.get_available_activation_fns(),
|
||||
help='activation function to use')
|
||||
parser.add_argument('--dropout', type=float, metavar='D',
|
||||
help='dropout probability')
|
||||
parser.add_argument('--attention-dropout', type=float, metavar='D',
|
||||
help='dropout probability for attention weights')
|
||||
parser.add_argument('--activation-dropout', '--relu-dropout', type=float, metavar='D',
|
||||
help='dropout probability after activation in FFN.')
|
||||
parser.add_argument('--encoder-embed-path', type=str, metavar='STR',
|
||||
help='path to pre-trained encoder embedding')
|
||||
parser.add_argument('--encoder-embed-dim', type=int, metavar='N',
|
||||
help='encoder embedding dimension')
|
||||
parser.add_argument('--encoder-ffn-embed-dim', type=int, metavar='N',
|
||||
help='encoder embedding dimension for FFN')
|
||||
parser.add_argument('--encoder-layers', type=int, metavar='N',
|
||||
help='num encoder layers')
|
||||
parser.add_argument('--encoder-attention-heads', type=int, metavar='N',
|
||||
help='num encoder attention heads')
|
||||
parser.add_argument('--encoder-normalize-before', action='store_true',
|
||||
help='apply layernorm before each encoder block')
|
||||
parser.add_argument('--encoder-learned-pos', action='store_true',
|
||||
help='use learned positional embeddings in the encoder')
|
||||
parser.add_argument('--decoder-embed-path', type=str, metavar='STR',
|
||||
help='path to pre-trained decoder embedding')
|
||||
parser.add_argument('--decoder-embed-dim', type=int, metavar='N',
|
||||
help='decoder embedding dimension')
|
||||
parser.add_argument('--decoder-ffn-embed-dim', type=int, metavar='N',
|
||||
help='decoder embedding dimension for FFN')
|
||||
parser.add_argument('--decoder-layers', type=int, metavar='N',
|
||||
help='num decoder layers')
|
||||
parser.add_argument('--decoder-attention-heads', type=int, metavar='N',
|
||||
help='num decoder attention heads')
|
||||
parser.add_argument('--decoder-learned-pos', action='store_true',
|
||||
help='use learned positional embeddings in the decoder')
|
||||
parser.add_argument('--decoder-normalize-before', action='store_true',
|
||||
help='apply layernorm before each decoder block')
|
||||
parser.add_argument('--share-decoder-input-output-embed', action='store_true',
|
||||
help='share decoder input and output embeddings')
|
||||
parser.add_argument('--share-all-embeddings', action='store_true',
|
||||
help='share encoder, decoder and output embeddings'
|
||||
' (requires shared dictionary and embed dim)')
|
||||
parser.add_argument('--no-token-positional-embeddings', default=False, action='store_true',
|
||||
help='if set, disables positional embeddings (outside self attention)')
|
||||
parser.add_argument('--adaptive-softmax-cutoff', metavar='EXPR',
|
||||
help='comma separated list of adaptive softmax cutoff points. '
|
||||
'Must be used with adaptive_loss criterion'),
|
||||
parser.add_argument('--adaptive-softmax-dropout', type=float, metavar='D',
|
||||
help='sets adaptive softmax dropout for the tail projections')
|
||||
# args for "Cross+Self-Attention for Transformer Models" (Peitz et al., 2019)
|
||||
parser.add_argument('--no-cross-attention', default=False, action='store_true',
|
||||
help='do not perform cross-attention')
|
||||
parser.add_argument('--cross-self-attention', default=False, action='store_true',
|
||||
help='perform cross+self-attention')
|
||||
parser.add_argument('--layer-wise-attention', default=False, action='store_true',
|
||||
help='perform layer-wise attention (cross-attention or cross+self-attention)')
|
||||
# fmt: on
|
||||
|
||||
@classmethod
|
||||
def build_model(cls, args, task):
|
||||
"""Build a new model instance."""
|
||||
|
||||
# make sure all arguments are present in older models
|
||||
base_architecture(args)
|
||||
|
||||
if not hasattr(args, 'max_source_positions'):
|
||||
args.max_source_positions = DEFAULT_MAX_SOURCE_POSITIONS
|
||||
if not hasattr(args, 'max_target_positions'):
|
||||
args.max_target_positions = DEFAULT_MAX_TARGET_POSITIONS
|
||||
|
||||
src_dict, tgt_dict = task.source_dictionary, task.target_dictionary
|
||||
|
||||
def build_embedding(dictionary, embed_dim, path=None):
|
||||
num_embeddings = len(dictionary)
|
||||
padding_idx = dictionary.pad()
|
||||
emb = Embedding(num_embeddings, embed_dim, padding_idx)
|
||||
# if provided, load from preloaded dictionaries
|
||||
if path:
|
||||
embed_dict = utils.parse_embedding(path)
|
||||
utils.load_embedding(embed_dict, dictionary, emb)
|
||||
return emb
|
||||
|
||||
if args.share_all_embeddings:
|
||||
if src_dict != tgt_dict:
|
||||
raise ValueError('--share-all-embeddings requires a joined dictionary')
|
||||
if args.encoder_embed_dim != args.decoder_embed_dim:
|
||||
raise ValueError(
|
||||
'--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim')
|
||||
if args.decoder_embed_path and (
|
||||
args.decoder_embed_path != args.encoder_embed_path):
|
||||
raise ValueError('--share-all-embeddings not compatible with --decoder-embed-path')
|
||||
encoder_embed_tokens = build_embedding(
|
||||
src_dict, args.encoder_embed_dim, args.encoder_embed_path
|
||||
)
|
||||
decoder_embed_tokens = encoder_embed_tokens
|
||||
args.share_decoder_input_output_embed = True
|
||||
else:
|
||||
encoder_embed_tokens = build_embedding(
|
||||
src_dict, args.encoder_embed_dim, args.encoder_embed_path
|
||||
)
|
||||
decoder_embed_tokens = build_embedding(
|
||||
tgt_dict, args.decoder_embed_dim, args.decoder_embed_path
|
||||
)
|
||||
|
||||
encoder = cls.build_encoder(args, src_dict, encoder_embed_tokens)
|
||||
decoder = cls.build_decoder(args, tgt_dict, decoder_embed_tokens)
|
||||
return cls(encoder, decoder)
|
||||
|
||||
@classmethod
|
||||
def build_encoder(cls, args, src_dict, embed_tokens):
|
||||
return TracingTransformerEncoder(args, src_dict, embed_tokens)
|
||||
|
||||
@classmethod
|
||||
def build_decoder(cls, args, tgt_dict, embed_tokens):
|
||||
return TracingTransformerDecoder(
|
||||
args,
|
||||
tgt_dict,
|
||||
embed_tokens,
|
||||
no_encoder_attn=getattr(args, 'no_cross_attention', False),
|
||||
)
|
||||
|
||||
|
||||
class TracingTransformerEncoder(FairseqEncoder):
|
||||
"""
|
||||
Transformer encoder consisting of *args.encoder_layers* layers. Each layer
|
||||
is a :class:`TransformerEncoderLayer`.
|
||||
|
||||
Args:
|
||||
args (argparse.Namespace): parsed command-line arguments
|
||||
dictionary (~fairseq.data.Dictionary): encoding dictionary
|
||||
embed_tokens (torch.nn.Embedding): input embedding
|
||||
"""
|
||||
|
||||
def __init__(self, args, dictionary, embed_tokens):
|
||||
super().__init__(dictionary)
|
||||
self.register_buffer('version', torch.Tensor([3]))
|
||||
|
||||
self.dropout = args.dropout
|
||||
|
||||
embed_dim = embed_tokens.embedding_dim
|
||||
self.padding_idx = embed_tokens.padding_idx
|
||||
self.max_source_positions = args.max_source_positions
|
||||
|
||||
self.embed_tokens = embed_tokens
|
||||
self.embed_scale = math.sqrt(embed_dim)
|
||||
self.embed_positions = PositionalEmbedding(
|
||||
args.max_source_positions, embed_dim, self.padding_idx,
|
||||
learned=args.encoder_learned_pos,
|
||||
) if not args.no_token_positional_embeddings else None
|
||||
|
||||
self.layer_wise_attention = getattr(args, 'layer_wise_attention', False)
|
||||
|
||||
self.layers = nn.ModuleList([])
|
||||
self.layers.extend([
|
||||
TransformerEncoderLayer(args)
|
||||
for i in range(args.encoder_layers)
|
||||
])
|
||||
|
||||
if args.encoder_normalize_before:
|
||||
self.layer_norm = LayerNorm(embed_dim)
|
||||
else:
|
||||
self.layer_norm = None
|
||||
|
||||
def forward_embedding(self, src_tokens):
|
||||
# embed tokens and positions
|
||||
embed = self.embed_scale * self.embed_tokens(src_tokens)
|
||||
if self.embed_positions is not None:
|
||||
x = embed + self.embed_positions(src_tokens)
|
||||
x = F.dropout(x, p=self.dropout, training=self.training)
|
||||
return x, embed
|
||||
|
||||
def forward(self, src_tokens, src_lengths, cls_input=None, return_all_hiddens=False):
|
||||
"""
|
||||
Args:
|
||||
src_tokens (LongTensor): tokens in the source language of shape
|
||||
`(batch, src_len)`
|
||||
src_lengths (torch.LongTensor): lengths of each source sentence of
|
||||
shape `(batch)`
|
||||
return_all_hiddens (bool, optional): also return all of the
|
||||
intermediate hidden states (default: False).
|
||||
|
||||
Returns:
|
||||
dict:
|
||||
- **encoder_out** (Tensor): the last encoder layer's output of
|
||||
shape `(src_len, batch, embed_dim)`
|
||||
- **encoder_padding_mask** (ByteTensor): the positions of
|
||||
padding elements of shape `(batch, src_len)`
|
||||
- **encoder_states** (List[Tensor]): all intermediate
|
||||
hidden states of shape `(src_len, batch, embed_dim)`.
|
||||
Only populated if *return_all_hiddens* is True.
|
||||
"""
|
||||
if self.layer_wise_attention:
|
||||
return_all_hiddens = True
|
||||
|
||||
x, encoder_embedding = self.forward_embedding(src_tokens)
|
||||
|
||||
# B x T x C -> T x B x C
|
||||
x = x.transpose(0, 1)
|
||||
|
||||
# compute padding mask
|
||||
encoder_padding_mask = src_tokens.eq(self.padding_idx)
|
||||
|
||||
encoder_states = [] if return_all_hiddens else None
|
||||
|
||||
# encoder layers
|
||||
for layer in self.layers:
|
||||
x = layer(x, encoder_padding_mask)
|
||||
if return_all_hiddens:
|
||||
encoder_states.append(x)
|
||||
|
||||
if self.layer_norm:
|
||||
x = self.layer_norm(x)
|
||||
if return_all_hiddens:
|
||||
encoder_states[-1] = x
|
||||
if encoder_states is not None:
|
||||
return x, encoder_padding_mask, encoder_embedding, encoder_states
|
||||
else:
|
||||
return x, encoder_padding_mask, encoder_embedding
|
||||
|
||||
def reorder_encoder_out(self, encoder_out, new_order):
|
||||
"""
|
||||
Reorder encoder output according to *new_order*.
|
||||
|
||||
Args:
|
||||
encoder_out: output from the ``forward()`` method
|
||||
new_order (LongTensor): desired order
|
||||
|
||||
Returns:
|
||||
*encoder_out* rearranged according to *new_order*
|
||||
"""
|
||||
# 0: encoder_out
|
||||
# 1: encoder_padding_mask
|
||||
# 2: encoder_states
|
||||
if encoder_out[0] is not None:
|
||||
encoder_out[0] = \
|
||||
encoder_out[0].index_select(1, new_order)
|
||||
if encoder_out[1] is not None:
|
||||
encoder_out[1] = \
|
||||
encoder_out[1].index_select(0, new_order)
|
||||
if len(encoder_out) == 3 and encoder_out[2] is not None:
|
||||
for idx, state in enumerate(encoder_out[2]):
|
||||
encoder_out[2][idx] = state.index_select(1, new_order)
|
||||
return encoder_out
|
||||
|
||||
def max_positions(self):
|
||||
"""Maximum input length supported by the encoder."""
|
||||
if self.embed_positions is None:
|
||||
return self.max_source_positions
|
||||
return min(self.max_source_positions, self.embed_positions.max_positions())
|
||||
|
||||
def buffered_future_mask(self, tensor):
|
||||
dim = tensor.size(0)
|
||||
if not hasattr(self, '_future_mask') or self._future_mask is None or self._future_mask.device != tensor.device:
|
||||
self._future_mask = torch.triu(utils.fill_with_neg_inf(tensor.new(dim, dim)), 1)
|
||||
if self._future_mask.size(0) < dim:
|
||||
self._future_mask = torch.triu(utils.fill_with_neg_inf(self._future_mask.resize_(dim, dim)), 1)
|
||||
return self._future_mask[:dim, :dim]
|
||||
|
||||
def upgrade_state_dict_named(self, state_dict, name):
|
||||
"""Upgrade a (possibly old) state dict for new versions of fairseq."""
|
||||
if isinstance(self.embed_positions, SinusoidalPositionalEmbedding):
|
||||
weights_key = '{}.embed_positions.weights'.format(name)
|
||||
if weights_key in state_dict:
|
||||
del state_dict[weights_key]
|
||||
state_dict['{}.embed_positions._float_tensor'.format(name)] = torch.FloatTensor(1)
|
||||
for i in range(len(self.layers)):
|
||||
# update layer norms
|
||||
self.layers[i].upgrade_state_dict_named(state_dict, "{}.layers.{}".format(name, i))
|
||||
|
||||
version_key = '{}.version'.format(name)
|
||||
if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) < 2:
|
||||
# earlier checkpoints did not normalize after the stack of layers
|
||||
self.layer_norm = None
|
||||
self.normalize = False
|
||||
state_dict[version_key] = torch.Tensor([1])
|
||||
return state_dict
|
||||
|
||||
|
||||
class TracingTransformerDecoder(FairseqIncrementalDecoder):
|
||||
"""
|
||||
Transformer decoder consisting of *args.decoder_layers* layers. Each layer
|
||||
is a :class:`TransformerDecoderLayer`.
|
||||
|
||||
Args:
|
||||
args (argparse.Namespace): parsed command-line arguments
|
||||
dictionary (~fairseq.data.Dictionary): decoding dictionary
|
||||
embed_tokens (torch.nn.Embedding): output embedding
|
||||
no_encoder_attn (bool, optional): whether to attend to encoder outputs
|
||||
(default: False).
|
||||
"""
|
||||
|
||||
def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False):
|
||||
super().__init__(dictionary)
|
||||
self.register_buffer('version', torch.Tensor([3]))
|
||||
|
||||
self.dropout = args.dropout
|
||||
self.share_input_output_embed = args.share_decoder_input_output_embed
|
||||
|
||||
input_embed_dim = embed_tokens.embedding_dim
|
||||
embed_dim = args.decoder_embed_dim
|
||||
self.output_embed_dim = args.decoder_output_dim
|
||||
|
||||
self.padding_idx = embed_tokens.padding_idx
|
||||
self.max_target_positions = args.max_target_positions
|
||||
|
||||
self.embed_tokens = embed_tokens
|
||||
self.embed_scale = math.sqrt(embed_dim) # todo: try with input_embed_dim
|
||||
|
||||
self.project_in_dim = Linear(input_embed_dim, embed_dim, bias=False) if embed_dim != input_embed_dim else None
|
||||
|
||||
self.embed_positions = PositionalEmbedding(
|
||||
args.max_target_positions, embed_dim, self.padding_idx,
|
||||
learned=args.decoder_learned_pos,
|
||||
) if not args.no_token_positional_embeddings else None
|
||||
|
||||
self.cross_self_attention = getattr(args, 'cross_self_attention', False)
|
||||
self.layer_wise_attention = getattr(args, 'layer_wise_attention', False)
|
||||
|
||||
self.layers = nn.ModuleList([])
|
||||
self.layers.extend([
|
||||
TransformerDecoderLayer(args, no_encoder_attn)
|
||||
for _ in range(args.decoder_layers)
|
||||
])
|
||||
|
||||
self.adaptive_softmax = None
|
||||
|
||||
self.project_out_dim = Linear(embed_dim, self.output_embed_dim, bias=False) \
|
||||
if embed_dim != self.output_embed_dim and not args.tie_adaptive_weights else None
|
||||
|
||||
if args.adaptive_softmax_cutoff is not None:
|
||||
self.adaptive_softmax = AdaptiveSoftmax(
|
||||
len(dictionary),
|
||||
self.output_embed_dim,
|
||||
options.eval_str_list(args.adaptive_softmax_cutoff, type=int),
|
||||
dropout=args.adaptive_softmax_dropout,
|
||||
adaptive_inputs=embed_tokens if args.tie_adaptive_weights else None,
|
||||
factor=args.adaptive_softmax_factor,
|
||||
tie_proj=args.tie_adaptive_proj,
|
||||
)
|
||||
elif not self.share_input_output_embed:
|
||||
self.embed_out = nn.Parameter(torch.Tensor(len(dictionary), self.output_embed_dim))
|
||||
nn.init.normal_(self.embed_out, mean=0, std=self.output_embed_dim ** -0.5)
|
||||
|
||||
if args.decoder_normalize_before and not getattr(args, 'no_decoder_final_norm', False):
|
||||
self.layer_norm = LayerNorm(embed_dim)
|
||||
else:
|
||||
self.layer_norm = None
|
||||
|
||||
def forward(
|
||||
self,
|
||||
prev_output_tokens,
|
||||
encoder_out=None,
|
||||
incremental_state=None,
|
||||
features_only=False,
|
||||
**extra_args,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
prev_output_tokens (LongTensor): previous decoder outputs of shape
|
||||
`(batch, tgt_len)`, for teacher forcing
|
||||
encoder_out (Tensor, optional): output from the encoder, used for
|
||||
encoder-side attention
|
||||
incremental_state (dict): dictionary used for storing state during
|
||||
:ref:`Incremental decoding`
|
||||
features_only (bool, optional): only return features without
|
||||
applying output layer (default: False).
|
||||
|
||||
Returns:
|
||||
tuple:
|
||||
- the decoder's output of shape `(batch, tgt_len, vocab)`
|
||||
- a dictionary with any model-specific outputs
|
||||
"""
|
||||
x, extra = self.extract_features(
|
||||
prev_output_tokens, encoder_out, incremental_state, **extra_args,
|
||||
)
|
||||
if not features_only:
|
||||
x = self.output_layer(x)
|
||||
return x, extra
|
||||
|
||||
def extract_features(
|
||||
self,
|
||||
prev_output_tokens,
|
||||
encoder_out=None,
|
||||
incremental_state=None,
|
||||
full_context_alignment=False,
|
||||
alignment_layer=None,
|
||||
alignment_heads=None,
|
||||
**unused,
|
||||
):
|
||||
"""
|
||||
Similar to *forward* but only return features.
|
||||
|
||||
Includes several features from "Jointly Learning to Align and
|
||||
Translate with Transformer Models" (Garg et al., EMNLP 2019).
|
||||
|
||||
Args:
|
||||
full_context_alignment (bool, optional): don't apply
|
||||
auto-regressive mask to self-attention (default: False).
|
||||
alignment_layer (int, optional): return mean alignment over
|
||||
heads at this layer (default: last layer).
|
||||
alignment_heads (int, optional): only average alignment over
|
||||
this many heads (default: all heads).
|
||||
|
||||
Returns:
|
||||
tuple:
|
||||
- the decoder's features of shape `(batch, tgt_len, embed_dim)`
|
||||
- a dictionary with any model-specific outputs
|
||||
"""
|
||||
if alignment_layer is None:
|
||||
alignment_layer = len(self.layers) - 1
|
||||
|
||||
# embed positions
|
||||
positions = self.embed_positions(
|
||||
prev_output_tokens,
|
||||
incremental_state=incremental_state,
|
||||
) if self.embed_positions is not None else None
|
||||
|
||||
if incremental_state is not None:
|
||||
prev_output_tokens = prev_output_tokens[:, -1:]
|
||||
if positions is not None:
|
||||
positions = positions[:, -1:]
|
||||
|
||||
# embed tokens and positions
|
||||
x = self.embed_scale * self.embed_tokens(prev_output_tokens)
|
||||
|
||||
if self.project_in_dim is not None:
|
||||
x = self.project_in_dim(x)
|
||||
|
||||
if positions is not None:
|
||||
x += positions
|
||||
x = F.dropout(x, p=self.dropout, training=self.training)
|
||||
|
||||
# B x T x C -> T x B x C
|
||||
x = x.transpose(0, 1)
|
||||
|
||||
self_attn_padding_mask = prev_output_tokens.eq(self.padding_idx)
|
||||
if not self_attn_padding_mask.any() and not self.cross_self_attention:
|
||||
self_attn_padding_mask = None
|
||||
|
||||
# decoder layers
|
||||
attn = None
|
||||
inner_states = [x]
|
||||
for idx, layer in enumerate(self.layers):
|
||||
encoder_state = None
|
||||
if encoder_out is not None:
|
||||
if self.layer_wise_attention:
|
||||
encoder_state = encoder_out[3][idx]
|
||||
else:
|
||||
encoder_state = encoder_out[0]
|
||||
|
||||
if incremental_state is None and not full_context_alignment:
|
||||
self_attn_mask = self.buffered_future_mask(x)
|
||||
else:
|
||||
self_attn_mask = None
|
||||
|
||||
x, layer_attn = layer(
|
||||
x,
|
||||
encoder_state
|
||||
if encoder_state is not None else None,
|
||||
encoder_out[1]
|
||||
if encoder_out is not None else None,
|
||||
incremental_state,
|
||||
self_attn_mask=self_attn_mask,
|
||||
self_attn_padding_mask=self_attn_padding_mask,
|
||||
need_attn=(idx == alignment_layer),
|
||||
need_head_weights=(idx == alignment_layer),
|
||||
)
|
||||
|
||||
inner_states.append(x)
|
||||
if layer_attn is not None and idx == alignment_layer:
|
||||
attn = layer_attn.float()
|
||||
|
||||
if attn is not None:
|
||||
if alignment_heads is not None:
|
||||
attn = attn[:alignment_heads]
|
||||
|
||||
# average probabilities over heads
|
||||
attn = attn.mean(dim=0)
|
||||
|
||||
if self.layer_norm:
|
||||
x = self.layer_norm(x)
|
||||
|
||||
# T x B x C -> B x T x C
|
||||
x = x.transpose(0, 1)
|
||||
|
||||
if self.project_out_dim is not None:
|
||||
x = self.project_out_dim(x)
|
||||
|
||||
return x, {'attn': attn, 'inner_states': inner_states}
|
||||
|
||||
def output_layer(self, features, **kwargs):
|
||||
"""Project features to the vocabulary size."""
|
||||
if self.adaptive_softmax is None:
|
||||
# project back to size of vocabulary
|
||||
if self.share_input_output_embed:
|
||||
return F.linear(features, self.embed_tokens.weight)
|
||||
else:
|
||||
return F.linear(features, self.embed_out)
|
||||
else:
|
||||
return features
|
||||
|
||||
def max_positions(self):
|
||||
"""Maximum output length supported by the decoder."""
|
||||
if self.embed_positions is None:
|
||||
return self.max_target_positions
|
||||
return min(self.max_target_positions, self.embed_positions.max_positions())
|
||||
|
||||
def buffered_future_mask(self, tensor):
|
||||
dim = tensor.size(0)
|
||||
if (
|
||||
not hasattr(self, '_future_mask')
|
||||
or self._future_mask is None
|
||||
or self._future_mask.device != tensor.device
|
||||
or self._future_mask.size(0) < dim
|
||||
):
|
||||
self._future_mask = torch.triu(utils.fill_with_neg_inf(tensor.new(dim, dim)), 1)
|
||||
return self._future_mask[:dim, :dim]
|
||||
|
||||
def upgrade_state_dict_named(self, state_dict, name):
|
||||
"""Upgrade a (possibly old) state dict for new versions of fairseq."""
|
||||
if isinstance(self.embed_positions, SinusoidalPositionalEmbedding):
|
||||
weights_key = '{}.embed_positions.weights'.format(name)
|
||||
if weights_key in state_dict:
|
||||
del state_dict[weights_key]
|
||||
state_dict['{}.embed_positions._float_tensor'.format(name)] = torch.FloatTensor(1)
|
||||
|
||||
for i in range(len(self.layers)):
|
||||
# update layer norms
|
||||
layer_norm_map = {
|
||||
'0': 'self_attn_layer_norm',
|
||||
'1': 'encoder_attn_layer_norm',
|
||||
'2': 'final_layer_norm'
|
||||
}
|
||||
for old, new in layer_norm_map.items():
|
||||
for m in ('weight', 'bias'):
|
||||
k = '{}.layers.{}.layer_norms.{}.{}'.format(name, i, old, m)
|
||||
if k in state_dict:
|
||||
state_dict['{}.layers.{}.{}.{}'.format(name, i, new, m)] = state_dict[k]
|
||||
del state_dict[k]
|
||||
|
||||
version_key = '{}.version'.format(name)
|
||||
if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) <= 2:
|
||||
# earlier checkpoints did not normalize after the stack of layers
|
||||
self.layer_norm = None
|
||||
self.normalize = False
|
||||
state_dict[version_key] = torch.Tensor([1])
|
||||
|
||||
return state_dict
|
@ -3,6 +3,7 @@
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from collections import namedtuple
|
||||
import math
|
||||
|
||||
import torch
|
||||
@ -279,6 +280,14 @@ class TransformerAlignModel(TransformerModel):
|
||||
return decoder_out
|
||||
|
||||
|
||||
EncoderOut = namedtuple('TransformerEncoderOut', [
|
||||
'encoder_out', # T x B x C
|
||||
'encoder_padding_mask', # B x T
|
||||
'encoder_embedding', # B x T x C
|
||||
'encoder_states', # List[T x B x C]
|
||||
])
|
||||
|
||||
|
||||
class TransformerEncoder(FairseqEncoder):
|
||||
"""
|
||||
Transformer encoder consisting of *args.encoder_layers* layers. Each layer
|
||||
@ -348,11 +357,13 @@ class TransformerEncoder(FairseqEncoder):
|
||||
intermediate hidden states (default: False).
|
||||
|
||||
Returns:
|
||||
dict:
|
||||
namedtuple:
|
||||
- **encoder_out** (Tensor): the last encoder layer's output of
|
||||
shape `(src_len, batch, embed_dim)`
|
||||
- **encoder_padding_mask** (ByteTensor): the positions of
|
||||
padding elements of shape `(batch, src_len)`
|
||||
- **encoder_embedding** (Tensor): the (scaled) embedding lookup
|
||||
of shape `(batch, src_len, embed_dim)`
|
||||
- **encoder_states** (List[Tensor]): all intermediate
|
||||
hidden states of shape `(src_len, batch, embed_dim)`.
|
||||
Only populated if *return_all_hiddens* is True.
|
||||
@ -386,12 +397,12 @@ class TransformerEncoder(FairseqEncoder):
|
||||
if return_all_hiddens:
|
||||
encoder_states[-1] = x
|
||||
|
||||
return {
|
||||
'encoder_out': x, # T x B x C
|
||||
'encoder_padding_mask': encoder_padding_mask, # B x T
|
||||
'encoder_embedding': encoder_embedding, # B x T x C
|
||||
'encoder_states': encoder_states, # List[T x B x C]
|
||||
}
|
||||
return EncoderOut(
|
||||
encoder_out=x, # T x B x C
|
||||
encoder_padding_mask=encoder_padding_mask, # B x T
|
||||
encoder_embedding=encoder_embedding, # B x T x C
|
||||
encoder_states=encoder_states, # List[T x B x C]
|
||||
)
|
||||
|
||||
def reorder_encoder_out(self, encoder_out, new_order):
|
||||
"""
|
||||
@ -404,15 +415,21 @@ class TransformerEncoder(FairseqEncoder):
|
||||
Returns:
|
||||
*encoder_out* rearranged according to *new_order*
|
||||
"""
|
||||
if encoder_out['encoder_out'] is not None:
|
||||
encoder_out['encoder_out'] = \
|
||||
encoder_out['encoder_out'].index_select(1, new_order)
|
||||
if encoder_out['encoder_padding_mask'] is not None:
|
||||
encoder_out['encoder_padding_mask'] = \
|
||||
encoder_out['encoder_padding_mask'].index_select(0, new_order)
|
||||
if encoder_out.get('encoder_states', None) is not None:
|
||||
for idx, state in enumerate(encoder_out['encoder_states']):
|
||||
encoder_out['encoder_states'][idx] = state.index_select(1, new_order)
|
||||
if encoder_out.encoder_out is not None:
|
||||
encoder_out = encoder_out._replace(
|
||||
encoder_out=encoder_out.encoder_out.index_select(1, new_order)
|
||||
)
|
||||
if encoder_out.encoder_padding_mask is not None:
|
||||
encoder_out = encoder_out._replace(
|
||||
encoder_padding_mask=encoder_out.encoder_padding_mask.index_select(0, new_order)
|
||||
)
|
||||
if encoder_out.encoder_embedding is not None:
|
||||
encoder_out = encoder_out._replace(
|
||||
encoder_embedding=encoder_out.encoder_embedding.index_select(0, new_order)
|
||||
)
|
||||
if encoder_out.encoder_states is not None:
|
||||
for idx, state in enumerate(encoder_out.encoder_states):
|
||||
encoder_out.encoder_states[idx] = state.index_select(1, new_order)
|
||||
return encoder_out
|
||||
|
||||
def max_positions(self):
|
||||
@ -532,13 +549,13 @@ class TransformerDecoder(FairseqIncrementalDecoder):
|
||||
encoder_out=None,
|
||||
incremental_state=None,
|
||||
features_only=False,
|
||||
**extra_args,
|
||||
**extra_args
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
prev_output_tokens (LongTensor): previous decoder outputs of shape
|
||||
`(batch, tgt_len)`, for teacher forcing
|
||||
encoder_out (Tensor, optional): output from the encoder, used for
|
||||
encoder_out (optional): output from the encoder, used for
|
||||
encoder-side attention
|
||||
incremental_state (dict): dictionary used for storing state during
|
||||
:ref:`Incremental decoding`
|
||||
@ -551,7 +568,10 @@ class TransformerDecoder(FairseqIncrementalDecoder):
|
||||
- a dictionary with any model-specific outputs
|
||||
"""
|
||||
x, extra = self.extract_features(
|
||||
prev_output_tokens, encoder_out, incremental_state, **extra_args,
|
||||
prev_output_tokens,
|
||||
encoder_out=encoder_out,
|
||||
incremental_state=incremental_state,
|
||||
**extra_args
|
||||
)
|
||||
if not features_only:
|
||||
x = self.output_layer(x)
|
||||
@ -628,9 +648,9 @@ class TransformerDecoder(FairseqIncrementalDecoder):
|
||||
encoder_state = None
|
||||
if encoder_out is not None:
|
||||
if self.layer_wise_attention:
|
||||
encoder_state = encoder_out['encoder_states'][idx]
|
||||
encoder_state = encoder_out.encoder_states[idx]
|
||||
else:
|
||||
encoder_state = encoder_out['encoder_out']
|
||||
encoder_state = encoder_out.encoder_out
|
||||
|
||||
if incremental_state is None and not full_context_alignment:
|
||||
self_attn_mask = self.buffered_future_mask(x)
|
||||
@ -643,7 +663,7 @@ class TransformerDecoder(FairseqIncrementalDecoder):
|
||||
x, layer_attn = layer(
|
||||
x,
|
||||
encoder_state,
|
||||
encoder_out['encoder_padding_mask'] if encoder_out is not None else None,
|
||||
encoder_out.encoder_padding_mask if encoder_out is not None else None,
|
||||
incremental_state,
|
||||
self_attn_mask=self_attn_mask,
|
||||
self_attn_padding_mask=self_attn_padding_mask,
|
||||
|
@ -26,16 +26,15 @@ class MeanPoolGatingNetwork(torch.nn.Module):
|
||||
|
||||
def forward(self, encoder_out):
|
||||
if not (
|
||||
isinstance(encoder_out, dict)
|
||||
and 'encoder_out' in encoder_out
|
||||
and 'encoder_padding_mask' in encoder_out
|
||||
and encoder_out['encoder_out'].size(2) == self.embed_dim
|
||||
hasattr(encoder_out, 'encoder_out')
|
||||
and hasattr(encoder_out, 'encoder_padding_mask')
|
||||
and encoder_out.encoder_out.size(2) == self.embed_dim
|
||||
):
|
||||
raise ValueError('Unexpected format for encoder_out')
|
||||
|
||||
# mean pooling over time
|
||||
encoder_padding_mask = encoder_out['encoder_padding_mask'] # B x T
|
||||
encoder_out = encoder_out['encoder_out'].transpose(0, 1) # B x T x C
|
||||
encoder_padding_mask = encoder_out.encoder_padding_mask # B x T
|
||||
encoder_out = encoder_out.encoder_out.transpose(0, 1) # B x T x C
|
||||
if encoder_padding_mask is not None:
|
||||
encoder_out = encoder_out.clone() # required because of transpose above
|
||||
encoder_out[encoder_padding_mask] = 0
|
||||
|
@ -197,51 +197,90 @@ class TestTranslation(unittest.TestCase):
|
||||
])
|
||||
generate_main(data_dir)
|
||||
|
||||
def test_cmlm_transformer(self):
|
||||
with contextlib.redirect_stdout(StringIO()):
|
||||
with tempfile.TemporaryDirectory('test_cmlm_transformer') as data_dir:
|
||||
create_dummy_data(data_dir)
|
||||
preprocess_translation_data(data_dir, ['--joined-dictionary'])
|
||||
train_translation_model(data_dir, 'cmlm_transformer', [
|
||||
'--apply-bert-init',
|
||||
'--criterion', 'nat_loss',
|
||||
'--noise', 'full_mask',
|
||||
'--pred-length-offset',
|
||||
'--length-loss-factor', '0.1'
|
||||
], task='translation_lev')
|
||||
generate_main(data_dir, [
|
||||
'--task', 'translation_lev',
|
||||
'--iter-decode-max-iter', '9',
|
||||
'--iter-decode-eos-penalty', '0',
|
||||
'--print-step',
|
||||
])
|
||||
|
||||
def test_levenshtein_transformer(self):
|
||||
with contextlib.redirect_stdout(StringIO()):
|
||||
with tempfile.TemporaryDirectory('test_levenshtein_transformer') as data_dir:
|
||||
create_dummy_data(data_dir)
|
||||
preprocess_translation_data(data_dir)
|
||||
preprocess_translation_data(data_dir, ['--joined-dictionary'])
|
||||
train_translation_model(data_dir, 'levenshtein_transformer', [
|
||||
'--apply-bert-init', '--early-exit', '6,6,6',
|
||||
'--criterion', 'nat_loss'
|
||||
], task='translation_lev')
|
||||
generate_main(data_dir)
|
||||
generate_main(data_dir, [
|
||||
'--task', 'translation_lev',
|
||||
'--iter-decode-max-iter', '9',
|
||||
'--iter-decode-eos-penalty', '0',
|
||||
'--print-step',
|
||||
])
|
||||
|
||||
def test_nonautoregressive_transformer(self):
|
||||
with contextlib.redirect_stdout(StringIO()):
|
||||
with tempfile.TemporaryDirectory('test_nonautoregressive_transformer') as data_dir:
|
||||
create_dummy_data(data_dir)
|
||||
preprocess_translation_data(data_dir)
|
||||
preprocess_translation_data(data_dir, ['--joined-dictionary'])
|
||||
train_translation_model(data_dir, 'nonautoregressive_transformer', [
|
||||
'--apply-bert-init', '--src-embedding-copy', '--criterion',
|
||||
'nat_loss', '--noise', 'full_mask', '--pred-length-offset',
|
||||
'--length-loss-factor', '0.1'
|
||||
], task='translation_lev')
|
||||
generate_main(data_dir)
|
||||
generate_main(data_dir, [
|
||||
'--task', 'translation_lev',
|
||||
'--iter-decode-max-iter', '9',
|
||||
'--iter-decode-eos-penalty', '0',
|
||||
'--print-step',
|
||||
])
|
||||
|
||||
def test_iterative_nonautoregressive_transformer(self):
|
||||
with contextlib.redirect_stdout(StringIO()):
|
||||
with tempfile.TemporaryDirectory('test_iterative_nonautoregressive_transformer') as data_dir:
|
||||
create_dummy_data(data_dir)
|
||||
preprocess_translation_data(data_dir)
|
||||
preprocess_translation_data(data_dir, ['--joined-dictionary'])
|
||||
train_translation_model(data_dir, 'iterative_nonautoregressive_transformer', [
|
||||
'--apply-bert-init', '--src-embedding-copy', '--criterion',
|
||||
'nat_loss', '--noise', 'full_mask', '--stochastic-approx',
|
||||
'--dae-ratio', '0.5', '--train-step', '3'
|
||||
], task='translation_lev')
|
||||
generate_main(data_dir)
|
||||
generate_main(data_dir, [
|
||||
'--task', 'translation_lev',
|
||||
'--iter-decode-max-iter', '9',
|
||||
'--iter-decode-eos-penalty', '0',
|
||||
'--print-step',
|
||||
])
|
||||
|
||||
def test_insertion_transformer(self):
|
||||
with contextlib.redirect_stdout(StringIO()):
|
||||
with tempfile.TemporaryDirectory('test_insertion_transformer') as data_dir:
|
||||
create_dummy_data(data_dir)
|
||||
preprocess_translation_data(data_dir)
|
||||
preprocess_translation_data(data_dir, ['--joined-dictionary'])
|
||||
train_translation_model(data_dir, 'insertion_transformer', [
|
||||
'--apply-bert-init', '--criterion', 'nat_loss', '--noise',
|
||||
'random_mask'
|
||||
], task='translation_lev')
|
||||
generate_main(data_dir)
|
||||
generate_main(data_dir, [
|
||||
'--task', 'translation_lev',
|
||||
'--iter-decode-max-iter', '9',
|
||||
'--iter-decode-eos-penalty', '0',
|
||||
'--print-step',
|
||||
])
|
||||
|
||||
def test_mixture_of_experts(self):
|
||||
with contextlib.redirect_stdout(StringIO()):
|
||||
|
Loading…
Reference in New Issue
Block a user