Migrate DummyMaskedLMTask to FairseqTask (#3593)

Summary:
# Before submitting

- [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements)
- [ ] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)?
- [ ] Did you make sure to update the docs?
- [ ] Did you write any new necessary tests?

## What does this PR do?
Fixes # (issue).

## PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

## Did you have fun?
Make sure you had fun coding �

Pull Request resolved: https://github.com/pytorch/fairseq/pull/3593

Reviewed By: msbaines

Differential Revision: D28992614

Pulled By: dianaml0

fbshipit-source-id: b2dfcab472a65c41536e78600a0e6b3745dc3a08
This commit is contained in:
Diana Liskovich 2021-06-10 09:42:18 -07:00 committed by Facebook GitHub Bot
parent 2fd9d8a972
commit 50158da3a7
5 changed files with 84 additions and 112 deletions

View File

@ -4,4 +4,4 @@
# LICENSE file in the root directory of this source tree.
# import models/tasks to register them
from . import dummy_lm, dummy_masked_lm, dummy_model, dummy_mt # noqa
from . import dummy_dataset, dummy_lm, dummy_masked_lm, dummy_model, dummy_mt # noqa

View File

@ -0,0 +1,36 @@
import numpy as np
from fairseq.data import FairseqDataset
class DummyDataset(FairseqDataset):
def __init__(self, batch, num_items, item_size):
super().__init__()
self.batch = batch
self.num_items = num_items
self.item_size = item_size
def __getitem__(self, index):
return index
def __len__(self):
return self.num_items
def collater(self, samples):
return self.batch
@property
def sizes(self):
return np.array([self.item_size] * self.num_items)
def num_tokens(self, index):
return self.item_size
def size(self, index):
return self.item_size
def ordered_indices(self):
return np.arange(self.num_items)
@property
def supports_prefetch(self):
return False

View File

@ -7,9 +7,9 @@ import logging
from dataclasses import dataclass, field
from typing import Optional
import numpy as np
import torch
from fairseq.data import Dictionary, FairseqDataset
from .dummy_dataset import DummyDataset
from fairseq.data import Dictionary
from fairseq.dataclass import FairseqDataclass
from fairseq.tasks import FairseqTask, register_task
from omegaconf import II
@ -33,7 +33,6 @@ class DummyLMConfig(FairseqDataclass):
@register_task("dummy_lm", dataclass=DummyLMConfig)
class DummyLMTask(FairseqTask):
def __init__(self, cfg: DummyLMConfig):
super().__init__(cfg)
@ -82,37 +81,3 @@ class DummyLMTask(FairseqTask):
@property
def target_dictionary(self):
return self.dictionary
class DummyDataset(FairseqDataset):
def __init__(self, batch, num_items, item_size):
super().__init__()
self.batch = batch
self.num_items = num_items
self.item_size = item_size
def __getitem__(self, index):
return index
def __len__(self):
return self.num_items
def collater(self, samples):
return self.batch
@property
def sizes(self):
return np.array([self.item_size] * self.num_items)
def num_tokens(self, index):
return self.item_size
def size(self, index):
return self.item_size
def ordered_indices(self):
return np.arange(self.num_items)
@property
def supports_prefetch(self):
return False

View File

@ -4,43 +4,53 @@
# LICENSE file in the root directory of this source tree.
import logging
from dataclasses import dataclass, field
from typing import Optional
import numpy as np
import torch
from fairseq.data import Dictionary, FairseqDataset
from fairseq.tasks import LegacyFairseqTask, register_task
from omegaconf import II
from .dummy_dataset import DummyDataset
from fairseq.data import Dictionary
from fairseq.dataclass import FairseqDataclass
from fairseq.tasks import FairseqTask, register_task
logger = logging.getLogger(__name__)
@register_task("dummy_masked_lm")
class DummyMaskedLMTask(LegacyFairseqTask):
@staticmethod
def add_args(parser):
"""Add task-specific arguments to the parser."""
parser.add_argument("--dict-size", default=49995, type=int)
parser.add_argument("--dataset-size", default=100000, type=int)
parser.add_argument(
"--tokens-per-sample",
default=512,
type=int,
help="max number of total tokens over all segments "
"per sample for BERT dataset",
)
@dataclass
class DummyMaskedLMConfig(FairseqDataclass):
dict_size: int = 49996
dataset_size: int = 100000
tokens_per_sample: int = field(
default=512,
metadata={
"help": "max number of total tokens over all"
" segments per sample for BERT dataset"
},
)
batch_size: Optional[int] = II("dataset.batch_size")
max_tokens: Optional[int] = II("dataset.max_tokens")
max_target_positions: int = II("task.tokens_per_sample")
def __init__(self, args, dictionary):
super().__init__(args)
self.dictionary = dictionary
@register_task("dummy_masked_lm", dataclass=DummyMaskedLMConfig)
class DummyMaskedLMTask(FairseqTask):
def __init__(self, cfg: DummyMaskedLMConfig):
super().__init__(cfg)
self.dictionary = Dictionary()
for i in range(cfg.dict_size):
self.dictionary.add_symbol("word{}".format(i))
logger.info("dictionary: {} types".format(len(self.dictionary)))
# add mask token
self.mask_idx = dictionary.add_symbol("<mask>")
dictionary.pad_to_multiple_(8) # often faster if divisible by 8
self.mask_idx = self.dictionary.add_symbol("<mask>")
self.dictionary.pad_to_multiple_(8) # often faster if divisible by 8
mask_idx = 0
pad_idx = 1
seq = torch.arange(args.tokens_per_sample) + pad_idx + 1
mask = torch.arange(2, args.tokens_per_sample, 7) # ~15%
seq = torch.arange(cfg.tokens_per_sample) + pad_idx + 1
mask = torch.arange(2, cfg.tokens_per_sample, 7) # ~15%
src = seq.clone()
src[mask] = mask_idx
tgt = torch.full_like(seq, pad_idx)
@ -49,39 +59,30 @@ class DummyMaskedLMTask(LegacyFairseqTask):
self.dummy_src = src
self.dummy_tgt = tgt
@classmethod
def setup_task(cls, args, **kwargs):
"""Setup the task. """
dictionary = Dictionary()
for i in range(args.dict_size):
dictionary.add_symbol("word{}".format(i))
logger.info("dictionary: {} types".format(len(dictionary)))
return cls(args, dictionary)
def load_dataset(self, split, epoch=1, combine=False, **kwargs):
"""Load a given dataset split.
Args:
split (str): name of the split (e.g., train, valid, test)
"""
if self.args.batch_size is not None:
bsz = self.args.batch_size
if self.cfg.batch_size is not None:
bsz = self.cfg.batch_size
else:
bsz = max(1, self.args.max_tokens // self.args.tokens_per_sample)
bsz = max(1, self.cfg.max_tokens // self.cfg.tokens_per_sample)
self.datasets[split] = DummyDataset(
{
"id": 1,
"net_input": {
"src_tokens": torch.stack([self.dummy_src for _ in range(bsz)]),
"src_lengths": torch.full(
(bsz,), self.args.tokens_per_sample, dtype=torch.long
(bsz,), self.cfg.tokens_per_sample, dtype=torch.long
),
},
"target": torch.stack([self.dummy_tgt for _ in range(bsz)]),
"nsentences": bsz,
"ntokens": bsz * self.args.tokens_per_sample,
"ntokens": bsz * self.cfg.tokens_per_sample,
},
num_items=self.args.dataset_size,
item_size=self.args.tokens_per_sample,
num_items=self.cfg.dataset_size,
item_size=self.cfg.tokens_per_sample,
)
@property
@ -91,37 +92,3 @@ class DummyMaskedLMTask(LegacyFairseqTask):
@property
def target_dictionary(self):
return self.dictionary
class DummyDataset(FairseqDataset):
def __init__(self, batch, num_items, item_size):
super().__init__()
self.batch = batch
self.num_items = num_items
self.item_size = item_size
def __getitem__(self, index):
return index
def __len__(self):
return self.num_items
def collater(self, samples):
return self.batch
@property
def sizes(self):
return np.array([self.item_size] * self.num_items)
def num_tokens(self, index):
return self.item_size
def size(self, index):
return self.item_size
def ordered_indices(self):
return np.arange(self.num_items)
@property
def supports_prefetch(self):
return False

View File

@ -103,6 +103,10 @@ class TestValidSubsetsErrors(unittest.TestCase):
cfg = make_lm_config(task="dummy_lm")
raise_if_valid_subsets_unintentionally_ignored(cfg)
def test_masked_dummy_task(self):
cfg = make_lm_config(task="dummy_masked_lm")
raise_if_valid_subsets_unintentionally_ignored(cfg)
class TestCombineValidSubsets(unittest.TestCase):
def _train(self, extra_flags):