Opensource code for Deep Transformer with Latent Depth (#2703)

Summary:
# Before submitting

- [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements)
- [ ] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)?
- [ ] Did you make sure to update the docs?
- [ ] Did you write any new necessary tests?

## What does this PR do?
Opensource code for Deep Transformer with Latent Depth (https://arxiv.org/pdf/2009.13102.pdf).

New features and design choices made:

- New feature: allow non-residual block to be weighted by sample z (generated per batch) instead of `x = residual + x`.
- Design choice: move  `x = residual + x` in transformer_layer.py into a function where the subclass (with latent depth) could overwrite it to `x = residual + z*x`.

- New feature: allow TransformerEncoder or TransformerDecoder to have additional logits parameters which will generate the samples z.
- Design choice: added subclass LatentTransformerEncoder and LatentTransformerDecoder, which has additional attributes for the logits parameters, and instantiate the corresponding LatentTransformerEncoderLayer and LatentTransformerDecoderLayer.

- New feature: allow multilingual_translation task to train with latent depth (results in the paper).
- Design choice:
  - added additional arguments in the multilingual_translation task.
  - added option for multilingual_transformer to use LatentTransformerEncoder and LatentTransformerDecoder besides standard TransformerEncoder.
  - added option in multilingual_translation task's `train_step` to generate the samples z and compute the KL (and sparsity) loss per batch.

## PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

## Did you have fun?
Make sure you had fun coding �

Pull Request resolved: https://github.com/pytorch/fairseq/pull/2703

Reviewed By: myleott

Differential Revision: D24155059

Pulled By: xianxl

fbshipit-source-id: f3e41639429f9664ec5565839709aa857a643668
This commit is contained in:
Xian Li 2020-10-15 09:23:54 -07:00 committed by Facebook GitHub Bot
parent 3544f5f24e
commit 573c2f4b60
15 changed files with 672 additions and 12 deletions

View File

@ -44,6 +44,7 @@ We provide reference implementations of various sequence modeling papers:
- [Generating Medical Reports from Patient-Doctor Conversations Using Sequence-to-Sequence Models (Enarvi et al., 2020)](examples/pointer_generator/README.md)
- [Linformer: Self-Attention with Linear Complexity (Wang et al., 2020)](examples/linformer/README.md)
- [Cross-lingual Retrieval for Iterative Self-Supervised Training (Tran et al., 2020)](examples/criss/README.md)
- [Deep Transformers with Latent Depth (Li et al., 2020)](examples/latent_depth/README.md)
- **Non-autoregressive Transformers**
- Non-Autoregressive Neural Machine Translation (Gu et al., 2017)
- Deterministic Non-Autoregressive Neural Sequence Modeling by Iterative Refinement (Lee et al. 2018)
@ -55,6 +56,7 @@ We provide reference implementations of various sequence modeling papers:
### What's New:
- October 2020: [Deep Transformer with Latent Depth code released](examples/latent_depth/README.md)
- October 2020: [Added CRISS models and code](examples/criss/README.md)
- September 2020: [Added Linformer code](examples/linformer/README.md)
- September 2020: [Added pointer-generator networks](examples/pointer_generator/README.md)

View File

@ -0,0 +1,77 @@
# Deep Transformers with Latent Depth (Li et al., 2020)
[https://arxiv.org/abs/2009.13102] (https://arxiv.org/abs/2009.13102).
## Introduction
We present a probabilistic framework to automatically learn which layer(s) to use by learning the posterior distributions of layer selection. As an extension of this framework, we propose a novel method to train one shared Transformer network for multilingual machine translation with different layer selection posteriors for each language pair.
## Training a multilingual model with latent depth
Below is an example of training with latent depth in decoder for one-to-many (O2M) related languages. We use the same preprocessed (numberized and binarized) TED8 dataset as in [Balancing Training for Multilingual Neural Machine Translation (Wang et al., 2020)] (https://github.com/cindyxinyiwang/multiDDS), which could be generated by [the script] (https://github.com/cindyxinyiwang/multiDDS/blob/multiDDS/util_scripts/prepare_multilingual_data.sh) the author provided.
```bash
lang_pairs_str="eng-aze,eng-bel,eng-ces,eng-glg,eng-por,eng-rus,eng-slk,eng-tur"
databin_dir=<path to binarized data>
python fairseq_cli/train.py ${databin_dir} \
--user-dir, examples/latent_depth/src \
--lang-pairs "${lang_pairs_str}" \
--arch multilingual_transformer_iwslt_de_en \
--task multilingual_translation_latent_depth \
--criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
--share-encoders \
--share-decoders \
--decoder-langtok \
--share-decoder-input-output-embed \
--dropout 0.3 --attention-dropout 0.3 \
--optimizer adam --adam-eps 1e-06 --adam-betas '(0.9, 0.98)' \
--lr-scheduler inverse_sqrt --min-lr 1e-9 --warmup-init-lr 1e-7 --warmup-updates 8000 \
--max-tokens 4096 --update-freq 1 \
--lr 0.0015 \
--clip-norm 1.0 \
--seed 2 \
--ddp-backend=no_c10d \
--encoder-layers 12 \
--decoder-layers 24 \
--decoder-latent-layer \
--sparsity-weight 0.1 \
--anneal-updates 5000 \
--soft-update 500 \
--target-layers 12 \
--share-weight 0.1
```
## Inference command
```bash
lang_pairs_str="eng-aze,eng-bel,eng-ces,eng-glg,eng-por,eng-rus,eng-slk,eng-tur"
databin_dir=<path to binarized data>
model_path=<path to checkpoint>
src_lang=<source language to translate from>
tgt_lang=<target language to translate to>
gen_data=<name of data split, e.g. valid, test, etc>
python fairseq_cli/generate.py ${databin_dir} \
--path ${model_path} \
--task multilingual_translation_latent_depth \
--decoder-latent-layer \
--lang-pairs "${lang_pairs_str}" \
-s ${src_lang} -t ${tgt_lang} \
--gen-subset $gen_data \
--scoring sacrebleu \
--remove-bpe 'sentencepiece' \
--lenpen 1.0 \
--beam 5 \
--decoder-langtok \
--max-tokens 4096
```
## Citation
```bibtex
@article{li2020deep,
title={Deep Transformers with Latent Depth},
author={Li, Xian and Stickland, Asa Cooper and Tang, Yuqing and Kong, Xiang},
journal={arXiv preprint arXiv:2009.13102},
year={2020}
}
```

View File

@ -0,0 +1,9 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from .models import latent_multilingual_transformer # noqa
from .modules import latent_layers # noqa
from .loss import latent_depth # noqa
from . import multilingual_translation_latent_depth # noqa

View File

@ -0,0 +1,86 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
import math
from torch.nn.modules.loss import _Loss
class LatentLayersKLLoss(_Loss):
def __init__(self, args):
super().__init__()
self.args = args
def forward(self, layer_samples, lang_idx, update_num, sample_size):
prior = self.args.prior
samples = layer_samples[lang_idx]
eps = 1e-7
if prior == "uniform":
# uniform prior
kl_loss = (samples * (
torch.log(samples + eps) - math.log(0.5)
)).sum(-1)
elif prior == "agged_posterior":
# aggregated posterior
y_t = torch.stack([x.detach() for x in layer_samples], dim=0)
agged_q = torch.sum(y_t, dim=0)
row_norm = agged_q.sum(-1)
normed_agg_q = agged_q / row_norm
kl_loss = (samples * (
torch.log(samples + eps) - torch.log(normed_agg_q + eps))).sum(-1)
else:
raise NotImplementedError("The specified prior is not implemented.")
# normalized by number of layers
kl_loss /= layer_samples[0].size()[0]
kl_weight = min(
self.args.sparsity_weight,
(update_num - self.args.soft_update) * self.args.sparsity_weight / self.args.anneal_updates
)
kl_loss *= kl_weight * sample_size
return kl_loss
class LatentLayersSparsityLoss(_Loss):
def __init__(self, args):
super().__init__()
self.args = args
def is_valid(self, update_num):
if self.args.target_layers <= 0:
return False
return update_num > (self.args.soft_update + self.args.anneal_updates)
def forward(self, layer_samples_list, update_num, sample_size):
batch_loss = 0
share_loss = 0
global_sparsity_loss = 0
layer_samples = torch.stack(layer_samples_list, dim=0)
if ((self.args.target_layers > 0 or self.args.share_weight > 0) and
update_num > (self.args.soft_update + self.args.anneal_updates)):
# anneal sparsity weight
if update_num < (self.args.anneal_updates + self.args.soft_update):
weight_anneal = 0
elif update_num < (2 * self.args.anneal_updates + self.args.soft_update):
weight_anneal = (
(update_num - self.args.soft_update - self.args.anneal_updates)
* self.args.share_weight / self.args.anneal_updates
)
else:
weight_anneal = 1
# compute ratio among languages
layer_utilization = torch.sum(layer_samples, dim=0)
layer_utilization /= layer_samples.size()[0]
if self.args.share_weight > 0:
# encouraging sharing across languages
share_loss = sum(-1.0 * v * math.log(v) for v in layer_utilization if v > 0)
batch_loss += weight_anneal * self.args.share_weight * sample_size * share_loss
if self.args.target_layers > 0:
# computed expected number of layers selected
expeted_layers = sum(layer_utilization)
# compute l2 loss wrt target number of layers
global_sparsity_loss = (expeted_layers - self.args.target_layers) ** 2
batch_loss += weight_anneal * self.args.share_weight * sample_size * global_sparsity_loss
return batch_loss

View File

@ -0,0 +1,60 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from fairseq.models import (
register_model,
register_model_architecture,
)
from fairseq.models.transformer import (
base_architecture,
TransformerEncoder,
TransformerDecoder,
)
from fairseq.models.multilingual_transformer import MultilingualTransformerModel
from .latent_transformer import (
LatentTransformerEncoder,
LatentTransformerDecoder,
)
@register_model('latent_multilingual_transformer')
class LatentMultilingualTransformerModel(MultilingualTransformerModel):
"""A variant of standard multilingual Transformer models which encoder and/or
decoders supports latent depth, as is in "Deep Transformer with Latent Depth"
(https://arxiv.org/abs/2009.13102).
"""
@classmethod
def _get_module_class(cls, is_encoder, args, lang_dict, embed_tokens, langs):
if is_encoder:
if hasattr(args, "encoder_latent_layer") and args.encoder_latent_layer:
return LatentTransformerEncoder(args, lang_dict, embed_tokens, num_logits=len(langs))
else:
return TransformerEncoder(args, lang_dict, embed_tokens)
else:
if hasattr(args, "decoder_latent_layer") and args.decoder_latent_layer:
return LatentTransformerDecoder(
args, lang_dict, embed_tokens, num_logits=len(langs)
)
else:
return TransformerDecoder(args, lang_dict, embed_tokens)
@register_model_architecture('latent_multilingual_transformer', 'latent_multilingual_transformer')
def latent_multilingual_architecture(args):
args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 512)
args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 1024)
args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 4)
args.encoder_layers = getattr(args, 'encoder_layers', 12)
args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 512)
args.decoder_ffn_embed_dim = getattr(args, 'decoder_ffn_embed_dim', 1024)
args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 4)
args.decoder_layers = getattr(args, 'decoder_layers', 24)
args.share_encoders = getattr(args, 'share_encoders', True)
args.share_decoders = getattr(args, 'share_decoders', True)
args.share_encoder_embeddings = getattr(args, 'share_encoder_embeddings', True)
args.share_decoder_embeddings = getattr(args, 'share_decoder_embeddings', True)
base_architecture(args)

View File

@ -0,0 +1,130 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import Any, Dict, Optional
import torch.nn as nn
from fairseq.models.fairseq_encoder import EncoderOut
from fairseq.models.transformer import TransformerEncoder, TransformerDecoder
from fairseq.modules import TransformerEncoderLayer, TransformerDecoderLayer
from ..modules.latent_layers import LayerSelect
from torch import Tensor
class LatentTransformerEncoder(TransformerEncoder):
"""Latent depth (https://arxiv.org/abs/2009.13102) implemented in
TransformerEncoder.
"""
def __init__(self, args, dictionary, embed_tokens, num_logits=1):
self.num_logits = num_logits
self.num_layers = args.encoder_layers
super().__init__(args, dictionary, embed_tokens)
self.layer_select = LayerSelect(self.num_layers, self.num_logits, args)
self.lang_idx = None
self.layers = nn.ModuleList([
self._build_encoder_layer(args, idx)
for idx in range(args.encoder_layers)
])
def set_lang_idx(self, lang_idx):
self.lang_idx = lang_idx
def _build_encoder_layer(self, args, idx=None):
return LatentTransformerEncoderLayer(args, idx, layer_select=self.layer_select)
def forward(self, src_tokens, src_lengths, return_all_hiddens: bool = False):
self.layer_select.sample(self.lang_idx)
return super().forward(src_tokens, src_lengths, return_all_hiddens)
class LatentTransformerEncoderLayer(TransformerEncoderLayer):
"""Encoder layer with each (non_residual) block weighted by samples of Bernouli
or Gumbel Signmoid samples.
Args:
args (argparse.Namespace): parsed command-line arguments from standard
TransformerEncoderLayer.
idx (int): layer index (used to retrieve samples).
layer_select (LayerSelect, optional): instance of LayerSelect module with logits
parameters and sampling method.
"""
def __init__(self, args, idx, layer_select=None):
super().__init__(args)
self.idx = idx
self.layer_select = layer_select
def residual_connection(self, x, residual):
return residual + x * self.layer_select(self.idx)
class LatentTransformerDecoder(TransformerDecoder):
"""Latent depth (https://arxiv.org/abs/2009.13102) implemented in
TransformerDecoder.
"""
def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False, num_logits=1):
self.num_logits = num_logits
self.num_layers = args.decoder_layers
super().__init__(
args, dictionary, embed_tokens, no_encoder_attn=no_encoder_attn
)
self.layer_select = LayerSelect(self.num_layers, self.num_logits, args)
self.lang_idx = None
self.layers = nn.ModuleList([
self._build_decoder_layer(args, no_encoder_attn, idx)
for idx in range(args.decoder_layers)
])
def set_lang_idx(self, lang_idx):
self.lang_idx = lang_idx
def _build_decoder_layer(self, args, no_encoder_attn=False, idx=None):
return LatentTransformerDecoderLayer(args, idx, layer_select=self.layer_select, no_encoder_attn=no_encoder_attn)
def forward(
self,
prev_output_tokens,
encoder_out: Optional[EncoderOut] = None,
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
features_only: bool = False,
alignment_layer: Optional[int] = None,
alignment_heads: Optional[int] = None,
src_lengths: Optional[Any] = None,
return_all_hiddens: bool = False,
):
self.layer_select.sample(self.lang_idx)
return super().forward(
prev_output_tokens=prev_output_tokens,
encoder_out=encoder_out,
incremental_state=incremental_state,
features_only=features_only,
alignment_layer=alignment_layer,
src_lengths=src_lengths,
return_all_hiddens=return_all_hiddens,
)
class LatentTransformerDecoderLayer(TransformerDecoderLayer):
"""Decoder layer with each (non_residual) block weighted by samples of Bernouli
or Gumbel Signmoid samples.
Args:
args (argparse.Namespace): parsed command-line arguments from standard
TransformerDecoderLayer.
idx (int): layer index (used to retrieve samples).
layer_select (LayerSelect, optional): instance of LayerSelect module with logits
parameters and sampling method.
no_encoder_attn (bool, optional): whether to attend to encoder outputs
(default: False).
"""
def __init__(
self, args, idx, layer_select=None, no_encoder_attn=False, add_bias_kv=False, add_zero_attn=False
):
super().__init__(args, no_encoder_attn, add_bias_kv, add_zero_attn)
self.idx = idx
self.layer_select = layer_select
def residual_connection(self, x, residual):
return residual + x * self.layer_select(self.idx)

View File

@ -0,0 +1,73 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn as nn
class LayerSelect(nn.Module):
"""Compute samples (from a Gumbel-Sigmoid distribution) which is used as
either (soft) weighting or (hard) selection of residual connection.
https://arxiv.org/abs/2009.13102
"""
def __init__(self, num_layers, num_logits, args):
super(LayerSelect, self).__init__()
self.args = args
self.layer_logits = torch.nn.Parameter(
torch.Tensor(num_logits, num_layers),
requires_grad=True,
)
self.hard_select = not (hasattr(args, "soft_select") and args.soft_select)
self.tau = getattr(args, "sampling_tau", 5)
self.detach_grad = False
self.layer_samples = [None] * num_logits
@staticmethod
def add_args(parser):
parser.add_argument(
'--soft-select',
action='store_true',
help='use soft samples in training an inference'
)
parser.add_argument('--sampling-tau', type=float, help='sampling temperature')
def sample(self, logit_idx):
""" To leverage the efficiency of distributed training, samples for all
layers are computed at once for each logit_idx. Logits are parameters
learnt independent of each other.
Args:
logit_idx: The index of logit parameters used for sampling.
"""
assert logit_idx is not None
self.samples = self._gumbel_sigmoid(
self.layer_logits[logit_idx, :].detach() if self.detach_grad else self.layer_logits[logit_idx, :],
dim=-1,
tau=self.tau,
hard=self.hard_select,
)
self.layer_samples[logit_idx] = self.samples
def forward(self, i):
sample = self.samples[i]
return sample
def _gumbel_sigmoid(self, logits, tau=1, hard=False, eps=1e-10, dim=-1, threshold=0.5):
# ~Gumbel(0,1)
gumbels1 = -torch.empty_like(logits, memory_format=torch.legacy_contiguous_format).exponential_().log()
gumbels2 = -torch.empty_like(logits, memory_format=torch.legacy_contiguous_format).exponential_().log()
# Difference of two gumbels because we apply a sigmoid
gumbels1 = (logits + gumbels1 - gumbels2) / tau
y_soft = gumbels1.sigmoid()
if hard:
# Straight through.
y_hard = torch.zeros_like(
logits, memory_format=torch.legacy_contiguous_format
).masked_fill(y_soft > threshold, 1.0)
ret = y_hard - y_soft.detach() + y_soft
else:
# Reparametrization trick.
ret = y_soft
return ret

View File

@ -0,0 +1,156 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from fairseq.tasks import register_task
from fairseq.tasks.multilingual_translation import MultilingualTranslationTask
from .loss.latent_depth import LatentLayersKLLoss, LatentLayersSparsityLoss
@register_task('multilingual_translation_latent_depth')
class MultilingualTranslationTaskLatentDepth(MultilingualTranslationTask):
"""A task for multiple translation with latent depth.
See `"Deep Transformer with Latent Depth"
(Li et al., 2020) <https://arxiv.org/pdf/2009.13102.pdf>`_.
"""
@staticmethod
def add_args(parser):
"""Add task-specific arguments to the parser."""
# fmt: off
MultilingualTranslationTask.add_args(parser)
parser.add_argument('--encoder-latent-layer', action='store_true', help='latent layer selection in encoder')
parser.add_argument('--decoder-latent-layer', action='store_true', help='latent layer selection in decoder')
parser.add_argument('--target-layers', default=-1, type=int,
help='number of effective layers to learn; -1 means no constraint')
parser.add_argument('--sparsity-weight', default=0.0, type=float,
help='weight for sparsity loss')
parser.add_argument('--share-weight', default=0.0, type=float,
help='weight for sharing loss')
parser.add_argument('--soft-update', default=1, type=int,
help='number of updates with soft sampling')
parser.add_argument('--anneal-updates', default=1, type=int,
help='number of updates to anneal the KL loss weight')
parser.add_argument('--prior', default="uniform", type=str,
help='prior used for computing KL loss')
# fmt: on
def __init__(self, args, dicts, training):
super().__init__(args, dicts, training)
self.src_langs, self.tgt_langs = zip(*[(lang.split("-")[0], lang.split("-")[1]) for lang in args.lang_pairs])
if self.training and self.encoder_latent_layer:
assert self.args.share_encoders
if self.training and self.decoder_latent_layer:
assert self.args.share_decoders
if training or self.encoder_latent_layer or self.decoder_latent_layer:
self.lang_pairs = args.lang_pairs
else:
self.lang_pairs = ['{}-{}'.format(args.source_lang, args.target_lang)]
self.eval_lang_pairs = self.lang_pairs
self.model_lang_pairs = self.lang_pairs
if self.training and (self.encoder_latent_layer or self.decoder_latent_layer):
self.kl_loss = LatentLayersKLLoss(self.args)
self.sparsity_loss = LatentLayersSparsityLoss(self.args)
def _per_lang_pair_train_loss(self, lang_pair, model, update_num, criterion, sample, optimizer, ignore_grad):
src, tgt = lang_pair.split("-")
if self.encoder_latent_layer:
src_lang_idx = self.src_lang_idx_dict[src]
model.models[lang_pair].encoder.set_lang_idx(src_lang_idx)
model.models[lang_pair].encoder.layer_select.hard_select = update_num > self.args.soft_update
if self.decoder_latent_layer:
tgt_lang_idx = self.tgt_lang_idx_dict[tgt]
model.models[lang_pair].decoder.set_lang_idx(tgt_lang_idx)
model.models[lang_pair].decoder.layer_select.hard_select = update_num > self.args.soft_update
loss, sample_size, logging_output = criterion(model.models[lang_pair], sample[lang_pair])
if self.encoder_latent_layer:
none_samples = sum(
1 if x is None else 0 for x in model.models[lang_pair].encoder.layer_select.layer_samples
)
if none_samples == 0 or self.args.prior != "agged_posterior":
loss += self.kl_loss(
model.models[lang_pair].encoder.layer_select.layer_samples,
src_lang_idx,
update_num,
sample_size
)
if self.decoder_latent_layer:
none_samples = sum(
1 if x is None else 0 for x in model.models[lang_pair].decoder.layer_select.layer_samples
)
if none_samples == 0 or self.args.prior != "agged_posterior":
loss += self.kl_loss(
model.models[lang_pair].decoder.layer_select.layer_samples,
tgt_lang_idx,
update_num,
sample_size
)
if ignore_grad:
loss *= 0
if hasattr(self, "sparsity_loss") and self.sparsity_loss.is_valid(update_num):
# need to retain the graph if sparsity loss needs to be added
loss.backward(retain_graph=True)
else:
optimizer.backward(loss)
return loss, sample_size, logging_output
def train_step(self, sample, model, criterion, optimizer, update_num, ignore_grad=False):
agg_loss, agg_sample_size, agg_logging_output = super().train_step(
sample, model, criterion, optimizer, update_num, ignore_grad)
# compute auxiliary loss from layere sparsity, based on all samples from all languages
if hasattr(self, "sparsity_loss") and self.sparsity_loss.is_valid(update_num):
sparsity_loss = 0
if self.encoder_latent_layer:
sparsity_loss += self.sparsity_loss(
next(iter(model.models.values())).encoder.layer_select.layer_samples, update_num, agg_sample_size)
if self.decoder_latent_layer:
sparsity_loss += self.sparsity_loss(
next(iter(model.models.values())).decoder.layer_select.layer_samples, update_num, agg_sample_size)
if sparsity_loss > 0:
optimizer.backward(sparsity_loss)
return agg_loss, agg_sample_size, agg_logging_output
def _per_lang_pair_valid_loss(self, lang_pair, model, criterion, sample):
src, tgt = lang_pair.split("-")
if self.encoder_latent_layer:
src_lang_idx = self.src_lang_idx_dict[src]
model.models[lang_pair].encoder.set_lang_idx(src_lang_idx)
if self.decoder_latent_layer:
tgt_lang_idx = self.tgt_lang_idx_dict[tgt]
model.models[lang_pair].decoder.set_lang_idx(tgt_lang_idx)
loss, sample_size, logging_output = criterion(model.models[lang_pair], sample[lang_pair])
return loss, sample_size, logging_output
def inference_step(self, generator, models, sample, prefix_tokens=None, constraints=None):
if self.encoder_latent_layer or self.decoder_latent_layer:
for model in models:
if self.encoder_latent_layer:
assert model.encoder.layer_select is not None
src_lang_idx = self.src_lang_idx_dict[self.args.source_lang]
model.encoder.set_lang_idx(src_lang_idx)
if self.decoder_latent_layer:
assert model.decoder.layer_select is not None
tgt_lang_idx = self.tgt_lang_idx_dict[self.args.target_lang]
model.decoder.set_lang_idx(tgt_lang_idx)
return super().inference_step(generator, models, sample, prefix_tokens, constraints)
@property
def encoder_latent_layer(self):
return hasattr(self.args, "encoder_latent_layer") and self.args.encoder_latent_layer
@property
def decoder_latent_layer(self):
return hasattr(self.args, "decoder_latent_layer") and self.args.decoder_latent_layer
@property
def src_lang_idx_dict(self):
return {lang: lang_idx for lang_idx, lang in enumerate(self.src_langs)}
@property
def tgt_lang_idx_dict(self):
return {lang: lang_idx for lang_idx, lang in enumerate(self.tgt_langs)}

View File

@ -136,7 +136,8 @@ class MultilingualTransformerModel(FairseqMultiModel):
encoder_embed_tokens = build_embedding(
task.dicts[lang], args.encoder_embed_dim, args.encoder_embed_path
)
lang_encoders[lang] = TransformerEncoder(args, task.dicts[lang], encoder_embed_tokens)
lang_encoders[lang] = cls._get_module_class(
True, args, task.dicts[lang], encoder_embed_tokens, src_langs)
return lang_encoders[lang]
def get_decoder(lang):
@ -147,7 +148,8 @@ class MultilingualTransformerModel(FairseqMultiModel):
decoder_embed_tokens = build_embedding(
task.dicts[lang], args.decoder_embed_dim, args.decoder_embed_path
)
lang_decoders[lang] = TransformerDecoder(args, task.dicts[lang], decoder_embed_tokens)
lang_decoders[lang] = cls._get_module_class(
False, args, task.dicts[lang], decoder_embed_tokens, tgt_langs)
return lang_decoders[lang]
# shared encoders/decoders (if applicable)
@ -164,6 +166,11 @@ class MultilingualTransformerModel(FairseqMultiModel):
return MultilingualTransformerModel(encoders, decoders)
@classmethod
def _get_module_class(cls, is_encoder, args, lang_dict, embed_tokens, langs):
module_class = TransformerEncoder if is_encoder else TransformerDecoder
return module_class(args, lang_dict, embed_tokens)
def load_state_dict(self, state_dict, strict=True, args=None):
state_dict_subset = state_dict.copy()
for k, _ in state_dict.items():

View File

@ -72,6 +72,9 @@ class TransformerEncoderLayer(nn.Module):
qn_block_size=self.quant_noise_block_size,
)
def residual_connection(self, x, residual):
return residual + x
def upgrade_state_dict_named(self, state_dict, name):
"""
Rename layer norm states from `...layer_norms.0.weight` to
@ -121,7 +124,7 @@ class TransformerEncoderLayer(nn.Module):
attn_mask=attn_mask,
)
x = self.dropout_module(x)
x = residual + x
x = self.residual_connection(x, residual)
if not self.normalize_before:
x = self.self_attn_layer_norm(x)
@ -133,7 +136,7 @@ class TransformerEncoderLayer(nn.Module):
x = self.activation_dropout_module(x)
x = self.fc2(x)
x = self.dropout_module(x)
x = residual + x
x = self.residual_connection(x, residual)
if not self.normalize_before:
x = self.final_layer_norm(x)
return x
@ -243,6 +246,9 @@ class TransformerDecoderLayer(nn.Module):
def prepare_for_onnx_export_(self):
self.onnx_trace = True
def residual_connection(self, x, residual):
return residual + x
def forward(
self,
x,
@ -320,7 +326,7 @@ class TransformerDecoderLayer(nn.Module):
attn_mask=self_attn_mask,
)
x = self.dropout_module(x)
x = residual + x
x = self.residual_connection(x, residual)
if not self.normalize_before:
x = self.self_attn_layer_norm(x)
@ -350,7 +356,7 @@ class TransformerDecoderLayer(nn.Module):
need_head_weights=need_head_weights,
)
x = self.dropout_module(x)
x = residual + x
x = self.residual_connection(x, residual)
if not self.normalize_before:
x = self.encoder_attn_layer_norm(x)
@ -362,7 +368,7 @@ class TransformerDecoderLayer(nn.Module):
x = self.activation_dropout_module(x)
x = self.fc2(x)
x = self.dropout_module(x)
x = residual + x
x = self.residual_connection(x, residual)
if not self.normalize_before:
x = self.final_layer_norm(x)
if self.onnx_trace and incremental_state is not None:

View File

@ -264,6 +264,13 @@ class MultilingualTranslationTask(LegacyFairseqTask):
raise ValueError('MultilingualTranslationTask requires a FairseqMultiModel architecture')
return model
def _per_lang_pair_train_loss(self, lang_pair, model, update_num, criterion, sample, optimizer, ignore_grad):
loss, sample_size, logging_output = criterion(model.models[lang_pair], sample[lang_pair])
if ignore_grad:
loss *= 0
optimizer.backward(loss)
return loss, sample_size, logging_output
def train_step(self, sample, model, criterion, optimizer, update_num, ignore_grad=False):
model.train()
from collections import defaultdict
@ -285,10 +292,8 @@ class MultilingualTranslationTask(LegacyFairseqTask):
else:
return contextlib.ExitStack() # dummy contextmanager
with maybe_no_sync():
loss, sample_size, logging_output = criterion(model.models[lang_pair], sample[lang_pair])
if ignore_grad:
loss *= 0
optimizer.backward(loss)
loss, sample_size, logging_output = self._per_lang_pair_train_loss(
lang_pair, model, update_num, criterion, sample, optimizer, ignore_grad)
agg_loss += loss.detach().item()
# TODO make summing of the sample sizes configurable
agg_sample_size += sample_size
@ -297,6 +302,9 @@ class MultilingualTranslationTask(LegacyFairseqTask):
agg_logging_output[f"{lang_pair}:{k}"] += logging_output[k]
return agg_loss, agg_sample_size, agg_logging_output
def _per_lang_pair_valid_loss(self, lang_pair, model, criterion, sample):
return criterion(model.models[lang_pair], sample[lang_pair])
def valid_step(self, sample, model, criterion):
model.eval()
with torch.no_grad():
@ -305,7 +313,7 @@ class MultilingualTranslationTask(LegacyFairseqTask):
for lang_pair in self.eval_lang_pairs:
if lang_pair not in sample or sample[lang_pair] is None or len(sample[lang_pair]) == 0:
continue
loss, sample_size, logging_output = criterion(model.models[lang_pair], sample[lang_pair])
loss, sample_size, logging_output = self._per_lang_pair_valid_loss(lang_pair, model, criterion, sample)
agg_loss += loss.data.item()
# TODO make summing of the sample sizes configurable
agg_sample_size += sample_size

View File

@ -207,6 +207,52 @@ class TestTranslation(unittest.TestCase):
] + enc_ltok_flag + dec_ltok_flag,
)
def test_multilingual_translation_latent_depth(self):
# test with latent depth in encoder, decoder, or both
encoder_latent_layer = [[], ['--encoder-latent-layer']]
decoder_latent_layer = [[], ['--decoder-latent-layer']]
with contextlib.redirect_stdout(StringIO()):
for i in range(len(encoder_latent_layer)):
for j in range(len(decoder_latent_layer)):
if i == 0 and j == 0:
continue
enc_ll_flag = encoder_latent_layer[i]
dec_ll_flag = decoder_latent_layer[j]
with tempfile.TemporaryDirectory(f'test_multilingual_translation_latent_depth_{i}_{j}') as data_dir:
create_dummy_data(data_dir)
preprocess_translation_data(
data_dir,
extra_flags=['--joined-dictionary']
)
train_translation_model(
data_dir,
arch='latent_multilingual_transformer',
task='multilingual_translation_latent_depth',
extra_flags=[
'--user-dir', 'examples/latent_depth/src',
'--encoder-layers', '2',
'--decoder-layers', '2',
'--encoder-embed-dim', '8',
'--decoder-embed-dim', '8',
'--share-encoders',
'--share-decoders',
'--sparsity-weight', '0.1',
] + enc_ll_flag + dec_ll_flag,
lang_flags=['--lang-pairs', 'in-out,out-in'],
run_validation=True,
extra_valid_flags=['--user-dir', 'examples/latent_depth/src'] + enc_ll_flag + dec_ll_flag,
)
generate_main(
data_dir,
extra_flags=[
'--user-dir', 'examples/latent_depth/src',
'--task', 'multilingual_translation_latent_depth',
'--lang-pairs', 'in-out,out-in',
'--source-lang', 'in',
'--target-lang', 'out',
] + enc_ll_flag + dec_ll_flag,
)
def test_translation_multi_simple_epoch(self):
# test with all combinations of encoder/decoder lang tokens
encoder_langtok_flags = [[], ['--encoder-langtok', 'src'], ['--encoder-langtok', 'tgt']]