mirror of
https://github.com/facebookresearch/fairseq.git
synced 2024-10-26 17:32:57 +03:00
656d7e5779
Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1667 Add support for FullyShardedDataParallel (--ddp-backend=fully_sharded) This enables fully parameter + optimizer state sharding by using FullyShardedDataParallel (FSDP) from fairscale. The user just needs to provide `--ddp-backend=fully_sharded` to enable. Other common options work out-of-the-box (e.g., `--fp16`, `--memory-efficient-fp16`, `--update-freq`, etc.). This should be a drop-in replacement for the "c10d" backend. This yields pretty big speedups for small models and enables training ~13B parameter models on 8 GPUs and 175B parameter models on 128 GPUs, without model parallelism. This also adds a new option `--cpu-offload` that offloads the optimizer state and FP32 model copy to CPU, which is particularly useful when combined with `--optimizer=cpu_adam`. Note: after enabling this, each GPU will save a checkpoint file, since the optimizer state is sharded. Each checkpoint will contain a single shard of the optimizer state and the rank 0 checkpoint will contain the full model weights. Note: a known limitation of the current implementation is that you cannot resume training on a different world_size. This constraint will be relaxed in future iterations. Test Plan: Imported from OSS Reviewed By: sshleifer Differential Revision: D26771144 Pulled By: myleott fbshipit-source-id: 74c2f46f57719e24e2dcfc9d9ee7c2fc0aeedb46
67 lines
2.8 KiB
Python
67 lines
2.8 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
#
|
|
# This source code is licensed under the MIT license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import logging
|
|
import unittest
|
|
from typing import Sequence
|
|
|
|
from fairseq.data import LanguagePairDataset, ListDataset, RoundRobinZipDatasets
|
|
from tests.test_train import mock_dict
|
|
|
|
|
|
def lang_pair_dataset(lengths: Sequence[int]) -> LanguagePairDataset:
|
|
tokens = [[i] * l for i, l in enumerate(lengths)]
|
|
return LanguagePairDataset(ListDataset(tokens), lengths, mock_dict())
|
|
|
|
|
|
def sample(id: int, length: int):
|
|
return {"id": id, "source": [id] * length, "target": None}
|
|
|
|
|
|
class TestDataset(unittest.TestCase):
|
|
def setUp(self):
|
|
logging.disable(logging.CRITICAL)
|
|
|
|
def tearDown(self):
|
|
logging.disable(logging.NOTSET)
|
|
|
|
def test_round_robin_zip_datasets(self):
|
|
long_dataset = lang_pair_dataset([10, 9, 8, 11])
|
|
short_dataset = lang_pair_dataset([11, 9])
|
|
|
|
dataset = RoundRobinZipDatasets({"a": long_dataset, "b": short_dataset})
|
|
# Dataset is now sorted by sentence length
|
|
dataset.ordered_indices()
|
|
assert dataset.longest_dataset is long_dataset
|
|
self.assertEqual(dict(dataset[0]), {"a": sample(2, 8), "b": sample(1, 9)})
|
|
# The item 2 of dataset 'a' is with item (2 % 2 = 0) of dataset 'b'
|
|
self.assertEqual(dict(dataset[2]), {"a": sample(0, 10), "b": sample(1, 9)})
|
|
|
|
def test_round_robin_zip_datasets_filtered(self):
|
|
long_dataset = lang_pair_dataset([10, 20, 8, 11, 1000, 7, 12])
|
|
short_dataset = lang_pair_dataset([11, 20, 9, 1000])
|
|
|
|
dataset = RoundRobinZipDatasets({"a": long_dataset, "b": short_dataset})
|
|
# Dataset is now sorted by sentence length
|
|
idx = dataset.ordered_indices()
|
|
idx, _ = dataset.filter_indices_by_size(idx, {"a": 19, "b": 900})
|
|
self.assertEqual(list(idx), [0, 1, 2, 3, 4])
|
|
self.assertEqual(dict(dataset[0]), {"a": sample(5, 7), "b": sample(2, 9)})
|
|
self.assertEqual(dict(dataset[2]), {"a": sample(0, 10), "b": sample(1, 20)})
|
|
self.assertEqual(dict(dataset[4]), {"a": sample(6, 12), "b": sample(0, 11)})
|
|
|
|
def test_round_robin_zip_datasets_filtered_with_tuple(self):
|
|
long_dataset = lang_pair_dataset([10, 20, 8, 11, 1000, 7, 12])
|
|
short_dataset = lang_pair_dataset([11, 20, 9, 1000])
|
|
|
|
dataset = RoundRobinZipDatasets({"a": long_dataset, "b": short_dataset})
|
|
# Dataset is now sorted by sentence length
|
|
idx = dataset.ordered_indices()
|
|
idx, _ = dataset.filter_indices_by_size(idx, 19)
|
|
self.assertEqual(list(idx), [0, 1, 2, 3, 4])
|
|
self.assertEqual(dict(dataset[0]), {"a": sample(5, 7), "b": sample(2, 9)})
|
|
self.assertEqual(dict(dataset[2]), {"a": sample(0, 10), "b": sample(2, 9)})
|
|
self.assertEqual(dict(dataset[4]), {"a": sample(6, 12), "b": sample(2, 9)})
|