From 86857a58bf2919c7bec3c29c58234aa4c434d566 Mon Sep 17 00:00:00 2001 From: Changhan Wang Date: Fri, 27 Sep 2019 13:56:47 -0700 Subject: [PATCH] Levenshtein Transformer paper code Summary: Code for our NeurIPS paper [Levenshtein Transformer](https://arxiv.org/abs/1905.11006) * Added Levenshtein Transformer model, task and criterion class * Added iterative NAT Transformer, insertion Transformer and CMLM Transformer model class for baselines * Add an option for prepending BOS to dictionary class and translation task class Reviewed By: myleott Differential Revision: D17297372 fbshipit-source-id: 54eca60831ae95dc721c2c34e882e1810ee575c7 --- README.md | 11 +- .../nonautoregressive_translation/README.md | 90 +++ .../nonautoregressive_translation/scripts.md | 148 ++++ fairseq/clib/libnat/edit_dist.cpp | 222 ++++++ fairseq/criterions/nat_loss.py | 190 ++++++ fairseq/data/dictionary.py | 5 +- fairseq/iterative_refinement_generator.py | 154 +++++ fairseq/models/cmlm_transformer.py | 136 ++++ fairseq/models/insertion_transformer.py | 259 +++++++ ...iterative_nonautoregressive_transformer.py | 196 ++++++ fairseq/models/levenshtein_transformer.py | 595 ++++++++++++++++ fairseq/models/model_utils.py | 62 ++ .../models/nonautoregressive_transformer.py | 640 ++++++++++++++++++ fairseq/models/transformer.py | 23 +- fairseq/modules/multihead_attention.py | 5 +- fairseq/modules/transformer_layer.py | 2 +- .../modules/transformer_sentence_encoder.py | 3 +- fairseq/options.py | 14 + fairseq/tasks/translation.py | 9 +- fairseq/tasks/translation_lev.py | 149 ++++ fairseq/utils.py | 8 + generate.py | 3 + setup.py | 8 + tests/test_binaries.py | 46 ++ train.py | 5 + 25 files changed, 2968 insertions(+), 15 deletions(-) create mode 100644 examples/nonautoregressive_translation/README.md create mode 100644 examples/nonautoregressive_translation/scripts.md create mode 100644 fairseq/clib/libnat/edit_dist.cpp create mode 100644 fairseq/criterions/nat_loss.py create mode 100644 fairseq/iterative_refinement_generator.py create mode 100644 fairseq/models/cmlm_transformer.py create mode 100644 fairseq/models/insertion_transformer.py create mode 100644 fairseq/models/iterative_nonautoregressive_transformer.py create mode 100644 fairseq/models/levenshtein_transformer.py create mode 100644 fairseq/models/model_utils.py create mode 100644 fairseq/models/nonautoregressive_transformer.py create mode 100644 fairseq/tasks/translation_lev.py diff --git a/README.md b/README.md index 45dce65cf..c39ff22c9 100644 --- a/README.md +++ b/README.md @@ -6,6 +6,7 @@ modeling and other text generation tasks. ### What's New: +- September 2019: [Nonautoregressive translation code released](examples/nonautoregressive_translation/README.md) - August 2019: [WMT'19 models released](examples/wmt19/README.md) - July 2019: fairseq relicensed under MIT license - July 2019: [RoBERTa models and code released](examples/roberta/README.md) @@ -32,6 +33,13 @@ Fairseq provides reference implementations of various sequence-to-sequence model - [Mixture Models for Diverse Machine Translation: Tricks of the Trade (Shen et al., 2019)](examples/translation_moe/README.md) - [RoBERTa: A Robustly Optimized BERT Pretraining Approach (Liu et al., 2019)](examples/roberta/README.md) - [Facebook FAIR's WMT19 News Translation Task Submission (Ng et al., 2019)](examples/wmt19/README.md) +- **Non-autoregressive Transformers** + - Non-Autoregressive Neural Machine Translation (Gu et al., 2017) + - Deterministic Non-Autoregressive Neural Sequence Modeling by Iterative Refinement (Lee et al. 2018) + - Insertion Transformer: Flexible Sequence Generation via Insertion Operations (Stern et al. 2019) + - Mask-Predict: Parallel Decoding of Conditional Masked Language Models (Ghazvininejad et al., 2019) + - [Levenshtein Transformer (Gu et al., 2019)](examples/nonautoregressive_translation/README.md) + **Additionally:** - multi-GPU (distributed) training on one machine or across multiple machines @@ -50,7 +58,7 @@ translation and language modeling datasets. # Requirements and Installation -* [PyTorch](http://pytorch.org/) version >= 1.1.0 +* [PyTorch](http://pytorch.org/) version >= 1.2.0 * Python version >= 3.5 * For training new models, you'll also need an NVIDIA GPU and [NCCL](https://github.com/NVIDIA/nccl) * **For faster training** install NVIDIA's [apex](https://github.com/NVIDIA/apex) library with the `--cuda_ext` option @@ -92,6 +100,7 @@ as well as example training and evaluation commands. - [Language Modeling](examples/language_model/README.md): convolutional and transformer models are available We also have more detailed READMEs to reproduce results from specific papers: +- [Levenshtein Transformer (Gu et al., 2019)](examples/nonautoregressive_translation/README.md) - [Facebook FAIR's WMT19 News Translation Task Submission (Ng et al., 2019)](examples/wmt19/README.md) - [RoBERTa: A Robustly Optimized BERT Pretraining Approach (Liu et al., 2019)](examples/roberta/README.md) - [wav2vec: Unsupervised Pre-training for Speech Recognition (Schneider et al., 2019)](examples/wav2vec/README.md) diff --git a/examples/nonautoregressive_translation/README.md b/examples/nonautoregressive_translation/README.md new file mode 100644 index 000000000..5f030868c --- /dev/null +++ b/examples/nonautoregressive_translation/README.md @@ -0,0 +1,90 @@ +# Non-autoregressive Neural Machine Translation (NAT) + +This page mainly includes instructions for reproducing results from the paper +* [Levenshtein Transformer (Gu et al., 2019)](https://arxiv.org/abs/1905.11006). + +We also provided our own implementations for several popular non-autoregressive-based models as reference:
+* [Non-Autoregressive Neural Machine Translation (Gu et al., 2017)](https://arxiv.org/abs/1711.02281)
+* [Deterministic Non-Autoregressive Neural Sequence Modeling by Iterative Refinement (Lee et al. 2018)](https://arxiv.org/abs/1802.06901)
+* [Insertion Transformer: Flexible Sequence Generation via Insertion Operations (Stern et al. 2019)](https://arxiv.org/abs/1902.03249)
+* [Mask-Predict: Parallel Decoding of Conditional Masked Language Models (Ghazvininejad et al., 2019)](https://arxiv.org/abs/1904.09324v2) + +## Dataset + +First, follow the [instructions to download and preprocess the WMT'14 En-De dataset](../translation#prepare-wmt14en2desh). +Make sure to learn a joint vocabulary by passing the `--joined-dictionary` option to `fairseq-preprocess`. + +### Knowledge Distillation +Following [Gu et al. 2019](https://arxiv.org/abs/1905.11006), [knowledge distillation](https://arxiv.org/abs/1606.07947) from an autoregressive model can effectively simplify the training data distribution, which is sometimes essential for NAT-based models to learn good translations. +The easiest way of performing distillation is to follow the [instructions of training a standard transformer model](../translation) on the same data, and then decode the training set to produce a distillation dataset for NAT. + +### Download +We also provided the preprocessed [original](http://dl.fbaipublicfiles.com/nat/original_dataset.zip) and [distillation](http://dl.fbaipublicfiles.com/nat/distill_dataset.zip) datasets. Please build the binarized dataset on your own. + + +## Train a model + +Then we can train a nonautoregressive model using the `translation_lev` task and a new criterion `nat_loss`. +Use the `--noise` flag to specify the input noise used on the target sentences. +In default, we run the task for *Levenshtein Transformer*, with `--noise='random_delete'`. Full scripts to run other models can also be found [here](./scripts.md). + +The following command will train a *Levenshtein Transformer* on the binarized dataset. + +```bash +fairseq-train \ + data-bin/wmt14_en_de_distill \ + --save-dir checkpoints \ + --ddp-backend=no_c10d \ + --task translation_lev \ + --criterion nat_loss \ + --arch levenshtein_transformer \ + --noise random_delete \ + --share-all-embeddings \ + --optimizer adam --adam-betas '(0.9,0.98)' \ + --lr 0.0005 --lr-scheduler inverse_sqrt \ + --min-lr '1e-09' --warmup-updates 10000 \ + --warmup-init-lr '1e-07' --label-smoothing 0.1 \ + --dropout 0.3 --weight-decay 0.01 \ + --decoder-learned-pos \ + --encoder-learned-pos \ + --apply-bert-init \ + --log-format 'simple' --log-interval 100 \ + --fixed-validation-seed 7 \ + --max-tokens 8000 \ + --save-interval-updates 10000 \ + --max-update 300000 +``` + +## Translate + +Once a model is trained, we can generate translations using an `iterative_refinement_generator` which will based on the model's initial output and iteratively read and greedily refine the translation until (1) the model predicts the same translations for two consecutive iterations; or (2) the generator reaches the maximum iterations (`--iter-decode-max-iter`). Use `--print-step` to check the actual # of iteration for each sentence. + +For *Levenshtein Transformer*, it sometimes helps to apply a `--iter-decode-eos-penalty` (typically, 0~3) to penalize the model finishing generation too early and generating too short translations. + + +For example, to generate with `--iter-decode-max-iter=9`: +```bash +fairseq-generate \ + data-bin/wmt14_en_de_distill \ + --gen-subset test \ + --task translation_lev \ + --path checkpoints/checkpoint_best.pt \ + --iter-decode-max-iter 9 \ + --iter-decode-eos-penalty 0 \ + --beam 1 --remove-bpe \ + --print-step \ + --batch-size 400 +``` +In the end of the generation, we can see the tokenized BLEU score for the translation. + + +## Citation + +```bibtex +@article{gu2019levenshtein, + title={Levenshtein Transformer}, + author={Gu, Jiatao and Wang, Changhan and Zhao, Jake}, + journal={arXiv preprint arXiv:1905.11006}, + year={2019} +} +``` diff --git a/examples/nonautoregressive_translation/scripts.md b/examples/nonautoregressive_translation/scripts.md new file mode 100644 index 000000000..2fda7f620 --- /dev/null +++ b/examples/nonautoregressive_translation/scripts.md @@ -0,0 +1,148 @@ +# Examples of Training scripts for Non-autoregressive Machine Translation models + +### Non-autoregressive Transformer (NAT, Gu et al., 2017) +Note that we need to have an additional module to perform "length prediction" (`--length-loss-factor`) before generating the whole sequence. +```bash +fairseq-train \ + data-bin/wmt14_en_de_distill \ + --save-dir checkpoints \ + --ddp-backend=no_c10d \ + --task translation_lev \ + --criterion nat_loss \ + --arch nonautoregressive_transformer \ + --noise full_mask \ + --share-all-embeddings \ + --optimizer adam --adam-betas '(0.9,0.98)' \ + --lr 0.0005 --lr-scheduler inverse_sqrt \ + --min-lr '1e-09' --warmup-updates 10000 \ + --warmup-init-lr '1e-07' --label-smoothing 0.1 \ + --dropout 0.3 --weight-decay 0.01 \ + --decoder-learned-pos \ + --encoder-learned-pos \ + --pred-length-offset \ + --length-loss-factor 0.1 \ + --apply-bert-init \ + --log-format 'simple' --log-interval 100 \ + --fixed-validation-seed 7 \ + --max-tokens 8000 \ + --save-interval-updates 10000 \ + --max-update 300000 +``` + +### Non-autoregressive Transformer with Iterative Refinement (iNAT, Lee et al., 2018) +Note that `--train-step` means how many iterations of refinement we used during training, and `--dae-ratio` controls the ratio of denoising auto-encoder training described in the original paper. +```bash +fairseq-train \ + data-bin/wmt14_en_de_distill \ + --save-dir checkpoints \ + --ddp-backend=no_c10d \ + --task translation_lev \ + --criterion nat_loss \ + --arch nonautoregressive_transformer \ + --noise full_mask \ + --share-all-embeddings \ + --optimizer adam --adam-betas '(0.9,0.98)' \ + --lr 0.0005 --lr-scheduler inverse_sqrt \ + --min-lr '1e-09' --warmup-updates 10000 \ + --warmup-init-lr '1e-07' --label-smoothing 0.1 \ + --dropout 0.3 --weight-decay 0.01 \ + --decoder-learned-pos \ + --encoder-learned-pos \ + --pred-length-offset \ + --length-loss-factor 0.1 \ + --train-step 4 \ + --dae-ratio 0.5 \ + --stochastic-approx \ + --apply-bert-init \ + --log-format 'simple' --log-interval 100 \ + --fixed-validation-seed 7 \ + --max-tokens 8000 \ + --save-interval-updates 10000 \ + --max-update 300000 +``` + +### Insertion Transformer (InsT, Stern et al., 2019) +Note that we need to specify the "slot-loss" (uniform or balanced tree) described in the original paper. Here we use `--label-tau` to control the temperature. + +```bash +fairseq-train \ + data-bin/wmt14_en_de_distill \ + --save-dir checkpoints \ + --ddp-backend=no_c10d \ + --task translation_lev \ + --criterion nat_loss \ + --arch insertion_transformer \ + --noise random_delete \ + --share-all-embeddings \ + --optimizer adam --adam-betas '(0.9,0.98)' \ + --lr 0.0005 --lr-scheduler inverse_sqrt \ + --min-lr '1e-09' --warmup-updates 10000 \ + --warmup-init-lr '1e-07' --label-smoothing 0.1 \ + --dropout 0.3 --weight-decay 0.01 \ + --decoder-learned-pos \ + --encoder-learned-pos \ + --pred-length-offset \ + --length-loss-factor 0.1 \ + --apply-bert-init \ + --log-format 'simple' --log-interval 100 \ + --fixed-validation-seed 7 \ + --max-tokens 8000 \ + --save-interval-updates 10000 \ + --max-update 300000 +``` + + +### Mask Predict (CMLM, Ghazvininejad et al., 2019) +```bash +fairseq-train \ + data-bin/wmt14_en_de_distill \ + --save-dir checkpoints \ + --ddp-backend=no_c10d \ + --task translation_lev \ + --criterion nat_loss \ + --arch cmlm_transformer \ + --noise random_mask \ + --share-all-embeddings \ + --optimizer adam --adam-betas '(0.9,0.98)' \ + --lr 0.0005 --lr-scheduler inverse_sqrt \ + --min-lr '1e-09' --warmup-updates 10000 \ + --warmup-init-lr '1e-07' --label-smoothing 0.1 \ + --dropout 0.3 --weight-decay 0.01 \ + --decoder-learned-pos \ + --encoder-learned-pos \ + --apply-bert-init \ + --log-format 'simple' --log-interval 100 \ + --fixed-validation-seed 7 \ + --max-tokens 8000 \ + --save-interval-updates 10000 \ + --max-update 300000 +``` + + + + +### Levenshtein Transformer (LevT, Gu et al., 2019) +```bash +fairseq-train \ + data-bin/wmt14_en_de_distill \ + --save-dir checkpoints \ + --ddp-backend=no_c10d \ + --task translation_lev \ + --criterion nat_loss \ + --arch levenshtein_transformer \ + --noise random_delete \ + --share-all-embeddings \ + --optimizer adam --adam-betas '(0.9,0.98)' \ + --lr 0.0005 --lr-scheduler inverse_sqrt \ + --min-lr '1e-09' --warmup-updates 10000 \ + --warmup-init-lr '1e-07' --label-smoothing 0.1 \ + --dropout 0.3 --weight-decay 0.01 \ + --decoder-learned-pos \ + --encoder-learned-pos \ + --apply-bert-init \ + --log-format 'simple' --log-interval 100 \ + --fixed-validation-seed 7 \ + --max-tokens 8000 \ + --save-interval-updates 10000 \ + --max-update 300000 +``` diff --git a/fairseq/clib/libnat/edit_dist.cpp b/fairseq/clib/libnat/edit_dist.cpp new file mode 100644 index 000000000..966e9083b --- /dev/null +++ b/fairseq/clib/libnat/edit_dist.cpp @@ -0,0 +1,222 @@ +/** + * Copyright 2017-present, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include // @manual=//caffe2:torch_extension +#include + +using namespace ::std; + +vector> edit_distance2_with_dp( + vector& x, + vector& y) { + uint32_t lx = x.size(); + uint32_t ly = y.size(); + vector> d(lx + 1, vector(ly + 1)); + for (uint32_t i = 0; i < lx + 1; i++) { + d[i][0] = i; + } + for (uint32_t j = 0; j < ly + 1; j++) { + d[0][j] = j; + } + for (uint32_t i = 1; i < lx + 1; i++) { + for (uint32_t j = 1; j < ly + 1; j++) { + d[i][j] = + min(min(d[i - 1][j], d[i][j - 1]) + 1, + d[i - 1][j - 1] + 2 * (x.at(i - 1) == y.at(j - 1) ? 0 : 1)); + } + } + return d; +} + +vector> edit_distance2_backtracking( + vector>& d, + vector& x, + vector& y, + uint32_t terminal_symbol) { + vector seq; + vector> edit_seqs(x.size() + 2, vector()); + /* + edit_seqs: + 0~x.size() cell is the insertion sequences + last cell is the delete sequence + */ + + if (x.size() == 0) { + edit_seqs.at(0) = y; + return edit_seqs; + } + + uint32_t i = d.size() - 1; + uint32_t j = d.at(0).size() - 1; + + while ((i >= 0) && (j >= 0)) { + if ((i == 0) && (j == 0)) { + break; + } + + if ((j > 0) && (d.at(i).at(j - 1) < d.at(i).at(j))) { + seq.push_back(1); // insert + seq.push_back(y.at(j - 1)); + j--; + } else if ((i > 0) && (d.at(i - 1).at(j) < d.at(i).at(j))) { + seq.push_back(2); // delete + seq.push_back(x.at(i - 1)); + i--; + } else { + seq.push_back(3); // keep + seq.push_back(x.at(i - 1)); + i--; + j--; + } + } + + uint32_t prev_op, op, s, word; + prev_op = 0, s = 0; + for (uint32_t k = 0; k < seq.size() / 2; k++) { + op = seq.at(seq.size() - 2 * k - 2); + word = seq.at(seq.size() - 2 * k - 1); + if (prev_op != 1) { + s++; + } + if (op == 1) // insert + { + edit_seqs.at(s - 1).push_back(word); + } else if (op == 2) // delete + { + edit_seqs.at(x.size() + 1).push_back(1); + } else { + edit_seqs.at(x.size() + 1).push_back(0); + } + + prev_op = op; + } + + for (uint32_t k = 0; k < edit_seqs.size(); k++) { + if (edit_seqs[k].size() == 0) { + edit_seqs[k].push_back(terminal_symbol); + } + } + return edit_seqs; +} + +vector> edit_distance2_backtracking_with_delete( + vector>& d, + vector& x, + vector& y, + uint32_t terminal_symbol, + uint32_t deletion_symbol) { + vector seq; + vector> edit_seqs(x.size() + 1, vector()); + /* + edit_seqs: + 0~x.size() cell is the insertion sequences + last cell is the delete sequence + */ + + if (x.size() == 0) { + edit_seqs.at(0) = y; + return edit_seqs; + } + + uint32_t i = d.size() - 1; + uint32_t j = d.at(0).size() - 1; + + while ((i >= 0) && (j >= 0)) { + if ((i == 0) && (j == 0)) { + break; + } + + if ((j > 0) && (d.at(i).at(j - 1) < d.at(i).at(j))) { + seq.push_back(1); // insert + seq.push_back(y.at(j - 1)); + j--; + } else if ((i > 0) && (d.at(i - 1).at(j) < d.at(i).at(j))) { + seq.push_back(2); // delete + seq.push_back(x.at(i - 1)); + i--; + } else { + seq.push_back(3); // keep + seq.push_back(x.at(i - 1)); + i--; + j--; + } + } + + uint32_t prev_op, op, s, word; + prev_op = 0, s = 0; + for (uint32_t k = 0; k < seq.size() / 2; k++) { + op = seq.at(seq.size() - 2 * k - 2); + word = seq.at(seq.size() - 2 * k - 1); + if (prev_op != 1) { + s++; + } + if (op == 1) // insert + { + edit_seqs.at(s - 1).push_back(word); + } else if (op == 2) // delete + { + edit_seqs.at(s - 1).push_back(deletion_symbol); + } + + prev_op = op; + } + + for (uint32_t k = 0; k < edit_seqs.size(); k++) { + if (edit_seqs.at(k).size() == 0) { + edit_seqs.at(k).push_back(terminal_symbol); + } + } + return edit_seqs; +} + +vector compute_ed2( + vector>& xs, + vector>& ys) { + vector distances(xs.size()); + for (uint32_t i = 0; i < xs.size(); i++) { + vector> d = edit_distance2_with_dp(xs.at(i), ys.at(i)); + distances.at(i) = d.at(xs.at(i).size()).at(ys.at(i).size()); + } + return distances; +} + +vector>> suggested_ed2_path( + vector>& xs, + vector>& ys, + uint32_t terminal_symbol) { + vector>> seq(xs.size()); + for (uint32_t i = 0; i < xs.size(); i++) { + vector> d = edit_distance2_with_dp(xs.at(i), ys.at(i)); + seq.at(i) = + edit_distance2_backtracking(d, xs.at(i), ys.at(i), terminal_symbol); + } + return seq; +} + +vector>> suggested_ed2_path_with_delete( + vector>& xs, + vector>& ys, + uint32_t terminal_symbol, + uint32_t deletion_symbol) { + vector>> seq(xs.size()); + for (uint32_t i = 0; i < xs.size(); i++) { + vector> d = edit_distance2_with_dp(xs.at(i), ys.at(i)); + seq.at(i) = edit_distance2_backtracking_with_delete( + d, xs.at(i), ys.at(i), terminal_symbol, deletion_symbol); + } + return seq; +} + +PYBIND11_MODULE(libnat, m) { + m.def("compute_ed2", &compute_ed2, "compute_ed2"); + m.def("suggested_ed2_path", &suggested_ed2_path, "suggested_ed2_path"); + m.def( + "suggested_ed2_path_with_delete", + &suggested_ed2_path_with_delete, + "suggested_ed2_path_with_delete"); +} diff --git a/fairseq/criterions/nat_loss.py b/fairseq/criterions/nat_loss.py new file mode 100644 index 000000000..ccb25298f --- /dev/null +++ b/fairseq/criterions/nat_loss.py @@ -0,0 +1,190 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math + +import torch.nn.functional as F +from fairseq import utils +from torch import Tensor + +from . import FairseqCriterion, register_criterion + + +@register_criterion("nat_loss") +class LabelSmoothedDualImitationCriterion(FairseqCriterion): + @staticmethod + def add_args(parser): + """Add criterion-specific arguments to the parser.""" + # fmt: off + parser.add_argument( + '--label-smoothing', + default=0., + type=float, + metavar='D', + help='epsilon for label smoothing, 0 means no label smoothing') + # fmt: on + + def _compute_loss( + self, outputs, targets, masks=None, label_smoothing=0.0, name="loss", factor=1.0 + ): + """ + outputs: batch x len x d_model + targets: batch x len + masks: batch x len + + policy_logprob: if there is some policy + depends on the likelihood score as rewards. + """ + + def mean_ds(x: Tensor, dim=None) -> Tensor: + return ( + x.float().mean().type_as(x) + if dim is None + else x.float().mean(dim).type_as(x) + ) + + if masks is not None: + outputs, targets = outputs[masks], targets[masks] + + logits = F.log_softmax(outputs, dim=-1) + if targets.dim() == 1: + losses = F.nll_loss(logits, targets, reduction="none") + + else: # soft-labels + losses = F.kl_div(logits, targets, reduction="none") + losses = losses.float().sum(-1).type_as(losses) + + nll_loss = mean_ds(losses) + if label_smoothing > 0: + loss = nll_loss * (1 - label_smoothing) - mean_ds(logits) * label_smoothing + else: + loss = nll_loss + + loss = loss * factor + return {"name": name, "loss": loss, "nll_loss": nll_loss, "factor": factor} + + def _custom_loss(self, loss, name="loss"): + return {"name": name, "loss": loss, "factor": 1} + + def forward(self, model, sample, reduce=True): + """Compute the loss for the given sample. + Returns a tuple with three elements: + 1) the loss + 2) the sample size, which is used as the denominator for the gradient + 3) logging outputs to display while training + """ + nsentences, ntokens = sample["nsentences"], sample["ntokens"] + + # B x T + src_tokens, src_lengths = ( + sample["net_input"]["src_tokens"], + sample["net_input"]["src_lengths"], + ) + tgt_tokens, prev_output_tokens = sample["target"], sample["prev_target"] + + outputs = model(src_tokens, src_lengths, prev_output_tokens, tgt_tokens) + losses = [] + if "mask_ins_out" in outputs: + mask_ins_losses = self._compute_loss( + outputs["mask_ins_out"], + outputs["mask_ins_tgt"], + outputs["mask_ins_mask"], + name="m_ins-loss", + factor=1 if "mask_ins_w" not in outputs else outputs["mask_ins_w"], + ) + losses += [mask_ins_losses] + + if "word_ins_out" in outputs: + word_ins_losses = self._compute_loss( + outputs["word_ins_out"], + outputs["word_ins_tgt"], + outputs["word_ins_mask"], + self.args.label_smoothing, + name="w_ins-loss", + factor=1 if "word_ins_w" not in outputs else outputs["word_ins_w"], + ) + + losses += [word_ins_losses] + nll_loss = word_ins_losses["nll_loss"] + + if "word_del_out" in outputs: + word_del_losses = self._compute_loss( + outputs["word_del_out"], + outputs["word_del_tgt"], + outputs["word_del_mask"], + 0.01, + name="w_del-loss", + factor=1 if "word_del_w" not in outputs else outputs["word_del_w"], + ) + + losses += [word_del_losses] + + if "length_out" in outputs: + length_losses = self._compute_loss( + outputs["length_out"], + outputs["length_tgt"], + name="len-loss", + factor=1 if "length_w" not in outputs else outputs["length_w"], + ) + + losses += [length_losses] + + for w in outputs: + if "-loss" in w: + losses += [self._custom_loss(outputs[w], w)] + + loss = sum(l["loss"] for l in losses) + + # NOTE: as we are summing up per token mlm loss and per sentence nsp loss + # we don't need to use sample_size as denominator for the gradient + # here sample_size is just used for logging + sample_size = 1 + logging_output = { + "loss": utils.item(loss.data) if reduce else loss.data, + "nll_loss": utils.item(nll_loss.data) if reduce else nll_loss.data, + "ntokens": ntokens, + "nsentences": nsentences, + "sample_size": sample_size, + } + + for l in losses: + logging_output[l["name"]] = ( + utils.item(l["loss"].data / l["factor"]) + if reduce + else l[["loss"]].data / l["factor"] + ) + + return loss, sample_size, logging_output + + @staticmethod + def aggregate_logging_outputs(logging_outputs): + """Aggregate logging outputs from data parallel training.""" + ntokens = sum(log.get("ntokens", 0) for log in logging_outputs) + nsentences = sum(log.get("nsentences", 0) for log in logging_outputs) + sample_size = sum(log.get("sample_size", 0) for log in logging_outputs) + loss = sum(log.get("loss", 0) for log in logging_outputs) + nll_loss = sum(log.get("nll_loss", 0) for log in logging_outputs) + + results = { + "loss": loss / sample_size / math.log(2) if sample_size > 0 else 0.0, + "nll_loss": nll_loss / sample_size / math.log(2) + if sample_size > 0 + else 0.0, + "ntokens": ntokens, + "nsentences": nsentences, + "sample_size": sample_size, + } + + for key in logging_outputs[0]: + if key[-5:] == "-loss": + results[key[:-5]] = ( + sum(log.get(key, 0) for log in logging_outputs) + / sample_size + / math.log(2) + if sample_size > 0 + else 0.0 + ) + + return results diff --git a/fairseq/data/dictionary.py b/fairseq/data/dictionary.py index 417105e50..5d135ba12 100644 --- a/fairseq/data/dictionary.py +++ b/fairseq/data/dictionary.py @@ -74,7 +74,10 @@ class Dictionary(object): else: return self[i] - sent = ' '.join(token_string(i) for i in tensor if i != self.eos()) + if hasattr(self, 'bos_index'): + sent = ' '.join(token_string(i) for i in tensor if (i != self.eos()) and (i != self.bos())) + else: + sent = ' '.join(token_string(i) for i in tensor if i != self.eos()) return data_utils.process_bpe_symbol(sent, bpe_symbol) def unk_string(self, escape=False): diff --git a/fairseq/iterative_refinement_generator.py b/fairseq/iterative_refinement_generator.py new file mode 100644 index 000000000..aee488418 --- /dev/null +++ b/fairseq/iterative_refinement_generator.py @@ -0,0 +1,154 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch + +from fairseq.models.model_utils import skip_tensors as _skip + + +class IterativeRefinementGenerator(object): + def __init__(self, + tgt_dict, + eos_penalty=0., + max_iter=10, + max_ratio=2, + decoding_format=None, + retain_dropout=False, + adaptive=True): + """ + Generates translations based on iterative refinement. + + Args: + tgt_dict: target dictionary + eos_penalty: if > 0.0, it penalized early-stopping in decoding + max_iter: maximum number of refinement iterations + max_ratio: generate sequences of maximum length ax, where x is the source length + decoding_format: decoding mode in {'unigram', 'ensemble', 'vote', 'dp', 'bs'} + retain_dropout: retaining dropout in the inference + adaptive: decoding with early stop + """ + self.bos = tgt_dict.bos() + self.pad = tgt_dict.pad() + self.unk = tgt_dict.unk() + self.eos = tgt_dict.eos() + self.vocab_size = len(tgt_dict) + self.eos_penalty = eos_penalty + self.max_iter = max_iter + self.max_ratio = max_ratio + self.decoding_format = decoding_format + self.retain_dropout = retain_dropout + self.adaptive = adaptive + + @torch.no_grad() + def generate(self, models, sample, prefix_tokens=None): + + # TODO: model ensemble + assert len(models) == 1, 'only support single model' + model = models[0] + if not self.retain_dropout: + model.eval() + + # TODO: better encoder inputs? + src_tokens = sample['net_input']['src_tokens'] + src_lengths = sample['net_input']['src_lengths'] + bsz, src_len = src_tokens.size() + sent_idxs = torch.arange(bsz, device=src_tokens.device) + + # encoding + encoder_out = model.forward_encoder([src_tokens, src_lengths]) + + # initialize buffers (very model specific, with length prediction or not) + prev_decoder_out = model.initialize_output_tokens( + encoder_out, src_tokens) + prev_out_tokens = prev_decoder_out['output_tokens'].clone() + + finalized = [[] for _ in range(bsz)] + + def is_a_loop(x, y, s, a): + b, l_x, l_y = x.size(0), x.size(1), y.size(1) + if l_x > l_y: + y = torch.cat([y, x.new_zeros(b, l_x - l_y).fill_(self.pad)], 1) + s = torch.cat([s, s.new_zeros(b, l_x - l_y)], 1) + if a is not None: + a = torch.cat([a, a.new_zeros(b, l_x - l_y, a.size(2))], 1) + elif l_x < l_y: + x = torch.cat([x, y.new_zeros(b, l_y - l_x).fill_(self.pad)], 1) + return (x == y).all(1), y, s, a + + def finalized_hypos(step, prev_out_token, prev_out_score, prev_out_attn): + cutoff = prev_out_token.ne(self.pad) + tokens = prev_out_token[cutoff] + scores = prev_out_score[cutoff] + if prev_out_attn is None: + hypo_attn, alignment = None, None + else: + hypo_attn = prev_out_attn[cutoff] + alignment = hypo_attn.max(dim=1)[1] + return { + 'steps': step, + 'tokens': tokens, + 'positional_scores': scores, + 'score': scores.mean(), + 'hypo_attn': hypo_attn, + 'alignment': alignment, + } + + for step in range(self.max_iter + 1): + + decoder_options = { + 'eos_penalty': self.eos_penalty, + 'max_ratio': self.max_ratio, + 'decoding_format': self.decoding_format + } + prev_decoder_out['step'] = step + prev_decoder_out['max_step'] = self.max_iter + 1 + + decoder_out = model.forward_decoder( + prev_decoder_out, encoder_out, **decoder_options + ) + + if self.adaptive: + # terminate if there is a loop + terminated, out_tokens, out_scores, out_attn = is_a_loop( + prev_out_tokens, decoder_out['output_tokens'], + decoder_out['output_scores'], decoder_out['attn']) + decoder_out['output_tokens'] = out_tokens + decoder_out['output_scores'] = out_scores + decoder_out['attn'] = out_attn + + else: + terminated = decoder_out['output_tokens'].new_zeros( + decoder_out['output_tokens'].size(0)).bool() + + if step == self.max_iter: # reach last iteration, terminate + terminated.fill_(1) + + # collect finalized sentences + finalized_idxs = sent_idxs[terminated] + finalized_tokens = decoder_out['output_tokens'][terminated] + finalized_scores = decoder_out['output_scores'][terminated] + finalized_attn = None if decoder_out['attn'] is None else decoder_out['attn'][terminated] + + for i in range(finalized_idxs.size(0)): + finalized[finalized_idxs[i]] = [ + finalized_hypos( + step, + finalized_tokens[i], + finalized_scores[i], + None if finalized_attn is None else finalized_attn[i] + ) + ] + # check if all terminated + if terminated.sum() == terminated.size(0): + break + + # for next step + prev_decoder_out = _skip(decoder_out, ~terminated) + encoder_out = _skip(encoder_out, ~terminated) + sent_idxs = _skip(sent_idxs, ~terminated) + + prev_out_tokens = prev_decoder_out['output_tokens'].clone() + + return finalized diff --git a/fairseq/models/cmlm_transformer.py b/fairseq/models/cmlm_transformer.py new file mode 100644 index 000000000..f76c93fd0 --- /dev/null +++ b/fairseq/models/cmlm_transformer.py @@ -0,0 +1,136 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" +This file implements: +Ghazvininejad, Marjan, et al. +"Constant-time machine translation with conditional masked language models." +arXiv preprint arXiv:1904.09324 (2019). +""" + +import torch +from fairseq.models import register_model, register_model_architecture +from fairseq.models.nonautoregressive_transformer import NATransformerModel + + +def _skeptical_unmasking(output_scores, output_masks, p): + sorted_index = output_scores.sort(-1)[1] + boundary_len = ( + (output_masks.sum(1, keepdim=True).type_as(output_scores) - 2) * p + ).long() + skeptical_mask = ( + torch.arange(output_masks.size(1), device=output_masks.device)[None, :] + < boundary_len + ) + return skeptical_mask.scatter(1, sorted_index, skeptical_mask) + + +@register_model("cmlm_transformer") +class CMLMNATransformerModel(NATransformerModel): + @staticmethod + def add_args(parser): + NATransformerModel.add_args(parser) + + def forward( + self, src_tokens, src_lengths, prev_output_tokens, tgt_tokens, **kwargs + ): + assert not self.decoder.src_embedding_copy, "do not support embedding copy." + + encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs) + length_out, length_tgt = self.decoder.forward_length_prediction( + encoder_out, tgt_tokens + ) + + word_ins_out, word_ins_tgt, _ = self.decoder( + prev_output_tokens, encoder_out=encoder_out, tgt_tokens=tgt_tokens + ) + word_ins_mask = prev_output_tokens.eq(self.unk) + return { + "word_ins_out": word_ins_out, + "word_ins_tgt": word_ins_tgt, + "word_ins_mask": word_ins_mask, + "length_out": length_out, + "length_tgt": length_tgt, + "length_w": self.decoder.length_loss_factor, + } + + def forward_decoder(self, decoder_out, encoder_out, decoding_format=None, **kwargs): + + step = decoder_out["step"] + max_step = decoder_out["max_step"] + + output_tokens = decoder_out["output_tokens"] + output_scores = decoder_out["output_scores"] + + # execute the decoder + output_masks = output_tokens.eq(self.unk) + _scores, _tokens = self.decoder( + output_tokens, encoder_out=encoder_out, decoding_format=decoding_format + ) + output_tokens.masked_scatter_(output_masks, _tokens[output_masks]) + output_scores.masked_scatter_(output_masks, _scores[output_masks]) + + # skeptical decoding (depend on the maximum decoding steps.) + if (step + 1) < max_step: + skeptical_mask = _skeptical_unmasking( + output_scores, output_tokens.ne(self.pad), 1 - (step + 1) / max_step + ) + + output_tokens.masked_fill_(skeptical_mask, self.unk) + output_scores.masked_fill_(skeptical_mask, 0.0) + + return {"output_tokens": output_tokens, "output_scores": output_scores} + + +@register_model_architecture("cmlm_transformer", "cmlm_transformer") +def base_architecture(args): + args.encoder_embed_path = getattr(args, "encoder_embed_path", None) + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512) + args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048) + args.encoder_layers = getattr(args, "encoder_layers", 6) + args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8) + args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False) + args.encoder_learned_pos = getattr(args, "encoder_learned_pos", False) + args.decoder_embed_path = getattr(args, "decoder_embed_path", None) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", args.encoder_embed_dim) + args.decoder_ffn_embed_dim = getattr( + args, "decoder_ffn_embed_dim", args.encoder_ffn_embed_dim + ) + args.decoder_layers = getattr(args, "decoder_layers", 6) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8) + args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False) + args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False) + args.attention_dropout = getattr(args, "attention_dropout", 0.0) + args.activation_dropout = getattr(args, "activation_dropout", 0.0) + args.activation_fn = getattr(args, "activation_fn", "relu") + args.dropout = getattr(args, "dropout", 0.1) + args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None) + args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0) + args.share_decoder_input_output_embed = getattr( + args, "share_decoder_input_output_embed", False + ) + args.share_all_embeddings = getattr(args, "share_all_embeddings", True) + args.no_token_positional_embeddings = getattr( + args, "no_token_positional_embeddings", False + ) + args.adaptive_input = getattr(args, "adaptive_input", False) + args.apply_bert_init = getattr(args, "apply_bert_init", False) + + args.decoder_output_dim = getattr( + args, "decoder_output_dim", args.decoder_embed_dim + ) + args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim) + + # --- special arguments --- + args.sg_length_pred = getattr(args, "sg_length_pred", False) + args.pred_length_offset = getattr(args, "pred_length_offset", False) + args.length_loss_factor = getattr(args, "length_loss_factor", 0.1) + args.ngram_predictor = getattr(args, "ngram_predictor", 1) + args.src_embedding_copy = getattr(args, "src_embedding_copy", False) + + +@register_model_architecture("cmlm_transformer", "cmlm_transformer_wmt_en_de") +def iter_nat_wmt_en_de(args): + base_architecture(args) diff --git a/fairseq/models/insertion_transformer.py b/fairseq/models/insertion_transformer.py new file mode 100644 index 000000000..5f5868a55 --- /dev/null +++ b/fairseq/models/insertion_transformer.py @@ -0,0 +1,259 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch +import torch.nn.functional as F +from fairseq import libnat +from fairseq.models import register_model, register_model_architecture +from fairseq.models.levenshtein_transformer import ( + LevenshteinTransformerDecoder, + LevenshteinTransformerModel, +) +from fairseq.models.transformer import Linear, TransformerModel +from fairseq.modules.transformer_sentence_encoder import init_bert_params + + +class NegativeDistanceScore(object): + def __init__(self): + + # pre-compute some values + self.scores = {} + + self.scores[0.5] = self.compute_score_full(50, 0.5) + self.scores[1.0] = self.compute_score_full(50, 1.0) + self.scores[2.0] = self.compute_score_full(50, 2.0) + + def __call__(self, i, L, tau): + if (tau is None) or (tau > 1000): + return 1 / L + + if tau in self.scores: + if L < self.scores[tau].shape[0]: + return self.scores[tau][L - 1, i] + return self.compute_score(L, tau)[i] + + def compute_score(self, L, tau): + s = np.array([-abs(L / 2 - i) / tau for i in range(L)]) + s = np.exp(s - s.max()) + return s / s.sum() + + def compute_score_full(self, L, tau): + s = -abs(np.arange(0, L - 1)[:, None] / 2 - np.arange(L)[None, :]) / tau + s = np.tril(s, 0) + np.triu(s - float("inf"), 1) + s = np.exp(s - s.max(1, keepdims=True)) + return s / s.sum(1, keepdims=True) + + +neg_scorer = NegativeDistanceScore() + + +def _get_ins_targets(in_tokens, out_tokens, padding_idx, unk_idx, vocab_size, tau=None): + B = in_tokens.size(0) + T = in_tokens.size(1) + V = vocab_size + + with torch.cuda.device_of(in_tokens): + in_tokens_list = [ + [t for t in s if t != padding_idx] for i, s in enumerate(in_tokens.tolist()) + ] + out_tokens_list = [ + [t for t in s if t != padding_idx] + for i, s in enumerate(out_tokens.tolist()) + ] + + full_labels = libnat.suggested_ed2_path( + in_tokens_list, out_tokens_list, padding_idx + ) + insert_labels = [a[:-1] for a in full_labels] + + # numericalize1 + insert_label_tensors = in_tokens.new_zeros(B * (T - 1) * V).float() + insert_index, insert_labels = zip( + *[ + (w + (j + i * (T - 1)) * V, neg_scorer(k, len(label), tau)) + for i, labels in enumerate(insert_labels) + for j, label in enumerate(labels[1:-1]) + for k, w in enumerate(label) + ] + ) # HACK 1:-1 + insert_index, insert_labels = [ + torch.tensor(list(a), device=in_tokens.device) + for a in [insert_index, insert_labels] + ] + insert_label_tensors.scatter_(0, insert_index.long(), insert_labels) + insert_label_tensors = insert_label_tensors.view(B, T - 1, V) + + return insert_label_tensors + + +def _apply_ins_words(in_tokens, in_scores, word_ins_pred, word_ins_scores, padding_idx): + + padding_masks = in_tokens[:, 1:].eq(padding_idx) + word_ins_scores.masked_fill_(padding_masks, 0.0) + word_ins_pred.masked_fill_(padding_masks, padding_idx) + + in_coords = torch.arange(in_tokens.size(1), device=in_tokens.device) + in_coords = in_coords.unsqueeze(0).repeat(in_tokens.size(0), 1).type_as(in_scores) + + # shift all padding predictions to infinite + out_coords = (in_coords[:, 1:] - 0.5).masked_fill( + word_ins_pred.eq(padding_idx), float("inf") + ) + out_coords = torch.cat([in_coords, out_coords], 1).sort(-1)[1] + out_tokens = torch.cat([in_tokens, word_ins_pred], 1).gather(1, out_coords) + out_scores = torch.cat([in_scores, word_ins_scores], 1).gather(1, out_coords) + return out_tokens, out_scores + + +@register_model("insertion_transformer") +class InsertionTransformerModel(LevenshteinTransformerModel): + def __init__(self, encoder, decoder): + super().__init__(encoder, decoder) + + @staticmethod + def add_args(parser): + TransformerModel.add_args(parser) + parser.add_argument( + "--apply-bert-init", + action="store_true", + help="use custom param initialization for BERT", + ) + parser.add_argument("--label-tau", default=None, type=float) + + @classmethod + def build_decoder(cls, args, tgt_dict, embed_tokens): + decoder = InsertionTransformerDecoder(args, tgt_dict, embed_tokens) + if getattr(args, "apply_bert_init", False): + decoder.apply(init_bert_params) + return decoder + + def forward( + self, src_tokens, src_lengths, prev_output_tokens, tgt_tokens, **kwargs + ): + + assert tgt_tokens is not None, "forward function only supports training." + + # encoding + encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs) + + # generate training labels for insertion + word_ins_out = self.decoder.forward_word_ins( + prev_output_tokens, encoder_out=encoder_out + ) + word_ins_tgt = _get_ins_targets( + prev_output_tokens, + tgt_tokens, + self.pad, + self.unk, + len(self.tgt_dict), + tau=self.decoder.label_tau, + ).type_as(word_ins_out) + word_ins_masks = prev_output_tokens[:, 1:].ne(self.pad) + + return { + "word_ins_out": word_ins_out, + "word_ins_tgt": word_ins_tgt, + "word_ins_mask": word_ins_masks, + } + + def forward_decoder( + self, decoder_out, encoder_out, eos_penalty=0.0, max_ratio=None, **kwargs + ): + + output_tokens = decoder_out["output_tokens"] + output_scores = decoder_out["output_scores"] + # TODO: decoding for InsertionTransformer + word_ins_out = self.decoder.forward_word_ins( + output_tokens, encoder_out=encoder_out + ) + word_ins_score = F.log_softmax(word_ins_out, 2) + if eos_penalty > 0.0: + word_ins_score[:, :, self.pad] -= eos_penalty + word_ins_score, word_ins_pred = word_ins_score.max(-1) + output_tokens, output_scores = _apply_ins_words( + output_tokens, output_scores, word_ins_pred, word_ins_score, self.pad + ) + + # delete some unnecessary paddings + cut_off = output_tokens.ne(self.pad).sum(1).max() + output_tokens = output_tokens[:, :cut_off] + output_scores = output_scores[:, :cut_off] + return {"output_tokens": output_tokens, "output_scores": output_scores} + + +class InsertionTransformerDecoder(LevenshteinTransformerDecoder): + def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False): + # use the TransformerDecoder's __init__ + super(LevenshteinTransformerDecoder, self).__init__( + args, dictionary, embed_tokens, no_encoder_attn=no_encoder_attn + ) + + self.dictionary = dictionary + self.bos = dictionary.bos() + self.unk = dictionary.unk() + self.eos = dictionary.eos() + self.pool_out = Linear(self.output_embed_dim * 2, self.output_embed_dim) + + self.label_tau = getattr(args, "label_tau", None) + + def forward_word_ins(self, prev_output_tokens, encoder_out=None): + features, _ = self.extract_features(prev_output_tokens, encoder_out=encoder_out) + features = self.pool_out( + torch.cat([features[:, :-1, :], features[:, 1:, :]], 2) + ) + return self.output_layer(features) + + def forward_mask_ins(self, *args, **kwargs): + raise NotImplementedError + + def forward_word_del(self, *args, **kwargs): + raise NotImplementedError + + def forward_word_del_mask_ins(self, *args, **kwargs): + raise NotImplementedError + + +@register_model_architecture("insertion_transformer", "insertion_transformer") +def base_architecture(args): + args.encoder_embed_path = getattr(args, "encoder_embed_path", None) + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512) + args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048) + args.encoder_layers = getattr(args, "encoder_layers", 6) + args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8) + args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False) + args.encoder_learned_pos = getattr(args, "encoder_learned_pos", False) + args.decoder_embed_path = getattr(args, "decoder_embed_path", None) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", args.encoder_embed_dim) + args.decoder_ffn_embed_dim = getattr( + args, "decoder_ffn_embed_dim", args.encoder_ffn_embed_dim + ) + args.decoder_layers = getattr(args, "decoder_layers", 6) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8) + args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False) + args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False) + args.attention_dropout = getattr(args, "attention_dropout", 0.0) + args.activation_dropout = getattr(args, "activation_dropout", 0.0) + args.activation_fn = getattr(args, "activation_fn", "relu") + args.dropout = getattr(args, "dropout", 0.1) + args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None) + args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0) + args.share_decoder_input_output_embed = getattr( + args, "share_decoder_input_output_embed", False + ) + args.share_all_embeddings = getattr(args, "share_all_embeddings", False) + args.no_token_positional_embeddings = getattr( + args, "no_token_positional_embeddings", False + ) + args.adaptive_input = getattr(args, "adaptive_input", False) + args.apply_bert_init = getattr(args, "apply_bert_init", False) + + args.decoder_output_dim = getattr( + args, "decoder_output_dim", args.decoder_embed_dim + ) + args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim) + + # special for insertion transformer + args.label_tau = getattr(args, "label_tau", None) diff --git a/fairseq/models/iterative_nonautoregressive_transformer.py b/fairseq/models/iterative_nonautoregressive_transformer.py new file mode 100644 index 000000000..73585db35 --- /dev/null +++ b/fairseq/models/iterative_nonautoregressive_transformer.py @@ -0,0 +1,196 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from fairseq.models import register_model, register_model_architecture +from fairseq.models.nonautoregressive_transformer import NATransformerModel + + +def _sequential_poisoning(s, V, beta=0.33, bos=2, eos=3, pad=1): + # s: input batch + # V: vocabulary size + rand_words = torch.randint(low=4, high=V, size=s.size(), device=s.device) + choices = torch.rand(size=s.size(), device=s.device) + choices.masked_fill_((s == pad) | (s == bos) | (s == eos), 1) + + replace = choices < beta / 3 + repeat = (choices >= beta / 3) & (choices < beta * 2 / 3) + swap = (choices >= beta * 2 / 3) & (choices < beta) + safe = choices >= beta + + for i in range(s.size(1) - 1): + rand_word = rand_words[:, i] + next_word = s[:, i + 1] + self_word = s[:, i] + + replace_i = replace[:, i] + swap_i = swap[:, i] & (next_word != 3) + repeat_i = repeat[:, i] & (next_word != 3) + safe_i = safe[:, i] | ((next_word == 3) & (~replace_i)) + + s[:, i] = ( + self_word * (safe_i | repeat_i).long() + + next_word * swap_i.long() + + rand_word * replace_i.long() + ) + s[:, i + 1] = ( + next_word * (safe_i | replace_i).long() + + self_word * (swap_i | repeat_i).long() + ) + return s + + +def gumbel_noise(input, TINY=1e-8): + return input.new_zeros(*input.size()).uniform_().add_( + TINY).log_().neg_().add_(TINY).log_().neg_() + + +@register_model("iterative_nonautoregressive_transformer") +class IterNATransformerModel(NATransformerModel): + @staticmethod + def add_args(parser): + NATransformerModel.add_args(parser) + parser.add_argument("--train-step", type=int, + help="number of refinement iterations during training") + parser.add_argument("--dae-ratio", type=float, + help="the probability of switching to the denoising auto-encoder loss") + parser.add_argument("--stochastic-approx", action="store_true", + help="sampling from the decoder as the inputs for next iteration") + + @classmethod + def build_model(cls, args, task): + model = super().build_model(args, task) + model.train_step = getattr(args, "train_step", 4) + model.dae_ratio = getattr(args, "dae_ratio", 0.5) + model.stochastic_approx = getattr(args, "stochastic_approx", False) + return model + + def forward( + self, src_tokens, src_lengths, prev_output_tokens, tgt_tokens, **kwargs + ): + + B, T = prev_output_tokens.size() + + # encoding + encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs) + length_out, length_tgt = self.decoder.forward_length_prediction( + encoder_out, tgt_tokens + ) + word_ins_outs, word_ins_tgts, word_ins_masks = [], [], [] + for t in range(self.train_step): + word_ins_out, word_ins_tgt, word_ins_mask = self.decoder( + prev_output_tokens, + encoder_out=encoder_out, + tgt_tokens=tgt_tokens, + step=t, + ) + + word_ins_outs.append(word_ins_out) + word_ins_tgts.append(word_ins_tgt) + word_ins_masks.append(word_ins_mask) + + if t < (self.train_step - 1): + # prediction for next iteration + if self.stochastic_approx: + word_ins_prediction = ( + word_ins_out + gumbel_noise(word_ins_out) + ).max(-1)[1] + else: + word_ins_prediction = word_ins_out.max(-1)[1] + + prev_output_tokens = prev_output_tokens.masked_scatter( + word_ins_mask, word_ins_prediction[word_ins_mask] + ) + + if self.dae_ratio > 0: + # we do not perform denoising for the first iteration + corrputed = ( + torch.rand(size=(B,), device=prev_output_tokens.device) + < self.dae_ratio + ) + corrputed_tokens = _sequential_poisoning( + tgt_tokens[corrputed], + len(self.tgt_dict), + 0.33, + self.bos, + self.eos, + self.pad, + ) + prev_output_tokens[corrputed] = corrputed_tokens + + # concat everything + word_ins_out = torch.cat(word_ins_outs, 0) + word_ins_tgt = torch.cat(word_ins_tgts, 0) + word_ins_mask = torch.cat(word_ins_masks, 0) + + return { + "word_ins_out": word_ins_out, + "word_ins_tgt": word_ins_tgt, + "word_ins_mask": word_ins_mask, + "length_out": length_out, + "length_tgt": length_tgt, + "length_w": self.decoder.length_loss_factor, + } + + +@register_model_architecture( + "iterative_nonautoregressive_transformer", "iterative_nonautoregressive_transformer" +) +def base_architecture(args): + args.encoder_embed_path = getattr(args, "encoder_embed_path", None) + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512) + args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048) + args.encoder_layers = getattr(args, "encoder_layers", 6) + args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8) + args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False) + args.encoder_learned_pos = getattr(args, "encoder_learned_pos", False) + args.decoder_embed_path = getattr(args, "decoder_embed_path", None) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", args.encoder_embed_dim) + args.decoder_ffn_embed_dim = getattr( + args, "decoder_ffn_embed_dim", args.encoder_ffn_embed_dim + ) + args.decoder_layers = getattr(args, "decoder_layers", 6) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8) + args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False) + args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False) + args.attention_dropout = getattr(args, "attention_dropout", 0.0) + args.activation_dropout = getattr(args, "activation_dropout", 0.0) + args.activation_fn = getattr(args, "activation_fn", "relu") + args.dropout = getattr(args, "dropout", 0.1) + args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None) + args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0) + args.share_decoder_input_output_embed = getattr( + args, "share_decoder_input_output_embed", False + ) + args.share_all_embeddings = getattr(args, "share_all_embeddings", False) + args.no_token_positional_embeddings = getattr( + args, "no_token_positional_embeddings", False + ) + args.adaptive_input = getattr(args, "adaptive_input", False) + args.apply_bert_init = getattr(args, "apply_bert_init", False) + + args.decoder_output_dim = getattr( + args, "decoder_output_dim", args.decoder_embed_dim + ) + args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim) + + # --- special arguments --- + args.sg_length_pred = getattr(args, "sg_length_pred", False) + args.pred_length_offset = getattr(args, "pred_length_offset", False) + args.length_loss_factor = getattr(args, "length_loss_factor", 0.1) + args.ngram_predictor = getattr(args, "ngram_predictor", 1) + args.src_embedding_copy = getattr(args, "src_embedding_copy", False) + + args.train_step = getattr(args, "train_step", 4) + args.dae_ratio = getattr(args, "dae_ratio", 0.5) + args.stochastic_approx = getattr(args, "stochastic_approx", False) + + +@register_model_architecture( + "iterative_nonautoregressive_transformer", + "iterative_nonautoregressive_transformer_wmt_en_de", +) +def iter_nat_wmt_en_de(args): + base_architecture(args) diff --git a/fairseq/models/levenshtein_transformer.py b/fairseq/models/levenshtein_transformer.py new file mode 100644 index 000000000..876bf01a0 --- /dev/null +++ b/fairseq/models/levenshtein_transformer.py @@ -0,0 +1,595 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn.functional as F +from fairseq import libnat +from fairseq.models import register_model, register_model_architecture +from fairseq.models.model_utils import fill_tensors as _fill, skip_tensors as _skip +from fairseq.models.transformer import ( + Embedding, + TransformerDecoder, + TransformerEncoder, + TransformerModel, +) +from fairseq.modules.transformer_sentence_encoder import init_bert_params + + +def _get_ins_targets(in_tokens, out_tokens, padding_idx, unk_idx): + in_seq_len, out_seq_len = in_tokens.size(1), out_tokens.size(1) + + with torch.cuda.device_of(in_tokens): + in_tokens_list = [ + [t for t in s if t != padding_idx] for i, s in enumerate(in_tokens.tolist()) + ] + out_tokens_list = [ + [t for t in s if t != padding_idx] + for i, s in enumerate(out_tokens.tolist()) + ] + + full_labels = libnat.suggested_ed2_path( + in_tokens_list, out_tokens_list, padding_idx + ) + mask_inputs = [ + [len(c) if c[0] != padding_idx else 0 for c in a[:-1]] for a in full_labels + ] + + # generate labels + masked_tgt_masks = [] + for mask_input in mask_inputs: + mask_label = [] + for beam_size in mask_input[1:-1]: # HACK 1:-1 + mask_label += [0] + [1 for _ in range(beam_size)] + masked_tgt_masks.append( + mask_label + [0 for _ in range(out_seq_len - len(mask_label))] + ) + mask_ins_targets = [ + mask_input[1:-1] + [0 for _ in range(in_seq_len - 1 - len(mask_input[1:-1]))] + for mask_input in mask_inputs + ] + + # transform to tensor + masked_tgt_masks = torch.tensor( + masked_tgt_masks, device=out_tokens.device + ).bool() + mask_ins_targets = torch.tensor(mask_ins_targets, device=in_tokens.device) + masked_tgt_tokens = out_tokens.masked_fill(masked_tgt_masks, unk_idx) + return masked_tgt_masks, masked_tgt_tokens, mask_ins_targets + + +def _get_del_targets(in_tokens, out_tokens, padding_idx): + out_seq_len = out_tokens.size(1) + + with torch.cuda.device_of(in_tokens): + in_tokens_list = [ + [t for t in s if t != padding_idx] for i, s in enumerate(in_tokens.tolist()) + ] + out_tokens_list = [ + [t for t in s if t != padding_idx] + for i, s in enumerate(out_tokens.tolist()) + ] + + full_labels = libnat.suggested_ed2_path( + in_tokens_list, out_tokens_list, padding_idx + ) + word_del_targets = [b[-1] for b in full_labels] + word_del_targets = [ + labels + [0 for _ in range(out_seq_len - len(labels))] + for labels in word_del_targets + ] + + # transform to tensor + word_del_targets = torch.tensor(word_del_targets, device=out_tokens.device) + return word_del_targets + + +def _get_del_ins_targets(in_tokens, out_tokens, padding_idx): + in_seq_len, out_seq_len = in_tokens.size(1), out_tokens.size(1) + + with torch.cuda.device_of(in_tokens): + in_tokens_list = [ + [t for t in s if t != padding_idx] for i, s in enumerate(in_tokens.tolist()) + ] + out_tokens_list = [ + [t for t in s if t != padding_idx] + for i, s in enumerate(out_tokens.tolist()) + ] + + full_labels = libnat.suggested_ed2_path( + in_tokens_list, out_tokens_list, padding_idx + ) + + word_del_targets = [b[-1] for b in full_labels] + word_del_targets = [ + labels + [0 for _ in range(out_seq_len - len(labels))] + for labels in word_del_targets + ] + + mask_inputs = [ + [len(c) if c[0] != padding_idx else 0 for c in a[:-1]] for a in full_labels + ] + mask_ins_targets = [ + mask_input[1:-1] + [0 for _ in range(in_seq_len - 1 - len(mask_input[1:-1]))] + for mask_input in mask_inputs + ] + + # transform to tensor + mask_ins_targets = torch.tensor(mask_ins_targets, device=in_tokens.device) + word_del_targets = torch.tensor(word_del_targets, device=out_tokens.device) + return word_del_targets, mask_ins_targets + + +def _apply_ins_masks( + in_tokens, in_scores, mask_ins_pred, padding_idx, unk_idx, eos_idx +): + + in_masks = in_tokens.ne(padding_idx) + in_lengths = in_masks.sum(1) + + # HACK: hacky way to shift all the paddings to eos first. + in_tokens.masked_fill_(~in_masks, eos_idx) + mask_ins_pred.masked_fill_(~in_masks[:, 1:], 0) + + out_lengths = in_lengths + mask_ins_pred.sum(1) + out_max_len = out_lengths.max() + out_masks = ( + torch.arange(out_max_len, device=out_lengths.device)[None, :] + < out_lengths[:, None] + ) + + reordering = (mask_ins_pred + in_masks[:, 1:].long()).cumsum(1) + out_tokens = ( + in_tokens.new_zeros(in_tokens.size(0), out_max_len) + .fill_(padding_idx) + .masked_fill_(out_masks, unk_idx) + ) + out_tokens[:, 0] = in_tokens[:, 0] + out_tokens.scatter_(1, reordering, in_tokens[:, 1:]) + + out_scores = None + if in_scores is not None: + in_scores.masked_fill_(~in_masks, 0) + out_scores = in_scores.new_zeros(*out_tokens.size()) + out_scores[:, 0] = in_scores[:, 0] + out_scores.scatter_(1, reordering, in_scores[:, 1:]) + + return out_tokens, out_scores + + +def _apply_ins_words(in_tokens, in_scores, word_ins_pred, word_ins_scores, unk_idx): + word_ins_masks = in_tokens.eq(unk_idx) + out_tokens = in_tokens.masked_scatter(word_ins_masks, word_ins_pred[word_ins_masks]) + + if in_scores is not None: + out_scores = in_scores.masked_scatter( + word_ins_masks, word_ins_scores[word_ins_masks] + ) + else: + out_scores = None + + return out_tokens, out_scores + + +def _apply_del_words( + in_tokens, in_scores, in_attn, word_del_pred, padding_idx, bos_idx, eos_idx +): + # apply deletion to a tensor + in_masks = in_tokens.ne(padding_idx) + bos_eos_masks = in_tokens.eq(bos_idx) | in_tokens.eq(eos_idx) + + max_len = in_tokens.size(1) + word_del_pred.masked_fill_(~in_masks, 1) + word_del_pred.masked_fill_(bos_eos_masks, 0) + + reordering = ( + torch.arange(max_len, device=in_tokens.device)[None, :] + .expand_as(in_tokens) + .contiguous() + .masked_fill_(word_del_pred, max_len) + .sort(1)[1] + ) + + out_tokens = in_tokens.masked_fill(word_del_pred, padding_idx).gather(1, reordering) + + out_scores = None + if in_scores is not None: + out_scores = in_scores.masked_fill(word_del_pred, 0).gather(1, reordering) + + out_attn = None + if in_attn is not None: + _mask = word_del_pred[:, :, None].expand_as(in_attn) + _reordering = reordering[:, :, None].expand_as(in_attn) + out_attn = in_attn.masked_fill(_mask, 0.).gather(1, _reordering) + + return out_tokens, out_scores, out_attn + + +@register_model("levenshtein_transformer") +class LevenshteinTransformerModel(TransformerModel): + def __init__(self, encoder, decoder): + super().__init__(encoder, decoder) + self.tgt_dict = decoder.dictionary + self.bos = decoder.dictionary.bos() + self.eos = decoder.dictionary.eos() + self.pad = decoder.dictionary.pad() + self.unk = decoder.dictionary.unk() + + @staticmethod + def add_args(parser): + TransformerModel.add_args(parser) + parser.add_argument( + "--apply-bert-init", + action="store_true", + help="use custom param initialization for BERT", + ) + parser.add_argument( + "--early-exit", + default="6,6,6", + type=str, + help="number of decoder layers before mask_ins, word_ins and word_del heads", + ) + + @classmethod + def build_decoder(cls, args, tgt_dict, embed_tokens): + decoder = LevenshteinTransformerDecoder(args, tgt_dict, embed_tokens) + if getattr(args, "apply_bert_init", False): + decoder.apply(init_bert_params) + return decoder + + @classmethod + def build_encoder(cls, args, src_dict, embed_tokens): + encoder = TransformerEncoder(args, src_dict, embed_tokens) + if getattr(args, "apply_bert_init", False): + encoder.apply(init_bert_params) + return encoder + + def forward( + self, src_tokens, src_lengths, prev_output_tokens, tgt_tokens, **kwargs + ): + + assert tgt_tokens is not None, "forward function only supports training." + + # encoding + encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs) + + # generate training labels for insertion + masked_tgt_masks, masked_tgt_tokens, mask_ins_targets = _get_ins_targets( + prev_output_tokens, tgt_tokens, self.pad, self.unk + ) + mask_ins_targets = mask_ins_targets.clamp(min=0, max=255) # for safe prediction + mask_ins_masks = prev_output_tokens[:, 1:].ne(self.pad) + + mask_ins_out, _ = self.decoder.forward_mask_ins( + prev_output_tokens, encoder_out=encoder_out + ) + word_ins_out, _ = self.decoder.forward_word_ins( + masked_tgt_tokens, encoder_out=encoder_out + ) + + # make online prediction + word_predictions = F.log_softmax(word_ins_out, dim=-1).max(2)[1] + word_predictions.masked_scatter_( + ~masked_tgt_masks, tgt_tokens[~masked_tgt_masks] + ) + + # generate training labels for deletion + word_del_targets = _get_del_targets(word_predictions, tgt_tokens, self.pad) + word_del_out, _ = self.decoder.forward_word_del( + word_predictions, encoder_out) + + return { + "mask_ins_out": mask_ins_out, + "mask_ins_tgt": mask_ins_targets, + "mask_ins_mask": mask_ins_masks, + "word_ins_out": word_ins_out, + "word_ins_tgt": tgt_tokens, + "word_ins_mask": masked_tgt_masks, + "word_del_out": word_del_out, + "word_del_tgt": word_del_targets, + "word_del_mask": word_predictions.ne(self.pad), + } + + def forward_encoder(self, encoder_inputs): + return self.encoder(*encoder_inputs) + + def forward_decoder( + self, decoder_out, encoder_out, eos_penalty=0.0, max_ratio=None, **kwargs + ): + + output_tokens = decoder_out["output_tokens"] + output_scores = decoder_out["output_scores"] + attn = decoder_out["attn"] + + if max_ratio is None: + max_lens = output_tokens.new(output_tokens.size(0)).fill_(255) + else: + max_lens = ( + (~encoder_out["encoder_padding_mask"]).sum(1) * max_ratio + ).clamp(min=10) + + # delete words + # do not delete tokens if it is + can_del_word = output_tokens.ne(self.pad).sum(1) > 2 + if can_del_word.sum() != 0: # we cannot delete, skip + word_del_out, word_del_attn = self.decoder.forward_word_del( + _skip(output_tokens, can_del_word), _skip(encoder_out, can_del_word) + ) + word_del_score = F.log_softmax(word_del_out, 2) + word_del_pred = word_del_score.max(-1)[1].bool() + + _tokens, _scores, _attn = _apply_del_words( + output_tokens[can_del_word], + output_scores[can_del_word], + word_del_attn, + word_del_pred, + self.pad, + self.bos, + self.eos, + ) + output_tokens = _fill(output_tokens, can_del_word, _tokens, self.pad) + output_scores = _fill(output_scores, can_del_word, _scores, 0) + attn = _fill(attn, can_del_word, _attn, 0.) + + # insert placeholders + can_ins_mask = output_tokens.ne(self.pad).sum(1) < max_lens + if can_ins_mask.sum() != 0: + mask_ins_out, _ = self.decoder.forward_mask_ins( + _skip(output_tokens, can_ins_mask), _skip(encoder_out, can_ins_mask) + ) + mask_ins_score = F.log_softmax(mask_ins_out, 2) + if eos_penalty > 0.0: + mask_ins_score[:, :, 0] -= eos_penalty + mask_ins_pred = mask_ins_score.max(-1)[1] + mask_ins_pred = torch.min( + mask_ins_pred, max_lens[:, None].expand_as(mask_ins_pred) + ) + + _tokens, _scores = _apply_ins_masks( + output_tokens[can_ins_mask], + output_scores[can_ins_mask], + mask_ins_pred, + self.pad, + self.unk, + self.eos, + ) + output_tokens = _fill(output_tokens, can_ins_mask, _tokens, self.pad) + output_scores = _fill(output_scores, can_ins_mask, _scores, 0) + + # insert words + can_ins_word = output_tokens.eq(self.unk).sum(1) > 0 + if can_ins_word.sum() != 0: + word_ins_out, word_ins_attn = self.decoder.forward_word_ins( + _skip(output_tokens, can_ins_word), _skip(encoder_out, can_ins_word) + ) + word_ins_score = F.log_softmax(word_ins_out, 2) + word_ins_pred = word_ins_score.max(-1)[1] + + _tokens, _scores = _apply_ins_words( + output_tokens[can_ins_word], + output_scores[can_ins_word], + word_ins_pred, + word_ins_score, + self.unk, + ) + + output_tokens = _fill(output_tokens, can_ins_word, _tokens, self.pad) + output_scores = _fill(output_scores, can_ins_word, _scores, 0) + attn = _fill(attn, can_ins_word, word_ins_attn, 0.) + + # delete some unnecessary paddings + cut_off = output_tokens.ne(self.pad).sum(1).max() + output_tokens = output_tokens[:, :cut_off] + output_scores = output_scores[:, :cut_off] + attn = None if attn is None else attn[:, :cut_off, :] + return { + "output_tokens": output_tokens, + "output_scores": output_scores, + "attn": attn, + } + + def initialize_output_tokens(self, encoder_out, src_tokens): + initial_output_tokens = src_tokens.new_zeros(src_tokens.size(0), 2) + initial_output_tokens[:, 0] = self.bos + initial_output_tokens[:, 1] = self.eos + + initial_output_scores = initial_output_tokens.new_zeros( + *initial_output_tokens.size() + ).type_as(encoder_out["encoder_out"]) + return { + "output_tokens": initial_output_tokens, + "output_scores": initial_output_scores, + "attn": None, + } + + +class LevenshteinTransformerDecoder(TransformerDecoder): + def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False): + super().__init__( + args, dictionary, embed_tokens, no_encoder_attn=no_encoder_attn + ) + self.dictionary = dictionary + self.bos = dictionary.bos() + self.unk = dictionary.unk() + self.eos = dictionary.eos() + + self.embed_mask_ins = Embedding(256, self.output_embed_dim * 2, None) + self.embed_word_del = Embedding(2, self.output_embed_dim, None) + # del_word, ins_mask, ins_word + self.early_exit = [int(i) for i in args.early_exit.split(',')] + assert len(self.early_exit) == 3 + + def extract_features( + self, prev_output_tokens, encoder_out=None, early_exit=None, **unused + ): + """ + Similar to *forward* but only return features. + + Inputs: + prev_output_tokens: Tensor(B, T) + encoder_out: a dictionary of hidden states and masks + + Returns: + tuple: + - the decoder's features of shape `(batch, tgt_len, embed_dim)` + - a dictionary with any model-specific outputs + the LevenshteinTransformer decoder has full-attention to all generated tokens + """ + # embed positions + positions = ( + self.embed_positions(prev_output_tokens) + if self.embed_positions is not None + else None + ) + + # embed tokens and positions + x = self.embed_scale * self.embed_tokens(prev_output_tokens) + if self.project_in_dim is not None: + x = self.project_in_dim(x) + + if positions is not None: + x += positions + x = F.dropout(x, p=self.dropout, training=self.training) + + # B x T x C -> T x B x C + x = x.transpose(0, 1) + attn = None + inner_states = [x] + + # decoder layers + decoder_padding_mask = prev_output_tokens.eq(self.padding_idx) + for i, layer in enumerate(self.layers): + + # early exit from the decoder. + if (early_exit is not None) and (i >= early_exit): + break + + x, attn = layer( + x, + encoder_out["encoder_out"] if encoder_out is not None else None, + encoder_out["encoder_padding_mask"] + if encoder_out is not None + else None, + self_attn_mask=None, + self_attn_padding_mask=decoder_padding_mask, + ) + inner_states.append(x) + + if self.layer_norm: + x = self.layer_norm(x) + + # T x B x C -> B x T x C + x = x.transpose(0, 1) + + if self.project_out_dim is not None: + x = self.project_out_dim(x) + + return x, {"attn": attn, "inner_states": inner_states} + + def forward_mask_ins(self, prev_output_tokens, encoder_out=None): + features, extra = self.extract_features( + prev_output_tokens, encoder_out=encoder_out, early_exit=self.early_exit[1] + ) + features_cat = torch.cat([features[:, :-1, :], features[:, 1:, :]], 2) + return F.linear(features_cat, self.embed_mask_ins.weight), extra['attn'] + + def forward_word_ins(self, prev_output_tokens, encoder_out=None): + features, extra = self.extract_features( + prev_output_tokens, encoder_out=encoder_out, early_exit=self.early_exit[2] + ) + return self.output_layer(features), extra['attn'] + + def forward_word_del(self, prev_output_tokens, encoder_out=None): + features, extra = self.extract_features( + prev_output_tokens, encoder_out=encoder_out, early_exit=self.early_exit[0] + ) + return F.linear(features, self.embed_word_del.weight), extra['attn'] + + def forward_word_del_mask_ins(self, prev_output_tokens, encoder_out=None): + # merge the word-deletion and mask insertion into one operation, + assert self.early_exit[0] == self.early_exit[1], "must the same depth." + features, extra = self.extract_features( + prev_output_tokens, encoder_out=encoder_out, early_exit=self.early_exit[2] + ) + features_cat = torch.cat([features[:, :-1, :], features[:, 1:, :]], 2) + f_word_del = F.linear(features, self.embed_word_del.weight) + f_mask_ins = F.linear(features_cat, self.embed_mask_ins.weight) + return f_word_del, f_mask_ins, extra['attn'] + + +@register_model_architecture("levenshtein_transformer", "levenshtein_transformer") +def base_architecture(args): + args.encoder_embed_path = getattr(args, "encoder_embed_path", None) + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512) + args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048) + args.encoder_layers = getattr(args, "encoder_layers", 6) + args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8) + args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False) + args.encoder_learned_pos = getattr(args, "encoder_learned_pos", False) + args.decoder_embed_path = getattr(args, "decoder_embed_path", None) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", args.encoder_embed_dim) + args.decoder_ffn_embed_dim = getattr( + args, "decoder_ffn_embed_dim", args.encoder_ffn_embed_dim + ) + args.decoder_layers = getattr(args, "decoder_layers", 6) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8) + args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False) + args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False) + args.attention_dropout = getattr(args, "attention_dropout", 0.0) + args.activation_dropout = getattr(args, "activation_dropout", 0.0) + args.activation_fn = getattr(args, "activation_fn", "relu") + args.dropout = getattr(args, "dropout", 0.1) + args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None) + args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0) + args.share_decoder_input_output_embed = getattr( + args, "share_decoder_input_output_embed", False + ) + args.share_all_embeddings = getattr(args, "share_all_embeddings", False) + args.no_token_positional_embeddings = getattr( + args, "no_token_positional_embeddings", False + ) + args.adaptive_input = getattr(args, "adaptive_input", False) + args.apply_bert_init = getattr(args, "apply_bert_init", False) + + args.decoder_output_dim = getattr( + args, "decoder_output_dim", args.decoder_embed_dim + ) + args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim) + + args.early_exit = getattr(args, "early_exit", "(6, 6, 6)") + + +@register_model_architecture( + "levenshtein_transformer", "levenshtein_transformer_wmt_en_de" +) +def levenshtein_transformer_wmt_en_de(args): + base_architecture(args) + + +# similar parameters used in the "Attention Is All You Need" paper (Vaswani et al., 2017) +@register_model_architecture( + "levenshtein_transformer", "levenshtein_transformer_vaswani_wmt_en_de_big" +) +def levenshtein_transformer_vaswani_wmt_en_de_big(args): + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024) + args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4096) + args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16) + args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 1024) + args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 4096) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16) + args.dropout = getattr(args, "dropout", 0.3) + base_architecture(args) + + +# default parameters used in tensor2tensor implementation +@register_model_architecture( + "levenshtein_transformer", "levenshtein_transformer_wmt_en_de_big" +) +def levenshtein_transformer_wmt_en_de_big_t2t(args): + args.encoder_normalize_before = getattr(args, "encoder_normalize_before", True) + args.decoder_normalize_before = getattr(args, "decoder_normalize_before", True) + args.attention_dropout = getattr(args, "attention_dropout", 0.1) + args.activation_dropout = getattr(args, "activation_dropout", 0.1) + levenshtein_transformer_vaswani_wmt_en_de_big(args) diff --git a/fairseq/models/model_utils.py b/fairseq/models/model_utils.py new file mode 100644 index 000000000..8217731c9 --- /dev/null +++ b/fairseq/models/model_utils.py @@ -0,0 +1,62 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch + + +def skip_tensors(x, mask): + """ + Getting sliced (dim=0) tensor by mask. Supporting tensor and list/dict of tensors. + """ + if isinstance(x, int): + return x + + if x is None: + return None + + if isinstance(x, torch.Tensor): + if x.size(0) == mask.size(0): + return x[mask] + elif x.size(1) == mask.size(0): + return x[:, mask] + + if isinstance(x, list): + return [skip_tensors(x_i, mask) for x_i in x] + + if isinstance(x, dict): + return {k: skip_tensors(v, mask) for k, v in x.items()} + + raise NotImplementedError + + +def fill_tensors(x, mask, y, padding_idx): + """ + Filling tensor x with y at masked positions (dim=0). + """ + if x is None: + return y + assert x.dim() == y.dim() and mask.size(0) == x.size(0) + assert x.dim() == 2 or (x.dim() == 3 and x.size(2) == y.size(2)) + n_selected = mask.sum() + assert n_selected == y.size(0) + + if n_selected == x.size(0): + return y + + if x.size(1) < y.size(1): + dims = [x.size(0), y.size(1) - x.size(1)] + if x.dim() == 3: + dims.append(x.size(2)) + x = torch.cat([x, x.new_zeros(*dims).fill_(padding_idx)], 1) + x[mask] = y + elif x.size(1) > y.size(1): + x[mask] = padding_idx + if x.dim() == 2: + x[mask, :y.size(1)] = y + else: + x[mask, :y.size(1), :] = y + else: + x[mask] = y + return x diff --git a/fairseq/models/nonautoregressive_transformer.py b/fairseq/models/nonautoregressive_transformer.py new file mode 100644 index 000000000..d45a5b443 --- /dev/null +++ b/fairseq/models/nonautoregressive_transformer.py @@ -0,0 +1,640 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn.functional as F +from fairseq import utils +from fairseq.models import register_model, register_model_architecture +from fairseq.models.transformer import ( + Embedding, + TransformerDecoder, + TransformerDecoderLayer, + TransformerEncoder, + TransformerModel, +) +from fairseq.modules import MultiheadAttention +from fairseq.modules.transformer_sentence_encoder import init_bert_params + + +def _mean_pooling(enc_feats, src_masks): + # enc_feats: T x B x C + # src_masks: B x T or None + if src_masks is None: + enc_feats = enc_feats.mean(0) + else: + src_masks = (~src_masks).transpose(0, 1).type_as(enc_feats) + enc_feats = ( + (enc_feats / src_masks.sum(0)[None, :, None]) * src_masks[:, :, None] + ).sum(0) + return enc_feats + + +def _argmax(x, dim): + return (x == x.max(dim, keepdim=True)[0]).type_as(x) + + +def _dynamic_programming(tokens, scores): + N, B, T = tokens.size() + cum_scores = scores[:, :, 0].clone() # N x B + cum_choice = tokens.new_zeros(B, T) + + # forward + for t in range(T - 1): + score, choice = cum_scores.max(0) + cum_choice[:, t] = choice + cum_scores[0] = score + scores[0, :, t + 1] + cum_scores[1:] = cum_scores[:-1] + scores[1:, :, t + 1] + + # back-tracking + end_score, end_choice = cum_scores.max(0) + cum_choice[:, T - 1] = end_choice + for t in range(T - 2, -1, -1): + is_start = (cum_choice[:, t + 1] == 0).type_as(cum_choice) + cum_choice[:, t] = (cum_choice[:, t + 1] - 1) * ~is_start + cum_choice[ + :, t + ] * is_start + + # finalize the prediction + tokens = tokens.gather(0, cum_choice.unsqueeze(0)).squeeze(0) + scores = scores.gather(0, cum_choice.unsqueeze(0)).squeeze(0) + return scores, tokens + + +def _beam_search(tokens, scores, W=None): + N, B, T = tokens.size() + + if (W is None) or (W > N): + W = N + + +def _uniform_assignment(src_lens, trg_lens): + max_trg_len = trg_lens.max() + steps = (src_lens.float() - 1) / (trg_lens.float() - 1) # step-size + # max_trg_len + index_t = torch.arange(max_trg_len, device=trg_lens.device).float() + index_t = steps[:, None] * index_t[None, :] # batch_size X max_trg_len + index_t = torch.round(index_t).long().detach() + return index_t + + +@register_model("nonautoregressive_transformer") +class NATransformerModel(TransformerModel): + def __init__(self, encoder, decoder): + super().__init__(encoder, decoder) + self.tgt_dict = decoder.dictionary + self.bos = decoder.dictionary.bos() + self.eos = decoder.dictionary.eos() + self.pad = decoder.dictionary.pad() + self.unk = decoder.dictionary.unk() + + @staticmethod + def add_args(parser): + TransformerModel.add_args(parser) + parser.add_argument( + "--apply-bert-init", + action="store_true", + help="use custom param initialization for BERT", + ) + + # length prediction + parser.add_argument("--src-embedding-copy", action="store_true", + help="copy encoder word embeddings as the initial input of the decoder") + parser.add_argument("--pred-length-offset", action="store_true", + help="predicting the length difference between the target and source sentences") + parser.add_argument("--sg-length-pred", action="store_true", + help="stop the gradients back-propagated from the length predictor") + parser.add_argument("--length-loss-factor", type=float, + help="weights on the length prediction loss") + + # n-gram predictor + parser.add_argument( + "--ngram-predictor", + nargs="?", + const=4, + default=1, + type=int, + help="adding an additional n-gram predictor.", + ) + + @classmethod + def build_decoder(cls, args, tgt_dict, embed_tokens): + decoder = NATransformerDecoder(args, tgt_dict, embed_tokens) + if getattr(args, "apply_bert_init", False): + decoder.apply(init_bert_params) + return decoder + + @classmethod + def build_encoder(cls, args, src_dict, embed_tokens): + encoder = TransformerEncoder(args, src_dict, embed_tokens) + if getattr(args, "apply_bert_init", False): + encoder.apply(init_bert_params) + return encoder + + def forward( + self, src_tokens, src_lengths, prev_output_tokens, tgt_tokens, **kwargs + ): + # encoding + encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs) + length_out, length_tgt = self.decoder.forward_length_prediction( + encoder_out, tgt_tokens + ) + + word_ins_out, word_ins_tgt, word_ins_mask = self.decoder( + prev_output_tokens, encoder_out=encoder_out, tgt_tokens=tgt_tokens + ) + + return { + "word_ins_out": word_ins_out, + "word_ins_tgt": word_ins_tgt, + "word_ins_mask": word_ins_mask, + "length_out": length_out, + "length_tgt": length_tgt, + "length_w": self.decoder.length_loss_factor, + } + + def forward_encoder(self, encoder_inputs): + return self.encoder(*encoder_inputs) + + def forward_decoder(self, decoder_out, encoder_out, decoding_format=None, **kwargs): + step = decoder_out["step"] + output_tokens = decoder_out["output_tokens"] + output_scores = decoder_out["output_scores"] + + # execute the decoder + output_masks = output_tokens.ne(self.pad) + _scores, _tokens = self.decoder( + output_tokens, + encoder_out=encoder_out, + decoding_format=decoding_format, + step=step, + ) + output_tokens.masked_scatter_(output_masks, _tokens[output_masks]) + output_scores.masked_scatter_(output_masks, _scores[output_masks]) + + return {"output_tokens": output_tokens, "output_scores": output_scores} + + def initialize_output_tokens(self, encoder_out, src_tokens): + # length prediction + _, length_tgt = self.decoder.forward_length_prediction(encoder_out) + max_length = length_tgt.max() + idx_length = torch.arange(max_length, device=src_tokens.device) + + initial_output_tokens = src_tokens.new_zeros( + src_tokens.size(0), max_length + ).fill_(self.pad) + initial_output_tokens.masked_fill_( + idx_length[None, :] < length_tgt[:, None], self.unk + ) + initial_output_tokens[:, 0] = self.bos + initial_output_tokens.scatter_(1, length_tgt[:, None] - 1, self.eos) + + initial_output_scores = initial_output_tokens.new_zeros( + *initial_output_tokens.size() + ).type_as(encoder_out["encoder_out"]) + + return { + "output_tokens": initial_output_tokens, + "output_scores": initial_output_scores, + } + + +class NATransformerDecoder(TransformerDecoder): + def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False): + super().__init__( + args, dictionary, embed_tokens, no_encoder_attn=no_encoder_attn + ) + + self.dictionary = dictionary + self.bos = dictionary.bos() + self.unk = dictionary.unk() + self.eos = dictionary.eos() + + self.encoder_embed_dim = args.encoder_embed_dim + self.sg_length_pred = getattr(args, "sg_length_pred", False) + self.pred_length_offset = getattr(args, "pred_length_offset", False) + self.length_loss_factor = getattr(args, "length_loss_factor", 0.1) + self.src_embedding_copy = getattr(args, "src_embedding_copy", False) + self.embed_length = Embedding(256, self.encoder_embed_dim, None) + + self.ngram_predictor = getattr(args, "ngram_predictor", 1) + self.ngram_layer = ( + None if (self.ngram_predictor == 1) else NgramDecoderLayer(args, True) + ) + + def forward( + self, + prev_output_tokens, + encoder_out=None, + tgt_tokens=None, + decoding_format=None, + step=0, + **kwargs + ): + + features, _ = self.extract_features( + prev_output_tokens, + encoder_out=encoder_out, + embedding_copy=(step == 0) & self.src_embedding_copy, + ) + + if tgt_tokens is not None: + if self.ngram_layer is None: + word_ins_mask = tgt_tokens.ne(self.padding_idx) + word_ins_tgt = tgt_tokens + else: + context_embeds, context_masks = self.forward_ngram_context(tgt_tokens) + features = self.ngram_layer(features, context_embeds=context_embeds) + word_ins_tgt = tgt_tokens[:, :, None].repeat(1, 1, self.ngram_predictor) + word_ins_mask = word_ins_tgt.ne(self.padding_idx) & context_masks + + return self.output_layer(features), word_ins_tgt, word_ins_mask + + else: + if self.ngram_layer is None: + return F.log_softmax(self.output_layer(features), -1).max(-1) + else: + # inner iterations + return self.forward_ngram_decoding( + features, prev_output_tokens.eq(self.padding_idx), decoding_format + ) + + def extract_features( + self, + prev_output_tokens, + encoder_out=None, + early_exit=None, + embedding_copy=False, + **unused + ): + """ + Similar to *forward* but only return features. + + Inputs: + prev_output_tokens: Tensor(B, T) + encoder_out: a dictionary of hidden states and masks + + Returns: + tuple: + - the decoder's features of shape `(batch, tgt_len, embed_dim)` + - a dictionary with any model-specific outputs + the LevenshteinTransformer decoder has full-attention to all generated tokens + """ + # embedding + if embedding_copy: + src_embd = encoder_out["encoder_embedding"] + src_mask = encoder_out["encoder_padding_mask"] + src_mask = ( + ~src_mask + if src_mask is not None + else prev_output_tokens.new_ones(*src_embd.size()[:2]).bool() + ) + + x, decoder_padding_mask = self.forward_embedding( + prev_output_tokens, + self.forward_copying_source( + src_embd, src_mask, prev_output_tokens.ne(self.padding_idx) + ), + ) + + else: + + x, decoder_padding_mask = self.forward_embedding(prev_output_tokens) + + # B x T x C -> T x B x C + x = x.transpose(0, 1) + attn = None + inner_states = [x] + + # decoder layers + for i, layer in enumerate(self.layers): + + # early exit from the decoder. + if (early_exit is not None) and (i >= early_exit): + break + + x, attn = layer( + x, + encoder_out["encoder_out"] if encoder_out is not None else None, + encoder_out["encoder_padding_mask"] + if encoder_out is not None + else None, + self_attn_mask=None, + self_attn_padding_mask=decoder_padding_mask, + ) + inner_states.append(x) + + if self.layer_norm: + x = self.layer_norm(x) + + # T x B x C -> B x T x C + x = x.transpose(0, 1) + + if self.project_out_dim is not None: + x = self.project_out_dim(x) + + return x, {"attn": attn, "inner_states": inner_states} + + def forward_ngram_context(self, tgt_tokens): + tgt_embeds = self.forward_embedding(tgt_tokens) + n_contexts = self.ngram_predictor - 1 + + # shifting the embeddings + # context_embeds: N x B x T x C + # context_masks: B x T x N + context_embeds = tgt_embeds.new_zeros(n_contexts, *tgt_embeds.size()) + context_masks = tgt_embeds.new_ones( + *tgt_embeds.size()[:2], self.ngram_predictor + ).bool() + + for k in range(n_contexts): + context_embeds[k, :, k + 1:] = tgt_embeds[:, : -k - 1] + context_masks[:, : k + 1, k + 1] = 0 + + return context_embeds, context_masks + + def forward_ngram_decoding(self, features, padding_mask=None, decoding_format=None): + context_embeds = None + scores, tokens = [], [] + ensemble_score = None + ensemble_index = None + + if decoding_format is None: + decoding_format = "ensemble" + + for k in range(self.ngram_predictor): + ngram_out = self.ngram_layer( + features, context_embeds=context_embeds, incremental=True + ) + ngram_scores = F.log_softmax(self.output_layer(ngram_out), -1) + max_score, max_token = ngram_scores.max(-1) + + if decoding_format == "vote": + ngram_scores = _argmax(ngram_scores, -1) + + if ensemble_score is None: + ensemble_score = ngram_scores + ensemble_index = ensemble_score.new_ones(*ensemble_score.size()[:2]) + else: + ensemble_index[:, k:] = ensemble_index[:, k:] + 1 + ensemble_score = ensemble_score + ngram_scores.masked_fill_( + (ensemble_index < k) + .unsqueeze(2) + .repeat(1, 1, ensemble_score.size(2)), + 0, + ) + max_score[:, :k] = float("-inf") + + if decoding_format == "unigram": + break + + scores.append(max_score.masked_fill_(padding_mask, 0)) + tokens.append(max_token.masked_fill_(padding_mask, self.padding_idx)) + + # context_embeds: N x B x T x C + if context_embeds is None: + context_embeds = self.forward_embedding(max_token).unsqueeze(0) + + else: + context_embeds = torch.cat( + [self.forward_embedding(max_token).unsqueeze(0), context_embeds], 0 + ) + + context_embeds[:, :, 1:] = context_embeds[:, :, :-1] + + if decoding_format != "dp": + ensemble_score = ensemble_score / ensemble_index.unsqueeze(2) + return ensemble_score.max(-1) + + else: + tokens = torch.cat([t.unsqueeze(0) for t in tokens], 0) + scores = torch.cat([s.unsqueeze(0) for s in scores], 0) + return _dynamic_programming(tokens, scores) + + def forward_embedding(self, prev_output_tokens, states=None): + # embed positions + positions = ( + self.embed_positions(prev_output_tokens) + if self.embed_positions is not None + else None + ) + + # embed tokens and positions + if states is None: + x = self.embed_scale * self.embed_tokens(prev_output_tokens) + if self.project_in_dim is not None: + x = self.project_in_dim(x) + else: + x = states + + if positions is not None: + x += positions + x = F.dropout(x, p=self.dropout, training=self.training) + decoder_padding_mask = prev_output_tokens.eq(self.padding_idx) + return x, decoder_padding_mask + + def forward_copying_source(self, src_embeds, src_masks, tgt_masks): + length_sources = src_masks.sum(1) + length_targets = tgt_masks.sum(1) + mapped_inputs = _uniform_assignment(length_sources, length_targets).masked_fill( + ~tgt_masks, 0 + ) + copied_embedding = torch.gather( + src_embeds, + 1, + mapped_inputs.unsqueeze(-1).expand( + *mapped_inputs.size(), src_embeds.size(-1) + ), + ) + return copied_embedding + + def forward_length_prediction(self, encoder_out, tgt_tokens=None): + enc_feats = encoder_out["encoder_out"] # T x B x C + src_masks = encoder_out["encoder_padding_mask"] # B x T or None + + if self.pred_length_offset: + if src_masks is None: + src_lengs = enc_feats.new_ones(enc_feats.size(1)).fill_( + enc_feats.size(0) + ) + else: + src_lengs = (~src_masks).transpose(0, 1).type_as(enc_feats).sum(0) + src_lengs = src_lengs.long() + + enc_feats = _mean_pooling(enc_feats, src_masks) + if self.sg_length_pred: + enc_feats = enc_feats.detach() + + length_out = F.linear(enc_feats, self.embed_length.weight) + + if tgt_tokens is not None: + # obtain the length target + tgt_lengs = tgt_tokens.ne(self.padding_idx).sum(1).long() + if self.pred_length_offset: + length_tgt = tgt_lengs - src_lengs + 128 + else: + length_tgt = tgt_lengs + length_tgt = length_tgt.clamp(min=0, max=255) + + else: + # predict the length target (greedy for now) + # TODO: implementing length-beam + pred_lengs = length_out.max(-1)[1] + if self.pred_length_offset: + length_tgt = pred_lengs - 128 + src_lengs + else: + length_tgt = pred_lengs + + return length_out, length_tgt + + +class NgramDecoderLayer(TransformerDecoderLayer): + """ + N-gram Decoder Layer: + + This module can be pluged in the last layer of any Non-autoregressive Model's + It provides an alternative way to capture local n-gram information by running the block multiple times. + """ + + def __init__(self, args, no_encoder_attn=False): + super(NgramDecoderLayer, self).__init__(args, no_encoder_attn=no_encoder_attn) + self.self_attn = MultiheadAttention( + embed_dim=self.embed_dim, + num_heads=1, # maybe n-gram does not need too many heads. + dropout=args.attention_dropout, + self_attention=False, + encoder_decoder_attention=True, + ) + + def forward( + self, + x, + encoder_out=None, + encoder_padding_mask=None, + context_embeds=None, + incremental=False, + ): + # x: T x B x C + # context_embeds: N x T x B x C + T, B, C = x.size() + + residual = x + x = self.maybe_layer_norm(self.self_attn_layer_norm, x, before=True) + x = x.contiguous().view(1, T * B, C).contiguous() + + if context_embeds is not None: + N = context_embeds.size(0) + context_embeds = context_embeds.view(N, T * B, C).contiguous() + + if not incremental: + assert context_embeds is not None, "we need context for training" + # attn_weights: (n_head x T x B) x 1 x N + # v: (n_head x T x B) x N x (dim / n_head) + # -- move the attention computation outside -- + attn_weights, values = self.self_attn( + query=x, key=context_embeds, value=context_embeds, before_softmax=True + ) + + attn_weights = attn_weights.repeat(1, N, 1) + attn_masks = attn_weights.new_ones(N, N).triu_(1).bool() + attn_masks = attn_masks.unsqueeze(0).repeat(attn_weights.size(0), 1, 1) + + attn_weights = attn_weights.masked_fill(attn_masks, float("-inf")) + attn_weights = utils.softmax(attn_weights, dim=-1).type_as(attn_weights) + attn_weights = F.dropout( + attn_weights, p=self.self_attn.dropout, training=self.training + ) + + # (n_head x T x B) x N x (dim / n_head) + attn = torch.bmm(attn_weights, values) + attn = attn.transpose(0, 1).contiguous() + attn = attn.view(N, T * B, C).contiguous() + attn = attn.transpose(1, 0).contiguous() + attn = attn.view(T, B, N, C) + + residual = residual.unsqueeze(2) + x = self.self_attn.out_proj(attn) + x = F.dropout(x, p=self.dropout, training=self.training) + x = torch.cat([residual, residual + x], 2) + + else: + if context_embeds is None: + x = residual + + else: + x, _ = self.self_attn(query=x, key=context_embeds, value=context_embeds) + x = x.view(T, B, C) + x = F.dropout(x, p=self.dropout, training=self.training) + x = residual + x + + x = self.maybe_layer_norm(self.self_attn_layer_norm, x, after=True) + + if self.encoder_attn is not None: + raise NotImplementedError + + residual = x + x = self.maybe_layer_norm(self.final_layer_norm, x, before=True) + x = self.activation_fn(self.fc1(x)) + x = F.dropout(x, p=self.activation_dropout, training=self.training) + x = self.fc2(x) + x = F.dropout(x, p=self.dropout, training=self.training) + x = residual + x + x = self.maybe_layer_norm(self.final_layer_norm, x, after=True) + return x + + +@register_model_architecture( + "nonautoregressive_transformer", "nonautoregressive_transformer" +) +def base_architecture(args): + args.encoder_embed_path = getattr(args, "encoder_embed_path", None) + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512) + args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048) + args.encoder_layers = getattr(args, "encoder_layers", 6) + args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8) + args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False) + args.encoder_learned_pos = getattr(args, "encoder_learned_pos", False) + args.decoder_embed_path = getattr(args, "decoder_embed_path", None) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", args.encoder_embed_dim) + args.decoder_ffn_embed_dim = getattr( + args, "decoder_ffn_embed_dim", args.encoder_ffn_embed_dim + ) + args.decoder_layers = getattr(args, "decoder_layers", 6) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8) + args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False) + args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False) + args.attention_dropout = getattr(args, "attention_dropout", 0.0) + args.activation_dropout = getattr(args, "activation_dropout", 0.0) + args.activation_fn = getattr(args, "activation_fn", "relu") + args.dropout = getattr(args, "dropout", 0.1) + args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None) + args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0) + args.share_decoder_input_output_embed = getattr( + args, "share_decoder_input_output_embed", False + ) + args.share_all_embeddings = getattr(args, "share_all_embeddings", False) + args.no_token_positional_embeddings = getattr( + args, "no_token_positional_embeddings", False + ) + args.adaptive_input = getattr(args, "adaptive_input", False) + args.apply_bert_init = getattr(args, "apply_bert_init", False) + + args.decoder_output_dim = getattr( + args, "decoder_output_dim", args.decoder_embed_dim + ) + args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim) + + # --- special arguments --- + args.sg_length_pred = getattr(args, "sg_length_pred", False) + args.pred_length_offset = getattr(args, "pred_length_offset", False) + args.length_loss_factor = getattr(args, "length_loss_factor", 0.1) + args.src_embedding_copy = getattr(args, "src_embedding_copy", False) + args.ngram_predictor = getattr(args, "ngram_predictor", 1) + + +@register_model_architecture( + "nonautoregressive_transformer", "nonautoregressive_transformer_wmt_en_de" +) +def nonautoregressive_transformer_wmt_en_de(args): + base_architecture(args) diff --git a/fairseq/models/transformer.py b/fairseq/models/transformer.py index 7fedc7755..dd10ae535 100644 --- a/fairseq/models/transformer.py +++ b/fairseq/models/transformer.py @@ -172,7 +172,7 @@ class TransformerModel(FairseqEncoderDecoderModel): encoder = cls.build_encoder(args, src_dict, encoder_embed_tokens) decoder = cls.build_decoder(args, tgt_dict, decoder_embed_tokens) - return TransformerModel(encoder, decoder) + return cls(encoder, decoder) @classmethod def build_encoder(cls, args, src_dict, embed_tokens): @@ -222,7 +222,15 @@ class TransformerEncoder(FairseqEncoder): else: self.layer_norm = None - def forward(self, src_tokens, src_lengths): + def forward_embedding(self, src_tokens): + # embed tokens and positions + embed = self.embed_scale * self.embed_tokens(src_tokens) + if self.embed_positions is not None: + x = embed + self.embed_positions(src_tokens) + x = F.dropout(x, p=self.dropout, training=self.training) + return x, embed + + def forward(self, src_tokens, src_lengths, cls_input=None): """ Args: src_tokens (LongTensor): tokens in the source language of shape @@ -237,11 +245,7 @@ class TransformerEncoder(FairseqEncoder): - **encoder_padding_mask** (ByteTensor): the positions of padding elements of shape `(batch, src_len)` """ - # embed tokens and positions - x = self.embed_scale * self.embed_tokens(src_tokens) - if self.embed_positions is not None: - x += self.embed_positions(src_tokens) - x = F.dropout(x, p=self.dropout, training=self.training) + x, encoder_embedding = self.forward_embedding(src_tokens) # B x T x C -> T x B x C x = x.transpose(0, 1) @@ -261,6 +265,7 @@ class TransformerEncoder(FairseqEncoder): return { 'encoder_out': x, # T x B x C 'encoder_padding_mask': encoder_padding_mask, # B x T + 'encoder_embedding': encoder_embedding, # B x T x C } def reorder_encoder_out(self, encoder_out, new_order): @@ -332,7 +337,7 @@ class TransformerDecoder(FairseqIncrementalDecoder): embed_dim = args.decoder_embed_dim self.output_embed_dim = args.decoder_output_dim - padding_idx = embed_tokens.padding_idx + self.padding_idx = embed_tokens.padding_idx self.max_target_positions = args.max_target_positions self.embed_tokens = embed_tokens @@ -341,7 +346,7 @@ class TransformerDecoder(FairseqIncrementalDecoder): self.project_in_dim = Linear(input_embed_dim, embed_dim, bias=False) if embed_dim != input_embed_dim else None self.embed_positions = PositionalEmbedding( - args.max_target_positions, embed_dim, padding_idx, + args.max_target_positions, embed_dim, self.padding_idx, learned=args.decoder_learned_pos, ) if not args.no_token_positional_embeddings else None diff --git a/fairseq/modules/multihead_attention.py b/fairseq/modules/multihead_attention.py index 4da628655..8c28255df 100644 --- a/fairseq/modules/multihead_attention.py +++ b/fairseq/modules/multihead_attention.py @@ -91,7 +91,7 @@ class MultiheadAttention(nn.Module): nn.init.xavier_normal_(self.bias_v) def forward(self, query, key, value, key_padding_mask=None, incremental_state=None, - need_weights=True, static_kv=False, attn_mask=None): + need_weights=True, static_kv=False, attn_mask=None, before_softmax=False): """Input shape: Time x Batch x Channel Timesteps can be masked by supplying a T x T mask in the @@ -239,6 +239,9 @@ class MultiheadAttention(nn.Module): ) attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + if before_softmax: + return attn_weights, v + attn_weights = utils.softmax( attn_weights, dim=-1, onnx_trace=self.onnx_trace, ).type_as(attn_weights) diff --git a/fairseq/modules/transformer_layer.py b/fairseq/modules/transformer_layer.py index 5da4909ca..f4a80ccee 100644 --- a/fairseq/modules/transformer_layer.py +++ b/fairseq/modules/transformer_layer.py @@ -83,7 +83,7 @@ class TransformerEncoderLayer(nn.Module): residual = x x = self.maybe_layer_norm(self.self_attn_layer_norm, x, before=True) if attn_mask is not None: - attn_mask = attn_mask.masked_fill(attn_mask.byte(), -1e8) + attn_mask = attn_mask.masked_fill(attn_mask.bool(), -1e8) # anything in original attn_mask = 1, becomes -1e8 # anything in original attn_mask = 0, becomes 0 # Note that we cannot use -inf here, because at some edge cases, diff --git a/fairseq/modules/transformer_sentence_encoder.py b/fairseq/modules/transformer_sentence_encoder.py index 169929125..9be7ab308 100644 --- a/fairseq/modules/transformer_sentence_encoder.py +++ b/fairseq/modules/transformer_sentence_encoder.py @@ -36,7 +36,8 @@ def init_bert_params(module): module.bias.data.zero_() if isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=0.02) - module.weight.data[module.padding_idx].zero_() + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() if isinstance(module, MultiheadAttention): module.in_proj_weight.data.normal_(mean=0.0, std=0.02) diff --git a/fairseq/options.py b/fairseq/options.py index 54c786390..bb1e27aeb 100644 --- a/fairseq/options.py +++ b/fairseq/options.py @@ -280,6 +280,8 @@ def add_dataset_args(parser, train=False, gen=False): ' (train, valid, valid1, test, test1)') group.add_argument('--validate-interval', type=int, default=1, metavar='N', help='validate every N epochs') + group.add_argument('--fixed-validation-seed', default=None, type=int, metavar='N', + help='specified random seed for validation') group.add_argument('--disable-validation', action='store_true', help='disable validation') group.add_argument('--max-tokens-valid', type=int, metavar='N', @@ -493,6 +495,18 @@ def add_generation_args(parser): help='strength of diversity penalty for Diverse Beam Search') group.add_argument('--print-alignment', action='store_true', help='if set, uses attention feedback to compute and print alignment to source tokens') + group.add_argument('--print-step', action='store_true') + + # arguments for iterative refinement generator + group.add_argument('---iter-decode-eos-penalty', default=0.0, type=float, metavar='N', + help='if > 0.0, it penalized early-stopping in decoding.') + group.add_argument('--iter-decode-max-iter', default=10, type=int, metavar='N', + help='maximum iterations for iterative refinement.') + group.add_argument('--iter-decode-force-max-iter', action='store_true', + help='if set, run exact the maximum number of iterations without early stop') + + # special decoding format for advanced decoding. + group.add_argument('--decoding-format', default=None, type=str, choices=['unigram', 'ensemble', 'vote', 'dp', 'bs']) # fmt: on return group diff --git a/fairseq/tasks/translation.py b/fairseq/tasks/translation.py index d3f51cb35..f3d60403b 100644 --- a/fairseq/tasks/translation.py +++ b/fairseq/tasks/translation.py @@ -12,6 +12,7 @@ from fairseq.data import ( data_utils, indexed_dataset, LanguagePairDataset, + PrependTokenDataset, ) from . import FairseqTask, register_task @@ -22,7 +23,8 @@ def load_langpair_dataset( src, src_dict, tgt, tgt_dict, combine, dataset_impl, upsample_primary, - left_pad_source, left_pad_target, max_source_positions, max_target_positions, + left_pad_source, left_pad_target, max_source_positions, + max_target_positions, prepend_bos=False, ): def split_exists(split, src, tgt, lang, data_path): filename = os.path.join(data_path, '{}.{}-{}.{}'.format(split, src, tgt, lang)) @@ -67,6 +69,11 @@ def load_langpair_dataset( src_dataset = ConcatDataset(src_datasets, sample_ratios) tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios) + if prepend_bos: + assert hasattr(src_dict, "bos_index") and hasattr(tgt_dict, "bos_index") + src_dataset = PrependTokenDataset(src_dataset, src_dict.bos()) + tgt_dataset = PrependTokenDataset(tgt_dataset, tgt_dict.bos()) + return LanguagePairDataset( src_dataset, src_dataset.sizes, src_dict, tgt_dataset, tgt_dataset.sizes, tgt_dict, diff --git a/fairseq/tasks/translation_lev.py b/fairseq/tasks/translation_lev.py new file mode 100644 index 000000000..47d6a3ed4 --- /dev/null +++ b/fairseq/tasks/translation_lev.py @@ -0,0 +1,149 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch + +from fairseq.tasks import register_task +from fairseq.tasks.translation import TranslationTask, load_langpair_dataset + + +@register_task('translation_lev') +class TranslationLevenshteinTask(TranslationTask): + """ + Translation (Sequence Generation) task for Levenshtein Transformer + See `"Levenshtein Transformer" `_. + """ + + @staticmethod + def add_args(parser): + """Add task-specific arguments to the parser.""" + # fmt: off + TranslationTask.add_args(parser) + parser.add_argument( + '--noise', + default='random_delete', + choices=['random_delete', 'random_mask', 'no_noise', 'full_mask']) + + def load_dataset(self, split, epoch=0, combine=False, **kwargs): + """Load a given dataset split. + + Args: + split (str): name of the split (e.g., train, valid, test) + """ + paths = self.args.data.split(':') + assert len(paths) > 0 + data_path = paths[epoch % len(paths)] + + # infer langcode + src, tgt = self.args.source_lang, self.args.target_lang + + self.datasets[split] = load_langpair_dataset( + data_path, split, src, self.src_dict, tgt, self.tgt_dict, + combine=combine, dataset_impl=self.args.dataset_impl, + upsample_primary=self.args.upsample_primary, + left_pad_source=self.args.left_pad_source, + left_pad_target=self.args.left_pad_target, + max_source_positions=self.args.max_source_positions, + max_target_positions=self.args.max_target_positions, + prepend_bos=True, + ) + + def inject_noise(self, target_tokens): + def _random_delete(target_tokens): + pad = self.tgt_dict.pad() + bos = self.tgt_dict.bos() + eos = self.tgt_dict.eos() + + max_len = target_tokens.size(1) + target_mask = target_tokens.eq(pad) + target_score = target_tokens.clone().float().uniform_() + target_score.masked_fill_( + target_tokens.eq(bos) | target_tokens.eq(eos), 0.0) + target_score.masked_fill_(target_mask, 1) + target_score, target_rank = target_score.sort(1) + target_length = target_mask.size(1) - target_mask.float().sum( + 1, keepdim=True) + + # do not delete and (we assign 0 score for them) + target_cutoff = 2 + ((target_length - 2) * target_score.new_zeros( + target_score.size(0), 1).uniform_()).long() + target_cutoff = target_score.sort(1)[1] >= target_cutoff + + prev_target_tokens = target_tokens.gather( + 1, target_rank).masked_fill_(target_cutoff, pad).gather( + 1, + target_rank.masked_fill_(target_cutoff, + max_len).sort(1)[1]) + prev_target_tokens = prev_target_tokens[:, :prev_target_tokens. + ne(pad).sum(1).max()] + + return prev_target_tokens + + def _random_mask(target_tokens): + pad = self.tgt_dict.pad() + bos = self.tgt_dict.bos() + eos = self.tgt_dict.eos() + unk = self.tgt_dict.unk() + + target_mask = target_tokens.eq(bos) | target_tokens.eq( + eos) | target_tokens.eq(pad) + target_score = target_tokens.clone().float().uniform_() + target_score.masked_fill_(target_mask, 1.0) + + prev_target_tokens = target_tokens.masked_fill( + target_score < target_score.new_zeros(target_score.size(0), + 1).uniform_(), unk) + return prev_target_tokens + + def _full_mask(target_tokens): + pad = self.tgt_dict.pad() + bos = self.tgt_dict.bos() + eos = self.tgt_dict.eos() + unk = self.tgt_dict.unk() + + target_mask = target_tokens.eq(bos) | target_tokens.eq( + eos) | target_tokens.eq(pad) + return target_tokens.masked_fill(~target_mask, unk) + + if self.args.noise == 'random_delete': + return _random_delete(target_tokens) + elif self.args.noise == 'random_mask': + return _random_mask(target_tokens) + elif self.args.noise == 'full_mask': + return _full_mask(target_tokens) + elif self.args.noise == 'no_noise': + return target_tokens + else: + raise NotImplementedError + + def build_generator(self, args): + from fairseq.iterative_refinement_generator import IterativeRefinementGenerator + return IterativeRefinementGenerator( + self.target_dictionary, + eos_penalty=getattr(args, 'iter_decode_eos_penalty', 0.0), + max_iter=getattr(args, 'iter_decode_max_iter', 10), + decoding_format=getattr(args, 'decoding_format', None), + adaptive=not getattr(args, 'iter_decode_force_max_iter', False)) + + def train_step(self, + sample, + model, + criterion, + optimizer, + ignore_grad=False): + model.train() + sample['prev_target'] = self.inject_noise(sample['target']) + loss, sample_size, logging_output = criterion(model, sample) + if ignore_grad: + loss *= 0 + optimizer.backward(loss) + return loss, sample_size, logging_output + + def valid_step(self, sample, model, criterion): + model.eval() + with torch.no_grad(): + sample['prev_target'] = self.inject_noise(sample['target']) + loss, sample_size, logging_output = criterion(model, sample) + return loss, sample_size, logging_output diff --git a/fairseq/utils.py b/fairseq/utils.py index 1af239443..80ecb6d08 100644 --- a/fairseq/utils.py +++ b/fairseq/utils.py @@ -359,3 +359,11 @@ def has_parameters(module): return True except StopIteration: return False + + +def set_torch_seed(seed): + # Set seed based on args.seed and the update number so that we get + # reproducible results when resuming from checkpoints + assert isinstance(seed, int) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) diff --git a/generate.py b/generate.py index c23cc7986..6de1a69ab 100644 --- a/generate.py +++ b/generate.py @@ -159,6 +159,9 @@ def main(args): ' '.join(map(lambda x: str(utils.item(x)), alignment)) )) + if args.print_step: + print('I-{}\t{}'.format(sample_id, hypo['steps'])) + # Score only the top hypothesis if has_target and j == 0: if align_dict is not None or args.remove_bpe is not None: diff --git a/setup.py b/setup.py index 8f4604be1..33849f810 100644 --- a/setup.py +++ b/setup.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. from setuptools import setup, find_packages, Extension +from torch.utils import cpp_extension import sys @@ -60,6 +61,12 @@ extensions = [ language='c++', extra_compile_args=extra_compile_args, ), + cpp_extension.CppExtension( + 'fairseq.libnat', + sources=[ + 'fairseq/clib/libnat/edit_dist.cpp', + ], + ) ] @@ -106,5 +113,6 @@ setup( 'fairseq-validate = fairseq_cli.validate:cli_main', ], }, + cmdclass={'build_ext': cpp_extension.BuildExtension}, zip_safe=False, ) diff --git a/tests/test_binaries.py b/tests/test_binaries.py index b51727827..8cede3c9f 100644 --- a/tests/test_binaries.py +++ b/tests/test_binaries.py @@ -180,6 +180,52 @@ class TestTranslation(unittest.TestCase): ]) generate_main(data_dir) + def test_levenshtein_transformer(self): + with contextlib.redirect_stdout(StringIO()): + with tempfile.TemporaryDirectory('test_levenshtein_transformer') as data_dir: + create_dummy_data(data_dir) + preprocess_translation_data(data_dir) + train_translation_model(data_dir, 'levenshtein_transformer', [ + '--apply-bert-init', '--early-exit', '6,6,6', + '--criterion', 'nat_loss' + ], task='translation_lev') + generate_main(data_dir) + + def test_nonautoregressive_transformer(self): + with contextlib.redirect_stdout(StringIO()): + with tempfile.TemporaryDirectory('test_nonautoregressive_transformer') as data_dir: + create_dummy_data(data_dir) + preprocess_translation_data(data_dir) + train_translation_model(data_dir, 'nonautoregressive_transformer', [ + '--apply-bert-init', '--src-embedding-copy', '--criterion', + 'nat_loss', '--noise', 'full_mask', '--pred-length-offset', + '--length-loss-factor', '0.1' + ], task='translation_lev') + generate_main(data_dir) + + def test_iterative_nonautoregressive_transformer(self): + with contextlib.redirect_stdout(StringIO()): + with tempfile.TemporaryDirectory('test_iterative_nonautoregressive_transformer') as data_dir: + create_dummy_data(data_dir) + preprocess_translation_data(data_dir) + train_translation_model(data_dir, 'iterative_nonautoregressive_transformer', [ + '--apply-bert-init', '--src-embedding-copy', '--criterion', + 'nat_loss', '--noise', 'full_mask', '--stochastic-approx', + '--dae-ratio', '0.5', '--train-step', '3' + ], task='translation_lev') + generate_main(data_dir) + + def test_insertion_transformer(self): + with contextlib.redirect_stdout(StringIO()): + with tempfile.TemporaryDirectory('test_insertion_transformer') as data_dir: + create_dummy_data(data_dir) + preprocess_translation_data(data_dir) + train_translation_model(data_dir, 'insertion_transformer', [ + '--apply-bert-init', '--criterion', 'nat_loss', '--noise', + 'random_mask' + ], task='translation_lev') + generate_main(data_dir) + def test_mixture_of_experts(self): with contextlib.redirect_stdout(StringIO()): with tempfile.TemporaryDirectory('test_moe') as data_dir: diff --git a/train.py b/train.py index db04dc219..3879375fe 100644 --- a/train.py +++ b/train.py @@ -194,6 +194,11 @@ def get_training_stats(trainer): def validate(args, trainer, task, epoch_itr, subsets): """Evaluate the model on the validation set(s) and return the losses.""" + + if args.fixed_validation_seed is not None: + # set fixed seed for every validation + utils.set_torch_seed(args.fixed_validation_seed) + valid_losses = [] for subset in subsets: # Initialize data iterator