Add README/tutorial for Fully Sharded Data Parallel (#3327)

Summary: Pull Request resolved: https://github.com/pytorch/fairseq/pull/3327

Reviewed By: sshleifer

Differential Revision: D26899416

Pulled By: myleott

fbshipit-source-id: bbb493a5c4e0a51f3b26fe8f94e3962b6206d6f6
This commit is contained in:
Myle Ott 2021-03-09 06:28:23 -08:00 committed by Facebook GitHub Bot
parent 16c1a200f8
commit 00d5b7adbe
9 changed files with 363 additions and 41 deletions

View File

@ -39,7 +39,8 @@ jobs:
- name: Install optional test requirements
run: |
python -m pip install fairscale iopath transformers pyarrow
python -m pip install iopath transformers pyarrow
python -m pip install git+https://github.com/facebookresearch/fairscale.git@master
- name: Lint with flake8
run: |

View File

@ -61,6 +61,9 @@ We provide reference implementations of various sequence modeling papers:
### What's New:
* March 2021 [Added full parameter and optimizer state sharding + CPU offloading](examples/fully_sharded_data_parallel/README.md)
* February 2021 [Added LASER training code](examples/laser/README.md)
* December 2020: [Added Adaptive Attention Span code](examples/adaptive_span/README.md)
* December 2020: [GottBERT model and code released](examples/gottbert/README.md)
* November 2020: Adopted the [Hydra](https://github.com/facebookresearch/hydra) configuration framework
* [see documentation explaining how to use it for new and existing projects](docs/hydra_integration.md)
@ -68,14 +71,14 @@ We provide reference implementations of various sequence modeling papers:
* October 2020: [Added R3F/R4F (Better Fine-Tuning) code](examples/rxf/README.md)
* October 2020: [Deep Transformer with Latent Depth code released](examples/latent_depth/README.md)
* October 2020: [Added CRISS models and code](examples/criss/README.md)
<details><summary>Previous updates</summary><p>
* September 2020: [Added Linformer code](examples/linformer/README.md)
* September 2020: [Added pointer-generator networks](examples/pointer_generator/README.md)
* August 2020: [Added lexically constrained decoding](examples/constrained_decoding/README.md)
* August 2020: [wav2vec2 models and code released](examples/wav2vec/README.md)
* July 2020: [Unsupervised Quality Estimation code released](examples/unsupervised_quality_estimation/README.md)
<details><summary>Previous updates</summary><p>
* May 2020: [Follow fairseq on Twitter](https://twitter.com/fairseq)
* April 2020: [Monotonic Multihead Attention code released](examples/simultaneous_translation/README.md)
* April 2020: [Quant-Noise code released](examples/quant_noise/README.md)
@ -108,6 +111,8 @@ We provide reference implementations of various sequence modeling papers:
* [mixed precision training](https://fairseq.readthedocs.io/en/latest/getting_started.html#training-with-half-precision-floating-point-fp16) (trains faster with less GPU memory on [NVIDIA tensor cores](https://developer.nvidia.com/tensor-cores))
* [extensible](https://fairseq.readthedocs.io/en/latest/overview.html): easily register new models, criterions, tasks, optimizers and learning rate schedulers
* [flexible configuration](docs/hydra_integration.md) based on [Hydra](https://github.com/facebookresearch/hydra) allowing a combination of code, command-line and file based configuration
* [full parameter and optimizer state sharding](examples/fully_sharded_data_parallel/README.md)
* [offloading parameters to CPU](examples/fully_sharded_data_parallel/README.md)
We also provide [pre-trained models for translation and language modeling](#pre-trained-models-and-examples)
with a convenient `torch.hub` interface:

View File

@ -0,0 +1,164 @@
# Fully Sharded Data Parallel (FSDP)
## Overview
Recent work by [Microsoft](https://arxiv.org/abs/1910.02054) and
[Google](https://arxiv.org/abs/2004.13336) has shown that data parallel
training can be made significantly more efficient by sharding the model
parameters and optimizer state across data parallel workers. These ideas are
encapsulated in the new **`FullyShardedDataParallel` (FSDP)** wrapper provided
by [fairscale](https://github.com/facebookresearch/fairscale/).
Compared to PyTorch DDP:
* FSDP produces identical results as PyTorch DDP (it's still synchronous data parallel training)
* FSDP shards parameters (FP16 + FP32) and optimizer state across data parallel GPUs
* FSDP is faster than PyTorch DDP because the optimizer step is sharded, and the communication can be overlapped with the forward pass
* FSDP enables training 13B parameter models on 8 GPUs and 175B parameter models on 128 GPUs
FSDP is fully supported in fairseq via the following new arguments:
* `--ddp-backend=fully_sharded`: enables full sharding via FSDP
* `--cpu-offload`: offloads the optimizer state and FP32 model copy to CPU (combine with `--optimizer=cpu_adam`)
* `--no-reshard-after-forward`: increases training speed for some models and is similar to ZeRO stage 2
* other popular options (`--fp16`, `--update-freq`, `--checkpoint-activations`, `--offload-activations`, etc.) continue to work as normal
<details><summary>Limitations</summary><p>
FSDP currently has several limitations compared to fairseq's default DDP backend (PyTorch DDP):
* while FSDP is full compatible with pointwise Optimizers (e.g., Adam, AdamW, Adadelta, Adamax, SGD, etc.), it is not currently compatible with non-pointwise Optimizers (e.g., Adagrad, Adafactor, LAMB, etc.)
* FSDP depends on flattening the parameters, so models that currently require `--fp16-no-flatten-grads` may not be supported
See the [fairscale docs](https://fairscale.readthedocs.io/en/latest/api/nn/fsdp_tips.html) for a more detailed
explanation of these and other limitations.
</p></details>
<details><summary>How it works</summary><p>
<img width="800" alt="Fully Sharded Data Parallel" src="https://user-images.githubusercontent.com/231798/110406775-c2de0000-8050-11eb-9718-fbfc4510a76a.png">
See the [fairscale docs](https://fairscale.readthedocs.io/en/latest/api/nn/fsdp_tips.html) for a more detailed
explanation of how FSDP works.
</p></details>
## Example usage
The following examples illustrate how to train a very large language model with
13 billion parameters on 1 GPU by offloading parameters and optimizer states to
CPU, or on 8 GPUs by fully sharding the params and optimizer states across GPUs.
These examples use the WikiText-103 dataset for demonstration purposes, but
in practice a much larger dataset will be needed to achieve good results.
Follow the [instructions here](https://github.com/pytorch/fairseq/blob/master/examples/roberta/README.pretraining.md#1-preprocess-the-data)
to preprocess the WikiText-103 dataset using the GPT-2/RoBERTa vocabulary.
### 13B params on 1 V100 GPU (with CPU offloading)
The following command trains a 13B parameter GPT-3 model on a single V100 GPU
using the `--cpu-offload` feature to offload parameters and optimizer states to
CPU. In this setting, the optimizer step (Adam) happens on CPU. We also use the
`--checkpoint-activations` feature (sometimes called [gradient checkpointing](https://pytorch.org/docs/stable/checkpoint.html)),
which further saves memory in exchange for a small increase in computation.
Requirements:
- You'll need 32GB of GPU memory and 256GB of system memory.
- We use the CPU Adam optimizer from [DeepSpeed](https://github.com/microsoft/DeepSpeed), so you'll need to `pip install deepspeed` before running the command.
Some notes:
- The command will take ~5 minutes to start training, during which time it will appear to be hung, since randomly initializing 13B weights can be slow.
- The `--cpu-offload` feature requires training in mixed precision (`--fp16`).
- Tune the `OMP_NUM_THREADS` env variable for best performance with CPU offloading.
- The example command below stops training after 10 steps (`--max-update 10`) and does not save checkpoints (`--no-save`).
```bash
OMP_NUM_THREADS=20 CUDA_VISIBLE_DEVICES=0 \
fairseq-train data-bin/wikitext-103-roberta-bpe-bin \
--ddp-backend fully_sharded --fp16 --fp16-init-scale 4 \
--cpu-offload --checkpoint-activations \
--task language_modeling --tokens-per-sample 2048 --batch-size 8 \
--arch transformer_lm_gpt3_13 \
--optimizer cpu_adam --adam-betas "(0.9,0.98)" \
--lr 0.0001 --lr-scheduler polynomial_decay --warmup-updates 5 --total-num-update 10 \
--max-update 10 --no-save --log-format json --log-interval 1
# Example output:
# (...)
# 2021-03-08 12:29:51 | INFO | fairseq_cli.train | num. model params: 13,110,865,920 (num. trained: 13,110,865,920)
# (...)
# 2021-03-08 12:29:51 | INFO | fairseq_cli.train | training on 1 devices (GPUs/TPUs)
# 2021-03-08 12:29:51 | INFO | fairseq_cli.train | max tokens per GPU = None and batch size per GPU = 8
# (...)
# Adam Optimizer #0 is created with AVX2 arithmetic capability.
# Config: alpha=0.000100, betas=(0.900000, 0.980000), weight_decay=0.000000, adam_w=1
# (...)
# 2021-03-08 12:31:36 | INFO | train_inner | {"epoch": 1, "update": 0.0, "loss": "16.475", "ppl": "91120.8", "wps": "0", "ups": "0", "wpb": "16384", "bsz": "8", "num_updates": "1", "lr": "2e-05", "gnorm": "20.751", "loss_scale": "4", "train_wall": "99", "gb_free": "9.3", "wall": "105"}
# 2021-03-08 12:32:33 | INFO | train_inner | {"epoch": 1, "update": 0.0, "loss": "16.446", "ppl": "89281.6", "wps": "288.7", "ups": "0.02", "wpb": "16384", "bsz": "8", "num_updates": "2", "lr": "4e-05", "gnorm": "19.777", "loss_scale": "4", "train_wall": "57", "gb_free": "9.3", "wall": "161"}
# 2021-03-08 12:33:12 | INFO | fairseq.trainer | NOTE: gradient overflow detected, ignoring gradient, setting loss scale to: 2.0
# 2021-03-08 12:33:51 | INFO | fairseq.trainer | NOTE: gradient overflow detected, ignoring gradient, setting loss scale to: 1.0
# 2021-03-08 12:34:45 | INFO | train_inner | {"epoch": 1, "update": 0.001, "loss": "25.22", "ppl": "3.90691e+07", "wps": "123.4", "ups": "0.01", "wpb": "16384", "bsz": "8", "num_updates": "3", "lr": "6e-05", "gnorm": "131.281", "loss_scale": "1", "train_wall": "133", "gb_free": "9.3", "wall": "294"}
# 2021-03-08 12:35:43 | INFO | train_inner | {"epoch": 1, "update": 0.001, "loss": "18.079", "ppl": "276809", "wps": "285.5", "ups": "0.02", "wpb": "16384", "bsz": "8", "num_updates": "4", "lr": "8e-05", "gnorm": "13.776", "loss_scale": "1", "train_wall": "57", "gb_free": "9.3", "wall": "351"}
# 2021-03-08 12:36:35 | INFO | train_inner | {"epoch": 1, "update": 0.001, "loss": "23.729", "ppl": "1.39088e+07", "wps": "316.7", "ups": "0.02", "wpb": "16384", "bsz": "8", "num_updates": "5", "lr": "0.0001", "gnorm": "72.774", "loss_scale": "1", "train_wall": "52", "gb_free": "9.3", "wall": "403"}
# 2021-03-08 12:37:28 | INFO | train_inner | {"epoch": 1, "update": 0.001, "loss": "20.429", "ppl": "1.41203e+06", "wps": "307.6", "ups": "0.02", "wpb": "16384", "bsz": "8", "num_updates": "6", "lr": "8e-05", "gnorm": "60.846", "loss_scale": "1", "train_wall": "53", "gb_free": "9.3", "wall": "456"}
# 2021-03-08 12:38:27 | INFO | train_inner | {"epoch": 1, "update": 0.001, "loss": "18.965", "ppl": "511684", "wps": "279.4", "ups": "0.02", "wpb": "16384", "bsz": "8", "num_updates": "7", "lr": "6e-05", "gnorm": "22.687", "loss_scale": "1", "train_wall": "59", "gb_free": "9.3", "wall": "515"}
# 2021-03-08 12:39:18 | INFO | train_inner | {"epoch": 1, "update": 0.001, "loss": "18.345", "ppl": "332887", "wps": "319.1", "ups": "0.02", "wpb": "16384", "bsz": "8", "num_updates": "8", "lr": "4e-05", "gnorm": "8.451", "loss_scale": "1", "train_wall": "51", "gb_free": "9.3", "wall": "566"}
# 2021-03-08 12:40:11 | INFO | train_inner | {"epoch": 1, "update": 0.002, "loss": "18.262", "ppl": "314336", "wps": "305.9", "ups": "0.02", "wpb": "16384", "bsz": "8", "num_updates": "9", "lr": "2e-05", "gnorm": "6.457", "loss_scale": "1", "train_wall": "54", "gb_free": "9.3", "wall": "620"}
# 2021-03-08 12:41:04 | INFO | train_inner | {"epoch": 1, "update": 0.002, "loss": "17.556", "ppl": "192686", "wps": "311.8", "ups": "0.02", "wpb": "16384", "bsz": "8", "num_updates": "10", "lr": "0", "gnorm": "5.796", "loss_scale": "1", "train_wall": "53", "gb_free": "9.3", "wall": "673"}
# 2021-03-08 12:41:04 | INFO | fairseq_cli.train | Stopping training due to num_updates: 10 >= max_update: 10
# 2021-03-08 12:41:04 | INFO | fairseq_cli.train | begin validation on "valid" subset
# 2021-03-08 12:43:15 | INFO | valid | {"epoch": 1, "valid_loss": "17.953", "valid_ppl": "253807", "valid_wps": "1868.4", "valid_wpb": "15400.2", "valid_bsz": "7.6", "valid_num_updates": "10"}
# 2021-03-08 12:43:15 | INFO | fairseq_cli.train | end of epoch 1 (average epoch stats below)
# 2021-03-08 12:43:15 | INFO | train | {"epoch": 1, "train_loss": "19.351", "train_ppl": "668509", "train_wps": "210.9", "train_ups": "0.01", "train_wpb": "16384", "train_bsz": "8", "train_num_updates": "10", "train_lr": "0", "train_gnorm": "36.26", "train_loss_scale": "1", "train_train_wall": "667", "train_gb_free": "9.3", "train_wall": "804"}
# 2021-03-08 12:43:15 | INFO | fairseq_cli.train | done training in 798.6 seconds
```
### 13B params on 8 V100 GPUs (with full parameter + optimizer state sharding)
FSDP can also shard the parameters and optimizer states across multiple GPUs,
reducing memory requirements significantly. On 8 GPUs, sharding enables
training the same 13B parameter model *without offloading the parameters to
CPU*. However, without CPU offloading we'd only be able to fit a batch size of
1 per GPU, which would cause training speed to suffer.
We obtain the best performance on 8 GPUs by combining full sharding and CPU
offloading. The following command trains the same 13B parameter GPT-3 model as
before on 8 GPUs; training speed increases from ~310 -> ~3200 words per second.
```bash
OMP_NUM_THREADS=20 CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
fairseq-train data-bin/wikitext-103-roberta-bpe-bin \
--ddp-backend fully_sharded --fp16 --fp16-init-scale 4 \
--cpu-offload --checkpoint-activations \
--task language_modeling --tokens-per-sample 2048 --batch-size 8 \
--arch transformer_lm_gpt3_13 \
--optimizer cpu_adam --adam-betas "(0.9,0.98)" \
--lr 0.0001 --lr-scheduler polynomial_decay --warmup-updates 5 --total-num-update 10 \
--max-update 10 --no-save --log-format json --log-interval 1
# Example output:
# (...)
# 2021-03-08 18:04:09 | INFO | fairseq_cli.train | num. model params: 13,110,865,920 (num. trained: 13,110,865,920)
# (...)
# 2021-03-08 18:04:09 | INFO | fairseq_cli.train | training on 8 devices (GPUs/TPUs)
# 2021-03-08 18:04:09 | INFO | fairseq_cli.train | max tokens per GPU = None and batch size per GPU = 8
# (...)
# Adam Optimizer #0 is created with AVX2 arithmetic capability.
# Config: alpha=0.000100, betas=(0.900000, 0.980000), weight_decay=0.000000, adam_w=1
# (...)
# 2021-03-08 18:05:06 | INFO | train_inner | {"epoch": 1, "update": 0.001, "loss": "16.408", "ppl": "86945.6", "wps": "0", "ups": "0", "wpb": "131072", "bsz": "64", "num_updates": "1", "lr": "2e-05", "gnorm": "18.27", "loss_scale": "4", "train_wall": "47", "gb_free": "9.3", "wall": "56"}
# 2021-03-08 18:05:45 | INFO | train_inner | {"epoch": 1, "update": 0.002, "loss": "16.352", "ppl": "83644.3", "wps": "3283.4", "ups": "0.03", "wpb": "131072", "bsz": "64", "num_updates": "2", "lr": "4e-05", "gnorm": "18.411", "loss_scale": "4", "train_wall": "40", "gb_free": "9.3", "wall": "96"}
# 2021-03-08 18:06:21 | INFO | fairseq.trainer | NOTE: gradient overflow detected, ignoring gradient, setting loss scale to: 2.0
# 2021-03-08 18:06:56 | INFO | fairseq.trainer | NOTE: gradient overflow detected, ignoring gradient, setting loss scale to: 1.0
# 2021-03-08 18:07:37 | INFO | train_inner | {"epoch": 1, "update": 0.006, "loss": "23.682", "ppl": "1.34537e+07", "wps": "1176.6", "ups": "0.01", "wpb": "131072", "bsz": "64", "num_updates": "3", "lr": "6e-05", "gnorm": "119.682", "loss_scale": "1", "train_wall": "111", "gb_free": "9.3", "wall": "208"}
# 2021-03-08 18:08:18 | INFO | train_inner | {"epoch": 1, "update": 0.007, "loss": "18.988", "ppl": "519921", "wps": "3189.1", "ups": "0.02", "wpb": "131072", "bsz": "64", "num_updates": "4", "lr": "8e-05", "gnorm": "14.934", "loss_scale": "1", "train_wall": "41", "gb_free": "9.3", "wall": "249"}
# 2021-03-08 18:08:59 | INFO | train_inner | {"epoch": 1, "update": 0.008, "loss": "20.08", "ppl": "1.10798e+06", "wps": "3223.1", "ups": "0.02", "wpb": "131072", "bsz": "64", "num_updates": "5", "lr": "0.0001", "gnorm": "59.92", "loss_scale": "1", "train_wall": "41", "gb_free": "9.3", "wall": "289"}
# 2021-03-08 18:09:39 | INFO | train_inner | {"epoch": 1, "update": 0.009, "loss": "18.323", "ppl": "327980", "wps": "3256.6", "ups": "0.02", "wpb": "131072", "bsz": "64", "num_updates": "6", "lr": "8e-05", "gnorm": "37.425", "loss_scale": "1", "train_wall": "40", "gb_free": "9.3", "wall": "330"}
# 2021-03-08 18:10:20 | INFO | train_inner | {"epoch": 1, "update": 0.01, "loss": "17.264", "ppl": "157354", "wps": "3188.7", "ups": "0.02", "wpb": "131072", "bsz": "64", "num_updates": "7", "lr": "6e-05", "gnorm": "10.824", "loss_scale": "1", "train_wall": "41", "gb_free": "9.3", "wall": "371"}
# 2021-03-08 18:11:01 | INFO | train_inner | {"epoch": 1, "update": 0.011, "loss": "16.794", "ppl": "113647", "wps": "3230", "ups": "0.02", "wpb": "131072", "bsz": "64", "num_updates": "8", "lr": "4e-05", "gnorm": "5.616", "loss_scale": "1", "train_wall": "41", "gb_free": "9.3", "wall": "411"}
# 2021-03-08 18:11:39 | INFO | train_inner | {"epoch": 1, "update": 0.012, "loss": "16.706", "ppl": "106938", "wps": "3384", "ups": "0.03", "wpb": "131072", "bsz": "64", "num_updates": "9", "lr": "2e-05", "gnorm": "5.318", "loss_scale": "1", "train_wall": "39", "gb_free": "9.3", "wall": "450"}
# 2021-03-08 18:12:19 | INFO | train_inner | {"epoch": 1, "update": 0.013, "loss": "16.548", "ppl": "95796.2", "wps": "3274.4", "ups": "0.02", "wpb": "131072", "bsz": "64", "num_updates": "10", "lr": "0", "gnorm": "5.22", "loss_scale": "1", "train_wall": "40", "gb_free": "9.3", "wall": "490"}
# 2021-03-08 18:12:19 | INFO | fairseq_cli.train | Stopping training due to num_updates: 10 >= max_update: 10
# 2021-03-08 18:12:19 | INFO | fairseq_cli.train | begin validation on "valid" subset
# 2021-03-08 18:12:45 | INFO | valid | {"epoch": 1, "valid_loss": "16.624", "valid_ppl": "101000", "valid_wps": "10855.9", "valid_wpb": "123202", "valid_bsz": "60.5", "valid_num_updates": "10"}
# 2021-03-08 18:12:45 | INFO | fairseq_cli.train | end of epoch 1 (average epoch stats below)
# 2021-03-08 18:12:45 | INFO | train | {"epoch": 1, "train_loss": "18.114", "train_ppl": "283776", "train_wps": "2567.8", "train_ups": "0.02", "train_wpb": "131072", "train_bsz": "64", "train_num_updates": "10", "train_lr": "0", "train_gnorm": "29.562", "train_loss_scale": "1", "train_train_wall": "480", "train_gb_free": "9.3", "train_wall": "516"}
# 2021-03-08 18:12:45 | INFO | fairseq_cli.train | done training in 509.9 seconds
```

View File

@ -43,6 +43,13 @@ class FullyShardedDataParallel(FSDP):
super().__init__(*args, **kwargs)
self.use_sharded_state = use_sharded_state
@property
def unwrapped_module(self) -> torch.nn.Module:
if self.flatten_parameters:
return self.module.module
else:
return self.module
def state_dict(self, destination=None, prefix='', keep_vars=False):
if self.use_sharded_state:
return super().local_state_dict(
@ -94,7 +101,11 @@ def fsdp_enable_wrap(cfg: DistributedTrainingConfig, use_sharded_state: bool = F
"compute_dtype": torch.float16 if cfg.fp16 else torch.float32,
"bucket_cap_mb": cfg.bucket_cap_mb,
}
with enable_wrap(use_sharded_state=use_sharded_state, **fsdp_config):
with enable_wrap(
wrapper_cls=FullyShardedDataParallel,
use_sharded_state=use_sharded_state,
**fsdp_config,
):
yield
@ -109,14 +120,13 @@ def fsdp_wrap(module, min_num_params: Optional[int] = None, **kwargs):
"""
try:
from fairscale.nn import wrap
cls = FullyShardedDataParallel
if min_num_params is not None:
num_params = sum(p.numel() for p in module.parameters())
if num_params >= min_num_params:
return wrap(module, cls=cls, **kwargs)
return wrap(module, **kwargs)
else:
return module
else:
return wrap(module, cls=cls, **kwargs)
return wrap(module, **kwargs)
except ImportError:
return module

View File

@ -29,9 +29,10 @@ logger = logging.getLogger(__name__)
def check_type(module, expected_type):
if hasattr(module, "unwrapped_module"):
assert isinstance(module.unwrapped_module, expected_type)
assert isinstance(module.unwrapped_module, expected_type), \
f"{type(module.unwrapped_module)} != {expected_type}"
else:
assert isinstance(module, expected_type)
assert isinstance(module, expected_type), f"{type(module)} != {expected_type}"
class BaseFairseqModel(nn.Module):

View File

@ -18,7 +18,7 @@ from fairseq.models import (
register_model,
register_model_architecture,
)
from fairseq.models.transformer import TransformerEncoder
from fairseq.models.transformer import DEFAULT_MIN_PARAMS_TO_WRAP, TransformerEncoder
from fairseq.modules import LayerNorm
from fairseq.modules.quant_noise import quant_noise as apply_quant_noise_
from fairseq.modules.transformer_sentence_encoder import init_bert_params
@ -122,6 +122,11 @@ class RobertaModel(FairseqEncoderModel):
action="store_true",
help="(re-)register and load heads when loading checkpoints",
)
parser.add_argument(
"--untie-weights-roberta",
action="store_true",
help="Untie weights between embeddings and classifiers in RoBERTa",
)
# args for "Reducing Transformer Depth on Demand with Structured Dropout" (Fan et al., 2019)
parser.add_argument(
"--encoder-layerdrop",
@ -157,17 +162,26 @@ class RobertaModel(FairseqEncoderModel):
default=0,
help="scalar quantization noise and scalar quantization at training time",
)
parser.add_argument(
"--untie-weights-roberta",
action="store_true",
help="Untie weights between embeddings and classifiers in RoBERTa",
)
# args for "Better Fine-Tuning by Reducing Representational Collapse" (Aghajanyan et al. 2020)
parser.add_argument(
"--spectral-norm-classification-head",
action="store_true",
default=False,
help="Apply spectral normalization on the classification head",
)
# args for Fully Sharded Data Parallel (FSDP) training
parser.add_argument(
"--min-params-to-wrap",
type=int,
metavar="D",
default=DEFAULT_MIN_PARAMS_TO_WRAP,
help=(
"minimum number of params for a layer to be wrapped with FSDP() when "
"training with --ddp-backend=fully_sharded. Smaller values will "
"improve memory efficiency, but may make torch.distributed "
"communication less efficient due to smaller input sizes."
)
)
@classmethod
def build_model(cls, args, task):

View File

@ -36,6 +36,9 @@ DEFAULT_MAX_SOURCE_POSITIONS = 1024
DEFAULT_MAX_TARGET_POSITIONS = 1024
DEFAULT_MIN_PARAMS_TO_WRAP = int(1e8)
@register_model("transformer")
class TransformerModel(FairseqEncoderDecoderModel):
"""
@ -191,6 +194,16 @@ class TransformerModel(FairseqEncoderDecoderModel):
help='block size of quantization noise at training time')
parser.add_argument('--quant-noise-scalar', type=float, metavar='D', default=0,
help='scalar quantization noise and scalar quantization at training time')
# args for Fully Sharded Data Parallel (FSDP) training
parser.add_argument(
'--min-params-to-wrap', type=int, metavar='D', default=DEFAULT_MIN_PARAMS_TO_WRAP,
help=(
'minimum number of params for a layer to be wrapped with FSDP() when '
'training with --ddp-backend=fully_sharded. Smaller values will '
'improve memory efficiency, but may make torch.distributed '
'communication less efficient due to smaller input sizes.'
)
)
# fmt: on
@classmethod
@ -242,8 +255,11 @@ class TransformerModel(FairseqEncoderDecoderModel):
encoder = cls.build_encoder(args, src_dict, encoder_embed_tokens)
decoder = cls.build_decoder(args, tgt_dict, decoder_embed_tokens)
if not args.share_all_embeddings:
encoder = fsdp_wrap(encoder, min_num_params=1e8)
decoder = fsdp_wrap(decoder, min_num_params=1e8)
min_params_to_wrap = getattr(
args, "min_params_to_wrap", DEFAULT_MIN_PARAMS_TO_WRAP
)
encoder = fsdp_wrap(encoder, min_num_params=min_params_to_wrap)
decoder = fsdp_wrap(decoder, min_num_params=min_params_to_wrap)
return cls(args, encoder, decoder)
@classmethod
@ -387,10 +403,16 @@ class TransformerEncoder(FairseqEncoder):
def build_encoder_layer(self, args):
layer = TransformerEncoderLayer(args)
if getattr(args, "checkpoint_activations", False):
checkpoint = getattr(args, "checkpoint_activations", False)
if checkpoint:
offload_to_cpu = getattr(args, "offload_activations", False)
layer = checkpoint_wrapper(layer, offload_to_cpu=offload_to_cpu)
layer = fsdp_wrap(layer, min_num_params=1e8)
# checkpointing requires alignment to FSDP wrap boundaries
min_params_to_wrap = (
getattr(args, "min_params_to_wrap", DEFAULT_MIN_PARAMS_TO_WRAP)
if not checkpoint else 0
)
layer = fsdp_wrap(layer, min_num_params=min_params_to_wrap)
return layer
def forward_embedding(
@ -728,10 +750,16 @@ class TransformerDecoder(FairseqIncrementalDecoder):
def build_decoder_layer(self, args, no_encoder_attn=False):
layer = TransformerDecoderLayer(args, no_encoder_attn)
if getattr(args, "checkpoint_activations", False):
checkpoint = getattr(args, "checkpoint_activations", False)
if checkpoint:
offload_to_cpu = getattr(args, "offload_activations", False)
layer = checkpoint_wrapper(layer, offload_to_cpu=offload_to_cpu)
layer = fsdp_wrap(layer, min_num_params=1e8)
# checkpointing requires alignment to FSDP wrap boundaries
min_params_to_wrap = (
getattr(args, "min_params_to_wrap", DEFAULT_MIN_PARAMS_TO_WRAP)
if not checkpoint else 0
)
layer = fsdp_wrap(layer, min_num_params=min_params_to_wrap)
return layer
def forward(

View File

@ -14,7 +14,9 @@ from fairseq.models import (
register_model,
register_model_architecture,
)
from fairseq.models.transformer import Embedding, TransformerDecoder
from fairseq.models.transformer import (
DEFAULT_MIN_PARAMS_TO_WRAP, Embedding, TransformerDecoder
)
from fairseq.modules import AdaptiveInput, CharacterTokenEmbedder
from omegaconf import II
@ -126,15 +128,6 @@ class TransformerLanguageModelConfig(FairseqDataclass):
default=False,
metadata={"help": "use learned positional embeddings in the decoder"},
)
decoder_layerdrop: float = field(
default=0.0, metadata={"help": "LayerDrop probability for decoder"}
)
decoder_layers_to_keep: Optional[str] = field(
default=None,
metadata={
"help": "which layers to *keep* when pruning as a comma-separated list"
},
)
layernorm_embedding: bool = field(
default=False, metadata={"help": "add layernorm to embedding"}
)
@ -148,6 +141,17 @@ class TransformerLanguageModelConfig(FairseqDataclass):
default=False,
metadata={"help": "move checkpointed activations to CPU after they are used."},
)
# config for "Reducing Transformer Depth on Demand with Structured Dropout" (Fan et al., 2019)
decoder_layerdrop: float = field(
default=0.0, metadata={"help": "LayerDrop probability for decoder"}
)
decoder_layers_to_keep: Optional[str] = field(
default=None,
metadata={
"help": "which layers to *keep* when pruning as a comma-separated list"
},
)
# config for Training with Quantization Noise for Extreme Model Compression ({Fan*, Stock*} et al., 2020)
quant_noise_pq: float = field(
default=0.0,
metadata={"help": "iterative PQ quantization noise at training time"},
@ -156,13 +160,25 @@ class TransformerLanguageModelConfig(FairseqDataclass):
default=8,
metadata={"help": "block size of quantization noise at training time"},
)
# TODO common var add to parent
quant_noise_scalar: float = field(
default=0.0,
metadata={
"help": "scalar quantization noise and scalar quantization at training time"
},
)
# config for Fully Sharded Data Parallel (FSDP) training
min_params_to_wrap: int = field(
default=DEFAULT_MIN_PARAMS_TO_WRAP,
metadata={
"help": (
"minimum number of params for a layer to be wrapped with FSDP() when "
"training with --ddp-backend=fully_sharded. Smaller values will "
"improve memory efficiency, but may make torch.distributed "
"communication less efficient due to smaller input sizes."
)
}
)
# options from other parts of the config
add_bos_token: bool = II("task.add_bos_token")
tokens_per_sample: int = II("task.tokens_per_sample")
max_target_positions: Optional[int] = II("task.max_target_positions")
@ -289,7 +305,7 @@ def base_lm_architecture(args):
args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None)
args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0)
args.adaptive_softmax_factor = getattr(args, "adaptive_softmax_factor", 4)
args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False)
args.decoder_learned_pos = getattr(args, "decoder_learned_pos", True)
args.activation_fn = getattr(args, "activation_fn", "relu")
args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0)
@ -428,3 +444,84 @@ def transformer_lm_gpt2_big(args):
args.attention_dropout = getattr(args, "attention_dropout", 0.1)
args.activation_fn = getattr(args, "activation_fn", "gelu")
base_lm_architecture(args)
def base_gpt3_architecture(args):
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", args.decoder_embed_dim * 4)
args.decoder_learned_pos = getattr(args, "decoder_learned_pos", True)
args.dropout = getattr(args, "dropout", 0.0)
args.attention_dropout = getattr(args, "attention_dropout", 0.0)
args.activation_fn = getattr(args, "activation_fn", "gelu")
base_lm_architecture(args)
@register_model_architecture("transformer_lm", "transformer_lm_gpt3_small")
def transformer_lm_gpt3_small(args):
# 125M params
args.decoder_layers = getattr(args, "decoder_layers", 12)
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 768)
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 12)
base_gpt3_architecture(args)
@register_model_architecture("transformer_lm", "transformer_lm_gpt3_medium")
def transformer_lm_gpt3_medium(args):
# 350M params
args.decoder_layers = getattr(args, "decoder_layers", 24)
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 1024)
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16)
base_gpt3_architecture(args)
@register_model_architecture("transformer_lm", "transformer_lm_gpt3_large")
def transformer_lm_gpt3_large(args):
# 760M params
args.decoder_layers = getattr(args, "decoder_layers", 24)
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 1536)
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16)
base_gpt3_architecture(args)
@register_model_architecture("transformer_lm", "transformer_lm_gpt3_xl")
def transformer_lm_gpt3_xl(args):
# 1.3B params
args.decoder_layers = getattr(args, "decoder_layers", 24)
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 2048)
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 24)
base_gpt3_architecture(args)
@register_model_architecture("transformer_lm", "transformer_lm_gpt3_2_7")
def transformer_lm_gpt3_2_7(args):
# 2.7B params
args.decoder_layers = getattr(args, "decoder_layers", 32)
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 2560)
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 32)
base_gpt3_architecture(args)
@register_model_architecture("transformer_lm", "transformer_lm_gpt3_6_7")
def transformer_lm_gpt3_6_7(args):
# 6.7B params
args.decoder_layers = getattr(args, "decoder_layers", 32)
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 4096)
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 32)
base_gpt3_architecture(args)
@register_model_architecture("transformer_lm", "transformer_lm_gpt3_13")
def transformer_lm_gpt3_13(args):
# 13B params
args.decoder_layers = getattr(args, "decoder_layers", 40)
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 5120)
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 40)
base_gpt3_architecture(args)
@register_model_architecture("transformer_lm", "transformer_lm_gpt3_175")
def transformer_lm_gpt3_175(args):
# 175B params
args.decoder_layers = getattr(args, "decoder_layers", 96)
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 12288)
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 96)
base_gpt3_architecture(args)

View File

@ -1017,15 +1017,17 @@ class Trainer(object):
def clip_grad_norm(self, clip_norm):
def agg_norm_fn(total_norm):
if self.cfg.distributed_training.ddp_backend == "fully_sharded":
total_norm = total_norm ** 2
if (
if (
self.cfg.distributed_training.ddp_backend == "fully_sharded"
and (
self.data_parallel_process_group is not None
or torch.distributed.is_initialized()
):
total_norm = distributed_utils.all_reduce(
total_norm.cuda(), group=self.data_parallel_process_group
)
)
):
total_norm = total_norm.cuda().float() ** 2
total_norm = distributed_utils.all_reduce(
total_norm, group=self.data_parallel_process_group
)
total_norm = total_norm ** 0.5
return total_norm