switch denoising and multilingual_denoising tasks to OmegaConf (#4447)

Co-authored-by: Alexander Jipa <azzhipa@amazon.com>
This commit is contained in:
Alexander Jipa 2022-06-28 15:44:18 -04:00 committed by GitHub
parent fe56de410c
commit a6a6327942
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 439 additions and 204 deletions

View File

@ -107,7 +107,6 @@ class DenoisingDataset(FairseqDataset):
shuffle (bool, optional): shuffle the elements before batching.
Default: ``True``
seed: Seed for random number generator for reproducibility.
args: argparse arguments.
"""
def __init__(
@ -119,7 +118,15 @@ class DenoisingDataset(FairseqDataset):
mask_whole_words,
shuffle,
seed,
args,
mask,
mask_random,
insert,
rotate,
permute_sentences,
bpe,
replace_length,
mask_length,
poisson_lambda,
eos=None,
item_transform_func=None,
):
@ -132,31 +139,31 @@ class DenoisingDataset(FairseqDataset):
self.seed = seed
self.mask_idx = mask_idx
self.mask_whole_word = mask_whole_words
self.mask_ratio = args.mask
self.random_ratio = args.mask_random
self.insert_ratio = args.insert
self.rotate_ratio = args.rotate
self.permute_sentence_ratio = args.permute_sentences
self.mask_ratio = mask
self.random_ratio = mask_random
self.insert_ratio = insert
self.rotate_ratio = rotate
self.permute_sentence_ratio = permute_sentences
self.eos = eos if eos is not None else vocab.eos()
self.item_transform_func = item_transform_func
if args.bpe != "gpt2":
if bpe != "gpt2":
self.full_stop_index = self.vocab.eos()
else:
assert args.bpe == "gpt2"
assert bpe == "gpt2"
self.full_stop_index = self.vocab.index("13")
self.replace_length = args.replace_length
self.replace_length = replace_length
if self.replace_length not in [-1, 0, 1]:
raise ValueError(f"invalid arg: replace_length={self.replace_length}")
if args.mask_length not in ["subword", "word", "span-poisson"]:
raise ValueError(f"invalid arg: mask-length={args.mask_length}")
if args.mask_length == "subword" and args.replace_length not in [0, 1]:
if mask_length not in ["subword", "word", "span-poisson"]:
raise ValueError(f"invalid arg: mask-length={mask_length}")
if mask_length == "subword" and replace_length not in [0, 1]:
raise ValueError(f"if using subwords, use replace-length=1 or 0")
self.mask_span_distribution = None
if args.mask_length == "span-poisson":
_lambda = args.poisson_lambda
if mask_length == "span-poisson":
_lambda = poisson_lambda
lambda_to_the_k = 1
e_to_the_minus_lambda = math.exp(-_lambda)

View File

@ -5,6 +5,11 @@
import logging
import os
from dataclasses import dataclass, field
from typing import Any, Optional
import numpy as np
from omegaconf import II, MISSING
from fairseq import utils
from fairseq.data import (
@ -22,145 +27,142 @@ from fairseq.data import (
)
from fairseq.data.encoders.utils import get_whole_word_mask
from fairseq.data.shorten_dataset import maybe_shorten_dataset
from fairseq.tasks import LegacyFairseqTask, register_task
import numpy as np
from fairseq.dataclass import ChoiceEnum, FairseqDataclass
from fairseq.tasks import FairseqTask, register_task
from ..data.indexed_dataset import get_available_dataset_impl
logger = logging.getLogger(__name__)
SAMPLE_BREAK_MODE_CHOICES = ChoiceEnum(["none", "complete", "complete_doc", "eos"])
SHORTEN_METHOD_CHOICES = ChoiceEnum(["none", "truncate", "random_crop"])
MASK_LENGTH_CHOICES = ChoiceEnum(["subword", "word", "span-poisson"])
@register_task("denoising")
class DenoisingTask(LegacyFairseqTask):
@dataclass
class DenoisingConfig(FairseqDataclass):
data: str = field(
default=MISSING,
metadata={"help": "path to data directory"},
)
bpe: Optional[str] = field(
default=None,
metadata={"help": "TODO"},
)
tokens_per_sample: int = field(
default=512,
metadata={
"help": "max number of total tokens over all segments "
"per sample for dataset"
},
)
sample_break_mode: SAMPLE_BREAK_MODE_CHOICES = field(
default="complete_doc",
metadata={
"help": 'If omitted or "none", fills each sample with tokens-per-sample '
'tokens. If set to "complete", splits samples only at the end '
"of sentence, but may include multiple sentences per sample. "
'"complete_doc" is similar but respects doc boundaries. '
'If set to "eos", includes only one sentence per sample.'
},
)
replace_length: int = field(
default=0,
metadata={"help": "TODO, should only allow -1, 0 and 1"},
)
mask: float = field(
default=0.0,
metadata={"help": "fraction of words/subwords that will be masked"},
)
mask_random: float = field(
default=0.0,
metadata={"help": "instead of using [MASK], use random token this often"},
)
insert: float = field(
default=0.0,
metadata={"help": "insert this percentage of additional random tokens"},
)
permute: float = field(
default=0.0,
metadata={"help": "take this proportion of subwords and permute them"},
)
rotate: float = field(
default=0.5,
metadata={"help": "rotate this proportion of inputs"},
)
poisson_lambda: float = field(
default=3.0,
metadata={"help": "randomly shuffle sentences for this proportion of inputs"},
)
shuffle_instance: float = field(
default=0.0,
metadata={"help": "shuffle this proportion of sentences in all inputs"},
)
mask_length: MASK_LENGTH_CHOICES = field(
default="subword",
metadata={"help": "mask length to choose"},
)
permute_sentences: int = field(
default=-1,
metadata={
"help": "when masking N tokens, replace with 0, 1, or N tokens (use -1 for N)"
},
)
seed: int = II("common.seed")
shorten_method: SHORTEN_METHOD_CHOICES = field(
default="none",
metadata={
"help": "if not none, shorten sequences that exceed --tokens-per-sample"
},
)
shorten_data_split_list: str = field(
default="",
metadata={
"help": "comma-separated list of dataset splits to apply shortening to, "
'e.g., "train,valid" (default: all dataset splits)'
},
)
max_source_positions: int = field(
default=1024,
metadata={"help": "max number of tokens in the source sequence"},
)
max_target_positions: int = field(
default=1024,
metadata={"help": "max number of tokens in the target sequence"},
)
dataset_impl: Optional[ChoiceEnum(get_available_dataset_impl())] = II(
"dataset.dataset_impl"
)
@register_task("denoising", dataclass=DenoisingConfig)
class DenoisingTask(FairseqTask):
"""
Denoising task for applying sequence to sequence denoising. (ie. BART)
"""
@staticmethod
def add_args(parser):
"""Add task-specific arguments to the parser."""
parser.add_argument("data", help="path to data directory")
parser.add_argument(
"--tokens-per-sample",
default=512,
type=int,
help="max number of total tokens over all segments"
" per sample for dataset",
)
parser.add_argument(
"--sample-break-mode",
default="complete_doc",
type=str,
help="mode for breaking sentence",
)
parser.add_argument(
"--mask",
default=0.0,
type=float,
help="fraction of words/subwords that will be masked",
)
parser.add_argument(
"--mask-random",
default=0.0,
type=float,
help="instead of using [MASK], use random token this often",
)
parser.add_argument(
"--insert",
default=0.0,
type=float,
help="insert this percentage of additional random tokens",
)
parser.add_argument(
"--permute",
default=0.0,
type=float,
help="take this proportion of subwords and permute them",
)
parser.add_argument(
"--rotate",
default=0.5,
type=float,
help="rotate this proportion of inputs",
)
parser.add_argument(
"--poisson-lambda",
default=3.0,
type=float,
help="randomly shuffle sentences for this proportion of inputs",
)
parser.add_argument(
"--permute-sentences",
default=0.0,
type=float,
help="shuffle this proportion of sentences in all inputs",
)
parser.add_argument(
"--mask-length",
default="subword",
type=str,
choices=["subword", "word", "span-poisson"],
help="mask length to choose",
)
parser.add_argument(
"--replace-length",
default=-1,
type=int,
help="when masking N tokens, replace with 0, 1, or N tokens (use -1 for N)",
)
parser.add_argument(
"--max-source-positions",
default=1024,
type=int,
metavar="N",
help="max number of tokens in the source sequence",
)
parser.add_argument(
"--max-target-positions",
default=1024,
type=int,
metavar="N",
help="max number of tokens in the target sequence",
)
cfg: DenoisingConfig
parser.add_argument(
"--shorten-method",
default="none",
choices=["none", "truncate", "random_crop"],
help="if not none, shorten sequences that exceed --tokens-per-sample",
)
parser.add_argument(
"--shorten-data-split-list",
default="",
help="comma-separated list of dataset splits to apply shortening to, "
'e.g., "train,valid" (default: all dataset splits)',
)
def __init__(self, args, dictionary):
super().__init__(args)
def __init__(self, cfg, dictionary):
super().__init__(cfg)
self.dictionary = dictionary
self.seed = args.seed
# add mask token
self.mask_idx = self.dictionary.add_symbol("<mask>")
@classmethod
def setup_task(cls, args, **kwargs):
def setup_task(cls, cfg: DenoisingConfig, **kwargs):
"""Setup the task."""
paths = utils.split_paths(args.data)
paths = utils.split_paths(cfg.data)
assert len(paths) > 0
dictionary = Dictionary.load(os.path.join(paths[0], "dict.txt"))
logger.info("dictionary: {} types".format(len(dictionary)))
if not hasattr(args, "shuffle_instance"):
args.shuffle_instance = False
return cls(args, dictionary)
if not hasattr(cfg, "shuffle_instance"):
cfg.shuffle_instance = False
return cls(cfg, dictionary)
def load_dataset(self, split, epoch=1, combine=False, **kwargs):
"""Load a given dataset split.
Args:
split (str): name of the split (e.g., train, valid, test)
"""
paths = utils.split_paths(self.args.data)
def _load_dataset_split(self, split, epoch, combine):
paths = utils.split_paths(self.cfg.data)
assert len(paths) > 0
data_path = paths[(epoch - 1) % len(paths)]
split_path = os.path.join(data_path, split)
@ -168,7 +170,7 @@ class DenoisingTask(LegacyFairseqTask):
dataset = data_utils.load_indexed_dataset(
split_path,
self.dictionary,
self.args.dataset_impl,
self.cfg.dataset_impl,
combine=combine,
)
if dataset is None:
@ -181,20 +183,21 @@ class DenoisingTask(LegacyFairseqTask):
dataset = maybe_shorten_dataset(
dataset,
split,
self.args.shorten_data_split_list,
self.args.shorten_method,
self.args.tokens_per_sample,
self.args.seed,
self.cfg.shorten_data_split_list,
self.cfg.shorten_method,
self.cfg.tokens_per_sample,
self.cfg.seed,
)
# create continuous blocks of tokens
dataset = TokenBlockDataset(
dataset,
dataset.sizes,
self.args.tokens_per_sample - 2, # one less for <s> and one for </s>
self.cfg.tokens_per_sample - 2,
# one less for <s> and one for </s>
pad=self.dictionary.pad(),
eos=self.dictionary.eos(),
break_mode=self.args.sample_break_mode,
break_mode=self.cfg.sample_break_mode,
document_sep_len=0,
)
logger.info("loaded {} blocks from: {}".format(len(dataset), split_path))
@ -202,10 +205,19 @@ class DenoisingTask(LegacyFairseqTask):
# prepend beginning-of-sentence token (<s>, equiv. to [CLS] in BERT)
dataset = PrependTokenDataset(dataset, self.source_dictionary.bos())
dataset = AppendTokenDataset(dataset, self.source_dictionary.eos())
return dataset
def load_dataset(self, split, epoch=1, combine=False, **kwargs):
"""Load a given dataset split.
Args:
split (str): name of the split (e.g., train, valid, test)
"""
dataset = self._load_dataset_split(split, epoch, combine)
mask_whole_words = (
get_whole_word_mask(self.args, self.source_dictionary)
if self.args.mask_length != "subword"
get_whole_word_mask(self.cfg.bpe, self.source_dictionary)
if self.cfg.mask_length != "subword"
else None
)
@ -215,9 +227,17 @@ class DenoisingTask(LegacyFairseqTask):
self.dictionary,
self.mask_idx,
mask_whole_words,
shuffle=self.args.shuffle_instance,
seed=self.seed,
args=self.args,
shuffle=self.cfg.shuffle_instance,
seed=self.cfg.seed,
mask=self.cfg.mask,
mask_random=self.cfg.mask_random,
insert=self.cfg.insert,
rotate=self.cfg.rotate,
permute_sentences=self.cfg.permute_sentences,
bpe=self.cfg.bpe,
replace_length=self.cfg.replace_length,
mask_length=self.cfg.mask_length,
poisson_lambda=self.cfg.poisson_lambda,
)
logger.info(
"Split: {0}, Loaded {1} samples of denoising_dataset".format(
@ -236,10 +256,10 @@ class DenoisingTask(LegacyFairseqTask):
src_dataset = TokenBlockDataset(
src_tokens,
src_lengths,
block_size=self.args.tokens_per_sample - 2, # for <s> and </s>
block_size=self.cfg.tokens_per_sample - 2, # for <s> and </s>
pad=pad,
eos=eos,
break_mode=self.args.sample_break_mode,
break_mode=self.cfg.sample_break_mode,
document_sep_len=0,
)
prev_output_tokens = PrependTokenDataset(
@ -263,7 +283,7 @@ class DenoisingTask(LegacyFairseqTask):
def max_positions(self):
"""Return the max sentence length allowed by the task."""
return (self.args.max_source_positions, self.args.max_target_positions)
return (self.cfg.max_source_positions, self.cfg.max_target_positions)
@property
def source_dictionary(self):

View File

@ -2,11 +2,13 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
import os
from dataclasses import dataclass, field
from typing import Optional
import numpy as np
from omegaconf import II
from fairseq.data import (
AppendTokenDataset,
@ -22,43 +24,49 @@ from fairseq.data import (
from fairseq.data.encoders.utils import get_whole_word_mask
from fairseq.tasks import register_task
from .denoising import DenoisingTask
from .denoising import DenoisingConfig, DenoisingTask
logger = logging.getLogger(__name__)
@register_task("multilingual_denoising")
@dataclass
class MultilingualDenoisingConfig(DenoisingConfig):
multilang_sampling_alpha: float = field(
default=1.0,
metadata={"help": "smoothing alpha for sample ratios across multiple datasets"},
)
add_lang_token: bool = field(
default=False,
metadata={"help": ""},
)
langs: Optional[str] = field(
default=None,
metadata={"help": "language ids we are considering"},
)
no_whole_word_mask_langs: str = field(
default="",
metadata={
"help": "languages without spacing between words don't support whole word masking"
},
)
train_subset: str = II("common.train_subset")
valid_subset: str = II("common.valid_subset")
@register_task("multilingual_denoising", dataclass=MultilingualDenoisingConfig)
class MultilingualDenoisingTask(DenoisingTask):
@staticmethod
def add_args(parser):
DenoisingTask.add_args(parser)
parser.add_argument(
"--multilang-sampling-alpha",
type=float,
default=1.0,
help="smoothing alpha for sample ratios across multiple datasets",
)
parser.add_argument("--add-lang-token", default=False, action="store_true")
parser.add_argument(
"--langs", type=str, help="language ids we are considering", default=None
)
parser.add_argument(
"--no-whole-word-mask-langs",
type=str,
default="",
metavar="N",
help="languages without spacing between words dont support whole word masking",
)
cfg: MultilingualDenoisingConfig
@classmethod
def setup_task(cls, args, **kwargs):
def setup_task(cls, cfg: MultilingualDenoisingConfig, **kwargs):
"""Setup the task."""
paths = args.data.split(":")
paths = cfg.data.split(":")
assert len(paths) > 0
dictionary = Dictionary.load(os.path.join(paths[0], "dict.txt"))
data_path = paths[0]
if args.langs is None:
if cfg.langs is None:
languages = sorted(
[
name
@ -67,34 +75,32 @@ class MultilingualDenoisingTask(DenoisingTask):
]
)
else:
languages = args.langs.split(",")
languages = cfg.langs.split(",")
if args.add_lang_token:
if cfg.add_lang_token:
for lang in languages:
dictionary.add_symbol("[{}]".format(lang))
logger.info("dictionary: {} types".format(len(dictionary)))
if not hasattr(args, "shuffle_instance"):
args.shuffle_instance = False
return cls(args, dictionary)
if not hasattr(cfg, "shuffle_instance"):
cfg.shuffle_instance = False
return cls(cfg, dictionary)
def __init__(self, args, dictionary):
super().__init__(args, dictionary)
def __init__(self, cfg: MultilingualDenoisingConfig, dictionary):
super().__init__(cfg, dictionary)
self.dictionary = dictionary
self.seed = args.seed
# add mask token
self.mask_idx = self.dictionary.add_symbol("<mask>")
self.langs = args.langs
self.args = args
self.cfg = cfg
def _get_sample_prob(self, dataset_lens):
"""
Get smoothed sampling porbability by languages. This helps low resource
Get smoothed sampling probability by languages. This helps low resource
languages by upsampling them.
"""
prob = dataset_lens / dataset_lens.sum()
smoothed_prob = prob**self.args.multilang_sampling_alpha
smoothed_prob = prob**self.cfg.multilang_sampling_alpha
smoothed_prob = smoothed_prob / smoothed_prob.sum()
return smoothed_prob
@ -104,12 +110,12 @@ class MultilingualDenoisingTask(DenoisingTask):
Args:
split (str): name of the split (e.g., train, valid, test)
"""
paths = self.args.data.split(":")
paths = self.cfg.data.split(":")
assert len(paths) > 0
data_path = paths[(epoch - 1) % len(paths)]
split_path = os.path.join(data_path, split)
if self.langs is None:
if self.cfg.langs is None:
languages = sorted(
[
name
@ -118,7 +124,7 @@ class MultilingualDenoisingTask(DenoisingTask):
]
)
else:
languages = self.langs.split(",")
languages = self.cfg.langs.split(",")
for name in languages:
p = os.path.join(data_path, name)
assert os.path.exists(p), "data not found: {}".format(p)
@ -128,8 +134,8 @@ class MultilingualDenoisingTask(DenoisingTask):
"Language to id mapping: ", {lang: id for id, lang in enumerate(languages)}
)
mask_whole_words = get_whole_word_mask(self.args, self.dictionary)
language_without_segmentations = self.args.no_whole_word_mask_langs.split(",")
mask_whole_words = get_whole_word_mask(self.cfg.bpe, self.dictionary)
language_without_segmentations = self.cfg.no_whole_word_mask_langs.split(",")
lang_datasets = []
for language in languages:
split_path = os.path.join(data_path, language, split)
@ -137,7 +143,7 @@ class MultilingualDenoisingTask(DenoisingTask):
dataset = data_utils.load_indexed_dataset(
split_path,
self.source_dictionary,
self.args.dataset_impl,
self.cfg.dataset_impl,
combine=combine,
)
if dataset is None:
@ -147,7 +153,7 @@ class MultilingualDenoisingTask(DenoisingTask):
end_token = (
self.source_dictionary.index("[{}]".format(language))
if self.args.add_lang_token
if self.cfg.add_lang_token
else self.source_dictionary.eos()
)
@ -155,10 +161,10 @@ class MultilingualDenoisingTask(DenoisingTask):
dataset = TokenBlockDataset(
dataset,
dataset.sizes,
self.args.tokens_per_sample - 2, # one less for <s>
self.cfg.tokens_per_sample - 2, # one less for <s>
pad=self.source_dictionary.pad(),
eos=end_token,
break_mode=self.args.sample_break_mode,
break_mode=self.cfg.sample_break_mode,
)
logger.info("loaded {} blocks from: {}".format(len(dataset), split_path))
@ -177,11 +183,19 @@ class MultilingualDenoisingTask(DenoisingTask):
self.dictionary,
self.mask_idx,
lang_mask_whole_words,
shuffle=self.args.shuffle_instance,
seed=self.seed,
args=self.args,
shuffle=self.cfg.shuffle_instance,
seed=self.cfg.seed,
mask=self.cfg.mask,
mask_random=self.cfg.mask_random,
insert=self.cfg.insert,
rotate=self.cfg.rotate,
permute_sentences=self.cfg.permute_sentences,
bpe=self.cfg.bpe,
replace_length=self.cfg.replace_length,
mask_length=self.cfg.mask_length,
poisson_lambda=self.cfg.poisson_lambda,
eos=None
if not self.args.add_lang_token
if not self.cfg.add_lang_token
else self.source_dictionary.index("[{}]".format(language)),
)
lang_datasets.append(lang_dataset)
@ -195,7 +209,7 @@ class MultilingualDenoisingTask(DenoisingTask):
int(dataset_lengths.sum()),
)
)
if split == self.args.train_subset:
if split == self.cfg.train_subset:
# For train subset, additionally up or down sample languages.
sample_probs = self._get_sample_prob(dataset_lengths)
logger.info(
@ -220,7 +234,7 @@ class MultilingualDenoisingTask(DenoisingTask):
ResamplingDataset(
lang_datasets[i],
size_ratio=size_ratio[i],
seed=self.args.seed,
seed=self.cfg.seed,
epoch=epoch,
replace=size_ratio[i] >= 1.0,
)
@ -237,12 +251,12 @@ class MultilingualDenoisingTask(DenoisingTask):
lang_splits.append(split_name)
self.datasets[split_name] = lang_dataset
if split in self.args.valid_subset:
self.args.valid_subset = self.args.valid_subset.replace(
if split in self.cfg.valid_subset:
self.cfg.valid_subset = self.cfg.valid_subset.replace(
split, ",".join(lang_splits)
)
with data_utils.numpy_seed(self.args.seed + epoch):
with data_utils.numpy_seed(self.cfg.seed + epoch):
shuffle = np.random.permutation(len(dataset))
self.datasets[split] = SortDataset(

View File

@ -0,0 +1,96 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import os
import unittest
from tempfile import TemporaryDirectory
from fairseq import options
from fairseq.binarizer import FileBinarizer, VocabularyDatasetBinarizer
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
from fairseq.tasks.denoising import DenoisingTask
from tests.utils import build_vocab, make_data
class TestDenoising(unittest.TestCase):
def test_denoising(self):
with TemporaryDirectory() as dirname:
# prep input file
raw_file = os.path.join(dirname, "raw")
data = make_data(out_file=raw_file)
vocab = build_vocab(data)
# binarize
binarizer = VocabularyDatasetBinarizer(vocab, append_eos=False)
split = "train"
bin_file = os.path.join(dirname, split)
dataset_impl = "mmap"
FileBinarizer.multiprocess_dataset(
input_file=raw_file,
binarizer=binarizer,
dataset_impl=dataset_impl,
vocab_size=len(vocab),
output_prefix=bin_file,
)
# setup task
train_args = options.parse_args_and_arch(
options.get_training_parser(),
[
"--task",
"denoising",
"--arch",
"bart_base",
"--seed",
"42",
"--mask-length",
"word",
"--permute-sentences",
"1",
"--rotate",
"0",
"--replace-length",
"-1",
"--mask",
"0.2",
dirname,
],
)
cfg = convert_namespace_to_omegaconf(train_args)
task = DenoisingTask(cfg.task, binarizer.dict)
# load datasets
original_dataset = task._load_dataset_split(bin_file, 1, False)
task.load_dataset(split)
masked_dataset = task.dataset(split)
iterator = task.get_batch_iterator(
dataset=masked_dataset,
max_tokens=65_536,
max_positions=4_096,
).next_epoch_itr(shuffle=False)
mask_index = task.source_dictionary.index("<mask>")
for batch in iterator:
for sample in range(len(batch)):
net_input = batch["net_input"]
masked_src_tokens = net_input["src_tokens"][sample]
masked_src_length = net_input["src_lengths"][sample]
masked_tgt_tokens = batch["target"][sample]
sample_id = batch["id"][sample]
original_tokens = original_dataset[sample_id]
original_tokens = original_tokens.masked_select(
masked_src_tokens[:masked_src_length] == mask_index
)
masked_tokens = masked_tgt_tokens.masked_select(
masked_src_tokens == mask_index
)
assert masked_tokens.equal(original_tokens)
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,98 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import os
import unittest
from tempfile import TemporaryDirectory
from fairseq import options
from fairseq.binarizer import FileBinarizer, VocabularyDatasetBinarizer
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
from fairseq.tasks.multilingual_denoising import MultilingualDenoisingTask
from tests.utils import build_vocab, make_data
class TestMultilingualDenoising(unittest.TestCase):
def test_multilingual_denoising(self):
with TemporaryDirectory() as dirname:
# prep input file
lang_dir = os.path.join(dirname, "en")
os.mkdir(lang_dir)
raw_file = os.path.join(lang_dir, "raw")
data = make_data(out_file=raw_file)
vocab = build_vocab(data)
# binarize
binarizer = VocabularyDatasetBinarizer(vocab, append_eos=False)
split = "train"
bin_file = os.path.join(lang_dir, split)
dataset_impl = "mmap"
FileBinarizer.multiprocess_dataset(
input_file=raw_file,
binarizer=binarizer,
dataset_impl=dataset_impl,
vocab_size=len(vocab),
output_prefix=bin_file,
)
# setup task
train_args = options.parse_args_and_arch(
options.get_training_parser(),
[
"--task",
"multilingual_denoising",
"--arch",
"bart_base",
"--seed",
"42",
"--mask-length",
"word",
"--permute-sentences",
"1",
"--rotate",
"0",
"--replace-length",
"-1",
"--mask",
"0.2",
dirname,
],
)
cfg = convert_namespace_to_omegaconf(train_args)
task = MultilingualDenoisingTask(cfg.task, binarizer.dict)
# load datasets
original_dataset = task._load_dataset_split(bin_file, 1, False)
task.load_dataset(split)
masked_dataset = task.dataset(split)
iterator = task.get_batch_iterator(
dataset=masked_dataset,
max_tokens=65_536,
max_positions=4_096,
).next_epoch_itr(shuffle=False)
mask_index = task.source_dictionary.index("<mask>")
for batch in iterator:
for sample in range(len(batch)):
net_input = batch["net_input"]
masked_src_tokens = net_input["src_tokens"][sample]
masked_src_length = net_input["src_lengths"][sample]
masked_tgt_tokens = batch["target"][sample]
sample_id = batch["id"][sample]
original_tokens = original_dataset[sample_id]
original_tokens = original_tokens.masked_select(
masked_src_tokens[:masked_src_length] == mask_index
)
masked_tokens = masked_tgt_tokens.masked_select(
masked_src_tokens == mask_index
)
assert masked_tokens.equal(original_tokens)
if __name__ == "__main__":
unittest.main()