Add code for mixture of experts (#521)

Summary:
Code for the paper: [Mixture Models for Diverse Machine Translation: Tricks of the Trade (Shen et al., 2019)](https://arxiv.org/abs/1902.07816).
Pull Request resolved: https://github.com/pytorch/fairseq/pull/521

Differential Revision: D14188021

Pulled By: myleott

fbshipit-source-id: ed5b1ed5ad9a582359bd5215fa2ea26dc76c673e
This commit is contained in:
Myle Ott 2019-02-22 13:11:22 -08:00 committed by Facebook Github Bot
parent b65c579bed
commit 4294c4f6d7
10 changed files with 435 additions and 18 deletions

View File

@ -5,19 +5,20 @@ developers to train custom models for translation, summarization, language
modeling and other text generation tasks. It provides reference implementations
of various sequence-to-sequence models, including:
- **Convolutional Neural Networks (CNN)**
- [Dauphin et al. (2017): Language Modeling with Gated Convolutional Networks](https://arxiv.org/abs/1612.08083)
- [Gehring et al. (2017): Convolutional Sequence to Sequence Learning](https://arxiv.org/abs/1705.03122)
- [Edunov et al. (2018): Classical Structured Prediction Losses for Sequence to Sequence Learning](https://arxiv.org/abs/1711.04956)
- [Fan et al. (2018): Hierarchical Neural Story Generation](https://arxiv.org/abs/1805.04833)
- [Dauphin et al. (2017): Language Modeling with Gated Convolutional Networks](examples/conv_lm/README.md)
- [Gehring et al. (2017): Convolutional Sequence to Sequence Learning](examples/conv_seq2seq/README.md)
- [Edunov et al. (2018): Classical Structured Prediction Losses for Sequence to Sequence Learning](https://github.com/pytorch/fairseq/tree/classic_seqlevel)
- [Fan et al. (2018): Hierarchical Neural Story Generation](examples/stories/README.md)
- **LightConv and DynamicConv models**
- **_New_** [Wu et al. (2019): Pay Less Attention with Lightweight and Dynamic Convolutions](https://openreview.net/pdf?id=SkVhlh09tX)
- **_New_** [Wu et al. (2019): Pay Less Attention with Lightweight and Dynamic Convolutions](examples/pay_less_attention_paper/README.md)
- **Long Short-Term Memory (LSTM) networks**
- [Luong et al. (2015): Effective Approaches to Attention-based Neural Machine Translation](https://arxiv.org/abs/1508.04025)
- [Wiseman and Rush (2016): Sequence-to-Sequence Learning as Beam-Search Optimization](https://arxiv.org/abs/1606.02960)
- **Transformer (self-attention) networks**
- [Vaswani et al. (2017): Attention Is All You Need](https://arxiv.org/abs/1706.03762)
- [Ott et al. (2018): Scaling Neural Machine Translation](https://arxiv.org/abs/1806.00187)
- [Ott et al. (2018): Scaling Neural Machine Translation](examples/scaling_nmt/README.md)
- [Edunov et al. (2018): Understanding Back-Translation at Scale](https://arxiv.org/abs/1808.09381)
- **_New_** [Shen et al. (2019) Mixture Models for Diverse Machine Translation: Tricks of the Trade](examples/translation_moe/README.md)
Fairseq features:
- multi-GPU (distributed) training on one machine or across multiple machines
@ -74,6 +75,7 @@ as well as example training and evaluation commands.
- [Language Modeling](examples/language_model/README.md): convolutional models are available
We also have more detailed READMEs to reproduce results from specific papers:
- [Shen et al. (2019) Mixture Models for Diverse Machine Translation: Tricks of the Trade](examples/translation_moe/README.md)
- [Wu et al. (2019): Pay Less Attention with Lightweight and Dynamic Convolutions](examples/pay_less_attention_paper/README.md)
- [Edunov et al. (2018): Classical Structured Prediction Losses for Sequence to Sequence Learning](https://github.com/pytorch/fairseq/tree/classic_seqlevel)
- [Fan et al. (2018): Hierarchical Neural Story Generation](examples/stories/README.md)

View File

@ -112,22 +112,22 @@ $ bash prepare-wmt14en2de.sh
$ cd ../..
# Binarize the dataset:
$ TEXT=examples/translation/wmt14_en_de
$ TEXT=examples/translation/wmt17_en_de
$ fairseq-preprocess --source-lang en --target-lang de \
--trainpref $TEXT/train --validpref $TEXT/valid --testpref $TEXT/test \
--destdir data-bin/wmt14_en_de --thresholdtgt 0 --thresholdsrc 0
--destdir data-bin/wmt17_en_de --thresholdtgt 0 --thresholdsrc 0
# Train the model:
# If it runs out of memory, try to set --max-tokens 1500 instead
$ mkdir -p checkpoints/fconv_wmt_en_de
$ fairseq-train data-bin/wmt14_en_de \
$ fairseq-train data-bin/wmt17_en_de \
--lr 0.5 --clip-norm 0.1 --dropout 0.2 --max-tokens 4000 \
--criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
--lr-scheduler fixed --force-anneal 50 \
--arch fconv_wmt_en_de --save-dir checkpoints/fconv_wmt_en_de
# Generate:
$ fairseq-generate data-bin/wmt14_en_de \
$ fairseq-generate data-bin/wmt17_en_de \
--path checkpoints/fconv_wmt_en_de/checkpoint_best.pt --beam 5 --remove-bpe
```

View File

@ -41,6 +41,9 @@ if [ "$1" == "--icml17" ]; then
URLS[2]="http://statmt.org/wmt14/training-parallel-nc-v9.tgz"
FILES[2]="training-parallel-nc-v9.tgz"
CORPORA[2]="training/news-commentary-v9.de-en"
OUTDIR=wmt14_en_de
else
OUTDIR=wmt17_en_de
fi
if [ ! -d "$SCRIPTS" ]; then
@ -51,7 +54,7 @@ fi
src=en
tgt=de
lang=en-de
prep=wmt14_en_de
prep=$OUTDIR
tmp=$prep/tmp
orig=orig
dev=dev/newstest2013

View File

@ -0,0 +1,87 @@
# Mixture Models for Diverse Machine Translation: Tricks of the Trade (Shen et al., 2019)
This page includes instructions for reproducing results from the paper [Mixture Models for Diverse Machine Translation: Tricks of the Trade (Shen et al., 2019)](https://arxiv.org/abs/1902.07816).
## Training a new model on WMT'17 En-De
First, follow the [instructions to download and preprocess the WMT'17 En-De dataset](../translation#prepare-wmt14en2desh).
Make sure to learn a joint vocabulary by passing the `--joined-dictionary` option to `fairseq-preprocess`.
Then we can train a mixture of experts model using the `translation_moe` task.
Use the `--method` option to choose the MoE variant; we support hard mixtures with a learned or uniform prior (`--method hMoElp` and `hMoEup`, respectively) and soft mixures (`--method sMoElp` and `sMoEup`).
To train a hard mixture of experts model with a learned prior (`hMoElp`) on 1 GPU:
```
$ CUDA_VISIBLE_DEVICES=0 fairseq-train data-bin/wmt17_en_de \
--max-update 100000 \
--task translation_moe \
--method hMoElp --mean-pool-gating-network \
--num-experts 3 \
--arch transformer_vaswani_wmt_en_de --share-all-embeddings \
--optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \
--lr-scheduler inverse_sqrt --warmup-init-lr 1e-07 --warmup-updates 4000 \
--lr 0.0007 --min-lr 1e-09 \
--dropout 0.1 --weight-decay 0.0 --criterion cross_entropy \
--max-tokens 3584 \
--update-freq 8
```
**Note**: the above command assumes 1 GPU, but accumulates gradients from 8 fwd/bwd passes to simulate training on 8 GPUs.
You can accelerate training on up to 8 GPUs by adjusting the `CUDA_VISIBLE_DEVICES` and `--update-freq` options accordingly.
Once a model is trained, we can generate translations from different experts using the `--gen-expert` option.
For example, to generate from expert 0:
```
$ fairseq-generate data-bin/wmt17_en_de \
--path checkpoints/checkpoint_best.pt
--beam 1 --remove-bpe \
--task translation_moe \
--method hMoElp --mean-pool-gating-network \
--num-experts 3 \
--gen-expert 0 \
```
You can also use `scripts/score_moe.py` to compute pairwise BLEU and average oracle BLEU.
We'll first download a tokenized version of the multi-reference WMT'14 En-De dataset:
```
$ wget dl.fbaipublicfiles.com/fairseq/data/wmt14-en-de.extra_refs.tok
```
Next apply BPE on the fly and run generation for each expert:
```
$ BPEROOT=examples/translation/subword-nmt/
$ BPE_CODE=examples/translation/wmt17_en_de/code
$ for EXPERT in $(seq 0 2); do \
cat wmt14-en-de.extra_refs.tok | grep ^S | cut -f 2 | \
python $BPEROOT/apply_bpe.py -c $BPE_CODE | \
fairseq-interactive data-bin/wmt17_en_de \
--path checkpoints/checkpoint_best.pt \
--beam 1 --remove-bpe \
--buffer 500 --max-tokens 6000 ; \
--task translation_moe \
--method hMoElp --mean-pool-gating-network \
--num-experts 3 \
--gen-expert $EXPERT \
done > wmt14-en-de.extra_refs.tok.gen.3experts
```
Finally compute pairwise BLUE and average oracle BLEU:
```
$ python scripts/score_moe.py --sys wmt14-en-de.extra_refs.tok.gen.3experts --ref wmt14-en-de.extra_refs.tok
pairwise BLEU: 48.26
avg oracle BLEU: 49.50
#refs covered: 2.11
```
This reproduces row 3 from Table 7 in the paper.
## Citation
```bibtex
@article{shen2019mixture,
title = {Mixture Models for Diverse Machine Translation: Tricks of the Trade},
author = {Tianxiao Shen and Myle Ott and Michael Auli and Marc'Aurelio Ranzato},
journal = {arXiv preprint arXiv:1902.07816},
year = 2019,
}
```

View File

@ -28,11 +28,7 @@ class CrossEntropyCriterion(FairseqCriterion):
3) logging outputs to display while training
"""
net_output = model(**sample['net_input'])
lprobs = model.get_normalized_probs(net_output, log_probs=True)
lprobs = lprobs.view(-1, lprobs.size(-1))
target = model.get_targets(sample, net_output).view(-1)
loss = F.nll_loss(lprobs, target, size_average=False, ignore_index=self.padding_idx,
reduce=reduce)
loss, _ = self.compute_loss(model, net_output, sample, reduce=reduce)
sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens']
logging_output = {
'loss': utils.item(loss.data) if reduce else loss.data,
@ -42,6 +38,14 @@ class CrossEntropyCriterion(FairseqCriterion):
}
return loss, sample_size, logging_output
def compute_loss(self, model, net_output, sample, reduce=True):
lprobs = model.get_normalized_probs(net_output, log_probs=True)
lprobs = lprobs.view(-1, lprobs.size(-1))
target = model.get_targets(sample, net_output).view(-1)
loss = F.nll_loss(lprobs, target, size_average=False, ignore_index=self.padding_idx,
reduce=reduce)
return loss, loss
@staticmethod
def aggregate_logging_outputs(logging_outputs):
"""Aggregate logging outputs from data parallel training."""

View File

@ -17,6 +17,8 @@ from .highway import Highway
from .learned_positional_embedding import LearnedPositionalEmbedding
from .lightweight_convolution import LightweightConv1dTBC
from .linearized_convolution import LinearizedConvolution
from .logsumexp_moe import LogSumExpMoE
from .mean_pool_gating_network import MeanPoolGatingNetwork
from .multihead_attention import MultiheadAttention
from .scalar_bias import ScalarBias
from .sinusoidal_positional_embedding import SinusoidalPositionalEmbedding
@ -35,6 +37,8 @@ __all__ = [
'LearnedPositionalEmbedding',
'LightweightConv1dTBC',
'LinearizedConvolution',
'LogSumExpMoE',
'MeanPoolGatingNetwork',
'MultiheadAttention',
'ScalarBias',
'SinusoidalPositionalEmbedding',

View File

@ -0,0 +1,28 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import torch
class LogSumExpMoE(torch.autograd.Function):
"""Standard LogSumExp forward pass, but use *posterior* for the backward.
See `"Mixture Models for Diverse Machine Translation: Tricks of the Trade"
(Shen et al., 2019) <https://arxiv.org/abs/1902.07816>`_.
"""
@staticmethod
def forward(ctx, logp, posterior, dim=-1):
ctx.save_for_backward(posterior)
ctx.dim = dim
return torch.logsumexp(logp, dim=dim)
@staticmethod
def backward(ctx, grad_output):
posterior, = ctx.saved_tensors
grad_logp = grad_output.unsqueeze(ctx.dim) * posterior
return grad_logp, None, None

View File

@ -0,0 +1,53 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import torch
import torch.nn.functional as F
class MeanPoolGatingNetwork(torch.nn.Module):
"""A simple mean-pooling gating network for selecting experts.
This module applies mean pooling over an encoder's output and returns
reponsibilities for each expert. The encoder format is expected to match
:class:`fairseq.models.transformer.TransformerEncoder`.
"""
def __init__(self, embed_dim, num_experts, dropout=None):
super().__init__()
self.embed_dim = embed_dim
self.num_experts = num_experts
self.fc1 = torch.nn.Linear(embed_dim, embed_dim)
self.dropout = torch.nn.Dropout(dropout) if dropout is not None else None
self.fc2 = torch.nn.Linear(embed_dim, num_experts)
def forward(self, encoder_out):
if not (
isinstance(encoder_out, dict)
and 'encoder_out' in encoder_out
and 'encoder_padding_mask' in encoder_out
and encoder_out['encoder_out'].size(2) == self.embed_dim
):
raise ValueError('Unexpected format for encoder_out')
# mean pooling over time
encoder_padding_mask = encoder_out['encoder_padding_mask'] # B x T
encoder_out = encoder_out['encoder_out'].transpose(0, 1) # B x T x C
if encoder_padding_mask is not None:
encoder_out = encoder_out.clone() # required because of transpose above
encoder_out[encoder_padding_mask] = 0
ntokens = torch.sum(1 - encoder_padding_mask, dim=1, keepdim=True)
x = torch.sum(encoder_out, dim=1) / ntokens.type_as(encoder_out)
else:
x = torch.mean(encoder_out, dim=1)
x = torch.tanh(self.fc1(x))
if self.dropout is not None:
x = self.dropout(x)
x = self.fc2(x)
return F.log_softmax(x, dim=-1, dtype=torch.float32).type_as(x)

View File

@ -98,7 +98,14 @@ class SequenceGenerator(object):
self.search = search.BeamSearch(tgt_dict)
@torch.no_grad()
def generate(self, models, sample=None, net_input=None, prefix_tokens=None, **kwargs):
def generate(
self,
models,
sample,
prefix_tokens=None,
bos_token=None,
**kwargs
):
"""Generate a batch of translations.
Args:
@ -143,7 +150,7 @@ class SequenceGenerator(object):
scores_buf = scores.clone()
tokens = src_tokens.new(bsz * beam_size, max_len + 2).fill_(self.pad)
tokens_buf = tokens.clone()
tokens[:, 0] = self.eos
tokens[:, 0] = bos_token or self.eos
attn, attn_buf = None, None
nonpad_idxs = None

View File

@ -0,0 +1,229 @@
# Copyright (c) 2017-present, Facebook, Inc. # All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import contextlib
import torch
from fairseq import modules, options, utils
from fairseq.data import (
ConcatDataset,
data_utils,
Dictionary,
IndexedCachedDataset,
IndexedDataset,
IndexedRawTextDataset,
LanguagePairDataset,
)
from . import register_task
from .translation import TranslationTask
@contextlib.contextmanager
def eval(model):
is_training = model.training
model.eval()
yield
model.train(is_training)
@register_task('translation_moe')
class TranslationMoETask(TranslationTask):
"""
Translation task for Mixture of Experts (MoE) models.
See `"Mixture Models for Diverse Machine Translation: Tricks of the Trade"
(Shen et al., 2019) <https://arxiv.org/abs/1902.07816>`_.
Args:
src_dict (Dictionary): dictionary for the source language
tgt_dict (Dictionary): dictionary for the target language
.. note::
The translation task is compatible with :mod:`fairseq-train`,
:mod:`fairseq-generate` and :mod:`fairseq-interactive`.
The translation task provides the following additional command-line
arguments:
.. argparse::
:ref: fairseq.tasks.translation_parser
:prog:
"""
@staticmethod
def add_args(parser):
"""Add task-specific arguments to the parser."""
# fmt: off
TranslationTask.add_args(parser)
parser.add_argument('--method', required=True,
choices=['sMoElp', 'sMoEup', 'hMoElp', 'hMoEup'])
parser.add_argument('--num-experts', type=int, metavar='N', required=True,
help='number of experts')
parser.add_argument('--mean-pool-gating-network', action='store_true',
help='use a simple mean-pooling gating network')
parser.add_argument('--mean-pool-gating-network-dropout', type=float,
help='dropout for mean-pooling gating network')
parser.add_argument('--mean-pool-gating-network-encoder-dim', type=float,
help='encoder output dim for mean-pooling gating network')
parser.add_argument('--gen-expert', type=int, default=0,
help='which expert to use for generation')
# fmt: on
def __init__(self, args, src_dict, tgt_dict):
if args.method == 'sMoElp':
# soft MoE with learned prior
self.uniform_prior = False
self.hard_selection = False
elif args.method == 'sMoEup':
# soft MoE with uniform prior
self.uniform_prior = True
self.hard_selection = False
elif args.method == 'hMoElp':
# hard MoE with learned prior
self.uniform_prior = False
self.hard_selection = True
elif args.method == 'hMoEup':
# hard MoE with uniform prior
self.uniform_prior = True
self.hard_selection = True
# add indicator tokens for each expert
for i in range(args.num_experts):
# add to both dictionaries in case we're sharing embeddings
src_dict.add_symbol('<expert_{}>'.format(i))
tgt_dict.add_symbol('<expert_{}>'.format(i))
super().__init__(args, src_dict, tgt_dict)
def build_model(self, args):
from fairseq import models
model = models.build_model(args, self)
if not self.uniform_prior and not hasattr(model, 'gating_network'):
if self.args.mean_pool_gating_network:
if getattr(args, 'mean_pool_gating_network_encoder_dim', None):
encoder_dim = args.mean_pool_gating_network_encoder_dim
elif getattr(args, 'encoder_embed_dim', None):
# assume that encoder_embed_dim is the encoder's output dimension
encoder_dim = args.encoder_embed_dim
else:
raise ValueError('Must specify --mean-pool-gating-network-encoder-dim')
if getattr(args, 'mean_pool_gating_network_dropout', None):
dropout = args.mean_pool_gating_network_dropout
elif getattr(args, 'dropout', None):
dropout = args.dropout
else:
raise ValueError('Must specify --mean-pool-gating-network-dropout')
model.gating_network = modules.MeanPoolGatingNetwork(
encoder_dim, args.num_experts, dropout,
)
else:
raise ValueError(
'translation_moe task with learned prior requires the model to '
'have a gating network; try using --mean-pool-gating-network'
)
return model
def expert_index(self, i):
return i + self.tgt_dict.index('<expert_0>')
def _get_loss(self, sample, model, criterion):
assert hasattr(criterion, 'compute_loss'), \
'translation_moe task requires the criterion to implement the compute_loss() method'
k = self.args.num_experts
bsz = sample['target'].size(0)
def get_lprob_y(encoder_out, prev_output_tokens_k):
net_output = model.decoder(prev_output_tokens_k, encoder_out)
loss, _ = criterion.compute_loss(model, net_output, sample, reduce=False)
loss = loss.view(bsz, -1)
return -loss.sum(dim=1, keepdim=True) # -> B x 1
def get_lprob_yz(winners=None):
encoder_out = model.encoder(sample['net_input']['src_tokens'], sample['net_input']['src_lengths'])
if winners is None:
lprob_y = []
for i in range(k):
prev_output_tokens_k = sample['net_input']['prev_output_tokens'].clone()
assert not prev_output_tokens_k.requires_grad
prev_output_tokens_k[:, 0] = self.expert_index(i)
lprob_y.append(get_lprob_y(encoder_out, prev_output_tokens_k))
lprob_y = torch.cat(lprob_y, dim=1) # -> B x K
else:
prev_output_tokens_k = sample['net_input']['prev_output_tokens'].clone()
prev_output_tokens_k[:, 0] = self.expert_index(winners)
lprob_y = get_lprob_y(encoder_out, prev_output_tokens_k) # -> B
if self.uniform_prior:
lprob_yz = lprob_y
else:
lprob_z = model.gating_network(encoder_out) # B x K
if winners is not None:
lprob_z = lprob_z.gather(dim=1, index=winners.unsqueeze(-1))
lprob_yz = lprob_y + lprob_z.type_as(lprob_y) # B x K
return lprob_yz
# compute responsibilities without dropout
with eval(model): # disable dropout
with torch.no_grad(): # disable autograd
lprob_yz = get_lprob_yz() # B x K
prob_z_xy = torch.nn.functional.softmax(lprob_yz, dim=1)
assert not prob_z_xy.requires_grad
# compute loss with dropout
if self.hard_selection:
winners = prob_z_xy.max(dim=1)[1]
loss = -get_lprob_yz(winners)
else:
lprob_yz = get_lprob_yz() # B x K
loss = -modules.LogSumExpMoE.apply(lprob_yz, prob_z_xy, 1)
loss = loss.sum()
sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens']
logging_output = {
'loss': utils.item(loss.data),
'ntokens': sample['ntokens'],
'sample_size': sample_size,
'posterior': prob_z_xy.float().sum(dim=0).cpu(),
}
return loss, sample_size, logging_output
def train_step(self, sample, model, criterion, optimizer, ignore_grad=False):
model.train()
loss, sample_size, logging_output = self._get_loss(sample, model, criterion)
if ignore_grad:
loss *= 0
optimizer.backward(loss)
return loss, sample_size, logging_output
def valid_step(self, sample, model, criterion):
model.eval()
with torch.no_grad():
loss, sample_size, logging_output = self._get_loss(sample, model, criterion)
return loss, sample_size, logging_output
def inference_step(self, generator, models, sample, prefix_tokens=None):
with torch.no_grad():
return generator.generate(
models,
sample,
prefix_tokens=prefix_tokens,
bos_token=self.expert_index(self.args.gen_expert),
)
def aggregate_logging_outputs(self, logging_outputs, criterion):
agg_logging_outputs = criterion._aggregate_logging_outputs(logging_outputs)
agg_logging_outputs['posterior'] = sum(
log['posterior'] for log in logging_outputs if 'posterior' in log
)
return agg_logging_outputs