Multilingual v1: Multilingual Training with multiple bitext and monolingual datasets: new multiligual task

Summary:
A first version of XLNMT multilingual project code release: Multilingual Training with multiple bitext

- A new task to glue all things together: fairseq/tasks/translation_multi_simple_epoch.py
- Minor changes to
    - fairseq/data/iterators.py to allow dynamic batch sampler
    - fairseq/checkpoint_utils.py to add finetuning option instead of using restore_file which will restore from original model when being requeued.

Reviewed By: pipibjc

Differential Revision: D22483484

fbshipit-source-id: 283b67e538508f330b0968609b7dae64d26bea05
This commit is contained in:
Yuqing Tang 2020-07-16 09:32:44 -07:00 committed by Facebook GitHub Bot
parent 033daef0fc
commit e52d071ee8
7 changed files with 380 additions and 9 deletions

View File

@ -188,8 +188,10 @@ class EpochBatchIterator(EpochBatchIterating):
Args:
dataset (~torch.utils.data.Dataset): dataset from which to load the data
collate_fn (callable): merges a list of samples to form a mini-batch
batch_sampler (~torch.utils.data.Sampler): an iterator over batches of
indices
batch_sampler (~torch.utils.data.Sampler or a callable): an iterator over batches of
indices, or a callable to create such an iterator (~torch.utils.data.Sampler).
A callable batch_sampler will be called for each epoch to enable per epoch dynamic
batch iterators defined by this callable batch_sampler.
seed (int, optional): seed for random number generator for
reproducibility (default: 1).
num_shards (int, optional): shard the data iterator into N
@ -215,7 +217,8 @@ class EpochBatchIterator(EpochBatchIterating):
assert isinstance(dataset, torch.utils.data.Dataset)
self.dataset = dataset
self.collate_fn = collate_fn
self.frozen_batches = tuple(batch_sampler)
self.batch_sampler = batch_sampler
self._frozen_batches = tuple(batch_sampler) if not callable(batch_sampler) else None
self.seed = seed
self.num_shards = num_shards
self.shard_id = shard_id
@ -231,6 +234,12 @@ class EpochBatchIterator(EpochBatchIterating):
self._next_epoch_itr = None
self._supports_prefetch = getattr(dataset, 'supports_prefetch', False)
@property
def frozen_batches(self):
if self._frozen_batches is None:
self._frozen_batches = tuple(self.batch_sampler(self.dataset, self.epoch))
return self._frozen_batches
def __len__(self):
return int(math.ceil(len(self.frozen_batches) / float(self.num_shards)))
@ -259,14 +268,17 @@ class EpochBatchIterator(EpochBatchIterating):
that :attr:`dataset` supports prefetching (default: False).
"""
self.epoch = self.next_epoch_idx
self.dataset.set_epoch(self.epoch)
if self._next_epoch_itr is not None:
self._cur_epoch_itr = self._next_epoch_itr
self._next_epoch_itr = None
else:
if callable(self.batch_sampler):
# reset _frozen_batches to refresh the next epoch
self._frozen_batches = None
self._cur_epoch_itr = self._get_iterator_for_epoch(
self.epoch, shuffle, fix_batches_to_gpus=fix_batches_to_gpus,
)
self.dataset.set_epoch(self.epoch)
self.shuffle = shuffle
return self._cur_epoch_itr

View File

@ -12,7 +12,7 @@ from collections import OrderedDict
import json
from fairseq import options
from fairseq.options import eval_str_dict, eval_str_list
from fairseq.options import eval_str_dict, csv_str_list
from fairseq.data import (
Dictionary,
@ -123,9 +123,13 @@ class MultilingualDatasetManager(object):
e.g. {"mined": comma-separated-lang-pairs, "denoised": comma-separated-lang-pairs}',
type=lambda uf: eval_str_dict(uf, type=str),
default=None)
parser.add_argument('--langtoks-specs', help='a list of comma separated language tokens specifictions',
parser.add_argument('--langtoks-specs',
help='a list of comma separated data types that a set of language tokens to be specialized for, \
e.g. "main,dae,mined". There will be a set of language tokens added to the vocab to \
distinguish languages in different training data types. If not specified, default language \
tokens per languages will be added',
default='main',
type=lambda uf: eval_str_list(uf, type=str),
type=csv_str_list,
)
parser.add_argument('--langtoks', help='a dictionary of how to add language tokens, \
e.g. {"mined": (None, "tgt"), "mono_dae": ("src.dae", "tgt"), "main": \

View File

@ -60,6 +60,10 @@ def get_validation_parser(default_task=None):
return parser
def csv_str_list(x):
return x.split(',')
def eval_str_list(x, type=float):
if x is None:
return None
@ -71,6 +75,14 @@ def eval_str_list(x, type=float):
return [type(x)]
def eval_str_dict(x, type=dict):
if x is None:
return None
if isinstance(x, str):
x = eval(x)
return x
def eval_bool(x, default=False):
if x is None:
return default

View File

@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree.
import warnings
import os
import torch
@ -78,6 +79,9 @@ class FairseqTask(object):
"""
return cls(args, **kwargs)
def has_sharded_data(self, split):
return (os.pathsep in getattr(self.args, 'data', ''))
def load_dataset(self, split, combine=False, **kwargs):
"""Load a given dataset split.

View File

@ -0,0 +1,300 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
import datetime
import time
import torch
from fairseq.data import (
data_utils,
FairseqDataset,
iterators,
LanguagePairDataset,
ListDataset,
)
from fairseq.tasks import FairseqTask, register_task
from fairseq.data.multilingual.sampling_method import SamplingMethod
from fairseq.data.multilingual.multilingual_data_manager import MultilingualDatasetManager
###
def get_time_gap(s, e):
return (datetime.datetime.fromtimestamp(e) - datetime.datetime.fromtimestamp(s)).__str__()
###
logger = logging.getLogger(__name__)
@register_task('translation_multi_simple_epoch')
class TranslationMultiSimpleEpochTask(FairseqTask):
"""
Translate from one (source) language to another (target) language.
Args:
langs (List[str]): a list of languages that are being supported
dicts (Dict[str, fairseq.data.Dictionary]): mapping from supported languages to their dictionaries
training (bool): whether the task should be configured for training or not
.. note::
The translation task is compatible with :mod:`fairseq-train`,
:mod:`fairseq-generate` and :mod:`fairseq-interactive`.
The translation task provides the following additional command-line
arguments:
.. argparse::
:ref: fairseq.tasks.translation_parser
:prog:
"""
@staticmethod
def add_args(parser):
"""Add task-specific arguments to the parser."""
# fmt: off
parser.add_argument('-s', '--source-lang', default=None, metavar='SRC',
help='inference source language')
parser.add_argument('-t', '--target-lang', default=None, metavar='TARGET',
help='inference target language')
parser.add_argument('--lang-pairs', default=None, metavar='PAIRS',
help='comma-separated list of language pairs (in training order): en-de,en-fr,de-fr')
SamplingMethod.add_arguments(parser)
MultilingualDatasetManager.add_args(parser)
# fmt: on
def __init__(self, args, langs, dicts, training):
super().__init__(args)
self.langs = langs
self.dicts = dicts
self.training = training
if training:
self.lang_pairs = args.lang_pairs
else:
self.lang_pairs = ['{}-{}'.format(args.source_lang, args.target_lang)]
# eval_lang_pairs for multilingual translation is usually all of the
# lang_pairs. However for other multitask settings or when we want to
# optimize for certain languages we want to use a different subset. Thus
# the eval_lang_pairs class variable is provided for classes that extend
# this class.
self.eval_lang_pairs = self.lang_pairs
# model_lang_pairs will be used to build encoder-decoder model pairs in
# models.build_model(). This allows multitask type of sub-class can
# build models other than the input lang_pairs
self.model_lang_pairs = self.lang_pairs
self.sampling_method = SamplingMethod.build_sampler(args, self)
self.data_manager = MultilingualDatasetManager.setup_data_manager(
args, self.lang_pairs, langs, dicts, self.sampling_method)
@classmethod
def setup_task(cls, args, **kwargs):
langs, dicts, training = MultilingualDatasetManager.prepare(args, **kwargs)
return cls(args, langs, dicts, training)
def has_sharded_data(self, split):
return self.data_manager.has_sharded_data(split)
def load_dataset(self, split, epoch=1, combine=False, **kwargs):
"""Load a given dataset split.
Args:
split (str): name of the split (e.g., train, valid, test)
"""
if split in self.datasets:
dataset = self.datasets[split]
if self.has_sharded_data(split) and dataset.load_next_shard:
shard_epoch = dataset.shard_epoch
else:
# no need to load next shard so skip loading
# also this avoid always loading from beginning of the data
return
else:
shard_epoch = None
self.datasets[split] = self.data_manager.load_sampled_multi_epoch_dataset(
split,
self.training,
epoch=epoch, combine=combine, shard_epoch=shard_epoch, **kwargs
)
def build_dataset_for_inference(self, src_tokens, src_lengths):
src_data = ListDataset(src_tokens, src_lengths)
dataset = LanguagePairDataset(src_data, src_lengths, self.source_dictionary)
src_langtok_spec, tgt_langtok_spec = self.args.langtoks['main']
if self.args.lang_tok_replacing_bos_eos:
dataset = self.data_manager.alter_dataset_langtok(
dataset,
src_eos=self.source_dictionary.eos(),
src_lang=self.args.source_lang,
tgt_eos=self.target_dictionary.eos(),
tgt_lang=self.args.target_lang,
src_langtok_spec=src_langtok_spec,
tgt_langtok_spec=tgt_langtok_spec,
)
else:
dataset.src = self.data_manager.src_dataset_tranform_func(
self.args.source_lang,
self.args.target_lang,
dataset=dataset.src,
spec=src_langtok_spec,
)
return dataset
def build_model(self, args):
return super().build_model(args)
def valid_step(self, sample, model, criterion):
loss, sample_size, logging_output = super().valid_step(sample, model, criterion)
return loss, sample_size, logging_output
def inference_step(self, generator, models, sample, prefix_tokens=None):
with torch.no_grad():
_, tgt_langtok_spec = self.args.langtoks['main']
if not self.args.lang_tok_replacing_bos_eos:
if prefix_tokens is None and tgt_langtok_spec:
tgt_lang_tok = self.data_manager.get_decoder_langtok(self.args.target_lang, tgt_langtok_spec)
src_tokens = sample['net_input']['src_tokens']
bsz = src_tokens.size(0)
prefix_tokens = torch.LongTensor(
[[tgt_lang_tok]]
).expand(bsz, 1).to(src_tokens)
return generator.generate(
models,
sample,
prefix_tokens=prefix_tokens,
)
else:
return generator.generate(
models,
sample,
prefix_tokens=prefix_tokens,
bos_token=self.data_manager.get_decoder_langtok(self.args.target_lang, tgt_langtok_spec)
if tgt_langtok_spec else self.target_dictionary.eos(),
)
def reduce_metrics(self, logging_outputs, criterion):
super().reduce_metrics(logging_outputs, criterion)
def max_positions(self):
"""Return the max sentence length allowed by the task."""
return (self.args.max_source_positions, self.args.max_target_positions)
@property
def source_dictionary(self):
if self.training:
return next(iter(self.dicts.values()))
else:
return self.dicts[self.args.source_lang]
@property
def target_dictionary(self):
if self.training:
return next(iter(self.dicts.values()))
else:
return self.dicts[self.args.target_lang]
def create_batch_sampler_func(
self, max_positions, ignore_invalid_inputs,
max_tokens, max_sentences
):
def construct_batch_sampler(
dataset, epoch
):
splits = [s for s, _ in self.datasets.items() if self.datasets[s] == dataset]
split = splits[0] if len(splits) > 0 else None
if epoch is not None:
dataset.set_epoch(epoch)
start_time = time.time()
# get indices ordered by example size
indices = dataset.ordered_indices()
logger.debug(f'[{split}] @batch_sampler order indices time: {get_time_gap(start_time, time.time())}')
# filter examples that are too large
if max_positions is not None:
my_time = time.time()
indices = data_utils.filter_by_size(
indices, dataset, max_positions, raise_exception=(not ignore_invalid_inputs),
)
logger.debug(f'[{split}] @batch_sampler filter_by_size time: {get_time_gap(my_time, time.time())}')
# create mini-batches with given size constraints
my_time = time.time()
batch_sampler = data_utils.batch_by_size(
indices, dataset.num_tokens, max_tokens=max_tokens, max_sentences=max_sentences,
)
logger.debug(f'[{split}] @batch_sampler batch_by_size time: {get_time_gap(my_time, time.time())}')
logger.debug(f'[{split}] per epoch batch_sampler set-up time: {get_time_gap(start_time, time.time())}')
return batch_sampler
return construct_batch_sampler
# we need to override get_batch_iterator because we want to reset the epoch iterator each time
def get_batch_iterator(
self, dataset, max_tokens=None, max_sentences=None, max_positions=None,
ignore_invalid_inputs=False, required_batch_size_multiple=1,
seed=1, num_shards=1, shard_id=0, num_workers=0, epoch=1,
):
"""
Get an iterator that yields batches of data from the given dataset.
Args:
dataset (~fairseq.data.FairseqDataset): dataset to batch
max_tokens (int, optional): max number of tokens in each batch
(default: None).
max_sentences (int, optional): max number of sentences in each
batch (default: None).
max_positions (optional): max sentence length supported by the
model (default: None).
ignore_invalid_inputs (bool, optional): don't raise Exception for
sentences that are too long (default: False).
required_batch_size_multiple (int, optional): require batch size to
be a multiple of N (default: 1).
seed (int, optional): seed for random number generator for
reproducibility (default: 1).
num_shards (int, optional): shard the data iterator into N
shards (default: 1).
shard_id (int, optional): which shard of the data iterator to
return (default: 0).
num_workers (int, optional): how many subprocesses to use for data
loading. 0 means the data will be loaded in the main process
(default: 0).
epoch (int, optional): the epoch to start the iterator from
(default: 0).
Returns:
~fairseq.iterators.EpochBatchIterator: a batched iterator over the
given dataset split
"""
# initialize the dataset with the correct starting epoch
assert isinstance(dataset, FairseqDataset)
if dataset in self.dataset_to_epoch_iter:
return self.dataset_to_epoch_iter[dataset]
if (
self.args.sampling_method == 'RoundRobin'
):
batch_iter = super().get_batch_iterator(
dataset, max_tokens=max_tokens, max_sentences=max_sentences, max_positions=max_positions,
ignore_invalid_inputs=ignore_invalid_inputs, required_batch_size_multiple=required_batch_size_multiple,
seed=seed, num_shards=num_shards, shard_id=shard_id, num_workers=num_workers, epoch=epoch,
)
self.dataset_to_epoch_iter[dataset] = batch_iter
return batch_iter
construct_batch_sampler = self.create_batch_sampler_func(
max_positions, ignore_invalid_inputs,
max_tokens, max_sentences)
epoch_iter = iterators.EpochBatchIterator(
dataset=dataset,
collate_fn=dataset.collater,
batch_sampler=construct_batch_sampler,
seed=seed,
num_shards=num_shards,
shard_id=shard_id,
num_workers=num_workers,
epoch=epoch,
)
return epoch_iter

View File

@ -10,7 +10,6 @@ Train a new model on one or across multiple GPUs.
import argparse
import logging
import math
import os
import random
import sys
from typing import Callable, Optional
@ -130,6 +129,7 @@ def main(
lr = trainer.get_lr()
train_meter = meters.StopwatchMeter()
train_meter.start()
while lr > args.min_lr and epoch_itr.next_epoch_idx <= max_epoch:
# train for one epoch
valid_losses, should_stop = train(args, trainer, task, epoch_itr)
@ -142,7 +142,7 @@ def main(
epoch_itr = trainer.get_train_iterator(
epoch_itr.next_epoch_idx,
# sharded data: get train iterator for next epoch
load_dataset=(os.pathsep in getattr(args, "data", "")),
load_dataset=task.has_sharded_data('train'),
)
train_meter.stop()
logger.info("done training in {:.1f} seconds".format(train_meter.sum))

View File

@ -206,6 +206,45 @@ class TestTranslation(unittest.TestCase):
] + enc_ltok_flag + dec_ltok_flag,
)
def test_translation_multi_simple_epoch(self):
# test with all combinations of encoder/decoder lang tokens
encoder_langtok_flags = [[], ['--encoder-langtok', 'src'], ['--encoder-langtok', 'tgt']]
decoder_langtok_flags = [[], ['--decoder-langtok']]
with contextlib.redirect_stdout(StringIO()):
for i in range(len(encoder_langtok_flags)):
for j in range(len(decoder_langtok_flags)):
enc_ltok_flag = encoder_langtok_flags[i]
dec_ltok_flag = decoder_langtok_flags[j]
with tempfile.TemporaryDirectory(f'test_translation_multi_simple_epoch_{i}_{j}') as data_dir:
create_dummy_data(data_dir)
preprocess_translation_data(data_dir)
train_translation_model(
data_dir,
arch='transformer',
task='translation_multi_simple_epoch',
extra_flags=[
'--encoder-layers', '2',
'--decoder-layers', '2',
'--encoder-embed-dim', '8',
'--decoder-embed-dim', '8',
'--sampling-method', 'temperature',
'--sampling-temperature', '1.5',
'--virtual-epoch-size', '1000',
] + enc_ltok_flag + dec_ltok_flag,
lang_flags=['--lang-pairs', 'in-out,out-in'],
run_validation=True,
extra_valid_flags=enc_ltok_flag + dec_ltok_flag,
)
generate_main(
data_dir,
extra_flags=[
'--task', 'translation_multi_simple_epoch',
'--lang-pairs', 'in-out,out-in',
'--source-lang', 'in',
'--target-lang', 'out',
] + enc_ltok_flag + dec_ltok_flag,
)
def test_transformer_cross_self_attention(self):
with contextlib.redirect_stdout(StringIO()):
with tempfile.TemporaryDirectory('test_transformer_cross_self_attention') as data_dir: