mirror of
https://github.com/facebookresearch/fairseq.git
synced 2024-07-14 18:50:22 +03:00
add span_masked_lm task (#4366)
Co-authored-by: Alexander Jipa <azzhipa@amazon.com>
This commit is contained in:
parent
5d8d0674c1
commit
ba415c99ca
293
fairseq/data/span_mask_tokens_dataset.py
Normal file
293
fairseq/data/span_mask_tokens_dataset.py
Normal file
@ -0,0 +1,293 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from . import Dictionary, FairseqDataset, data_utils
|
||||
|
||||
|
||||
def collate(
|
||||
samples,
|
||||
pad_idx,
|
||||
eos_idx,
|
||||
vocab,
|
||||
left_pad_source=False,
|
||||
left_pad_target=False,
|
||||
input_feeding=True,
|
||||
pad_to_length=None,
|
||||
):
|
||||
assert input_feeding
|
||||
if len(samples) == 0:
|
||||
return {}
|
||||
|
||||
def merge(key, left_pad, move_eos_to_beginning=False, pad_to_length=None):
|
||||
return data_utils.collate_tokens(
|
||||
[s[key] for s in samples],
|
||||
pad_idx,
|
||||
eos_idx=None, # use eos_idx of each sample instead of vocab.eos()
|
||||
left_pad=left_pad,
|
||||
move_eos_to_beginning=move_eos_to_beginning,
|
||||
pad_to_length=pad_to_length,
|
||||
)
|
||||
|
||||
id = torch.LongTensor([s["id"] for s in samples])
|
||||
src_tokens = merge(
|
||||
"source",
|
||||
left_pad=left_pad_source,
|
||||
pad_to_length=pad_to_length["source"] if pad_to_length is not None else None,
|
||||
)
|
||||
# sort by descending source length
|
||||
src_lengths = torch.LongTensor([s["source"].numel() for s in samples])
|
||||
src_lengths, sort_order = src_lengths.sort(descending=True)
|
||||
id = id.index_select(0, sort_order)
|
||||
src_tokens = src_tokens.index_select(0, sort_order)
|
||||
|
||||
prev_output_tokens = None
|
||||
target = None
|
||||
if samples[0].get("target", None) is not None:
|
||||
target = merge(
|
||||
"target",
|
||||
left_pad=left_pad_target,
|
||||
pad_to_length=pad_to_length["target"]
|
||||
if pad_to_length is not None
|
||||
else None,
|
||||
)
|
||||
target = target.index_select(0, sort_order)
|
||||
ntokens = sum(len(s["target"]) for s in samples)
|
||||
|
||||
if input_feeding:
|
||||
# we create a shifted version of targets for feeding the
|
||||
# previous output token(s) into the next decoder step
|
||||
prev_output_tokens = merge(
|
||||
"target",
|
||||
left_pad=left_pad_target,
|
||||
move_eos_to_beginning=True,
|
||||
pad_to_length=pad_to_length["target"]
|
||||
if pad_to_length is not None
|
||||
else None,
|
||||
)
|
||||
prev_output_tokens = prev_output_tokens.index_select(0, sort_order)
|
||||
else:
|
||||
ntokens = sum(len(s["source"]) for s in samples)
|
||||
|
||||
batch = {
|
||||
"id": id,
|
||||
"ntokens": ntokens,
|
||||
"net_input": {
|
||||
"src_tokens": src_tokens,
|
||||
"src_lengths": src_lengths,
|
||||
},
|
||||
"target": target,
|
||||
"target_lengths": torch.LongTensor([len(t) for t in target]),
|
||||
"nsentences": samples[0]["source"].size(0),
|
||||
"sort_order": sort_order,
|
||||
}
|
||||
if prev_output_tokens is not None:
|
||||
batch["net_input"]["prev_output_tokens"] = prev_output_tokens
|
||||
|
||||
return batch
|
||||
|
||||
|
||||
class SpanMaskedTokensDataset(FairseqDataset):
|
||||
"""
|
||||
A wrapper around TokenBlockDataset for T5 dataset.
|
||||
|
||||
Args:
|
||||
dataset (~torch.utils.data.Dataset): dataset to wrap
|
||||
vocab (~fairseq.data.Dictionary): vocabulary
|
||||
noise_density (float): fraction of the tokens to select as noise.
|
||||
mean_noise_span_length (float): mean noise span length.
|
||||
shuffle (bool, optional): shuffle the elements before batching.
|
||||
Default: ``True``
|
||||
seed: Seed for random number generator for reproducibility.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataset: torch.utils.data.Dataset,
|
||||
vocab: Dictionary,
|
||||
noise_density: float,
|
||||
mean_noise_span_length: float,
|
||||
shuffle: bool,
|
||||
seed: int = 1,
|
||||
):
|
||||
self.dataset = dataset
|
||||
self.vocab = vocab
|
||||
self.seed = seed
|
||||
self.noise_density = noise_density
|
||||
self.mean_noise_span_length = mean_noise_span_length
|
||||
self.shuffle = shuffle
|
||||
self.epoch = 0
|
||||
|
||||
@property
|
||||
def can_reuse_epoch_itr_across_epochs(self):
|
||||
return True # only the noise changes, not item sizes
|
||||
|
||||
def set_epoch(self, epoch, **unused):
|
||||
self.epoch = epoch
|
||||
|
||||
def __getitem__(self, index):
|
||||
with data_utils.numpy_seed(self.seed, self.epoch, index):
|
||||
item = self.dataset[index]
|
||||
assert item[-1] == self.vocab.eos()
|
||||
|
||||
noise_mask = self.random_spans_noise_mask(len(item))
|
||||
|
||||
source_sentinel_ids = self.create_sentinel_ids(noise_mask.astype(np.int8))
|
||||
source = self.filter_input_ids(item, source_sentinel_ids)
|
||||
|
||||
target_sentinel_ids = self.create_sentinel_ids(
|
||||
(~noise_mask).astype(np.int8)
|
||||
)
|
||||
target = self.filter_input_ids(item, target_sentinel_ids)
|
||||
|
||||
return {
|
||||
"id": index,
|
||||
"source": torch.from_numpy(source),
|
||||
"target": torch.from_numpy(target),
|
||||
}
|
||||
|
||||
def random_spans_noise_mask(self, length):
|
||||
|
||||
"""
|
||||
This function is copy of `random_spans_helper <https://github.com/google-research/text-to-text-transfer-transformer/blob/84f8bcc14b5f2c03de51bd3587609ba8f6bbd1cd/t5/data/preprocessors.py#L2682>`__ .
|
||||
Noise mask consisting of random spans of noise tokens.
|
||||
The number of noise tokens and the number of noise spans and non-noise spans
|
||||
are determined deterministically as follows:
|
||||
num_noise_tokens = round(length * noise_density)
|
||||
num_nonnoise_spans = num_noise_spans = round(num_noise_tokens / mean_noise_span_length)
|
||||
Spans alternate between non-noise and noise, beginning with non-noise.
|
||||
Subject to the above restrictions, all masks are equally likely.
|
||||
Args:
|
||||
length: an int32 scalar (length of the incoming token sequence)
|
||||
Returns:
|
||||
a boolean tensor with shape [length]
|
||||
"""
|
||||
|
||||
orig_length = length
|
||||
|
||||
num_noise_tokens = int(np.round(length * self.noise_density))
|
||||
# avoid degeneracy by ensuring positive numbers of noise and nonnoise tokens.
|
||||
num_noise_tokens = min(max(num_noise_tokens, 1), length - 1)
|
||||
num_noise_spans = int(np.round(num_noise_tokens / self.mean_noise_span_length))
|
||||
|
||||
# avoid degeneracy by ensuring positive number of noise spans
|
||||
num_noise_spans = max(num_noise_spans, 1)
|
||||
num_nonnoise_tokens = length - num_noise_tokens
|
||||
|
||||
# pick the lengths of the noise spans and the non-noise spans
|
||||
def _random_segmentation(num_items, num_segments):
|
||||
"""
|
||||
Partition a sequence of items randomly into non-empty segments.
|
||||
Args:
|
||||
num_items: an integer scalar > 0
|
||||
num_segments: an integer scalar in [1, num_items]
|
||||
Returns:
|
||||
a Tensor with shape [num_segments] containing positive integers that add up to num_items
|
||||
"""
|
||||
mask_indices = np.arange(num_items - 1) < (num_segments - 1)
|
||||
np.random.shuffle(mask_indices)
|
||||
first_in_segment = np.pad(mask_indices, [[1, 0]])
|
||||
segment_id = np.cumsum(first_in_segment)
|
||||
# count length of subsegments assuming that list is sorted
|
||||
_, segment_length = np.unique(segment_id, return_counts=True)
|
||||
return segment_length
|
||||
|
||||
noise_span_lengths = _random_segmentation(num_noise_tokens, num_noise_spans)
|
||||
nonnoise_span_lengths = _random_segmentation(
|
||||
num_nonnoise_tokens, num_noise_spans
|
||||
)
|
||||
|
||||
interleaved_span_lengths = np.reshape(
|
||||
np.stack([nonnoise_span_lengths, noise_span_lengths], axis=1),
|
||||
[num_noise_spans * 2],
|
||||
)
|
||||
span_starts = np.cumsum(interleaved_span_lengths)[:-1]
|
||||
span_start_indicator = np.zeros((length,), dtype=np.int8)
|
||||
span_start_indicator[span_starts] = True
|
||||
span_num = np.cumsum(span_start_indicator)
|
||||
is_noise = np.equal(span_num % 2, 1)
|
||||
|
||||
return is_noise[:orig_length]
|
||||
|
||||
def create_sentinel_ids(self, mask_indices):
|
||||
"""
|
||||
Sentinel ids creation given the indices that should be masked.
|
||||
The start indices of each mask are replaced by the sentinel ids in increasing
|
||||
order. Consecutive mask indices to be deleted are replaced with `-1`.
|
||||
"""
|
||||
start_indices = mask_indices - np.roll(mask_indices, 1, axis=-1) * mask_indices
|
||||
|
||||
sentinel_ids = np.where(
|
||||
start_indices != 0, np.cumsum(start_indices, axis=-1), start_indices
|
||||
)
|
||||
# making sure all sentinel tokens are unique over the example
|
||||
sentinel_ids = np.where(sentinel_ids != 0, len(self.vocab) - sentinel_ids, 0)
|
||||
sentinel_ids -= mask_indices - start_indices
|
||||
return sentinel_ids
|
||||
|
||||
@staticmethod
|
||||
def filter_input_ids(input_ids, sentinel_ids):
|
||||
"""
|
||||
Puts sentinel mask on `input_ids` and fuse consecutive mask tokens into a single mask token by deleting.
|
||||
This will reduce the sequence length from `expanded_inputs_length` to `input_length`.
|
||||
"""
|
||||
input_ids_full = np.where(sentinel_ids != 0, sentinel_ids, input_ids)
|
||||
|
||||
# input_ids tokens and sentinel tokens are >= 0, tokens < 0 are
|
||||
# masked tokens coming after sentinel tokens and should be removed
|
||||
return input_ids_full[input_ids_full >= 0]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.dataset)
|
||||
|
||||
def collater(self, samples, pad_to_length=None):
|
||||
"""
|
||||
Merge a list of samples to form a mini-batch.
|
||||
Args:
|
||||
samples (List[dict]): samples to collate
|
||||
Returns:
|
||||
dict: a mini-batch of data
|
||||
"""
|
||||
return collate(
|
||||
samples,
|
||||
self.vocab.pad(),
|
||||
self.vocab.eos(),
|
||||
self.vocab,
|
||||
pad_to_length=pad_to_length,
|
||||
)
|
||||
|
||||
def num_tokens(self, index):
|
||||
"""Return the number of tokens in a sample. This value is used to
|
||||
enforce ``--max-tokens`` during batching."""
|
||||
return self.dataset.sizes[index]
|
||||
|
||||
def size(self, index):
|
||||
"""Return an example's size as a float or tuple. This value is used when
|
||||
filtering a dataset with ``--max-positions``."""
|
||||
return self.dataset.sizes[index]
|
||||
|
||||
def ordered_indices(self):
|
||||
"""Return an ordered list of indices. Batches will be constructed based
|
||||
on this order."""
|
||||
if self.shuffle:
|
||||
indices = np.random.permutation(len(self))
|
||||
else:
|
||||
indices = np.arange(len(self))
|
||||
return indices[np.argsort(self.dataset.sizes[indices], kind="mergesort")]
|
||||
|
||||
def prefetch(self, indices):
|
||||
self.src.prefetch(indices)
|
||||
self.tgt.prefetch(indices)
|
||||
|
||||
@property
|
||||
def supports_prefetch(self):
|
||||
return (
|
||||
hasattr(self.src, "supports_prefetch")
|
||||
and self.src.supports_prefetch
|
||||
and hasattr(self.tgt, "supports_prefetch")
|
||||
and self.tgt.supports_prefetch
|
||||
)
|
243
fairseq/tasks/span_masked_lm.py
Normal file
243
fairseq/tasks/span_masked_lm.py
Normal file
@ -0,0 +1,243 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
from omegaconf import II, MISSING
|
||||
|
||||
from fairseq import utils
|
||||
from fairseq.data import (
|
||||
AppendTokenDataset,
|
||||
Dictionary,
|
||||
IdDataset,
|
||||
NestedDictionaryDataset,
|
||||
NumelDataset,
|
||||
PadDataset,
|
||||
PrependTokenDataset,
|
||||
StripTokenDataset,
|
||||
TokenBlockDataset,
|
||||
data_utils,
|
||||
)
|
||||
from fairseq.data.shorten_dataset import maybe_shorten_dataset
|
||||
from fairseq.data.span_mask_tokens_dataset import SpanMaskedTokensDataset
|
||||
from fairseq.dataclass import ChoiceEnum, FairseqDataclass
|
||||
from fairseq.tasks import FairseqTask, register_task
|
||||
|
||||
from ..data.indexed_dataset import get_available_dataset_impl
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SAMPLE_BREAK_MODE_CHOICES = ChoiceEnum(["none", "complete", "complete_doc", "eos"])
|
||||
SHORTEN_METHOD_CHOICES = ChoiceEnum(["none", "truncate", "random_crop"])
|
||||
|
||||
|
||||
@dataclass
|
||||
class SpanMaskedLMConfig(FairseqDataclass):
|
||||
shuffle: bool = field(
|
||||
default=False,
|
||||
)
|
||||
noise_density: float = field(
|
||||
default=0.15,
|
||||
metadata={"help": "What fraction of the tokens to select as noise"},
|
||||
)
|
||||
mean_noise_span_length: float = field(
|
||||
default=3,
|
||||
metadata={"help": "Mean noise span length, must be >= 1"},
|
||||
)
|
||||
data: str = field(
|
||||
default=MISSING,
|
||||
metadata={
|
||||
"help": "colon separated path to data directories list, "
|
||||
"will be iterated upon during epochs in round-robin manner"
|
||||
},
|
||||
)
|
||||
sample_break_mode: SAMPLE_BREAK_MODE_CHOICES = field(
|
||||
default="none",
|
||||
metadata={
|
||||
"help": 'If omitted or "none", fills each sample with tokens-per-sample '
|
||||
'tokens. If set to "complete", splits samples only at the end '
|
||||
"of sentence, but may include multiple sentences per sample. "
|
||||
'"complete_doc" is similar but respects doc boundaries. '
|
||||
'If set to "eos", includes only one sentence per sample.'
|
||||
},
|
||||
)
|
||||
tokens_per_sample: int = field(
|
||||
default=1024,
|
||||
metadata={"help": "max number of tokens per sample for LM dataset"},
|
||||
)
|
||||
shorten_method: SHORTEN_METHOD_CHOICES = field(
|
||||
default="none",
|
||||
metadata={
|
||||
"help": "if not none, shorten sequences that exceed --tokens-per-sample"
|
||||
},
|
||||
)
|
||||
shorten_data_split_list: str = field(
|
||||
default="",
|
||||
metadata={
|
||||
"help": "comma-separated list of dataset splits to apply shortening to, "
|
||||
'e.g., "train,valid" (default: all dataset splits)'
|
||||
},
|
||||
)
|
||||
seed: int = II("common.seed")
|
||||
dataset_impl: Optional[ChoiceEnum(get_available_dataset_impl())] = II(
|
||||
"dataset.dataset_impl"
|
||||
)
|
||||
max_source_positions: int = field(
|
||||
default=1024, metadata={"help": "max number of tokens in the source sequence"}
|
||||
)
|
||||
max_target_positions: int = field(
|
||||
default=1024, metadata={"help": "max number of tokens in the target sequence"}
|
||||
)
|
||||
include_target_tokens: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "include target tokens in model input. this is used for data2vec"
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@register_task("span_masked_lm", dataclass=SpanMaskedLMConfig)
|
||||
class SpanMaskedLMTask(FairseqTask):
|
||||
"""
|
||||
Span masked language modeling task. (ie. T5)
|
||||
"""
|
||||
|
||||
cfg: SpanMaskedLMConfig
|
||||
|
||||
def __init__(self, cfg, dictionary):
|
||||
super().__init__(cfg)
|
||||
self.dictionary = dictionary
|
||||
|
||||
@classmethod
|
||||
def setup_task(cls, cfg: SpanMaskedLMConfig, **kwargs):
|
||||
"""Setup the task."""
|
||||
paths = utils.split_paths(cfg.data)
|
||||
assert len(paths) > 0
|
||||
dictionary = Dictionary.load(os.path.join(paths[0], "dict.txt"))
|
||||
logger.info("dictionary: {} types".format(len(dictionary)))
|
||||
if not hasattr(cfg, "shuffle"):
|
||||
cfg.shuffle = False
|
||||
return cls(cfg, dictionary)
|
||||
|
||||
def _load_dataset_split(self, split, epoch, combine):
|
||||
paths = utils.split_paths(self.cfg.data)
|
||||
assert len(paths) > 0
|
||||
data_path = paths[(epoch - 1) % len(paths)]
|
||||
split_path = os.path.join(data_path, split)
|
||||
|
||||
dataset = data_utils.load_indexed_dataset(
|
||||
split_path,
|
||||
self.dictionary,
|
||||
self.cfg.dataset_impl,
|
||||
combine=combine,
|
||||
)
|
||||
if dataset is None:
|
||||
raise FileNotFoundError(
|
||||
"Dataset not found: {} ({})".format(split, split_path)
|
||||
)
|
||||
|
||||
dataset = StripTokenDataset(dataset, self.dictionary.eos())
|
||||
|
||||
dataset = maybe_shorten_dataset(
|
||||
dataset,
|
||||
split,
|
||||
self.cfg.shorten_data_split_list,
|
||||
self.cfg.shorten_method,
|
||||
self.cfg.tokens_per_sample,
|
||||
self.cfg.seed,
|
||||
)
|
||||
|
||||
# create continuous blocks of tokens
|
||||
dataset = TokenBlockDataset(
|
||||
dataset,
|
||||
dataset.sizes,
|
||||
self.cfg.tokens_per_sample - 2, # one less for <s> and one for </s>
|
||||
pad=self.dictionary.pad(),
|
||||
eos=self.dictionary.eos(),
|
||||
break_mode=self.cfg.sample_break_mode,
|
||||
document_sep_len=0,
|
||||
)
|
||||
logger.info("loaded {} blocks from: {}".format(len(dataset), split_path))
|
||||
|
||||
# prepend beginning-of-sentence token (<s>, equiv. to [CLS] in BERT)
|
||||
dataset = PrependTokenDataset(dataset, self.source_dictionary.bos())
|
||||
dataset = AppendTokenDataset(dataset, self.source_dictionary.eos())
|
||||
return dataset
|
||||
|
||||
def load_dataset(self, split, epoch=1, combine=False, **kwargs):
|
||||
"""Load a given dataset split.
|
||||
|
||||
Args:
|
||||
split (str): name of the split (e.g., train, valid, test)
|
||||
"""
|
||||
dataset = self._load_dataset_split(split, epoch, combine)
|
||||
|
||||
self.datasets[split] = SpanMaskedTokensDataset(
|
||||
dataset,
|
||||
self.dictionary,
|
||||
noise_density=self.cfg.noise_density,
|
||||
mean_noise_span_length=self.cfg.mean_noise_span_length,
|
||||
shuffle=self.cfg.shuffle,
|
||||
seed=self.cfg.seed,
|
||||
)
|
||||
logger.info(
|
||||
"Split: {0}, Loaded {1} samples of span_masked_tokens_dataset".format(
|
||||
split,
|
||||
len(self.datasets[split]),
|
||||
)
|
||||
)
|
||||
|
||||
def build_dataset_for_inference(self, src_tokens, src_lengths, **kwargs):
|
||||
"""
|
||||
Generate batches for inference. We assume that the input begins with a
|
||||
bos symbol (`<s>`) and ends with an eos symbol (`</s>`).
|
||||
"""
|
||||
pad = self.source_dictionary.pad()
|
||||
eos = self.source_dictionary.eos()
|
||||
src_dataset = TokenBlockDataset(
|
||||
src_tokens,
|
||||
src_lengths,
|
||||
block_size=self.cfg.tokens_per_sample - 2, # for <s> and </s>
|
||||
pad=pad,
|
||||
eos=eos,
|
||||
break_mode=self.cfg.sample_break_mode,
|
||||
document_sep_len=0,
|
||||
)
|
||||
prev_output_tokens = PrependTokenDataset(
|
||||
StripTokenDataset(src_dataset, eos), eos
|
||||
)
|
||||
src_dataset = PadDataset(src_dataset, pad_idx=pad, left_pad=False)
|
||||
return NestedDictionaryDataset(
|
||||
{
|
||||
"id": IdDataset(),
|
||||
"net_input": {
|
||||
"src_tokens": src_dataset,
|
||||
"src_lengths": NumelDataset(src_dataset, reduce=False),
|
||||
"prev_output_tokens": PadDataset(
|
||||
prev_output_tokens, pad_idx=pad, left_pad=False
|
||||
),
|
||||
},
|
||||
"target": src_dataset,
|
||||
},
|
||||
sizes=[np.array(src_lengths)],
|
||||
)
|
||||
|
||||
def max_positions(self):
|
||||
"""Return the max sentence length allowed by the task."""
|
||||
return (self.cfg.max_source_positions, self.cfg.max_target_positions)
|
||||
|
||||
@property
|
||||
def source_dictionary(self):
|
||||
"""Return the source :class:`~fairseq.data.Dictionary`."""
|
||||
return self.dictionary
|
||||
|
||||
@property
|
||||
def target_dictionary(self):
|
||||
"""Return the target :class:`~fairseq.data.Dictionary`."""
|
||||
return self.dictionary
|
106
tests/tasks/test_span_masked_lm.py
Normal file
106
tests/tasks/test_span_masked_lm.py
Normal file
@ -0,0 +1,106 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import os
|
||||
import unittest
|
||||
from tempfile import TemporaryDirectory
|
||||
|
||||
from fairseq import options
|
||||
from fairseq.binarizer import FileBinarizer, VocabularyDatasetBinarizer
|
||||
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
|
||||
from fairseq.tasks.span_masked_lm import SpanMaskedLMTask
|
||||
from tests.utils import build_vocab, make_data
|
||||
|
||||
|
||||
class TestSpanMaskedLM(unittest.TestCase):
|
||||
def test_masks_token_spans(self):
|
||||
with TemporaryDirectory() as dirname:
|
||||
|
||||
# prep input file
|
||||
raw_file = os.path.join(dirname, "raw")
|
||||
data = make_data(out_file=raw_file)
|
||||
vocab = build_vocab(data)
|
||||
|
||||
# binarize
|
||||
binarizer = VocabularyDatasetBinarizer(vocab, append_eos=False)
|
||||
split = "train"
|
||||
bin_file = os.path.join(dirname, split)
|
||||
dataset_impl = "mmap"
|
||||
|
||||
FileBinarizer.multiprocess_dataset(
|
||||
input_file=raw_file,
|
||||
binarizer=binarizer,
|
||||
dataset_impl=dataset_impl,
|
||||
vocab_size=len(vocab),
|
||||
output_prefix=bin_file,
|
||||
)
|
||||
|
||||
# adding sentinel tokens
|
||||
for i in range(100):
|
||||
vocab.add_symbol(f"<extra_id_{i}>")
|
||||
|
||||
# setup task
|
||||
train_args = options.parse_args_and_arch(
|
||||
options.get_training_parser(),
|
||||
[
|
||||
"--task",
|
||||
"span_masked_lm",
|
||||
"--arch",
|
||||
"bart_base",
|
||||
"--seed",
|
||||
"42",
|
||||
dirname,
|
||||
],
|
||||
)
|
||||
cfg = convert_namespace_to_omegaconf(train_args)
|
||||
task = SpanMaskedLMTask(cfg.task, binarizer.dict)
|
||||
|
||||
# load datasets
|
||||
original_dataset = task._load_dataset_split(bin_file, 1, False)
|
||||
task.load_dataset(split)
|
||||
masked_dataset = task.dataset(split)
|
||||
|
||||
iterator = task.get_batch_iterator(
|
||||
dataset=masked_dataset,
|
||||
max_tokens=65_536,
|
||||
max_positions=4_096,
|
||||
).next_epoch_itr(shuffle=False)
|
||||
num_tokens = len(vocab)
|
||||
for batch in iterator:
|
||||
for sample in range(len(batch)):
|
||||
sample_id = batch["id"][sample]
|
||||
original_tokens = original_dataset[sample_id]
|
||||
masked_src_tokens = batch["net_input"]["src_tokens"][sample]
|
||||
masked_src_length = batch["net_input"]["src_lengths"][sample]
|
||||
masked_tgt_tokens = batch["target"][sample]
|
||||
|
||||
original_offset = 0
|
||||
masked_tgt_offset = 0
|
||||
extra_id_token = len(vocab) - 1
|
||||
for masked_src_token in masked_src_tokens[:masked_src_length]:
|
||||
if masked_src_token == extra_id_token:
|
||||
assert (
|
||||
masked_src_token == masked_tgt_tokens[masked_tgt_offset]
|
||||
)
|
||||
extra_id_token -= 1
|
||||
masked_tgt_offset += 1
|
||||
while (
|
||||
original_offset < len(original_tokens)
|
||||
and masked_tgt_tokens[masked_tgt_offset]
|
||||
!= extra_id_token
|
||||
):
|
||||
assert (
|
||||
original_tokens[original_offset]
|
||||
== masked_tgt_tokens[masked_tgt_offset]
|
||||
)
|
||||
original_offset += 1
|
||||
masked_tgt_offset += 1
|
||||
else:
|
||||
assert original_tokens[original_offset] == masked_src_token
|
||||
original_offset += 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in New Issue
Block a user