Merge TracingCompliantTransformer and regular Transformer, fix NAT tests

Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/899

Differential Revision: D18373060

Pulled By: myleott

fbshipit-source-id: bb5510ec15799a0a10a7c0669e76d8200e1ba479
This commit is contained in:
Myle Ott 2019-11-13 09:10:52 -08:00 committed by Facebook Github Bot
parent 2a9b4ec237
commit 27568a7ebe
14 changed files with 548 additions and 1188 deletions

View File

@ -48,7 +48,7 @@ class LabelSmoothedDualImitationCriterion(FairseqCriterion):
if masks is not None:
outputs, targets = outputs[masks], targets[masks]
if not masks.any():
if masks is not None and not masks.any():
nll_loss = torch.tensor(0)
loss = nll_loss
else:

View File

@ -3,11 +3,20 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from collections import namedtuple
import torch
from fairseq import utils
from fairseq.models.levenshtein_transformer import LevenshteinTransformerModel
from fairseq.models.model_utils import script_skip_tensor_list, skip_tensors as _skip
from fairseq.models.nonautoregressive_ensembles import EnsembleLevT
DecoderOut = namedtuple('IterativeRefinementDecoderOut', [
'output_tokens',
'output_scores',
'attn',
'step',
'max_step',
])
class IterativeRefinementGenerator(object):
@ -88,6 +97,8 @@ class IterativeRefinementGenerator(object):
@torch.no_grad()
def generate(self, models, sample, prefix_tokens=None):
from fairseq.models.levenshtein_transformer import LevenshteinTransformerModel
from fairseq.models.nonautoregressive_ensembles import EnsembleLevT
if len(models) == 1:
# Keep this for other NAT models for which we have yet to implement ensemble wrappers. Later delete this.
@ -110,7 +121,7 @@ class IterativeRefinementGenerator(object):
# initialize buffers (very model specific, with length prediction or not)
prev_decoder_out = model.initialize_output_tokens(encoder_out, src_tokens)
prev_output_tokens = prev_decoder_out[0].clone()
prev_output_tokens = prev_decoder_out.output_tokens.clone()
finalized = [[] for _ in range(bsz)]
@ -150,8 +161,10 @@ class IterativeRefinementGenerator(object):
"max_ratio": self.max_ratio,
"decoding_format": self.decoding_format,
}
prev_decoder_out[3] = step
prev_decoder_out[4] = self.max_iter + 1
prev_decoder_out = prev_decoder_out._replace(
step=step,
max_step=self.max_iter + 1,
)
decoder_out = model.forward_decoder(
prev_decoder_out, encoder_out, **decoder_options
@ -160,24 +173,26 @@ class IterativeRefinementGenerator(object):
if self.adaptive:
# terminate if there is a loop
terminated, out_tokens, out_scores, out_attn = is_a_loop(
prev_output_tokens, decoder_out[0], decoder_out[1], decoder_out[2]
prev_output_tokens, decoder_out.output_tokens, decoder_out.output_scores, decoder_out.attn
)
decoder_out = decoder_out._replace(
output_tokens=out_tokens,
output_scores=out_scores,
attn=out_attn,
)
decoder_out[0] = out_tokens
decoder_out[1] = out_scores
decoder_out[2] = out_attn
else:
terminated = decoder_out[0].new_zeros(decoder_out[0].size(0)).bool()
terminated = decoder_out.output_tokens.new_zeros(decoder_out.output_tokens.size(0)).bool()
if step == self.max_iter: # reach last iteration, terminate
terminated.fill_(1)
# collect finalized sentences
finalized_idxs = sent_idxs[terminated]
finalized_tokens = decoder_out[0][terminated]
finalized_scores = decoder_out[1][terminated]
finalized_tokens = decoder_out.output_tokens[terminated]
finalized_scores = decoder_out.output_scores[terminated]
finalized_attn = (
None if decoder_out[2] is None else decoder_out[2][terminated]
None if decoder_out.attn is None else decoder_out.attn[terminated]
)
for i in range(finalized_idxs.size(0)):
@ -194,10 +209,15 @@ class IterativeRefinementGenerator(object):
break
# for next step
prev_decoder_out = _skip(decoder_out, ~terminated)
encoder_out = script_skip_tensor_list(encoder_out, ~terminated)
sent_idxs = _skip(sent_idxs, ~terminated)
not_terminated = ~terminated
prev_decoder_out = decoder_out._replace(
output_tokens=decoder_out.output_tokens[not_terminated],
output_scores=decoder_out.output_scores[not_terminated],
attn=decoder_out.attn[not_terminated] if decoder_out.attn is not None else None,
)
encoder_out = model.encoder.reorder_encoder_out(encoder_out, not_terminated.nonzero().squeeze())
sent_idxs = sent_idxs[not_terminated]
prev_output_tokens = prev_decoder_out[0].clone()
prev_output_tokens = prev_decoder_out.output_tokens.clone()
return finalized

View File

@ -10,9 +10,9 @@ Ghazvininejad, Marjan, et al.
arXiv preprint arXiv:1904.09324 (2019).
"""
from fairseq.utils import new_arange
from fairseq.models import register_model, register_model_architecture
from fairseq.models.nonautoregressive_transformer import NATransformerModel
from fairseq.utils import new_arange
def _skeptical_unmasking(output_scores, output_masks, p):
@ -55,11 +55,11 @@ class CMLMNATransformerModel(NATransformerModel):
def forward_decoder(self, decoder_out, encoder_out, decoding_format=None, **kwargs):
step = decoder_out["step"]
max_step = decoder_out["max_step"]
step = decoder_out.step
max_step = decoder_out.max_step
output_tokens = decoder_out["output_tokens"]
output_scores = decoder_out["output_scores"]
output_tokens = decoder_out.output_tokens
output_scores = decoder_out.output_scores
# execute the decoder
output_masks = output_tokens.eq(self.unk)
@ -78,7 +78,11 @@ class CMLMNATransformerModel(NATransformerModel):
output_tokens.masked_fill_(skeptical_mask, self.unk)
output_scores.masked_fill_(skeptical_mask, 0.0)
return {"output_tokens": output_tokens, "output_scores": output_scores}
return decoder_out._replace(
output_tokens=output_tokens,
output_scores=output_scores,
attn=None,
)
@register_model_architecture("cmlm_transformer", "cmlm_transformer")

View File

@ -6,7 +6,7 @@
import numpy as np
import torch
import torch.nn.functional as F
from fairseq.utils import new_arange
from fairseq.models import register_model, register_model_architecture
from fairseq.models.levenshtein_transformer import (
LevenshteinTransformerDecoder,
@ -14,6 +14,7 @@ from fairseq.models.levenshtein_transformer import (
)
from fairseq.models.transformer import Linear, TransformerModel
from fairseq.modules.transformer_sentence_encoder import init_bert_params
from fairseq.utils import new_arange
class NegativeDistanceScore(object):
@ -116,8 +117,8 @@ def _apply_ins_words(in_tokens, in_scores, word_ins_pred, word_ins_scores, paddi
@register_model("insertion_transformer")
class InsertionTransformerModel(LevenshteinTransformerModel):
def __init__(self, encoder, decoder):
super().__init__(encoder, decoder)
def __init__(self, args, encoder, decoder):
super().__init__(args, encoder, decoder)
@staticmethod
def add_args(parser):
@ -169,8 +170,8 @@ class InsertionTransformerModel(LevenshteinTransformerModel):
self, decoder_out, encoder_out, eos_penalty=0.0, max_ratio=None, **kwargs
):
output_tokens = decoder_out["output_tokens"]
output_scores = decoder_out["output_scores"]
output_tokens = decoder_out.output_tokens
output_scores = decoder_out.output_scores
# TODO: decoding for InsertionTransformer
word_ins_out = self.decoder.forward_word_ins(
output_tokens, encoder_out=encoder_out
@ -187,7 +188,11 @@ class InsertionTransformerModel(LevenshteinTransformerModel):
cut_off = output_tokens.ne(self.pad).sum(1).max()
output_tokens = output_tokens[:, :cut_off]
output_scores = output_scores[:, :cut_off]
return {"output_tokens": output_tokens, "output_scores": output_scores, "attn": None}
return decoder_out._replace(
output_tokens=output_tokens,
output_scores=output_scores,
attn=None,
)
class InsertionTransformerDecoder(LevenshteinTransformerDecoder):
@ -206,7 +211,7 @@ class InsertionTransformerDecoder(LevenshteinTransformerDecoder):
self.label_tau = getattr(args, "label_tau", None)
def forward_word_ins(self, prev_output_tokens, encoder_out=None):
features, _ = self.extract_features(prev_output_tokens, encoder_out=encoder_out)
features = self.extract_features(prev_output_tokens, encoder_out=encoder_out)[0]
features = self.pool_out(
torch.cat([features[:, :-1, :], features[:, 1:, :]], 2)
)

View File

@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree.
import torch
from fairseq.models import register_model, register_model_architecture
from fairseq.models.nonautoregressive_transformer import NATransformerModel

View File

@ -1,52 +1,109 @@
#!/usr/bin/env python3
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from __future__ import absolute_import, division, print_function, unicode_literals
from typing import Optional
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn as nn
import torch.nn.functional as F
from fairseq.iterative_refinement_generator import DecoderOut
from fairseq.models import register_model, register_model_architecture
from fairseq.models.tracing_compliant_transformer import (
TracingTransformerDecoder,
TracingTransformerEncoder,
TracingTransformerModel,
TransformerDecoderLayer,
from fairseq.models.transformer import (
Embedding,
TransformerDecoder,
TransformerEncoder,
TransformerModel,
TransformerDecoderLayer
)
from fairseq.models.model_utils import (
fill_tensors as _fill,
script_skip_tensor,
script_skip_tensor_list,
)
from fairseq.models.transformer import Embedding
from fairseq.modules.transformer_sentence_encoder import init_bert_params
from torch import Tensor
from fairseq.utils import new_arange
# -------------- Helper Functions --------------------------------------------------- #
def _skip(x, mask):
"""
Getting sliced (dim=0) tensor by mask. Supporting tensor and list/dict of tensors.
"""
if isinstance(x, int):
return x
if x is None:
return None
if isinstance(x, torch.Tensor):
if x.size(0) == mask.size(0):
return x[mask]
elif x.size(1) == mask.size(0):
return x[:, mask]
if isinstance(x, list):
return [_skip(x_i, mask) for x_i in x]
if isinstance(x, dict):
return {k: _skip(v, mask) for k, v in x.items()}
raise NotImplementedError
def _get_ins_targets(in_tokens, out_tokens, padding_idx, unk_idx):
def _skip_encoder_out(encoder, encoder_out, mask):
if not mask.any():
return encoder_out
else:
return encoder.reorder_encoder_out(encoder_out, mask.nonzero().squeeze())
def _fill(x, mask, y, padding_idx):
"""
Filling tensor x with y at masked positions (dim=0).
"""
if x is None:
return y
assert x.dim() == y.dim() and mask.size(0) == x.size(0)
assert x.dim() == 2 or (x.dim() == 3 and x.size(2) == y.size(2))
n_selected = mask.sum()
assert n_selected == y.size(0)
if n_selected == x.size(0):
return y
if x.size(1) < y.size(1):
dims = [x.size(0), y.size(1) - x.size(1)]
if x.dim() == 3:
dims.append(x.size(2))
x = torch.cat([x, x.new_zeros(*dims).fill_(padding_idx)], 1)
x[mask] = y
elif x.size(1) > y.size(1):
x[mask] = padding_idx
if x.dim() == 2:
x[mask, :y.size(1)] = y
else:
x[mask, :y.size(1), :] = y
else:
x[mask] = y
return x
def load_libnat():
try:
from fairseq import libnat
except ImportError as e:
import sys
sys.stderr.write("ERROR: missing libnat. run `pip install --editable .`\n")
raise e
return libnat
def _get_ins_targets(in_tokens, out_tokens, padding_idx, unk_idx):
libnat = load_libnat()
in_seq_len, out_seq_len = in_tokens.size(1), out_tokens.size(1)
in_tokens_list = [
[t for t in s if t != padding_idx] for i, s in enumerate(in_tokens.tolist())
]
out_tokens_list = [
[t for t in s if t != padding_idx] for i, s in enumerate(out_tokens.tolist())
[t for t in s if t != padding_idx]
for i, s in enumerate(out_tokens.tolist())
]
full_labels = libnat.suggested_ed2_path(
@ -71,28 +128,27 @@ def _get_ins_targets(in_tokens, out_tokens, padding_idx, unk_idx):
]
# transform to tensor
masked_tgt_masks = torch.tensor(masked_tgt_masks, device=out_tokens.device).bool()
masked_tgt_masks = torch.tensor(
masked_tgt_masks, device=out_tokens.device
).bool()
mask_ins_targets = torch.tensor(mask_ins_targets, device=in_tokens.device)
masked_tgt_tokens = out_tokens.masked_fill(masked_tgt_masks, unk_idx)
return masked_tgt_masks, masked_tgt_tokens, mask_ins_targets
def _get_del_targets(in_tokens, out_tokens, padding_idx):
try:
from fairseq import libnat
except ImportError as e:
import sys
libnat = load_libnat()
sys.stderr.write("ERROR: missing libnat. run `pip install --editable .`\n")
raise e
out_seq_len = out_tokens.size(1)
in_tokens_list = [
[t for t in s if t != padding_idx] for i, s in enumerate(in_tokens.tolist())
]
out_tokens_list = [
[t for t in s if t != padding_idx] for i, s in enumerate(out_tokens.tolist())
]
with torch.cuda.device_of(in_tokens):
in_tokens_list = [
[t for t in s if t != padding_idx] for i, s in enumerate(in_tokens.tolist())
]
out_tokens_list = [
[t for t in s if t != padding_idx]
for i, s in enumerate(out_tokens.tolist())
]
full_labels = libnat.suggested_ed2_path(
in_tokens_list, out_tokens_list, padding_idx
@ -104,26 +160,23 @@ def _get_del_targets(in_tokens, out_tokens, padding_idx):
]
# transform to tensor
word_del_targets = torch.tensor(word_del_targets)
word_del_targets = torch.tensor(word_del_targets, device=out_tokens.device)
return word_del_targets
def _get_del_ins_targets(in_tokens, out_tokens, padding_idx):
try:
from fairseq import libnat
except ImportError as e:
import sys
libnat = load_libnat()
sys.stderr.write("ERROR: missing libnat. run `pip install --editable .`\n")
raise e
in_seq_len, out_seq_len = in_tokens.size(1), out_tokens.size(1)
in_tokens_list = [
[t for t in s if t != padding_idx] for i, s in enumerate(in_tokens.tolist())
]
out_tokens_list = [
[t for t in s if t != padding_idx] for i, s in enumerate(out_tokens.tolist())
]
with torch.cuda.device_of(in_tokens):
in_tokens_list = [
[t for t in s if t != padding_idx] for i, s in enumerate(in_tokens.tolist())
]
out_tokens_list = [
[t for t in s if t != padding_idx]
for i, s in enumerate(out_tokens.tolist())
]
full_labels = libnat.suggested_ed2_path(
in_tokens_list, out_tokens_list, padding_idx
@ -144,15 +197,101 @@ def _get_del_ins_targets(in_tokens, out_tokens, padding_idx):
]
# transform to tensor
mask_ins_targets = torch.tensor(mask_ins_targets)
word_del_targets = torch.tensor(word_del_targets)
mask_ins_targets = torch.tensor(mask_ins_targets, device=in_tokens.device)
word_del_targets = torch.tensor(word_del_targets, device=out_tokens.device)
return word_del_targets, mask_ins_targets
def _apply_ins_masks(
in_tokens, in_scores, mask_ins_pred, padding_idx, unk_idx, eos_idx
):
in_masks = in_tokens.ne(padding_idx)
in_lengths = in_masks.sum(1)
# HACK: hacky way to shift all the paddings to eos first.
in_tokens.masked_fill_(~in_masks, eos_idx)
mask_ins_pred.masked_fill_(~in_masks[:, 1:], 0)
out_lengths = in_lengths + mask_ins_pred.sum(1)
out_max_len = out_lengths.max()
out_masks = (
new_arange(out_lengths, out_max_len)[None, :]
< out_lengths[:, None]
)
reordering = (mask_ins_pred + in_masks[:, 1:].long()).cumsum(1)
out_tokens = (
in_tokens.new_zeros(in_tokens.size(0), out_max_len)
.fill_(padding_idx)
.masked_fill_(out_masks, unk_idx)
)
out_tokens[:, 0] = in_tokens[:, 0]
out_tokens.scatter_(1, reordering, in_tokens[:, 1:])
out_scores = None
if in_scores is not None:
in_scores.masked_fill_(~in_masks, 0)
out_scores = in_scores.new_zeros(*out_tokens.size())
out_scores[:, 0] = in_scores[:, 0]
out_scores.scatter_(1, reordering, in_scores[:, 1:])
return out_tokens, out_scores
def _apply_ins_words(
in_tokens, in_scores, word_ins_pred, word_ins_scores, unk_idx
):
word_ins_masks = in_tokens.eq(unk_idx)
out_tokens = in_tokens.masked_scatter(word_ins_masks, word_ins_pred[word_ins_masks])
if in_scores is not None:
out_scores = in_scores.masked_scatter(
word_ins_masks, word_ins_scores[word_ins_masks]
)
else:
out_scores = None
return out_tokens, out_scores
def _apply_del_words(
in_tokens, in_scores, in_attn, word_del_pred, padding_idx, bos_idx, eos_idx
):
# apply deletion to a tensor
in_masks = in_tokens.ne(padding_idx)
bos_eos_masks = in_tokens.eq(bos_idx) | in_tokens.eq(eos_idx)
max_len = in_tokens.size(1)
word_del_pred.masked_fill_(~in_masks, 1)
word_del_pred.masked_fill_(bos_eos_masks, 0)
reordering = (
new_arange(in_tokens)
.masked_fill_(word_del_pred, max_len)
.sort(1)[1]
)
out_tokens = in_tokens.masked_fill(word_del_pred, padding_idx).gather(1, reordering)
out_scores = None
if in_scores is not None:
out_scores = in_scores.masked_fill(word_del_pred, 0).gather(1, reordering)
out_attn = None
if in_attn is not None:
_mask = word_del_pred[:, :, None].expand_as(in_attn)
_reordering = reordering[:, :, None].expand_as(in_attn)
out_attn = in_attn.masked_fill(_mask, 0.).gather(1, _reordering)
return out_tokens, out_scores, out_attn
# ------------------------------------------------------------------------------------- #
@register_model("levenshtein_transformer")
class LevenshteinTransformerModel(TracingTransformerModel):
def __init__(self, encoder, decoder):
super().__init__(encoder, decoder)
class LevenshteinTransformerModel(TransformerModel):
def __init__(self, args, encoder, decoder):
super().__init__(args, encoder, decoder)
self.tgt_dict = decoder.dictionary
self.bos = decoder.dictionary.bos()
self.eos = decoder.dictionary.eos()
@ -161,7 +300,7 @@ class LevenshteinTransformerModel(TracingTransformerModel):
@staticmethod
def add_args(parser):
TracingTransformerModel.add_args(parser)
TransformerModel.add_args(parser)
parser.add_argument(
"--apply-bert-init",
action="store_true",
@ -171,31 +310,27 @@ class LevenshteinTransformerModel(TracingTransformerModel):
"--early-exit",
default="6,6,6",
type=str,
help="number of decoder layers for del_word, ins_mask, ins_word",
help="number of decoder layers before word_del, mask_ins, word_ins",
)
parser.add_argument(
"--no-share-discriminator",
action="store_true",
help="addtional decoder-layers to learn deletion",
help="separate parameters for discriminator",
)
parser.add_argument(
"--no-share-maskpredictor",
action="store_true",
help="addtional decoder-layers to learn predicting masks",
help="separate parameters for mask-predictor",
)
parser.add_argument(
"--share-discriminator-maskpredictor",
action="store_true",
help="share the parameters for both mask-predictor and discriminator",
)
parser.add_argument(
"--sampling-for-deletion",
action="store_true",
help="instead of argmax, use sampling to predict the tokens",
)
# Added for compatibility
parser.add_argument(
"--decoder-out-embed-dim",
default=None,
type=int,
metavar="N",
help="decoder output embedding dimension (bottleneck layer before"
"output layer if specified.)",
action='store_true',
help='instead of argmax, use sampling to predict the tokens'
)
@classmethod
@ -207,7 +342,7 @@ class LevenshteinTransformerModel(TracingTransformerModel):
@classmethod
def build_encoder(cls, args, src_dict, embed_tokens):
encoder = TracingTransformerEncoder(args, src_dict, embed_tokens)
encoder = TransformerEncoder(args, src_dict, embed_tokens)
if getattr(args, "apply_bert_init", False):
encoder.apply(init_bert_params)
return encoder
@ -238,8 +373,8 @@ class LevenshteinTransformerModel(TracingTransformerModel):
# make online prediction
if self.decoder.sampling_for_deletion:
word_predictions = torch.multinomial(
F.softmax(word_ins_out, -1).view(-1, word_ins_out.size(-1)), 1
).view(word_ins_out.size(0), -1)
F.softmax(word_ins_out, -1).view(-1, word_ins_out.size(-1)), 1).view(
word_ins_out.size(0), -1)
else:
word_predictions = F.log_softmax(word_ins_out, dim=-1).max(2)[1]
@ -249,7 +384,10 @@ class LevenshteinTransformerModel(TracingTransformerModel):
# generate training labels for deletion
word_del_targets = _get_del_targets(word_predictions, tgt_tokens, self.pad)
word_del_out, _ = self.decoder.forward_word_del(word_predictions, encoder_out)
word_del_out, _ = self.decoder.forward_word_del(
word_predictions, encoder_out)
word_del_masks = word_predictions.ne(self.pad)
return {
"mask_ins_out": mask_ins_out,
"mask_ins_tgt": mask_ins_targets,
@ -259,7 +397,7 @@ class LevenshteinTransformerModel(TracingTransformerModel):
"word_ins_mask": masked_tgt_masks,
"word_del_out": word_del_out,
"word_del_tgt": word_del_targets,
"word_del_mask": word_predictions.ne(self.pad),
"word_del_mask": word_del_masks,
}
def forward_encoder(self, encoder_inputs):
@ -269,248 +407,123 @@ class LevenshteinTransformerModel(TracingTransformerModel):
self, decoder_out, encoder_out, eos_penalty=0.0, max_ratio=None, **kwargs
):
output_tokens = decoder_out[0]
output_scores = decoder_out[1]
attn = decoder_out[2]
if max_ratio is not None and encoder_out[1] is not None:
max_lengths = ((~encoder_out[1]).sum(1) * max_ratio).clamp(min=10)
output_tokens = decoder_out.output_tokens
output_scores = decoder_out.output_scores
attn = decoder_out.attn
bsz = output_tokens.size(0)
if max_ratio is None:
max_lens = torch.zeros_like(output_tokens).fill_(255)
else:
max_lengths = torch.zeros(output_tokens.size(0)).fill_(255)
@torch.jit.script
def del_word(
output_tokens,
output_scores,
attn: Tensor,
word_del_attn: Optional[Tensor],
word_del_out,
can_del_word,
pad_idx: int,
bos_idx: int,
eos_idx: int,
):
# delete words
# do not delete tokens if it is <s> </s>
if can_del_word.sum() != 0: # we cannot delete, skip
word_del_score = F.log_softmax(word_del_out, 2)
word_del_pred = torch.jit.Attribute(word_del_score.max(-1)[1], bool)
in_tokens = output_tokens[can_del_word]
in_scores = output_scores[can_del_word]
# apply deletion to a tensor
in_masks = in_tokens.ne(pad_idx)
bos_eos_masks = in_tokens.eq(bos_idx) | in_tokens.eq(eos_idx)
max_len = in_tokens.size(1)
word_del_pred.masked_fill_(~in_masks, 1)
word_del_pred.masked_fill_(bos_eos_masks, 0)
reordering = (
torch.arange(max_len)[None, :]
.expand_as(in_tokens)
.contiguous()
.masked_fill(word_del_pred, max_len)
.sort(1)[1]
)
_tokens = in_tokens.masked_fill(word_del_pred, pad_idx).gather(
1, reordering
)
_scores = in_scores.masked_fill(word_del_pred, 0).gather(1, reordering)
if word_del_attn is not None:
_mask = word_del_pred[:, :, None].expand_as(word_del_attn)
_reordering = reordering[:, :, None].expand_as(word_del_attn)
_attn = word_del_attn.masked_fill(_mask, 0.0).gather(1, _reordering)
attn = _fill(attn, can_del_word, _attn, 0)
output_tokens = _fill(output_tokens, can_del_word, _tokens, pad_idx)
output_scores = _fill(output_scores, can_del_word, _scores, 0)
return output_tokens, output_scores, attn
@torch.jit.script
def ins_placeholders(
output_tokens,
output_scores,
mask_ins_out,
can_ins_mask,
pad_idx: int,
unk_idx: int,
eos_idx: int,
max_ratio: float,
max_lengths,
):
# insert placeholders
if can_ins_mask.sum() != 0:
mask_ins_score = F.log_softmax(mask_ins_out, 2)
if eos_penalty > 0.0:
mask_ins_score[:, :, 0] -= eos_penalty
mask_ins_pred = mask_ins_score.max(-1)[1]
if max_ratio is not None and encoder_out[1] is not None:
mask_ins_pred = torch.min(
mask_ins_pred, max_lengths[can_ins_mask, None].expand_as(mask_ins_pred)
)
in_tokens = output_tokens[can_ins_mask]
in_scores = output_scores[can_ins_mask]
in_masks = in_tokens.ne(pad_idx)
in_lengths = in_masks.sum(1)
# HACK: hacky way to shift all the paddings to eos first.
in_tokens.masked_fill_(~in_masks, eos_idx)
mask_ins_pred.masked_fill_(~in_masks[:, 1:], 0)
out_lengths = in_lengths + mask_ins_pred.sum(1)
out_max_len = out_lengths.max()
out_masks = (
torch.arange(out_max_len)[None, :].long() < out_lengths[:, None]
)
reordering = (mask_ins_pred + in_masks[:, 1:].long()).cumsum(1)
out_tokens = (
torch.zeros(in_tokens.size()[0], out_max_len)
.fill_(pad_idx)
.masked_fill_(out_masks, unk_idx)
)
out_tokens = torch.cat([in_tokens[:, :1], out_tokens[:, 1:]], 1)
out_tokens.scatter_(1, reordering, in_tokens[:, 1:].float())
if in_scores is not None:
in_scores.masked_fill_(~in_masks, 0)
out_scores = torch.zeros_like(out_tokens).to(in_scores)
out_tokens = torch.cat([in_tokens[:, :1], out_tokens[:, 1:]], 1)
out_scores.scatter_(1, reordering, in_scores[:, 1:])
else:
out_scores = None
output_tokens = _fill(output_tokens, can_ins_mask, out_tokens, pad_idx)
output_scores = _fill(output_scores, can_ins_mask, out_scores, 0)
return output_tokens, output_scores
@torch.jit.script
def ins_words(
output_tokens,
output_scores,
attn: Tensor,
word_ins_attn,
word_ins_out,
can_ins_word,
pad_idx: int,
unk_idx: int,
):
# insert words
if can_ins_word.sum() != 0:
word_ins_scores = F.log_softmax(word_ins_out, 2)
word_ins_pred = word_ins_scores.max(-1)[1]
in_tokens = output_tokens[can_ins_word]
in_scores = output_scores[can_ins_word]
word_ins_masks = in_tokens.eq(unk_idx)
out_tokens = in_tokens.masked_scatter(
word_ins_masks, word_ins_pred[word_ins_masks].float()
)
if in_scores is not None:
out_scores = in_scores.masked_scatter(
word_ins_masks, word_ins_scores[word_ins_masks]
)
else:
out_scores = None
output_tokens = _fill(output_tokens, can_ins_word, out_tokens, pad_idx)
output_scores = _fill(output_scores, can_ins_word, out_scores, 0)
attn = _fill(attn, can_ins_word, word_ins_attn, 0)
return output_tokens, output_scores, attn
if encoder_out.encoder_padding_mask is None:
max_src_len = encoder_out.encoder_out.size(1)
src_lens = encoder_out.encoder_out.new(bsz).fill_(max_src_len)
else:
src_lens = (~encoder_out.encoder_padding_mask).sum(1)
max_lens = (src_lens * max_ratio).clamp(min=10).long()
# delete words
# do not delete tokens if it is <s> </s>
can_del_word = output_tokens.ne(self.pad).sum(1) > 2
word_del_out, word_del_attn = self.decoder.forward_word_del(
script_skip_tensor(output_tokens, can_del_word),
script_skip_tensor_list(list(encoder_out), can_del_word),
)
if can_del_word.sum() != 0: # we cannot delete, skip
word_del_out, word_del_attn = self.decoder.forward_word_del(
_skip(output_tokens, can_del_word),
_skip_encoder_out(self.encoder, encoder_out, can_del_word)
)
word_del_score = F.log_softmax(word_del_out, 2)
word_del_pred = word_del_score.max(-1)[1].bool()
output_tokens, output_scores, attn = del_word(
output_tokens,
output_scores,
attn,
word_del_attn,
word_del_out,
can_del_word,
self.pad,
self.bos,
self.eos,
)
_tokens, _scores, _attn = _apply_del_words(
output_tokens[can_del_word],
output_scores[can_del_word],
word_del_attn,
word_del_pred,
self.pad,
self.bos,
self.eos,
)
output_tokens = _fill(output_tokens, can_del_word, _tokens, self.pad)
output_scores = _fill(output_scores, can_del_word, _scores, 0)
attn = _fill(attn, can_del_word, _attn, 0.)
can_ins_mask = output_tokens.ne(self.pad).sum(1) < max_lengths
mask_ins_out, _ = self.decoder.forward_mask_ins(
script_skip_tensor(output_tokens, can_ins_mask),
script_skip_tensor_list(encoder_out, can_ins_mask),
)
output_tokens, output_scores = ins_placeholders(
output_tokens,
output_scores,
mask_ins_out,
can_ins_mask,
self.pad,
self.unk,
self.eos,
max_ratio,
max_lengths,
)
# insert placeholders
can_ins_mask = output_tokens.ne(self.pad).sum(1) < max_lens
if can_ins_mask.sum() != 0:
mask_ins_out, _ = self.decoder.forward_mask_ins(
_skip(output_tokens, can_ins_mask),
_skip_encoder_out(self.encoder, encoder_out, can_ins_mask)
)
mask_ins_score = F.log_softmax(mask_ins_out, 2)
if eos_penalty > 0.0:
mask_ins_score[:, :, 0] = mask_ins_score[:, :, 0] - eos_penalty
mask_ins_pred = mask_ins_score.max(-1)[1]
mask_ins_pred = torch.min(
mask_ins_pred, max_lens[can_ins_mask, None].expand_as(mask_ins_pred)
)
_tokens, _scores = _apply_ins_masks(
output_tokens[can_ins_mask],
output_scores[can_ins_mask],
mask_ins_pred,
self.pad,
self.unk,
self.eos,
)
output_tokens = _fill(output_tokens, can_ins_mask, _tokens, self.pad)
output_scores = _fill(output_scores, can_ins_mask, _scores, 0)
# insert words
can_ins_word = output_tokens.eq(self.unk).sum(1) > 0
word_ins_out, word_ins_attn = self.decoder.forward_word_ins(
script_skip_tensor(output_tokens, can_ins_word),
script_skip_tensor_list(encoder_out, can_ins_word),
)
if can_ins_word.sum() != 0:
word_ins_out, word_ins_attn = self.decoder.forward_word_ins(
_skip(output_tokens, can_ins_word),
_skip_encoder_out(self.encoder, encoder_out, can_ins_word)
)
word_ins_score, word_ins_pred = F.log_softmax(word_ins_out, 2).max(-1)
word_ins_pred = word_ins_score.max(-1)[1]
_tokens, _scores = _apply_ins_words(
output_tokens[can_ins_word],
output_scores[can_ins_word],
word_ins_pred,
word_ins_score,
self.unk,
)
output_tokens, output_scores, attn = ins_words(
output_tokens,
output_scores,
attn,
word_ins_attn,
word_ins_out,
can_ins_word,
self.pad,
self.unk,
)
output_tokens = _fill(output_tokens, can_ins_word, _tokens, self.pad)
output_scores = _fill(output_scores, can_ins_word, _scores, 0)
attn = _fill(attn, can_ins_word, word_ins_attn, 0.)
# delete some unnecessary paddings
cut_off = output_tokens.ne(self.pad).sum(1).max()
output_tokens = output_tokens[:, :cut_off]
output_scores = output_scores[:, :cut_off]
attn = None if attn is None else attn[:, :cut_off, :]
@torch.jit.script
def slice_wrap(x, l):
return x[:, :l]
@torch.jit.script
def slice_wrap_attn(x, l):
return x if x.size()[0] == 0 else x[:, :l, :]
output_tokens = slice_wrap(output_tokens, cut_off)
output_scores = slice_wrap(output_scores, cut_off)
attn = slice_wrap(attn, cut_off)
return [output_tokens, output_scores, attn, 0, 0]
return decoder_out._replace(
output_tokens=output_tokens,
output_scores=output_scores,
attn=attn,
)
def initialize_output_tokens(self, encoder_out, src_tokens):
initial_output_tokens = torch.cat(
[
torch.zeros(src_tokens.size(0), 1).fill_(self.bos),
torch.zeros(src_tokens.size(0), 1).fill_(self.eos),
],
1,
initial_output_tokens = src_tokens.new_zeros(src_tokens.size(0), 2)
initial_output_tokens[:, 0] = self.bos
initial_output_tokens[:, 1] = self.eos
initial_output_scores = initial_output_tokens.new_zeros(
*initial_output_tokens.size()
).type_as(encoder_out.encoder_out)
return DecoderOut(
output_tokens=initial_output_tokens,
output_scores=initial_output_scores,
attn=None,
step=0,
max_step=0,
)
initial_output_scores = torch.zeros_like(initial_output_tokens).to(
encoder_out[0]
)
initial_attn = torch.empty([0])
if getattr(self.decoder.layers[-1], "need_attn", True):
initial_attn = torch.zeros([src_tokens.size(0), 2, src_tokens.size(1)]).to(
initial_output_tokens
)
return [initial_output_tokens, initial_output_scores, initial_attn, 0, 0]
class LevenshteinTransformerDecoder(TracingTransformerDecoder):
class LevenshteinTransformerDecoder(TransformerDecoder):
def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False):
super().__init__(
args, dictionary, embed_tokens, no_encoder_attn=no_encoder_attn
@ -524,38 +537,32 @@ class LevenshteinTransformerDecoder(TracingTransformerDecoder):
self.embed_word_del = Embedding(2, self.output_embed_dim, None)
# del_word, ins_mask, ins_word
self.early_exit = [int(i) for i in args.early_exit.split(",")]
self.early_exit = [int(i) for i in args.early_exit.split(',')]
assert len(self.early_exit) == 3
# copy layers for mask-predict/deletion
self.layers_msk = None
if getattr(args, "no_share_maskpredictor", False):
self.layers_msk = nn.ModuleList(
[
TransformerDecoderLayer(args, no_encoder_attn)
for _ in range(self.early_exit[1])
]
)
self.layers_msk = nn.ModuleList([
TransformerDecoderLayer(args, no_encoder_attn)
for _ in range(self.early_exit[1])
])
self.layers_del = None
if getattr(args, "no_share_discriminator", False):
self.layers_del = nn.ModuleList(
[
TransformerDecoderLayer(args, no_encoder_attn)
for _ in range(self.early_exit[0])
]
)
self.layers_del = nn.ModuleList([
TransformerDecoderLayer(args, no_encoder_attn)
for _ in range(self.early_exit[0])
])
if getattr(args, "share_discriminator_maskpredictor", False):
assert getattr(args, "no_share_discriminator", False), "must set saperate discriminator"
self.layers_msk = self.layers_del
def extract_features(
self,
prev_output_tokens,
encoder_out=None,
early_exit=None,
layers=None,
**unused
self, prev_output_tokens, encoder_out=None, early_exit=None, layers=None, **unused
):
"""
Similar to *forward* but only return features.
Inputs:
prev_output_tokens: Tensor(B, T)
encoder_out: a dictionary of hidden states and masks
@ -574,7 +581,7 @@ class LevenshteinTransformerDecoder(TracingTransformerDecoder):
)
# embed tokens and positions
x = self.embed_scale * self.embed_tokens(prev_output_tokens.long())
x = self.embed_scale * self.embed_tokens(prev_output_tokens)
if self.project_in_dim is not None:
x = self.project_in_dim(x)
@ -591,11 +598,11 @@ class LevenshteinTransformerDecoder(TracingTransformerDecoder):
decoder_padding_mask = prev_output_tokens.eq(self.padding_idx)
layers = self.layers if layers is None else layers
early_exit = len(layers) if early_exit is None else early_exit
for _, layer in enumerate(layers[:early_exit]):
for _, layer in enumerate(layers[: early_exit]):
x, attn = layer(
x,
encoder_out[0] if encoder_out is not None else None,
encoder_out[1] if encoder_out is not None else None,
encoder_out.encoder_out if encoder_out is not None else None,
encoder_out.encoder_padding_mask if encoder_out is not None else None,
self_attn_mask=None,
self_attn_padding_mask=decoder_padding_mask,
)
@ -610,38 +617,26 @@ class LevenshteinTransformerDecoder(TracingTransformerDecoder):
if self.project_out_dim is not None:
x = self.project_out_dim(x)
return x, attn, inner_states
return x, {"attn": attn, "inner_states": inner_states}
def forward_mask_ins(self, prev_output_tokens, encoder_out=None, **unused):
features, attn, _ = self.extract_features(
prev_output_tokens,
encoder_out=encoder_out,
early_exit=self.early_exit[1],
layers=self.layers_msk,
**unused
features, extra = self.extract_features(
prev_output_tokens, encoder_out=encoder_out, early_exit=self.early_exit[1], layers=self.layers_msk, **unused
)
features_cat = torch.cat([features[:, :-1, :], features[:, 1:, :]], 2)
return F.linear(features_cat, self.embed_mask_ins.weight), attn
return F.linear(features_cat, self.embed_mask_ins.weight), extra['attn']
def forward_word_ins(self, prev_output_tokens, encoder_out=None, **unused):
features, attn, _ = self.extract_features(
prev_output_tokens,
encoder_out=encoder_out,
early_exit=self.early_exit[2],
layers=self.layers,
**unused
features, extra = self.extract_features(
prev_output_tokens, encoder_out=encoder_out, early_exit=self.early_exit[2], layers=self.layers, **unused
)
return self.output_layer(features), attn
return self.output_layer(features), extra['attn']
def forward_word_del(self, prev_output_tokens, encoder_out=None, **unused):
features, attn, _ = self.extract_features(
prev_output_tokens,
encoder_out=encoder_out,
early_exit=self.early_exit[0],
layers=self.layers_del,
**unused
features, extra = self.extract_features(
prev_output_tokens, encoder_out=encoder_out, early_exit=self.early_exit[0], layers=self.layers_del, **unused
)
return F.linear(features, self.embed_word_del.weight), attn
return F.linear(features, self.embed_word_del.weight), extra['attn']
@register_model_architecture("levenshtein_transformer", "levenshtein_transformer")
@ -671,7 +666,7 @@ def base_architecture(args):
args.share_decoder_input_output_embed = getattr(
args, "share_decoder_input_output_embed", False
)
args.share_all_embeddings = getattr(args, "share_all_embeddings", True)
args.share_all_embeddings = getattr(args, "share_all_embeddings", False)
args.no_token_positional_embeddings = getattr(
args, "no_token_positional_embeddings", False
)
@ -686,6 +681,8 @@ def base_architecture(args):
args.early_exit = getattr(args, "early_exit", "6,6,6")
args.no_share_discriminator = getattr(args, "no_share_discriminator", False)
args.no_share_maskpredictor = getattr(args, "no_share_maskpredictor", False)
args.share_discriminator_maskpredictor = getattr(args, "share_discriminator_maskpredictor", False)
args.no_share_last_layer = getattr(args, "no_share_last_layer", False)
@register_model_architecture(

View File

@ -3,7 +3,7 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import Dict, List
from typing import List, Optional
import torch
from torch import Tensor
@ -33,39 +33,6 @@ def script_skip_tensor(x: Tensor, mask):
return res
@torch.jit.script
def script_skip_tensor_dict(x: Dict[str, Tensor], mask):
outputs = {}
for s, t in x.items():
outputs[s] = t[mask] if t.size(0) == mask.size(0) else t[:, mask]
return outputs
def skip_tensors(x, mask):
"""
Getting sliced (dim=0) tensor by mask. Supporting tensor and list/dict of tensors.
"""
if isinstance(x, int):
return x
if x is None:
return None
if isinstance(x, torch.Tensor):
if x.size(0) == mask.size(0):
return x[mask]
elif x.size(1) == mask.size(0):
return x[:, mask]
if isinstance(x, list):
return [skip_tensors(x_i, mask) for x_i in x]
if isinstance(x, dict):
return {k: skip_tensors(v, mask) for k, v in x.items()}
raise NotImplementedError
@torch.jit.script
def expand_2d_or_3d_tensor(x, trg_dim: int, padding_idx: int):
"""
@ -88,12 +55,17 @@ def expand_2d_or_3d_tensor(x, trg_dim: int, padding_idx: int):
@torch.jit.script
def fill_tensors(x, mask, y, padding_idx: int):
def coalesce(x: Optional[Tensor], y: Tensor) -> Tensor:
return x if x is not None else y
@torch.jit.script
def fill_tensors(x: Optional[Tensor], mask, y: Optional[Tensor], padding_idx: int) -> Optional[Tensor]:
"""
Filling tensor x with y at masked positions (dim=0).
"""
if x is None or x.size()[0] == 0:
return torch.empty([0])
if x is None or x.size()[0] == 0 or y is None:
return x
assert x.dim() == y.dim() and mask.size(0) == x.size(0)
assert x.dim() == 2 or (x.dim() == 3 and x.size(2) == y.size(2))
@ -116,88 +88,3 @@ def fill_tensors(x, mask, y, padding_idx: int):
else:
x[mask] = y
return x
def _apply_ins_masks(
in_tokens, in_scores, mask_ins_pred, padding_idx, unk_idx, eos_idx
):
in_masks = in_tokens.ne(padding_idx)
in_lengths = in_masks.sum(1)
# HACK: hacky way to shift all the paddings to eos first.
in_tokens.masked_fill_(~in_masks, eos_idx)
mask_ins_pred.masked_fill_(~in_masks[:, 1:], 0)
out_lengths = in_lengths + mask_ins_pred.sum(1)
out_max_len = out_lengths.max()
out_masks = (
torch.arange(out_max_len, device=out_lengths.device)[None, :]
< out_lengths[:, None]
)
reordering = (mask_ins_pred + in_masks[:, 1:].long()).cumsum(1)
out_tokens = (
in_tokens.new_zeros(in_tokens.size(0), out_max_len)
.fill_(padding_idx)
.masked_fill_(out_masks, unk_idx)
)
out_tokens[:, 0] = in_tokens[:, 0]
out_tokens.scatter_(1, reordering, in_tokens[:, 1:])
out_scores = None
if in_scores is not None:
in_scores.masked_fill_(~in_masks, 0)
out_scores = in_scores.new_zeros(*out_tokens.size())
out_scores[:, 0] = in_scores[:, 0]
out_scores.scatter_(1, reordering, in_scores[:, 1:])
return out_tokens, out_scores
def _apply_ins_words(in_tokens, in_scores, word_ins_pred, word_ins_scores, unk_idx):
word_ins_masks = in_tokens.eq(unk_idx)
out_tokens = in_tokens.masked_scatter(word_ins_masks, word_ins_pred[word_ins_masks])
if in_scores is not None:
out_scores = in_scores.masked_scatter(
word_ins_masks, word_ins_scores[word_ins_masks]
)
else:
out_scores = None
return out_tokens, out_scores
def _apply_del_words(
in_tokens, in_scores, in_attn, word_del_pred, padding_idx, bos_idx, eos_idx
):
# apply deletion to a tensor
in_masks = in_tokens.ne(padding_idx)
bos_eos_masks = in_tokens.eq(bos_idx) | in_tokens.eq(eos_idx)
max_len = in_tokens.size(1)
word_del_pred.masked_fill_(~in_masks, 1)
word_del_pred.masked_fill_(bos_eos_masks, 0)
reordering = (
torch.arange(max_len, device=in_tokens.device)[None, :]
.expand_as(in_tokens)
.contiguous()
.masked_fill_(word_del_pred, max_len)
.sort(1)[1]
)
out_tokens = in_tokens.masked_fill(word_del_pred, padding_idx).gather(1, reordering)
out_scores = None
if in_scores is not None:
out_scores = in_scores.masked_fill(word_del_pred, 0).gather(1, reordering)
out_attn = None
if in_attn is not None:
_mask = word_del_pred[:, :, None].expand_as(in_attn)
_reordering = reordering[:, :, None].expand_as(in_attn)
out_attn = in_attn.masked_fill(_mask, 0.).gather(1, _reordering)
return out_tokens, out_scores, out_attn

View File

@ -3,11 +3,18 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
import torch
import torch.nn.functional as F
import math
from fairseq.models.model_utils import fill_tensors as _fill, skip_tensors as _skip
from fairseq.models.model_utils import _apply_del_words, _apply_ins_masks, _apply_ins_words
from fairseq.models.levenshtein_transformer import (
_skip,
_apply_ins_masks,
_apply_ins_words,
_apply_del_words,
)
from fairseq.models.model_utils import fill_tensors as _fill
class BasicEnsembleModel(torch.nn.Module):

View File

@ -5,7 +5,9 @@
import torch
import torch.nn.functional as F
from fairseq import utils
from fairseq.iterative_refinement_generator import DecoderOut
from fairseq.models import register_model, register_model_architecture
from fairseq.models.transformer import (
Embedding,
@ -45,8 +47,8 @@ def _uniform_assignment(src_lens, trg_lens):
@register_model("nonautoregressive_transformer")
class NATransformerModel(TransformerModel):
def __init__(self, encoder, decoder):
super().__init__(encoder, decoder)
def __init__(self, args, encoder, decoder):
super().__init__(args, encoder, decoder)
self.tgt_dict = decoder.dictionary
self.bos = decoder.dictionary.bos()
self.eos = decoder.dictionary.eos()
@ -112,9 +114,9 @@ class NATransformerModel(TransformerModel):
return self.encoder(*encoder_inputs)
def forward_decoder(self, decoder_out, encoder_out, decoding_format=None, **kwargs):
step = decoder_out["step"]
output_tokens = decoder_out["output_tokens"]
output_scores = decoder_out["output_scores"]
step = decoder_out.step
output_tokens = decoder_out.output_tokens
output_scores = decoder_out.output_scores
# execute the decoder
output_masks = output_tokens.ne(self.pad)
@ -127,12 +129,16 @@ class NATransformerModel(TransformerModel):
output_tokens.masked_scatter_(output_masks, _tokens[output_masks])
output_scores.masked_scatter_(output_masks, _scores[output_masks])
return {"output_tokens": output_tokens, "output_scores": output_scores, "attn": None}
return decoder_out._replace(
output_tokens=output_tokens,
output_scores=output_scores,
attn=None,
)
def initialize_output_tokens(self, encoder_out, src_tokens):
# length prediction
_, length_tgt = self.decoder.forward_length_prediction(encoder_out)
max_length = length_tgt.max()
max_length = length_tgt.clamp_(min=2).max()
idx_length = utils.new_arange(src_tokens, max_length)
initial_output_tokens = src_tokens.new_zeros(
@ -146,13 +152,15 @@ class NATransformerModel(TransformerModel):
initial_output_scores = initial_output_tokens.new_zeros(
*initial_output_tokens.size()
).type_as(encoder_out["encoder_out"])
).type_as(encoder_out.encoder_out)
return {
"output_tokens": initial_output_tokens,
"output_scores": initial_output_scores,
"attn": None
}
return DecoderOut(
output_tokens=initial_output_tokens,
output_scores=initial_output_scores,
attn=None,
step=0,
max_step=0,
)
class NATransformerDecoder(TransformerDecoder):
@ -220,8 +228,8 @@ class NATransformerDecoder(TransformerDecoder):
"""
# embedding
if embedding_copy:
src_embd = encoder_out["encoder_embedding"]
src_mask = encoder_out["encoder_padding_mask"]
src_embd = encoder_out.encoder_embedding
src_mask = encoder_out.encoder_padding_mask
src_mask = (
~src_mask
if src_mask is not None
@ -253,10 +261,8 @@ class NATransformerDecoder(TransformerDecoder):
x, attn = layer(
x,
encoder_out["encoder_out"] if encoder_out is not None else None,
encoder_out["encoder_padding_mask"]
if encoder_out is not None
else None,
encoder_out.encoder_out if encoder_out is not None else None,
encoder_out.encoder_padding_mask if encoder_out is not None else None,
self_attn_mask=None,
self_attn_padding_mask=decoder_padding_mask,
)
@ -311,8 +317,8 @@ class NATransformerDecoder(TransformerDecoder):
return copied_embedding
def forward_length_prediction(self, encoder_out, tgt_tokens=None):
enc_feats = encoder_out["encoder_out"] # T x B x C
src_masks = encoder_out["encoder_padding_mask"] # B x T or None
enc_feats = encoder_out.encoder_out # T x B x C
src_masks = encoder_out.encoder_padding_mask # B x T or None
if self.pred_length_offset:
if src_masks is None:

View File

@ -348,7 +348,7 @@ class RobertaEncoder(FairseqDecoder):
- a dictionary of additional data, where 'inner_states'
is a list of hidden states.
"""
x, extra = self.extract_features(src_tokens, return_all_hiddens)
x, extra = self.extract_features(src_tokens, return_all_hiddens=return_all_hiddens)
if not features_only:
x = self.output_layer(x, masked_tokens=masked_tokens)
return x, extra

View File

@ -1,625 +0,0 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from fairseq import options, utils
from fairseq.models import (
FairseqEncoder,
FairseqIncrementalDecoder,
FairseqEncoderDecoderModel,
register_model,
register_model_architecture,
)
from fairseq.models.transformer import Embedding, Linear, base_architecture
from fairseq.modules import (
AdaptiveSoftmax,
LayerNorm,
PositionalEmbedding,
SinusoidalPositionalEmbedding,
TransformerDecoderLayer,
TransformerEncoderLayer,
)
DEFAULT_MAX_SOURCE_POSITIONS = 1024
DEFAULT_MAX_TARGET_POSITIONS = 1024
@register_model('tracing_transformer')
class TracingTransformerModel(FairseqEncoderDecoderModel):
"""
Transformer model from `"Attention Is All You Need" (Vaswani, et al, 2017)
<https://arxiv.org/abs/1706.03762>`_.
Args:
encoder (TransformerEncoder): the encoder
decoder (TransformerDecoder): the decoder
The Transformer model provides the following named architectures and
command-line arguments:
.. argparse::
:ref: fairseq.models.transformer_parser
:prog:
"""
@classmethod
def hub_models(cls):
# fmt: off
return {
'transformer.wmt14.en-fr': 'https://dl.fbaipublicfiles.com/fairseq/models/wmt14.en-fr.joined-dict.transformer.tar.bz2',
'transformer.wmt16.en-de': 'https://dl.fbaipublicfiles.com/fairseq/models/wmt16.en-de.joined-dict.transformer.tar.bz2',
'transformer.wmt18.en-de': 'https://dl.fbaipublicfiles.com/fairseq/models/wmt18.en-de.ensemble.tar.gz',
'transformer.wmt19.en-de': 'https://dl.fbaipublicfiles.com/fairseq/models/wmt19.en-de.joined-dict.ensemble.tar.gz',
'transformer.wmt19.en-ru': 'https://dl.fbaipublicfiles.com/fairseq/models/wmt19.en-ru.ensemble.tar.gz',
'transformer.wmt19.de-en': 'https://dl.fbaipublicfiles.com/fairseq/models/wmt19.de-en.joined-dict.ensemble.tar.gz',
'transformer.wmt19.ru-en': 'https://dl.fbaipublicfiles.com/fairseq/models/wmt19.ru-en.ensemble.tar.gz',
'transformer.wmt19.en-de.single_model': 'https://dl.fbaipublicfiles.com/fairseq/models/wmt19.en-de.joined-dict.single_model.tar.gz',
'transformer.wmt19.en-ru.single_model': 'https://dl.fbaipublicfiles.com/fairseq/models/wmt19.en-ru.single_model.tar.gz',
'transformer.wmt19.de-en.single_model': 'https://dl.fbaipublicfiles.com/fairseq/models/wmt19.de-en.joined-dict.single_model.tar.gz',
'transformer.wmt19.ru-en.single_model': 'https://dl.fbaipublicfiles.com/fairseq/models/wmt19.ru-en.single_model.tar.gz',
}
# fmt: on
def __init__(self, encoder, decoder):
super().__init__(encoder, decoder)
self.supports_align_args = True
@staticmethod
def add_args(parser):
"""Add model-specific arguments to the parser."""
# fmt: off
parser.add_argument('--activation-fn',
choices=utils.get_available_activation_fns(),
help='activation function to use')
parser.add_argument('--dropout', type=float, metavar='D',
help='dropout probability')
parser.add_argument('--attention-dropout', type=float, metavar='D',
help='dropout probability for attention weights')
parser.add_argument('--activation-dropout', '--relu-dropout', type=float, metavar='D',
help='dropout probability after activation in FFN.')
parser.add_argument('--encoder-embed-path', type=str, metavar='STR',
help='path to pre-trained encoder embedding')
parser.add_argument('--encoder-embed-dim', type=int, metavar='N',
help='encoder embedding dimension')
parser.add_argument('--encoder-ffn-embed-dim', type=int, metavar='N',
help='encoder embedding dimension for FFN')
parser.add_argument('--encoder-layers', type=int, metavar='N',
help='num encoder layers')
parser.add_argument('--encoder-attention-heads', type=int, metavar='N',
help='num encoder attention heads')
parser.add_argument('--encoder-normalize-before', action='store_true',
help='apply layernorm before each encoder block')
parser.add_argument('--encoder-learned-pos', action='store_true',
help='use learned positional embeddings in the encoder')
parser.add_argument('--decoder-embed-path', type=str, metavar='STR',
help='path to pre-trained decoder embedding')
parser.add_argument('--decoder-embed-dim', type=int, metavar='N',
help='decoder embedding dimension')
parser.add_argument('--decoder-ffn-embed-dim', type=int, metavar='N',
help='decoder embedding dimension for FFN')
parser.add_argument('--decoder-layers', type=int, metavar='N',
help='num decoder layers')
parser.add_argument('--decoder-attention-heads', type=int, metavar='N',
help='num decoder attention heads')
parser.add_argument('--decoder-learned-pos', action='store_true',
help='use learned positional embeddings in the decoder')
parser.add_argument('--decoder-normalize-before', action='store_true',
help='apply layernorm before each decoder block')
parser.add_argument('--share-decoder-input-output-embed', action='store_true',
help='share decoder input and output embeddings')
parser.add_argument('--share-all-embeddings', action='store_true',
help='share encoder, decoder and output embeddings'
' (requires shared dictionary and embed dim)')
parser.add_argument('--no-token-positional-embeddings', default=False, action='store_true',
help='if set, disables positional embeddings (outside self attention)')
parser.add_argument('--adaptive-softmax-cutoff', metavar='EXPR',
help='comma separated list of adaptive softmax cutoff points. '
'Must be used with adaptive_loss criterion'),
parser.add_argument('--adaptive-softmax-dropout', type=float, metavar='D',
help='sets adaptive softmax dropout for the tail projections')
# args for "Cross+Self-Attention for Transformer Models" (Peitz et al., 2019)
parser.add_argument('--no-cross-attention', default=False, action='store_true',
help='do not perform cross-attention')
parser.add_argument('--cross-self-attention', default=False, action='store_true',
help='perform cross+self-attention')
parser.add_argument('--layer-wise-attention', default=False, action='store_true',
help='perform layer-wise attention (cross-attention or cross+self-attention)')
# fmt: on
@classmethod
def build_model(cls, args, task):
"""Build a new model instance."""
# make sure all arguments are present in older models
base_architecture(args)
if not hasattr(args, 'max_source_positions'):
args.max_source_positions = DEFAULT_MAX_SOURCE_POSITIONS
if not hasattr(args, 'max_target_positions'):
args.max_target_positions = DEFAULT_MAX_TARGET_POSITIONS
src_dict, tgt_dict = task.source_dictionary, task.target_dictionary
def build_embedding(dictionary, embed_dim, path=None):
num_embeddings = len(dictionary)
padding_idx = dictionary.pad()
emb = Embedding(num_embeddings, embed_dim, padding_idx)
# if provided, load from preloaded dictionaries
if path:
embed_dict = utils.parse_embedding(path)
utils.load_embedding(embed_dict, dictionary, emb)
return emb
if args.share_all_embeddings:
if src_dict != tgt_dict:
raise ValueError('--share-all-embeddings requires a joined dictionary')
if args.encoder_embed_dim != args.decoder_embed_dim:
raise ValueError(
'--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim')
if args.decoder_embed_path and (
args.decoder_embed_path != args.encoder_embed_path):
raise ValueError('--share-all-embeddings not compatible with --decoder-embed-path')
encoder_embed_tokens = build_embedding(
src_dict, args.encoder_embed_dim, args.encoder_embed_path
)
decoder_embed_tokens = encoder_embed_tokens
args.share_decoder_input_output_embed = True
else:
encoder_embed_tokens = build_embedding(
src_dict, args.encoder_embed_dim, args.encoder_embed_path
)
decoder_embed_tokens = build_embedding(
tgt_dict, args.decoder_embed_dim, args.decoder_embed_path
)
encoder = cls.build_encoder(args, src_dict, encoder_embed_tokens)
decoder = cls.build_decoder(args, tgt_dict, decoder_embed_tokens)
return cls(encoder, decoder)
@classmethod
def build_encoder(cls, args, src_dict, embed_tokens):
return TracingTransformerEncoder(args, src_dict, embed_tokens)
@classmethod
def build_decoder(cls, args, tgt_dict, embed_tokens):
return TracingTransformerDecoder(
args,
tgt_dict,
embed_tokens,
no_encoder_attn=getattr(args, 'no_cross_attention', False),
)
class TracingTransformerEncoder(FairseqEncoder):
"""
Transformer encoder consisting of *args.encoder_layers* layers. Each layer
is a :class:`TransformerEncoderLayer`.
Args:
args (argparse.Namespace): parsed command-line arguments
dictionary (~fairseq.data.Dictionary): encoding dictionary
embed_tokens (torch.nn.Embedding): input embedding
"""
def __init__(self, args, dictionary, embed_tokens):
super().__init__(dictionary)
self.register_buffer('version', torch.Tensor([3]))
self.dropout = args.dropout
embed_dim = embed_tokens.embedding_dim
self.padding_idx = embed_tokens.padding_idx
self.max_source_positions = args.max_source_positions
self.embed_tokens = embed_tokens
self.embed_scale = math.sqrt(embed_dim)
self.embed_positions = PositionalEmbedding(
args.max_source_positions, embed_dim, self.padding_idx,
learned=args.encoder_learned_pos,
) if not args.no_token_positional_embeddings else None
self.layer_wise_attention = getattr(args, 'layer_wise_attention', False)
self.layers = nn.ModuleList([])
self.layers.extend([
TransformerEncoderLayer(args)
for i in range(args.encoder_layers)
])
if args.encoder_normalize_before:
self.layer_norm = LayerNorm(embed_dim)
else:
self.layer_norm = None
def forward_embedding(self, src_tokens):
# embed tokens and positions
embed = self.embed_scale * self.embed_tokens(src_tokens)
if self.embed_positions is not None:
x = embed + self.embed_positions(src_tokens)
x = F.dropout(x, p=self.dropout, training=self.training)
return x, embed
def forward(self, src_tokens, src_lengths, cls_input=None, return_all_hiddens=False):
"""
Args:
src_tokens (LongTensor): tokens in the source language of shape
`(batch, src_len)`
src_lengths (torch.LongTensor): lengths of each source sentence of
shape `(batch)`
return_all_hiddens (bool, optional): also return all of the
intermediate hidden states (default: False).
Returns:
dict:
- **encoder_out** (Tensor): the last encoder layer's output of
shape `(src_len, batch, embed_dim)`
- **encoder_padding_mask** (ByteTensor): the positions of
padding elements of shape `(batch, src_len)`
- **encoder_states** (List[Tensor]): all intermediate
hidden states of shape `(src_len, batch, embed_dim)`.
Only populated if *return_all_hiddens* is True.
"""
if self.layer_wise_attention:
return_all_hiddens = True
x, encoder_embedding = self.forward_embedding(src_tokens)
# B x T x C -> T x B x C
x = x.transpose(0, 1)
# compute padding mask
encoder_padding_mask = src_tokens.eq(self.padding_idx)
encoder_states = [] if return_all_hiddens else None
# encoder layers
for layer in self.layers:
x = layer(x, encoder_padding_mask)
if return_all_hiddens:
encoder_states.append(x)
if self.layer_norm:
x = self.layer_norm(x)
if return_all_hiddens:
encoder_states[-1] = x
if encoder_states is not None:
return x, encoder_padding_mask, encoder_embedding, encoder_states
else:
return x, encoder_padding_mask, encoder_embedding
def reorder_encoder_out(self, encoder_out, new_order):
"""
Reorder encoder output according to *new_order*.
Args:
encoder_out: output from the ``forward()`` method
new_order (LongTensor): desired order
Returns:
*encoder_out* rearranged according to *new_order*
"""
# 0: encoder_out
# 1: encoder_padding_mask
# 2: encoder_states
if encoder_out[0] is not None:
encoder_out[0] = \
encoder_out[0].index_select(1, new_order)
if encoder_out[1] is not None:
encoder_out[1] = \
encoder_out[1].index_select(0, new_order)
if len(encoder_out) == 3 and encoder_out[2] is not None:
for idx, state in enumerate(encoder_out[2]):
encoder_out[2][idx] = state.index_select(1, new_order)
return encoder_out
def max_positions(self):
"""Maximum input length supported by the encoder."""
if self.embed_positions is None:
return self.max_source_positions
return min(self.max_source_positions, self.embed_positions.max_positions())
def buffered_future_mask(self, tensor):
dim = tensor.size(0)
if not hasattr(self, '_future_mask') or self._future_mask is None or self._future_mask.device != tensor.device:
self._future_mask = torch.triu(utils.fill_with_neg_inf(tensor.new(dim, dim)), 1)
if self._future_mask.size(0) < dim:
self._future_mask = torch.triu(utils.fill_with_neg_inf(self._future_mask.resize_(dim, dim)), 1)
return self._future_mask[:dim, :dim]
def upgrade_state_dict_named(self, state_dict, name):
"""Upgrade a (possibly old) state dict for new versions of fairseq."""
if isinstance(self.embed_positions, SinusoidalPositionalEmbedding):
weights_key = '{}.embed_positions.weights'.format(name)
if weights_key in state_dict:
del state_dict[weights_key]
state_dict['{}.embed_positions._float_tensor'.format(name)] = torch.FloatTensor(1)
for i in range(len(self.layers)):
# update layer norms
self.layers[i].upgrade_state_dict_named(state_dict, "{}.layers.{}".format(name, i))
version_key = '{}.version'.format(name)
if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) < 2:
# earlier checkpoints did not normalize after the stack of layers
self.layer_norm = None
self.normalize = False
state_dict[version_key] = torch.Tensor([1])
return state_dict
class TracingTransformerDecoder(FairseqIncrementalDecoder):
"""
Transformer decoder consisting of *args.decoder_layers* layers. Each layer
is a :class:`TransformerDecoderLayer`.
Args:
args (argparse.Namespace): parsed command-line arguments
dictionary (~fairseq.data.Dictionary): decoding dictionary
embed_tokens (torch.nn.Embedding): output embedding
no_encoder_attn (bool, optional): whether to attend to encoder outputs
(default: False).
"""
def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False):
super().__init__(dictionary)
self.register_buffer('version', torch.Tensor([3]))
self.dropout = args.dropout
self.share_input_output_embed = args.share_decoder_input_output_embed
input_embed_dim = embed_tokens.embedding_dim
embed_dim = args.decoder_embed_dim
self.output_embed_dim = args.decoder_output_dim
self.padding_idx = embed_tokens.padding_idx
self.max_target_positions = args.max_target_positions
self.embed_tokens = embed_tokens
self.embed_scale = math.sqrt(embed_dim) # todo: try with input_embed_dim
self.project_in_dim = Linear(input_embed_dim, embed_dim, bias=False) if embed_dim != input_embed_dim else None
self.embed_positions = PositionalEmbedding(
args.max_target_positions, embed_dim, self.padding_idx,
learned=args.decoder_learned_pos,
) if not args.no_token_positional_embeddings else None
self.cross_self_attention = getattr(args, 'cross_self_attention', False)
self.layer_wise_attention = getattr(args, 'layer_wise_attention', False)
self.layers = nn.ModuleList([])
self.layers.extend([
TransformerDecoderLayer(args, no_encoder_attn)
for _ in range(args.decoder_layers)
])
self.adaptive_softmax = None
self.project_out_dim = Linear(embed_dim, self.output_embed_dim, bias=False) \
if embed_dim != self.output_embed_dim and not args.tie_adaptive_weights else None
if args.adaptive_softmax_cutoff is not None:
self.adaptive_softmax = AdaptiveSoftmax(
len(dictionary),
self.output_embed_dim,
options.eval_str_list(args.adaptive_softmax_cutoff, type=int),
dropout=args.adaptive_softmax_dropout,
adaptive_inputs=embed_tokens if args.tie_adaptive_weights else None,
factor=args.adaptive_softmax_factor,
tie_proj=args.tie_adaptive_proj,
)
elif not self.share_input_output_embed:
self.embed_out = nn.Parameter(torch.Tensor(len(dictionary), self.output_embed_dim))
nn.init.normal_(self.embed_out, mean=0, std=self.output_embed_dim ** -0.5)
if args.decoder_normalize_before and not getattr(args, 'no_decoder_final_norm', False):
self.layer_norm = LayerNorm(embed_dim)
else:
self.layer_norm = None
def forward(
self,
prev_output_tokens,
encoder_out=None,
incremental_state=None,
features_only=False,
**extra_args,
):
"""
Args:
prev_output_tokens (LongTensor): previous decoder outputs of shape
`(batch, tgt_len)`, for teacher forcing
encoder_out (Tensor, optional): output from the encoder, used for
encoder-side attention
incremental_state (dict): dictionary used for storing state during
:ref:`Incremental decoding`
features_only (bool, optional): only return features without
applying output layer (default: False).
Returns:
tuple:
- the decoder's output of shape `(batch, tgt_len, vocab)`
- a dictionary with any model-specific outputs
"""
x, extra = self.extract_features(
prev_output_tokens, encoder_out, incremental_state, **extra_args,
)
if not features_only:
x = self.output_layer(x)
return x, extra
def extract_features(
self,
prev_output_tokens,
encoder_out=None,
incremental_state=None,
full_context_alignment=False,
alignment_layer=None,
alignment_heads=None,
**unused,
):
"""
Similar to *forward* but only return features.
Includes several features from "Jointly Learning to Align and
Translate with Transformer Models" (Garg et al., EMNLP 2019).
Args:
full_context_alignment (bool, optional): don't apply
auto-regressive mask to self-attention (default: False).
alignment_layer (int, optional): return mean alignment over
heads at this layer (default: last layer).
alignment_heads (int, optional): only average alignment over
this many heads (default: all heads).
Returns:
tuple:
- the decoder's features of shape `(batch, tgt_len, embed_dim)`
- a dictionary with any model-specific outputs
"""
if alignment_layer is None:
alignment_layer = len(self.layers) - 1
# embed positions
positions = self.embed_positions(
prev_output_tokens,
incremental_state=incremental_state,
) if self.embed_positions is not None else None
if incremental_state is not None:
prev_output_tokens = prev_output_tokens[:, -1:]
if positions is not None:
positions = positions[:, -1:]
# embed tokens and positions
x = self.embed_scale * self.embed_tokens(prev_output_tokens)
if self.project_in_dim is not None:
x = self.project_in_dim(x)
if positions is not None:
x += positions
x = F.dropout(x, p=self.dropout, training=self.training)
# B x T x C -> T x B x C
x = x.transpose(0, 1)
self_attn_padding_mask = prev_output_tokens.eq(self.padding_idx)
if not self_attn_padding_mask.any() and not self.cross_self_attention:
self_attn_padding_mask = None
# decoder layers
attn = None
inner_states = [x]
for idx, layer in enumerate(self.layers):
encoder_state = None
if encoder_out is not None:
if self.layer_wise_attention:
encoder_state = encoder_out[3][idx]
else:
encoder_state = encoder_out[0]
if incremental_state is None and not full_context_alignment:
self_attn_mask = self.buffered_future_mask(x)
else:
self_attn_mask = None
x, layer_attn = layer(
x,
encoder_state
if encoder_state is not None else None,
encoder_out[1]
if encoder_out is not None else None,
incremental_state,
self_attn_mask=self_attn_mask,
self_attn_padding_mask=self_attn_padding_mask,
need_attn=(idx == alignment_layer),
need_head_weights=(idx == alignment_layer),
)
inner_states.append(x)
if layer_attn is not None and idx == alignment_layer:
attn = layer_attn.float()
if attn is not None:
if alignment_heads is not None:
attn = attn[:alignment_heads]
# average probabilities over heads
attn = attn.mean(dim=0)
if self.layer_norm:
x = self.layer_norm(x)
# T x B x C -> B x T x C
x = x.transpose(0, 1)
if self.project_out_dim is not None:
x = self.project_out_dim(x)
return x, {'attn': attn, 'inner_states': inner_states}
def output_layer(self, features, **kwargs):
"""Project features to the vocabulary size."""
if self.adaptive_softmax is None:
# project back to size of vocabulary
if self.share_input_output_embed:
return F.linear(features, self.embed_tokens.weight)
else:
return F.linear(features, self.embed_out)
else:
return features
def max_positions(self):
"""Maximum output length supported by the decoder."""
if self.embed_positions is None:
return self.max_target_positions
return min(self.max_target_positions, self.embed_positions.max_positions())
def buffered_future_mask(self, tensor):
dim = tensor.size(0)
if (
not hasattr(self, '_future_mask')
or self._future_mask is None
or self._future_mask.device != tensor.device
or self._future_mask.size(0) < dim
):
self._future_mask = torch.triu(utils.fill_with_neg_inf(tensor.new(dim, dim)), 1)
return self._future_mask[:dim, :dim]
def upgrade_state_dict_named(self, state_dict, name):
"""Upgrade a (possibly old) state dict for new versions of fairseq."""
if isinstance(self.embed_positions, SinusoidalPositionalEmbedding):
weights_key = '{}.embed_positions.weights'.format(name)
if weights_key in state_dict:
del state_dict[weights_key]
state_dict['{}.embed_positions._float_tensor'.format(name)] = torch.FloatTensor(1)
for i in range(len(self.layers)):
# update layer norms
layer_norm_map = {
'0': 'self_attn_layer_norm',
'1': 'encoder_attn_layer_norm',
'2': 'final_layer_norm'
}
for old, new in layer_norm_map.items():
for m in ('weight', 'bias'):
k = '{}.layers.{}.layer_norms.{}.{}'.format(name, i, old, m)
if k in state_dict:
state_dict['{}.layers.{}.{}.{}'.format(name, i, new, m)] = state_dict[k]
del state_dict[k]
version_key = '{}.version'.format(name)
if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) <= 2:
# earlier checkpoints did not normalize after the stack of layers
self.layer_norm = None
self.normalize = False
state_dict[version_key] = torch.Tensor([1])
return state_dict

View File

@ -3,6 +3,7 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from collections import namedtuple
import math
import torch
@ -279,6 +280,14 @@ class TransformerAlignModel(TransformerModel):
return decoder_out
EncoderOut = namedtuple('TransformerEncoderOut', [
'encoder_out', # T x B x C
'encoder_padding_mask', # B x T
'encoder_embedding', # B x T x C
'encoder_states', # List[T x B x C]
])
class TransformerEncoder(FairseqEncoder):
"""
Transformer encoder consisting of *args.encoder_layers* layers. Each layer
@ -348,11 +357,13 @@ class TransformerEncoder(FairseqEncoder):
intermediate hidden states (default: False).
Returns:
dict:
namedtuple:
- **encoder_out** (Tensor): the last encoder layer's output of
shape `(src_len, batch, embed_dim)`
- **encoder_padding_mask** (ByteTensor): the positions of
padding elements of shape `(batch, src_len)`
- **encoder_embedding** (Tensor): the (scaled) embedding lookup
of shape `(batch, src_len, embed_dim)`
- **encoder_states** (List[Tensor]): all intermediate
hidden states of shape `(src_len, batch, embed_dim)`.
Only populated if *return_all_hiddens* is True.
@ -386,12 +397,12 @@ class TransformerEncoder(FairseqEncoder):
if return_all_hiddens:
encoder_states[-1] = x
return {
'encoder_out': x, # T x B x C
'encoder_padding_mask': encoder_padding_mask, # B x T
'encoder_embedding': encoder_embedding, # B x T x C
'encoder_states': encoder_states, # List[T x B x C]
}
return EncoderOut(
encoder_out=x, # T x B x C
encoder_padding_mask=encoder_padding_mask, # B x T
encoder_embedding=encoder_embedding, # B x T x C
encoder_states=encoder_states, # List[T x B x C]
)
def reorder_encoder_out(self, encoder_out, new_order):
"""
@ -404,15 +415,21 @@ class TransformerEncoder(FairseqEncoder):
Returns:
*encoder_out* rearranged according to *new_order*
"""
if encoder_out['encoder_out'] is not None:
encoder_out['encoder_out'] = \
encoder_out['encoder_out'].index_select(1, new_order)
if encoder_out['encoder_padding_mask'] is not None:
encoder_out['encoder_padding_mask'] = \
encoder_out['encoder_padding_mask'].index_select(0, new_order)
if encoder_out.get('encoder_states', None) is not None:
for idx, state in enumerate(encoder_out['encoder_states']):
encoder_out['encoder_states'][idx] = state.index_select(1, new_order)
if encoder_out.encoder_out is not None:
encoder_out = encoder_out._replace(
encoder_out=encoder_out.encoder_out.index_select(1, new_order)
)
if encoder_out.encoder_padding_mask is not None:
encoder_out = encoder_out._replace(
encoder_padding_mask=encoder_out.encoder_padding_mask.index_select(0, new_order)
)
if encoder_out.encoder_embedding is not None:
encoder_out = encoder_out._replace(
encoder_embedding=encoder_out.encoder_embedding.index_select(0, new_order)
)
if encoder_out.encoder_states is not None:
for idx, state in enumerate(encoder_out.encoder_states):
encoder_out.encoder_states[idx] = state.index_select(1, new_order)
return encoder_out
def max_positions(self):
@ -532,13 +549,13 @@ class TransformerDecoder(FairseqIncrementalDecoder):
encoder_out=None,
incremental_state=None,
features_only=False,
**extra_args,
**extra_args
):
"""
Args:
prev_output_tokens (LongTensor): previous decoder outputs of shape
`(batch, tgt_len)`, for teacher forcing
encoder_out (Tensor, optional): output from the encoder, used for
encoder_out (optional): output from the encoder, used for
encoder-side attention
incremental_state (dict): dictionary used for storing state during
:ref:`Incremental decoding`
@ -551,7 +568,10 @@ class TransformerDecoder(FairseqIncrementalDecoder):
- a dictionary with any model-specific outputs
"""
x, extra = self.extract_features(
prev_output_tokens, encoder_out, incremental_state, **extra_args,
prev_output_tokens,
encoder_out=encoder_out,
incremental_state=incremental_state,
**extra_args
)
if not features_only:
x = self.output_layer(x)
@ -628,9 +648,9 @@ class TransformerDecoder(FairseqIncrementalDecoder):
encoder_state = None
if encoder_out is not None:
if self.layer_wise_attention:
encoder_state = encoder_out['encoder_states'][idx]
encoder_state = encoder_out.encoder_states[idx]
else:
encoder_state = encoder_out['encoder_out']
encoder_state = encoder_out.encoder_out
if incremental_state is None and not full_context_alignment:
self_attn_mask = self.buffered_future_mask(x)
@ -643,7 +663,7 @@ class TransformerDecoder(FairseqIncrementalDecoder):
x, layer_attn = layer(
x,
encoder_state,
encoder_out['encoder_padding_mask'] if encoder_out is not None else None,
encoder_out.encoder_padding_mask if encoder_out is not None else None,
incremental_state,
self_attn_mask=self_attn_mask,
self_attn_padding_mask=self_attn_padding_mask,

View File

@ -26,16 +26,15 @@ class MeanPoolGatingNetwork(torch.nn.Module):
def forward(self, encoder_out):
if not (
isinstance(encoder_out, dict)
and 'encoder_out' in encoder_out
and 'encoder_padding_mask' in encoder_out
and encoder_out['encoder_out'].size(2) == self.embed_dim
hasattr(encoder_out, 'encoder_out')
and hasattr(encoder_out, 'encoder_padding_mask')
and encoder_out.encoder_out.size(2) == self.embed_dim
):
raise ValueError('Unexpected format for encoder_out')
# mean pooling over time
encoder_padding_mask = encoder_out['encoder_padding_mask'] # B x T
encoder_out = encoder_out['encoder_out'].transpose(0, 1) # B x T x C
encoder_padding_mask = encoder_out.encoder_padding_mask # B x T
encoder_out = encoder_out.encoder_out.transpose(0, 1) # B x T x C
if encoder_padding_mask is not None:
encoder_out = encoder_out.clone() # required because of transpose above
encoder_out[encoder_padding_mask] = 0

View File

@ -197,51 +197,90 @@ class TestTranslation(unittest.TestCase):
])
generate_main(data_dir)
def test_cmlm_transformer(self):
with contextlib.redirect_stdout(StringIO()):
with tempfile.TemporaryDirectory('test_cmlm_transformer') as data_dir:
create_dummy_data(data_dir)
preprocess_translation_data(data_dir, ['--joined-dictionary'])
train_translation_model(data_dir, 'cmlm_transformer', [
'--apply-bert-init',
'--criterion', 'nat_loss',
'--noise', 'full_mask',
'--pred-length-offset',
'--length-loss-factor', '0.1'
], task='translation_lev')
generate_main(data_dir, [
'--task', 'translation_lev',
'--iter-decode-max-iter', '9',
'--iter-decode-eos-penalty', '0',
'--print-step',
])
def test_levenshtein_transformer(self):
with contextlib.redirect_stdout(StringIO()):
with tempfile.TemporaryDirectory('test_levenshtein_transformer') as data_dir:
create_dummy_data(data_dir)
preprocess_translation_data(data_dir)
preprocess_translation_data(data_dir, ['--joined-dictionary'])
train_translation_model(data_dir, 'levenshtein_transformer', [
'--apply-bert-init', '--early-exit', '6,6,6',
'--criterion', 'nat_loss'
], task='translation_lev')
generate_main(data_dir)
generate_main(data_dir, [
'--task', 'translation_lev',
'--iter-decode-max-iter', '9',
'--iter-decode-eos-penalty', '0',
'--print-step',
])
def test_nonautoregressive_transformer(self):
with contextlib.redirect_stdout(StringIO()):
with tempfile.TemporaryDirectory('test_nonautoregressive_transformer') as data_dir:
create_dummy_data(data_dir)
preprocess_translation_data(data_dir)
preprocess_translation_data(data_dir, ['--joined-dictionary'])
train_translation_model(data_dir, 'nonautoregressive_transformer', [
'--apply-bert-init', '--src-embedding-copy', '--criterion',
'nat_loss', '--noise', 'full_mask', '--pred-length-offset',
'--length-loss-factor', '0.1'
], task='translation_lev')
generate_main(data_dir)
generate_main(data_dir, [
'--task', 'translation_lev',
'--iter-decode-max-iter', '9',
'--iter-decode-eos-penalty', '0',
'--print-step',
])
def test_iterative_nonautoregressive_transformer(self):
with contextlib.redirect_stdout(StringIO()):
with tempfile.TemporaryDirectory('test_iterative_nonautoregressive_transformer') as data_dir:
create_dummy_data(data_dir)
preprocess_translation_data(data_dir)
preprocess_translation_data(data_dir, ['--joined-dictionary'])
train_translation_model(data_dir, 'iterative_nonautoregressive_transformer', [
'--apply-bert-init', '--src-embedding-copy', '--criterion',
'nat_loss', '--noise', 'full_mask', '--stochastic-approx',
'--dae-ratio', '0.5', '--train-step', '3'
], task='translation_lev')
generate_main(data_dir)
generate_main(data_dir, [
'--task', 'translation_lev',
'--iter-decode-max-iter', '9',
'--iter-decode-eos-penalty', '0',
'--print-step',
])
def test_insertion_transformer(self):
with contextlib.redirect_stdout(StringIO()):
with tempfile.TemporaryDirectory('test_insertion_transformer') as data_dir:
create_dummy_data(data_dir)
preprocess_translation_data(data_dir)
preprocess_translation_data(data_dir, ['--joined-dictionary'])
train_translation_model(data_dir, 'insertion_transformer', [
'--apply-bert-init', '--criterion', 'nat_loss', '--noise',
'random_mask'
], task='translation_lev')
generate_main(data_dir)
generate_main(data_dir, [
'--task', 'translation_lev',
'--iter-decode-max-iter', '9',
'--iter-decode-eos-penalty', '0',
'--print-step',
])
def test_mixture_of_experts(self):
with contextlib.redirect_stdout(StringIO()):