mirror of
https://github.com/facebookresearch/fairseq.git
synced 2024-10-26 17:32:57 +03:00
add span_masked_lm task (#4366)
Co-authored-by: Alexander Jipa <azzhipa@amazon.com>
This commit is contained in:
parent
5d8d0674c1
commit
ba415c99ca
293
fairseq/data/span_mask_tokens_dataset.py
Normal file
293
fairseq/data/span_mask_tokens_dataset.py
Normal file
@ -0,0 +1,293 @@
|
|||||||
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the MIT license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from . import Dictionary, FairseqDataset, data_utils
|
||||||
|
|
||||||
|
|
||||||
|
def collate(
|
||||||
|
samples,
|
||||||
|
pad_idx,
|
||||||
|
eos_idx,
|
||||||
|
vocab,
|
||||||
|
left_pad_source=False,
|
||||||
|
left_pad_target=False,
|
||||||
|
input_feeding=True,
|
||||||
|
pad_to_length=None,
|
||||||
|
):
|
||||||
|
assert input_feeding
|
||||||
|
if len(samples) == 0:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def merge(key, left_pad, move_eos_to_beginning=False, pad_to_length=None):
|
||||||
|
return data_utils.collate_tokens(
|
||||||
|
[s[key] for s in samples],
|
||||||
|
pad_idx,
|
||||||
|
eos_idx=None, # use eos_idx of each sample instead of vocab.eos()
|
||||||
|
left_pad=left_pad,
|
||||||
|
move_eos_to_beginning=move_eos_to_beginning,
|
||||||
|
pad_to_length=pad_to_length,
|
||||||
|
)
|
||||||
|
|
||||||
|
id = torch.LongTensor([s["id"] for s in samples])
|
||||||
|
src_tokens = merge(
|
||||||
|
"source",
|
||||||
|
left_pad=left_pad_source,
|
||||||
|
pad_to_length=pad_to_length["source"] if pad_to_length is not None else None,
|
||||||
|
)
|
||||||
|
# sort by descending source length
|
||||||
|
src_lengths = torch.LongTensor([s["source"].numel() for s in samples])
|
||||||
|
src_lengths, sort_order = src_lengths.sort(descending=True)
|
||||||
|
id = id.index_select(0, sort_order)
|
||||||
|
src_tokens = src_tokens.index_select(0, sort_order)
|
||||||
|
|
||||||
|
prev_output_tokens = None
|
||||||
|
target = None
|
||||||
|
if samples[0].get("target", None) is not None:
|
||||||
|
target = merge(
|
||||||
|
"target",
|
||||||
|
left_pad=left_pad_target,
|
||||||
|
pad_to_length=pad_to_length["target"]
|
||||||
|
if pad_to_length is not None
|
||||||
|
else None,
|
||||||
|
)
|
||||||
|
target = target.index_select(0, sort_order)
|
||||||
|
ntokens = sum(len(s["target"]) for s in samples)
|
||||||
|
|
||||||
|
if input_feeding:
|
||||||
|
# we create a shifted version of targets for feeding the
|
||||||
|
# previous output token(s) into the next decoder step
|
||||||
|
prev_output_tokens = merge(
|
||||||
|
"target",
|
||||||
|
left_pad=left_pad_target,
|
||||||
|
move_eos_to_beginning=True,
|
||||||
|
pad_to_length=pad_to_length["target"]
|
||||||
|
if pad_to_length is not None
|
||||||
|
else None,
|
||||||
|
)
|
||||||
|
prev_output_tokens = prev_output_tokens.index_select(0, sort_order)
|
||||||
|
else:
|
||||||
|
ntokens = sum(len(s["source"]) for s in samples)
|
||||||
|
|
||||||
|
batch = {
|
||||||
|
"id": id,
|
||||||
|
"ntokens": ntokens,
|
||||||
|
"net_input": {
|
||||||
|
"src_tokens": src_tokens,
|
||||||
|
"src_lengths": src_lengths,
|
||||||
|
},
|
||||||
|
"target": target,
|
||||||
|
"target_lengths": torch.LongTensor([len(t) for t in target]),
|
||||||
|
"nsentences": samples[0]["source"].size(0),
|
||||||
|
"sort_order": sort_order,
|
||||||
|
}
|
||||||
|
if prev_output_tokens is not None:
|
||||||
|
batch["net_input"]["prev_output_tokens"] = prev_output_tokens
|
||||||
|
|
||||||
|
return batch
|
||||||
|
|
||||||
|
|
||||||
|
class SpanMaskedTokensDataset(FairseqDataset):
|
||||||
|
"""
|
||||||
|
A wrapper around TokenBlockDataset for T5 dataset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset (~torch.utils.data.Dataset): dataset to wrap
|
||||||
|
vocab (~fairseq.data.Dictionary): vocabulary
|
||||||
|
noise_density (float): fraction of the tokens to select as noise.
|
||||||
|
mean_noise_span_length (float): mean noise span length.
|
||||||
|
shuffle (bool, optional): shuffle the elements before batching.
|
||||||
|
Default: ``True``
|
||||||
|
seed: Seed for random number generator for reproducibility.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dataset: torch.utils.data.Dataset,
|
||||||
|
vocab: Dictionary,
|
||||||
|
noise_density: float,
|
||||||
|
mean_noise_span_length: float,
|
||||||
|
shuffle: bool,
|
||||||
|
seed: int = 1,
|
||||||
|
):
|
||||||
|
self.dataset = dataset
|
||||||
|
self.vocab = vocab
|
||||||
|
self.seed = seed
|
||||||
|
self.noise_density = noise_density
|
||||||
|
self.mean_noise_span_length = mean_noise_span_length
|
||||||
|
self.shuffle = shuffle
|
||||||
|
self.epoch = 0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def can_reuse_epoch_itr_across_epochs(self):
|
||||||
|
return True # only the noise changes, not item sizes
|
||||||
|
|
||||||
|
def set_epoch(self, epoch, **unused):
|
||||||
|
self.epoch = epoch
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
with data_utils.numpy_seed(self.seed, self.epoch, index):
|
||||||
|
item = self.dataset[index]
|
||||||
|
assert item[-1] == self.vocab.eos()
|
||||||
|
|
||||||
|
noise_mask = self.random_spans_noise_mask(len(item))
|
||||||
|
|
||||||
|
source_sentinel_ids = self.create_sentinel_ids(noise_mask.astype(np.int8))
|
||||||
|
source = self.filter_input_ids(item, source_sentinel_ids)
|
||||||
|
|
||||||
|
target_sentinel_ids = self.create_sentinel_ids(
|
||||||
|
(~noise_mask).astype(np.int8)
|
||||||
|
)
|
||||||
|
target = self.filter_input_ids(item, target_sentinel_ids)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"id": index,
|
||||||
|
"source": torch.from_numpy(source),
|
||||||
|
"target": torch.from_numpy(target),
|
||||||
|
}
|
||||||
|
|
||||||
|
def random_spans_noise_mask(self, length):
|
||||||
|
|
||||||
|
"""
|
||||||
|
This function is copy of `random_spans_helper <https://github.com/google-research/text-to-text-transfer-transformer/blob/84f8bcc14b5f2c03de51bd3587609ba8f6bbd1cd/t5/data/preprocessors.py#L2682>`__ .
|
||||||
|
Noise mask consisting of random spans of noise tokens.
|
||||||
|
The number of noise tokens and the number of noise spans and non-noise spans
|
||||||
|
are determined deterministically as follows:
|
||||||
|
num_noise_tokens = round(length * noise_density)
|
||||||
|
num_nonnoise_spans = num_noise_spans = round(num_noise_tokens / mean_noise_span_length)
|
||||||
|
Spans alternate between non-noise and noise, beginning with non-noise.
|
||||||
|
Subject to the above restrictions, all masks are equally likely.
|
||||||
|
Args:
|
||||||
|
length: an int32 scalar (length of the incoming token sequence)
|
||||||
|
Returns:
|
||||||
|
a boolean tensor with shape [length]
|
||||||
|
"""
|
||||||
|
|
||||||
|
orig_length = length
|
||||||
|
|
||||||
|
num_noise_tokens = int(np.round(length * self.noise_density))
|
||||||
|
# avoid degeneracy by ensuring positive numbers of noise and nonnoise tokens.
|
||||||
|
num_noise_tokens = min(max(num_noise_tokens, 1), length - 1)
|
||||||
|
num_noise_spans = int(np.round(num_noise_tokens / self.mean_noise_span_length))
|
||||||
|
|
||||||
|
# avoid degeneracy by ensuring positive number of noise spans
|
||||||
|
num_noise_spans = max(num_noise_spans, 1)
|
||||||
|
num_nonnoise_tokens = length - num_noise_tokens
|
||||||
|
|
||||||
|
# pick the lengths of the noise spans and the non-noise spans
|
||||||
|
def _random_segmentation(num_items, num_segments):
|
||||||
|
"""
|
||||||
|
Partition a sequence of items randomly into non-empty segments.
|
||||||
|
Args:
|
||||||
|
num_items: an integer scalar > 0
|
||||||
|
num_segments: an integer scalar in [1, num_items]
|
||||||
|
Returns:
|
||||||
|
a Tensor with shape [num_segments] containing positive integers that add up to num_items
|
||||||
|
"""
|
||||||
|
mask_indices = np.arange(num_items - 1) < (num_segments - 1)
|
||||||
|
np.random.shuffle(mask_indices)
|
||||||
|
first_in_segment = np.pad(mask_indices, [[1, 0]])
|
||||||
|
segment_id = np.cumsum(first_in_segment)
|
||||||
|
# count length of subsegments assuming that list is sorted
|
||||||
|
_, segment_length = np.unique(segment_id, return_counts=True)
|
||||||
|
return segment_length
|
||||||
|
|
||||||
|
noise_span_lengths = _random_segmentation(num_noise_tokens, num_noise_spans)
|
||||||
|
nonnoise_span_lengths = _random_segmentation(
|
||||||
|
num_nonnoise_tokens, num_noise_spans
|
||||||
|
)
|
||||||
|
|
||||||
|
interleaved_span_lengths = np.reshape(
|
||||||
|
np.stack([nonnoise_span_lengths, noise_span_lengths], axis=1),
|
||||||
|
[num_noise_spans * 2],
|
||||||
|
)
|
||||||
|
span_starts = np.cumsum(interleaved_span_lengths)[:-1]
|
||||||
|
span_start_indicator = np.zeros((length,), dtype=np.int8)
|
||||||
|
span_start_indicator[span_starts] = True
|
||||||
|
span_num = np.cumsum(span_start_indicator)
|
||||||
|
is_noise = np.equal(span_num % 2, 1)
|
||||||
|
|
||||||
|
return is_noise[:orig_length]
|
||||||
|
|
||||||
|
def create_sentinel_ids(self, mask_indices):
|
||||||
|
"""
|
||||||
|
Sentinel ids creation given the indices that should be masked.
|
||||||
|
The start indices of each mask are replaced by the sentinel ids in increasing
|
||||||
|
order. Consecutive mask indices to be deleted are replaced with `-1`.
|
||||||
|
"""
|
||||||
|
start_indices = mask_indices - np.roll(mask_indices, 1, axis=-1) * mask_indices
|
||||||
|
|
||||||
|
sentinel_ids = np.where(
|
||||||
|
start_indices != 0, np.cumsum(start_indices, axis=-1), start_indices
|
||||||
|
)
|
||||||
|
# making sure all sentinel tokens are unique over the example
|
||||||
|
sentinel_ids = np.where(sentinel_ids != 0, len(self.vocab) - sentinel_ids, 0)
|
||||||
|
sentinel_ids -= mask_indices - start_indices
|
||||||
|
return sentinel_ids
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def filter_input_ids(input_ids, sentinel_ids):
|
||||||
|
"""
|
||||||
|
Puts sentinel mask on `input_ids` and fuse consecutive mask tokens into a single mask token by deleting.
|
||||||
|
This will reduce the sequence length from `expanded_inputs_length` to `input_length`.
|
||||||
|
"""
|
||||||
|
input_ids_full = np.where(sentinel_ids != 0, sentinel_ids, input_ids)
|
||||||
|
|
||||||
|
# input_ids tokens and sentinel tokens are >= 0, tokens < 0 are
|
||||||
|
# masked tokens coming after sentinel tokens and should be removed
|
||||||
|
return input_ids_full[input_ids_full >= 0]
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.dataset)
|
||||||
|
|
||||||
|
def collater(self, samples, pad_to_length=None):
|
||||||
|
"""
|
||||||
|
Merge a list of samples to form a mini-batch.
|
||||||
|
Args:
|
||||||
|
samples (List[dict]): samples to collate
|
||||||
|
Returns:
|
||||||
|
dict: a mini-batch of data
|
||||||
|
"""
|
||||||
|
return collate(
|
||||||
|
samples,
|
||||||
|
self.vocab.pad(),
|
||||||
|
self.vocab.eos(),
|
||||||
|
self.vocab,
|
||||||
|
pad_to_length=pad_to_length,
|
||||||
|
)
|
||||||
|
|
||||||
|
def num_tokens(self, index):
|
||||||
|
"""Return the number of tokens in a sample. This value is used to
|
||||||
|
enforce ``--max-tokens`` during batching."""
|
||||||
|
return self.dataset.sizes[index]
|
||||||
|
|
||||||
|
def size(self, index):
|
||||||
|
"""Return an example's size as a float or tuple. This value is used when
|
||||||
|
filtering a dataset with ``--max-positions``."""
|
||||||
|
return self.dataset.sizes[index]
|
||||||
|
|
||||||
|
def ordered_indices(self):
|
||||||
|
"""Return an ordered list of indices. Batches will be constructed based
|
||||||
|
on this order."""
|
||||||
|
if self.shuffle:
|
||||||
|
indices = np.random.permutation(len(self))
|
||||||
|
else:
|
||||||
|
indices = np.arange(len(self))
|
||||||
|
return indices[np.argsort(self.dataset.sizes[indices], kind="mergesort")]
|
||||||
|
|
||||||
|
def prefetch(self, indices):
|
||||||
|
self.src.prefetch(indices)
|
||||||
|
self.tgt.prefetch(indices)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def supports_prefetch(self):
|
||||||
|
return (
|
||||||
|
hasattr(self.src, "supports_prefetch")
|
||||||
|
and self.src.supports_prefetch
|
||||||
|
and hasattr(self.tgt, "supports_prefetch")
|
||||||
|
and self.tgt.supports_prefetch
|
||||||
|
)
|
243
fairseq/tasks/span_masked_lm.py
Normal file
243
fairseq/tasks/span_masked_lm.py
Normal file
@ -0,0 +1,243 @@
|
|||||||
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the MIT license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from omegaconf import II, MISSING
|
||||||
|
|
||||||
|
from fairseq import utils
|
||||||
|
from fairseq.data import (
|
||||||
|
AppendTokenDataset,
|
||||||
|
Dictionary,
|
||||||
|
IdDataset,
|
||||||
|
NestedDictionaryDataset,
|
||||||
|
NumelDataset,
|
||||||
|
PadDataset,
|
||||||
|
PrependTokenDataset,
|
||||||
|
StripTokenDataset,
|
||||||
|
TokenBlockDataset,
|
||||||
|
data_utils,
|
||||||
|
)
|
||||||
|
from fairseq.data.shorten_dataset import maybe_shorten_dataset
|
||||||
|
from fairseq.data.span_mask_tokens_dataset import SpanMaskedTokensDataset
|
||||||
|
from fairseq.dataclass import ChoiceEnum, FairseqDataclass
|
||||||
|
from fairseq.tasks import FairseqTask, register_task
|
||||||
|
|
||||||
|
from ..data.indexed_dataset import get_available_dataset_impl
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
SAMPLE_BREAK_MODE_CHOICES = ChoiceEnum(["none", "complete", "complete_doc", "eos"])
|
||||||
|
SHORTEN_METHOD_CHOICES = ChoiceEnum(["none", "truncate", "random_crop"])
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SpanMaskedLMConfig(FairseqDataclass):
|
||||||
|
shuffle: bool = field(
|
||||||
|
default=False,
|
||||||
|
)
|
||||||
|
noise_density: float = field(
|
||||||
|
default=0.15,
|
||||||
|
metadata={"help": "What fraction of the tokens to select as noise"},
|
||||||
|
)
|
||||||
|
mean_noise_span_length: float = field(
|
||||||
|
default=3,
|
||||||
|
metadata={"help": "Mean noise span length, must be >= 1"},
|
||||||
|
)
|
||||||
|
data: str = field(
|
||||||
|
default=MISSING,
|
||||||
|
metadata={
|
||||||
|
"help": "colon separated path to data directories list, "
|
||||||
|
"will be iterated upon during epochs in round-robin manner"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
sample_break_mode: SAMPLE_BREAK_MODE_CHOICES = field(
|
||||||
|
default="none",
|
||||||
|
metadata={
|
||||||
|
"help": 'If omitted or "none", fills each sample with tokens-per-sample '
|
||||||
|
'tokens. If set to "complete", splits samples only at the end '
|
||||||
|
"of sentence, but may include multiple sentences per sample. "
|
||||||
|
'"complete_doc" is similar but respects doc boundaries. '
|
||||||
|
'If set to "eos", includes only one sentence per sample.'
|
||||||
|
},
|
||||||
|
)
|
||||||
|
tokens_per_sample: int = field(
|
||||||
|
default=1024,
|
||||||
|
metadata={"help": "max number of tokens per sample for LM dataset"},
|
||||||
|
)
|
||||||
|
shorten_method: SHORTEN_METHOD_CHOICES = field(
|
||||||
|
default="none",
|
||||||
|
metadata={
|
||||||
|
"help": "if not none, shorten sequences that exceed --tokens-per-sample"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
shorten_data_split_list: str = field(
|
||||||
|
default="",
|
||||||
|
metadata={
|
||||||
|
"help": "comma-separated list of dataset splits to apply shortening to, "
|
||||||
|
'e.g., "train,valid" (default: all dataset splits)'
|
||||||
|
},
|
||||||
|
)
|
||||||
|
seed: int = II("common.seed")
|
||||||
|
dataset_impl: Optional[ChoiceEnum(get_available_dataset_impl())] = II(
|
||||||
|
"dataset.dataset_impl"
|
||||||
|
)
|
||||||
|
max_source_positions: int = field(
|
||||||
|
default=1024, metadata={"help": "max number of tokens in the source sequence"}
|
||||||
|
)
|
||||||
|
max_target_positions: int = field(
|
||||||
|
default=1024, metadata={"help": "max number of tokens in the target sequence"}
|
||||||
|
)
|
||||||
|
include_target_tokens: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata={
|
||||||
|
"help": "include target tokens in model input. this is used for data2vec"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@register_task("span_masked_lm", dataclass=SpanMaskedLMConfig)
|
||||||
|
class SpanMaskedLMTask(FairseqTask):
|
||||||
|
"""
|
||||||
|
Span masked language modeling task. (ie. T5)
|
||||||
|
"""
|
||||||
|
|
||||||
|
cfg: SpanMaskedLMConfig
|
||||||
|
|
||||||
|
def __init__(self, cfg, dictionary):
|
||||||
|
super().__init__(cfg)
|
||||||
|
self.dictionary = dictionary
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def setup_task(cls, cfg: SpanMaskedLMConfig, **kwargs):
|
||||||
|
"""Setup the task."""
|
||||||
|
paths = utils.split_paths(cfg.data)
|
||||||
|
assert len(paths) > 0
|
||||||
|
dictionary = Dictionary.load(os.path.join(paths[0], "dict.txt"))
|
||||||
|
logger.info("dictionary: {} types".format(len(dictionary)))
|
||||||
|
if not hasattr(cfg, "shuffle"):
|
||||||
|
cfg.shuffle = False
|
||||||
|
return cls(cfg, dictionary)
|
||||||
|
|
||||||
|
def _load_dataset_split(self, split, epoch, combine):
|
||||||
|
paths = utils.split_paths(self.cfg.data)
|
||||||
|
assert len(paths) > 0
|
||||||
|
data_path = paths[(epoch - 1) % len(paths)]
|
||||||
|
split_path = os.path.join(data_path, split)
|
||||||
|
|
||||||
|
dataset = data_utils.load_indexed_dataset(
|
||||||
|
split_path,
|
||||||
|
self.dictionary,
|
||||||
|
self.cfg.dataset_impl,
|
||||||
|
combine=combine,
|
||||||
|
)
|
||||||
|
if dataset is None:
|
||||||
|
raise FileNotFoundError(
|
||||||
|
"Dataset not found: {} ({})".format(split, split_path)
|
||||||
|
)
|
||||||
|
|
||||||
|
dataset = StripTokenDataset(dataset, self.dictionary.eos())
|
||||||
|
|
||||||
|
dataset = maybe_shorten_dataset(
|
||||||
|
dataset,
|
||||||
|
split,
|
||||||
|
self.cfg.shorten_data_split_list,
|
||||||
|
self.cfg.shorten_method,
|
||||||
|
self.cfg.tokens_per_sample,
|
||||||
|
self.cfg.seed,
|
||||||
|
)
|
||||||
|
|
||||||
|
# create continuous blocks of tokens
|
||||||
|
dataset = TokenBlockDataset(
|
||||||
|
dataset,
|
||||||
|
dataset.sizes,
|
||||||
|
self.cfg.tokens_per_sample - 2, # one less for <s> and one for </s>
|
||||||
|
pad=self.dictionary.pad(),
|
||||||
|
eos=self.dictionary.eos(),
|
||||||
|
break_mode=self.cfg.sample_break_mode,
|
||||||
|
document_sep_len=0,
|
||||||
|
)
|
||||||
|
logger.info("loaded {} blocks from: {}".format(len(dataset), split_path))
|
||||||
|
|
||||||
|
# prepend beginning-of-sentence token (<s>, equiv. to [CLS] in BERT)
|
||||||
|
dataset = PrependTokenDataset(dataset, self.source_dictionary.bos())
|
||||||
|
dataset = AppendTokenDataset(dataset, self.source_dictionary.eos())
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
def load_dataset(self, split, epoch=1, combine=False, **kwargs):
|
||||||
|
"""Load a given dataset split.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
split (str): name of the split (e.g., train, valid, test)
|
||||||
|
"""
|
||||||
|
dataset = self._load_dataset_split(split, epoch, combine)
|
||||||
|
|
||||||
|
self.datasets[split] = SpanMaskedTokensDataset(
|
||||||
|
dataset,
|
||||||
|
self.dictionary,
|
||||||
|
noise_density=self.cfg.noise_density,
|
||||||
|
mean_noise_span_length=self.cfg.mean_noise_span_length,
|
||||||
|
shuffle=self.cfg.shuffle,
|
||||||
|
seed=self.cfg.seed,
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
"Split: {0}, Loaded {1} samples of span_masked_tokens_dataset".format(
|
||||||
|
split,
|
||||||
|
len(self.datasets[split]),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def build_dataset_for_inference(self, src_tokens, src_lengths, **kwargs):
|
||||||
|
"""
|
||||||
|
Generate batches for inference. We assume that the input begins with a
|
||||||
|
bos symbol (`<s>`) and ends with an eos symbol (`</s>`).
|
||||||
|
"""
|
||||||
|
pad = self.source_dictionary.pad()
|
||||||
|
eos = self.source_dictionary.eos()
|
||||||
|
src_dataset = TokenBlockDataset(
|
||||||
|
src_tokens,
|
||||||
|
src_lengths,
|
||||||
|
block_size=self.cfg.tokens_per_sample - 2, # for <s> and </s>
|
||||||
|
pad=pad,
|
||||||
|
eos=eos,
|
||||||
|
break_mode=self.cfg.sample_break_mode,
|
||||||
|
document_sep_len=0,
|
||||||
|
)
|
||||||
|
prev_output_tokens = PrependTokenDataset(
|
||||||
|
StripTokenDataset(src_dataset, eos), eos
|
||||||
|
)
|
||||||
|
src_dataset = PadDataset(src_dataset, pad_idx=pad, left_pad=False)
|
||||||
|
return NestedDictionaryDataset(
|
||||||
|
{
|
||||||
|
"id": IdDataset(),
|
||||||
|
"net_input": {
|
||||||
|
"src_tokens": src_dataset,
|
||||||
|
"src_lengths": NumelDataset(src_dataset, reduce=False),
|
||||||
|
"prev_output_tokens": PadDataset(
|
||||||
|
prev_output_tokens, pad_idx=pad, left_pad=False
|
||||||
|
),
|
||||||
|
},
|
||||||
|
"target": src_dataset,
|
||||||
|
},
|
||||||
|
sizes=[np.array(src_lengths)],
|
||||||
|
)
|
||||||
|
|
||||||
|
def max_positions(self):
|
||||||
|
"""Return the max sentence length allowed by the task."""
|
||||||
|
return (self.cfg.max_source_positions, self.cfg.max_target_positions)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def source_dictionary(self):
|
||||||
|
"""Return the source :class:`~fairseq.data.Dictionary`."""
|
||||||
|
return self.dictionary
|
||||||
|
|
||||||
|
@property
|
||||||
|
def target_dictionary(self):
|
||||||
|
"""Return the target :class:`~fairseq.data.Dictionary`."""
|
||||||
|
return self.dictionary
|
106
tests/tasks/test_span_masked_lm.py
Normal file
106
tests/tasks/test_span_masked_lm.py
Normal file
@ -0,0 +1,106 @@
|
|||||||
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the MIT license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
import os
|
||||||
|
import unittest
|
||||||
|
from tempfile import TemporaryDirectory
|
||||||
|
|
||||||
|
from fairseq import options
|
||||||
|
from fairseq.binarizer import FileBinarizer, VocabularyDatasetBinarizer
|
||||||
|
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
|
||||||
|
from fairseq.tasks.span_masked_lm import SpanMaskedLMTask
|
||||||
|
from tests.utils import build_vocab, make_data
|
||||||
|
|
||||||
|
|
||||||
|
class TestSpanMaskedLM(unittest.TestCase):
|
||||||
|
def test_masks_token_spans(self):
|
||||||
|
with TemporaryDirectory() as dirname:
|
||||||
|
|
||||||
|
# prep input file
|
||||||
|
raw_file = os.path.join(dirname, "raw")
|
||||||
|
data = make_data(out_file=raw_file)
|
||||||
|
vocab = build_vocab(data)
|
||||||
|
|
||||||
|
# binarize
|
||||||
|
binarizer = VocabularyDatasetBinarizer(vocab, append_eos=False)
|
||||||
|
split = "train"
|
||||||
|
bin_file = os.path.join(dirname, split)
|
||||||
|
dataset_impl = "mmap"
|
||||||
|
|
||||||
|
FileBinarizer.multiprocess_dataset(
|
||||||
|
input_file=raw_file,
|
||||||
|
binarizer=binarizer,
|
||||||
|
dataset_impl=dataset_impl,
|
||||||
|
vocab_size=len(vocab),
|
||||||
|
output_prefix=bin_file,
|
||||||
|
)
|
||||||
|
|
||||||
|
# adding sentinel tokens
|
||||||
|
for i in range(100):
|
||||||
|
vocab.add_symbol(f"<extra_id_{i}>")
|
||||||
|
|
||||||
|
# setup task
|
||||||
|
train_args = options.parse_args_and_arch(
|
||||||
|
options.get_training_parser(),
|
||||||
|
[
|
||||||
|
"--task",
|
||||||
|
"span_masked_lm",
|
||||||
|
"--arch",
|
||||||
|
"bart_base",
|
||||||
|
"--seed",
|
||||||
|
"42",
|
||||||
|
dirname,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
cfg = convert_namespace_to_omegaconf(train_args)
|
||||||
|
task = SpanMaskedLMTask(cfg.task, binarizer.dict)
|
||||||
|
|
||||||
|
# load datasets
|
||||||
|
original_dataset = task._load_dataset_split(bin_file, 1, False)
|
||||||
|
task.load_dataset(split)
|
||||||
|
masked_dataset = task.dataset(split)
|
||||||
|
|
||||||
|
iterator = task.get_batch_iterator(
|
||||||
|
dataset=masked_dataset,
|
||||||
|
max_tokens=65_536,
|
||||||
|
max_positions=4_096,
|
||||||
|
).next_epoch_itr(shuffle=False)
|
||||||
|
num_tokens = len(vocab)
|
||||||
|
for batch in iterator:
|
||||||
|
for sample in range(len(batch)):
|
||||||
|
sample_id = batch["id"][sample]
|
||||||
|
original_tokens = original_dataset[sample_id]
|
||||||
|
masked_src_tokens = batch["net_input"]["src_tokens"][sample]
|
||||||
|
masked_src_length = batch["net_input"]["src_lengths"][sample]
|
||||||
|
masked_tgt_tokens = batch["target"][sample]
|
||||||
|
|
||||||
|
original_offset = 0
|
||||||
|
masked_tgt_offset = 0
|
||||||
|
extra_id_token = len(vocab) - 1
|
||||||
|
for masked_src_token in masked_src_tokens[:masked_src_length]:
|
||||||
|
if masked_src_token == extra_id_token:
|
||||||
|
assert (
|
||||||
|
masked_src_token == masked_tgt_tokens[masked_tgt_offset]
|
||||||
|
)
|
||||||
|
extra_id_token -= 1
|
||||||
|
masked_tgt_offset += 1
|
||||||
|
while (
|
||||||
|
original_offset < len(original_tokens)
|
||||||
|
and masked_tgt_tokens[masked_tgt_offset]
|
||||||
|
!= extra_id_token
|
||||||
|
):
|
||||||
|
assert (
|
||||||
|
original_tokens[original_offset]
|
||||||
|
== masked_tgt_tokens[masked_tgt_offset]
|
||||||
|
)
|
||||||
|
original_offset += 1
|
||||||
|
masked_tgt_offset += 1
|
||||||
|
else:
|
||||||
|
assert original_tokens[original_offset] == masked_src_token
|
||||||
|
original_offset += 1
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
Loading…
Reference in New Issue
Block a user