diff --git a/fairseq/data/denoising_dataset.py b/fairseq/data/denoising_dataset.py index bdb62c8d5..a900fc6f9 100644 --- a/fairseq/data/denoising_dataset.py +++ b/fairseq/data/denoising_dataset.py @@ -107,7 +107,6 @@ class DenoisingDataset(FairseqDataset): shuffle (bool, optional): shuffle the elements before batching. Default: ``True`` seed: Seed for random number generator for reproducibility. - args: argparse arguments. """ def __init__( @@ -119,7 +118,15 @@ class DenoisingDataset(FairseqDataset): mask_whole_words, shuffle, seed, - args, + mask, + mask_random, + insert, + rotate, + permute_sentences, + bpe, + replace_length, + mask_length, + poisson_lambda, eos=None, item_transform_func=None, ): @@ -132,31 +139,31 @@ class DenoisingDataset(FairseqDataset): self.seed = seed self.mask_idx = mask_idx self.mask_whole_word = mask_whole_words - self.mask_ratio = args.mask - self.random_ratio = args.mask_random - self.insert_ratio = args.insert - self.rotate_ratio = args.rotate - self.permute_sentence_ratio = args.permute_sentences + self.mask_ratio = mask + self.random_ratio = mask_random + self.insert_ratio = insert + self.rotate_ratio = rotate + self.permute_sentence_ratio = permute_sentences self.eos = eos if eos is not None else vocab.eos() self.item_transform_func = item_transform_func - if args.bpe != "gpt2": + if bpe != "gpt2": self.full_stop_index = self.vocab.eos() else: - assert args.bpe == "gpt2" + assert bpe == "gpt2" self.full_stop_index = self.vocab.index("13") - self.replace_length = args.replace_length + self.replace_length = replace_length if self.replace_length not in [-1, 0, 1]: raise ValueError(f"invalid arg: replace_length={self.replace_length}") - if args.mask_length not in ["subword", "word", "span-poisson"]: - raise ValueError(f"invalid arg: mask-length={args.mask_length}") - if args.mask_length == "subword" and args.replace_length not in [0, 1]: + if mask_length not in ["subword", "word", "span-poisson"]: + raise ValueError(f"invalid arg: mask-length={mask_length}") + if mask_length == "subword" and replace_length not in [0, 1]: raise ValueError(f"if using subwords, use replace-length=1 or 0") self.mask_span_distribution = None - if args.mask_length == "span-poisson": - _lambda = args.poisson_lambda + if mask_length == "span-poisson": + _lambda = poisson_lambda lambda_to_the_k = 1 e_to_the_minus_lambda = math.exp(-_lambda) diff --git a/fairseq/tasks/denoising.py b/fairseq/tasks/denoising.py index 1d4f84c08..57b824d58 100644 --- a/fairseq/tasks/denoising.py +++ b/fairseq/tasks/denoising.py @@ -5,6 +5,11 @@ import logging import os +from dataclasses import dataclass, field +from typing import Any, Optional + +import numpy as np +from omegaconf import II, MISSING from fairseq import utils from fairseq.data import ( @@ -22,145 +27,142 @@ from fairseq.data import ( ) from fairseq.data.encoders.utils import get_whole_word_mask from fairseq.data.shorten_dataset import maybe_shorten_dataset -from fairseq.tasks import LegacyFairseqTask, register_task -import numpy as np +from fairseq.dataclass import ChoiceEnum, FairseqDataclass +from fairseq.tasks import FairseqTask, register_task +from ..data.indexed_dataset import get_available_dataset_impl logger = logging.getLogger(__name__) +SAMPLE_BREAK_MODE_CHOICES = ChoiceEnum(["none", "complete", "complete_doc", "eos"]) +SHORTEN_METHOD_CHOICES = ChoiceEnum(["none", "truncate", "random_crop"]) +MASK_LENGTH_CHOICES = ChoiceEnum(["subword", "word", "span-poisson"]) -@register_task("denoising") -class DenoisingTask(LegacyFairseqTask): + +@dataclass +class DenoisingConfig(FairseqDataclass): + data: str = field( + default=MISSING, + metadata={"help": "path to data directory"}, + ) + bpe: Optional[str] = field( + default=None, + metadata={"help": "TODO"}, + ) + tokens_per_sample: int = field( + default=512, + metadata={ + "help": "max number of total tokens over all segments " + "per sample for dataset" + }, + ) + sample_break_mode: SAMPLE_BREAK_MODE_CHOICES = field( + default="complete_doc", + metadata={ + "help": 'If omitted or "none", fills each sample with tokens-per-sample ' + 'tokens. If set to "complete", splits samples only at the end ' + "of sentence, but may include multiple sentences per sample. " + '"complete_doc" is similar but respects doc boundaries. ' + 'If set to "eos", includes only one sentence per sample.' + }, + ) + replace_length: int = field( + default=0, + metadata={"help": "TODO, should only allow -1, 0 and 1"}, + ) + mask: float = field( + default=0.0, + metadata={"help": "fraction of words/subwords that will be masked"}, + ) + mask_random: float = field( + default=0.0, + metadata={"help": "instead of using [MASK], use random token this often"}, + ) + insert: float = field( + default=0.0, + metadata={"help": "insert this percentage of additional random tokens"}, + ) + permute: float = field( + default=0.0, + metadata={"help": "take this proportion of subwords and permute them"}, + ) + rotate: float = field( + default=0.5, + metadata={"help": "rotate this proportion of inputs"}, + ) + poisson_lambda: float = field( + default=3.0, + metadata={"help": "randomly shuffle sentences for this proportion of inputs"}, + ) + shuffle_instance: float = field( + default=0.0, + metadata={"help": "shuffle this proportion of sentences in all inputs"}, + ) + mask_length: MASK_LENGTH_CHOICES = field( + default="subword", + metadata={"help": "mask length to choose"}, + ) + permute_sentences: int = field( + default=-1, + metadata={ + "help": "when masking N tokens, replace with 0, 1, or N tokens (use -1 for N)" + }, + ) + seed: int = II("common.seed") + shorten_method: SHORTEN_METHOD_CHOICES = field( + default="none", + metadata={ + "help": "if not none, shorten sequences that exceed --tokens-per-sample" + }, + ) + shorten_data_split_list: str = field( + default="", + metadata={ + "help": "comma-separated list of dataset splits to apply shortening to, " + 'e.g., "train,valid" (default: all dataset splits)' + }, + ) + max_source_positions: int = field( + default=1024, + metadata={"help": "max number of tokens in the source sequence"}, + ) + max_target_positions: int = field( + default=1024, + metadata={"help": "max number of tokens in the target sequence"}, + ) + dataset_impl: Optional[ChoiceEnum(get_available_dataset_impl())] = II( + "dataset.dataset_impl" + ) + + +@register_task("denoising", dataclass=DenoisingConfig) +class DenoisingTask(FairseqTask): """ Denoising task for applying sequence to sequence denoising. (ie. BART) """ - @staticmethod - def add_args(parser): - """Add task-specific arguments to the parser.""" - parser.add_argument("data", help="path to data directory") - parser.add_argument( - "--tokens-per-sample", - default=512, - type=int, - help="max number of total tokens over all segments" - " per sample for dataset", - ) - parser.add_argument( - "--sample-break-mode", - default="complete_doc", - type=str, - help="mode for breaking sentence", - ) - parser.add_argument( - "--mask", - default=0.0, - type=float, - help="fraction of words/subwords that will be masked", - ) - parser.add_argument( - "--mask-random", - default=0.0, - type=float, - help="instead of using [MASK], use random token this often", - ) - parser.add_argument( - "--insert", - default=0.0, - type=float, - help="insert this percentage of additional random tokens", - ) - parser.add_argument( - "--permute", - default=0.0, - type=float, - help="take this proportion of subwords and permute them", - ) - parser.add_argument( - "--rotate", - default=0.5, - type=float, - help="rotate this proportion of inputs", - ) - parser.add_argument( - "--poisson-lambda", - default=3.0, - type=float, - help="randomly shuffle sentences for this proportion of inputs", - ) - parser.add_argument( - "--permute-sentences", - default=0.0, - type=float, - help="shuffle this proportion of sentences in all inputs", - ) - parser.add_argument( - "--mask-length", - default="subword", - type=str, - choices=["subword", "word", "span-poisson"], - help="mask length to choose", - ) - parser.add_argument( - "--replace-length", - default=-1, - type=int, - help="when masking N tokens, replace with 0, 1, or N tokens (use -1 for N)", - ) - parser.add_argument( - "--max-source-positions", - default=1024, - type=int, - metavar="N", - help="max number of tokens in the source sequence", - ) - parser.add_argument( - "--max-target-positions", - default=1024, - type=int, - metavar="N", - help="max number of tokens in the target sequence", - ) + cfg: DenoisingConfig - parser.add_argument( - "--shorten-method", - default="none", - choices=["none", "truncate", "random_crop"], - help="if not none, shorten sequences that exceed --tokens-per-sample", - ) - parser.add_argument( - "--shorten-data-split-list", - default="", - help="comma-separated list of dataset splits to apply shortening to, " - 'e.g., "train,valid" (default: all dataset splits)', - ) - - def __init__(self, args, dictionary): - super().__init__(args) + def __init__(self, cfg, dictionary): + super().__init__(cfg) self.dictionary = dictionary - self.seed = args.seed # add mask token self.mask_idx = self.dictionary.add_symbol("") @classmethod - def setup_task(cls, args, **kwargs): + def setup_task(cls, cfg: DenoisingConfig, **kwargs): """Setup the task.""" - paths = utils.split_paths(args.data) + paths = utils.split_paths(cfg.data) assert len(paths) > 0 dictionary = Dictionary.load(os.path.join(paths[0], "dict.txt")) logger.info("dictionary: {} types".format(len(dictionary))) - if not hasattr(args, "shuffle_instance"): - args.shuffle_instance = False - return cls(args, dictionary) + if not hasattr(cfg, "shuffle_instance"): + cfg.shuffle_instance = False + return cls(cfg, dictionary) - def load_dataset(self, split, epoch=1, combine=False, **kwargs): - """Load a given dataset split. - - Args: - split (str): name of the split (e.g., train, valid, test) - """ - paths = utils.split_paths(self.args.data) + def _load_dataset_split(self, split, epoch, combine): + paths = utils.split_paths(self.cfg.data) assert len(paths) > 0 data_path = paths[(epoch - 1) % len(paths)] split_path = os.path.join(data_path, split) @@ -168,7 +170,7 @@ class DenoisingTask(LegacyFairseqTask): dataset = data_utils.load_indexed_dataset( split_path, self.dictionary, - self.args.dataset_impl, + self.cfg.dataset_impl, combine=combine, ) if dataset is None: @@ -181,20 +183,21 @@ class DenoisingTask(LegacyFairseqTask): dataset = maybe_shorten_dataset( dataset, split, - self.args.shorten_data_split_list, - self.args.shorten_method, - self.args.tokens_per_sample, - self.args.seed, + self.cfg.shorten_data_split_list, + self.cfg.shorten_method, + self.cfg.tokens_per_sample, + self.cfg.seed, ) # create continuous blocks of tokens dataset = TokenBlockDataset( dataset, dataset.sizes, - self.args.tokens_per_sample - 2, # one less for and one for + self.cfg.tokens_per_sample - 2, + # one less for and one for pad=self.dictionary.pad(), eos=self.dictionary.eos(), - break_mode=self.args.sample_break_mode, + break_mode=self.cfg.sample_break_mode, document_sep_len=0, ) logger.info("loaded {} blocks from: {}".format(len(dataset), split_path)) @@ -202,10 +205,19 @@ class DenoisingTask(LegacyFairseqTask): # prepend beginning-of-sentence token (, equiv. to [CLS] in BERT) dataset = PrependTokenDataset(dataset, self.source_dictionary.bos()) dataset = AppendTokenDataset(dataset, self.source_dictionary.eos()) + return dataset + + def load_dataset(self, split, epoch=1, combine=False, **kwargs): + """Load a given dataset split. + + Args: + split (str): name of the split (e.g., train, valid, test) + """ + dataset = self._load_dataset_split(split, epoch, combine) mask_whole_words = ( - get_whole_word_mask(self.args, self.source_dictionary) - if self.args.mask_length != "subword" + get_whole_word_mask(self.cfg.bpe, self.source_dictionary) + if self.cfg.mask_length != "subword" else None ) @@ -215,9 +227,17 @@ class DenoisingTask(LegacyFairseqTask): self.dictionary, self.mask_idx, mask_whole_words, - shuffle=self.args.shuffle_instance, - seed=self.seed, - args=self.args, + shuffle=self.cfg.shuffle_instance, + seed=self.cfg.seed, + mask=self.cfg.mask, + mask_random=self.cfg.mask_random, + insert=self.cfg.insert, + rotate=self.cfg.rotate, + permute_sentences=self.cfg.permute_sentences, + bpe=self.cfg.bpe, + replace_length=self.cfg.replace_length, + mask_length=self.cfg.mask_length, + poisson_lambda=self.cfg.poisson_lambda, ) logger.info( "Split: {0}, Loaded {1} samples of denoising_dataset".format( @@ -236,10 +256,10 @@ class DenoisingTask(LegacyFairseqTask): src_dataset = TokenBlockDataset( src_tokens, src_lengths, - block_size=self.args.tokens_per_sample - 2, # for and + block_size=self.cfg.tokens_per_sample - 2, # for and pad=pad, eos=eos, - break_mode=self.args.sample_break_mode, + break_mode=self.cfg.sample_break_mode, document_sep_len=0, ) prev_output_tokens = PrependTokenDataset( @@ -263,7 +283,7 @@ class DenoisingTask(LegacyFairseqTask): def max_positions(self): """Return the max sentence length allowed by the task.""" - return (self.args.max_source_positions, self.args.max_target_positions) + return (self.cfg.max_source_positions, self.cfg.max_target_positions) @property def source_dictionary(self): diff --git a/fairseq/tasks/multilingual_denoising.py b/fairseq/tasks/multilingual_denoising.py index 8226d9503..cb5ee3455 100644 --- a/fairseq/tasks/multilingual_denoising.py +++ b/fairseq/tasks/multilingual_denoising.py @@ -2,11 +2,13 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. - import logging import os +from dataclasses import dataclass, field +from typing import Optional import numpy as np +from omegaconf import II from fairseq.data import ( AppendTokenDataset, @@ -22,43 +24,49 @@ from fairseq.data import ( from fairseq.data.encoders.utils import get_whole_word_mask from fairseq.tasks import register_task -from .denoising import DenoisingTask +from .denoising import DenoisingConfig, DenoisingTask logger = logging.getLogger(__name__) -@register_task("multilingual_denoising") +@dataclass +class MultilingualDenoisingConfig(DenoisingConfig): + multilang_sampling_alpha: float = field( + default=1.0, + metadata={"help": "smoothing alpha for sample ratios across multiple datasets"}, + ) + add_lang_token: bool = field( + default=False, + metadata={"help": ""}, + ) + langs: Optional[str] = field( + default=None, + metadata={"help": "language ids we are considering"}, + ) + no_whole_word_mask_langs: str = field( + default="", + metadata={ + "help": "languages without spacing between words don't support whole word masking" + }, + ) + train_subset: str = II("common.train_subset") + valid_subset: str = II("common.valid_subset") + + +@register_task("multilingual_denoising", dataclass=MultilingualDenoisingConfig) class MultilingualDenoisingTask(DenoisingTask): - @staticmethod - def add_args(parser): - DenoisingTask.add_args(parser) - parser.add_argument( - "--multilang-sampling-alpha", - type=float, - default=1.0, - help="smoothing alpha for sample ratios across multiple datasets", - ) - parser.add_argument("--add-lang-token", default=False, action="store_true") - parser.add_argument( - "--langs", type=str, help="language ids we are considering", default=None - ) - parser.add_argument( - "--no-whole-word-mask-langs", - type=str, - default="", - metavar="N", - help="languages without spacing between words dont support whole word masking", - ) + + cfg: MultilingualDenoisingConfig @classmethod - def setup_task(cls, args, **kwargs): + def setup_task(cls, cfg: MultilingualDenoisingConfig, **kwargs): """Setup the task.""" - paths = args.data.split(":") + paths = cfg.data.split(":") assert len(paths) > 0 dictionary = Dictionary.load(os.path.join(paths[0], "dict.txt")) data_path = paths[0] - if args.langs is None: + if cfg.langs is None: languages = sorted( [ name @@ -67,34 +75,32 @@ class MultilingualDenoisingTask(DenoisingTask): ] ) else: - languages = args.langs.split(",") + languages = cfg.langs.split(",") - if args.add_lang_token: + if cfg.add_lang_token: for lang in languages: dictionary.add_symbol("[{}]".format(lang)) logger.info("dictionary: {} types".format(len(dictionary))) - if not hasattr(args, "shuffle_instance"): - args.shuffle_instance = False - return cls(args, dictionary) + if not hasattr(cfg, "shuffle_instance"): + cfg.shuffle_instance = False + return cls(cfg, dictionary) - def __init__(self, args, dictionary): - super().__init__(args, dictionary) + def __init__(self, cfg: MultilingualDenoisingConfig, dictionary): + super().__init__(cfg, dictionary) self.dictionary = dictionary - self.seed = args.seed # add mask token self.mask_idx = self.dictionary.add_symbol("") - self.langs = args.langs - self.args = args + self.cfg = cfg def _get_sample_prob(self, dataset_lens): """ - Get smoothed sampling porbability by languages. This helps low resource + Get smoothed sampling probability by languages. This helps low resource languages by upsampling them. """ prob = dataset_lens / dataset_lens.sum() - smoothed_prob = prob**self.args.multilang_sampling_alpha + smoothed_prob = prob**self.cfg.multilang_sampling_alpha smoothed_prob = smoothed_prob / smoothed_prob.sum() return smoothed_prob @@ -104,12 +110,12 @@ class MultilingualDenoisingTask(DenoisingTask): Args: split (str): name of the split (e.g., train, valid, test) """ - paths = self.args.data.split(":") + paths = self.cfg.data.split(":") assert len(paths) > 0 data_path = paths[(epoch - 1) % len(paths)] split_path = os.path.join(data_path, split) - if self.langs is None: + if self.cfg.langs is None: languages = sorted( [ name @@ -118,7 +124,7 @@ class MultilingualDenoisingTask(DenoisingTask): ] ) else: - languages = self.langs.split(",") + languages = self.cfg.langs.split(",") for name in languages: p = os.path.join(data_path, name) assert os.path.exists(p), "data not found: {}".format(p) @@ -128,8 +134,8 @@ class MultilingualDenoisingTask(DenoisingTask): "Language to id mapping: ", {lang: id for id, lang in enumerate(languages)} ) - mask_whole_words = get_whole_word_mask(self.args, self.dictionary) - language_without_segmentations = self.args.no_whole_word_mask_langs.split(",") + mask_whole_words = get_whole_word_mask(self.cfg.bpe, self.dictionary) + language_without_segmentations = self.cfg.no_whole_word_mask_langs.split(",") lang_datasets = [] for language in languages: split_path = os.path.join(data_path, language, split) @@ -137,7 +143,7 @@ class MultilingualDenoisingTask(DenoisingTask): dataset = data_utils.load_indexed_dataset( split_path, self.source_dictionary, - self.args.dataset_impl, + self.cfg.dataset_impl, combine=combine, ) if dataset is None: @@ -147,7 +153,7 @@ class MultilingualDenoisingTask(DenoisingTask): end_token = ( self.source_dictionary.index("[{}]".format(language)) - if self.args.add_lang_token + if self.cfg.add_lang_token else self.source_dictionary.eos() ) @@ -155,10 +161,10 @@ class MultilingualDenoisingTask(DenoisingTask): dataset = TokenBlockDataset( dataset, dataset.sizes, - self.args.tokens_per_sample - 2, # one less for + self.cfg.tokens_per_sample - 2, # one less for pad=self.source_dictionary.pad(), eos=end_token, - break_mode=self.args.sample_break_mode, + break_mode=self.cfg.sample_break_mode, ) logger.info("loaded {} blocks from: {}".format(len(dataset), split_path)) @@ -177,11 +183,19 @@ class MultilingualDenoisingTask(DenoisingTask): self.dictionary, self.mask_idx, lang_mask_whole_words, - shuffle=self.args.shuffle_instance, - seed=self.seed, - args=self.args, + shuffle=self.cfg.shuffle_instance, + seed=self.cfg.seed, + mask=self.cfg.mask, + mask_random=self.cfg.mask_random, + insert=self.cfg.insert, + rotate=self.cfg.rotate, + permute_sentences=self.cfg.permute_sentences, + bpe=self.cfg.bpe, + replace_length=self.cfg.replace_length, + mask_length=self.cfg.mask_length, + poisson_lambda=self.cfg.poisson_lambda, eos=None - if not self.args.add_lang_token + if not self.cfg.add_lang_token else self.source_dictionary.index("[{}]".format(language)), ) lang_datasets.append(lang_dataset) @@ -195,7 +209,7 @@ class MultilingualDenoisingTask(DenoisingTask): int(dataset_lengths.sum()), ) ) - if split == self.args.train_subset: + if split == self.cfg.train_subset: # For train subset, additionally up or down sample languages. sample_probs = self._get_sample_prob(dataset_lengths) logger.info( @@ -220,7 +234,7 @@ class MultilingualDenoisingTask(DenoisingTask): ResamplingDataset( lang_datasets[i], size_ratio=size_ratio[i], - seed=self.args.seed, + seed=self.cfg.seed, epoch=epoch, replace=size_ratio[i] >= 1.0, ) @@ -237,12 +251,12 @@ class MultilingualDenoisingTask(DenoisingTask): lang_splits.append(split_name) self.datasets[split_name] = lang_dataset - if split in self.args.valid_subset: - self.args.valid_subset = self.args.valid_subset.replace( + if split in self.cfg.valid_subset: + self.cfg.valid_subset = self.cfg.valid_subset.replace( split, ",".join(lang_splits) ) - with data_utils.numpy_seed(self.args.seed + epoch): + with data_utils.numpy_seed(self.cfg.seed + epoch): shuffle = np.random.permutation(len(dataset)) self.datasets[split] = SortDataset( diff --git a/tests/tasks/test_denoising.py b/tests/tasks/test_denoising.py new file mode 100644 index 000000000..5c2216835 --- /dev/null +++ b/tests/tasks/test_denoising.py @@ -0,0 +1,96 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import os +import unittest +from tempfile import TemporaryDirectory + +from fairseq import options +from fairseq.binarizer import FileBinarizer, VocabularyDatasetBinarizer +from fairseq.dataclass.utils import convert_namespace_to_omegaconf +from fairseq.tasks.denoising import DenoisingTask +from tests.utils import build_vocab, make_data + + +class TestDenoising(unittest.TestCase): + def test_denoising(self): + with TemporaryDirectory() as dirname: + + # prep input file + raw_file = os.path.join(dirname, "raw") + data = make_data(out_file=raw_file) + vocab = build_vocab(data) + + # binarize + binarizer = VocabularyDatasetBinarizer(vocab, append_eos=False) + split = "train" + bin_file = os.path.join(dirname, split) + dataset_impl = "mmap" + FileBinarizer.multiprocess_dataset( + input_file=raw_file, + binarizer=binarizer, + dataset_impl=dataset_impl, + vocab_size=len(vocab), + output_prefix=bin_file, + ) + + # setup task + train_args = options.parse_args_and_arch( + options.get_training_parser(), + [ + "--task", + "denoising", + "--arch", + "bart_base", + "--seed", + "42", + "--mask-length", + "word", + "--permute-sentences", + "1", + "--rotate", + "0", + "--replace-length", + "-1", + "--mask", + "0.2", + dirname, + ], + ) + cfg = convert_namespace_to_omegaconf(train_args) + task = DenoisingTask(cfg.task, binarizer.dict) + + # load datasets + original_dataset = task._load_dataset_split(bin_file, 1, False) + task.load_dataset(split) + masked_dataset = task.dataset(split) + + iterator = task.get_batch_iterator( + dataset=masked_dataset, + max_tokens=65_536, + max_positions=4_096, + ).next_epoch_itr(shuffle=False) + mask_index = task.source_dictionary.index("") + for batch in iterator: + for sample in range(len(batch)): + net_input = batch["net_input"] + masked_src_tokens = net_input["src_tokens"][sample] + masked_src_length = net_input["src_lengths"][sample] + masked_tgt_tokens = batch["target"][sample] + + sample_id = batch["id"][sample] + original_tokens = original_dataset[sample_id] + original_tokens = original_tokens.masked_select( + masked_src_tokens[:masked_src_length] == mask_index + ) + masked_tokens = masked_tgt_tokens.masked_select( + masked_src_tokens == mask_index + ) + + assert masked_tokens.equal(original_tokens) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/tasks/test_multilingual_denoising.py b/tests/tasks/test_multilingual_denoising.py new file mode 100644 index 000000000..a0227f69b --- /dev/null +++ b/tests/tasks/test_multilingual_denoising.py @@ -0,0 +1,98 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import os +import unittest +from tempfile import TemporaryDirectory + +from fairseq import options +from fairseq.binarizer import FileBinarizer, VocabularyDatasetBinarizer +from fairseq.dataclass.utils import convert_namespace_to_omegaconf +from fairseq.tasks.multilingual_denoising import MultilingualDenoisingTask +from tests.utils import build_vocab, make_data + + +class TestMultilingualDenoising(unittest.TestCase): + def test_multilingual_denoising(self): + with TemporaryDirectory() as dirname: + + # prep input file + lang_dir = os.path.join(dirname, "en") + os.mkdir(lang_dir) + raw_file = os.path.join(lang_dir, "raw") + data = make_data(out_file=raw_file) + vocab = build_vocab(data) + + # binarize + binarizer = VocabularyDatasetBinarizer(vocab, append_eos=False) + split = "train" + bin_file = os.path.join(lang_dir, split) + dataset_impl = "mmap" + FileBinarizer.multiprocess_dataset( + input_file=raw_file, + binarizer=binarizer, + dataset_impl=dataset_impl, + vocab_size=len(vocab), + output_prefix=bin_file, + ) + + # setup task + train_args = options.parse_args_and_arch( + options.get_training_parser(), + [ + "--task", + "multilingual_denoising", + "--arch", + "bart_base", + "--seed", + "42", + "--mask-length", + "word", + "--permute-sentences", + "1", + "--rotate", + "0", + "--replace-length", + "-1", + "--mask", + "0.2", + dirname, + ], + ) + cfg = convert_namespace_to_omegaconf(train_args) + task = MultilingualDenoisingTask(cfg.task, binarizer.dict) + + # load datasets + original_dataset = task._load_dataset_split(bin_file, 1, False) + task.load_dataset(split) + masked_dataset = task.dataset(split) + + iterator = task.get_batch_iterator( + dataset=masked_dataset, + max_tokens=65_536, + max_positions=4_096, + ).next_epoch_itr(shuffle=False) + mask_index = task.source_dictionary.index("") + for batch in iterator: + for sample in range(len(batch)): + net_input = batch["net_input"] + masked_src_tokens = net_input["src_tokens"][sample] + masked_src_length = net_input["src_lengths"][sample] + masked_tgt_tokens = batch["target"][sample] + + sample_id = batch["id"][sample] + original_tokens = original_dataset[sample_id] + original_tokens = original_tokens.masked_select( + masked_src_tokens[:masked_src_length] == mask_index + ) + masked_tokens = masked_tgt_tokens.masked_select( + masked_src_tokens == mask_index + ) + + assert masked_tokens.equal(original_tokens) + + +if __name__ == "__main__": + unittest.main()