NormFormer: flags and docs (#2460)

Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/2460

Reviewed By: myleott

Differential Revision: D31731798

Pulled By: sshleifer

fbshipit-source-id: 938456c17aa004cacffdcdd124aebe390da83d5f
This commit is contained in:
Sam Shleifer 2021-10-19 17:12:02 -07:00 committed by Facebook GitHub Bot
parent 29be3fe141
commit c5ff181125
6 changed files with 202 additions and 0 deletions

View File

@ -55,6 +55,7 @@ We provide reference implementations of various sequence modeling papers:
+ [Simple and Effective Zero-shot Cross-lingual Phoneme Recognition (Xu et al., 2021)](https://arxiv.org/abs/2109.11680)
+ [VideoCLIP: Contrastive Pre-training for Zero-shot Video-Text Understanding (Xu et. al., 2021)](https://arxiv.org/pdf/2109.14084.pdf)
+ [VLM: Task-agnostic Video-Language Model Pre-training for Video Understanding (Xu et. al., 2021)](https://aclanthology.org/2021.findings-acl.370.pdf)
+ [NormFormer: Improved Transformer Pretraining with Extra Normalization (Shleifer et. al, 2021)](examples/normformer/README.md)
* **Non-autoregressive Transformers**
+ Non-Autoregressive Neural Machine Translation (Gu et al., 2017)
+ Deterministic Non-Autoregressive Neural Sequence Modeling by Iterative Refinement (Lee et al. 2018)

View File

@ -0,0 +1,70 @@
### NormFormer
This is the code for the ["NormFormer: Improved Transformer Pretraining with Extra Normalization"](https://arxiv.org/abs/2110.09456)
- 2021-10-19: Commands for CLM Experiments
- Coming soon: Commands for MLM experiments
If you have any issues or questions please post a github issue and tag `@sshleifer`.
### Data
- To preprocess language modeling data, see [here](https://github.com/pytorch/fairseq/blob/d0fbcb0baef6f6ff3425ded62d8daea0e8b12114/examples/language_model/README.md#1-preprocess-the-data).
- The replication commands below expect `$DATA` to be the path to the binarized data directory.
- Note that NormFormer results in Table 2 use a much larger private dataset, and to get good results you should adapt the pre-processing instructions to your dataset and compare to a baseline on the same data, rather than Table 2.
- The code uses `FSDP`, which requires `pip install fairscale>=0.4.0`.
### Modify existing Command
To modify an existing `fairseq-train` command to use NormFormer, simply add the following flags:
```bash
fairseq-train ... \
--scale-attn --scale-fc --scale-heads
```
- you probably also want to increase your learning rate
- if your model is small, you may want to add `--scale-resids`
### Exact Training Commands
- Note that NormFormer results in Table 2 use a much larger private dataset, and to get good results you should adapt the pre-processing instructions to your dataset.
The full commands are functions defined here, so to run them you must `source examples/normformer/train_lm.sh`.
- We default `--distributed-world-size 8`. You should adjust `--update-freq` and `--batch-size` and such that the effective batch size is (1024x1024x0.5) tokens for 125M and 355M,
and (1024x1024) for 1.3B parameter and above. For small models, `--update-freq`=256/`global_bs`. For large models, `--update-freq`=512/`global_bs`, where `global_bs` = `--batch-size` * `--distributed-world-size`
- The small models will all train on as few as 8 GPUs.
```bash
train_125M --lr 6e-4 # GPT-3 Replicated
train_125M --lr 1e-3 # stronger high-lr baseline
train_125M --lr 3e-3 --scale-attn --scale-fc --scale-heads # No scale-resids
train_125M --lr 3e-3 --scale-attn --scale-fc --scale-heads --scale-resids # Best command
```
```bash
train_355M --lr 6e-4 # GPT-3 Replicated
train_355M --lr 1e-3 # stronger high-lr baseline
train_355M --lr 1e-3 --scale-attn --scale-fc --scale-heads # No scale-resids
train_355M --lr 1e-3 --scale-attn --scale-fc --scale-heads --scale-resids # Slightly better
```
```bash
train_1.3B --lr 2e-4 # GPT-3 Replicated
train_1.3B --lr 6e-4 # stronger high-lr baseline
train_1.3B --lr 6e-4 --scale-attn --scale-fc --scale-heads # NormFormer
```
```bash
train_2.7B --lr 1.6e-4 # GPT-3 Replicated
train_2.7B --lr 1.6e-4 --activation-fn relu_squared # stronger Relu^2 baseline
train_2.7B --lr 6e-4 --activation-fn relu_squared --scale-attn --scale-fc --scale-heads # NormFormer 2.7B
```
### Citation
```bibtex
@misc{shleifer2021normformer,
title={NormFormer: Improved Transformer Pretraining with Extra Normalization},
author={Sam Shleifer and Jason Weston and Myle Ott},
year={2021},
eprint={2110.09456},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
```

View File

@ -0,0 +1,78 @@
#!/usr/bin/env bash
train_common () {
fairseq-train "$DATA" \
--combine-val \
--train-subset train \
--num-workers 2 \
--validate-interval-updates 1000 \
--save-interval-updates 1000 \
--no-epoch-checkpoints \
--ddp-backend fully_sharded \
--memory-efficient-fp16 \
--fp16-init-scale 4 \
--checkpoint-activations \
--arch transformer_lm_gpt \
--activation-fn gelu \
--share-decoder-input-output-embed \
--task language_modeling \
--sample-break-mode none \
--tokens-per-sample 2048 \
--optimizer adam --adam-betas "(0.9, 0.98)" \
--adam-eps 1e-08 \
--clip-norm 0.0 \
--lr-scheduler polynomial_decay \
--warmup-updates 750 \
--dropout 0.1 \
--attention-dropout 0.1 \
--weight-decay 0.01 \
--batch-size 16 \
--update-freq 2 \
--required-batch-size-multiple 1 \
--total-num-update 572204 \
--max-update 572204 \
--seed 1 \
--log-format json --log-interval 1 \
--distributed-world-size 8 --distributed-port 13177 \
"$@"
}
train_125M () {
train_common --decoder-layers 12 \
--decoder-embed-dim 768 \
--decoder-ffn-embed-dim 3072 \
--decoder-attention-heads 12 "$@"
}
train_355M () {
train_common --decoder-layers 24 \
--decoder-embed-dim 1024\
--decoder-ffn-embed-dim 4096 \
--decoder-attention-heads 16 \
--dropout 0.0 \
--attention-dropout 0.0 \
"$@"
}
train_1.3B () {
train_common --decoder-layers 24 \
--decoder-embed-dim 2048 \
--decoder-ffn-embed-dim 8192 \
--decoder-attention-heads 32 \
--batch-size 4 \
--update-freq 16 \
--total-num-update 286102 \
--max-update 286102 \
"$@"
}
train_2.7B () {
train_common --decoder-layers 32 \
--decoder-embed-dim 2560 \
--decoder-ffn-embed-dim 10240 \
--decoder-attention-heads 32 \
--batch-size 4 \
--update-freq 16 \
--total-num-update 286102 \
--max-update 286102 \
"$@"
}

View File

@ -191,6 +191,11 @@ class TransformerLanguageModelConfig(FairseqDataclass):
base_shuffle: Optional[int] = field(
default=1, metadata={"help": "shuffle tokens between workers before computing assignment"}
)
# NormFormer
scale_fc: Optional[bool] = field(default=False, metadata={"help": 'Insert LayerNorm between fully connected layers'})
scale_attn: Optional[bool] = field(default=False, metadata={"help": 'Insert LayerNorm after attention'})
scale_heads: Optional[bool] = field(default=False, metadata={"help": 'Learn a scale coefficient for each attention head'})
scale_resids: Optional[bool] = field(default=False, metadata={"help": 'Learn a scale coefficient for each residual connection'})
# options from other parts of the config
add_bos_token: bool = II("task.add_bos_token")
tokens_per_sample: int = II("task.tokens_per_sample")
@ -357,6 +362,10 @@ def base_lm_architecture(args):
args.layernorm_embedding = safe_getattr(args, "layernorm_embedding", False)
args.checkpoint_activations = safe_getattr(args, "checkpoint_activations", False)
args.offload_activations = safe_getattr(args, "offload_activations", False)
args.scale_fc = safe_getattr(args, 'scale_fc', False)
args.scale_attn = safe_getattr(args, 'scale_attn', False)
args.scale_heads = safe_getattr(args, 'scale_heads', False)
args.scale_resids = safe_getattr(args, 'scale_resids', False)
if args.offload_activations:
args.checkpoint_activations = True

View File

@ -213,6 +213,11 @@ class TransformerDecoderLayerBase(nn.Module):
add_bias_kv=add_bias_kv,
add_zero_attn=add_zero_attn,
)
self.attn_ln = LayerNorm(self.embed_dim) if utils.safe_getattr(cfg, 'scale_attn', False) else None
self.nh = self.self_attn.num_heads
self.head_dim = self.self_attn.head_dim
scale_heads = utils.safe_getattr(cfg, 'scale_heads', False)
self.c_attn = nn.Parameter(torch.ones((self.nh,)), requires_grad=True) if scale_heads else None
self.activation_fn = utils.get_activation_fn(activation=cfg.activation_fn)
activation_dropout_p = cfg.activation_dropout
@ -233,6 +238,9 @@ class TransformerDecoderLayerBase(nn.Module):
self.encoder_attn = self.build_encoder_attention(self.embed_dim, cfg)
self.encoder_attn_layer_norm = LayerNorm(self.embed_dim, export=cfg.export)
self.ffn_layernorm = LayerNorm(cfg.decoder.ffn_embed_dim) if utils.safe_getattr(cfg, 'scale_fc', False) else None
self.w_resid = nn.Parameter(torch.ones(self.embed_dim, ), requires_grad=True) if utils.safe_getattr(cfg, 'scale_resids', False) else None
self.fc1 = self.build_fc1(
self.embed_dim,
cfg.decoder.ffn_embed_dim,
@ -289,6 +297,7 @@ class TransformerDecoderLayerBase(nn.Module):
def residual_connection(self, x, residual):
return residual + x
def forward(
self,
x,
@ -365,6 +374,13 @@ class TransformerDecoderLayerBase(nn.Module):
need_weights=False,
attn_mask=self_attn_mask,
)
if self.c_attn is not None:
tgt_len, bsz = x.size(0), x.size(1)
x = x.view(tgt_len, bsz, self.nh, self.head_dim)
x = torch.einsum('tbhd,h->tbdh', x, self.c_attn)
x = x.reshape(tgt_len, bsz, self.embed_dim)
if self.attn_ln is not None:
x = self.attn_ln(x)
x = self.dropout_module(x)
x = self.residual_connection(x, residual)
if not self.normalize_before:
@ -406,8 +422,12 @@ class TransformerDecoderLayerBase(nn.Module):
x = self.activation_fn(self.fc1(x))
x = self.activation_dropout_module(x)
if self.ffn_layernorm is not None:
x = self.ffn_layernorm(x)
x = self.fc2(x)
x = self.dropout_module(x)
if self.w_resid is not None:
residual = torch.mul(self.w_resid, residual)
x = self.residual_connection(x, residual)
if not self.normalize_before:
x = self.final_layer_norm(x)

View File

@ -1178,6 +1178,30 @@ class TestLanguageModeling(unittest.TestCase):
],
)
def test_normformer_lm(self):
with contextlib.redirect_stdout(StringIO()):
with tempfile.TemporaryDirectory("test_transformer_lm") as data_dir:
create_dummy_data(data_dir)
preprocess_lm_data(data_dir)
train_language_model(
data_dir,
"transformer_lm",
["--add-bos-token", '--nval', '1', '--scale-fc', '--scale-heads', '--scale-attn', '--scale-fc'],
run_validation=True,
)
eval_lm_main(data_dir)
eval_lm_main(data_dir, extra_flags=["--context-window", "25"])
generate_main(
data_dir,
[
"--task",
"language_modeling",
"--sample-break-mode",
"eos",
"--tokens-per-sample",
"500",
],
)
def test_transformer_lm_with_adaptive_softmax(self):
with contextlib.redirect_stdout(StringIO()):
with tempfile.TemporaryDirectory(