add span_masked_lm task (#4366)

Co-authored-by: Alexander Jipa <azzhipa@amazon.com>
This commit is contained in:
Alexander Jipa 2022-06-29 10:04:00 -04:00 committed by GitHub
parent 5d8d0674c1
commit ba415c99ca
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 642 additions and 0 deletions

View File

@ -0,0 +1,293 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import numpy as np
import torch
from . import Dictionary, FairseqDataset, data_utils
def collate(
samples,
pad_idx,
eos_idx,
vocab,
left_pad_source=False,
left_pad_target=False,
input_feeding=True,
pad_to_length=None,
):
assert input_feeding
if len(samples) == 0:
return {}
def merge(key, left_pad, move_eos_to_beginning=False, pad_to_length=None):
return data_utils.collate_tokens(
[s[key] for s in samples],
pad_idx,
eos_idx=None, # use eos_idx of each sample instead of vocab.eos()
left_pad=left_pad,
move_eos_to_beginning=move_eos_to_beginning,
pad_to_length=pad_to_length,
)
id = torch.LongTensor([s["id"] for s in samples])
src_tokens = merge(
"source",
left_pad=left_pad_source,
pad_to_length=pad_to_length["source"] if pad_to_length is not None else None,
)
# sort by descending source length
src_lengths = torch.LongTensor([s["source"].numel() for s in samples])
src_lengths, sort_order = src_lengths.sort(descending=True)
id = id.index_select(0, sort_order)
src_tokens = src_tokens.index_select(0, sort_order)
prev_output_tokens = None
target = None
if samples[0].get("target", None) is not None:
target = merge(
"target",
left_pad=left_pad_target,
pad_to_length=pad_to_length["target"]
if pad_to_length is not None
else None,
)
target = target.index_select(0, sort_order)
ntokens = sum(len(s["target"]) for s in samples)
if input_feeding:
# we create a shifted version of targets for feeding the
# previous output token(s) into the next decoder step
prev_output_tokens = merge(
"target",
left_pad=left_pad_target,
move_eos_to_beginning=True,
pad_to_length=pad_to_length["target"]
if pad_to_length is not None
else None,
)
prev_output_tokens = prev_output_tokens.index_select(0, sort_order)
else:
ntokens = sum(len(s["source"]) for s in samples)
batch = {
"id": id,
"ntokens": ntokens,
"net_input": {
"src_tokens": src_tokens,
"src_lengths": src_lengths,
},
"target": target,
"target_lengths": torch.LongTensor([len(t) for t in target]),
"nsentences": samples[0]["source"].size(0),
"sort_order": sort_order,
}
if prev_output_tokens is not None:
batch["net_input"]["prev_output_tokens"] = prev_output_tokens
return batch
class SpanMaskedTokensDataset(FairseqDataset):
"""
A wrapper around TokenBlockDataset for T5 dataset.
Args:
dataset (~torch.utils.data.Dataset): dataset to wrap
vocab (~fairseq.data.Dictionary): vocabulary
noise_density (float): fraction of the tokens to select as noise.
mean_noise_span_length (float): mean noise span length.
shuffle (bool, optional): shuffle the elements before batching.
Default: ``True``
seed: Seed for random number generator for reproducibility.
"""
def __init__(
self,
dataset: torch.utils.data.Dataset,
vocab: Dictionary,
noise_density: float,
mean_noise_span_length: float,
shuffle: bool,
seed: int = 1,
):
self.dataset = dataset
self.vocab = vocab
self.seed = seed
self.noise_density = noise_density
self.mean_noise_span_length = mean_noise_span_length
self.shuffle = shuffle
self.epoch = 0
@property
def can_reuse_epoch_itr_across_epochs(self):
return True # only the noise changes, not item sizes
def set_epoch(self, epoch, **unused):
self.epoch = epoch
def __getitem__(self, index):
with data_utils.numpy_seed(self.seed, self.epoch, index):
item = self.dataset[index]
assert item[-1] == self.vocab.eos()
noise_mask = self.random_spans_noise_mask(len(item))
source_sentinel_ids = self.create_sentinel_ids(noise_mask.astype(np.int8))
source = self.filter_input_ids(item, source_sentinel_ids)
target_sentinel_ids = self.create_sentinel_ids(
(~noise_mask).astype(np.int8)
)
target = self.filter_input_ids(item, target_sentinel_ids)
return {
"id": index,
"source": torch.from_numpy(source),
"target": torch.from_numpy(target),
}
def random_spans_noise_mask(self, length):
"""
This function is copy of `random_spans_helper <https://github.com/google-research/text-to-text-transfer-transformer/blob/84f8bcc14b5f2c03de51bd3587609ba8f6bbd1cd/t5/data/preprocessors.py#L2682>`__ .
Noise mask consisting of random spans of noise tokens.
The number of noise tokens and the number of noise spans and non-noise spans
are determined deterministically as follows:
num_noise_tokens = round(length * noise_density)
num_nonnoise_spans = num_noise_spans = round(num_noise_tokens / mean_noise_span_length)
Spans alternate between non-noise and noise, beginning with non-noise.
Subject to the above restrictions, all masks are equally likely.
Args:
length: an int32 scalar (length of the incoming token sequence)
Returns:
a boolean tensor with shape [length]
"""
orig_length = length
num_noise_tokens = int(np.round(length * self.noise_density))
# avoid degeneracy by ensuring positive numbers of noise and nonnoise tokens.
num_noise_tokens = min(max(num_noise_tokens, 1), length - 1)
num_noise_spans = int(np.round(num_noise_tokens / self.mean_noise_span_length))
# avoid degeneracy by ensuring positive number of noise spans
num_noise_spans = max(num_noise_spans, 1)
num_nonnoise_tokens = length - num_noise_tokens
# pick the lengths of the noise spans and the non-noise spans
def _random_segmentation(num_items, num_segments):
"""
Partition a sequence of items randomly into non-empty segments.
Args:
num_items: an integer scalar > 0
num_segments: an integer scalar in [1, num_items]
Returns:
a Tensor with shape [num_segments] containing positive integers that add up to num_items
"""
mask_indices = np.arange(num_items - 1) < (num_segments - 1)
np.random.shuffle(mask_indices)
first_in_segment = np.pad(mask_indices, [[1, 0]])
segment_id = np.cumsum(first_in_segment)
# count length of subsegments assuming that list is sorted
_, segment_length = np.unique(segment_id, return_counts=True)
return segment_length
noise_span_lengths = _random_segmentation(num_noise_tokens, num_noise_spans)
nonnoise_span_lengths = _random_segmentation(
num_nonnoise_tokens, num_noise_spans
)
interleaved_span_lengths = np.reshape(
np.stack([nonnoise_span_lengths, noise_span_lengths], axis=1),
[num_noise_spans * 2],
)
span_starts = np.cumsum(interleaved_span_lengths)[:-1]
span_start_indicator = np.zeros((length,), dtype=np.int8)
span_start_indicator[span_starts] = True
span_num = np.cumsum(span_start_indicator)
is_noise = np.equal(span_num % 2, 1)
return is_noise[:orig_length]
def create_sentinel_ids(self, mask_indices):
"""
Sentinel ids creation given the indices that should be masked.
The start indices of each mask are replaced by the sentinel ids in increasing
order. Consecutive mask indices to be deleted are replaced with `-1`.
"""
start_indices = mask_indices - np.roll(mask_indices, 1, axis=-1) * mask_indices
sentinel_ids = np.where(
start_indices != 0, np.cumsum(start_indices, axis=-1), start_indices
)
# making sure all sentinel tokens are unique over the example
sentinel_ids = np.where(sentinel_ids != 0, len(self.vocab) - sentinel_ids, 0)
sentinel_ids -= mask_indices - start_indices
return sentinel_ids
@staticmethod
def filter_input_ids(input_ids, sentinel_ids):
"""
Puts sentinel mask on `input_ids` and fuse consecutive mask tokens into a single mask token by deleting.
This will reduce the sequence length from `expanded_inputs_length` to `input_length`.
"""
input_ids_full = np.where(sentinel_ids != 0, sentinel_ids, input_ids)
# input_ids tokens and sentinel tokens are >= 0, tokens < 0 are
# masked tokens coming after sentinel tokens and should be removed
return input_ids_full[input_ids_full >= 0]
def __len__(self):
return len(self.dataset)
def collater(self, samples, pad_to_length=None):
"""
Merge a list of samples to form a mini-batch.
Args:
samples (List[dict]): samples to collate
Returns:
dict: a mini-batch of data
"""
return collate(
samples,
self.vocab.pad(),
self.vocab.eos(),
self.vocab,
pad_to_length=pad_to_length,
)
def num_tokens(self, index):
"""Return the number of tokens in a sample. This value is used to
enforce ``--max-tokens`` during batching."""
return self.dataset.sizes[index]
def size(self, index):
"""Return an example's size as a float or tuple. This value is used when
filtering a dataset with ``--max-positions``."""
return self.dataset.sizes[index]
def ordered_indices(self):
"""Return an ordered list of indices. Batches will be constructed based
on this order."""
if self.shuffle:
indices = np.random.permutation(len(self))
else:
indices = np.arange(len(self))
return indices[np.argsort(self.dataset.sizes[indices], kind="mergesort")]
def prefetch(self, indices):
self.src.prefetch(indices)
self.tgt.prefetch(indices)
@property
def supports_prefetch(self):
return (
hasattr(self.src, "supports_prefetch")
and self.src.supports_prefetch
and hasattr(self.tgt, "supports_prefetch")
and self.tgt.supports_prefetch
)

View File

@ -0,0 +1,243 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
import os
from dataclasses import dataclass, field
from typing import Optional
import numpy as np
from omegaconf import II, MISSING
from fairseq import utils
from fairseq.data import (
AppendTokenDataset,
Dictionary,
IdDataset,
NestedDictionaryDataset,
NumelDataset,
PadDataset,
PrependTokenDataset,
StripTokenDataset,
TokenBlockDataset,
data_utils,
)
from fairseq.data.shorten_dataset import maybe_shorten_dataset
from fairseq.data.span_mask_tokens_dataset import SpanMaskedTokensDataset
from fairseq.dataclass import ChoiceEnum, FairseqDataclass
from fairseq.tasks import FairseqTask, register_task
from ..data.indexed_dataset import get_available_dataset_impl
logger = logging.getLogger(__name__)
SAMPLE_BREAK_MODE_CHOICES = ChoiceEnum(["none", "complete", "complete_doc", "eos"])
SHORTEN_METHOD_CHOICES = ChoiceEnum(["none", "truncate", "random_crop"])
@dataclass
class SpanMaskedLMConfig(FairseqDataclass):
shuffle: bool = field(
default=False,
)
noise_density: float = field(
default=0.15,
metadata={"help": "What fraction of the tokens to select as noise"},
)
mean_noise_span_length: float = field(
default=3,
metadata={"help": "Mean noise span length, must be >= 1"},
)
data: str = field(
default=MISSING,
metadata={
"help": "colon separated path to data directories list, "
"will be iterated upon during epochs in round-robin manner"
},
)
sample_break_mode: SAMPLE_BREAK_MODE_CHOICES = field(
default="none",
metadata={
"help": 'If omitted or "none", fills each sample with tokens-per-sample '
'tokens. If set to "complete", splits samples only at the end '
"of sentence, but may include multiple sentences per sample. "
'"complete_doc" is similar but respects doc boundaries. '
'If set to "eos", includes only one sentence per sample.'
},
)
tokens_per_sample: int = field(
default=1024,
metadata={"help": "max number of tokens per sample for LM dataset"},
)
shorten_method: SHORTEN_METHOD_CHOICES = field(
default="none",
metadata={
"help": "if not none, shorten sequences that exceed --tokens-per-sample"
},
)
shorten_data_split_list: str = field(
default="",
metadata={
"help": "comma-separated list of dataset splits to apply shortening to, "
'e.g., "train,valid" (default: all dataset splits)'
},
)
seed: int = II("common.seed")
dataset_impl: Optional[ChoiceEnum(get_available_dataset_impl())] = II(
"dataset.dataset_impl"
)
max_source_positions: int = field(
default=1024, metadata={"help": "max number of tokens in the source sequence"}
)
max_target_positions: int = field(
default=1024, metadata={"help": "max number of tokens in the target sequence"}
)
include_target_tokens: bool = field(
default=False,
metadata={
"help": "include target tokens in model input. this is used for data2vec"
},
)
@register_task("span_masked_lm", dataclass=SpanMaskedLMConfig)
class SpanMaskedLMTask(FairseqTask):
"""
Span masked language modeling task. (ie. T5)
"""
cfg: SpanMaskedLMConfig
def __init__(self, cfg, dictionary):
super().__init__(cfg)
self.dictionary = dictionary
@classmethod
def setup_task(cls, cfg: SpanMaskedLMConfig, **kwargs):
"""Setup the task."""
paths = utils.split_paths(cfg.data)
assert len(paths) > 0
dictionary = Dictionary.load(os.path.join(paths[0], "dict.txt"))
logger.info("dictionary: {} types".format(len(dictionary)))
if not hasattr(cfg, "shuffle"):
cfg.shuffle = False
return cls(cfg, dictionary)
def _load_dataset_split(self, split, epoch, combine):
paths = utils.split_paths(self.cfg.data)
assert len(paths) > 0
data_path = paths[(epoch - 1) % len(paths)]
split_path = os.path.join(data_path, split)
dataset = data_utils.load_indexed_dataset(
split_path,
self.dictionary,
self.cfg.dataset_impl,
combine=combine,
)
if dataset is None:
raise FileNotFoundError(
"Dataset not found: {} ({})".format(split, split_path)
)
dataset = StripTokenDataset(dataset, self.dictionary.eos())
dataset = maybe_shorten_dataset(
dataset,
split,
self.cfg.shorten_data_split_list,
self.cfg.shorten_method,
self.cfg.tokens_per_sample,
self.cfg.seed,
)
# create continuous blocks of tokens
dataset = TokenBlockDataset(
dataset,
dataset.sizes,
self.cfg.tokens_per_sample - 2, # one less for <s> and one for </s>
pad=self.dictionary.pad(),
eos=self.dictionary.eos(),
break_mode=self.cfg.sample_break_mode,
document_sep_len=0,
)
logger.info("loaded {} blocks from: {}".format(len(dataset), split_path))
# prepend beginning-of-sentence token (<s>, equiv. to [CLS] in BERT)
dataset = PrependTokenDataset(dataset, self.source_dictionary.bos())
dataset = AppendTokenDataset(dataset, self.source_dictionary.eos())
return dataset
def load_dataset(self, split, epoch=1, combine=False, **kwargs):
"""Load a given dataset split.
Args:
split (str): name of the split (e.g., train, valid, test)
"""
dataset = self._load_dataset_split(split, epoch, combine)
self.datasets[split] = SpanMaskedTokensDataset(
dataset,
self.dictionary,
noise_density=self.cfg.noise_density,
mean_noise_span_length=self.cfg.mean_noise_span_length,
shuffle=self.cfg.shuffle,
seed=self.cfg.seed,
)
logger.info(
"Split: {0}, Loaded {1} samples of span_masked_tokens_dataset".format(
split,
len(self.datasets[split]),
)
)
def build_dataset_for_inference(self, src_tokens, src_lengths, **kwargs):
"""
Generate batches for inference. We assume that the input begins with a
bos symbol (`<s>`) and ends with an eos symbol (`</s>`).
"""
pad = self.source_dictionary.pad()
eos = self.source_dictionary.eos()
src_dataset = TokenBlockDataset(
src_tokens,
src_lengths,
block_size=self.cfg.tokens_per_sample - 2, # for <s> and </s>
pad=pad,
eos=eos,
break_mode=self.cfg.sample_break_mode,
document_sep_len=0,
)
prev_output_tokens = PrependTokenDataset(
StripTokenDataset(src_dataset, eos), eos
)
src_dataset = PadDataset(src_dataset, pad_idx=pad, left_pad=False)
return NestedDictionaryDataset(
{
"id": IdDataset(),
"net_input": {
"src_tokens": src_dataset,
"src_lengths": NumelDataset(src_dataset, reduce=False),
"prev_output_tokens": PadDataset(
prev_output_tokens, pad_idx=pad, left_pad=False
),
},
"target": src_dataset,
},
sizes=[np.array(src_lengths)],
)
def max_positions(self):
"""Return the max sentence length allowed by the task."""
return (self.cfg.max_source_positions, self.cfg.max_target_positions)
@property
def source_dictionary(self):
"""Return the source :class:`~fairseq.data.Dictionary`."""
return self.dictionary
@property
def target_dictionary(self):
"""Return the target :class:`~fairseq.data.Dictionary`."""
return self.dictionary

View File

@ -0,0 +1,106 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import os
import unittest
from tempfile import TemporaryDirectory
from fairseq import options
from fairseq.binarizer import FileBinarizer, VocabularyDatasetBinarizer
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
from fairseq.tasks.span_masked_lm import SpanMaskedLMTask
from tests.utils import build_vocab, make_data
class TestSpanMaskedLM(unittest.TestCase):
def test_masks_token_spans(self):
with TemporaryDirectory() as dirname:
# prep input file
raw_file = os.path.join(dirname, "raw")
data = make_data(out_file=raw_file)
vocab = build_vocab(data)
# binarize
binarizer = VocabularyDatasetBinarizer(vocab, append_eos=False)
split = "train"
bin_file = os.path.join(dirname, split)
dataset_impl = "mmap"
FileBinarizer.multiprocess_dataset(
input_file=raw_file,
binarizer=binarizer,
dataset_impl=dataset_impl,
vocab_size=len(vocab),
output_prefix=bin_file,
)
# adding sentinel tokens
for i in range(100):
vocab.add_symbol(f"<extra_id_{i}>")
# setup task
train_args = options.parse_args_and_arch(
options.get_training_parser(),
[
"--task",
"span_masked_lm",
"--arch",
"bart_base",
"--seed",
"42",
dirname,
],
)
cfg = convert_namespace_to_omegaconf(train_args)
task = SpanMaskedLMTask(cfg.task, binarizer.dict)
# load datasets
original_dataset = task._load_dataset_split(bin_file, 1, False)
task.load_dataset(split)
masked_dataset = task.dataset(split)
iterator = task.get_batch_iterator(
dataset=masked_dataset,
max_tokens=65_536,
max_positions=4_096,
).next_epoch_itr(shuffle=False)
num_tokens = len(vocab)
for batch in iterator:
for sample in range(len(batch)):
sample_id = batch["id"][sample]
original_tokens = original_dataset[sample_id]
masked_src_tokens = batch["net_input"]["src_tokens"][sample]
masked_src_length = batch["net_input"]["src_lengths"][sample]
masked_tgt_tokens = batch["target"][sample]
original_offset = 0
masked_tgt_offset = 0
extra_id_token = len(vocab) - 1
for masked_src_token in masked_src_tokens[:masked_src_length]:
if masked_src_token == extra_id_token:
assert (
masked_src_token == masked_tgt_tokens[masked_tgt_offset]
)
extra_id_token -= 1
masked_tgt_offset += 1
while (
original_offset < len(original_tokens)
and masked_tgt_tokens[masked_tgt_offset]
!= extra_id_token
):
assert (
original_tokens[original_offset]
== masked_tgt_tokens[masked_tgt_offset]
)
original_offset += 1
masked_tgt_offset += 1
else:
assert original_tokens[original_offset] == masked_src_token
original_offset += 1
if __name__ == "__main__":
unittest.main()