allowing sharded dataset (#696)

Summary:
Co-authored-by: myleott <myleott@fb.com>

Changing `data` to be `str` with colon separated list for loading sharded datasets. This change is useful for loading large datasets that cannot fit into, memory. The large dataset can be sharded and then each shard is loaded in one epoch in roudrobin manner.

For example, if there are `5` shards of data and `10` epochs then the shards will be iterated upon `[0, 1, 2, 3, 4, 0, 1, 2, 3, 4]`.

myleott We need to look into `translation.py` as it currently already expects a list and then concats the datasets.
Pull Request resolved: https://github.com/pytorch/fairseq/pull/696

Differential Revision: D15214049

fbshipit-source-id: 03e43a7b69c7aefada2ca668abf1eac1969fe013
This commit is contained in:
Naman Goyal 2019-05-06 14:56:13 -07:00 committed by Facebook Github Bot
parent 57da383c9a
commit 0add50c2e0
9 changed files with 140 additions and 91 deletions

View File

@ -79,11 +79,12 @@ class EpochBatchIterator(object):
num_workers (int, optional): how many subprocesses to use for data
loading. 0 means the data will be loaded in the main process
(default: 0).
epoch (int, optional): The epoch to start the iterator from.
"""
def __init__(
self, dataset, collate_fn, batch_sampler, seed=1, num_shards=1, shard_id=0,
num_workers=0,
num_workers=0, epoch=0,
):
assert isinstance(dataset, torch.utils.data.Dataset)
self.dataset = dataset
@ -94,7 +95,7 @@ class EpochBatchIterator(object):
self.shard_id = shard_id
self.num_workers = num_workers
self.epoch = 0
self.epoch = epoch
self._cur_epoch_itr = None
self._next_epoch_itr = None
self._supports_prefetch = getattr(dataset, 'supports_prefetch', False)

View File

@ -42,7 +42,8 @@ class CrossLingualLMTask(FairseqTask):
@staticmethod
def add_args(parser):
"""Add task-specific arguments to the parser."""
parser.add_argument('data', help='path to data directory')
parser.add_argument('data', help='colon separated path to data directories list, \
will be iterated upon during epochs in round-robin manner')
parser.add_argument('--tokens-per-sample', default=512, type=int,
help='max number of total tokens over all segments'
' per sample')
@ -106,12 +107,16 @@ class CrossLingualLMTask(FairseqTask):
return cls(args, dictionary)
def _load_single_lang_dataset(self, split):
def _load_single_lang_dataset(self, split, epoch):
loaded_datasets = []
paths = self.args.data.split(':')
assert len(paths) > 0
data_path = paths[epoch % len(paths)]
for k in itertools.count():
split_k = split + (str(k) if k > 0 else '')
path = os.path.join(self.args.data, split_k)
path = os.path.join(data_path, split_k)
if self.args.raw_text and IndexedRawTextDataset.exists(path):
ds = IndexedRawTextDataset(path, self.dictionary)
@ -124,7 +129,7 @@ class CrossLingualLMTask(FairseqTask):
if k > 0:
break
else:
raise FileNotFoundError('Dataset not found: {} ({})'.format(split, self.args.data))
raise FileNotFoundError('Dataset not found: {} ({})'.format(split, data_path))
# Since we append each block with the classification_token,
# we need to effectively create blocks of length
@ -136,7 +141,7 @@ class CrossLingualLMTask(FairseqTask):
)
)
print('| {} {} {} examples'.format(self.args.data, split_k, len(loaded_datasets[-1])))
print('| {} {} {} examples'.format(data_path, split_k, len(loaded_datasets[-1])))
if len(loaded_datasets) == 1:
dataset = loaded_datasets[0]
@ -147,7 +152,7 @@ class CrossLingualLMTask(FairseqTask):
return dataset, sizes
def load_dataset(self, split, combine=False, **kwargs):
def load_dataset(self, split, epoch=0, combine=False, **kwargs):
"""Load a given dataset split.
Args:
split (str): name of the split (e.g., train, valid, test)
@ -162,7 +167,7 @@ class CrossLingualLMTask(FairseqTask):
# Datasets are expected to be in "split.lang" format (Eg: train.en)
language_split = '{}.{}'.format(split, lang)
block_dataset, sizes = self._load_single_lang_dataset(split=language_split)
block_dataset, sizes = self._load_single_lang_dataset(split=language_split, epoch=epoch)
dataset_map[lang] = MaskedLMDataset(
dataset=block_dataset,
@ -182,6 +187,6 @@ class CrossLingualLMTask(FairseqTask):
dataset_map, default_key=self.default_key
)
print('| {} {} {} examples'.format(
self.args.data, split, len(self.datasets[split])
self.args.data.split(':')[epoch], split, len(self.datasets[split])
)
)
)

View File

@ -92,7 +92,7 @@ class FairseqTask(object):
def get_batch_iterator(
self, dataset, max_tokens=None, max_sentences=None, max_positions=None,
ignore_invalid_inputs=False, required_batch_size_multiple=1,
seed=1, num_shards=1, shard_id=0, num_workers=0,
seed=1, num_shards=1, shard_id=0, num_workers=0, epoch=0,
):
"""
Get an iterator that yields batches of data from the given dataset.
@ -118,6 +118,7 @@ class FairseqTask(object):
num_workers (int, optional): how many subprocesses to use for data
loading. 0 means the data will be loaded in the main process
(default: 0).
epoch (int, optional): The epoch to start the iterator from.
Returns:
~fairseq.iterators.EpochBatchIterator: a batched iterator over the
@ -149,6 +150,7 @@ class FairseqTask(object):
num_shards=num_shards,
shard_id=shard_id,
num_workers=num_workers,
epoch=epoch,
)
def build_model(self, args):

View File

@ -104,7 +104,9 @@ class LanguageModelingTask(FairseqTask):
dictionary = None
output_dictionary = None
if args.data:
dictionary = Dictionary.load(os.path.join(args.data, 'dict.txt'))
paths = args.data.split(':')
assert len(paths) > 0
dictionary = Dictionary.load(os.path.join(paths[0], 'dict.txt'))
print('| dictionary: {} types'.format(len(dictionary)))
output_dictionary = dictionary
if args.output_dictionary_size >= 0:
@ -136,7 +138,7 @@ class LanguageModelingTask(FairseqTask):
return model
def load_dataset(self, split, combine=False, **kwargs):
def load_dataset(self, split, epoch=0, combine=False, **kwargs):
"""Load a given dataset split.
Args:
@ -145,9 +147,13 @@ class LanguageModelingTask(FairseqTask):
loaded_datasets = []
paths = self.args.data.split(':')
assert len(paths) > 0
data_path = paths[epoch % len(paths)]
for k in itertools.count():
split_k = split + (str(k) if k > 0 else '')
path = os.path.join(self.args.data, split_k)
path = os.path.join(data_path, split_k)
if self.args.raw_text and IndexedRawTextDataset.exists(path):
ds = IndexedRawTextDataset(path, self.dictionary)
@ -160,7 +166,7 @@ class LanguageModelingTask(FairseqTask):
if k > 0:
break
else:
raise FileNotFoundError('Dataset not found: {} ({})'.format(split, self.args.data))
raise FileNotFoundError('Dataset not found: {} ({})'.format(split, data_path))
loaded_datasets.append(
TokenBlockDataset(
@ -170,7 +176,7 @@ class LanguageModelingTask(FairseqTask):
)
)
print('| {} {} {} examples'.format(self.args.data, split_k, len(loaded_datasets[-1])))
print('| {} {} {} examples'.format(data_path, split_k, len(loaded_datasets[-1])))
if not combine:
break

View File

@ -135,7 +135,9 @@ class MultilingualTranslationTask(FairseqTask):
# load dictionaries
dicts = OrderedDict()
for lang in sorted_langs:
dicts[lang] = Dictionary.load(os.path.join(args.data, 'dict.{}.txt'.format(lang)))
paths = args.data.split(':')
assert len(paths) > 0
dicts[lang] = Dictionary.load(os.path.join(paths[0], 'dict.{}.txt'.format(lang)))
if len(dicts) > 0:
assert dicts[lang].pad() == dicts[sorted_langs[0]].pad()
assert dicts[lang].eos() == dicts[sorted_langs[0]].eos()
@ -185,11 +187,15 @@ class MultilingualTranslationTask(FairseqTask):
new_tgt_bos=new_tgt_bos,
)
def load_dataset(self, split, **kwargs):
def load_dataset(self, split, epoch=0, **kwargs):
"""Load a dataset split."""
paths = self.args.data.split(':')
assert len(paths) > 0
data_path = paths[epoch % len(paths)]
def split_exists(split, src, tgt, lang):
filename = os.path.join(self.args.data, '{}.{}-{}.{}'.format(split, src, tgt, lang))
filename = os.path.join(data_path, '{}.{}-{}.{}'.format(split, src, tgt, lang))
if self.args.raw_text and IndexedRawTextDataset.exists(filename):
return True
elif not self.args.raw_text and IndexedDataset.exists(filename):
@ -210,17 +216,17 @@ class MultilingualTranslationTask(FairseqTask):
for lang_pair in self.args.lang_pairs:
src, tgt = lang_pair.split('-')
if split_exists(split, src, tgt, src):
prefix = os.path.join(self.args.data, '{}.{}-{}.'.format(split, src, tgt))
prefix = os.path.join(data_path, '{}.{}-{}.'.format(split, src, tgt))
elif split_exists(split, tgt, src, src):
prefix = os.path.join(self.args.data, '{}.{}-{}.'.format(split, tgt, src))
prefix = os.path.join(data_path, '{}.{}-{}.'.format(split, tgt, src))
else:
continue
src_datasets[lang_pair] = indexed_dataset(prefix + src, self.dicts[src])
tgt_datasets[lang_pair] = indexed_dataset(prefix + tgt, self.dicts[tgt])
print('| {} {} {} examples'.format(self.args.data, split, len(src_datasets[lang_pair])))
print('| {} {} {} examples'.format(data_path, split, len(src_datasets[lang_pair])))
if len(src_datasets) == 0:
raise FileNotFoundError('Dataset not found: {} ({})'.format(split, self.args.data))
raise FileNotFoundError('Dataset not found: {} ({})'.format(split, data_path))
def language_pair_dataset(lang_pair):
src, tgt = lang_pair.split('-')

View File

@ -132,14 +132,18 @@ class SemisupervisedTranslationTask(MultilingualTranslationTask):
dicts, training = MultilingualTranslationTask.prepare(args, **kwargs)
return cls(args, dicts, training)
def load_dataset(self, split, **kwargs):
def load_dataset(self, split, epoch=0, **kwargs):
"""Load a dataset split."""
paths = self.args.data.split(':')
assert len(paths) > 0
data_path = paths[epoch % len(paths)]
def split_exists(split, src, tgt, lang):
if src is not None:
filename = os.path.join(self.args.data, '{}.{}-{}.{}'.format(split, src, tgt, lang))
filename = os.path.join(data_path, '{}.{}-{}.{}'.format(split, src, tgt, lang))
else:
filename = os.path.join(self.args.data, '{}.{}-None.{}'.format(split, src, tgt))
filename = os.path.join(data_path, '{}.{}-None.{}'.format(split, src, tgt))
if self.args.raw_text and IndexedRawTextDataset.exists(filename):
return True
elif not self.args.raw_text and IndexedDataset.exists(filename):
@ -162,16 +166,16 @@ class SemisupervisedTranslationTask(MultilingualTranslationTask):
for lang_pair in self.args.lang_pairs:
src, tgt = lang_pair.split('-')
if split_exists(split, src, tgt, src):
prefix = os.path.join(self.args.data, '{}.{}-{}.'.format(split, src, tgt))
prefix = os.path.join(data_path, '{}.{}-{}.'.format(split, src, tgt))
elif split_exists(split, tgt, src, src):
prefix = os.path.join(self.args.data, '{}.{}-{}.'.format(split, tgt, src))
prefix = os.path.join(data_path, '{}.{}-{}.'.format(split, tgt, src))
else:
continue
src_datasets[lang_pair] = indexed_dataset(prefix + src, self.dicts[src])
tgt_datasets[lang_pair] = indexed_dataset(prefix + tgt, self.dicts[tgt])
print('| parallel-{} {} {} examples'.format(self.args.data, split, len(src_datasets[lang_pair])))
print('| parallel-{} {} {} examples'.format(data_path, split, len(src_datasets[lang_pair])))
if len(src_datasets) == 0:
raise FileNotFoundError('Dataset not found: {} ({})'.format(split, self.args.data))
raise FileNotFoundError('Dataset not found: {} ({})'.format(split, data_path))
# back translation datasets
backtranslate_datasets = {}
@ -179,8 +183,8 @@ class SemisupervisedTranslationTask(MultilingualTranslationTask):
for lang_pair in self.args.lang_pairs:
src, tgt = lang_pair.split('-')
if not split_exists(split, tgt, None, tgt):
raise FileNotFoundError('Dataset not found: backtranslation {} ({})'.format(split, self.args.data))
filename = os.path.join(self.args.data, '{}.{}-None.{}'.format(split, tgt, tgt))
raise FileNotFoundError('Dataset not found: backtranslation {} ({})'.format(split, data_path))
filename = os.path.join(data_path, '{}.{}-None.{}'.format(split, tgt, tgt))
dataset = indexed_dataset(filename, self.dicts[tgt])
lang_pair_dataset_tgt = LanguagePairDataset(
dataset,
@ -216,7 +220,7 @@ class SemisupervisedTranslationTask(MultilingualTranslationTask):
).collater,
)
print('| backtranslate-{}: {} {} {} examples'.format(
tgt, self.args.data, split, len(backtranslate_datasets[lang_pair]),
tgt, data_path, split, len(backtranslate_datasets[lang_pair]),
))
self.backtranslate_datasets[lang_pair] = backtranslate_datasets[lang_pair]
@ -227,7 +231,7 @@ class SemisupervisedTranslationTask(MultilingualTranslationTask):
_, tgt = lang_pair.split('-')
if not split_exists(split, tgt, None, tgt):
continue
filename = os.path.join(self.args.data, '{}.{}-None.{}'.format(split, tgt, tgt))
filename = os.path.join(data_path, '{}.{}-None.{}'.format(split, tgt, tgt))
tgt_dataset1 = indexed_dataset(filename, self.dicts[tgt])
tgt_dataset2 = indexed_dataset(filename, self.dicts[tgt])
noising_dataset = NoisingDataset(
@ -255,7 +259,7 @@ class SemisupervisedTranslationTask(MultilingualTranslationTask):
tgt_lang=tgt,
)
print('| denoising-{}: {} {} {} examples'.format(
tgt, self.args.data, split, len(noising_datasets[lang_pair]),
tgt, data_path, split, len(noising_datasets[lang_pair]),
))
def language_pair_dataset(lang_pair):

View File

@ -48,7 +48,8 @@ class TranslationTask(FairseqTask):
def add_args(parser):
"""Add task-specific arguments to the parser."""
# fmt: off
parser.add_argument('data', nargs='+', help='path(s) to data directorie(s)')
parser.add_argument('data', help='colon separated path to data directories list, \
will be iterated upon during epochs in round-robin manner')
parser.add_argument('-s', '--source-lang', default=None, metavar='SRC',
help='source language')
parser.add_argument('-t', '--target-lang', default=None, metavar='TARGET',
@ -84,19 +85,17 @@ class TranslationTask(FairseqTask):
args.left_pad_source = options.eval_bool(args.left_pad_source)
args.left_pad_target = options.eval_bool(args.left_pad_target)
# upgrade old checkpoints
if isinstance(args.data, str):
args.data = [args.data]
paths = args.data.split(':')
assert len(paths) > 0
# find language pair automatically
if args.source_lang is None or args.target_lang is None:
args.source_lang, args.target_lang = data_utils.infer_language_pair(args.data[0])
args.source_lang, args.target_lang = data_utils.infer_language_pair(paths[0])
if args.source_lang is None or args.target_lang is None:
raise Exception('Could not infer language pair, please provide it explicitly')
# load dictionaries
src_dict = cls.load_dictionary(os.path.join(args.data[0], 'dict.{}.txt'.format(args.source_lang)))
tgt_dict = cls.load_dictionary(os.path.join(args.data[0], 'dict.{}.txt'.format(args.target_lang)))
src_dict = cls.load_dictionary(os.path.join(paths[0], 'dict.{}.txt'.format(args.source_lang)))
tgt_dict = cls.load_dictionary(os.path.join(paths[0], 'dict.{}.txt'.format(args.target_lang)))
assert src_dict.pad() == tgt_dict.pad()
assert src_dict.eos() == tgt_dict.eos()
assert src_dict.unk() == tgt_dict.unk()
@ -105,12 +104,15 @@ class TranslationTask(FairseqTask):
return cls(args, src_dict, tgt_dict)
def load_dataset(self, split, combine=False, **kwargs):
def load_dataset(self, split, epoch=0, combine=False, **kwargs):
"""Load a given dataset split.
Args:
split (str): name of the split (e.g., train, valid, test)
"""
paths = self.args.data.split(':')
assert len(paths) > 0
data_path = paths[epoch % len(paths)]
def split_exists(split, src, tgt, lang, data_path):
filename = os.path.join(data_path, '{}.{}-{}.{}'.format(split, src, tgt, lang))
@ -133,29 +135,28 @@ class TranslationTask(FairseqTask):
src_datasets = []
tgt_datasets = []
for dk, data_path in enumerate(self.args.data):
for k in itertools.count():
split_k = split + (str(k) if k > 0 else '')
for k in itertools.count():
split_k = split + (str(k) if k > 0 else '')
# infer langcode
src, tgt = self.args.source_lang, self.args.target_lang
if split_exists(split_k, src, tgt, src, data_path):
prefix = os.path.join(data_path, '{}.{}-{}.'.format(split_k, src, tgt))
elif split_exists(split_k, tgt, src, src, data_path):
prefix = os.path.join(data_path, '{}.{}-{}.'.format(split_k, tgt, src))
else:
if k > 0 or dk > 0:
break
else:
raise FileNotFoundError('Dataset not found: {} ({})'.format(split, data_path))
src_datasets.append(indexed_dataset(prefix + src, self.src_dict))
tgt_datasets.append(indexed_dataset(prefix + tgt, self.tgt_dict))
print('| {} {} {} examples'.format(data_path, split_k, len(src_datasets[-1])))
if not combine:
# infer langcode
src, tgt = self.args.source_lang, self.args.target_lang
if split_exists(split_k, src, tgt, src, data_path):
prefix = os.path.join(data_path, '{}.{}-{}.'.format(split_k, src, tgt))
elif split_exists(split_k, tgt, src, src, data_path):
prefix = os.path.join(data_path, '{}.{}-{}.'.format(split_k, tgt, src))
else:
if k > 0:
break
else:
raise FileNotFoundError('Dataset not found: {} ({})'.format(split, data_path))
src_datasets.append(indexed_dataset(prefix + src, self.src_dict))
tgt_datasets.append(indexed_dataset(prefix + tgt, self.tgt_dict))
print('| {} {} {} examples'.format(data_path, split_k, len(src_datasets[-1])))
if not combine:
break
assert len(src_datasets) == len(tgt_datasets)

View File

@ -68,10 +68,13 @@ class TestLoadCheckpoint(unittest.TestCase):
[p.start() for p in self.applied_patches]
def test_load_partial_checkpoint(self):
with contextlib.redirect_stdout(StringIO()):
trainer, epoch_itr = get_trainer_and_epoch_itr(2, 150, 200, 50)
train.load_checkpoint(self.args_mock, trainer, epoch_itr)
with patch('train.reload_train', return_value=epoch_itr):
train.load_checkpoint(self.args_mock, trainer, epoch_itr, 512, None)
self.assertEqual(epoch_itr.epoch, 2)
self.assertEqual(epoch_itr.iterations_in_epoch, 50)
@ -86,7 +89,8 @@ class TestLoadCheckpoint(unittest.TestCase):
with contextlib.redirect_stdout(StringIO()):
trainer, epoch_itr = get_trainer_and_epoch_itr(2, 150, 300, 150)
train.load_checkpoint(self.args_mock, trainer, epoch_itr)
with patch('train.reload_train', return_value=epoch_itr):
train.load_checkpoint(self.args_mock, trainer, epoch_itr, 512, None)
itr = epoch_itr.next_epoch_itr(shuffle=False)
self.assertEqual(epoch_itr.epoch, 3)
@ -98,7 +102,7 @@ class TestLoadCheckpoint(unittest.TestCase):
trainer, epoch_itr = get_trainer_and_epoch_itr(0, 150, 0, 0)
self.patches['os.path.isfile'].return_value = False
train.load_checkpoint(self.args_mock, trainer, epoch_itr)
train.load_checkpoint(self.args_mock, trainer, epoch_itr, 512, None)
itr = epoch_itr.next_epoch_itr(shuffle=False)
self.assertEqual(epoch_itr.epoch, 1)

View File

@ -44,7 +44,9 @@ def main(args, init_distributed=False):
task = tasks.setup_task(args)
# Load dataset splits
load_dataset_splits(args, task)
task.load_dataset(args.train_subset, combine=True, epoch=0)
for valid_sub_split in args.valid_subset.split(','):
task.load_dataset(valid_sub_split, combine=True, epoch=0)
# Build model and criterion
model = task.build_model(args)
@ -64,15 +66,16 @@ def main(args, init_distributed=False):
args.max_sentences,
))
max_positions = utils.resolve_max_positions(
task.max_positions(),
model.max_positions(),
)
# Initialize dataloader
epoch_itr = task.get_batch_iterator(
dataset=task.dataset(args.train_subset),
max_tokens=args.max_tokens,
max_sentences=args.max_sentences,
max_positions=utils.resolve_max_positions(
task.max_positions(),
model.max_positions(),
),
max_positions=max_positions,
ignore_invalid_inputs=True,
required_batch_size_multiple=args.required_batch_size_multiple,
seed=args.seed,
@ -82,7 +85,7 @@ def main(args, init_distributed=False):
)
# Load the latest checkpoint if one is available
load_checkpoint(args, trainer, epoch_itr)
load_checkpoint(args, trainer, epoch_itr, max_positions, task)
# Train until the learning rate gets too small
max_epoch = args.max_epoch or math.inf
@ -105,10 +108,34 @@ def main(args, init_distributed=False):
# save checkpoint
if epoch_itr.epoch % args.save_interval == 0:
save_checkpoint(args, trainer, epoch_itr, valid_losses[0])
epoch_itr = reload_train(args, epoch_itr, max_positions, task)
train_meter.stop()
print('| done training in {:.1f} seconds'.format(train_meter.sum))
def reload_train(args, epoch_itr, max_positions, task):
# nothing needs to be done when the dataset is not sharded.
if len(args.data.split(":")) == 1:
return epoch_itr
print("| Reloading shard of train data at epoch: ", epoch_itr.epoch)
task.load_dataset(args.train_subset, combine=True, epoch=epoch_itr.epoch)
epoch_itr = task.get_batch_iterator(
dataset=task.dataset(args.train_subset),
max_tokens=args.max_tokens,
max_sentences=args.max_sentences,
max_positions=max_positions,
ignore_invalid_inputs=True,
required_batch_size_multiple=args.required_batch_size_multiple,
seed=args.seed,
num_shards=args.distributed_world_size,
shard_id=args.distributed_rank,
num_workers=args.num_workers,
epoch=epoch_itr.epoch,
)
return epoch_itr
def train(args, trainer, task, epoch_itr):
"""Train the model for one epoch."""
# Update parameters every N batches
@ -335,9 +362,8 @@ def save_checkpoint(args, trainer, epoch_itr, val_loss):
os.remove(old_chk)
def load_checkpoint(args, trainer, epoch_itr):
def load_checkpoint(args, trainer, epoch_itr, max_positions, task):
"""Load a checkpoint and replay dataloader to match."""
# Only rank 0 should attempt to create the required dir
if args.distributed_rank == 0:
os.makedirs(args.save_dir, exist_ok=True)
@ -351,7 +377,14 @@ def load_checkpoint(args, trainer, epoch_itr):
eval(args.optimizer_overrides))
if extra_state is not None:
# replay train iterator to match checkpoint
epoch_itr.load_state_dict(extra_state['train_iterator'])
epoch_itr_state = extra_state['train_iterator']
# If the loaded checkpoint is not at epoch 0, reload train dataset,
# as it could be potentially sharded.
if epoch_itr_state['epoch'] != 0:
epoch_itr = reload_train(args, epoch_itr, max_positions, task)
epoch_itr.load_state_dict(epoch_itr_state)
print('| loaded checkpoint {} (epoch {} @ {} updates)'.format(
checkpoint_path, epoch_itr.epoch, trainer.get_num_updates()))
@ -366,19 +399,6 @@ def load_checkpoint(args, trainer, epoch_itr):
return False
def load_dataset_splits(args, task):
task.load_dataset(args.train_subset, combine=True)
for split in args.valid_subset.split(','):
for k in itertools.count():
split_k = split + (str(k) if k > 0 else '')
try:
task.load_dataset(split_k, combine=False)
except FileNotFoundError as e:
if k > 0:
break
raise e
def distributed_main(i, args, start_rank=0):
args.device_id = i
if args.distributed_rank is None: # torch.multiprocessing.spawn