Summary:
# Before submitting

- [x] Was this discussed/approved via a Github issue? (no need for typos, doc improvements)
- [x] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.m)?
- [x] Did you make sure to update the docs?
- [x] Did you write any new necessary tests?
  too many of them actually ^^

## What does this PR do?
This is a rewrite of https://github.com/fairinternal/fairseq-py/issues/1538 following the discussion there, and taking into account the proposed https://github.com/fairinternal/fairseq-py/issues/1560 from Myle.
it brings online backtranslation to fairseq.
It adds a RobertaEncDec to fairseq. RobertaEncDec can be built from a pretrained Roberta model allowing to do transfer learning. This is crucial for backtranslation.

## PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

## Did you have fun?
Make sure you had fun coding �

Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1614

Reviewed By: myleott

Differential Revision: D27157296

Pulled By: gwenzek

fbshipit-source-id: 43020bc27743419bd4b138716165bf5764117c21
This commit is contained in:
Guillaume Wenzek 2021-03-30 09:54:22 -07:00 committed by Facebook GitHub Bot
parent 7dafb05754
commit c2e8904b60
16 changed files with 1472 additions and 28 deletions

View File

@ -296,6 +296,8 @@ class NoisingDataset(torch.utils.data.Dataset):
**kwargs,
)
)
self.sizes = src_dataset.sizes
def __getitem__(self, index):
"""

View File

@ -141,7 +141,7 @@ class RoundRobinZipDatasets(FairseqDataset):
f"{len(ignored)} samples from {key} have invalid sizes and will be skipped, "
f"max_positions={max_positions[key]}, first few sample ids={ignored[:10]}"
)
# Since we are modifiying in place the _ordered_indices,
# Since we are modifying in place the _ordered_indices,
# it's not possible anymore to return valid ignored indices.
# Hopefully the extra debug information print above should be enough to debug.
# Ideally we would receive ignore_invalid_inputs so that we could have

View File

@ -50,6 +50,9 @@ class TransformEosLangPairDataset(FairseqDataset):
def collater(self, samples, **extra_args):
samples = self.dataset.collater(samples, **extra_args)
if 'net_input' not in samples:
return samples
if self.new_src_eos is not None:
if self.dataset.left_pad_source:
assert (

View File

@ -5,6 +5,7 @@
from .hub_interface import * # noqa
from .model import * # noqa
from .enc_dec import * # noqa
from .model_camembert import * # noqa
from .model_gottbert import * # noqa
from .model_xlmr import * # noqa

View File

@ -0,0 +1,192 @@
import argparse
import logging
import torch.nn as nn
import fairseq.checkpoint_utils
from fairseq.models import (
FairseqEncoderDecoderModel,
register_model,
register_model_architecture,
)
from fairseq.models.transformer import TransformerDecoder
from fairseq.models.roberta import model as roberta
logger = logging.getLogger(__name__)
@register_model("roberta_enc_dec")
class RobertaEncDecModel(FairseqEncoderDecoderModel):
@staticmethod
def add_args(parser):
parser.add_argument(
"--pretrained-mlm-checkpoint",
default=None,
type=str,
metavar="PRETRAINED",
help="path to pretrained mlm checkpoint",
)
parser.add_argument(
"--pretrained-decoder", action="store_true", help="reload decoder"
)
parser.add_argument(
"--hack-layernorm-embedding",
action="store_true",
help="hack to reload old models trained with encoder-normalize-before=False (no equivalent to encoder-normalize-before=False and layernorm_embedding=False",
)
parser.add_argument(
"--share-decoder-input-output-embed",
action="store_true",
help="share decoder input and output embeddings",
)
parser.add_argument(
"--share-all-embeddings",
action="store_true",
help="share encoder, decoder and output embeddings"
" (requires shared dictionary and embed dim)",
)
@classmethod
def build_model(cls, args, task):
"""Build a new model instance."""
# make sure all arguments are present
base_enc_dec_architecture(args)
if args.pretrained_mlm_checkpoint:
arg_overrides = None
if args.hack_layernorm_embedding:
arg_overrides = {"layernorm_embedding": False}
loaded = fairseq.checkpoint_utils.load_model_ensemble_and_task(
[args.pretrained_mlm_checkpoint], arg_overrides=arg_overrides
)
([roberta_enc], _cfg, _task) = loaded
else:
# Do we need to edit untie_weights here ?
share_in_out = (
args.share_decoder_input_output_embed or args.share_all_embeddings
)
args.untie_weights_roberta = not share_in_out
if args.hack_layernorm_embedding:
args.layernorm_embedding = False
args.encoder_normalize_before = False
roberta_enc = roberta.RobertaModel.build_model(args, task)
return cls.from_roberta(roberta_enc, args, task.source_dictionary)
@staticmethod
def from_roberta(roberta_enc: roberta.RobertaModel, args, dictionary):
encoder = roberta_enc.encoder.sentence_encoder
vocab_size, embed_dim = encoder.embed_tokens.weight.shape
if args.share_all_embeddings:
lm_head = roberta_enc.encoder.lm_head
assert encoder.embed_tokens.weight is lm_head.weight, (
"Can't use --share-all-embeddings with a model "
"that was pretraiend with --untie-weights-roberta_enc"
)
else:
lm_head = roberta.RobertaLMHead(
embed_dim, vocab_size, roberta_enc.args.activation_fn
)
dec_embs = nn.Embedding(vocab_size, embed_dim, dictionary.pad())
if args.share_all_embeddings or args.share_decoder_input_output_embed:
# Note: I wasn't able to use Embedding _weight parameter to achive this sharing.
dec_embs.weight = lm_head.weight
decoder = TransformerDecoder(
RobertaEncDecModel.read_args_from_roberta(roberta_enc.args),
dictionary,
dec_embs,
no_encoder_attn=False,
output_projection=lm_head,
)
if getattr(args, "pretrained_decoder", False):
decoder_dict = encoder.state_dict()
# TODO: hide setting "encoder_attn" layers behind a flag.
for k, w in list(decoder_dict.items()):
if ".self_attn" in k:
k_enc_attn = k.replace(".self_attn", ".encoder_attn")
decoder_dict[k_enc_attn] = w.detach().clone()
for k, w in lm_head.state_dict().items():
decoder_dict["output_projection." + k] = w
missing_keys, unexpected_keys = decoder.load_state_dict(
decoder_dict, strict=False
)
# missing_keys = [m for m in missing_keys if ".encoder_attn" not in m]
assert not missing_keys and not unexpected_keys, (
"Failed to load state dict. "
f"Missing keys: {missing_keys}. "
f"Unexpected keys: {unexpected_keys}."
)
if args.share_all_embeddings:
assert decoder.output_projection.weight is decoder.embed_tokens.weight
assert encoder.embed_tokens.weight is decoder.embed_tokens.weight
elif args.share_decoder_input_output_embed:
assert decoder.output_projection.weight is decoder.embed_tokens.weight
assert encoder.embed_tokens.weight is not decoder.embed_tokens.weight
else:
assert decoder.output_projection.weight is not decoder.embed_tokens.weight
assert encoder.embed_tokens.weight is not decoder.embed_tokens.weight
return RobertaEncDecModel(encoder, decoder)
@staticmethod
def read_args_from_roberta(roberta_args: argparse.Namespace):
# TODO: this would become easier if encoder/decoder where using a similar
# TransformerConfig object
args = argparse.Namespace(**vars(roberta_args))
attr_map = [
("encoder_attention_heads", "decoder_attention_heads"),
("encoder_embed_dim", "decoder_embed_dim"),
("encoder_embed_dim", "decoder_output_dim"),
("encoder_normalize_before", "decoder_normalize_before"),
("encoder_layers_to_keep", "decoder_layers_to_keep"),
("encoder_ffn_embed_dim", "decoder_ffn_embed_dim"),
("encoder_layerdrop", "decoder_layerdrop"),
("encoder_layers", "decoder_layers"),
("encoder_learned_pos", "decoder_learned_pos"),
# should this be set from here ?
("max_positions", "max_target_positions"),
]
for k1, k2 in attr_map:
setattr(args, k2, getattr(roberta_args, k1))
args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None)
args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0)
args.share_decoder_input_output_embed = not roberta_args.untie_weights_roberta
return args
def upgrade_state_dict_named(self, state_dict, name):
prefix = name + "." if name != "" else ""
super().upgrade_state_dict_named(state_dict, name)
old_keys = list(state_dict.keys())
# rename decoder -> encoder before upgrading children modules
for k in old_keys:
if k.startswith(prefix + "encoder.lm_head"):
state_dict.pop(k)
continue
new_k = k
new_k = new_k.replace(".sentence_encoder.", ".")
new_k = new_k.replace("decoder.lm_head.", "decoder.output_projection.")
if k == new_k:
continue
# print(k, "->", new_k)
state_dict[new_k] = state_dict.pop(k)
@register_model_architecture("roberta_enc_dec", "roberta_enc_dec")
def base_enc_dec_architecture(args):
args.hack_layernorm_embedding = getattr(args, "hack_layernorm_embedding", False)
args.pretrained_mlm_checkpoint = getattr(args, "pretrained_mlm_checkpoint", None)
args.pretrained_decoder = getattr(args, "pretrained_decoder", None)
args.share_all_embeddings = getattr(args, "share_all_embeddings", False)
args.share_decoder_input_output_embed = getattr(
args, "share_decoder_input_output_embed", False
)
roberta.base_architecture(args)

View File

@ -204,7 +204,7 @@ class RobertaModel(FairseqEncoderModel):
features_only=False,
return_all_hiddens=False,
classification_head_name=None,
**kwargs
**kwargs,
):
if classification_head_name is not None:
features_only = True
@ -259,7 +259,7 @@ class RobertaModel(FairseqEncoderModel):
checkpoint_file="model.pt",
data_name_or_path=".",
bpe="gpt2",
**kwargs
**kwargs,
):
from fairseq import hub_utils
@ -464,7 +464,7 @@ class RobertaEncoder(FairseqEncoder):
features_only=False,
return_all_hiddens=False,
masked_tokens=None,
**unused
**unused,
):
"""
Args:

View File

@ -645,7 +645,14 @@ class TransformerDecoder(FairseqIncrementalDecoder):
(default: False).
"""
def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False):
def __init__(
self,
args,
dictionary,
embed_tokens,
no_encoder_attn=False,
output_projection=None,
):
self.args = args
super().__init__(dictionary)
self.register_buffer("version", torch.Tensor([3]))
@ -727,7 +734,11 @@ class TransformerDecoder(FairseqIncrementalDecoder):
)
self.adaptive_softmax = None
self.output_projection = None
self.output_projection = output_projection
if self.output_projection is None:
self.build_output_projection(args, dictionary, embed_tokens)
def build_output_projection(self, args, dictionary, embed_tokens):
if args.adaptive_softmax_cutoff is not None:
self.adaptive_softmax = AdaptiveSoftmax(
len(dictionary),
@ -789,7 +800,7 @@ class TransformerDecoder(FairseqIncrementalDecoder):
prev_output_tokens (LongTensor): previous decoder outputs of shape
`(batch, tgt_len)`, for teacher forcing
encoder_out (optional): output from the encoder, used for
encoder-side attention
encoder-side attention, should be of size T x B x C
incremental_state (dict): dictionary used for storing state during
:ref:`Incremental decoding`
features_only (bool, optional): only return features without
@ -802,6 +813,7 @@ class TransformerDecoder(FairseqIncrementalDecoder):
- the decoder's output of shape `(batch, tgt_len, vocab)`
- a dictionary with any model-specific outputs
"""
x, extra = self.extract_features(
prev_output_tokens,
encoder_out=encoder_out,
@ -810,6 +822,7 @@ class TransformerDecoder(FairseqIncrementalDecoder):
alignment_layer=alignment_layer,
alignment_heads=alignment_heads,
)
if not features_only:
x = self.output_layer(x)
return x, extra
@ -866,9 +879,19 @@ class TransformerDecoder(FairseqIncrementalDecoder):
- the decoder's features of shape `(batch, tgt_len, embed_dim)`
- a dictionary with any model-specific outputs
"""
bs, slen = prev_output_tokens.size()
if alignment_layer is None:
alignment_layer = self.num_layers - 1
enc: Optional[Tensor] = None
padding_mask: Optional[Tensor] = None
if encoder_out is not None:
enc = encoder_out["encoder_out"][0]
padding_mask = encoder_out["encoder_padding_mask"][0]
assert (
enc.size()[1] == bs
), f"Expected enc.shape == (t, {bs}, c) got {enc.shape}"
# embed positions
positions = None
if self.embed_positions is not None:
@ -916,15 +939,8 @@ class TransformerDecoder(FairseqIncrementalDecoder):
x, layer_attn, _ = layer(
x,
encoder_out["encoder_out"][0]
if (encoder_out is not None and len(encoder_out["encoder_out"]) > 0)
else None,
encoder_out["encoder_padding_mask"][0]
if (
encoder_out is not None
and len(encoder_out["encoder_padding_mask"]) > 0
)
else None,
enc,
padding_mask,
incremental_state,
self_attn_mask=self_attn_mask,
self_attn_padding_mask=self_attn_padding_mask,

View File

@ -147,8 +147,16 @@ class MultiheadAttention(nn.Module):
is_tpu = query.device.type == "xla"
tgt_len, bsz, embed_dim = query.size()
src_len = tgt_len
assert embed_dim == self.embed_dim
assert list(query.size()) == [tgt_len, bsz, embed_dim]
if key is not None:
src_len, key_bsz, key_embed_dim = key.size()
if not torch.jit.is_scripting():
assert (key_bsz, key_embed_dim) == (bsz, embed_dim)
assert value is not None
assert (src_len, bsz, embed_dim) == value.shape
if (
not self.onnx_trace
@ -262,6 +270,7 @@ class MultiheadAttention(nn.Module):
else:
assert k is not None
k = torch.cat([prev_key, k], dim=1)
src_len = k.size(1)
if "prev_value" in saved_state:
_prev_value = saved_state["prev_value"]
assert _prev_value is not None
@ -290,7 +299,7 @@ class MultiheadAttention(nn.Module):
assert incremental_state is not None
incremental_state = self._set_input_buffer(incremental_state, saved_state)
assert k is not None
src_len = k.size(1)
assert k.size(1) == src_len
# This is part of a workaround to get around fork/join parallelism
# not supporting Optional types.

View File

@ -4,7 +4,8 @@
# LICENSE file in the root directory of this source tree.
import argparse
from typing import Callable, List, Optional
from pathlib import Path
from typing import Callable, List, Optional, Union
import torch
from fairseq import utils
@ -361,3 +362,18 @@ def add_model_args(parser):
help='model architecture')
# fmt: on
return group
def get_args(
data: Union[str, Path],
task: str = "translation",
arch: str = "transformer",
**overrides
):
parser = get_training_parser(task)
args = parse_args_and_arch(parser, [str(data), "--task", task, "--arch", arch])
for k, v in overrides.items():
setattr(args, k, v)
return args

View File

@ -23,6 +23,7 @@ class SequenceGenerator(nn.Module):
beam_size=1,
max_len_a=0,
max_len_b=200,
max_len=0,
min_len=1,
normalize_scores=True,
len_penalty=1.0,
@ -44,6 +45,8 @@ class SequenceGenerator(nn.Module):
beam_size (int, optional): beam width (default: 1)
max_len_a/b (int, optional): generate sequences of maximum length
ax + b, where x is the source length
max_len (int, optional): the maximum length of the generated output
(not including end-of-sentence)
min_len (int, optional): the minimum length of the generated output
(not including end-of-sentence)
normalize_scores (bool, optional): normalize scores by the length
@ -79,6 +82,7 @@ class SequenceGenerator(nn.Module):
self.max_len_a = max_len_a
self.max_len_b = max_len_b
self.min_len = min_len
self.max_len = max_len or self.model.max_decoder_positions()
self.normalize_scores = normalize_scores
self.len_penalty = len_penalty
@ -166,7 +170,7 @@ class SequenceGenerator(nn.Module):
yield id, src, ref, hypos[i]
@torch.no_grad()
def generate(self, models, sample: Dict[str, Dict[str, Tensor]], **kwargs):
def generate(self, models, sample: Dict[str, Dict[str, Tensor]], **kwargs) -> List[List[Dict[str, Tensor]]]:
"""Generate translations. Match the api of other fairseq generators.
Args:
@ -232,8 +236,7 @@ class SequenceGenerator(nn.Module):
else:
max_len = min(
int(self.max_len_a * src_len + self.max_len_b),
# exclude the EOS marker
self.model.max_decoder_positions() - 1,
self.max_len - 1,
)
assert (
self.min_len <= max_len
@ -275,9 +278,8 @@ class SequenceGenerator(nn.Module):
[torch.jit.annotate(List[Dict[str, Tensor]], []) for i in range(bsz)],
) # contains lists of dictionaries of infomation about the hypothesis being finalized at each step
finished = [
False for i in range(bsz)
] # a boolean array indicating if the sentence at the index is finished or not
# a boolean array indicating if the sentence at the index is finished or not
finished = [False for i in range(bsz)]
num_remaining_sent = bsz # number of sentences remaining
# number of candidate hypos per step

View File

@ -0,0 +1,677 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import contextlib
import json
import logging
import math
import os
from argparse import Namespace
from collections import OrderedDict, defaultdict
from pathlib import Path
from typing import Dict, Sequence, Tuple
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import fairseq
from fairseq import metrics, options, utils
from fairseq.data import (
FairseqDataset,
LanguagePairDataset,
NoisingDataset,
PrependTokenDataset,
RoundRobinZipDatasets,
TransformEosLangPairDataset,
data_utils,
encoders,
)
from fairseq.sequence_generator import SequenceGenerator
from fairseq.tasks import register_task
from fairseq.tasks.translation import TranslationTask, load_langpair_dataset
logger = logging.getLogger(__name__)
class PiecewiseLinearFn:
"""Piecewise linear function. Can be configured with a string."""
def __init__(self, pieces: Sequence[Tuple[int, float]]):
assert pieces == sorted(
pieces
), f"PiecewiseLinearFn configuration should be sorted, received: {pieces}"
self.pieces = pieces
def __call__(self, x: int) -> float:
for i, (x_a, y_a) in enumerate(self.pieces[:-1]):
x_b, y_b = self.pieces[i + 1]
if x_a <= x <= x_b:
return y_a + (x - x_a) * (y_b - y_a) / (x_b - x_a)
return self.pieces[-1][1]
@staticmethod
def from_string(configuration: str) -> "PiecewiseLinearFn":
"""
Parse the configuration of lambda coefficient (for scheduling).
x = "3" # lambda will be a constant equal to x
x = "0:1,1000:0" # lambda will start from 1 and linearly decrease
# to 0 during the first 1000 iterations
x = "0:0,1000:0,2000:1" # lambda will be equal to 0 for the first 1000
# iterations, then will linearly increase to 1 until iteration 2000
"""
if isinstance(configuration, float):
return PiecewiseLinearFn([(0, configuration)])
try:
parts = configuration.split(",")
if len(parts) == 1:
v = float(configuration)
return PiecewiseLinearFn([(0, v)])
split = [s.split(":") for s in parts]
pieces = [(int(t), float(v)) for t, v in split]
return PiecewiseLinearFn(pieces)
except Exception:
raise ValueError(
f"Invalid PiecewiseLinearFn configuration: {configuration!r}"
)
@staticmethod
def one() -> "PiecewiseLinearFn":
return PiecewiseLinearFn([(0, 1.0)])
@register_task("online_backtranslation")
class OnlineBackTranslationTask(TranslationTask):
@staticmethod
def add_args(parser):
"""Add task-specific arguments to the parser."""
# fmt: off
# Generic translation args
parser.add_argument('data', help='colon separated path to data directories list, \
will be iterated upon during epochs in round-robin manner; \
however, valid and test data are always in the first directory to \
avoid the need for repeating them in all directories')
parser.add_argument('--mono-langs', metavar='MONO_LANGS',
help='monolingual languages for training')
parser.add_argument('--valid-lang-pairs', default=None, metavar='VALID_LANG_PAIRS',
help='language pairs for validation')
parser.add_argument('--load-alignments', action='store_true',
help='load the binarized alignments')
parser.add_argument('--left-pad-source', default='False', type=str, metavar='BOOL',
help='pad the source on the left')
parser.add_argument('--left-pad-target', default='False', type=str, metavar='BOOL',
help='pad the target on the left')
parser.add_argument('--upsample-primary', default=1, type=int,
help='amount to upsample primary dataset')
parser.add_argument('--max-source-positions', default=1024, type=int, metavar='N',
help='max number of tokens in the source sequence')
parser.add_argument('--max-target-positions', default=1024, type=int, metavar='N',
help='max number of tokens in the target sequence')
parser.add_argument('--truncate-source', action='store_true', default=False,
help='truncate source to max-source-positions')
parser.add_argument('--num-batch-buckets', default=0, type=int, metavar='N',
help='if >0, then bucket source and target lengths into N '
'buckets and pad accordingly; this is useful on TPUs '
'to minimize the number of compilations')
# Denoising args
parser.add_argument('--max-word-shuffle-distance', default=3.0, type=float, metavar='N',
help='maximum word shuffle distance for denoising autoencoding data generation')
parser.add_argument('--word-dropout-prob', default=0.1, type=float, metavar='N',
help='word dropout probability for denoising autoencoding data generation')
parser.add_argument('--word-blanking-prob', default=0.2, type=float, metavar='N',
help='word blanking probability for denoising autoencoding data generation')
# Backtranslation args
parser.add_argument('--lambda-bt', default="1.0", type=str, metavar='N',
help='back-translation weight')
parser.add_argument('--lambda-dae', default="1.0", type=str, metavar='N',
help='denoising auto-encoder weight')
# Evaluation args
parser.add_argument('--generate-one-by-one', action='store_true',
help='generate one sentence at a time for backtranslation')
parser.add_argument('--eval-bleu', action='store_true',
help='evaluation with BLEU scores')
parser.add_argument('--eval-bleu-detok', type=str, default="space",
help='detokenize before computing BLEU (e.g., "moses"); '
'required if using --eval-bleu; use "space" to '
'disable detokenization; see fairseq.data.encoders '
'for other options')
parser.add_argument('--eval-bleu-detok-args', type=str, metavar='JSON',
help='args for building the tokenizer, if needed')
parser.add_argument('--eval-tokenized-bleu', action='store_true', default=False,
help='compute tokenized BLEU instead of sacrebleu')
parser.add_argument('--eval-bleu-remove-bpe', nargs='?', const='@@ ', default=None,
help='remove BPE before computing BLEU')
parser.add_argument('--eval-bleu-args', type=str, metavar='JSON',
help='generation args for BLUE scoring, '
'e.g., \'{"beam": 4, "lenpen": 0.6}\'')
parser.add_argument('--eval-bleu-print-samples', action='store_true',
help='print sample generations during validation')
# fmt: on
def __init__(self, args, common_dict, mono_langs, valid_lang_pairs):
super().__init__(args, common_dict, common_dict)
self.common_dict = common_dict
self.mono_langs = mono_langs
self.valid_lang_pairs = valid_lang_pairs
self.SHOW_SAMPLES_INTERVAL = 1000
# Start by showing samples
self._show_samples_ctr = self.SHOW_SAMPLES_INTERVAL
self.SHOW_SAMPLES_NUMBER = 5
self.lambda_bt = PiecewiseLinearFn.from_string(args.lambda_bt)
self.lambda_dae = PiecewiseLinearFn.from_string(args.lambda_dae)
self.args = args
self.data = utils.split_paths(self.args.data)
if len(self.data) == 1:
shards = list(Path(self.data[0]).glob("shard*"))
if len(shards) > 0:
# keep this as strings, since it can also be a manifold path
old_data = self.data
self.data = [str(shard) for shard in shards]
logging.warning(f"Expanded data directory {old_data} to {self.data}")
@classmethod
def setup_task(cls, args, **kwargs):
"""Setup the task (e.g., load dictionaries).
Args:
args (argparse.Namespace): parsed command-line arguments
"""
args.left_pad_source = options.eval_bool(args.left_pad_source)
args.left_pad_target = options.eval_bool(args.left_pad_target)
paths = utils.split_paths(args.data)
assert len(paths) > 0
assert args.mono_langs is not None
mono_langs = args.mono_langs.split(",")
valid_lang_pairs = args.valid_lang_pairs.split(",")
# load dictionary
dict_path = os.path.join(paths[0], "dict.txt")
common_dict = cls.load_dictionary(dict_path)
return cls(args, common_dict, mono_langs, valid_lang_pairs)
def load_dataset(self, split, epoch=1, combine=False, **kwargs) -> FairseqDataset:
"""Load a given dataset split.
Args:
split (str): name of the split (e.g., train, valid, test)
"""
if split == "train":
data_path = self.data[(epoch - 1) % len(self.data)]
dataset = self.load_train_dataset(data_path)
else:
# valid/test should always be the same.
dataset = self.load_translation_dataset(split, self.data[0])
self.datasets[split] = dataset
return dataset
def load_train_dataset(self, data_path: str) -> FairseqDataset:
"""The training dataset is made of backtranslation dataset and denoising dataset."""
data = []
for lang in self.mono_langs:
train_path = os.path.join(data_path, lang, "train")
# TODO: could we do the BT using denoise sample ?
# this would half the data loading work
data.append((f"{lang}-BT", self.load_bt_dataset(train_path, lang)))
data.append(
(f"{lang}-DENOISE", self.load_denoise_dataset(train_path, lang))
)
return RoundRobinZipDatasets(OrderedDict(data))
def _langpair_dataset(
self, src: FairseqDataset, tgt: FairseqDataset
) -> LanguagePairDataset:
return LanguagePairDataset(
src,
src.sizes,
self.dictionary,
tgt=tgt,
tgt_sizes=tgt.sizes,
tgt_dict=self.dictionary,
left_pad_source=self.args.left_pad_source,
left_pad_target=self.args.left_pad_target,
# TODO: should we shuffle ? we are already sorting batch by sizes so ?
# shuffle=True,
)
def _prepend_lang_bos_to_target(
self, dataset: LanguagePairDataset, lang: str
) -> LanguagePairDataset:
bos = _lang_token_index(self.dictionary, lang)
return TransformEosLangPairDataset(
dataset,
src_eos=self.dictionary.eos(),
new_src_eos=self.dictionary.eos(),
tgt_bos=self.dictionary.eos(),
new_tgt_bos=bos,
)
def load_bt_dataset(self, data_path: str, lang: str) -> FairseqDataset:
"""The BT dataset is generated with (tgt, tgt) pairs.
The actual translation to a (generated_src, tgt) pair
is done on the fly during training.
"""
mono_dataset = data_utils.load_indexed_dataset(
data_path, self.common_dict, self.args.dataset_impl
)
assert mono_dataset is not None, f"No dataset found for {lang}"
mono_dataset_src = PrependTokenDataset(
mono_dataset, _lang_token_index(self.dictionary, lang)
)
mono_dataset_bt = self._langpair_dataset(mono_dataset_src, mono_dataset)
logger.info(
f"mono_lang = {lang} "
f"lang token index = {_lang_token_index(self.dictionary, lang)} "
f"lang token = {_lang_token(lang)}"
)
mono_dataset_bt = self._prepend_lang_bos_to_target(mono_dataset_bt, lang)
return mono_dataset_bt
def load_denoise_dataset(self, data_path: str, lang: str) -> FairseqDataset:
"""Classic denoising dataset"""
dataset = data_utils.load_indexed_dataset(
data_path, self.common_dict, self.args.dataset_impl
)
noisy_dataset = NoisingDataset(
dataset,
self.dictionary,
seed=1,
max_word_shuffle_distance=self.args.max_word_shuffle_distance,
word_dropout_prob=self.args.word_dropout_prob,
word_blanking_prob=self.args.word_blanking_prob,
)
noisy_dataset = PrependTokenDataset(
noisy_dataset, _lang_token_index(self.dictionary, lang)
)
clean_dataset = data_utils.load_indexed_dataset(
data_path, self.common_dict, self.args.dataset_impl
)
denoising_dataset = self._langpair_dataset(noisy_dataset, clean_dataset)
denoising_dataset = self._prepend_lang_bos_to_target(denoising_dataset, lang)
return denoising_dataset
def load_translation_dataset(
self, split: str, data_path: str, combine: bool = False
):
# only judging with one language pair for the moment,
# since ConcatDataset doesn't work as expected
assert len(self.valid_lang_pairs) == 1, "For now..."
valid_lang_pair = self.valid_lang_pairs[0]
src, tgt = valid_lang_pair.split("-")
# use the same function than TranslationTask
src_tgt_dt = load_langpair_dataset(
data_path,
split,
src,
self.common_dict,
tgt,
self.common_dict,
combine=combine,
dataset_impl=self.args.dataset_impl,
upsample_primary=self.args.upsample_primary,
left_pad_source=self.args.left_pad_source,
left_pad_target=self.args.left_pad_target,
max_source_positions=self.args.max_source_positions,
max_target_positions=self.args.max_target_positions,
load_alignments=self.args.load_alignments,
truncate_source=self.args.truncate_source,
num_buckets=self.args.num_batch_buckets,
shuffle=(split != "test"),
prepend_bos_src=_lang_token_index(self.dictionary, src),
)
src_tgt_eos_dt = self._prepend_lang_bos_to_target(src_tgt_dt, tgt)
src_tgt_eos_dt.args = self.args
return src_tgt_eos_dt
def build_dataset_for_inference(self, src_tokens, src_lengths, constraints=None):
raise NotImplementedError
def build_model(self, args):
# torch.autograd.set_detect_anomaly(True)
model = super().build_model(args)
add_secial_tokens_to_dict_and_model(self.common_dict, model, self.mono_langs)
self.sequence_generators = {}
for mono_lang in self.mono_langs:
self.sequence_generators[mono_lang] = SequenceGenerator(
[model],
tgt_dict=self.dictionary,
beam_size=1,
max_len_a=1.3,
max_len_b=5,
min_len=5,
# keep 1 to be able to prepend bos
max_len=model.max_decoder_positions() - 1,
)
if getattr(args, "eval_bleu", False):
assert getattr(args, "eval_bleu_detok", None) is not None, (
"--eval-bleu-detok is required if using --eval-bleu; "
"try --eval-bleu-detok=moses (or --eval-bleu-detok=space "
"to disable detokenization, e.g., when using sentencepiece)"
)
detok_args = json.loads(getattr(args, "eval_bleu_detok_args", "{}") or "{}")
self.tokenizer = encoders.build_tokenizer(
Namespace(
tokenizer=getattr(args, "eval_bleu_detok", None), **detok_args
)
)
gen_args = json.loads(getattr(args, "eval_bleu_args", "{}") or "{}")
self.bleu_sequence_generator = self.build_generator(
[model], Namespace(**gen_args)
)
return model
def max_positions(self):
"""Return the max sentence length allowed by the task."""
return (self.args.max_source_positions, self.args.max_target_positions)
@property
def dictionary(self):
"""Return the source :class:`~fairseq.data.Dictionary`."""
return self.common_dict
def display_samples_once_in_a_while(self, smp, mono_lang, other_lang):
self._show_samples_ctr += 1
if self._show_samples_ctr < self.SHOW_SAMPLES_INTERVAL:
return
self._show_samples_ctr = 0
ln = smp["net_input"]["src_tokens"].shape[0]
logger.info(
f"(r:{self.args.distributed_rank}) : "
f"{other_lang} ---> {mono_lang} "
f"({other_lang} was generated by back-translation.) {ln} samples"
)
for i in range(min(ln, self.SHOW_SAMPLES_NUMBER)):
src_tokens = smp["net_input"]["src_tokens"][i]
tgt_tokens = smp["target"][i]
src_str = self.dictionary.string(src_tokens, "sentencepiece")
tgt_str = self.dictionary.string(tgt_tokens, "sentencepiece")
logger.info(
f"\n{i}\t\t[{other_lang} generated] {src_str}\n"
f"\t\t[{mono_lang} original ] {tgt_str}\n"
f"\t\t[ src tokens] {src_tokens}\n"
)
def backtranslate_sample(self, smp, orig_lang, other_lang) -> None:
"""
* WARNING: smp is modified in place.
* At the start of this function, `smp` has the same input and target:
|--------------------------------------------------------|
| smp['net_input']['src_tokens'] | smp['target'] |
| (from data) __en__ hello world | __en__ hello world |
|--------------------------------------------------------|
* We call generator.generate(smp, bos_token = token("ro")),
and copy the result as input
* At the end, `smp` has the translation to other language.
|--------------------------------------------------------|
| smp['net_input']['src_tokens'] | smp['target'] |
| (generated) __ro__ salut lume | __en__ hello world |
|--------------------------------------------------------|
"""
bos_token = _lang_token_index(self.dictionary, other_lang)
generated = self.sequence_generators[orig_lang].generate(
models=[], sample=smp, bos_token=bos_token
)
max_lngth = max([gn[0]["tokens"].size(0) for gn in generated])
net_input = smp["net_input"]
n_src_tokens = torch.empty(
size=(len(generated), max_lngth + 1), dtype=net_input["src_tokens"].dtype
)
n_src_lengths = torch.empty(
len(generated), dtype=net_input["src_lengths"].dtype
)
for i, gn in enumerate(generated):
tokens = gn[0]["tokens"]
tokens_size = tokens.size(0)
padding_needed = max_lngth - tokens_size
tokens = torch.cat([tokens.new([bos_token]), tokens])
tokens = F.pad(tokens, (0, padding_needed), value=self.dictionary.pad())
n_src_tokens[i] = tokens
n_src_lengths[i] = tokens_size + 1
device = net_input["src_tokens"].device
# This seems to be important
del net_input["src_tokens"]
del net_input["src_lengths"]
net_input["src_tokens"] = n_src_tokens.to(device)
net_input["src_lengths"] = n_src_lengths.to(device)
def generate(self, smp, model):
model.eval()
orig_lang = (
self.dictionary[smp["net_input"]["src_tokens"][0][0]]
.replace(" ", "")
.replace("_", "")
)
bos_token = smp["net_input"]["prev_output_tokens"][0][0]
with torch.no_grad():
generated = self.sequence_generators[orig_lang].generate(
models=[model], sample=smp, bos_token=bos_token
)
return generated
def get_other_lang(self, lang):
# TODO: allow more complex mapping
if lang != self.mono_langs[0]:
return self.mono_langs[0]
if len(self.mono_langs) == 2:
return self.mono_langs[1]
return self.mono_langs[np.random.randint(1, len(self.mono_langs))]
def train_step(
self, sample, model, criterion, optimizer, update_num, ignore_grad=False
):
model.train()
model.set_num_updates(update_num)
agg_loss, agg_sample_size = 0.0, 0.0
agg_logging_output: Dict[str, float] = defaultdict(float)
dataset_keys = self.datasets["train"].datasets.keys()
weights = {
"BT": self.lambda_bt(update_num),
"DENOISE": self.lambda_dae(update_num),
}
log_keys = {"BT": "bt_", "DENOISE": "dae_"}
for dataset_key in dataset_keys:
smp = sample[dataset_key]
mono_lang, task_subtype = dataset_key.split("-")
if weights[task_subtype] == 0:
continue
if task_subtype == "BT":
with torch.autograd.profiler.record_function("backtranslation"):
model.eval()
# TODO: Could we translate to several language at once ?
# this would allow to share encoder_out and maximize GPU usage.
other_lang = self.get_other_lang(mono_lang)
self.backtranslate_sample(smp, mono_lang, other_lang)
self.display_samples_once_in_a_while(smp, mono_lang, other_lang)
model.train()
# Like in FairseqTask.train_step
with torch.autograd.profiler.record_function("forward"):
loss, sample_size, logging_output = criterion(model, smp)
loss *= weights[task_subtype]
if ignore_grad:
loss *= 0
with torch.autograd.profiler.record_function("backward"):
optimizer.backward(loss)
agg_loss += loss.item()
agg_sample_size += sample_size
for k in logging_output:
agg_logging_output[log_keys[task_subtype] + k] += logging_output[k]
agg_logging_output[k] += logging_output[k]
return agg_loss, agg_sample_size, agg_logging_output
def get_bos_token_from_sample(self, sample):
net_input = sample["net_input"]
source_lang_token_id = torch.unique(net_input["src_tokens"][:, 0]).item()
source_lang_token = self.dictionary[source_lang_token_id].replace("_", "")
target_lang_token_id = _lang_token_index(
self.dictionary, self.get_other_lang(source_lang_token)
)
return target_lang_token_id
def reduce_metrics(self, logging_outputs, criterion):
super().reduce_metrics(logging_outputs, criterion)
bt_sample_size = sum(x.get("bt_sample_size", 0) for x in logging_outputs)
if bt_sample_size:
bt_loss_sum = sum(x.get("bt_loss", 0) for x in logging_outputs)
bt_loss_sum *= 1 / bt_sample_size / math.log(2)
metrics.log_scalar("bt_loss", bt_loss_sum, bt_sample_size, round=3)
bt_nll_loss_sum = sum(x.get("bt_nll_loss", 0) for x in logging_outputs)
bt_ntokens = sum(x.get("bt_ntokens", 0) for x in logging_outputs)
bt_nll_loss_sum *= 1 / bt_ntokens / math.log(2)
metrics.log_scalar("bt_nll_loss", bt_nll_loss_sum, bt_ntokens, round=3)
metrics.log_derived(
"bt_ppl", lambda meters: utils.get_perplexity(meters["bt_nll_loss"].avg)
)
dae_sample_size = sum(x.get("dae_sample_size", 0) for x in logging_outputs)
if dae_sample_size:
dae_loss_sum = sum(x.get("dae_loss", 0) for x in logging_outputs)
dae_loss_sum *= 1 / dae_sample_size / math.log(2)
metrics.log_scalar("dae_loss", dae_loss_sum, dae_sample_size, round=3)
dae_nll_loss_sum = sum(x.get("dae_nll_loss", 0) for x in logging_outputs)
dae_ntokens = sum(x.get("dae_ntokens", 0) for x in logging_outputs)
dae_nll_loss_sum *= 1 / dae_ntokens / math.log(2)
metrics.log_scalar("dae_nll_loss", dae_nll_loss_sum, dae_ntokens, round=3)
metrics.log_derived(
"dae_ppl",
lambda meters: utils.get_perplexity(meters["dae_nll_loss"].avg),
)
@torch.no_grad()
def extend_embedding(
emb: nn.Module, new_vocab_size: int, copy_from_token_id: int
) -> None:
old_emb_data = emb.weight.data
(old_vocab_size, dim) = old_emb_data.shape
assert new_vocab_size >= old_vocab_size
if new_vocab_size > old_vocab_size:
emb.weight.data = torch.zeros((new_vocab_size, dim))
emb.weight.data[:old_vocab_size, :] = old_emb_data
# initialize new embeddings
emb.weight.data[old_vocab_size:, :] = old_emb_data[copy_from_token_id]
if hasattr(emb, "num_embeddings"):
emb.num_embeddings = new_vocab_size
if hasattr(emb, "out_features"):
emb.out_features = new_vocab_size
if getattr(emb, "bias", None) is None:
return
# Fix the bias.
# Bias shape can be different from the previous vocab size
# if the weight matrix was shared and alread extended but not the bias.
(old_vocab_size,) = emb.bias.shape
assert new_vocab_size >= old_vocab_size
if new_vocab_size > old_vocab_size:
old_bias = emb.bias.data
new_bias = torch.zeros(
(new_vocab_size,), dtype=old_bias.dtype, device=old_bias.device
)
new_bias[:old_vocab_size] = old_bias
emb.bias.data = new_bias
def add_secial_tokens_to_dict_and_model(
dictionary: "fairseq.data.Dictionary",
model: nn.Module,
mono_langs: Sequence[str],
) -> None:
embs = model.encoder.embed_tokens
vocab_size, embedding_dim = embs.weight.shape
# The model may or may not have a '<mask>' embedding yet
assert (
len(dictionary) <= vocab_size <= len(dictionary) + 1
), f"Dictionary len ({len(dictionary)}) doesn't match embs shape ({embs.weight.shape})"
# TODO: we should reuse the pretrained model dict which already has <mask>
dictionary.add_symbol("<mask>")
for lang in mono_langs:
lang_token = _lang_token(lang)
dictionary.add_symbol(lang_token)
logger.info(
f"dictionary: {len(dictionary)} -> {vocab_size} tokens "
f"after adding {len(mono_langs)} lang tokens."
)
if len(dictionary) <= vocab_size:
return
extend_embedding(embs, len(dictionary), dictionary.bos())
dec_embs = model.decoder.embed_tokens
extend_embedding(dec_embs, len(dictionary), dictionary.bos())
lm_head = model.decoder.output_projection
extend_embedding(lm_head, len(dictionary), dictionary.bos())
assert lm_head.weight.shape == (len(dictionary), embedding_dim)
def _lang_token(lang: str) -> str:
return f"__{lang}__"
def _lang_token_index(dictionary, lang: str) -> int:
return dictionary.index(_lang_token(lang))
@contextlib.contextmanager
def assert_weights_have_changed(model: nn.Module):
def checksum(model: nn.Module) -> float:
return sum(p.sum().item() for p in model.parameters())
initial_checksum = checksum(model)
yield model
final_checksum = checksum(model)
logger.info(
f"initial_checksum={initial_checksum} -> final_checksum={final_checksum}"
)
assert initial_checksum != final_checksum, "Model hasn't changed !"

View File

@ -57,6 +57,7 @@ def load_langpair_dataset(
num_buckets=0,
shuffle=True,
pad_to_multiple=1,
prepend_bos_src=None,
):
def split_exists(split, src, tgt, lang, data_path):
filename = os.path.join(data_path, "{}.{}-{}.{}".format(split, src, tgt, lang))
@ -128,6 +129,9 @@ def load_langpair_dataset(
src_dataset = PrependTokenDataset(src_dataset, src_dict.bos())
if tgt_dataset is not None:
tgt_dataset = PrependTokenDataset(tgt_dataset, tgt_dict.bos())
elif prepend_bos_src is not None:
logger.info(f"prepending src bos: {prepend_bos_src}")
src_dataset = PrependTokenDataset(src_dataset, prepend_bos_src)
eos = None
if append_source_id:

View File

@ -38,7 +38,7 @@ class TranslationFromPretrainedBARTTask(TranslationTask):
"""Add task-specific arguments to the parser."""
# fmt: off
TranslationTask.add_args(parser)
parser.add_argument('--langs', required=True, metavar='LANG',
parser.add_argument('--langs', type=str, metavar='LANG',
help='comma-separated list of monolingual language, '
'for example, "en,de,fr". These should match the '
'langs from pretraining (and be in the same order). '

View File

@ -80,9 +80,6 @@ def main(cfg: FairseqConfig) -> None:
# Setup task, e.g., translation, language modeling, etc.
task = tasks.setup_task(cfg.task)
# Load valid dataset (we load training data below, based on the latest checkpoint)
for valid_sub_split in cfg.dataset.valid_subset.split(","):
task.load_dataset(valid_sub_split, combine=False, epoch=1)
assert cfg.criterion, "Please specify criterion to train a model"
@ -111,6 +108,11 @@ def main(cfg: FairseqConfig) -> None:
)
)
# Load valid dataset (we load training data below, based on the latest checkpoint)
# We load the valid dataset AFTER building the model
for valid_sub_split in cfg.dataset.valid_subset.split(","):
task.load_dataset(valid_sub_split, combine=False, epoch=1)
# (optionally) Configure quantization
if cfg.common.quantization_config_path is not None:
quantizer = quantization_utils.Quantizer(

View File

@ -0,0 +1,206 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import tempfile
import unittest
from pathlib import Path
from typing import Any, Dict, Sequence
import fairseq.data.indexed_dataset as indexed_dataset
import fairseq.options
import fairseq.tasks.online_backtranslation as obt
import torch
from tests import utils
def mk_sample(tokens: Sequence[int], batch_size: int = 2) -> Dict[str, Any]:
batch = torch.stack([torch.tensor(tokens, dtype=torch.long)] * batch_size)
sample = {
"net_input": {
"src_tokens": batch,
"prev_output_tokens": batch,
"src_lengths": torch.tensor([len(tokens)] * batch_size, dtype=torch.long),
},
"target": batch[:, 1:],
}
return sample
def mk_dataset(num_samples: int, max_len: int, output: Path):
output.parent.mkdir(exist_ok=True)
idx = indexed_dataset.IndexedDatasetBuilder(str(output))
data = torch.randint(5, 100, (num_samples, max_len))
lengths = torch.randint(3, max_len, (num_samples,))
for d, l in zip(data, lengths):
d[0] = 0
idx.add_item(d[:l])
idx.finalize(output.with_suffix(".idx"))
assert output.exists()
assert output.with_suffix(".idx").exists()
class OnlineBacktranslationTest(unittest.TestCase):
tmp_dir = Path(tempfile.mkdtemp(suffix="OnlineBacktranslationTest"))
@classmethod
def obt_task(
cls, languages: Sequence[str], data: Path = None, language_mapping: str = None
):
dict_path = cls.tmp_dir / "dict.txt"
if not dict_path.exists():
dictionary = utils.dummy_dictionary(100)
dictionary.save(str(dict_path))
if data is not None:
(data / "dict.txt").write_text(dict_path.read_text())
else:
data = cls.tmp_dir
assert len(languages) >= 2
kwargs = {
"arch": "transformer",
# --max-sentences=1 for better predictability of batches
"max_sentences": 1,
# Use characteristics dimensions
"encoder_layers": 3,
"encoder_embed_dim": 12,
"encoder_ffn_embed_dim": 14,
"encoder_attention_heads": 4,
"decoder_layers": 3,
"decoder_embed_dim": 12,
"decoder_output_dim": 12,
"decoder_ffn_embed_dim": 14,
"decoder_attention_heads": 4,
# Disable dropout so we have comparable tests.
"dropout": 0,
"attention_dropout": 0,
"activation_dropout": 0,
"encoder_layerdrop": 0,
}
args = fairseq.options.get_args(
data,
task="online_backtranslation",
mono_langs=",".join(languages),
valid_lang_pairs=f"{languages[0]}-{languages[1]}",
tokens_per_sample=256,
language_mapping=language_mapping,
**kwargs,
)
task = obt.OnlineBackTranslationTask.setup_task(args)
# we need to build the model to have the correct dictionary
model = task.build_model(task.args)
return task, model
def tmp_path(self, test_case: str) -> Path:
return Path(tempfile.mkdtemp(test_case, dir=self.tmp_dir))
def test_lang_tokens(self):
task, model = self.obt_task(["en", "ro", "zh"])
assert obt._lang_token("en") in task.dictionary
assert obt._lang_token("ro") in task.dictionary
assert obt._lang_token("zh") in task.dictionary
en_bos = obt._lang_token_index(task.common_dict, "en")
assert "en" == task.common_dict[en_bos].strip("_")
zh_bos = obt._lang_token_index(task.common_dict, "zh")
assert "zh" == task.common_dict[zh_bos].strip("_")
zh_sample = mk_sample([zh_bos, 16, 14, 12, 10])
# we expect to receive the bos token for translation
assert task.get_bos_token_from_sample(zh_sample) == en_bos
def test_backtranslate_sample(self):
task, model = self.obt_task(["en", "ro", "zh"])
en_bos = obt._lang_token_index(task.common_dict, "en")
zh_bos = obt._lang_token_index(task.common_dict, "zh")
sample = mk_sample([zh_bos, 16, 14, 12, 10])
task.backtranslate_sample(sample, "zh", "en")
target_zh = list(sample["target"][0])
assert target_zh == [16, 14, 12, 10] # original zh sentence
generated_en = sample["net_input"]["src_tokens"][0]
assert generated_en[0] == en_bos
def test_train_dataset(self):
data = self.tmp_path("test_train_dataset")
mk_dataset(20, 10, data / "en" / "train.bin")
mk_dataset(10, 10, data / "zh" / "train.bin")
task, model = self.obt_task(["en", "zh"], data)
task.load_dataset("train")
en_bos = obt._lang_token_index(task.common_dict, "en")
zh_bos = obt._lang_token_index(task.common_dict, "zh")
train = task.datasets["train"]
train.ordered_indices()
train.prefetch([0, 19])
sample_0 = train[0]
sample_19 = train[19]
self.assertEqual(
set(sample_0.keys()), {"en-BT", "en-DENOISE", "zh-BT", "zh-DENOISE"}
)
for sample in (sample_0, sample_19):
self.assertEqual(sample["en-BT"]["source"][0], en_bos)
# bt target isn't ready to look at.
self.assertEqual(sample["en-DENOISE"]["source"][0], en_bos)
# TODO What could we check on the target side ?
for i in range(10):
# Zh dataset is shorter, and is wrapped around En dataset.
train.prefetch([i, i + 10])
self.assertEqual(
list(train[i]["zh-DENOISE"]["source"]),
list(train[i + 10]["zh-DENOISE"]["source"]),
)
self.assertEqual(train[i]["zh-DENOISE"]["source"][0].item(), zh_bos)
# Sorted by increasing len
self.assertLess(
len(sample_0["en-BT"]["source"]), len(sample_19["en-BT"]["source"])
)
def test_valid_dataset(self):
data = self.tmp_path("test_valid_dataset")
mk_dataset(10, 21, data / "valid.en-zh.en.bin")
mk_dataset(10, 21, data / "valid.en-zh.zh.bin")
task, model = self.obt_task(["en", "zh"], data)
valid = task.load_dataset("valid")
en_bos = obt._lang_token_index(task.common_dict, "en")
assert valid is not None
valid.prefetch(range(10))
sample_0 = valid[0]
sample_9 = valid[9]
self.assertEqual(sample_0["id"], 0)
self.assertEqual(sample_9["id"], 9)
self.assertEqual(sample_0["source"][0], en_bos)
self.assertEqual(sample_9["source"][0], en_bos)
# TODO: could we test the target side ?
def assertFnMatch(self, fn, values):
for x, y in values.items():
fn_x = fn(x)
self.assertEqual(fn_x, y, f"Fn has wrong value: fn({x}) = {fn_x} != {y}")
def test_piecewise_linear_fn(self):
self.assertFnMatch(
obt.PiecewiseLinearFn.from_string("1.0"), {0: 1, 100: 1, 500: 1, 1000: 1}
)
self.assertFnMatch(
obt.PiecewiseLinearFn.from_string("0:1,1000:0"),
{0: 1, 500: 0.5, 1000: 0, 2000: 0},
)
self.assertFnMatch(
obt.PiecewiseLinearFn.from_string("0:0,1000:1"),
{0: 0, 500: 0.5, 1000: 1, 2000: 1},
)
self.assertFnMatch(
obt.PiecewiseLinearFn.from_string("0:0,1000:1,2000:0"),
{0: 0, 500: 0.5, 1000: 1, 1500: 0.5, 2000: 0, 3000: 0},
)

314
tests/test_roberta.py Normal file
View File

@ -0,0 +1,314 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import functools
import unittest
from typing import Any, Dict, Sequence
import fairseq
import fairseq.options
import fairseq.tasks
import torch
from tests.utils import dummy_dictionary
VOCAB_SIZE = 100
@fairseq.tasks.register_task("fake_task")
class FakeTask(fairseq.tasks.LegacyFairseqTask):
def __init__(self, args):
super().__init__(args)
self.dictionary = dummy_dictionary(VOCAB_SIZE - 4)
assert len(self.dictionary) == VOCAB_SIZE
@property
def source_dictionary(self):
return self.dictionary
@property
def target_dictionary(self):
return self.dictionary
@functools.lru_cache()
def get_toy_model(
device: str,
architecture: str = "roberta_enc_dec",
**extra_args: Any,
):
assert device in ("gpu", "cpu")
kwargs = {
"arch": architecture,
# Use characteristics dimensions
"encoder_layers": 3,
"encoder_embed_dim": 12,
"encoder_ffn_embed_dim": 14,
"encoder_attention_heads": 4,
"decoder_layers": 3,
"decoder_embed_dim": 12,
"decoder_ffn_embed_dim": 14,
"decoder_attention_heads": 4,
# Disable dropout so we have comparable tests.
"dropout": 0,
"attention_dropout": 0,
"activation_dropout": 0,
"encoder_layerdrop": 0,
# required args
"tokens_per_sample": 256,
"data": "/tmp/test_roberta",
}
kwargs.update(extra_args)
fake_task = FakeTask(kwargs)
args = fairseq.options.get_args(
task="online_backtranslation",
mono_langs="en,ro",
valid_lang_pairs="en-ro",
**kwargs,
)
torch.manual_seed(0)
model = fake_task.build_model(args)
if device == "gpu":
model.cuda()
return fake_task, model
def mk_sample(
lang: str, device: str, tok: Sequence[int] = None, batch_size: int = 2
) -> Dict[str, Any]:
assert device in ("gpu", "cpu")
if not tok:
if lang == "en":
tok = [10, 11, 12, 13, 14, 15, 2]
else:
tok = [20, 21, 22, 23, 24, 25, 26, 27, 2]
batch = torch.stack([torch.tensor(tok, dtype=torch.long)] * batch_size)
if device == "gpu":
batch = batch.cuda()
sample = {
"net_input": {
"src_tokens": batch,
"prev_output_tokens": batch,
"src_lengths": torch.tensor(
[len(tok)] * batch_size, dtype=torch.long, device=batch.device
),
},
"target": batch[:, 1:],
}
return sample
def cpu_gpu(fn):
def helper(self):
fn(self, "cpu")
if torch.cuda.is_available():
fn(self, "gpu")
return helper
def architectures(fn):
def helper(self):
for arch in ["roberta_enc_dec", "transformer"]:
fn(self, arch)
return helper
class RobertaTest(unittest.TestCase):
def assertTensorEqual(self, t1, t2, delta: float = 1e-6):
self.assertEqual(t1.size(), t2.size(), "size mismatch")
if delta == 0.0:
self.assertEqual(t1.ne(t2).long().sum(), 0)
else:
self.assertEqual(((t2 - t1).abs() > delta).long().sum(), 0)
def assertSharing(self, model, link_groups: Sequence[Sequence[str]]):
ids = {}
for group in link_groups:
group_ids = {name: id(params(model, name)) for name in group}
shared_id = group_ids[group[0]]
self.assertEqual(group_ids, {name: shared_id for name in group})
self.assertNotIn(shared_id, ids)
ids[shared_id] = group
def test_roberta_shared_params(self):
_, roberta = get_toy_model("cpu", architecture="roberta")
self.assertSharing(
roberta,
[
[
"encoder.sentence_encoder.embed_tokens.weight",
"encoder.lm_head.weight",
]
],
)
_, roberta = get_toy_model(
"cpu", architecture="roberta", untie_weights_roberta=True
)
self.assertSharing(
roberta,
[
["encoder.sentence_encoder.embed_tokens.weight"],
["encoder.lm_head.weight"],
],
)
def test_roberta_enc_dec_shared_params(self):
# 3 distinct embeddings
_, enc_dec = get_toy_model("cpu", architecture="roberta_enc_dec")
self.assertSharing(
enc_dec,
[
["encoder.embed_tokens.weight"],
["decoder.embed_tokens.weight"],
["decoder.output_projection.weight"],
],
)
# 2 distinct embeddings, one for encoder, one for decoder
_, enc_dec = get_toy_model(
"cpu", architecture="roberta_enc_dec", share_decoder_input_output_embed=True
)
self.assertSharing(
enc_dec,
[
["encoder.embed_tokens.weight"],
[
"decoder.embed_tokens.weight",
"decoder.output_projection.weight",
],
],
)
# shared embeddings
_, enc_dec = get_toy_model(
"cpu", architecture="roberta_enc_dec", share_all_embeddings=True
)
self.assertSharing(
enc_dec,
[
[
"encoder.embed_tokens.weight",
"decoder.embed_tokens.weight",
"decoder.output_projection.weight",
]
],
)
def test_roberta_max_positions_is_correctly_set(self):
device = "cpu"
task, model = get_toy_model(device)
max_pos = model.max_decoder_positions()
self.assertEqual(max_pos, 256)
self.assertEqual(max_pos, model.decoder.max_positions())
self.assertEqual(max_pos, model.encoder.max_positions())
self.assertEqual(max_pos, model.encoder.embed_positions.max_positions)
sentence = [31 for _ in range(max_pos)]
sample = mk_sample("en", device, sentence, batch_size=1)
self.assertEqual(list(sample["net_input"]["src_lengths"]), [max_pos])
self.assertEqual(len(sample["net_input"]["src_tokens"][0]), max_pos)
x, _ = model.forward(**sample["net_input"])
self.assertEqual(x.shape, (1, max_pos, VOCAB_SIZE))
@cpu_gpu
def test_roberta_forward_backward(self, device: str):
_, model = get_toy_model(device)
sample = mk_sample("en", device)
en_tokens = sample["net_input"]["src_tokens"]
(bs, l) = en_tokens.shape
# Forward
logits, _ = model(**sample["net_input"])
self.assertEqual(logits.shape, (bs, l, VOCAB_SIZE))
# Backward
loss = logits.sum()
loss.backward()
@cpu_gpu
def test_roberta_forward_backward_bs1(self, device: str):
_, model = get_toy_model(device)
sample = mk_sample("en", device, batch_size=1)
o, _ = model.forward(**sample["net_input"])
loss = o.sum()
sample2 = mk_sample("ro", device, batch_size=1)
o, _ = model.forward(**sample2["net_input"])
loss += o.sum()
loss.backward()
@cpu_gpu
def test_roberta_batching(self, device: str):
"""
Checks that the batch of size 2 give twice the same results than the batch of size 1.
"""
_, model = get_toy_model(device)
sample = mk_sample("en", device, batch_size=1)
slen = sample["net_input"]["src_lengths"][0]
sample2 = mk_sample("en", device, batch_size=2)
with torch.no_grad():
z = model.encoder.forward(
sample["net_input"]["src_tokens"], sample["net_input"]["src_lengths"]
)
z = z["encoder_out"][-1]
logits, _ = model.forward(**sample["net_input"])
z2 = model.encoder.forward(
sample2["net_input"]["src_tokens"], sample["net_input"]["src_lengths"]
)
z2 = z2["encoder_out"][-1]
logits2, _ = model.forward(**sample2["net_input"])
self.assertEqual(z.shape, (slen, 1, 12))
self.assertEqual(z2.shape, (slen, 2, 12))
self.assertTensorEqual(logits2[0], logits2[1])
self.assertTensorEqual(logits[0], logits2[0])
@cpu_gpu
def test_roberta_incremental_decoder(self, device: str):
"""
Checks that incremental decoding yields the same result than non incremental one.
"""
task, model = get_toy_model(device)
en_sample = mk_sample("en", device)
en_tokens = en_sample["net_input"]["src_tokens"]
ro_sample = mk_sample("ro", device)
ro_tokens = ro_sample["net_input"]["src_tokens"]
en_enc = model.encoder.forward(
en_tokens, src_lengths=en_sample["net_input"]["src_lengths"]
)
(bs, tgt_len) = ro_tokens.shape
# Decode without incremental state
ro_dec, _ = model.decoder.forward(ro_tokens, encoder_out=en_enc)
self.assertEqual(ro_dec.shape, (bs, tgt_len, VOCAB_SIZE))
self.assertTensorEqual(ro_dec[0], ro_dec[1])
# Decode with incremental state
inc_state = {}
ro_dec_inc = []
for l in range(tgt_len):
ro, _ = model.decoder.forward(
ro_tokens[:, : l + 1], encoder_out=en_enc, incremental_state=inc_state
)
self.assertEqual(ro.shape, (bs, 1, VOCAB_SIZE))
ro_dec_inc.append(ro)
for l in range(tgt_len):
# Intra-batch
self.assertTensorEqual(ro_dec_inc[l][0], ro_dec_inc[l][1])
# Incremental vs non-incremental
self.assertTensorEqual(ro_dec_inc[l][:, 0], ro_dec[:, l])
def params(model, name):
if "." not in name:
return getattr(model, name)
prefix, name = name.split(".", 1)
return params(getattr(model, prefix), name)