From c5ff181125c7e6126b49a85e5ebdd5f5b6a07914 Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Tue, 19 Oct 2021 17:12:02 -0700 Subject: [PATCH] NormFormer: flags and docs (#2460) Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/2460 Reviewed By: myleott Differential Revision: D31731798 Pulled By: sshleifer fbshipit-source-id: 938456c17aa004cacffdcdd124aebe390da83d5f --- README.md | 1 + examples/normformer/README.md | 70 +++++++++++++++++++++++++ examples/normformer/train_lm.sh | 78 ++++++++++++++++++++++++++++ fairseq/models/transformer_lm.py | 9 ++++ fairseq/modules/transformer_layer.py | 20 +++++++ tests/test_binaries.py | 24 +++++++++ 6 files changed, 202 insertions(+) create mode 100644 examples/normformer/README.md create mode 100644 examples/normformer/train_lm.sh diff --git a/README.md b/README.md index c8e0b969d..4a3c6b29d 100644 --- a/README.md +++ b/README.md @@ -55,6 +55,7 @@ We provide reference implementations of various sequence modeling papers: + [Simple and Effective Zero-shot Cross-lingual Phoneme Recognition (Xu et al., 2021)](https://arxiv.org/abs/2109.11680) + [VideoCLIP: Contrastive Pre-training for Zero-shot Video-Text Understanding (Xu et. al., 2021)](https://arxiv.org/pdf/2109.14084.pdf) + [VLM: Task-agnostic Video-Language Model Pre-training for Video Understanding (Xu et. al., 2021)](https://aclanthology.org/2021.findings-acl.370.pdf) + + [NormFormer: Improved Transformer Pretraining with Extra Normalization (Shleifer et. al, 2021)](examples/normformer/README.md) * **Non-autoregressive Transformers** + Non-Autoregressive Neural Machine Translation (Gu et al., 2017) + Deterministic Non-Autoregressive Neural Sequence Modeling by Iterative Refinement (Lee et al. 2018) diff --git a/examples/normformer/README.md b/examples/normformer/README.md new file mode 100644 index 000000000..037b453ff --- /dev/null +++ b/examples/normformer/README.md @@ -0,0 +1,70 @@ +### NormFormer +This is the code for the ["NormFormer: Improved Transformer Pretraining with Extra Normalization"](https://arxiv.org/abs/2110.09456) +- 2021-10-19: Commands for CLM Experiments +- Coming soon: Commands for MLM experiments + +If you have any issues or questions please post a github issue and tag `@sshleifer`. + + +### Data +- To preprocess language modeling data, see [here](https://github.com/pytorch/fairseq/blob/d0fbcb0baef6f6ff3425ded62d8daea0e8b12114/examples/language_model/README.md#1-preprocess-the-data). +- The replication commands below expect `$DATA` to be the path to the binarized data directory. +- Note that NormFormer results in Table 2 use a much larger private dataset, and to get good results you should adapt the pre-processing instructions to your dataset and compare to a baseline on the same data, rather than Table 2. +- The code uses `FSDP`, which requires `pip install fairscale>=0.4.0`. + + +### Modify existing Command +To modify an existing `fairseq-train` command to use NormFormer, simply add the following flags: +```bash +fairseq-train ... \ + --scale-attn --scale-fc --scale-heads +``` +- you probably also want to increase your learning rate +- if your model is small, you may want to add `--scale-resids` + +### Exact Training Commands + +- Note that NormFormer results in Table 2 use a much larger private dataset, and to get good results you should adapt the pre-processing instructions to your dataset. +The full commands are functions defined here, so to run them you must `source examples/normformer/train_lm.sh`. +- We default `--distributed-world-size 8`. You should adjust `--update-freq` and `--batch-size` and such that the effective batch size is (1024x1024x0.5) tokens for 125M and 355M, + and (1024x1024) for 1.3B parameter and above. For small models, `--update-freq`=256/`global_bs`. For large models, `--update-freq`=512/`global_bs`, where `global_bs` = `--batch-size` * `--distributed-world-size` +- The small models will all train on as few as 8 GPUs. + +```bash +train_125M --lr 6e-4 # GPT-3 Replicated +train_125M --lr 1e-3 # stronger high-lr baseline +train_125M --lr 3e-3 --scale-attn --scale-fc --scale-heads # No scale-resids +train_125M --lr 3e-3 --scale-attn --scale-fc --scale-heads --scale-resids # Best command +``` + +```bash +train_355M --lr 6e-4 # GPT-3 Replicated +train_355M --lr 1e-3 # stronger high-lr baseline +train_355M --lr 1e-3 --scale-attn --scale-fc --scale-heads # No scale-resids +train_355M --lr 1e-3 --scale-attn --scale-fc --scale-heads --scale-resids # Slightly better +``` + +```bash +train_1.3B --lr 2e-4 # GPT-3 Replicated +train_1.3B --lr 6e-4 # stronger high-lr baseline +train_1.3B --lr 6e-4 --scale-attn --scale-fc --scale-heads # NormFormer +``` + +```bash +train_2.7B --lr 1.6e-4 # GPT-3 Replicated +train_2.7B --lr 1.6e-4 --activation-fn relu_squared # stronger Relu^2 baseline +train_2.7B --lr 6e-4 --activation-fn relu_squared --scale-attn --scale-fc --scale-heads # NormFormer 2.7B +``` + + +### Citation +```bibtex +@misc{shleifer2021normformer, + title={NormFormer: Improved Transformer Pretraining with Extra Normalization}, + author={Sam Shleifer and Jason Weston and Myle Ott}, + year={2021}, + eprint={2110.09456}, + archivePrefix={arXiv}, + primaryClass={cs.CL} +} +``` diff --git a/examples/normformer/train_lm.sh b/examples/normformer/train_lm.sh new file mode 100644 index 000000000..b081f2ddd --- /dev/null +++ b/examples/normformer/train_lm.sh @@ -0,0 +1,78 @@ +#!/usr/bin/env bash +train_common () { + fairseq-train "$DATA" \ + --combine-val \ + --train-subset train \ + --num-workers 2 \ + --validate-interval-updates 1000 \ + --save-interval-updates 1000 \ + --no-epoch-checkpoints \ + --ddp-backend fully_sharded \ + --memory-efficient-fp16 \ + --fp16-init-scale 4 \ + --checkpoint-activations \ + --arch transformer_lm_gpt \ + --activation-fn gelu \ + --share-decoder-input-output-embed \ + --task language_modeling \ + --sample-break-mode none \ + --tokens-per-sample 2048 \ + --optimizer adam --adam-betas "(0.9, 0.98)" \ + --adam-eps 1e-08 \ + --clip-norm 0.0 \ + --lr-scheduler polynomial_decay \ + --warmup-updates 750 \ + --dropout 0.1 \ + --attention-dropout 0.1 \ + --weight-decay 0.01 \ + --batch-size 16 \ + --update-freq 2 \ + --required-batch-size-multiple 1 \ + --total-num-update 572204 \ + --max-update 572204 \ + --seed 1 \ + --log-format json --log-interval 1 \ + --distributed-world-size 8 --distributed-port 13177 \ + "$@" +} + +train_125M () { + train_common --decoder-layers 12 \ + --decoder-embed-dim 768 \ + --decoder-ffn-embed-dim 3072 \ + --decoder-attention-heads 12 "$@" +} + +train_355M () { + train_common --decoder-layers 24 \ + --decoder-embed-dim 1024\ + --decoder-ffn-embed-dim 4096 \ + --decoder-attention-heads 16 \ + --dropout 0.0 \ + --attention-dropout 0.0 \ + "$@" +} + +train_1.3B () { + train_common --decoder-layers 24 \ + --decoder-embed-dim 2048 \ + --decoder-ffn-embed-dim 8192 \ + --decoder-attention-heads 32 \ + --batch-size 4 \ + --update-freq 16 \ + --total-num-update 286102 \ + --max-update 286102 \ + "$@" +} + +train_2.7B () { + train_common --decoder-layers 32 \ + --decoder-embed-dim 2560 \ + --decoder-ffn-embed-dim 10240 \ + --decoder-attention-heads 32 \ + --batch-size 4 \ + --update-freq 16 \ + --total-num-update 286102 \ + --max-update 286102 \ + "$@" +} diff --git a/fairseq/models/transformer_lm.py b/fairseq/models/transformer_lm.py index eedd5151b..14cbb089f 100644 --- a/fairseq/models/transformer_lm.py +++ b/fairseq/models/transformer_lm.py @@ -191,6 +191,11 @@ class TransformerLanguageModelConfig(FairseqDataclass): base_shuffle: Optional[int] = field( default=1, metadata={"help": "shuffle tokens between workers before computing assignment"} ) + # NormFormer + scale_fc: Optional[bool] = field(default=False, metadata={"help": 'Insert LayerNorm between fully connected layers'}) + scale_attn: Optional[bool] = field(default=False, metadata={"help": 'Insert LayerNorm after attention'}) + scale_heads: Optional[bool] = field(default=False, metadata={"help": 'Learn a scale coefficient for each attention head'}) + scale_resids: Optional[bool] = field(default=False, metadata={"help": 'Learn a scale coefficient for each residual connection'}) # options from other parts of the config add_bos_token: bool = II("task.add_bos_token") tokens_per_sample: int = II("task.tokens_per_sample") @@ -357,6 +362,10 @@ def base_lm_architecture(args): args.layernorm_embedding = safe_getattr(args, "layernorm_embedding", False) args.checkpoint_activations = safe_getattr(args, "checkpoint_activations", False) args.offload_activations = safe_getattr(args, "offload_activations", False) + args.scale_fc = safe_getattr(args, 'scale_fc', False) + args.scale_attn = safe_getattr(args, 'scale_attn', False) + args.scale_heads = safe_getattr(args, 'scale_heads', False) + args.scale_resids = safe_getattr(args, 'scale_resids', False) if args.offload_activations: args.checkpoint_activations = True diff --git a/fairseq/modules/transformer_layer.py b/fairseq/modules/transformer_layer.py index 347b8118d..d685f72bb 100644 --- a/fairseq/modules/transformer_layer.py +++ b/fairseq/modules/transformer_layer.py @@ -213,6 +213,11 @@ class TransformerDecoderLayerBase(nn.Module): add_bias_kv=add_bias_kv, add_zero_attn=add_zero_attn, ) + self.attn_ln = LayerNorm(self.embed_dim) if utils.safe_getattr(cfg, 'scale_attn', False) else None + self.nh = self.self_attn.num_heads + self.head_dim = self.self_attn.head_dim + scale_heads = utils.safe_getattr(cfg, 'scale_heads', False) + self.c_attn = nn.Parameter(torch.ones((self.nh,)), requires_grad=True) if scale_heads else None self.activation_fn = utils.get_activation_fn(activation=cfg.activation_fn) activation_dropout_p = cfg.activation_dropout @@ -233,6 +238,9 @@ class TransformerDecoderLayerBase(nn.Module): self.encoder_attn = self.build_encoder_attention(self.embed_dim, cfg) self.encoder_attn_layer_norm = LayerNorm(self.embed_dim, export=cfg.export) + self.ffn_layernorm = LayerNorm(cfg.decoder.ffn_embed_dim) if utils.safe_getattr(cfg, 'scale_fc', False) else None + self.w_resid = nn.Parameter(torch.ones(self.embed_dim, ), requires_grad=True) if utils.safe_getattr(cfg, 'scale_resids', False) else None + self.fc1 = self.build_fc1( self.embed_dim, cfg.decoder.ffn_embed_dim, @@ -289,6 +297,7 @@ class TransformerDecoderLayerBase(nn.Module): def residual_connection(self, x, residual): return residual + x + def forward( self, x, @@ -365,6 +374,13 @@ class TransformerDecoderLayerBase(nn.Module): need_weights=False, attn_mask=self_attn_mask, ) + if self.c_attn is not None: + tgt_len, bsz = x.size(0), x.size(1) + x = x.view(tgt_len, bsz, self.nh, self.head_dim) + x = torch.einsum('tbhd,h->tbdh', x, self.c_attn) + x = x.reshape(tgt_len, bsz, self.embed_dim) + if self.attn_ln is not None: + x = self.attn_ln(x) x = self.dropout_module(x) x = self.residual_connection(x, residual) if not self.normalize_before: @@ -406,8 +422,12 @@ class TransformerDecoderLayerBase(nn.Module): x = self.activation_fn(self.fc1(x)) x = self.activation_dropout_module(x) + if self.ffn_layernorm is not None: + x = self.ffn_layernorm(x) x = self.fc2(x) x = self.dropout_module(x) + if self.w_resid is not None: + residual = torch.mul(self.w_resid, residual) x = self.residual_connection(x, residual) if not self.normalize_before: x = self.final_layer_norm(x) diff --git a/tests/test_binaries.py b/tests/test_binaries.py index 4e2077426..bc233192f 100644 --- a/tests/test_binaries.py +++ b/tests/test_binaries.py @@ -1178,6 +1178,30 @@ class TestLanguageModeling(unittest.TestCase): ], ) + def test_normformer_lm(self): + with contextlib.redirect_stdout(StringIO()): + with tempfile.TemporaryDirectory("test_transformer_lm") as data_dir: + create_dummy_data(data_dir) + preprocess_lm_data(data_dir) + train_language_model( + data_dir, + "transformer_lm", + ["--add-bos-token", '--nval', '1', '--scale-fc', '--scale-heads', '--scale-attn', '--scale-fc'], + run_validation=True, + ) + eval_lm_main(data_dir) + eval_lm_main(data_dir, extra_flags=["--context-window", "25"]) + generate_main( + data_dir, + [ + "--task", + "language_modeling", + "--sample-break-mode", + "eos", + "--tokens-per-sample", + "500", + ], + ) def test_transformer_lm_with_adaptive_softmax(self): with contextlib.redirect_stdout(StringIO()): with tempfile.TemporaryDirectory(