migrate translation task (#1569)

Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1569

Test Plan:
Imported from OSS

tests + ran

```
python fairseq_cli/train.py \                                                           18:08:56
    ~/data/iwslt14.de-en \
    --arch transformer_iwslt_de_en --share-decoder-input-output-embed \
    --optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \
    --lr 5e-4 --lr-scheduler inverse_sqrt --warmup-updates 4000 \
    --dropout 0.3 --weight-decay 0.0001 \
    --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
    --max-tokens 4096 \
    --eval-bleu \
    --eval-bleu-args '{"beam": 5, "max_len_a": 1.2, "max_len_b": 10}' \
    --eval-bleu-detok moses \
    --eval-bleu-remove-bpe \
    --eval-bleu-print-samples \
    --best-checkpoint-metric bleu --maximize-best-checkpoint-metric
```

Reviewed By: myleott

Differential Revision: D25967217

Pulled By: alexeib

fbshipit-source-id: 808f3cb0939fa13e1e05f39bfa02a7fb0b152940
This commit is contained in:
alexeib 2021-01-20 17:59:39 -08:00 committed by Facebook GitHub Bot
parent 9fc53d6217
commit 15867e1284
5 changed files with 246 additions and 183 deletions

View File

@ -3,16 +3,52 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass, field
import torch
from omegaconf import II
from fairseq import metrics, utils
from fairseq.dataclass import ChoiceEnum
from fairseq.tasks import register_task
from fairseq.tasks.translation import TranslationTask
from fairseq.tasks.translation import TranslationConfig, TranslationTask
from .logsumexp_moe import LogSumExpMoE
from .mean_pool_gating_network import MeanPoolGatingNetwork
@register_task("translation_moe")
METHOD_CHOICES = ChoiceEnum(["sMoElp", "sMoEup", "hMoElp", "hMoEup"])
@dataclass
class TranslationMoEConfig(TranslationConfig):
method: METHOD_CHOICES = field(
default="hMoEup",
metadata={"help": "MoE method"},
)
num_experts: int = field(
default=3,
metadata={"help": "number of experts"},
)
mean_pool_gating_network: bool = field(
default=False,
metadata={"help": "use a simple mean-pooling gating network"},
)
mean_pool_gating_network_dropout: float = field(
default=0,
metadata={"help": "dropout for mean-pooling gating network"},
)
mean_pool_gating_network_encoder_dim: int = field(
default=0,
metadata={"help": "encoder output dim for mean-pooling gating network"},
)
gen_expert: int = field(
default=0,
metadata={"help": "which expert to use for generation"},
)
sentence_avg: bool = II("optimization.sentence_avg")
@register_task("translation_moe", dataclass=TranslationMoEConfig)
class TranslationMoETask(TranslationTask):
"""
Translation task for Mixture of Experts (MoE) models.
@ -37,77 +73,60 @@ class TranslationMoETask(TranslationTask):
:prog:
"""
@staticmethod
def add_args(parser):
"""Add task-specific arguments to the parser."""
# fmt: off
TranslationTask.add_args(parser)
parser.add_argument('--method', default='hMoEup',
choices=['sMoElp', 'sMoEup', 'hMoElp', 'hMoEup'])
parser.add_argument('--num-experts', default=3, type=int, metavar='N',
help='number of experts')
parser.add_argument('--mean-pool-gating-network', action='store_true',
help='use a simple mean-pooling gating network')
parser.add_argument('--mean-pool-gating-network-dropout', type=float,
help='dropout for mean-pooling gating network')
parser.add_argument('--mean-pool-gating-network-encoder-dim', type=float,
help='encoder output dim for mean-pooling gating network')
parser.add_argument('--gen-expert', type=int, default=0,
help='which expert to use for generation')
# fmt: on
cfg: TranslationMoEConfig
def __init__(self, args, src_dict, tgt_dict):
if args.method == "sMoElp":
def __init__(self, cfg: TranslationMoEConfig, src_dict, tgt_dict):
if cfg.method == "sMoElp":
# soft MoE with learned prior
self.uniform_prior = False
self.hard_selection = False
elif args.method == "sMoEup":
elif cfg.method == "sMoEup":
# soft MoE with uniform prior
self.uniform_prior = True
self.hard_selection = False
elif args.method == "hMoElp":
elif cfg.method == "hMoElp":
# hard MoE with learned prior
self.uniform_prior = False
self.hard_selection = True
elif args.method == "hMoEup":
elif cfg.method == "hMoEup":
# hard MoE with uniform prior
self.uniform_prior = True
self.hard_selection = True
# add indicator tokens for each expert
for i in range(args.num_experts):
for i in range(cfg.num_experts):
# add to both dictionaries in case we're sharing embeddings
src_dict.add_symbol("<expert_{}>".format(i))
tgt_dict.add_symbol("<expert_{}>".format(i))
super().__init__(args, src_dict, tgt_dict)
super().__init__(cfg, src_dict, tgt_dict)
def build_model(self, args):
def build_model(self, cfg):
from fairseq import models
model = models.build_model(args, self)
model = models.build_model(cfg, self)
if not self.uniform_prior and not hasattr(model, "gating_network"):
if self.args.mean_pool_gating_network:
if getattr(args, "mean_pool_gating_network_encoder_dim", None):
encoder_dim = args.mean_pool_gating_network_encoder_dim
elif getattr(args, "encoder_embed_dim", None):
if self.cfg.mean_pool_gating_network:
if self.cfg.mean_pool_gating_network_encoder_dim > 0:
encoder_dim = self.cfg.mean_pool_gating_network_encoder_dim
elif getattr(cfg, "encoder_embed_dim", None):
# assume that encoder_embed_dim is the encoder's output dimension
encoder_dim = args.encoder_embed_dim
encoder_dim = cfg.encoder_embed_dim
else:
raise ValueError(
"Must specify --mean-pool-gating-network-encoder-dim"
)
if getattr(args, "mean_pool_gating_network_dropout", None):
dropout = args.mean_pool_gating_network_dropout
elif getattr(args, "dropout", None):
dropout = args.dropout
if self.cfg.mean_pool_gating_network_dropout > 0:
dropout = self.cfg.mean_pool_gating_network_dropout
elif getattr(cfg, "dropout", None):
dropout = cfg.dropout
else:
raise ValueError("Must specify --mean-pool-gating-network-dropout")
raise ValueError("Must specify task.mean_pool_gating_network_dropout")
model.gating_network = MeanPoolGatingNetwork(
encoder_dim,
args.num_experts,
self.cfg.num_experts,
dropout,
)
else:
@ -125,7 +144,7 @@ class TranslationMoETask(TranslationTask):
criterion, "compute_loss"
), "translation_moe task requires the criterion to implement the compute_loss() method"
k = self.args.num_experts
k = self.cfg.num_experts
bsz = sample["target"].size(0)
def get_lprob_y(encoder_out, prev_output_tokens_k):
@ -185,7 +204,7 @@ class TranslationMoETask(TranslationTask):
loss = loss.sum()
sample_size = (
sample["target"].size(0) if self.args.sentence_avg else sample["ntokens"]
sample["target"].size(0) if self.cfg.sentence_avg else sample["ntokens"]
)
logging_output = {
"loss": utils.item(loss.data),
@ -221,7 +240,7 @@ class TranslationMoETask(TranslationTask):
expert=None,
constraints=None,
):
expert = expert or self.args.gen_expert
expert = expert or self.cfg.gen_expert
with torch.no_grad():
return generator.generate(
models,

View File

@ -3,14 +3,17 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass, field
import itertools
import json
import logging
import os
from typing import Optional
from argparse import Namespace
from omegaconf import II
import numpy as np
from fairseq import metrics, options, utils
from fairseq import metrics, utils
from fairseq.data import (
AppendTokenDataset,
ConcatDataset,
@ -22,7 +25,9 @@ from fairseq.data import (
encoders,
indexed_dataset,
)
from fairseq.tasks import LegacyFairseqTask, register_task
from fairseq.data.indexed_dataset import get_available_dataset_impl
from fairseq.dataclass import ChoiceEnum, FairseqDataclass
from fairseq.tasks import FairseqTask, register_task
EVAL_BLEU_ORDER = 4
@ -161,8 +166,102 @@ def load_langpair_dataset(
)
@register_task("translation")
class TranslationTask(LegacyFairseqTask):
@dataclass
class TranslationConfig(FairseqDataclass):
data: Optional[str] = field(
default=None,
metadata={
"help": "colon separated path to data directories list, will be iterated upon during epochs "
"in round-robin manner; however, valid and test data are always in the first directory "
"to avoid the need for repeating them in all directories"
},
)
source_lang: Optional[str] = field(
default=None,
metadata={
"help": "source language",
"argparse_alias": "-s",
},
)
target_lang: Optional[str] = field(
default=None,
metadata={
"help": "target language",
"argparse_alias": "-t",
},
)
load_alignments: bool = field(
default=False, metadata={"help": "load the binarized alignments"}
)
left_pad_source: bool = field(
default=True, metadata={"help": "pad the source on the left"}
)
left_pad_target: bool = field(
default=False, metadata={"help": "pad the target on the left"}
)
max_source_positions: int = field(
default=1024, metadata={"help": "max number of tokens in the source sequence"}
)
max_target_positions: int = field(
default=1024, metadata={"help": "max number of tokens in the target sequence"}
)
upsample_primary: int = field(
default=-1, metadata={"help": "the amount of upsample primary dataset"}
)
truncate_source: bool = field(
default=False, metadata={"help": "truncate source to max-source-positions"}
)
num_batch_buckets: int = field(
default=0,
metadata={
"help": "if >0, then bucket source and target lengths into "
"N buckets and pad accordingly; this is useful on TPUs to minimize the number of compilations"
},
)
train_subset: str = II("dataset.train_subset")
dataset_impl: Optional[ChoiceEnum(get_available_dataset_impl())] = II(
"dataset.dataset_impl"
)
required_seq_len_multiple: int = II("dataset.required_seq_len_multiple")
# options for reporting BLEU during validation
eval_bleu: bool = field(
default=False, metadata={"help": "evaluation with BLEU scores"}
)
eval_bleu_args: str = field(
default="{}",
metadata={
"help": 'generation args for BLUE scoring, e.g., \'{"beam": 4, "lenpen": 0.6}\', as JSON string'
},
)
eval_bleu_detok: str = field(
default="space",
metadata={
"help": "detokenize before computing BLEU (e.g., 'moses'); required if using --eval-bleu; "
"use 'space' to disable detokenization; see fairseq.data.encoders for other options"
},
)
eval_bleu_detok_args: str = field(
default="{}",
metadata={"help": "args for building the tokenizer, if needed, as JSON string"},
)
eval_tokenized_bleu: bool = field(
default=False, metadata={"help": "compute tokenized BLEU instead of sacrebleu"}
)
eval_bleu_remove_bpe: Optional[str] = field(
default=None,
metadata={
"help": "remove BPE before computing BLEU",
"argparse_const": "@@ ",
},
)
eval_bleu_print_samples: bool = field(
default=False, metadata={"help": "print sample generations during validation"}
)
@register_task("translation", dataclass=TranslationConfig)
class TranslationTask(FairseqTask):
"""
Translate from one (source) language to another (target) language.
@ -174,108 +273,47 @@ class TranslationTask(LegacyFairseqTask):
The translation task is compatible with :mod:`fairseq-train`,
:mod:`fairseq-generate` and :mod:`fairseq-interactive`.
The translation task provides the following additional command-line
arguments:
.. argparse::
:ref: fairseq.tasks.translation_parser
:prog:
"""
@staticmethod
def add_args(parser):
"""Add task-specific arguments to the parser."""
# fmt: off
parser.add_argument('data', help='colon separated path to data directories list, \
will be iterated upon during epochs in round-robin manner; \
however, valid and test data are always in the first directory to \
avoid the need for repeating them in all directories')
parser.add_argument('-s', '--source-lang', default=None, metavar='SRC',
help='source language')
parser.add_argument('-t', '--target-lang', default=None, metavar='TARGET',
help='target language')
parser.add_argument('--load-alignments', action='store_true',
help='load the binarized alignments')
parser.add_argument('--left-pad-source', default='True', type=str, metavar='BOOL',
help='pad the source on the left')
parser.add_argument('--left-pad-target', default='False', type=str, metavar='BOOL',
help='pad the target on the left')
parser.add_argument('--max-source-positions', default=1024, type=int, metavar='N',
help='max number of tokens in the source sequence')
parser.add_argument('--max-target-positions', default=1024, type=int, metavar='N',
help='max number of tokens in the target sequence')
parser.add_argument('--upsample-primary', default=1, type=int,
help='amount to upsample primary dataset')
parser.add_argument('--truncate-source', action='store_true', default=False,
help='truncate source to max-source-positions')
parser.add_argument('--num-batch-buckets', default=0, type=int, metavar='N',
help='if >0, then bucket source and target lengths into N '
'buckets and pad accordingly; this is useful on TPUs '
'to minimize the number of compilations')
cfg: TranslationConfig
# options for reporting BLEU during validation
parser.add_argument('--eval-bleu', action='store_true',
help='evaluation with BLEU scores')
parser.add_argument('--eval-bleu-detok', type=str, default="space",
help='detokenize before computing BLEU (e.g., "moses"); '
'required if using --eval-bleu; use "space" to '
'disable detokenization; see fairseq.data.encoders '
'for other options')
parser.add_argument('--eval-bleu-detok-args', type=str, metavar='JSON',
help='args for building the tokenizer, if needed')
parser.add_argument('--eval-tokenized-bleu', action='store_true', default=False,
help='compute tokenized BLEU instead of sacrebleu')
parser.add_argument('--eval-bleu-remove-bpe', nargs='?', const='@@ ', default=None,
help='remove BPE before computing BLEU')
parser.add_argument('--eval-bleu-args', type=str, metavar='JSON',
help='generation args for BLUE scoring, '
'e.g., \'{"beam": 4, "lenpen": 0.6}\'')
parser.add_argument('--eval-bleu-print-samples', action='store_true',
help='print sample generations during validation')
# fmt: on
def __init__(self, args, src_dict, tgt_dict):
super().__init__(args)
def __init__(self, cfg: TranslationConfig, src_dict, tgt_dict):
super().__init__(cfg)
self.src_dict = src_dict
self.tgt_dict = tgt_dict
@classmethod
def setup_task(cls, args, **kwargs):
def setup_task(cls, cfg: TranslationConfig, **kwargs):
"""Setup the task (e.g., load dictionaries).
Args:
args (argparse.Namespace): parsed command-line arguments
"""
args.left_pad_source = utils.eval_bool(args.left_pad_source)
args.left_pad_target = utils.eval_bool(args.left_pad_target)
paths = utils.split_paths(args.data)
paths = utils.split_paths(cfg.data)
assert len(paths) > 0
# find language pair automatically
if args.source_lang is None or args.target_lang is None:
args.source_lang, args.target_lang = data_utils.infer_language_pair(
paths[0]
)
if args.source_lang is None or args.target_lang is None:
if cfg.source_lang is None or cfg.target_lang is None:
cfg.source_lang, cfg.target_lang = data_utils.infer_language_pair(paths[0])
if cfg.source_lang is None or cfg.target_lang is None:
raise Exception(
"Could not infer language pair, please provide it explicitly"
)
# load dictionaries
src_dict = cls.load_dictionary(
os.path.join(paths[0], "dict.{}.txt".format(args.source_lang))
os.path.join(paths[0], "dict.{}.txt".format(cfg.source_lang))
)
tgt_dict = cls.load_dictionary(
os.path.join(paths[0], "dict.{}.txt".format(args.target_lang))
os.path.join(paths[0], "dict.{}.txt".format(cfg.target_lang))
)
assert src_dict.pad() == tgt_dict.pad()
assert src_dict.eos() == tgt_dict.eos()
assert src_dict.unk() == tgt_dict.unk()
logger.info("[{}] dictionary: {} types".format(args.source_lang, len(src_dict)))
logger.info("[{}] dictionary: {} types".format(args.target_lang, len(tgt_dict)))
logger.info("[{}] dictionary: {} types".format(cfg.source_lang, len(src_dict)))
logger.info("[{}] dictionary: {} types".format(cfg.target_lang, len(tgt_dict)))
return cls(args, src_dict, tgt_dict)
return cls(cfg, src_dict, tgt_dict)
def load_dataset(self, split, epoch=1, combine=False, **kwargs):
"""Load a given dataset split.
@ -283,15 +321,15 @@ class TranslationTask(LegacyFairseqTask):
Args:
split (str): name of the split (e.g., train, valid, test)
"""
paths = utils.split_paths(self.args.data)
paths = utils.split_paths(self.cfg.data)
assert len(paths) > 0
if split != getattr(self.args, "train_subset", None):
if split != self.cfg.train_subset:
# if not training data set, use the first shard for valid and test
paths = paths[:1]
data_path = paths[(epoch - 1) % len(paths)]
# infer langcode
src, tgt = self.args.source_lang, self.args.target_lang
src, tgt = self.cfg.source_lang, self.cfg.target_lang
self.datasets[split] = load_langpair_dataset(
data_path,
@ -301,17 +339,17 @@ class TranslationTask(LegacyFairseqTask):
tgt,
self.tgt_dict,
combine=combine,
dataset_impl=self.args.dataset_impl,
upsample_primary=self.args.upsample_primary,
left_pad_source=self.args.left_pad_source,
left_pad_target=self.args.left_pad_target,
max_source_positions=self.args.max_source_positions,
max_target_positions=self.args.max_target_positions,
load_alignments=self.args.load_alignments,
truncate_source=self.args.truncate_source,
num_buckets=self.args.num_batch_buckets,
dataset_impl=self.cfg.dataset_impl,
upsample_primary=self.cfg.upsample_primary,
left_pad_source=self.cfg.left_pad_source,
left_pad_target=self.cfg.left_pad_target,
max_source_positions=self.cfg.max_source_positions,
max_target_positions=self.cfg.max_target_positions,
load_alignments=self.cfg.load_alignments,
truncate_source=self.cfg.truncate_source,
num_buckets=self.cfg.num_batch_buckets,
shuffle=(split != "test"),
pad_to_multiple=self.args.required_seq_len_multiple,
pad_to_multiple=self.cfg.required_seq_len_multiple,
)
def build_dataset_for_inference(self, src_tokens, src_lengths, constraints=None):
@ -323,22 +361,15 @@ class TranslationTask(LegacyFairseqTask):
constraints=constraints,
)
def build_model(self, args):
model = super().build_model(args)
if getattr(args, "eval_bleu", False):
assert getattr(args, "eval_bleu_detok", None) is not None, (
"--eval-bleu-detok is required if using --eval-bleu; "
"try --eval-bleu-detok=moses (or --eval-bleu-detok=space "
"to disable detokenization, e.g., when using sentencepiece)"
)
detok_args = json.loads(getattr(args, "eval_bleu_detok_args", "{}") or "{}")
def build_model(self, cfg):
model = super().build_model(cfg)
if self.cfg.eval_bleu:
detok_args = json.loads(self.cfg.eval_bleu_detok_args)
self.tokenizer = encoders.build_tokenizer(
Namespace(
tokenizer=getattr(args, "eval_bleu_detok", None), **detok_args
)
Namespace(tokenizer=self.cfg.eval_bleu_detok, **detok_args)
)
gen_args = json.loads(getattr(args, "eval_bleu_args", "{}") or "{}")
gen_args = json.loads(self.cfg.eval_bleu_args)
self.sequence_generator = self.build_generator(
[model], Namespace(**gen_args)
)
@ -346,7 +377,7 @@ class TranslationTask(LegacyFairseqTask):
def valid_step(self, sample, model, criterion):
loss, sample_size, logging_output = super().valid_step(sample, model, criterion)
if self.args.eval_bleu:
if self.cfg.eval_bleu:
bleu = self._inference_with_bleu(self.sequence_generator, sample, model)
logging_output["_bleu_sys_len"] = bleu.sys_len
logging_output["_bleu_ref_len"] = bleu.ref_len
@ -360,7 +391,7 @@ class TranslationTask(LegacyFairseqTask):
def reduce_metrics(self, logging_outputs, criterion):
super().reduce_metrics(logging_outputs, criterion)
if self.args.eval_bleu:
if self.cfg.eval_bleu:
def sum_logs(key):
return sum(log.get(key, 0) for log in logging_outputs)
@ -399,7 +430,7 @@ class TranslationTask(LegacyFairseqTask):
def max_positions(self):
"""Return the max sentence length allowed by the task."""
return (self.args.max_source_positions, self.args.max_target_positions)
return (self.cfg.max_source_positions, self.cfg.max_target_positions)
@property
def source_dictionary(self):
@ -417,7 +448,7 @@ class TranslationTask(LegacyFairseqTask):
def decode(toks, escape_unk=False):
s = self.tgt_dict.string(
toks.int().cpu(),
self.args.eval_bleu_remove_bpe,
self.cfg.eval_bleu_remove_bpe,
# The default unknown string in fairseq is `<unk>`, but
# this is tokenized by sacrebleu as `< unk >`, inflating
# BLEU scores. Instead, we use a somewhat more verbose
@ -439,10 +470,10 @@ class TranslationTask(LegacyFairseqTask):
escape_unk=True, # don't count <unk> as matches to the hypo
)
)
if self.args.eval_bleu_print_samples:
if self.cfg.eval_bleu_print_samples:
logger.info("example hypothesis: " + hyps[0])
logger.info("example reference: " + refs[0])
if self.args.eval_tokenized_bleu:
if self.cfg.eval_tokenized_bleu:
return sacrebleu.corpus_bleu(hyps, [refs], tokenize="none")
else:
return sacrebleu.corpus_bleu(hyps, [refs])

View File

@ -3,13 +3,21 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass
from fairseq.data.legacy.masked_lm_dictionary import MaskedLMDictionary
from fairseq.tasks.translation import TranslationTask
from fairseq.tasks.translation import TranslationConfig, TranslationTask
from . import register_task
@register_task("translation_from_pretrained_xlm")
@dataclass
class TranslationFromPretrainedXLMConfig(TranslationConfig):
pass
@register_task(
"translation_from_pretrained_xlm", dataclass=TranslationFromPretrainedXLMConfig
)
class TranslationFromPretrainedXLMTask(TranslationTask):
"""
Same as TranslationTask except use the MaskedLMDictionary class so that

View File

@ -3,33 +3,35 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import os
from dataclasses import dataclass, field
import torch
from fairseq import utils
from fairseq.data import LanguagePairDataset
from fairseq.dataclass import ChoiceEnum
from fairseq.tasks import register_task
from fairseq.tasks.translation import TranslationTask, load_langpair_dataset
from fairseq.tasks.translation import TranslationConfig, TranslationTask, load_langpair_dataset
from fairseq.utils import new_arange
@register_task("translation_lev")
NOISE_CHOICES = ChoiceEnum(["random_delete", "random_mask", "no_noise", "full_mask"])
@dataclass
class TranslationLevenshteinConfig(TranslationConfig):
noise: NOISE_CHOICES = field(
default="random_delete",
metadata={
"help": "type of noise"
},
)
@register_task("translation_lev", dataclass=TranslationLevenshteinConfig)
class TranslationLevenshteinTask(TranslationTask):
"""
Translation (Sequence Generation) task for Levenshtein Transformer
See `"Levenshtein Transformer" <https://arxiv.org/abs/1905.11006>`_.
"""
@staticmethod
def add_args(parser):
"""Add task-specific arguments to the parser."""
# fmt: off
TranslationTask.add_args(parser)
parser.add_argument(
'--noise',
default='random_delete',
choices=['random_delete', 'random_mask', 'no_noise', 'full_mask'])
# fmt: on
cfg: TranslationLevenshteinConfig
def load_dataset(self, split, epoch=1, combine=False, **kwargs):
"""Load a given dataset split.
@ -37,12 +39,12 @@ class TranslationLevenshteinTask(TranslationTask):
Args:
split (str): name of the split (e.g., train, valid, test)
"""
paths = utils.split_paths(self.args.data)
paths = utils.split_paths(self.cfg.data)
assert len(paths) > 0
data_path = paths[(epoch - 1) % len(paths)]
# infer langcode
src, tgt = self.args.source_lang, self.args.target_lang
src, tgt = self.cfg.source_lang, self.cfg.target_lang
self.datasets[split] = load_langpair_dataset(
data_path,
@ -52,12 +54,12 @@ class TranslationLevenshteinTask(TranslationTask):
tgt,
self.tgt_dict,
combine=combine,
dataset_impl=self.args.dataset_impl,
upsample_primary=self.args.upsample_primary,
left_pad_source=self.args.left_pad_source,
left_pad_target=self.args.left_pad_target,
max_source_positions=self.args.max_source_positions,
max_target_positions=self.args.max_target_positions,
dataset_impl=self.cfg.dataset_impl,
upsample_primary=self.cfg.upsample_primary,
left_pad_source=self.cfg.left_pad_source,
left_pad_target=self.cfg.left_pad_target,
max_source_positions=self.cfg.max_source_positions,
max_target_positions=self.cfg.max_target_positions,
prepend_bos=True,
)
@ -133,13 +135,13 @@ class TranslationLevenshteinTask(TranslationTask):
)
return target_tokens.masked_fill(~target_mask, unk)
if self.args.noise == "random_delete":
if self.cfg.noise == "random_delete":
return _random_delete(target_tokens)
elif self.args.noise == "random_mask":
elif self.cfg.noise == "random_mask":
return _random_mask(target_tokens)
elif self.args.noise == "full_mask":
elif self.cfg.noise == "full_mask":
return _full_mask(target_tokens)
elif self.args.noise == "no_noise":
elif self.cfg.noise == "no_noise":
return target_tokens
else:
raise NotImplementedError

View File

@ -53,7 +53,7 @@ class TestCheckpointUtils(unittest.TestCase):
yield os.path.join(data_dir, "checkpoint_last.pt")
def test_load_model_ensemble_and_task(self):
with contextlib.redirect_stdout(StringIO()):
# with contextlib.redirect_stdout(StringIO()):
with self._train_transformer(seed=123) as model1:
with self._train_transformer(seed=456) as model2:
ensemble, cfg, task = checkpoint_utils.load_model_ensemble_and_task(
@ -67,7 +67,10 @@ class TestCheckpointUtils(unittest.TestCase):
self.assertEqual(ensemble[1].args.seed, 456)
# the task from the first model should be returned
self.assertEqual(task.args.seed, 123)
self.assertTrue("seed123" in task.cfg.data)
# last cfg is saved
self.assertEqual(cfg.common.seed, 456)
def test_prune_state_dict(self):
with contextlib.redirect_stdout(StringIO()):