mirror of
https://github.com/facebookresearch/fairseq.git
synced 2024-10-05 13:17:39 +03:00
Obt 2 (#1614)
Summary: # Before submitting - [x] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [x] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.m)? - [x] Did you make sure to update the docs? - [x] Did you write any new necessary tests? too many of them actually ^^ ## What does this PR do? This is a rewrite of https://github.com/fairinternal/fairseq-py/issues/1538 following the discussion there, and taking into account the proposed https://github.com/fairinternal/fairseq-py/issues/1560 from Myle. it brings online backtranslation to fairseq. It adds a RobertaEncDec to fairseq. RobertaEncDec can be built from a pretrained Roberta model allowing to do transfer learning. This is crucial for backtranslation. ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1614 Reviewed By: myleott Differential Revision: D27157296 Pulled By: gwenzek fbshipit-source-id: 43020bc27743419bd4b138716165bf5764117c21
This commit is contained in:
parent
7dafb05754
commit
c2e8904b60
@ -296,6 +296,8 @@ class NoisingDataset(torch.utils.data.Dataset):
|
||||
**kwargs,
|
||||
)
|
||||
)
|
||||
self.sizes = src_dataset.sizes
|
||||
|
||||
|
||||
def __getitem__(self, index):
|
||||
"""
|
||||
|
@ -141,7 +141,7 @@ class RoundRobinZipDatasets(FairseqDataset):
|
||||
f"{len(ignored)} samples from {key} have invalid sizes and will be skipped, "
|
||||
f"max_positions={max_positions[key]}, first few sample ids={ignored[:10]}"
|
||||
)
|
||||
# Since we are modifiying in place the _ordered_indices,
|
||||
# Since we are modifying in place the _ordered_indices,
|
||||
# it's not possible anymore to return valid ignored indices.
|
||||
# Hopefully the extra debug information print above should be enough to debug.
|
||||
# Ideally we would receive ignore_invalid_inputs so that we could have
|
||||
|
@ -50,6 +50,9 @@ class TransformEosLangPairDataset(FairseqDataset):
|
||||
def collater(self, samples, **extra_args):
|
||||
samples = self.dataset.collater(samples, **extra_args)
|
||||
|
||||
if 'net_input' not in samples:
|
||||
return samples
|
||||
|
||||
if self.new_src_eos is not None:
|
||||
if self.dataset.left_pad_source:
|
||||
assert (
|
||||
|
@ -5,6 +5,7 @@
|
||||
|
||||
from .hub_interface import * # noqa
|
||||
from .model import * # noqa
|
||||
from .enc_dec import * # noqa
|
||||
from .model_camembert import * # noqa
|
||||
from .model_gottbert import * # noqa
|
||||
from .model_xlmr import * # noqa
|
||||
|
192
fairseq/models/roberta/enc_dec.py
Normal file
192
fairseq/models/roberta/enc_dec.py
Normal file
@ -0,0 +1,192 @@
|
||||
import argparse
|
||||
import logging
|
||||
|
||||
import torch.nn as nn
|
||||
import fairseq.checkpoint_utils
|
||||
from fairseq.models import (
|
||||
FairseqEncoderDecoderModel,
|
||||
register_model,
|
||||
register_model_architecture,
|
||||
)
|
||||
from fairseq.models.transformer import TransformerDecoder
|
||||
from fairseq.models.roberta import model as roberta
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@register_model("roberta_enc_dec")
|
||||
class RobertaEncDecModel(FairseqEncoderDecoderModel):
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
parser.add_argument(
|
||||
"--pretrained-mlm-checkpoint",
|
||||
default=None,
|
||||
type=str,
|
||||
metavar="PRETRAINED",
|
||||
help="path to pretrained mlm checkpoint",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pretrained-decoder", action="store_true", help="reload decoder"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hack-layernorm-embedding",
|
||||
action="store_true",
|
||||
help="hack to reload old models trained with encoder-normalize-before=False (no equivalent to encoder-normalize-before=False and layernorm_embedding=False",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--share-decoder-input-output-embed",
|
||||
action="store_true",
|
||||
help="share decoder input and output embeddings",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--share-all-embeddings",
|
||||
action="store_true",
|
||||
help="share encoder, decoder and output embeddings"
|
||||
" (requires shared dictionary and embed dim)",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def build_model(cls, args, task):
|
||||
"""Build a new model instance."""
|
||||
|
||||
# make sure all arguments are present
|
||||
base_enc_dec_architecture(args)
|
||||
if args.pretrained_mlm_checkpoint:
|
||||
arg_overrides = None
|
||||
if args.hack_layernorm_embedding:
|
||||
arg_overrides = {"layernorm_embedding": False}
|
||||
loaded = fairseq.checkpoint_utils.load_model_ensemble_and_task(
|
||||
[args.pretrained_mlm_checkpoint], arg_overrides=arg_overrides
|
||||
)
|
||||
([roberta_enc], _cfg, _task) = loaded
|
||||
else:
|
||||
# Do we need to edit untie_weights here ?
|
||||
share_in_out = (
|
||||
args.share_decoder_input_output_embed or args.share_all_embeddings
|
||||
)
|
||||
args.untie_weights_roberta = not share_in_out
|
||||
if args.hack_layernorm_embedding:
|
||||
args.layernorm_embedding = False
|
||||
args.encoder_normalize_before = False
|
||||
roberta_enc = roberta.RobertaModel.build_model(args, task)
|
||||
|
||||
return cls.from_roberta(roberta_enc, args, task.source_dictionary)
|
||||
|
||||
@staticmethod
|
||||
def from_roberta(roberta_enc: roberta.RobertaModel, args, dictionary):
|
||||
encoder = roberta_enc.encoder.sentence_encoder
|
||||
vocab_size, embed_dim = encoder.embed_tokens.weight.shape
|
||||
|
||||
if args.share_all_embeddings:
|
||||
lm_head = roberta_enc.encoder.lm_head
|
||||
assert encoder.embed_tokens.weight is lm_head.weight, (
|
||||
"Can't use --share-all-embeddings with a model "
|
||||
"that was pretraiend with --untie-weights-roberta_enc"
|
||||
)
|
||||
else:
|
||||
lm_head = roberta.RobertaLMHead(
|
||||
embed_dim, vocab_size, roberta_enc.args.activation_fn
|
||||
)
|
||||
|
||||
dec_embs = nn.Embedding(vocab_size, embed_dim, dictionary.pad())
|
||||
if args.share_all_embeddings or args.share_decoder_input_output_embed:
|
||||
# Note: I wasn't able to use Embedding _weight parameter to achive this sharing.
|
||||
dec_embs.weight = lm_head.weight
|
||||
|
||||
decoder = TransformerDecoder(
|
||||
RobertaEncDecModel.read_args_from_roberta(roberta_enc.args),
|
||||
dictionary,
|
||||
dec_embs,
|
||||
no_encoder_attn=False,
|
||||
output_projection=lm_head,
|
||||
)
|
||||
if getattr(args, "pretrained_decoder", False):
|
||||
decoder_dict = encoder.state_dict()
|
||||
|
||||
# TODO: hide setting "encoder_attn" layers behind a flag.
|
||||
for k, w in list(decoder_dict.items()):
|
||||
if ".self_attn" in k:
|
||||
k_enc_attn = k.replace(".self_attn", ".encoder_attn")
|
||||
decoder_dict[k_enc_attn] = w.detach().clone()
|
||||
|
||||
for k, w in lm_head.state_dict().items():
|
||||
decoder_dict["output_projection." + k] = w
|
||||
|
||||
missing_keys, unexpected_keys = decoder.load_state_dict(
|
||||
decoder_dict, strict=False
|
||||
)
|
||||
# missing_keys = [m for m in missing_keys if ".encoder_attn" not in m]
|
||||
assert not missing_keys and not unexpected_keys, (
|
||||
"Failed to load state dict. "
|
||||
f"Missing keys: {missing_keys}. "
|
||||
f"Unexpected keys: {unexpected_keys}."
|
||||
)
|
||||
|
||||
if args.share_all_embeddings:
|
||||
assert decoder.output_projection.weight is decoder.embed_tokens.weight
|
||||
assert encoder.embed_tokens.weight is decoder.embed_tokens.weight
|
||||
elif args.share_decoder_input_output_embed:
|
||||
assert decoder.output_projection.weight is decoder.embed_tokens.weight
|
||||
assert encoder.embed_tokens.weight is not decoder.embed_tokens.weight
|
||||
else:
|
||||
assert decoder.output_projection.weight is not decoder.embed_tokens.weight
|
||||
assert encoder.embed_tokens.weight is not decoder.embed_tokens.weight
|
||||
|
||||
return RobertaEncDecModel(encoder, decoder)
|
||||
|
||||
@staticmethod
|
||||
def read_args_from_roberta(roberta_args: argparse.Namespace):
|
||||
# TODO: this would become easier if encoder/decoder where using a similar
|
||||
# TransformerConfig object
|
||||
args = argparse.Namespace(**vars(roberta_args))
|
||||
attr_map = [
|
||||
("encoder_attention_heads", "decoder_attention_heads"),
|
||||
("encoder_embed_dim", "decoder_embed_dim"),
|
||||
("encoder_embed_dim", "decoder_output_dim"),
|
||||
("encoder_normalize_before", "decoder_normalize_before"),
|
||||
("encoder_layers_to_keep", "decoder_layers_to_keep"),
|
||||
("encoder_ffn_embed_dim", "decoder_ffn_embed_dim"),
|
||||
("encoder_layerdrop", "decoder_layerdrop"),
|
||||
("encoder_layers", "decoder_layers"),
|
||||
("encoder_learned_pos", "decoder_learned_pos"),
|
||||
# should this be set from here ?
|
||||
("max_positions", "max_target_positions"),
|
||||
]
|
||||
for k1, k2 in attr_map:
|
||||
setattr(args, k2, getattr(roberta_args, k1))
|
||||
|
||||
args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None)
|
||||
args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0)
|
||||
args.share_decoder_input_output_embed = not roberta_args.untie_weights_roberta
|
||||
return args
|
||||
|
||||
def upgrade_state_dict_named(self, state_dict, name):
|
||||
prefix = name + "." if name != "" else ""
|
||||
super().upgrade_state_dict_named(state_dict, name)
|
||||
old_keys = list(state_dict.keys())
|
||||
|
||||
# rename decoder -> encoder before upgrading children modules
|
||||
for k in old_keys:
|
||||
if k.startswith(prefix + "encoder.lm_head"):
|
||||
state_dict.pop(k)
|
||||
continue
|
||||
new_k = k
|
||||
new_k = new_k.replace(".sentence_encoder.", ".")
|
||||
new_k = new_k.replace("decoder.lm_head.", "decoder.output_projection.")
|
||||
if k == new_k:
|
||||
continue
|
||||
# print(k, "->", new_k)
|
||||
state_dict[new_k] = state_dict.pop(k)
|
||||
|
||||
|
||||
@register_model_architecture("roberta_enc_dec", "roberta_enc_dec")
|
||||
def base_enc_dec_architecture(args):
|
||||
args.hack_layernorm_embedding = getattr(args, "hack_layernorm_embedding", False)
|
||||
args.pretrained_mlm_checkpoint = getattr(args, "pretrained_mlm_checkpoint", None)
|
||||
args.pretrained_decoder = getattr(args, "pretrained_decoder", None)
|
||||
args.share_all_embeddings = getattr(args, "share_all_embeddings", False)
|
||||
args.share_decoder_input_output_embed = getattr(
|
||||
args, "share_decoder_input_output_embed", False
|
||||
)
|
||||
|
||||
roberta.base_architecture(args)
|
@ -204,7 +204,7 @@ class RobertaModel(FairseqEncoderModel):
|
||||
features_only=False,
|
||||
return_all_hiddens=False,
|
||||
classification_head_name=None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
if classification_head_name is not None:
|
||||
features_only = True
|
||||
@ -259,7 +259,7 @@ class RobertaModel(FairseqEncoderModel):
|
||||
checkpoint_file="model.pt",
|
||||
data_name_or_path=".",
|
||||
bpe="gpt2",
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
from fairseq import hub_utils
|
||||
|
||||
@ -464,7 +464,7 @@ class RobertaEncoder(FairseqEncoder):
|
||||
features_only=False,
|
||||
return_all_hiddens=False,
|
||||
masked_tokens=None,
|
||||
**unused
|
||||
**unused,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
|
@ -645,7 +645,14 @@ class TransformerDecoder(FairseqIncrementalDecoder):
|
||||
(default: False).
|
||||
"""
|
||||
|
||||
def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False):
|
||||
def __init__(
|
||||
self,
|
||||
args,
|
||||
dictionary,
|
||||
embed_tokens,
|
||||
no_encoder_attn=False,
|
||||
output_projection=None,
|
||||
):
|
||||
self.args = args
|
||||
super().__init__(dictionary)
|
||||
self.register_buffer("version", torch.Tensor([3]))
|
||||
@ -727,7 +734,11 @@ class TransformerDecoder(FairseqIncrementalDecoder):
|
||||
)
|
||||
|
||||
self.adaptive_softmax = None
|
||||
self.output_projection = None
|
||||
self.output_projection = output_projection
|
||||
if self.output_projection is None:
|
||||
self.build_output_projection(args, dictionary, embed_tokens)
|
||||
|
||||
def build_output_projection(self, args, dictionary, embed_tokens):
|
||||
if args.adaptive_softmax_cutoff is not None:
|
||||
self.adaptive_softmax = AdaptiveSoftmax(
|
||||
len(dictionary),
|
||||
@ -789,7 +800,7 @@ class TransformerDecoder(FairseqIncrementalDecoder):
|
||||
prev_output_tokens (LongTensor): previous decoder outputs of shape
|
||||
`(batch, tgt_len)`, for teacher forcing
|
||||
encoder_out (optional): output from the encoder, used for
|
||||
encoder-side attention
|
||||
encoder-side attention, should be of size T x B x C
|
||||
incremental_state (dict): dictionary used for storing state during
|
||||
:ref:`Incremental decoding`
|
||||
features_only (bool, optional): only return features without
|
||||
@ -802,6 +813,7 @@ class TransformerDecoder(FairseqIncrementalDecoder):
|
||||
- the decoder's output of shape `(batch, tgt_len, vocab)`
|
||||
- a dictionary with any model-specific outputs
|
||||
"""
|
||||
|
||||
x, extra = self.extract_features(
|
||||
prev_output_tokens,
|
||||
encoder_out=encoder_out,
|
||||
@ -810,6 +822,7 @@ class TransformerDecoder(FairseqIncrementalDecoder):
|
||||
alignment_layer=alignment_layer,
|
||||
alignment_heads=alignment_heads,
|
||||
)
|
||||
|
||||
if not features_only:
|
||||
x = self.output_layer(x)
|
||||
return x, extra
|
||||
@ -866,9 +879,19 @@ class TransformerDecoder(FairseqIncrementalDecoder):
|
||||
- the decoder's features of shape `(batch, tgt_len, embed_dim)`
|
||||
- a dictionary with any model-specific outputs
|
||||
"""
|
||||
bs, slen = prev_output_tokens.size()
|
||||
if alignment_layer is None:
|
||||
alignment_layer = self.num_layers - 1
|
||||
|
||||
enc: Optional[Tensor] = None
|
||||
padding_mask: Optional[Tensor] = None
|
||||
if encoder_out is not None:
|
||||
enc = encoder_out["encoder_out"][0]
|
||||
padding_mask = encoder_out["encoder_padding_mask"][0]
|
||||
assert (
|
||||
enc.size()[1] == bs
|
||||
), f"Expected enc.shape == (t, {bs}, c) got {enc.shape}"
|
||||
|
||||
# embed positions
|
||||
positions = None
|
||||
if self.embed_positions is not None:
|
||||
@ -916,15 +939,8 @@ class TransformerDecoder(FairseqIncrementalDecoder):
|
||||
|
||||
x, layer_attn, _ = layer(
|
||||
x,
|
||||
encoder_out["encoder_out"][0]
|
||||
if (encoder_out is not None and len(encoder_out["encoder_out"]) > 0)
|
||||
else None,
|
||||
encoder_out["encoder_padding_mask"][0]
|
||||
if (
|
||||
encoder_out is not None
|
||||
and len(encoder_out["encoder_padding_mask"]) > 0
|
||||
)
|
||||
else None,
|
||||
enc,
|
||||
padding_mask,
|
||||
incremental_state,
|
||||
self_attn_mask=self_attn_mask,
|
||||
self_attn_padding_mask=self_attn_padding_mask,
|
||||
|
@ -147,8 +147,16 @@ class MultiheadAttention(nn.Module):
|
||||
is_tpu = query.device.type == "xla"
|
||||
|
||||
tgt_len, bsz, embed_dim = query.size()
|
||||
src_len = tgt_len
|
||||
assert embed_dim == self.embed_dim
|
||||
assert list(query.size()) == [tgt_len, bsz, embed_dim]
|
||||
if key is not None:
|
||||
src_len, key_bsz, key_embed_dim = key.size()
|
||||
if not torch.jit.is_scripting():
|
||||
assert (key_bsz, key_embed_dim) == (bsz, embed_dim)
|
||||
assert value is not None
|
||||
assert (src_len, bsz, embed_dim) == value.shape
|
||||
|
||||
|
||||
if (
|
||||
not self.onnx_trace
|
||||
@ -262,6 +270,7 @@ class MultiheadAttention(nn.Module):
|
||||
else:
|
||||
assert k is not None
|
||||
k = torch.cat([prev_key, k], dim=1)
|
||||
src_len = k.size(1)
|
||||
if "prev_value" in saved_state:
|
||||
_prev_value = saved_state["prev_value"]
|
||||
assert _prev_value is not None
|
||||
@ -290,7 +299,7 @@ class MultiheadAttention(nn.Module):
|
||||
assert incremental_state is not None
|
||||
incremental_state = self._set_input_buffer(incremental_state, saved_state)
|
||||
assert k is not None
|
||||
src_len = k.size(1)
|
||||
assert k.size(1) == src_len
|
||||
|
||||
# This is part of a workaround to get around fork/join parallelism
|
||||
# not supporting Optional types.
|
||||
|
@ -4,7 +4,8 @@
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import argparse
|
||||
from typing import Callable, List, Optional
|
||||
from pathlib import Path
|
||||
from typing import Callable, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from fairseq import utils
|
||||
@ -361,3 +362,18 @@ def add_model_args(parser):
|
||||
help='model architecture')
|
||||
# fmt: on
|
||||
return group
|
||||
|
||||
|
||||
def get_args(
|
||||
data: Union[str, Path],
|
||||
task: str = "translation",
|
||||
arch: str = "transformer",
|
||||
**overrides
|
||||
):
|
||||
parser = get_training_parser(task)
|
||||
args = parse_args_and_arch(parser, [str(data), "--task", task, "--arch", arch])
|
||||
|
||||
for k, v in overrides.items():
|
||||
setattr(args, k, v)
|
||||
|
||||
return args
|
||||
|
@ -23,6 +23,7 @@ class SequenceGenerator(nn.Module):
|
||||
beam_size=1,
|
||||
max_len_a=0,
|
||||
max_len_b=200,
|
||||
max_len=0,
|
||||
min_len=1,
|
||||
normalize_scores=True,
|
||||
len_penalty=1.0,
|
||||
@ -44,6 +45,8 @@ class SequenceGenerator(nn.Module):
|
||||
beam_size (int, optional): beam width (default: 1)
|
||||
max_len_a/b (int, optional): generate sequences of maximum length
|
||||
ax + b, where x is the source length
|
||||
max_len (int, optional): the maximum length of the generated output
|
||||
(not including end-of-sentence)
|
||||
min_len (int, optional): the minimum length of the generated output
|
||||
(not including end-of-sentence)
|
||||
normalize_scores (bool, optional): normalize scores by the length
|
||||
@ -79,6 +82,7 @@ class SequenceGenerator(nn.Module):
|
||||
self.max_len_a = max_len_a
|
||||
self.max_len_b = max_len_b
|
||||
self.min_len = min_len
|
||||
self.max_len = max_len or self.model.max_decoder_positions()
|
||||
|
||||
self.normalize_scores = normalize_scores
|
||||
self.len_penalty = len_penalty
|
||||
@ -166,7 +170,7 @@ class SequenceGenerator(nn.Module):
|
||||
yield id, src, ref, hypos[i]
|
||||
|
||||
@torch.no_grad()
|
||||
def generate(self, models, sample: Dict[str, Dict[str, Tensor]], **kwargs):
|
||||
def generate(self, models, sample: Dict[str, Dict[str, Tensor]], **kwargs) -> List[List[Dict[str, Tensor]]]:
|
||||
"""Generate translations. Match the api of other fairseq generators.
|
||||
|
||||
Args:
|
||||
@ -232,8 +236,7 @@ class SequenceGenerator(nn.Module):
|
||||
else:
|
||||
max_len = min(
|
||||
int(self.max_len_a * src_len + self.max_len_b),
|
||||
# exclude the EOS marker
|
||||
self.model.max_decoder_positions() - 1,
|
||||
self.max_len - 1,
|
||||
)
|
||||
assert (
|
||||
self.min_len <= max_len
|
||||
@ -275,9 +278,8 @@ class SequenceGenerator(nn.Module):
|
||||
[torch.jit.annotate(List[Dict[str, Tensor]], []) for i in range(bsz)],
|
||||
) # contains lists of dictionaries of infomation about the hypothesis being finalized at each step
|
||||
|
||||
finished = [
|
||||
False for i in range(bsz)
|
||||
] # a boolean array indicating if the sentence at the index is finished or not
|
||||
# a boolean array indicating if the sentence at the index is finished or not
|
||||
finished = [False for i in range(bsz)]
|
||||
num_remaining_sent = bsz # number of sentences remaining
|
||||
|
||||
# number of candidate hypos per step
|
||||
|
677
fairseq/tasks/online_backtranslation.py
Normal file
677
fairseq/tasks/online_backtranslation.py
Normal file
@ -0,0 +1,677 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import contextlib
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
from argparse import Namespace
|
||||
from collections import OrderedDict, defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Dict, Sequence, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import fairseq
|
||||
from fairseq import metrics, options, utils
|
||||
from fairseq.data import (
|
||||
FairseqDataset,
|
||||
LanguagePairDataset,
|
||||
NoisingDataset,
|
||||
PrependTokenDataset,
|
||||
RoundRobinZipDatasets,
|
||||
TransformEosLangPairDataset,
|
||||
data_utils,
|
||||
encoders,
|
||||
)
|
||||
from fairseq.sequence_generator import SequenceGenerator
|
||||
from fairseq.tasks import register_task
|
||||
from fairseq.tasks.translation import TranslationTask, load_langpair_dataset
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PiecewiseLinearFn:
|
||||
"""Piecewise linear function. Can be configured with a string."""
|
||||
|
||||
def __init__(self, pieces: Sequence[Tuple[int, float]]):
|
||||
assert pieces == sorted(
|
||||
pieces
|
||||
), f"PiecewiseLinearFn configuration should be sorted, received: {pieces}"
|
||||
|
||||
self.pieces = pieces
|
||||
|
||||
def __call__(self, x: int) -> float:
|
||||
for i, (x_a, y_a) in enumerate(self.pieces[:-1]):
|
||||
x_b, y_b = self.pieces[i + 1]
|
||||
if x_a <= x <= x_b:
|
||||
return y_a + (x - x_a) * (y_b - y_a) / (x_b - x_a)
|
||||
|
||||
return self.pieces[-1][1]
|
||||
|
||||
@staticmethod
|
||||
def from_string(configuration: str) -> "PiecewiseLinearFn":
|
||||
"""
|
||||
Parse the configuration of lambda coefficient (for scheduling).
|
||||
x = "3" # lambda will be a constant equal to x
|
||||
x = "0:1,1000:0" # lambda will start from 1 and linearly decrease
|
||||
# to 0 during the first 1000 iterations
|
||||
x = "0:0,1000:0,2000:1" # lambda will be equal to 0 for the first 1000
|
||||
# iterations, then will linearly increase to 1 until iteration 2000
|
||||
"""
|
||||
if isinstance(configuration, float):
|
||||
return PiecewiseLinearFn([(0, configuration)])
|
||||
|
||||
try:
|
||||
parts = configuration.split(",")
|
||||
if len(parts) == 1:
|
||||
v = float(configuration)
|
||||
return PiecewiseLinearFn([(0, v)])
|
||||
|
||||
split = [s.split(":") for s in parts]
|
||||
pieces = [(int(t), float(v)) for t, v in split]
|
||||
return PiecewiseLinearFn(pieces)
|
||||
except Exception:
|
||||
raise ValueError(
|
||||
f"Invalid PiecewiseLinearFn configuration: {configuration!r}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def one() -> "PiecewiseLinearFn":
|
||||
return PiecewiseLinearFn([(0, 1.0)])
|
||||
|
||||
|
||||
@register_task("online_backtranslation")
|
||||
class OnlineBackTranslationTask(TranslationTask):
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
"""Add task-specific arguments to the parser."""
|
||||
# fmt: off
|
||||
# Generic translation args
|
||||
parser.add_argument('data', help='colon separated path to data directories list, \
|
||||
will be iterated upon during epochs in round-robin manner; \
|
||||
however, valid and test data are always in the first directory to \
|
||||
avoid the need for repeating them in all directories')
|
||||
parser.add_argument('--mono-langs', metavar='MONO_LANGS',
|
||||
help='monolingual languages for training')
|
||||
parser.add_argument('--valid-lang-pairs', default=None, metavar='VALID_LANG_PAIRS',
|
||||
help='language pairs for validation')
|
||||
parser.add_argument('--load-alignments', action='store_true',
|
||||
help='load the binarized alignments')
|
||||
parser.add_argument('--left-pad-source', default='False', type=str, metavar='BOOL',
|
||||
help='pad the source on the left')
|
||||
parser.add_argument('--left-pad-target', default='False', type=str, metavar='BOOL',
|
||||
help='pad the target on the left')
|
||||
parser.add_argument('--upsample-primary', default=1, type=int,
|
||||
help='amount to upsample primary dataset')
|
||||
parser.add_argument('--max-source-positions', default=1024, type=int, metavar='N',
|
||||
help='max number of tokens in the source sequence')
|
||||
parser.add_argument('--max-target-positions', default=1024, type=int, metavar='N',
|
||||
help='max number of tokens in the target sequence')
|
||||
parser.add_argument('--truncate-source', action='store_true', default=False,
|
||||
help='truncate source to max-source-positions')
|
||||
parser.add_argument('--num-batch-buckets', default=0, type=int, metavar='N',
|
||||
help='if >0, then bucket source and target lengths into N '
|
||||
'buckets and pad accordingly; this is useful on TPUs '
|
||||
'to minimize the number of compilations')
|
||||
|
||||
# Denoising args
|
||||
parser.add_argument('--max-word-shuffle-distance', default=3.0, type=float, metavar='N',
|
||||
help='maximum word shuffle distance for denoising autoencoding data generation')
|
||||
parser.add_argument('--word-dropout-prob', default=0.1, type=float, metavar='N',
|
||||
help='word dropout probability for denoising autoencoding data generation')
|
||||
parser.add_argument('--word-blanking-prob', default=0.2, type=float, metavar='N',
|
||||
help='word blanking probability for denoising autoencoding data generation')
|
||||
|
||||
# Backtranslation args
|
||||
parser.add_argument('--lambda-bt', default="1.0", type=str, metavar='N',
|
||||
help='back-translation weight')
|
||||
parser.add_argument('--lambda-dae', default="1.0", type=str, metavar='N',
|
||||
help='denoising auto-encoder weight')
|
||||
|
||||
# Evaluation args
|
||||
parser.add_argument('--generate-one-by-one', action='store_true',
|
||||
help='generate one sentence at a time for backtranslation')
|
||||
|
||||
parser.add_argument('--eval-bleu', action='store_true',
|
||||
help='evaluation with BLEU scores')
|
||||
parser.add_argument('--eval-bleu-detok', type=str, default="space",
|
||||
help='detokenize before computing BLEU (e.g., "moses"); '
|
||||
'required if using --eval-bleu; use "space" to '
|
||||
'disable detokenization; see fairseq.data.encoders '
|
||||
'for other options')
|
||||
parser.add_argument('--eval-bleu-detok-args', type=str, metavar='JSON',
|
||||
help='args for building the tokenizer, if needed')
|
||||
parser.add_argument('--eval-tokenized-bleu', action='store_true', default=False,
|
||||
help='compute tokenized BLEU instead of sacrebleu')
|
||||
parser.add_argument('--eval-bleu-remove-bpe', nargs='?', const='@@ ', default=None,
|
||||
help='remove BPE before computing BLEU')
|
||||
parser.add_argument('--eval-bleu-args', type=str, metavar='JSON',
|
||||
help='generation args for BLUE scoring, '
|
||||
'e.g., \'{"beam": 4, "lenpen": 0.6}\'')
|
||||
parser.add_argument('--eval-bleu-print-samples', action='store_true',
|
||||
help='print sample generations during validation')
|
||||
# fmt: on
|
||||
|
||||
def __init__(self, args, common_dict, mono_langs, valid_lang_pairs):
|
||||
super().__init__(args, common_dict, common_dict)
|
||||
self.common_dict = common_dict
|
||||
self.mono_langs = mono_langs
|
||||
self.valid_lang_pairs = valid_lang_pairs
|
||||
|
||||
self.SHOW_SAMPLES_INTERVAL = 1000
|
||||
# Start by showing samples
|
||||
self._show_samples_ctr = self.SHOW_SAMPLES_INTERVAL
|
||||
self.SHOW_SAMPLES_NUMBER = 5
|
||||
self.lambda_bt = PiecewiseLinearFn.from_string(args.lambda_bt)
|
||||
self.lambda_dae = PiecewiseLinearFn.from_string(args.lambda_dae)
|
||||
|
||||
self.args = args
|
||||
self.data = utils.split_paths(self.args.data)
|
||||
if len(self.data) == 1:
|
||||
shards = list(Path(self.data[0]).glob("shard*"))
|
||||
if len(shards) > 0:
|
||||
# keep this as strings, since it can also be a manifold path
|
||||
old_data = self.data
|
||||
self.data = [str(shard) for shard in shards]
|
||||
logging.warning(f"Expanded data directory {old_data} to {self.data}")
|
||||
|
||||
@classmethod
|
||||
def setup_task(cls, args, **kwargs):
|
||||
"""Setup the task (e.g., load dictionaries).
|
||||
|
||||
Args:
|
||||
args (argparse.Namespace): parsed command-line arguments
|
||||
"""
|
||||
args.left_pad_source = options.eval_bool(args.left_pad_source)
|
||||
args.left_pad_target = options.eval_bool(args.left_pad_target)
|
||||
|
||||
paths = utils.split_paths(args.data)
|
||||
assert len(paths) > 0
|
||||
assert args.mono_langs is not None
|
||||
|
||||
mono_langs = args.mono_langs.split(",")
|
||||
valid_lang_pairs = args.valid_lang_pairs.split(",")
|
||||
|
||||
# load dictionary
|
||||
dict_path = os.path.join(paths[0], "dict.txt")
|
||||
common_dict = cls.load_dictionary(dict_path)
|
||||
|
||||
return cls(args, common_dict, mono_langs, valid_lang_pairs)
|
||||
|
||||
def load_dataset(self, split, epoch=1, combine=False, **kwargs) -> FairseqDataset:
|
||||
"""Load a given dataset split.
|
||||
|
||||
Args:
|
||||
split (str): name of the split (e.g., train, valid, test)
|
||||
"""
|
||||
if split == "train":
|
||||
data_path = self.data[(epoch - 1) % len(self.data)]
|
||||
dataset = self.load_train_dataset(data_path)
|
||||
else:
|
||||
# valid/test should always be the same.
|
||||
dataset = self.load_translation_dataset(split, self.data[0])
|
||||
|
||||
self.datasets[split] = dataset
|
||||
return dataset
|
||||
|
||||
def load_train_dataset(self, data_path: str) -> FairseqDataset:
|
||||
"""The training dataset is made of backtranslation dataset and denoising dataset."""
|
||||
data = []
|
||||
for lang in self.mono_langs:
|
||||
train_path = os.path.join(data_path, lang, "train")
|
||||
# TODO: could we do the BT using denoise sample ?
|
||||
# this would half the data loading work
|
||||
data.append((f"{lang}-BT", self.load_bt_dataset(train_path, lang)))
|
||||
data.append(
|
||||
(f"{lang}-DENOISE", self.load_denoise_dataset(train_path, lang))
|
||||
)
|
||||
|
||||
return RoundRobinZipDatasets(OrderedDict(data))
|
||||
|
||||
def _langpair_dataset(
|
||||
self, src: FairseqDataset, tgt: FairseqDataset
|
||||
) -> LanguagePairDataset:
|
||||
return LanguagePairDataset(
|
||||
src,
|
||||
src.sizes,
|
||||
self.dictionary,
|
||||
tgt=tgt,
|
||||
tgt_sizes=tgt.sizes,
|
||||
tgt_dict=self.dictionary,
|
||||
left_pad_source=self.args.left_pad_source,
|
||||
left_pad_target=self.args.left_pad_target,
|
||||
# TODO: should we shuffle ? we are already sorting batch by sizes so ?
|
||||
# shuffle=True,
|
||||
)
|
||||
|
||||
def _prepend_lang_bos_to_target(
|
||||
self, dataset: LanguagePairDataset, lang: str
|
||||
) -> LanguagePairDataset:
|
||||
bos = _lang_token_index(self.dictionary, lang)
|
||||
return TransformEosLangPairDataset(
|
||||
dataset,
|
||||
src_eos=self.dictionary.eos(),
|
||||
new_src_eos=self.dictionary.eos(),
|
||||
tgt_bos=self.dictionary.eos(),
|
||||
new_tgt_bos=bos,
|
||||
)
|
||||
|
||||
def load_bt_dataset(self, data_path: str, lang: str) -> FairseqDataset:
|
||||
"""The BT dataset is generated with (tgt, tgt) pairs.
|
||||
The actual translation to a (generated_src, tgt) pair
|
||||
is done on the fly during training.
|
||||
"""
|
||||
mono_dataset = data_utils.load_indexed_dataset(
|
||||
data_path, self.common_dict, self.args.dataset_impl
|
||||
)
|
||||
assert mono_dataset is not None, f"No dataset found for {lang}"
|
||||
|
||||
mono_dataset_src = PrependTokenDataset(
|
||||
mono_dataset, _lang_token_index(self.dictionary, lang)
|
||||
)
|
||||
|
||||
mono_dataset_bt = self._langpair_dataset(mono_dataset_src, mono_dataset)
|
||||
logger.info(
|
||||
f"mono_lang = {lang} "
|
||||
f"lang token index = {_lang_token_index(self.dictionary, lang)} "
|
||||
f"lang token = {_lang_token(lang)}"
|
||||
)
|
||||
|
||||
mono_dataset_bt = self._prepend_lang_bos_to_target(mono_dataset_bt, lang)
|
||||
return mono_dataset_bt
|
||||
|
||||
def load_denoise_dataset(self, data_path: str, lang: str) -> FairseqDataset:
|
||||
"""Classic denoising dataset"""
|
||||
dataset = data_utils.load_indexed_dataset(
|
||||
data_path, self.common_dict, self.args.dataset_impl
|
||||
)
|
||||
noisy_dataset = NoisingDataset(
|
||||
dataset,
|
||||
self.dictionary,
|
||||
seed=1,
|
||||
max_word_shuffle_distance=self.args.max_word_shuffle_distance,
|
||||
word_dropout_prob=self.args.word_dropout_prob,
|
||||
word_blanking_prob=self.args.word_blanking_prob,
|
||||
)
|
||||
noisy_dataset = PrependTokenDataset(
|
||||
noisy_dataset, _lang_token_index(self.dictionary, lang)
|
||||
)
|
||||
|
||||
clean_dataset = data_utils.load_indexed_dataset(
|
||||
data_path, self.common_dict, self.args.dataset_impl
|
||||
)
|
||||
denoising_dataset = self._langpair_dataset(noisy_dataset, clean_dataset)
|
||||
denoising_dataset = self._prepend_lang_bos_to_target(denoising_dataset, lang)
|
||||
return denoising_dataset
|
||||
|
||||
def load_translation_dataset(
|
||||
self, split: str, data_path: str, combine: bool = False
|
||||
):
|
||||
# only judging with one language pair for the moment,
|
||||
# since ConcatDataset doesn't work as expected
|
||||
assert len(self.valid_lang_pairs) == 1, "For now..."
|
||||
valid_lang_pair = self.valid_lang_pairs[0]
|
||||
src, tgt = valid_lang_pair.split("-")
|
||||
|
||||
# use the same function than TranslationTask
|
||||
src_tgt_dt = load_langpair_dataset(
|
||||
data_path,
|
||||
split,
|
||||
src,
|
||||
self.common_dict,
|
||||
tgt,
|
||||
self.common_dict,
|
||||
combine=combine,
|
||||
dataset_impl=self.args.dataset_impl,
|
||||
upsample_primary=self.args.upsample_primary,
|
||||
left_pad_source=self.args.left_pad_source,
|
||||
left_pad_target=self.args.left_pad_target,
|
||||
max_source_positions=self.args.max_source_positions,
|
||||
max_target_positions=self.args.max_target_positions,
|
||||
load_alignments=self.args.load_alignments,
|
||||
truncate_source=self.args.truncate_source,
|
||||
num_buckets=self.args.num_batch_buckets,
|
||||
shuffle=(split != "test"),
|
||||
prepend_bos_src=_lang_token_index(self.dictionary, src),
|
||||
)
|
||||
|
||||
src_tgt_eos_dt = self._prepend_lang_bos_to_target(src_tgt_dt, tgt)
|
||||
src_tgt_eos_dt.args = self.args
|
||||
return src_tgt_eos_dt
|
||||
|
||||
def build_dataset_for_inference(self, src_tokens, src_lengths, constraints=None):
|
||||
raise NotImplementedError
|
||||
|
||||
def build_model(self, args):
|
||||
# torch.autograd.set_detect_anomaly(True)
|
||||
model = super().build_model(args)
|
||||
|
||||
add_secial_tokens_to_dict_and_model(self.common_dict, model, self.mono_langs)
|
||||
|
||||
self.sequence_generators = {}
|
||||
for mono_lang in self.mono_langs:
|
||||
self.sequence_generators[mono_lang] = SequenceGenerator(
|
||||
[model],
|
||||
tgt_dict=self.dictionary,
|
||||
beam_size=1,
|
||||
max_len_a=1.3,
|
||||
max_len_b=5,
|
||||
min_len=5,
|
||||
# keep 1 to be able to prepend bos
|
||||
max_len=model.max_decoder_positions() - 1,
|
||||
)
|
||||
|
||||
if getattr(args, "eval_bleu", False):
|
||||
assert getattr(args, "eval_bleu_detok", None) is not None, (
|
||||
"--eval-bleu-detok is required if using --eval-bleu; "
|
||||
"try --eval-bleu-detok=moses (or --eval-bleu-detok=space "
|
||||
"to disable detokenization, e.g., when using sentencepiece)"
|
||||
)
|
||||
detok_args = json.loads(getattr(args, "eval_bleu_detok_args", "{}") or "{}")
|
||||
self.tokenizer = encoders.build_tokenizer(
|
||||
Namespace(
|
||||
tokenizer=getattr(args, "eval_bleu_detok", None), **detok_args
|
||||
)
|
||||
)
|
||||
|
||||
gen_args = json.loads(getattr(args, "eval_bleu_args", "{}") or "{}")
|
||||
self.bleu_sequence_generator = self.build_generator(
|
||||
[model], Namespace(**gen_args)
|
||||
)
|
||||
|
||||
return model
|
||||
|
||||
def max_positions(self):
|
||||
"""Return the max sentence length allowed by the task."""
|
||||
return (self.args.max_source_positions, self.args.max_target_positions)
|
||||
|
||||
@property
|
||||
def dictionary(self):
|
||||
"""Return the source :class:`~fairseq.data.Dictionary`."""
|
||||
return self.common_dict
|
||||
|
||||
def display_samples_once_in_a_while(self, smp, mono_lang, other_lang):
|
||||
self._show_samples_ctr += 1
|
||||
if self._show_samples_ctr < self.SHOW_SAMPLES_INTERVAL:
|
||||
return
|
||||
self._show_samples_ctr = 0
|
||||
|
||||
ln = smp["net_input"]["src_tokens"].shape[0]
|
||||
|
||||
logger.info(
|
||||
f"(r:{self.args.distributed_rank}) : "
|
||||
f"{other_lang} ---> {mono_lang} "
|
||||
f"({other_lang} was generated by back-translation.) {ln} samples"
|
||||
)
|
||||
|
||||
for i in range(min(ln, self.SHOW_SAMPLES_NUMBER)):
|
||||
src_tokens = smp["net_input"]["src_tokens"][i]
|
||||
tgt_tokens = smp["target"][i]
|
||||
|
||||
src_str = self.dictionary.string(src_tokens, "sentencepiece")
|
||||
tgt_str = self.dictionary.string(tgt_tokens, "sentencepiece")
|
||||
logger.info(
|
||||
f"\n{i}\t\t[{other_lang} generated] {src_str}\n"
|
||||
f"\t\t[{mono_lang} original ] {tgt_str}\n"
|
||||
f"\t\t[ src tokens] {src_tokens}\n"
|
||||
)
|
||||
|
||||
def backtranslate_sample(self, smp, orig_lang, other_lang) -> None:
|
||||
"""
|
||||
* WARNING: smp is modified in place.
|
||||
* At the start of this function, `smp` has the same input and target:
|
||||
|--------------------------------------------------------|
|
||||
| smp['net_input']['src_tokens'] | smp['target'] |
|
||||
| (from data) __en__ hello world | __en__ hello world |
|
||||
|--------------------------------------------------------|
|
||||
|
||||
* We call generator.generate(smp, bos_token = token("ro")),
|
||||
and copy the result as input
|
||||
* At the end, `smp` has the translation to other language.
|
||||
|--------------------------------------------------------|
|
||||
| smp['net_input']['src_tokens'] | smp['target'] |
|
||||
| (generated) __ro__ salut lume | __en__ hello world |
|
||||
|--------------------------------------------------------|
|
||||
|
||||
"""
|
||||
bos_token = _lang_token_index(self.dictionary, other_lang)
|
||||
generated = self.sequence_generators[orig_lang].generate(
|
||||
models=[], sample=smp, bos_token=bos_token
|
||||
)
|
||||
|
||||
max_lngth = max([gn[0]["tokens"].size(0) for gn in generated])
|
||||
net_input = smp["net_input"]
|
||||
n_src_tokens = torch.empty(
|
||||
size=(len(generated), max_lngth + 1), dtype=net_input["src_tokens"].dtype
|
||||
)
|
||||
n_src_lengths = torch.empty(
|
||||
len(generated), dtype=net_input["src_lengths"].dtype
|
||||
)
|
||||
|
||||
for i, gn in enumerate(generated):
|
||||
tokens = gn[0]["tokens"]
|
||||
tokens_size = tokens.size(0)
|
||||
padding_needed = max_lngth - tokens_size
|
||||
tokens = torch.cat([tokens.new([bos_token]), tokens])
|
||||
tokens = F.pad(tokens, (0, padding_needed), value=self.dictionary.pad())
|
||||
n_src_tokens[i] = tokens
|
||||
n_src_lengths[i] = tokens_size + 1
|
||||
|
||||
device = net_input["src_tokens"].device
|
||||
# This seems to be important
|
||||
del net_input["src_tokens"]
|
||||
del net_input["src_lengths"]
|
||||
net_input["src_tokens"] = n_src_tokens.to(device)
|
||||
net_input["src_lengths"] = n_src_lengths.to(device)
|
||||
|
||||
def generate(self, smp, model):
|
||||
model.eval()
|
||||
orig_lang = (
|
||||
self.dictionary[smp["net_input"]["src_tokens"][0][0]]
|
||||
.replace(" ", "")
|
||||
.replace("_", "")
|
||||
)
|
||||
bos_token = smp["net_input"]["prev_output_tokens"][0][0]
|
||||
with torch.no_grad():
|
||||
generated = self.sequence_generators[orig_lang].generate(
|
||||
models=[model], sample=smp, bos_token=bos_token
|
||||
)
|
||||
return generated
|
||||
|
||||
def get_other_lang(self, lang):
|
||||
# TODO: allow more complex mapping
|
||||
if lang != self.mono_langs[0]:
|
||||
return self.mono_langs[0]
|
||||
if len(self.mono_langs) == 2:
|
||||
return self.mono_langs[1]
|
||||
return self.mono_langs[np.random.randint(1, len(self.mono_langs))]
|
||||
|
||||
def train_step(
|
||||
self, sample, model, criterion, optimizer, update_num, ignore_grad=False
|
||||
):
|
||||
|
||||
model.train()
|
||||
model.set_num_updates(update_num)
|
||||
|
||||
agg_loss, agg_sample_size = 0.0, 0.0
|
||||
agg_logging_output: Dict[str, float] = defaultdict(float)
|
||||
|
||||
dataset_keys = self.datasets["train"].datasets.keys()
|
||||
|
||||
weights = {
|
||||
"BT": self.lambda_bt(update_num),
|
||||
"DENOISE": self.lambda_dae(update_num),
|
||||
}
|
||||
log_keys = {"BT": "bt_", "DENOISE": "dae_"}
|
||||
|
||||
for dataset_key in dataset_keys:
|
||||
smp = sample[dataset_key]
|
||||
mono_lang, task_subtype = dataset_key.split("-")
|
||||
if weights[task_subtype] == 0:
|
||||
continue
|
||||
|
||||
if task_subtype == "BT":
|
||||
with torch.autograd.profiler.record_function("backtranslation"):
|
||||
model.eval()
|
||||
# TODO: Could we translate to several language at once ?
|
||||
# this would allow to share encoder_out and maximize GPU usage.
|
||||
other_lang = self.get_other_lang(mono_lang)
|
||||
self.backtranslate_sample(smp, mono_lang, other_lang)
|
||||
self.display_samples_once_in_a_while(smp, mono_lang, other_lang)
|
||||
model.train()
|
||||
|
||||
# Like in FairseqTask.train_step
|
||||
with torch.autograd.profiler.record_function("forward"):
|
||||
loss, sample_size, logging_output = criterion(model, smp)
|
||||
loss *= weights[task_subtype]
|
||||
if ignore_grad:
|
||||
loss *= 0
|
||||
with torch.autograd.profiler.record_function("backward"):
|
||||
optimizer.backward(loss)
|
||||
|
||||
agg_loss += loss.item()
|
||||
agg_sample_size += sample_size
|
||||
for k in logging_output:
|
||||
agg_logging_output[log_keys[task_subtype] + k] += logging_output[k]
|
||||
agg_logging_output[k] += logging_output[k]
|
||||
|
||||
return agg_loss, agg_sample_size, agg_logging_output
|
||||
|
||||
def get_bos_token_from_sample(self, sample):
|
||||
net_input = sample["net_input"]
|
||||
source_lang_token_id = torch.unique(net_input["src_tokens"][:, 0]).item()
|
||||
source_lang_token = self.dictionary[source_lang_token_id].replace("_", "")
|
||||
target_lang_token_id = _lang_token_index(
|
||||
self.dictionary, self.get_other_lang(source_lang_token)
|
||||
)
|
||||
|
||||
return target_lang_token_id
|
||||
|
||||
def reduce_metrics(self, logging_outputs, criterion):
|
||||
super().reduce_metrics(logging_outputs, criterion)
|
||||
bt_sample_size = sum(x.get("bt_sample_size", 0) for x in logging_outputs)
|
||||
if bt_sample_size:
|
||||
bt_loss_sum = sum(x.get("bt_loss", 0) for x in logging_outputs)
|
||||
bt_loss_sum *= 1 / bt_sample_size / math.log(2)
|
||||
metrics.log_scalar("bt_loss", bt_loss_sum, bt_sample_size, round=3)
|
||||
|
||||
bt_nll_loss_sum = sum(x.get("bt_nll_loss", 0) for x in logging_outputs)
|
||||
bt_ntokens = sum(x.get("bt_ntokens", 0) for x in logging_outputs)
|
||||
bt_nll_loss_sum *= 1 / bt_ntokens / math.log(2)
|
||||
metrics.log_scalar("bt_nll_loss", bt_nll_loss_sum, bt_ntokens, round=3)
|
||||
metrics.log_derived(
|
||||
"bt_ppl", lambda meters: utils.get_perplexity(meters["bt_nll_loss"].avg)
|
||||
)
|
||||
|
||||
dae_sample_size = sum(x.get("dae_sample_size", 0) for x in logging_outputs)
|
||||
if dae_sample_size:
|
||||
dae_loss_sum = sum(x.get("dae_loss", 0) for x in logging_outputs)
|
||||
dae_loss_sum *= 1 / dae_sample_size / math.log(2)
|
||||
metrics.log_scalar("dae_loss", dae_loss_sum, dae_sample_size, round=3)
|
||||
|
||||
dae_nll_loss_sum = sum(x.get("dae_nll_loss", 0) for x in logging_outputs)
|
||||
dae_ntokens = sum(x.get("dae_ntokens", 0) for x in logging_outputs)
|
||||
dae_nll_loss_sum *= 1 / dae_ntokens / math.log(2)
|
||||
metrics.log_scalar("dae_nll_loss", dae_nll_loss_sum, dae_ntokens, round=3)
|
||||
metrics.log_derived(
|
||||
"dae_ppl",
|
||||
lambda meters: utils.get_perplexity(meters["dae_nll_loss"].avg),
|
||||
)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def extend_embedding(
|
||||
emb: nn.Module, new_vocab_size: int, copy_from_token_id: int
|
||||
) -> None:
|
||||
old_emb_data = emb.weight.data
|
||||
(old_vocab_size, dim) = old_emb_data.shape
|
||||
assert new_vocab_size >= old_vocab_size
|
||||
|
||||
if new_vocab_size > old_vocab_size:
|
||||
emb.weight.data = torch.zeros((new_vocab_size, dim))
|
||||
emb.weight.data[:old_vocab_size, :] = old_emb_data
|
||||
# initialize new embeddings
|
||||
emb.weight.data[old_vocab_size:, :] = old_emb_data[copy_from_token_id]
|
||||
if hasattr(emb, "num_embeddings"):
|
||||
emb.num_embeddings = new_vocab_size
|
||||
if hasattr(emb, "out_features"):
|
||||
emb.out_features = new_vocab_size
|
||||
|
||||
if getattr(emb, "bias", None) is None:
|
||||
return
|
||||
|
||||
# Fix the bias.
|
||||
# Bias shape can be different from the previous vocab size
|
||||
# if the weight matrix was shared and alread extended but not the bias.
|
||||
(old_vocab_size,) = emb.bias.shape
|
||||
assert new_vocab_size >= old_vocab_size
|
||||
if new_vocab_size > old_vocab_size:
|
||||
old_bias = emb.bias.data
|
||||
new_bias = torch.zeros(
|
||||
(new_vocab_size,), dtype=old_bias.dtype, device=old_bias.device
|
||||
)
|
||||
new_bias[:old_vocab_size] = old_bias
|
||||
emb.bias.data = new_bias
|
||||
|
||||
|
||||
def add_secial_tokens_to_dict_and_model(
|
||||
dictionary: "fairseq.data.Dictionary",
|
||||
model: nn.Module,
|
||||
mono_langs: Sequence[str],
|
||||
) -> None:
|
||||
embs = model.encoder.embed_tokens
|
||||
vocab_size, embedding_dim = embs.weight.shape
|
||||
|
||||
# The model may or may not have a '<mask>' embedding yet
|
||||
assert (
|
||||
len(dictionary) <= vocab_size <= len(dictionary) + 1
|
||||
), f"Dictionary len ({len(dictionary)}) doesn't match embs shape ({embs.weight.shape})"
|
||||
# TODO: we should reuse the pretrained model dict which already has <mask>
|
||||
dictionary.add_symbol("<mask>")
|
||||
|
||||
for lang in mono_langs:
|
||||
lang_token = _lang_token(lang)
|
||||
dictionary.add_symbol(lang_token)
|
||||
logger.info(
|
||||
f"dictionary: {len(dictionary)} -> {vocab_size} tokens "
|
||||
f"after adding {len(mono_langs)} lang tokens."
|
||||
)
|
||||
|
||||
if len(dictionary) <= vocab_size:
|
||||
return
|
||||
|
||||
extend_embedding(embs, len(dictionary), dictionary.bos())
|
||||
dec_embs = model.decoder.embed_tokens
|
||||
extend_embedding(dec_embs, len(dictionary), dictionary.bos())
|
||||
lm_head = model.decoder.output_projection
|
||||
extend_embedding(lm_head, len(dictionary), dictionary.bos())
|
||||
assert lm_head.weight.shape == (len(dictionary), embedding_dim)
|
||||
|
||||
|
||||
def _lang_token(lang: str) -> str:
|
||||
return f"__{lang}__"
|
||||
|
||||
|
||||
def _lang_token_index(dictionary, lang: str) -> int:
|
||||
return dictionary.index(_lang_token(lang))
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def assert_weights_have_changed(model: nn.Module):
|
||||
def checksum(model: nn.Module) -> float:
|
||||
return sum(p.sum().item() for p in model.parameters())
|
||||
|
||||
initial_checksum = checksum(model)
|
||||
yield model
|
||||
final_checksum = checksum(model)
|
||||
logger.info(
|
||||
f"initial_checksum={initial_checksum} -> final_checksum={final_checksum}"
|
||||
)
|
||||
assert initial_checksum != final_checksum, "Model hasn't changed !"
|
@ -57,6 +57,7 @@ def load_langpair_dataset(
|
||||
num_buckets=0,
|
||||
shuffle=True,
|
||||
pad_to_multiple=1,
|
||||
prepend_bos_src=None,
|
||||
):
|
||||
def split_exists(split, src, tgt, lang, data_path):
|
||||
filename = os.path.join(data_path, "{}.{}-{}.{}".format(split, src, tgt, lang))
|
||||
@ -128,6 +129,9 @@ def load_langpair_dataset(
|
||||
src_dataset = PrependTokenDataset(src_dataset, src_dict.bos())
|
||||
if tgt_dataset is not None:
|
||||
tgt_dataset = PrependTokenDataset(tgt_dataset, tgt_dict.bos())
|
||||
elif prepend_bos_src is not None:
|
||||
logger.info(f"prepending src bos: {prepend_bos_src}")
|
||||
src_dataset = PrependTokenDataset(src_dataset, prepend_bos_src)
|
||||
|
||||
eos = None
|
||||
if append_source_id:
|
||||
|
@ -38,7 +38,7 @@ class TranslationFromPretrainedBARTTask(TranslationTask):
|
||||
"""Add task-specific arguments to the parser."""
|
||||
# fmt: off
|
||||
TranslationTask.add_args(parser)
|
||||
parser.add_argument('--langs', required=True, metavar='LANG',
|
||||
parser.add_argument('--langs', type=str, metavar='LANG',
|
||||
help='comma-separated list of monolingual language, '
|
||||
'for example, "en,de,fr". These should match the '
|
||||
'langs from pretraining (and be in the same order). '
|
||||
|
@ -80,9 +80,6 @@ def main(cfg: FairseqConfig) -> None:
|
||||
|
||||
# Setup task, e.g., translation, language modeling, etc.
|
||||
task = tasks.setup_task(cfg.task)
|
||||
# Load valid dataset (we load training data below, based on the latest checkpoint)
|
||||
for valid_sub_split in cfg.dataset.valid_subset.split(","):
|
||||
task.load_dataset(valid_sub_split, combine=False, epoch=1)
|
||||
|
||||
assert cfg.criterion, "Please specify criterion to train a model"
|
||||
|
||||
@ -111,6 +108,11 @@ def main(cfg: FairseqConfig) -> None:
|
||||
)
|
||||
)
|
||||
|
||||
# Load valid dataset (we load training data below, based on the latest checkpoint)
|
||||
# We load the valid dataset AFTER building the model
|
||||
for valid_sub_split in cfg.dataset.valid_subset.split(","):
|
||||
task.load_dataset(valid_sub_split, combine=False, epoch=1)
|
||||
|
||||
# (optionally) Configure quantization
|
||||
if cfg.common.quantization_config_path is not None:
|
||||
quantizer = quantization_utils.Quantizer(
|
||||
|
206
tests/test_online_backtranslation.py
Normal file
206
tests/test_online_backtranslation.py
Normal file
@ -0,0 +1,206 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import tempfile
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Sequence
|
||||
|
||||
import fairseq.data.indexed_dataset as indexed_dataset
|
||||
import fairseq.options
|
||||
import fairseq.tasks.online_backtranslation as obt
|
||||
import torch
|
||||
from tests import utils
|
||||
|
||||
|
||||
def mk_sample(tokens: Sequence[int], batch_size: int = 2) -> Dict[str, Any]:
|
||||
batch = torch.stack([torch.tensor(tokens, dtype=torch.long)] * batch_size)
|
||||
sample = {
|
||||
"net_input": {
|
||||
"src_tokens": batch,
|
||||
"prev_output_tokens": batch,
|
||||
"src_lengths": torch.tensor([len(tokens)] * batch_size, dtype=torch.long),
|
||||
},
|
||||
"target": batch[:, 1:],
|
||||
}
|
||||
return sample
|
||||
|
||||
|
||||
def mk_dataset(num_samples: int, max_len: int, output: Path):
|
||||
output.parent.mkdir(exist_ok=True)
|
||||
idx = indexed_dataset.IndexedDatasetBuilder(str(output))
|
||||
data = torch.randint(5, 100, (num_samples, max_len))
|
||||
lengths = torch.randint(3, max_len, (num_samples,))
|
||||
for d, l in zip(data, lengths):
|
||||
d[0] = 0
|
||||
idx.add_item(d[:l])
|
||||
idx.finalize(output.with_suffix(".idx"))
|
||||
assert output.exists()
|
||||
assert output.with_suffix(".idx").exists()
|
||||
|
||||
|
||||
class OnlineBacktranslationTest(unittest.TestCase):
|
||||
|
||||
tmp_dir = Path(tempfile.mkdtemp(suffix="OnlineBacktranslationTest"))
|
||||
|
||||
@classmethod
|
||||
def obt_task(
|
||||
cls, languages: Sequence[str], data: Path = None, language_mapping: str = None
|
||||
):
|
||||
dict_path = cls.tmp_dir / "dict.txt"
|
||||
if not dict_path.exists():
|
||||
dictionary = utils.dummy_dictionary(100)
|
||||
dictionary.save(str(dict_path))
|
||||
|
||||
if data is not None:
|
||||
(data / "dict.txt").write_text(dict_path.read_text())
|
||||
else:
|
||||
data = cls.tmp_dir
|
||||
assert len(languages) >= 2
|
||||
|
||||
kwargs = {
|
||||
"arch": "transformer",
|
||||
# --max-sentences=1 for better predictability of batches
|
||||
"max_sentences": 1,
|
||||
# Use characteristics dimensions
|
||||
"encoder_layers": 3,
|
||||
"encoder_embed_dim": 12,
|
||||
"encoder_ffn_embed_dim": 14,
|
||||
"encoder_attention_heads": 4,
|
||||
"decoder_layers": 3,
|
||||
"decoder_embed_dim": 12,
|
||||
"decoder_output_dim": 12,
|
||||
"decoder_ffn_embed_dim": 14,
|
||||
"decoder_attention_heads": 4,
|
||||
# Disable dropout so we have comparable tests.
|
||||
"dropout": 0,
|
||||
"attention_dropout": 0,
|
||||
"activation_dropout": 0,
|
||||
"encoder_layerdrop": 0,
|
||||
}
|
||||
|
||||
args = fairseq.options.get_args(
|
||||
data,
|
||||
task="online_backtranslation",
|
||||
mono_langs=",".join(languages),
|
||||
valid_lang_pairs=f"{languages[0]}-{languages[1]}",
|
||||
tokens_per_sample=256,
|
||||
language_mapping=language_mapping,
|
||||
**kwargs,
|
||||
)
|
||||
task = obt.OnlineBackTranslationTask.setup_task(args)
|
||||
# we need to build the model to have the correct dictionary
|
||||
model = task.build_model(task.args)
|
||||
return task, model
|
||||
|
||||
def tmp_path(self, test_case: str) -> Path:
|
||||
return Path(tempfile.mkdtemp(test_case, dir=self.tmp_dir))
|
||||
|
||||
def test_lang_tokens(self):
|
||||
task, model = self.obt_task(["en", "ro", "zh"])
|
||||
assert obt._lang_token("en") in task.dictionary
|
||||
assert obt._lang_token("ro") in task.dictionary
|
||||
assert obt._lang_token("zh") in task.dictionary
|
||||
|
||||
en_bos = obt._lang_token_index(task.common_dict, "en")
|
||||
assert "en" == task.common_dict[en_bos].strip("_")
|
||||
zh_bos = obt._lang_token_index(task.common_dict, "zh")
|
||||
assert "zh" == task.common_dict[zh_bos].strip("_")
|
||||
zh_sample = mk_sample([zh_bos, 16, 14, 12, 10])
|
||||
|
||||
# we expect to receive the bos token for translation
|
||||
assert task.get_bos_token_from_sample(zh_sample) == en_bos
|
||||
|
||||
def test_backtranslate_sample(self):
|
||||
task, model = self.obt_task(["en", "ro", "zh"])
|
||||
|
||||
en_bos = obt._lang_token_index(task.common_dict, "en")
|
||||
zh_bos = obt._lang_token_index(task.common_dict, "zh")
|
||||
sample = mk_sample([zh_bos, 16, 14, 12, 10])
|
||||
|
||||
task.backtranslate_sample(sample, "zh", "en")
|
||||
target_zh = list(sample["target"][0])
|
||||
assert target_zh == [16, 14, 12, 10] # original zh sentence
|
||||
generated_en = sample["net_input"]["src_tokens"][0]
|
||||
assert generated_en[0] == en_bos
|
||||
|
||||
def test_train_dataset(self):
|
||||
data = self.tmp_path("test_train_dataset")
|
||||
mk_dataset(20, 10, data / "en" / "train.bin")
|
||||
mk_dataset(10, 10, data / "zh" / "train.bin")
|
||||
task, model = self.obt_task(["en", "zh"], data)
|
||||
task.load_dataset("train")
|
||||
|
||||
en_bos = obt._lang_token_index(task.common_dict, "en")
|
||||
zh_bos = obt._lang_token_index(task.common_dict, "zh")
|
||||
|
||||
train = task.datasets["train"]
|
||||
train.ordered_indices()
|
||||
train.prefetch([0, 19])
|
||||
sample_0 = train[0]
|
||||
sample_19 = train[19]
|
||||
self.assertEqual(
|
||||
set(sample_0.keys()), {"en-BT", "en-DENOISE", "zh-BT", "zh-DENOISE"}
|
||||
)
|
||||
for sample in (sample_0, sample_19):
|
||||
self.assertEqual(sample["en-BT"]["source"][0], en_bos)
|
||||
# bt target isn't ready to look at.
|
||||
self.assertEqual(sample["en-DENOISE"]["source"][0], en_bos)
|
||||
# TODO What could we check on the target side ?
|
||||
|
||||
for i in range(10):
|
||||
# Zh dataset is shorter, and is wrapped around En dataset.
|
||||
train.prefetch([i, i + 10])
|
||||
self.assertEqual(
|
||||
list(train[i]["zh-DENOISE"]["source"]),
|
||||
list(train[i + 10]["zh-DENOISE"]["source"]),
|
||||
)
|
||||
self.assertEqual(train[i]["zh-DENOISE"]["source"][0].item(), zh_bos)
|
||||
|
||||
# Sorted by increasing len
|
||||
self.assertLess(
|
||||
len(sample_0["en-BT"]["source"]), len(sample_19["en-BT"]["source"])
|
||||
)
|
||||
|
||||
def test_valid_dataset(self):
|
||||
data = self.tmp_path("test_valid_dataset")
|
||||
mk_dataset(10, 21, data / "valid.en-zh.en.bin")
|
||||
mk_dataset(10, 21, data / "valid.en-zh.zh.bin")
|
||||
|
||||
task, model = self.obt_task(["en", "zh"], data)
|
||||
valid = task.load_dataset("valid")
|
||||
en_bos = obt._lang_token_index(task.common_dict, "en")
|
||||
|
||||
assert valid is not None
|
||||
valid.prefetch(range(10))
|
||||
sample_0 = valid[0]
|
||||
sample_9 = valid[9]
|
||||
self.assertEqual(sample_0["id"], 0)
|
||||
self.assertEqual(sample_9["id"], 9)
|
||||
self.assertEqual(sample_0["source"][0], en_bos)
|
||||
self.assertEqual(sample_9["source"][0], en_bos)
|
||||
# TODO: could we test the target side ?
|
||||
|
||||
def assertFnMatch(self, fn, values):
|
||||
for x, y in values.items():
|
||||
fn_x = fn(x)
|
||||
self.assertEqual(fn_x, y, f"Fn has wrong value: fn({x}) = {fn_x} != {y}")
|
||||
|
||||
def test_piecewise_linear_fn(self):
|
||||
self.assertFnMatch(
|
||||
obt.PiecewiseLinearFn.from_string("1.0"), {0: 1, 100: 1, 500: 1, 1000: 1}
|
||||
)
|
||||
self.assertFnMatch(
|
||||
obt.PiecewiseLinearFn.from_string("0:1,1000:0"),
|
||||
{0: 1, 500: 0.5, 1000: 0, 2000: 0},
|
||||
)
|
||||
self.assertFnMatch(
|
||||
obt.PiecewiseLinearFn.from_string("0:0,1000:1"),
|
||||
{0: 0, 500: 0.5, 1000: 1, 2000: 1},
|
||||
)
|
||||
self.assertFnMatch(
|
||||
obt.PiecewiseLinearFn.from_string("0:0,1000:1,2000:0"),
|
||||
{0: 0, 500: 0.5, 1000: 1, 1500: 0.5, 2000: 0, 3000: 0},
|
||||
)
|
314
tests/test_roberta.py
Normal file
314
tests/test_roberta.py
Normal file
@ -0,0 +1,314 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import functools
|
||||
import unittest
|
||||
from typing import Any, Dict, Sequence
|
||||
|
||||
import fairseq
|
||||
import fairseq.options
|
||||
import fairseq.tasks
|
||||
import torch
|
||||
from tests.utils import dummy_dictionary
|
||||
|
||||
VOCAB_SIZE = 100
|
||||
|
||||
|
||||
@fairseq.tasks.register_task("fake_task")
|
||||
class FakeTask(fairseq.tasks.LegacyFairseqTask):
|
||||
def __init__(self, args):
|
||||
super().__init__(args)
|
||||
self.dictionary = dummy_dictionary(VOCAB_SIZE - 4)
|
||||
assert len(self.dictionary) == VOCAB_SIZE
|
||||
|
||||
@property
|
||||
def source_dictionary(self):
|
||||
return self.dictionary
|
||||
|
||||
@property
|
||||
def target_dictionary(self):
|
||||
return self.dictionary
|
||||
|
||||
|
||||
@functools.lru_cache()
|
||||
def get_toy_model(
|
||||
device: str,
|
||||
architecture: str = "roberta_enc_dec",
|
||||
**extra_args: Any,
|
||||
):
|
||||
assert device in ("gpu", "cpu")
|
||||
kwargs = {
|
||||
"arch": architecture,
|
||||
# Use characteristics dimensions
|
||||
"encoder_layers": 3,
|
||||
"encoder_embed_dim": 12,
|
||||
"encoder_ffn_embed_dim": 14,
|
||||
"encoder_attention_heads": 4,
|
||||
"decoder_layers": 3,
|
||||
"decoder_embed_dim": 12,
|
||||
"decoder_ffn_embed_dim": 14,
|
||||
"decoder_attention_heads": 4,
|
||||
# Disable dropout so we have comparable tests.
|
||||
"dropout": 0,
|
||||
"attention_dropout": 0,
|
||||
"activation_dropout": 0,
|
||||
"encoder_layerdrop": 0,
|
||||
# required args
|
||||
"tokens_per_sample": 256,
|
||||
"data": "/tmp/test_roberta",
|
||||
}
|
||||
kwargs.update(extra_args)
|
||||
fake_task = FakeTask(kwargs)
|
||||
args = fairseq.options.get_args(
|
||||
task="online_backtranslation",
|
||||
mono_langs="en,ro",
|
||||
valid_lang_pairs="en-ro",
|
||||
**kwargs,
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
model = fake_task.build_model(args)
|
||||
if device == "gpu":
|
||||
model.cuda()
|
||||
return fake_task, model
|
||||
|
||||
|
||||
def mk_sample(
|
||||
lang: str, device: str, tok: Sequence[int] = None, batch_size: int = 2
|
||||
) -> Dict[str, Any]:
|
||||
assert device in ("gpu", "cpu")
|
||||
if not tok:
|
||||
if lang == "en":
|
||||
tok = [10, 11, 12, 13, 14, 15, 2]
|
||||
else:
|
||||
tok = [20, 21, 22, 23, 24, 25, 26, 27, 2]
|
||||
|
||||
batch = torch.stack([torch.tensor(tok, dtype=torch.long)] * batch_size)
|
||||
if device == "gpu":
|
||||
batch = batch.cuda()
|
||||
sample = {
|
||||
"net_input": {
|
||||
"src_tokens": batch,
|
||||
"prev_output_tokens": batch,
|
||||
"src_lengths": torch.tensor(
|
||||
[len(tok)] * batch_size, dtype=torch.long, device=batch.device
|
||||
),
|
||||
},
|
||||
"target": batch[:, 1:],
|
||||
}
|
||||
return sample
|
||||
|
||||
|
||||
def cpu_gpu(fn):
|
||||
def helper(self):
|
||||
fn(self, "cpu")
|
||||
if torch.cuda.is_available():
|
||||
fn(self, "gpu")
|
||||
|
||||
return helper
|
||||
|
||||
|
||||
def architectures(fn):
|
||||
def helper(self):
|
||||
for arch in ["roberta_enc_dec", "transformer"]:
|
||||
fn(self, arch)
|
||||
|
||||
return helper
|
||||
|
||||
|
||||
class RobertaTest(unittest.TestCase):
|
||||
def assertTensorEqual(self, t1, t2, delta: float = 1e-6):
|
||||
self.assertEqual(t1.size(), t2.size(), "size mismatch")
|
||||
if delta == 0.0:
|
||||
self.assertEqual(t1.ne(t2).long().sum(), 0)
|
||||
else:
|
||||
self.assertEqual(((t2 - t1).abs() > delta).long().sum(), 0)
|
||||
|
||||
def assertSharing(self, model, link_groups: Sequence[Sequence[str]]):
|
||||
ids = {}
|
||||
for group in link_groups:
|
||||
group_ids = {name: id(params(model, name)) for name in group}
|
||||
shared_id = group_ids[group[0]]
|
||||
self.assertEqual(group_ids, {name: shared_id for name in group})
|
||||
self.assertNotIn(shared_id, ids)
|
||||
ids[shared_id] = group
|
||||
|
||||
def test_roberta_shared_params(self):
|
||||
_, roberta = get_toy_model("cpu", architecture="roberta")
|
||||
self.assertSharing(
|
||||
roberta,
|
||||
[
|
||||
[
|
||||
"encoder.sentence_encoder.embed_tokens.weight",
|
||||
"encoder.lm_head.weight",
|
||||
]
|
||||
],
|
||||
)
|
||||
|
||||
_, roberta = get_toy_model(
|
||||
"cpu", architecture="roberta", untie_weights_roberta=True
|
||||
)
|
||||
self.assertSharing(
|
||||
roberta,
|
||||
[
|
||||
["encoder.sentence_encoder.embed_tokens.weight"],
|
||||
["encoder.lm_head.weight"],
|
||||
],
|
||||
)
|
||||
|
||||
def test_roberta_enc_dec_shared_params(self):
|
||||
# 3 distinct embeddings
|
||||
_, enc_dec = get_toy_model("cpu", architecture="roberta_enc_dec")
|
||||
self.assertSharing(
|
||||
enc_dec,
|
||||
[
|
||||
["encoder.embed_tokens.weight"],
|
||||
["decoder.embed_tokens.weight"],
|
||||
["decoder.output_projection.weight"],
|
||||
],
|
||||
)
|
||||
|
||||
# 2 distinct embeddings, one for encoder, one for decoder
|
||||
_, enc_dec = get_toy_model(
|
||||
"cpu", architecture="roberta_enc_dec", share_decoder_input_output_embed=True
|
||||
)
|
||||
self.assertSharing(
|
||||
enc_dec,
|
||||
[
|
||||
["encoder.embed_tokens.weight"],
|
||||
[
|
||||
"decoder.embed_tokens.weight",
|
||||
"decoder.output_projection.weight",
|
||||
],
|
||||
],
|
||||
)
|
||||
|
||||
# shared embeddings
|
||||
_, enc_dec = get_toy_model(
|
||||
"cpu", architecture="roberta_enc_dec", share_all_embeddings=True
|
||||
)
|
||||
self.assertSharing(
|
||||
enc_dec,
|
||||
[
|
||||
[
|
||||
"encoder.embed_tokens.weight",
|
||||
"decoder.embed_tokens.weight",
|
||||
"decoder.output_projection.weight",
|
||||
]
|
||||
],
|
||||
)
|
||||
|
||||
def test_roberta_max_positions_is_correctly_set(self):
|
||||
device = "cpu"
|
||||
task, model = get_toy_model(device)
|
||||
max_pos = model.max_decoder_positions()
|
||||
self.assertEqual(max_pos, 256)
|
||||
self.assertEqual(max_pos, model.decoder.max_positions())
|
||||
self.assertEqual(max_pos, model.encoder.max_positions())
|
||||
self.assertEqual(max_pos, model.encoder.embed_positions.max_positions)
|
||||
|
||||
sentence = [31 for _ in range(max_pos)]
|
||||
sample = mk_sample("en", device, sentence, batch_size=1)
|
||||
self.assertEqual(list(sample["net_input"]["src_lengths"]), [max_pos])
|
||||
self.assertEqual(len(sample["net_input"]["src_tokens"][0]), max_pos)
|
||||
x, _ = model.forward(**sample["net_input"])
|
||||
self.assertEqual(x.shape, (1, max_pos, VOCAB_SIZE))
|
||||
|
||||
@cpu_gpu
|
||||
def test_roberta_forward_backward(self, device: str):
|
||||
_, model = get_toy_model(device)
|
||||
sample = mk_sample("en", device)
|
||||
en_tokens = sample["net_input"]["src_tokens"]
|
||||
(bs, l) = en_tokens.shape
|
||||
# Forward
|
||||
logits, _ = model(**sample["net_input"])
|
||||
self.assertEqual(logits.shape, (bs, l, VOCAB_SIZE))
|
||||
|
||||
# Backward
|
||||
loss = logits.sum()
|
||||
loss.backward()
|
||||
|
||||
@cpu_gpu
|
||||
def test_roberta_forward_backward_bs1(self, device: str):
|
||||
_, model = get_toy_model(device)
|
||||
sample = mk_sample("en", device, batch_size=1)
|
||||
o, _ = model.forward(**sample["net_input"])
|
||||
loss = o.sum()
|
||||
sample2 = mk_sample("ro", device, batch_size=1)
|
||||
o, _ = model.forward(**sample2["net_input"])
|
||||
loss += o.sum()
|
||||
loss.backward()
|
||||
|
||||
@cpu_gpu
|
||||
def test_roberta_batching(self, device: str):
|
||||
"""
|
||||
Checks that the batch of size 2 give twice the same results than the batch of size 1.
|
||||
"""
|
||||
_, model = get_toy_model(device)
|
||||
sample = mk_sample("en", device, batch_size=1)
|
||||
slen = sample["net_input"]["src_lengths"][0]
|
||||
sample2 = mk_sample("en", device, batch_size=2)
|
||||
with torch.no_grad():
|
||||
z = model.encoder.forward(
|
||||
sample["net_input"]["src_tokens"], sample["net_input"]["src_lengths"]
|
||||
)
|
||||
z = z["encoder_out"][-1]
|
||||
logits, _ = model.forward(**sample["net_input"])
|
||||
|
||||
z2 = model.encoder.forward(
|
||||
sample2["net_input"]["src_tokens"], sample["net_input"]["src_lengths"]
|
||||
)
|
||||
z2 = z2["encoder_out"][-1]
|
||||
logits2, _ = model.forward(**sample2["net_input"])
|
||||
|
||||
self.assertEqual(z.shape, (slen, 1, 12))
|
||||
self.assertEqual(z2.shape, (slen, 2, 12))
|
||||
self.assertTensorEqual(logits2[0], logits2[1])
|
||||
self.assertTensorEqual(logits[0], logits2[0])
|
||||
|
||||
@cpu_gpu
|
||||
def test_roberta_incremental_decoder(self, device: str):
|
||||
"""
|
||||
Checks that incremental decoding yields the same result than non incremental one.
|
||||
"""
|
||||
task, model = get_toy_model(device)
|
||||
|
||||
en_sample = mk_sample("en", device)
|
||||
en_tokens = en_sample["net_input"]["src_tokens"]
|
||||
ro_sample = mk_sample("ro", device)
|
||||
ro_tokens = ro_sample["net_input"]["src_tokens"]
|
||||
|
||||
en_enc = model.encoder.forward(
|
||||
en_tokens, src_lengths=en_sample["net_input"]["src_lengths"]
|
||||
)
|
||||
(bs, tgt_len) = ro_tokens.shape
|
||||
|
||||
# Decode without incremental state
|
||||
ro_dec, _ = model.decoder.forward(ro_tokens, encoder_out=en_enc)
|
||||
self.assertEqual(ro_dec.shape, (bs, tgt_len, VOCAB_SIZE))
|
||||
self.assertTensorEqual(ro_dec[0], ro_dec[1])
|
||||
|
||||
# Decode with incremental state
|
||||
inc_state = {}
|
||||
ro_dec_inc = []
|
||||
for l in range(tgt_len):
|
||||
ro, _ = model.decoder.forward(
|
||||
ro_tokens[:, : l + 1], encoder_out=en_enc, incremental_state=inc_state
|
||||
)
|
||||
self.assertEqual(ro.shape, (bs, 1, VOCAB_SIZE))
|
||||
ro_dec_inc.append(ro)
|
||||
|
||||
for l in range(tgt_len):
|
||||
# Intra-batch
|
||||
self.assertTensorEqual(ro_dec_inc[l][0], ro_dec_inc[l][1])
|
||||
# Incremental vs non-incremental
|
||||
self.assertTensorEqual(ro_dec_inc[l][:, 0], ro_dec[:, l])
|
||||
|
||||
|
||||
def params(model, name):
|
||||
if "." not in name:
|
||||
return getattr(model, name)
|
||||
|
||||
prefix, name = name.split(".", 1)
|
||||
return params(getattr(model, prefix), name)
|
Loading…
Reference in New Issue
Block a user