From ba415c99ca89b212ce9c8b79c565d828b89a8761 Mon Sep 17 00:00:00 2001 From: Alexander Jipa Date: Wed, 29 Jun 2022 10:04:00 -0400 Subject: [PATCH] add span_masked_lm task (#4366) Co-authored-by: Alexander Jipa --- fairseq/data/span_mask_tokens_dataset.py | 293 +++++++++++++++++++++++ fairseq/tasks/span_masked_lm.py | 243 +++++++++++++++++++ tests/tasks/test_span_masked_lm.py | 106 ++++++++ 3 files changed, 642 insertions(+) create mode 100644 fairseq/data/span_mask_tokens_dataset.py create mode 100644 fairseq/tasks/span_masked_lm.py create mode 100644 tests/tasks/test_span_masked_lm.py diff --git a/fairseq/data/span_mask_tokens_dataset.py b/fairseq/data/span_mask_tokens_dataset.py new file mode 100644 index 00000000..72189bd3 --- /dev/null +++ b/fairseq/data/span_mask_tokens_dataset.py @@ -0,0 +1,293 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch + +from . import Dictionary, FairseqDataset, data_utils + + +def collate( + samples, + pad_idx, + eos_idx, + vocab, + left_pad_source=False, + left_pad_target=False, + input_feeding=True, + pad_to_length=None, +): + assert input_feeding + if len(samples) == 0: + return {} + + def merge(key, left_pad, move_eos_to_beginning=False, pad_to_length=None): + return data_utils.collate_tokens( + [s[key] for s in samples], + pad_idx, + eos_idx=None, # use eos_idx of each sample instead of vocab.eos() + left_pad=left_pad, + move_eos_to_beginning=move_eos_to_beginning, + pad_to_length=pad_to_length, + ) + + id = torch.LongTensor([s["id"] for s in samples]) + src_tokens = merge( + "source", + left_pad=left_pad_source, + pad_to_length=pad_to_length["source"] if pad_to_length is not None else None, + ) + # sort by descending source length + src_lengths = torch.LongTensor([s["source"].numel() for s in samples]) + src_lengths, sort_order = src_lengths.sort(descending=True) + id = id.index_select(0, sort_order) + src_tokens = src_tokens.index_select(0, sort_order) + + prev_output_tokens = None + target = None + if samples[0].get("target", None) is not None: + target = merge( + "target", + left_pad=left_pad_target, + pad_to_length=pad_to_length["target"] + if pad_to_length is not None + else None, + ) + target = target.index_select(0, sort_order) + ntokens = sum(len(s["target"]) for s in samples) + + if input_feeding: + # we create a shifted version of targets for feeding the + # previous output token(s) into the next decoder step + prev_output_tokens = merge( + "target", + left_pad=left_pad_target, + move_eos_to_beginning=True, + pad_to_length=pad_to_length["target"] + if pad_to_length is not None + else None, + ) + prev_output_tokens = prev_output_tokens.index_select(0, sort_order) + else: + ntokens = sum(len(s["source"]) for s in samples) + + batch = { + "id": id, + "ntokens": ntokens, + "net_input": { + "src_tokens": src_tokens, + "src_lengths": src_lengths, + }, + "target": target, + "target_lengths": torch.LongTensor([len(t) for t in target]), + "nsentences": samples[0]["source"].size(0), + "sort_order": sort_order, + } + if prev_output_tokens is not None: + batch["net_input"]["prev_output_tokens"] = prev_output_tokens + + return batch + + +class SpanMaskedTokensDataset(FairseqDataset): + """ + A wrapper around TokenBlockDataset for T5 dataset. + + Args: + dataset (~torch.utils.data.Dataset): dataset to wrap + vocab (~fairseq.data.Dictionary): vocabulary + noise_density (float): fraction of the tokens to select as noise. + mean_noise_span_length (float): mean noise span length. + shuffle (bool, optional): shuffle the elements before batching. + Default: ``True`` + seed: Seed for random number generator for reproducibility. + """ + + def __init__( + self, + dataset: torch.utils.data.Dataset, + vocab: Dictionary, + noise_density: float, + mean_noise_span_length: float, + shuffle: bool, + seed: int = 1, + ): + self.dataset = dataset + self.vocab = vocab + self.seed = seed + self.noise_density = noise_density + self.mean_noise_span_length = mean_noise_span_length + self.shuffle = shuffle + self.epoch = 0 + + @property + def can_reuse_epoch_itr_across_epochs(self): + return True # only the noise changes, not item sizes + + def set_epoch(self, epoch, **unused): + self.epoch = epoch + + def __getitem__(self, index): + with data_utils.numpy_seed(self.seed, self.epoch, index): + item = self.dataset[index] + assert item[-1] == self.vocab.eos() + + noise_mask = self.random_spans_noise_mask(len(item)) + + source_sentinel_ids = self.create_sentinel_ids(noise_mask.astype(np.int8)) + source = self.filter_input_ids(item, source_sentinel_ids) + + target_sentinel_ids = self.create_sentinel_ids( + (~noise_mask).astype(np.int8) + ) + target = self.filter_input_ids(item, target_sentinel_ids) + + return { + "id": index, + "source": torch.from_numpy(source), + "target": torch.from_numpy(target), + } + + def random_spans_noise_mask(self, length): + + """ + This function is copy of `random_spans_helper `__ . + Noise mask consisting of random spans of noise tokens. + The number of noise tokens and the number of noise spans and non-noise spans + are determined deterministically as follows: + num_noise_tokens = round(length * noise_density) + num_nonnoise_spans = num_noise_spans = round(num_noise_tokens / mean_noise_span_length) + Spans alternate between non-noise and noise, beginning with non-noise. + Subject to the above restrictions, all masks are equally likely. + Args: + length: an int32 scalar (length of the incoming token sequence) + Returns: + a boolean tensor with shape [length] + """ + + orig_length = length + + num_noise_tokens = int(np.round(length * self.noise_density)) + # avoid degeneracy by ensuring positive numbers of noise and nonnoise tokens. + num_noise_tokens = min(max(num_noise_tokens, 1), length - 1) + num_noise_spans = int(np.round(num_noise_tokens / self.mean_noise_span_length)) + + # avoid degeneracy by ensuring positive number of noise spans + num_noise_spans = max(num_noise_spans, 1) + num_nonnoise_tokens = length - num_noise_tokens + + # pick the lengths of the noise spans and the non-noise spans + def _random_segmentation(num_items, num_segments): + """ + Partition a sequence of items randomly into non-empty segments. + Args: + num_items: an integer scalar > 0 + num_segments: an integer scalar in [1, num_items] + Returns: + a Tensor with shape [num_segments] containing positive integers that add up to num_items + """ + mask_indices = np.arange(num_items - 1) < (num_segments - 1) + np.random.shuffle(mask_indices) + first_in_segment = np.pad(mask_indices, [[1, 0]]) + segment_id = np.cumsum(first_in_segment) + # count length of subsegments assuming that list is sorted + _, segment_length = np.unique(segment_id, return_counts=True) + return segment_length + + noise_span_lengths = _random_segmentation(num_noise_tokens, num_noise_spans) + nonnoise_span_lengths = _random_segmentation( + num_nonnoise_tokens, num_noise_spans + ) + + interleaved_span_lengths = np.reshape( + np.stack([nonnoise_span_lengths, noise_span_lengths], axis=1), + [num_noise_spans * 2], + ) + span_starts = np.cumsum(interleaved_span_lengths)[:-1] + span_start_indicator = np.zeros((length,), dtype=np.int8) + span_start_indicator[span_starts] = True + span_num = np.cumsum(span_start_indicator) + is_noise = np.equal(span_num % 2, 1) + + return is_noise[:orig_length] + + def create_sentinel_ids(self, mask_indices): + """ + Sentinel ids creation given the indices that should be masked. + The start indices of each mask are replaced by the sentinel ids in increasing + order. Consecutive mask indices to be deleted are replaced with `-1`. + """ + start_indices = mask_indices - np.roll(mask_indices, 1, axis=-1) * mask_indices + + sentinel_ids = np.where( + start_indices != 0, np.cumsum(start_indices, axis=-1), start_indices + ) + # making sure all sentinel tokens are unique over the example + sentinel_ids = np.where(sentinel_ids != 0, len(self.vocab) - sentinel_ids, 0) + sentinel_ids -= mask_indices - start_indices + return sentinel_ids + + @staticmethod + def filter_input_ids(input_ids, sentinel_ids): + """ + Puts sentinel mask on `input_ids` and fuse consecutive mask tokens into a single mask token by deleting. + This will reduce the sequence length from `expanded_inputs_length` to `input_length`. + """ + input_ids_full = np.where(sentinel_ids != 0, sentinel_ids, input_ids) + + # input_ids tokens and sentinel tokens are >= 0, tokens < 0 are + # masked tokens coming after sentinel tokens and should be removed + return input_ids_full[input_ids_full >= 0] + + def __len__(self): + return len(self.dataset) + + def collater(self, samples, pad_to_length=None): + """ + Merge a list of samples to form a mini-batch. + Args: + samples (List[dict]): samples to collate + Returns: + dict: a mini-batch of data + """ + return collate( + samples, + self.vocab.pad(), + self.vocab.eos(), + self.vocab, + pad_to_length=pad_to_length, + ) + + def num_tokens(self, index): + """Return the number of tokens in a sample. This value is used to + enforce ``--max-tokens`` during batching.""" + return self.dataset.sizes[index] + + def size(self, index): + """Return an example's size as a float or tuple. This value is used when + filtering a dataset with ``--max-positions``.""" + return self.dataset.sizes[index] + + def ordered_indices(self): + """Return an ordered list of indices. Batches will be constructed based + on this order.""" + if self.shuffle: + indices = np.random.permutation(len(self)) + else: + indices = np.arange(len(self)) + return indices[np.argsort(self.dataset.sizes[indices], kind="mergesort")] + + def prefetch(self, indices): + self.src.prefetch(indices) + self.tgt.prefetch(indices) + + @property + def supports_prefetch(self): + return ( + hasattr(self.src, "supports_prefetch") + and self.src.supports_prefetch + and hasattr(self.tgt, "supports_prefetch") + and self.tgt.supports_prefetch + ) diff --git a/fairseq/tasks/span_masked_lm.py b/fairseq/tasks/span_masked_lm.py new file mode 100644 index 00000000..d746aa15 --- /dev/null +++ b/fairseq/tasks/span_masked_lm.py @@ -0,0 +1,243 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import os +from dataclasses import dataclass, field +from typing import Optional + +import numpy as np +from omegaconf import II, MISSING + +from fairseq import utils +from fairseq.data import ( + AppendTokenDataset, + Dictionary, + IdDataset, + NestedDictionaryDataset, + NumelDataset, + PadDataset, + PrependTokenDataset, + StripTokenDataset, + TokenBlockDataset, + data_utils, +) +from fairseq.data.shorten_dataset import maybe_shorten_dataset +from fairseq.data.span_mask_tokens_dataset import SpanMaskedTokensDataset +from fairseq.dataclass import ChoiceEnum, FairseqDataclass +from fairseq.tasks import FairseqTask, register_task + +from ..data.indexed_dataset import get_available_dataset_impl + +logger = logging.getLogger(__name__) + +SAMPLE_BREAK_MODE_CHOICES = ChoiceEnum(["none", "complete", "complete_doc", "eos"]) +SHORTEN_METHOD_CHOICES = ChoiceEnum(["none", "truncate", "random_crop"]) + + +@dataclass +class SpanMaskedLMConfig(FairseqDataclass): + shuffle: bool = field( + default=False, + ) + noise_density: float = field( + default=0.15, + metadata={"help": "What fraction of the tokens to select as noise"}, + ) + mean_noise_span_length: float = field( + default=3, + metadata={"help": "Mean noise span length, must be >= 1"}, + ) + data: str = field( + default=MISSING, + metadata={ + "help": "colon separated path to data directories list, " + "will be iterated upon during epochs in round-robin manner" + }, + ) + sample_break_mode: SAMPLE_BREAK_MODE_CHOICES = field( + default="none", + metadata={ + "help": 'If omitted or "none", fills each sample with tokens-per-sample ' + 'tokens. If set to "complete", splits samples only at the end ' + "of sentence, but may include multiple sentences per sample. " + '"complete_doc" is similar but respects doc boundaries. ' + 'If set to "eos", includes only one sentence per sample.' + }, + ) + tokens_per_sample: int = field( + default=1024, + metadata={"help": "max number of tokens per sample for LM dataset"}, + ) + shorten_method: SHORTEN_METHOD_CHOICES = field( + default="none", + metadata={ + "help": "if not none, shorten sequences that exceed --tokens-per-sample" + }, + ) + shorten_data_split_list: str = field( + default="", + metadata={ + "help": "comma-separated list of dataset splits to apply shortening to, " + 'e.g., "train,valid" (default: all dataset splits)' + }, + ) + seed: int = II("common.seed") + dataset_impl: Optional[ChoiceEnum(get_available_dataset_impl())] = II( + "dataset.dataset_impl" + ) + max_source_positions: int = field( + default=1024, metadata={"help": "max number of tokens in the source sequence"} + ) + max_target_positions: int = field( + default=1024, metadata={"help": "max number of tokens in the target sequence"} + ) + include_target_tokens: bool = field( + default=False, + metadata={ + "help": "include target tokens in model input. this is used for data2vec" + }, + ) + + +@register_task("span_masked_lm", dataclass=SpanMaskedLMConfig) +class SpanMaskedLMTask(FairseqTask): + """ + Span masked language modeling task. (ie. T5) + """ + + cfg: SpanMaskedLMConfig + + def __init__(self, cfg, dictionary): + super().__init__(cfg) + self.dictionary = dictionary + + @classmethod + def setup_task(cls, cfg: SpanMaskedLMConfig, **kwargs): + """Setup the task.""" + paths = utils.split_paths(cfg.data) + assert len(paths) > 0 + dictionary = Dictionary.load(os.path.join(paths[0], "dict.txt")) + logger.info("dictionary: {} types".format(len(dictionary))) + if not hasattr(cfg, "shuffle"): + cfg.shuffle = False + return cls(cfg, dictionary) + + def _load_dataset_split(self, split, epoch, combine): + paths = utils.split_paths(self.cfg.data) + assert len(paths) > 0 + data_path = paths[(epoch - 1) % len(paths)] + split_path = os.path.join(data_path, split) + + dataset = data_utils.load_indexed_dataset( + split_path, + self.dictionary, + self.cfg.dataset_impl, + combine=combine, + ) + if dataset is None: + raise FileNotFoundError( + "Dataset not found: {} ({})".format(split, split_path) + ) + + dataset = StripTokenDataset(dataset, self.dictionary.eos()) + + dataset = maybe_shorten_dataset( + dataset, + split, + self.cfg.shorten_data_split_list, + self.cfg.shorten_method, + self.cfg.tokens_per_sample, + self.cfg.seed, + ) + + # create continuous blocks of tokens + dataset = TokenBlockDataset( + dataset, + dataset.sizes, + self.cfg.tokens_per_sample - 2, # one less for and one for + pad=self.dictionary.pad(), + eos=self.dictionary.eos(), + break_mode=self.cfg.sample_break_mode, + document_sep_len=0, + ) + logger.info("loaded {} blocks from: {}".format(len(dataset), split_path)) + + # prepend beginning-of-sentence token (, equiv. to [CLS] in BERT) + dataset = PrependTokenDataset(dataset, self.source_dictionary.bos()) + dataset = AppendTokenDataset(dataset, self.source_dictionary.eos()) + return dataset + + def load_dataset(self, split, epoch=1, combine=False, **kwargs): + """Load a given dataset split. + + Args: + split (str): name of the split (e.g., train, valid, test) + """ + dataset = self._load_dataset_split(split, epoch, combine) + + self.datasets[split] = SpanMaskedTokensDataset( + dataset, + self.dictionary, + noise_density=self.cfg.noise_density, + mean_noise_span_length=self.cfg.mean_noise_span_length, + shuffle=self.cfg.shuffle, + seed=self.cfg.seed, + ) + logger.info( + "Split: {0}, Loaded {1} samples of span_masked_tokens_dataset".format( + split, + len(self.datasets[split]), + ) + ) + + def build_dataset_for_inference(self, src_tokens, src_lengths, **kwargs): + """ + Generate batches for inference. We assume that the input begins with a + bos symbol (``) and ends with an eos symbol (``). + """ + pad = self.source_dictionary.pad() + eos = self.source_dictionary.eos() + src_dataset = TokenBlockDataset( + src_tokens, + src_lengths, + block_size=self.cfg.tokens_per_sample - 2, # for and + pad=pad, + eos=eos, + break_mode=self.cfg.sample_break_mode, + document_sep_len=0, + ) + prev_output_tokens = PrependTokenDataset( + StripTokenDataset(src_dataset, eos), eos + ) + src_dataset = PadDataset(src_dataset, pad_idx=pad, left_pad=False) + return NestedDictionaryDataset( + { + "id": IdDataset(), + "net_input": { + "src_tokens": src_dataset, + "src_lengths": NumelDataset(src_dataset, reduce=False), + "prev_output_tokens": PadDataset( + prev_output_tokens, pad_idx=pad, left_pad=False + ), + }, + "target": src_dataset, + }, + sizes=[np.array(src_lengths)], + ) + + def max_positions(self): + """Return the max sentence length allowed by the task.""" + return (self.cfg.max_source_positions, self.cfg.max_target_positions) + + @property + def source_dictionary(self): + """Return the source :class:`~fairseq.data.Dictionary`.""" + return self.dictionary + + @property + def target_dictionary(self): + """Return the target :class:`~fairseq.data.Dictionary`.""" + return self.dictionary diff --git a/tests/tasks/test_span_masked_lm.py b/tests/tasks/test_span_masked_lm.py new file mode 100644 index 00000000..d289cf84 --- /dev/null +++ b/tests/tasks/test_span_masked_lm.py @@ -0,0 +1,106 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import os +import unittest +from tempfile import TemporaryDirectory + +from fairseq import options +from fairseq.binarizer import FileBinarizer, VocabularyDatasetBinarizer +from fairseq.dataclass.utils import convert_namespace_to_omegaconf +from fairseq.tasks.span_masked_lm import SpanMaskedLMTask +from tests.utils import build_vocab, make_data + + +class TestSpanMaskedLM(unittest.TestCase): + def test_masks_token_spans(self): + with TemporaryDirectory() as dirname: + + # prep input file + raw_file = os.path.join(dirname, "raw") + data = make_data(out_file=raw_file) + vocab = build_vocab(data) + + # binarize + binarizer = VocabularyDatasetBinarizer(vocab, append_eos=False) + split = "train" + bin_file = os.path.join(dirname, split) + dataset_impl = "mmap" + + FileBinarizer.multiprocess_dataset( + input_file=raw_file, + binarizer=binarizer, + dataset_impl=dataset_impl, + vocab_size=len(vocab), + output_prefix=bin_file, + ) + + # adding sentinel tokens + for i in range(100): + vocab.add_symbol(f"") + + # setup task + train_args = options.parse_args_and_arch( + options.get_training_parser(), + [ + "--task", + "span_masked_lm", + "--arch", + "bart_base", + "--seed", + "42", + dirname, + ], + ) + cfg = convert_namespace_to_omegaconf(train_args) + task = SpanMaskedLMTask(cfg.task, binarizer.dict) + + # load datasets + original_dataset = task._load_dataset_split(bin_file, 1, False) + task.load_dataset(split) + masked_dataset = task.dataset(split) + + iterator = task.get_batch_iterator( + dataset=masked_dataset, + max_tokens=65_536, + max_positions=4_096, + ).next_epoch_itr(shuffle=False) + num_tokens = len(vocab) + for batch in iterator: + for sample in range(len(batch)): + sample_id = batch["id"][sample] + original_tokens = original_dataset[sample_id] + masked_src_tokens = batch["net_input"]["src_tokens"][sample] + masked_src_length = batch["net_input"]["src_lengths"][sample] + masked_tgt_tokens = batch["target"][sample] + + original_offset = 0 + masked_tgt_offset = 0 + extra_id_token = len(vocab) - 1 + for masked_src_token in masked_src_tokens[:masked_src_length]: + if masked_src_token == extra_id_token: + assert ( + masked_src_token == masked_tgt_tokens[masked_tgt_offset] + ) + extra_id_token -= 1 + masked_tgt_offset += 1 + while ( + original_offset < len(original_tokens) + and masked_tgt_tokens[masked_tgt_offset] + != extra_id_token + ): + assert ( + original_tokens[original_offset] + == masked_tgt_tokens[masked_tgt_offset] + ) + original_offset += 1 + masked_tgt_offset += 1 + else: + assert original_tokens[original_offset] == masked_src_token + original_offset += 1 + + +if __name__ == "__main__": + unittest.main()