CPLTask for training with continuous pseudo labeling

Summary:
CPLTaskImpl provides implementation to augment existing tasks to take additional input of ema_model in its train_step and valid_step for continous pseudo-labeling (CPL) during training. It passes this ema_model to the criterion.

See Kaizen semi-supervised training paper for more details https://arxiv.org/abs/2106.07759.

This implementation also supports using CPLDataset which enables using unsupervised data only for `cpl_finetune_epoch > epochs >= cpl_start_epoch`. CPLDataset is like MultiCorpusDataset but ignores the unsupervised datasets while sampling.

Another addition in this diff is to skip dataset in MultiCorpusDataset if the sampling probability is 0.

Reviewed By: cruvadom

Differential Revision: D30701536

fbshipit-source-id: 1d840eacfd538ed7aed3baaefc8b254390642b45
This commit is contained in:
Vimal Manohar 2021-10-14 22:07:55 -07:00 committed by Facebook GitHub Bot
parent f670d9f1f2
commit 1ef3d6a1a2
2 changed files with 20 additions and 11 deletions

View File

@ -6,7 +6,7 @@
import logging
import time
from collections import OrderedDict
from typing import Dict, List
from typing import Dict, List, Optional
import numpy as np
from fairseq.data import data_utils
@ -18,13 +18,15 @@ logger = logging.getLogger(__name__)
class MultiCorpusDataset(FairseqDataset):
"""
Stores multiple instances of FairseqDataset together. Requires each instance
Stores multiple instances of FairseqDataset together.
Unless batch_sample=True, requires each instance
to be the same dataset, as the collate method needs to work on batches with
samples from each dataset.
Allows specifying a distribution over the datasets to use. Note that unlike
MultiCorpusSampledDataset, this distribution allows sampling for each item,
rather than on a batch level.
rather than on a batch level. Note that datasets with sampling probabilty
of 0 will be skipped.
Each time ordered_indices() is called, a new sample is generated with
the specified distribution.
@ -45,7 +47,7 @@ class MultiCorpusDataset(FairseqDataset):
seed: int,
sort_indices: bool = False,
batch_sample: bool = False,
distributed_rank=None,
distributed_rank: Optional[int] = None,
):
super().__init__()
assert isinstance(datasets, OrderedDict)
@ -62,14 +64,18 @@ class MultiCorpusDataset(FairseqDataset):
self.dataset_list = list(datasets.values())
self.total_num_instances = 0
first_dataset = list(self.datasets.values())[0]
first_dataset = self.dataset_list[0]
self.num_instances_per_dataset = []
self.dataset_offsets = []
for dataset in datasets.values():
for i, dataset in enumerate(self.dataset_list):
assert isinstance(dataset, FairseqDataset)
assert type(dataset) is type(first_dataset)
self.num_instances_per_dataset.append(
0 if self.distribution[i] == 0 else len(dataset)
)
self.dataset_offsets.append(self.total_num_instances)
self.total_num_instances += len(dataset)
self.total_num_instances += self.num_instances_per_dataset[i]
def ordered_indices(self):
start = time.time()
@ -80,6 +86,9 @@ class MultiCorpusDataset(FairseqDataset):
# For each dataset i, sample self.distribution[i] * self.total_num_instances
for i, key in enumerate(self.datasets):
if self.distribution[i] == 0:
# skip dataset if sampling probability is 0
continue
if i < len(self.datasets) - 1:
num_instances = int(self.distribution[i] * self.total_num_instances)
@ -136,10 +145,10 @@ class MultiCorpusDataset(FairseqDataset):
maps to index 1 of B.
"""
counter = 0
for key, dataset in self.datasets.items():
if index < counter + len(dataset):
for num_instances, key in zip(self.num_instances_per_dataset, self.datasets):
if index < counter + num_instances:
return index - counter, key
counter += len(dataset)
counter += num_instances
raise ValueError(
"Invalid index: {}, max: {}".format(index, self.total_num_instances)
)

View File

@ -75,5 +75,5 @@ class TestMultiCorpusDataset(unittest.TestCase):
print(distribution)
def test_multi_corpus_dataset(self):
for distribution in [[0.5, 0.5], [0.1, 0.9], [0.9, 0.1]]:
for distribution in [[0.5, 0.5], [0.1, 0.9], [0.9, 0.1], [0.0, 1.0]]:
self._test_sample_helper(distribution=distribution)