add and link hydra docs (#1405)

Summary:
updates hydra integration doc and links to it

Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1405

Reviewed By: myleott

Differential Revision: D24808779

Pulled By: alexeib

fbshipit-source-id: a50160e196e469e30e39d6ee47440a569c0154bd
This commit is contained in:
alexeib 2020-11-07 16:50:15 -08:00 committed by Facebook GitHub Bot
parent 50422496ac
commit bd2e804b9c
3 changed files with 320 additions and 180 deletions

209
README.md
View File

@ -13,100 +13,107 @@
Fairseq(-py) is a sequence modeling toolkit that allows researchers and
developers to train custom models for translation, summarization, language
modeling and other text generation tasks.
We provide reference implementations of various sequence modeling papers:
<details><summary>List of implemented papers</summary><p>
- **Convolutional Neural Networks (CNN)**
- [Language Modeling with Gated Convolutional Networks (Dauphin et al., 2017)](examples/language_model/conv_lm/README.md)
- [Convolutional Sequence to Sequence Learning (Gehring et al., 2017)](examples/conv_seq2seq/README.md)
- [Classical Structured Prediction Losses for Sequence to Sequence Learning (Edunov et al., 2018)](https://github.com/pytorch/fairseq/tree/classic_seqlevel)
- [Hierarchical Neural Story Generation (Fan et al., 2018)](examples/stories/README.md)
- [wav2vec: Unsupervised Pre-training for Speech Recognition (Schneider et al., 2019)](examples/wav2vec/README.md)
- **LightConv and DynamicConv models**
- [Pay Less Attention with Lightweight and Dynamic Convolutions (Wu et al., 2019)](examples/pay_less_attention_paper/README.md)
- **Long Short-Term Memory (LSTM) networks**
- Effective Approaches to Attention-based Neural Machine Translation (Luong et al., 2015)
- **Transformer (self-attention) networks**
- Attention Is All You Need (Vaswani et al., 2017)
- [Scaling Neural Machine Translation (Ott et al., 2018)](examples/scaling_nmt/README.md)
- [Understanding Back-Translation at Scale (Edunov et al., 2018)](examples/backtranslation/README.md)
- [Adaptive Input Representations for Neural Language Modeling (Baevski and Auli, 2018)](examples/language_model/README.adaptive_inputs.md)
- [Lexically constrained decoding with dynamic beam allocation (Post & Vilar, 2018)](examples/constrained_decoding/README.md)
- [Mixture Models for Diverse Machine Translation: Tricks of the Trade (Shen et al., 2019)](examples/translation_moe/README.md)
- [RoBERTa: A Robustly Optimized BERT Pretraining Approach (Liu et al., 2019)](examples/roberta/README.md)
- [Facebook FAIR's WMT19 News Translation Task Submission (Ng et al., 2019)](examples/wmt19/README.md)
- [Jointly Learning to Align and Translate with Transformer Models (Garg et al., 2019)](examples/joint_alignment_translation/README.md )
- [Multilingual Denoising Pre-training for Neural Machine Translation (Liu et at., 2020)](examples/mbart/README.md)
- [Neural Machine Translation with Byte-Level Subwords (Wang et al., 2020)](examples/byte_level_bpe/README.md)
- [Unsupervised Quality Estimation for Neural Machine Translation (Fomicheva et al., 2020)](examples/unsupervised_quality_estimation/README.md)
- [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations (Baevski et al., 2020)](examples/wav2vec/README.md)
- [Generating Medical Reports from Patient-Doctor Conversations Using Sequence-to-Sequence Models (Enarvi et al., 2020)](examples/pointer_generator/README.md)
- [Linformer: Self-Attention with Linear Complexity (Wang et al., 2020)](examples/linformer/README.md)
- [Cross-lingual Retrieval for Iterative Self-Supervised Training (Tran et al., 2020)](examples/criss/README.md)
- [Deep Transformers with Latent Depth (Li et al., 2020)](examples/latent_depth/README.md)
- **Non-autoregressive Transformers**
- Non-Autoregressive Neural Machine Translation (Gu et al., 2017)
- Deterministic Non-Autoregressive Neural Sequence Modeling by Iterative Refinement (Lee et al. 2018)
- Insertion Transformer: Flexible Sequence Generation via Insertion Operations (Stern et al. 2019)
- Mask-Predict: Parallel Decoding of Conditional Masked Language Models (Ghazvininejad et al., 2019)
- [Levenshtein Transformer (Gu et al., 2019)](examples/nonautoregressive_translation/README.md)
- **Finetuning**
- [Better Fine-Tuning by Reducing Representational Collapse (Aghajanyan et al. 2020)](examples/rxf/README.md)
* **Convolutional Neural Networks (CNN)**
+ [Language Modeling with Gated Convolutional Networks (Dauphin et al., 2017)](examples/language_model/conv_lm/README.md)
+ [Convolutional Sequence to Sequence Learning (Gehring et al., 2017)](examples/conv_seq2seq/README.md)
+ [Classical Structured Prediction Losses for Sequence to Sequence Learning (Edunov et al., 2018)](https://github.com/pytorch/fairseq/tree/classic_seqlevel)
+ [Hierarchical Neural Story Generation (Fan et al., 2018)](examples/stories/README.md)
+ [wav2vec: Unsupervised Pre-training for Speech Recognition (Schneider et al., 2019)](examples/wav2vec/README.md)
* **LightConv and DynamicConv models**
+ [Pay Less Attention with Lightweight and Dynamic Convolutions (Wu et al., 2019)](examples/pay_less_attention_paper/README.md)
* **Long Short-Term Memory (LSTM) networks**
+ Effective Approaches to Attention-based Neural Machine Translation (Luong et al., 2015)
* **Transformer (self-attention) networks**
+ Attention Is All You Need (Vaswani et al., 2017)
+ [Scaling Neural Machine Translation (Ott et al., 2018)](examples/scaling_nmt/README.md)
+ [Understanding Back-Translation at Scale (Edunov et al., 2018)](examples/backtranslation/README.md)
+ [Adaptive Input Representations for Neural Language Modeling (Baevski and Auli, 2018)](examples/language_model/README.adaptive_inputs.md)
+ [Lexically constrained decoding with dynamic beam allocation (Post & Vilar, 2018)](examples/constrained_decoding/README.md)
+ [Mixture Models for Diverse Machine Translation: Tricks of the Trade (Shen et al., 2019)](examples/translation_moe/README.md)
+ [RoBERTa: A Robustly Optimized BERT Pretraining Approach (Liu et al., 2019)](examples/roberta/README.md)
+ [Facebook FAIR's WMT19 News Translation Task Submission (Ng et al., 2019)](examples/wmt19/README.md)
+ [Jointly Learning to Align and Translate with Transformer Models (Garg et al., 2019)](examples/joint_alignment_translation/README.md )
+ [Multilingual Denoising Pre-training for Neural Machine Translation (Liu et at., 2020)](examples/mbart/README.md)
+ [Neural Machine Translation with Byte-Level Subwords (Wang et al., 2020)](examples/byte_level_bpe/README.md)
+ [Unsupervised Quality Estimation for Neural Machine Translation (Fomicheva et al., 2020)](examples/unsupervised_quality_estimation/README.md)
+ [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations (Baevski et al., 2020)](examples/wav2vec/README.md)
+ [Generating Medical Reports from Patient-Doctor Conversations Using Sequence-to-Sequence Models (Enarvi et al., 2020)](examples/pointer_generator/README.md)
+ [Linformer: Self-Attention with Linear Complexity (Wang et al., 2020)](examples/linformer/README.md)
+ [Cross-lingual Retrieval for Iterative Self-Supervised Training (Tran et al., 2020)](examples/criss/README.md)
+ [Deep Transformers with Latent Depth (Li et al., 2020)](examples/latent_depth/README.md)
* **Non-autoregressive Transformers**
+ Non-Autoregressive Neural Machine Translation (Gu et al., 2017)
+ Deterministic Non-Autoregressive Neural Sequence Modeling by Iterative Refinement (Lee et al. 2018)
+ Insertion Transformer: Flexible Sequence Generation via Insertion Operations (Stern et al. 2019)
+ Mask-Predict: Parallel Decoding of Conditional Masked Language Models (Ghazvininejad et al., 2019)
+ [Levenshtein Transformer (Gu et al., 2019)](examples/nonautoregressive_translation/README.md)
* **Finetuning**
+ [Better Fine-Tuning by Reducing Representational Collapse (Aghajanyan et al. 2020)](examples/rxf/README.md)
</p></details>
### What's New:
- October 2020: [Added R3F/R4F (Better Fine-Tuning) code](examples/rxf/README.md)
- October 2020: [Deep Transformer with Latent Depth code released](examples/latent_depth/README.md)
- October 2020: [Added CRISS models and code](examples/criss/README.md)
- September 2020: [Added Linformer code](examples/linformer/README.md)
- September 2020: [Added pointer-generator networks](examples/pointer_generator/README.md)
- August 2020: [Added lexically constrained decoding](examples/constrained_decoding/README.md)
- August 2020: [wav2vec2 models and code released](examples/wav2vec/README.md)
- July 2020: [Unsupervised Quality Estimation code released](examples/unsupervised_quality_estimation/README.md)
- May 2020: [Follow fairseq on Twitter](https://twitter.com/fairseq)
- April 2020: [Monotonic Multihead Attention code released](examples/simultaneous_translation/README.md)
- April 2020: [Quant-Noise code released](examples/quant_noise/README.md)
- April 2020: [Initial model parallel support and 11B parameters unidirectional LM released](examples/megatron_11b/README.md)
* November 2020: Adopted [Hydra](https://github.com/facebookresearch/hydra) as a configuration framework;
[added documentation explaining how to use it for new and existing projects](docs/hydra_integration.md)
* October 2020: [Added R3F/R4F (Better Fine-Tuning) code](examples/rxf/README.md)
* October 2020: [Deep Transformer with Latent Depth code released](examples/latent_depth/README.md)
* October 2020: [Added CRISS models and code](examples/criss/README.md)
* September 2020: [Added Linformer code](examples/linformer/README.md)
* September 2020: [Added pointer-generator networks](examples/pointer_generator/README.md)
* August 2020: [Added lexically constrained decoding](examples/constrained_decoding/README.md)
* August 2020: [wav2vec2 models and code released](examples/wav2vec/README.md)
* July 2020: [Unsupervised Quality Estimation code released](examples/unsupervised_quality_estimation/README.md)
* May 2020: [Follow fairseq on Twitter](https://twitter.com/fairseq)
* April 2020: [Monotonic Multihead Attention code released](examples/simultaneous_translation/README.md)
* April 2020: [Quant-Noise code released](examples/quant_noise/README.md)
* April 2020: [Initial model parallel support and 11B parameters unidirectional LM released](examples/megatron_11b/README.md)
<details><summary>Previous updates</summary><p>
- March 2020: [Byte-level BPE code released](examples/byte_level_bpe/README.md)
- February 2020: [mBART model and code released](examples/mbart/README.md)
- February 2020: [Added tutorial for back-translation](https://github.com/pytorch/fairseq/tree/master/examples/backtranslation#training-your-own-model-wmt18-english-german)
- December 2019: [fairseq 0.9.0 released](https://github.com/pytorch/fairseq/releases/tag/v0.9.0)
- November 2019: [VizSeq released (a visual analysis toolkit for evaluating fairseq models)](https://facebookresearch.github.io/vizseq/docs/getting_started/fairseq_example)
- November 2019: [CamemBERT model and code released](examples/camembert/README.md)
- November 2019: [BART model and code released](examples/bart/README.md)
- November 2019: [XLM-R models and code released](examples/xlmr/README.md)
- September 2019: [Nonautoregressive translation code released](examples/nonautoregressive_translation/README.md)
- August 2019: [WMT'19 models released](examples/wmt19/README.md)
- July 2019: fairseq relicensed under MIT license
- July 2019: [RoBERTa models and code released](examples/roberta/README.md)
- June 2019: [wav2vec models and code released](examples/wav2vec/README.md)
* March 2020: [Byte-level BPE code released](examples/byte_level_bpe/README.md)
* February 2020: [mBART model and code released](examples/mbart/README.md)
* February 2020: [Added tutorial for back-translation](https://github.com/pytorch/fairseq/tree/master/examples/backtranslation#training-your-own-model-wmt18-english-german)
* December 2019: [fairseq 0.9.0 released](https://github.com/pytorch/fairseq/releases/tag/v0.9.0)
* November 2019: [VizSeq released (a visual analysis toolkit for evaluating fairseq models)](https://facebookresearch.github.io/vizseq/docs/getting_started/fairseq_example)
* November 2019: [CamemBERT model and code released](examples/camembert/README.md)
* November 2019: [BART model and code released](examples/bart/README.md)
* November 2019: [XLM-R models and code released](examples/xlmr/README.md)
* September 2019: [Nonautoregressive translation code released](examples/nonautoregressive_translation/README.md)
* August 2019: [WMT'19 models released](examples/wmt19/README.md)
* July 2019: fairseq relicensed under MIT license
* July 2019: [RoBERTa models and code released](examples/roberta/README.md)
* June 2019: [wav2vec models and code released](examples/wav2vec/README.md)
</p></details>
### Features:
- multi-GPU training on one machine or across multiple machines (data and model parallel)
- fast generation on both CPU and GPU with multiple search algorithms implemented:
- beam search
- Diverse Beam Search ([Vijayakumar et al., 2016](https://arxiv.org/abs/1610.02424))
- sampling (unconstrained, top-k and top-p/nucleus)
- lexically constrained decoding ([Post & Vilar, 2018](examples/constrained_decoding/README.md))
- large mini-batch training even on a single GPU via delayed updates
- mixed precision training (trains faster with less GPU memory on [NVIDIA tensor cores](https://developer.nvidia.com/tensor-cores))
- extensible: easily register new models, criterions, tasks, optimizers and learning rate schedulers
* multi-GPU training on one machine or across multiple machines (data and model parallel)
* fast generation on both CPU and GPU with multiple search algorithms implemented:
+ beam search
+ Diverse Beam Search ([Vijayakumar et al., 2016](https://arxiv.org/abs/1610.02424))
+ sampling (unconstrained, top-k and top-p/nucleus)
+ lexically constrained decoding ([Post & Vilar, 2018](examples/constrained_decoding/README.md))
* large mini-batch training even on a single GPU via delayed updates
* mixed precision training (trains faster with less GPU memory on [NVIDIA tensor cores](https://developer.nvidia.com/tensor-cores))
* extensible: easily register new models, criterions, tasks, optimizers and learning rate schedulers
* [flexible configuration](docs/hydra_integration.md) based on [Hydra](https://github.com/facebookresearch/hydra) allowing a combination of code, command-line and file based configuration
We also provide [pre-trained models for translation and language modeling](#pre-trained-models-and-examples)
with a convenient `torch.hub` interface:
```python
``` python
en2de = torch.hub.load('pytorch/fairseq', 'transformer.wmt19.en-de.single_model')
en2de.translate('Hello world', beam=5)
# 'Hallo Welt'
```
See the PyTorch Hub tutorials for [translation](https://pytorch.org/hub/pytorch_fairseq_translation/)
and [RoBERTa](https://pytorch.org/hub/pytorch_fairseq_roberta/) for more examples.
@ -116,7 +123,8 @@ and [RoBERTa](https://pytorch.org/hub/pytorch_fairseq_roberta/) for more example
* Python version >= 3.6
* For training new models, you'll also need an NVIDIA GPU and [NCCL](https://github.com/NVIDIA/nccl)
* **To install fairseq** and develop locally:
```bash
``` bash
git clone https://github.com/pytorch/fairseq
cd fairseq
pip install --editable ./
@ -124,18 +132,20 @@ pip install --editable ./
# on MacOS:
# CFLAGS="-stdlib=libc++" pip install --editable ./
```
* **For faster training** install NVIDIA's [apex](https://github.com/NVIDIA/apex) library:
```bash
``` bash
git clone https://github.com/NVIDIA/apex
cd apex
pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" \
--global-option="--deprecated_fused_adam" --global-option="--xentropy" \
--global-option="--fast_multihead_attn" ./
```
* **For large datasets** install [PyArrow](https://arrow.apache.org/docs/python/install.html#using-pip): `pip install pyarrow`
* If you use Docker make sure to increase the shared memory size either with
`--ipc=host` or `--shm-size` as command line options to `nvidia-docker run`.
* **For large datasets** install [PyArrow](https://arrow.apache.org/docs/python/install.html#using-pip): `pip install pyarrow`
* If you use Docker make sure to increase the shared memory size either with `--ipc=host` or `--shm-size`
as command line options to `nvidia-docker run` .
# Getting Started
@ -148,30 +158,31 @@ types and tasks.
We provide pre-trained models and pre-processed, binarized test sets for several tasks listed below,
as well as example training and evaluation commands.
- [Translation](examples/translation/README.md): convolutional and transformer models are available
- [Language Modeling](examples/language_model/README.md): convolutional and transformer models are available
* [Translation](examples/translation/README.md): convolutional and transformer models are available
* [Language Modeling](examples/language_model/README.md): convolutional and transformer models are available
We also have more detailed READMEs to reproduce results from specific papers:
- [Cross-lingual Retrieval for Iterative Self-Supervised Training (Tran et al., 2020)](examples/criss/README.md)
- [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations (Baevski et al., 2020)](examples/wav2vec/README.md)
- [Unsupervised Quality Estimation for Neural Machine Translation (Fomicheva et al., 2020)](examples/unsupervised_quality_estimation/README.md)
- [Training with Quantization Noise for Extreme Model Compression ({Fan*, Stock*} et al., 2020)](examples/quant_noise/README.md)
- [Neural Machine Translation with Byte-Level Subwords (Wang et al., 2020)](examples/byte_level_bpe/README.md)
- [Multilingual Denoising Pre-training for Neural Machine Translation (Liu et at., 2020)](examples/mbart/README.md)
- [Reducing Transformer Depth on Demand with Structured Dropout (Fan et al., 2019)](examples/layerdrop/README.md)
- [Jointly Learning to Align and Translate with Transformer Models (Garg et al., 2019)](examples/joint_alignment_translation/README.md)
- [Levenshtein Transformer (Gu et al., 2019)](examples/nonautoregressive_translation/README.md)
- [Facebook FAIR's WMT19 News Translation Task Submission (Ng et al., 2019)](examples/wmt19/README.md)
- [RoBERTa: A Robustly Optimized BERT Pretraining Approach (Liu et al., 2019)](examples/roberta/README.md)
- [wav2vec: Unsupervised Pre-training for Speech Recognition (Schneider et al., 2019)](examples/wav2vec/README.md)
- [Mixture Models for Diverse Machine Translation: Tricks of the Trade (Shen et al., 2019)](examples/translation_moe/README.md)
- [Pay Less Attention with Lightweight and Dynamic Convolutions (Wu et al., 2019)](examples/pay_less_attention_paper/README.md)
- [Understanding Back-Translation at Scale (Edunov et al., 2018)](examples/backtranslation/README.md)
- [Classical Structured Prediction Losses for Sequence to Sequence Learning (Edunov et al., 2018)](https://github.com/pytorch/fairseq/tree/classic_seqlevel)
- [Hierarchical Neural Story Generation (Fan et al., 2018)](examples/stories/README.md)
- [Scaling Neural Machine Translation (Ott et al., 2018)](examples/scaling_nmt/README.md)
- [Convolutional Sequence to Sequence Learning (Gehring et al., 2017)](examples/conv_seq2seq/README.md)
- [Language Modeling with Gated Convolutional Networks (Dauphin et al., 2017)](examples/language_model/README.conv.md)
* [Cross-lingual Retrieval for Iterative Self-Supervised Training (Tran et al., 2020)](examples/criss/README.md)
* [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations (Baevski et al., 2020)](examples/wav2vec/README.md)
* [Unsupervised Quality Estimation for Neural Machine Translation (Fomicheva et al., 2020)](examples/unsupervised_quality_estimation/README.md)
* [Training with Quantization Noise for Extreme Model Compression ({Fan*, Stock*} et al., 2020)](examples/quant_noise/README.md)
* [Neural Machine Translation with Byte-Level Subwords (Wang et al., 2020)](examples/byte_level_bpe/README.md)
* [Multilingual Denoising Pre-training for Neural Machine Translation (Liu et at., 2020)](examples/mbart/README.md)
* [Reducing Transformer Depth on Demand with Structured Dropout (Fan et al., 2019)](examples/layerdrop/README.md)
* [Jointly Learning to Align and Translate with Transformer Models (Garg et al., 2019)](examples/joint_alignment_translation/README.md)
* [Levenshtein Transformer (Gu et al., 2019)](examples/nonautoregressive_translation/README.md)
* [Facebook FAIR's WMT19 News Translation Task Submission (Ng et al., 2019)](examples/wmt19/README.md)
* [RoBERTa: A Robustly Optimized BERT Pretraining Approach (Liu et al., 2019)](examples/roberta/README.md)
* [wav2vec: Unsupervised Pre-training for Speech Recognition (Schneider et al., 2019)](examples/wav2vec/README.md)
* [Mixture Models for Diverse Machine Translation: Tricks of the Trade (Shen et al., 2019)](examples/translation_moe/README.md)
* [Pay Less Attention with Lightweight and Dynamic Convolutions (Wu et al., 2019)](examples/pay_less_attention_paper/README.md)
* [Understanding Back-Translation at Scale (Edunov et al., 2018)](examples/backtranslation/README.md)
* [Classical Structured Prediction Losses for Sequence to Sequence Learning (Edunov et al., 2018)](https://github.com/pytorch/fairseq/tree/classic_seqlevel)
* [Hierarchical Neural Story Generation (Fan et al., 2018)](examples/stories/README.md)
* [Scaling Neural Machine Translation (Ott et al., 2018)](examples/scaling_nmt/README.md)
* [Convolutional Sequence to Sequence Learning (Gehring et al., 2017)](examples/conv_seq2seq/README.md)
* [Language Modeling with Gated Convolutional Networks (Dauphin et al., 2017)](examples/language_model/README.conv.md)
# Join the fairseq community
@ -188,7 +199,7 @@ The license applies to the pre-trained models as well.
Please cite as:
```bibtex
``` bibtex
@inproceedings{ott2019fairseq,
title = {fairseq: A Fast, Extensible Toolkit for Sequence Modeling},
author = {Myle Ott and Sergey Edunov and Alexei Baevski and Angela Fan and Sam Gross and Nathan Ng and David Grangier and Michael Auli},

View File

@ -1,111 +1,240 @@
## Hydra
Hydra is an open-source Python framework that simplifies the development of research and other complex applications. The key feature is the ability to dynamically create a hierarchical configuration by composition and override it through config files and the command line. The name Hydra comes from its ability to run multiple similar jobs - much like a Hydra with multiple heads.
[Hydra](https://github.com/facebookresearch/hydra) is an open-source Python framework that simplifies the development of
research and other complex applications. The key feature is the ability to dynamically create a hierarchical
configuration by composition and override it through config files and the command line. The name Hydra comes from its
ability to run multiple similar jobs - much like a Hydra with multiple heads.
## Train models with hydra interface
## Motivation
#### Provide parameters in `.yaml` files
For example, if we'd like to train a language model with transformer, we could provide parameters in yaml files. Note that the modules used (task, model, criterion, optimizer, lr scheduler) in training must be migrated with hydra interface already (See session below).
Until recently, all components in fairseq were configured through a shared "args" namespace that was created at
application startup. Components declared their own "add_args" method to update the argparse parser, hoping that
the names would not clash with arguments from other components. While this model works for smaller applications,
as fairseq grew and became integrated into other applications, this became problematic.
In order to determine how to configure each component, one needed to a) examine what args were added by this component, and
b) read the code to figure out what shared arguments it is using that were added in other places. Reproducing
models involved sharing commands that often contained dozens of command line switches.
- Provide top level choices on which generic parameter file, and which modules to use: `config/config.yaml`, this will look like for example:
The model described above is still supported by fairseq for backward compatibility, but will be deprecated some time
in the future.
```
defaults:
- task: language_modeling
- model: transformer_lm
- criterion: cross_entropy
- optimizer: adam
- lr_scheduler: inverse_sqrt
New components in fairseq should now create a dataclass that encapsulates all parameters required to configure this
component. The dataclass is registered along with the component, and fairseq takes care of constructing and
providing this configuration object to the component's constructor. Note that sharing parameters can optionally
still work, but one has to explicitly point to the "source of truth" (see inheritance example below).
These changes make components in fairseq
more independent and re-usable by other applications: all that is needed to create a component is to initialize its
dataclass and overwrite some of the defaults.
While configuring fairseq through command line (using either the legacy argparse based or the new Hydra based entry points) is still
fully supported, you can now take advantage of configuring fairseq completely or piece-by-piece through
hierarchical YAML configuration files. These files can also be shipped as examples that others can use to run
an identically configured job.
Additionally, Hydra has a rich and growing
[library of plugins](https://github.com/facebookresearch/hydra/tree/master/plugins) that provide functionality such as
hyperparameter sweeping (including using bayesian optimization through the [Ax](https://github.com/facebook/Ax) library),
job launching across various platforms, and more.
## Creating or migrating components
In general, each new (or updated) component should provide a companion [dataclass](https://www.python.org/dev/peps/pep-0557/). These dataclass are typically located in the same
file as the component and are passed as arguments to the register_*() functions. Top-level configs that should be
present in every fairseq application are placed in the [global](fairseq/dataclass/configs.py) config file and added
to the FairseqConfig object.
Each dataclass is a plain-old-data object, similar to a NamedTuple. These classes are decorated with a @dataclass
decorator, and typically inherit from `FairseqDataclass` (which adds some functionality for backward compatibility).
Each field must have a type, and generally has metadata (such as a help string) and a default value. Only primitive types or other config objects are allowed as
data types for each field.
Example:
``` python
from dataclasses import dataclass, field
from fairseq.dataclass import FairseqDataclass
@dataclass
class InteractiveConfig(FairseqDataclass):
buffer_size: int = field(
default=0,
metadata={
"help": "read this many sentences into a buffer before processing them"
},
)
input: str = field(
default="-",
metadata={"help": "file to read from; use - for stdin"},
)
```
- Provide generic parameters common across different jobs: `config.yaml`
- Provide task parameters: `config/task/language_modeling.yaml`
- Provide model parameters: `config/model/transformer_lm.yaml`
- Provide criterion parameters: `config/criterion/cross_entropy.yaml`
- Provide optimizer parameters: `config/optimizer/adam.yaml`
- Provide lr_scheduler parameters `config/lr_scheduler/inverse_sqrt.yaml`
### Inherting values
#### Command line overriding
`train_hydra.py` is the main entry point for training with hydra interface. If we specify all parameters we want in `.yaml` files, then we could simply use command:
Some components require sharing a value. For example, a learning rate scheduler and an optimizer may both need to
know the initial learning rate value. One can declare a field that, by default, will
inherit its value from another config node in the same hierarchy:
```
# task.data is requested field marked by `???` in yaml
python fairseq_cli/train_hydra.py \
task.data=/private/home/abaevski/data/wiki103 \
``` python
@dataclass
FairseqAdamConfig(FairseqDataclass):
...
lr: List[float] = II("optimization.lr")
...
```
Alternatively, if we need to override certain params from the command line, we could do so as below (note the structure of where each parameter sits)
`II("optimization.lr")` is syntactic sugar for `"${optimization.lr}"` , which is the value one can use in a YAML config file or through
command line to achieve the same effect. Note that this assumes that there is an "optimization" config object
in the root config and it has a field called "lr".
```
python fairseq_cli/train_hydra.py
task=language_modeling \
task.data=/private/home/abaevski/data/wiki103 \
task.tokens_per_sample=512 \
task.sample_break_mode=none \
model=transformer_lm \
model.share_decoder_input_output_embed=true \
model.dropout=0.1 \
optimizer=adam \
optimizer.adam_betas="'(0.9, 0.98)'" \
optimizer.weight_decay=0.01 \
lr_scheduler=inverse_sqrt \
lr_scheduler.warmup_updates=4000 \
lr_scheduler.warmup_init_lr=1e-07 \
criterion=cross_entropy \
common.fp16=true \
common.log_format=json \
common.log_interval=1 \
dataset.max_tokens=1024 \
dataset.num_workers=4 \
optimization.update_freq=[16] \
optimization.max_update=50000 \
optimization.clip_norm=0.0 \
optimization.lr=[0.0005] \
checkpoint.save_dir=/checkpoint/mtian/transformer_wikitext-103-hydra-args-cli \
checkpoint.save_interval_updates=10
### Tasks and Models
Creating Tasks and Models works same as before, except that legacy implementations now inherit from Legacy* base classes,
while new components inherit from FairseqTask and FairseqModel and provide a dataclass to the register_*() functions.
Task example:
``` python
@dataclass
class LanguageModelingConfig(FairseqDataclass):
data: Optional[str] = field(
default=None, metadata={"help": "path to data directory"}
)
...
@register_task("language_modeling", dataclass=LanguageModelingConfig)
class LanguageModelingTask(LegacyFairseqTask):
...
@classmethod
def setup_task(cls, cfg: LanguageModelingConfig):
...
```
## Migrate existing/Creating new modules to hydra interface
Model example:
In each of the modules we want to migrated/create with hydra interface, fundamentally we need to
``` python
@dataclass
class TransformerLanguageModelConfig(FairseqDataclass):
activation_fn: ChoiceEnum(utils.get_available_activation_fns()) = field(
default="relu", metadata={"help": "activation function to use"}
)
dropout: float = field(default=0.1, metadata={"help": "dropout probability"})
...
- Provide a dataclass that layouts the parameters used in the module.
@register_model("transformer_lm", dataclass=TransformerLanguageModelConfig)
class TransformerLanguageModel(FairseqLanguageModel):
...
@classmethod
def build_model(cls, cfg: TransformerLanguageModelConfig, task: FairseqTask):
...
```
- Modify the builder and/or constructor that previously takes `argparse.Namespace` argument `args`, into taking `omegaconf.DictConfig` config objects. At this moment we allow `Union[omegaconf.DictConfig, argparse.Namespace]` to support compatibility.
### Other components
- For `add_args()`, we need to extract argument from the dataclass defined in the same file, and append them into `parser`. This is also to support compatibility. This is simply supported with `gen_parser_from_dataclass` API, see examples files below.
Other components work as before, but they now take their configuration dataclass as the only constructor argument:
#### Migrated examples:
``` python
@dataclass
class MosesTokenizerConfig(FairseqDataclass):
source_lang: str = field(default="en", metadata={"help": "source language"})
...
- Task: `fairseq/tasks/language_modeling.py`
@register_tokenizer("moses", dataclass=MosesTokenizerConfig)
class MosesTokenizer(object):
def __init__(self, cfg: MosesTokenizerConfig):
...
```
- Model: `fairseq/models/transformer_lm.py`
Note that if you are adding a new registry for a new set of components, you need to add it to the FairseqConfig object in
fairseq/dataclass/configs.py:
- Criterion: `fairseq/criterions/adaptive_loss.py` and `fairseq/criterions/cross_entropy.py`
``` python
@dataclass
class FairseqConfig(object):
...
my_new_registry: Any = None
```
- Optimizer: `fairseq/optim/adam.py` and `fairseq/optim/nag.py`
## Training with hydra_train.py
- LR scheduler: `fairseq/optim/lr_scheduler/cosine_lr_scheduler.py` and `fairseq/optim/lr_scheduler/inverse_square_root_schedule.py`
To fully take advantage of configuration flexibility offered by Hydra, you may want to train new models using the
hydra_train.py entry point located in the fairseq_cli directory. Legacy CLI tools such as train.py,
will remain supported for the foreseeable future but will be deprecated eventually.
On startup, Hydra will create a configuration object that contains a hierarchy of all the necessary dataclasses
populated with their default values in the code. The default values are overwritten by values found in YAML files in
fairseq/config directory (which currently just set default task, optimizer, etc) and then further overwritten by values
provided through command line arguments.
## Interpolate parameters across different places
Some of the most common use cases are shown below:
## Support of legacy interface
If you still like to pass legacy style arguments in command line, `fairseq_cli/train.py` can support this. Internally it coverted `args` into hydra config objects whenever there are migrated modules aligned.
### 1. Overwrite default values through command line:
```shell script
python fairseq_cli/hydra_train.py distributed_training.distributed_world_size=1 dataset.batch_size=2 task.data=data-bin \
model=transformer_lm/transformer_lm_gpt task=language_modeling optimization.max_update=5000
```
python fairseq_cli/train.py --task language_modeling \
/private/home/abaevski/data/wiki103 \
--save-dir /checkpoint/mtian/transformer_wikitext-103-hydra-args-cli \
--arch transformer_lm --share-decoder-input-output-embed \
--dropout 0.1 \
--optimizer adam --adam-betas '(0.9, 0.98)' --weight-decay 0.01 --clip-norm 0.0 \
--lr 0.0005 --lr-scheduler inverse_sqrt --warmup-updates 4000 --warmup-init-lr 1e-07 \
--tokens-per-sample 512 --sample-break-mode none \
--max-tokens 1024 --update-freq 16 \
--fp16 \
--max-update 50000 --log-format json --log-interval 1 --num-workers 4 \
--save-interval-updates 10
Note that along with explicitly providing values for parameters such as dataset.batch_size, this also tells Hydra to overlay configuration found in `fairseq/config/model/transformer_lm/transformer_lm_gpt.yaml`
over the default values in the dataclass. If you want to train a model without specifying a particular architecture
you can simply specify model=transformer_lm. This only works for migrated tasks and models.
### 2. Replace bundled configs with an external config:
```shell script
python fairseq_cli/hydra_train.py --config-path /path/to/external/configs --config-name wiki103
```
where /path/to/external/configs/wiki103.yaml contains:
``` yaml
# @package _group_
model:
_name: transformer_lm
distributed_training:
distributed_world_size: 1
dataset:
batch_size: 2
task:
_name: language_modeling
data: /path/to/data
add_bos_token: false
max_target_positions: 1024
optimization:
max_update: 50000
lr: [ 0.25 ]
criterion: cross_entropy
optimizer: adam
lr_scheduler:
_name: cosine
```
Note that here bundled configs from `fairseq/config` directory are not used, however the defaults from each dataclass will still be used (unless overwritten by your external config).
Additionally you can choose to break up your configs by creating a directory structure in the same location as your main config file, with the names of the top-level fields
(such as "model", "dataset", etc), and placing config files with meaningful names that would populate that specific section of your
top-level config file (for example, you might have model/small_transformer_lm.yaml, model/big_transformer_lm.yaml, etc). You can then specify the correct configuration via command line, defaults in the main config, or even launch all of them as a sweep (see Hydra documentation on how to do this).
### 3. Add an external config directory to Hydra search path:
This allows combining default configuration (including using any bundled config files), while specifying your own config files for some parts of the configuration.
```shell script
python fairseq_cli/hydra_train.py distributed_training.distributed_world_size=1 dataset.batch_size=2 \
task.data=/path/to/data/ model=transformer_lm/2_layers task=language_modeling optimization.max_update=5000 \
--config-dir /path/to/external/configs
```
where /path/to/external/configs has the following structure:
```
.
+-- model
| +-- transformer_lm
| | +-- 2_layers.yaml
```
and 2_layers.yaml contains a copy of transformer_lm_gpt.yaml but with decoder_layers set to 2. You can add
other configs to configure other components as well.

View File

@ -24,7 +24,7 @@ class MosesTokenizerConfig(FairseqDataclass):
@register_tokenizer("moses", dataclass=MosesTokenizerConfig)
class MosesTokenizer(object):
def __init__(self, cfg):
def __init__(self, cfg: MosesTokenizerConfig):
self.cfg = cfg
try: