From bd2e804b9c2ff1fae202c00e227f1afece12420b Mon Sep 17 00:00:00 2001 From: alexeib Date: Sat, 7 Nov 2020 16:50:15 -0800 Subject: [PATCH] add and link hydra docs (#1405) Summary: updates hydra integration doc and links to it Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1405 Reviewed By: myleott Differential Revision: D24808779 Pulled By: alexeib fbshipit-source-id: a50160e196e469e30e39d6ee47440a569c0154bd --- README.md | 209 ++++++++-------- docs/hydra_integration.md | 289 ++++++++++++++++------- fairseq/data/encoders/moses_tokenizer.py | 2 +- 3 files changed, 320 insertions(+), 180 deletions(-) diff --git a/README.md b/README.md index 13b822395..0648da15f 100644 --- a/README.md +++ b/README.md @@ -13,100 +13,107 @@ Fairseq(-py) is a sequence modeling toolkit that allows researchers and developers to train custom models for translation, summarization, language modeling and other text generation tasks. + We provide reference implementations of various sequence modeling papers:
List of implemented papers

-- **Convolutional Neural Networks (CNN)** - - [Language Modeling with Gated Convolutional Networks (Dauphin et al., 2017)](examples/language_model/conv_lm/README.md) - - [Convolutional Sequence to Sequence Learning (Gehring et al., 2017)](examples/conv_seq2seq/README.md) - - [Classical Structured Prediction Losses for Sequence to Sequence Learning (Edunov et al., 2018)](https://github.com/pytorch/fairseq/tree/classic_seqlevel) - - [Hierarchical Neural Story Generation (Fan et al., 2018)](examples/stories/README.md) - - [wav2vec: Unsupervised Pre-training for Speech Recognition (Schneider et al., 2019)](examples/wav2vec/README.md) -- **LightConv and DynamicConv models** - - [Pay Less Attention with Lightweight and Dynamic Convolutions (Wu et al., 2019)](examples/pay_less_attention_paper/README.md) -- **Long Short-Term Memory (LSTM) networks** - - Effective Approaches to Attention-based Neural Machine Translation (Luong et al., 2015) -- **Transformer (self-attention) networks** - - Attention Is All You Need (Vaswani et al., 2017) - - [Scaling Neural Machine Translation (Ott et al., 2018)](examples/scaling_nmt/README.md) - - [Understanding Back-Translation at Scale (Edunov et al., 2018)](examples/backtranslation/README.md) - - [Adaptive Input Representations for Neural Language Modeling (Baevski and Auli, 2018)](examples/language_model/README.adaptive_inputs.md) - - [Lexically constrained decoding with dynamic beam allocation (Post & Vilar, 2018)](examples/constrained_decoding/README.md) - - [Mixture Models for Diverse Machine Translation: Tricks of the Trade (Shen et al., 2019)](examples/translation_moe/README.md) - - [RoBERTa: A Robustly Optimized BERT Pretraining Approach (Liu et al., 2019)](examples/roberta/README.md) - - [Facebook FAIR's WMT19 News Translation Task Submission (Ng et al., 2019)](examples/wmt19/README.md) - - [Jointly Learning to Align and Translate with Transformer Models (Garg et al., 2019)](examples/joint_alignment_translation/README.md ) - - [Multilingual Denoising Pre-training for Neural Machine Translation (Liu et at., 2020)](examples/mbart/README.md) - - [Neural Machine Translation with Byte-Level Subwords (Wang et al., 2020)](examples/byte_level_bpe/README.md) - - [Unsupervised Quality Estimation for Neural Machine Translation (Fomicheva et al., 2020)](examples/unsupervised_quality_estimation/README.md) - - [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations (Baevski et al., 2020)](examples/wav2vec/README.md) - - [Generating Medical Reports from Patient-Doctor Conversations Using Sequence-to-Sequence Models (Enarvi et al., 2020)](examples/pointer_generator/README.md) - - [Linformer: Self-Attention with Linear Complexity (Wang et al., 2020)](examples/linformer/README.md) - - [Cross-lingual Retrieval for Iterative Self-Supervised Training (Tran et al., 2020)](examples/criss/README.md) - - [Deep Transformers with Latent Depth (Li et al., 2020)](examples/latent_depth/README.md) -- **Non-autoregressive Transformers** - - Non-Autoregressive Neural Machine Translation (Gu et al., 2017) - - Deterministic Non-Autoregressive Neural Sequence Modeling by Iterative Refinement (Lee et al. 2018) - - Insertion Transformer: Flexible Sequence Generation via Insertion Operations (Stern et al. 2019) - - Mask-Predict: Parallel Decoding of Conditional Masked Language Models (Ghazvininejad et al., 2019) - - [Levenshtein Transformer (Gu et al., 2019)](examples/nonautoregressive_translation/README.md) -- **Finetuning** - - [Better Fine-Tuning by Reducing Representational Collapse (Aghajanyan et al. 2020)](examples/rxf/README.md) +* **Convolutional Neural Networks (CNN)** + + [Language Modeling with Gated Convolutional Networks (Dauphin et al., 2017)](examples/language_model/conv_lm/README.md) + + [Convolutional Sequence to Sequence Learning (Gehring et al., 2017)](examples/conv_seq2seq/README.md) + + [Classical Structured Prediction Losses for Sequence to Sequence Learning (Edunov et al., 2018)](https://github.com/pytorch/fairseq/tree/classic_seqlevel) + + [Hierarchical Neural Story Generation (Fan et al., 2018)](examples/stories/README.md) + + [wav2vec: Unsupervised Pre-training for Speech Recognition (Schneider et al., 2019)](examples/wav2vec/README.md) +* **LightConv and DynamicConv models** + + [Pay Less Attention with Lightweight and Dynamic Convolutions (Wu et al., 2019)](examples/pay_less_attention_paper/README.md) +* **Long Short-Term Memory (LSTM) networks** + + Effective Approaches to Attention-based Neural Machine Translation (Luong et al., 2015) +* **Transformer (self-attention) networks** + + Attention Is All You Need (Vaswani et al., 2017) + + [Scaling Neural Machine Translation (Ott et al., 2018)](examples/scaling_nmt/README.md) + + [Understanding Back-Translation at Scale (Edunov et al., 2018)](examples/backtranslation/README.md) + + [Adaptive Input Representations for Neural Language Modeling (Baevski and Auli, 2018)](examples/language_model/README.adaptive_inputs.md) + + [Lexically constrained decoding with dynamic beam allocation (Post & Vilar, 2018)](examples/constrained_decoding/README.md) + + [Mixture Models for Diverse Machine Translation: Tricks of the Trade (Shen et al., 2019)](examples/translation_moe/README.md) + + [RoBERTa: A Robustly Optimized BERT Pretraining Approach (Liu et al., 2019)](examples/roberta/README.md) + + [Facebook FAIR's WMT19 News Translation Task Submission (Ng et al., 2019)](examples/wmt19/README.md) + + [Jointly Learning to Align and Translate with Transformer Models (Garg et al., 2019)](examples/joint_alignment_translation/README.md ) + + [Multilingual Denoising Pre-training for Neural Machine Translation (Liu et at., 2020)](examples/mbart/README.md) + + [Neural Machine Translation with Byte-Level Subwords (Wang et al., 2020)](examples/byte_level_bpe/README.md) + + [Unsupervised Quality Estimation for Neural Machine Translation (Fomicheva et al., 2020)](examples/unsupervised_quality_estimation/README.md) + + [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations (Baevski et al., 2020)](examples/wav2vec/README.md) + + [Generating Medical Reports from Patient-Doctor Conversations Using Sequence-to-Sequence Models (Enarvi et al., 2020)](examples/pointer_generator/README.md) + + [Linformer: Self-Attention with Linear Complexity (Wang et al., 2020)](examples/linformer/README.md) + + [Cross-lingual Retrieval for Iterative Self-Supervised Training (Tran et al., 2020)](examples/criss/README.md) + + [Deep Transformers with Latent Depth (Li et al., 2020)](examples/latent_depth/README.md) +* **Non-autoregressive Transformers** + + Non-Autoregressive Neural Machine Translation (Gu et al., 2017) + + Deterministic Non-Autoregressive Neural Sequence Modeling by Iterative Refinement (Lee et al. 2018) + + Insertion Transformer: Flexible Sequence Generation via Insertion Operations (Stern et al. 2019) + + Mask-Predict: Parallel Decoding of Conditional Masked Language Models (Ghazvininejad et al., 2019) + + [Levenshtein Transformer (Gu et al., 2019)](examples/nonautoregressive_translation/README.md) +* **Finetuning** + + [Better Fine-Tuning by Reducing Representational Collapse (Aghajanyan et al. 2020)](examples/rxf/README.md)

### What's New: -- October 2020: [Added R3F/R4F (Better Fine-Tuning) code](examples/rxf/README.md) -- October 2020: [Deep Transformer with Latent Depth code released](examples/latent_depth/README.md) -- October 2020: [Added CRISS models and code](examples/criss/README.md) -- September 2020: [Added Linformer code](examples/linformer/README.md) -- September 2020: [Added pointer-generator networks](examples/pointer_generator/README.md) -- August 2020: [Added lexically constrained decoding](examples/constrained_decoding/README.md) -- August 2020: [wav2vec2 models and code released](examples/wav2vec/README.md) -- July 2020: [Unsupervised Quality Estimation code released](examples/unsupervised_quality_estimation/README.md) -- May 2020: [Follow fairseq on Twitter](https://twitter.com/fairseq) -- April 2020: [Monotonic Multihead Attention code released](examples/simultaneous_translation/README.md) -- April 2020: [Quant-Noise code released](examples/quant_noise/README.md) -- April 2020: [Initial model parallel support and 11B parameters unidirectional LM released](examples/megatron_11b/README.md) +* November 2020: Adopted [Hydra](https://github.com/facebookresearch/hydra) as a configuration framework; +[added documentation explaining how to use it for new and existing projects](docs/hydra_integration.md) +* October 2020: [Added R3F/R4F (Better Fine-Tuning) code](examples/rxf/README.md) +* October 2020: [Deep Transformer with Latent Depth code released](examples/latent_depth/README.md) +* October 2020: [Added CRISS models and code](examples/criss/README.md) +* September 2020: [Added Linformer code](examples/linformer/README.md) +* September 2020: [Added pointer-generator networks](examples/pointer_generator/README.md) +* August 2020: [Added lexically constrained decoding](examples/constrained_decoding/README.md) +* August 2020: [wav2vec2 models and code released](examples/wav2vec/README.md) +* July 2020: [Unsupervised Quality Estimation code released](examples/unsupervised_quality_estimation/README.md) +* May 2020: [Follow fairseq on Twitter](https://twitter.com/fairseq) +* April 2020: [Monotonic Multihead Attention code released](examples/simultaneous_translation/README.md) +* April 2020: [Quant-Noise code released](examples/quant_noise/README.md) +* April 2020: [Initial model parallel support and 11B parameters unidirectional LM released](examples/megatron_11b/README.md) +
Previous updates

-- March 2020: [Byte-level BPE code released](examples/byte_level_bpe/README.md) -- February 2020: [mBART model and code released](examples/mbart/README.md) -- February 2020: [Added tutorial for back-translation](https://github.com/pytorch/fairseq/tree/master/examples/backtranslation#training-your-own-model-wmt18-english-german) -- December 2019: [fairseq 0.9.0 released](https://github.com/pytorch/fairseq/releases/tag/v0.9.0) -- November 2019: [VizSeq released (a visual analysis toolkit for evaluating fairseq models)](https://facebookresearch.github.io/vizseq/docs/getting_started/fairseq_example) -- November 2019: [CamemBERT model and code released](examples/camembert/README.md) -- November 2019: [BART model and code released](examples/bart/README.md) -- November 2019: [XLM-R models and code released](examples/xlmr/README.md) -- September 2019: [Nonautoregressive translation code released](examples/nonautoregressive_translation/README.md) -- August 2019: [WMT'19 models released](examples/wmt19/README.md) -- July 2019: fairseq relicensed under MIT license -- July 2019: [RoBERTa models and code released](examples/roberta/README.md) -- June 2019: [wav2vec models and code released](examples/wav2vec/README.md) +* March 2020: [Byte-level BPE code released](examples/byte_level_bpe/README.md) +* February 2020: [mBART model and code released](examples/mbart/README.md) +* February 2020: [Added tutorial for back-translation](https://github.com/pytorch/fairseq/tree/master/examples/backtranslation#training-your-own-model-wmt18-english-german) +* December 2019: [fairseq 0.9.0 released](https://github.com/pytorch/fairseq/releases/tag/v0.9.0) +* November 2019: [VizSeq released (a visual analysis toolkit for evaluating fairseq models)](https://facebookresearch.github.io/vizseq/docs/getting_started/fairseq_example) +* November 2019: [CamemBERT model and code released](examples/camembert/README.md) +* November 2019: [BART model and code released](examples/bart/README.md) +* November 2019: [XLM-R models and code released](examples/xlmr/README.md) +* September 2019: [Nonautoregressive translation code released](examples/nonautoregressive_translation/README.md) +* August 2019: [WMT'19 models released](examples/wmt19/README.md) +* July 2019: fairseq relicensed under MIT license +* July 2019: [RoBERTa models and code released](examples/roberta/README.md) +* June 2019: [wav2vec models and code released](examples/wav2vec/README.md)

### Features: -- multi-GPU training on one machine or across multiple machines (data and model parallel) -- fast generation on both CPU and GPU with multiple search algorithms implemented: - - beam search - - Diverse Beam Search ([Vijayakumar et al., 2016](https://arxiv.org/abs/1610.02424)) - - sampling (unconstrained, top-k and top-p/nucleus) - - lexically constrained decoding ([Post & Vilar, 2018](examples/constrained_decoding/README.md)) -- large mini-batch training even on a single GPU via delayed updates -- mixed precision training (trains faster with less GPU memory on [NVIDIA tensor cores](https://developer.nvidia.com/tensor-cores)) -- extensible: easily register new models, criterions, tasks, optimizers and learning rate schedulers +* multi-GPU training on one machine or across multiple machines (data and model parallel) +* fast generation on both CPU and GPU with multiple search algorithms implemented: + + beam search + + Diverse Beam Search ([Vijayakumar et al., 2016](https://arxiv.org/abs/1610.02424)) + + sampling (unconstrained, top-k and top-p/nucleus) + + lexically constrained decoding ([Post & Vilar, 2018](examples/constrained_decoding/README.md)) +* large mini-batch training even on a single GPU via delayed updates +* mixed precision training (trains faster with less GPU memory on [NVIDIA tensor cores](https://developer.nvidia.com/tensor-cores)) +* extensible: easily register new models, criterions, tasks, optimizers and learning rate schedulers +* [flexible configuration](docs/hydra_integration.md) based on [Hydra](https://github.com/facebookresearch/hydra) allowing a combination of code, command-line and file based configuration We also provide [pre-trained models for translation and language modeling](#pre-trained-models-and-examples) with a convenient `torch.hub` interface: -```python + +``` python en2de = torch.hub.load('pytorch/fairseq', 'transformer.wmt19.en-de.single_model') en2de.translate('Hello world', beam=5) # 'Hallo Welt' ``` + See the PyTorch Hub tutorials for [translation](https://pytorch.org/hub/pytorch_fairseq_translation/) and [RoBERTa](https://pytorch.org/hub/pytorch_fairseq_roberta/) for more examples. @@ -116,7 +123,8 @@ and [RoBERTa](https://pytorch.org/hub/pytorch_fairseq_roberta/) for more example * Python version >= 3.6 * For training new models, you'll also need an NVIDIA GPU and [NCCL](https://github.com/NVIDIA/nccl) * **To install fairseq** and develop locally: -```bash + +``` bash git clone https://github.com/pytorch/fairseq cd fairseq pip install --editable ./ @@ -124,18 +132,20 @@ pip install --editable ./ # on MacOS: # CFLAGS="-stdlib=libc++" pip install --editable ./ ``` + * **For faster training** install NVIDIA's [apex](https://github.com/NVIDIA/apex) library: -```bash + +``` bash git clone https://github.com/NVIDIA/apex cd apex pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" \ --global-option="--deprecated_fused_adam" --global-option="--xentropy" \ --global-option="--fast_multihead_attn" ./ ``` -* **For large datasets** install [PyArrow](https://arrow.apache.org/docs/python/install.html#using-pip): `pip install pyarrow` -* If you use Docker make sure to increase the shared memory size either with -`--ipc=host` or `--shm-size` as command line options to `nvidia-docker run`. +* **For large datasets** install [PyArrow](https://arrow.apache.org/docs/python/install.html#using-pip): `pip install pyarrow` +* If you use Docker make sure to increase the shared memory size either with `--ipc=host` or `--shm-size` + as command line options to `nvidia-docker run` . # Getting Started @@ -148,30 +158,31 @@ types and tasks. We provide pre-trained models and pre-processed, binarized test sets for several tasks listed below, as well as example training and evaluation commands. -- [Translation](examples/translation/README.md): convolutional and transformer models are available -- [Language Modeling](examples/language_model/README.md): convolutional and transformer models are available +* [Translation](examples/translation/README.md): convolutional and transformer models are available +* [Language Modeling](examples/language_model/README.md): convolutional and transformer models are available We also have more detailed READMEs to reproduce results from specific papers: -- [Cross-lingual Retrieval for Iterative Self-Supervised Training (Tran et al., 2020)](examples/criss/README.md) -- [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations (Baevski et al., 2020)](examples/wav2vec/README.md) -- [Unsupervised Quality Estimation for Neural Machine Translation (Fomicheva et al., 2020)](examples/unsupervised_quality_estimation/README.md) -- [Training with Quantization Noise for Extreme Model Compression ({Fan*, Stock*} et al., 2020)](examples/quant_noise/README.md) -- [Neural Machine Translation with Byte-Level Subwords (Wang et al., 2020)](examples/byte_level_bpe/README.md) -- [Multilingual Denoising Pre-training for Neural Machine Translation (Liu et at., 2020)](examples/mbart/README.md) -- [Reducing Transformer Depth on Demand with Structured Dropout (Fan et al., 2019)](examples/layerdrop/README.md) -- [Jointly Learning to Align and Translate with Transformer Models (Garg et al., 2019)](examples/joint_alignment_translation/README.md) -- [Levenshtein Transformer (Gu et al., 2019)](examples/nonautoregressive_translation/README.md) -- [Facebook FAIR's WMT19 News Translation Task Submission (Ng et al., 2019)](examples/wmt19/README.md) -- [RoBERTa: A Robustly Optimized BERT Pretraining Approach (Liu et al., 2019)](examples/roberta/README.md) -- [wav2vec: Unsupervised Pre-training for Speech Recognition (Schneider et al., 2019)](examples/wav2vec/README.md) -- [Mixture Models for Diverse Machine Translation: Tricks of the Trade (Shen et al., 2019)](examples/translation_moe/README.md) -- [Pay Less Attention with Lightweight and Dynamic Convolutions (Wu et al., 2019)](examples/pay_less_attention_paper/README.md) -- [Understanding Back-Translation at Scale (Edunov et al., 2018)](examples/backtranslation/README.md) -- [Classical Structured Prediction Losses for Sequence to Sequence Learning (Edunov et al., 2018)](https://github.com/pytorch/fairseq/tree/classic_seqlevel) -- [Hierarchical Neural Story Generation (Fan et al., 2018)](examples/stories/README.md) -- [Scaling Neural Machine Translation (Ott et al., 2018)](examples/scaling_nmt/README.md) -- [Convolutional Sequence to Sequence Learning (Gehring et al., 2017)](examples/conv_seq2seq/README.md) -- [Language Modeling with Gated Convolutional Networks (Dauphin et al., 2017)](examples/language_model/README.conv.md) + +* [Cross-lingual Retrieval for Iterative Self-Supervised Training (Tran et al., 2020)](examples/criss/README.md) +* [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations (Baevski et al., 2020)](examples/wav2vec/README.md) +* [Unsupervised Quality Estimation for Neural Machine Translation (Fomicheva et al., 2020)](examples/unsupervised_quality_estimation/README.md) +* [Training with Quantization Noise for Extreme Model Compression ({Fan*, Stock*} et al., 2020)](examples/quant_noise/README.md) +* [Neural Machine Translation with Byte-Level Subwords (Wang et al., 2020)](examples/byte_level_bpe/README.md) +* [Multilingual Denoising Pre-training for Neural Machine Translation (Liu et at., 2020)](examples/mbart/README.md) +* [Reducing Transformer Depth on Demand with Structured Dropout (Fan et al., 2019)](examples/layerdrop/README.md) +* [Jointly Learning to Align and Translate with Transformer Models (Garg et al., 2019)](examples/joint_alignment_translation/README.md) +* [Levenshtein Transformer (Gu et al., 2019)](examples/nonautoregressive_translation/README.md) +* [Facebook FAIR's WMT19 News Translation Task Submission (Ng et al., 2019)](examples/wmt19/README.md) +* [RoBERTa: A Robustly Optimized BERT Pretraining Approach (Liu et al., 2019)](examples/roberta/README.md) +* [wav2vec: Unsupervised Pre-training for Speech Recognition (Schneider et al., 2019)](examples/wav2vec/README.md) +* [Mixture Models for Diverse Machine Translation: Tricks of the Trade (Shen et al., 2019)](examples/translation_moe/README.md) +* [Pay Less Attention with Lightweight and Dynamic Convolutions (Wu et al., 2019)](examples/pay_less_attention_paper/README.md) +* [Understanding Back-Translation at Scale (Edunov et al., 2018)](examples/backtranslation/README.md) +* [Classical Structured Prediction Losses for Sequence to Sequence Learning (Edunov et al., 2018)](https://github.com/pytorch/fairseq/tree/classic_seqlevel) +* [Hierarchical Neural Story Generation (Fan et al., 2018)](examples/stories/README.md) +* [Scaling Neural Machine Translation (Ott et al., 2018)](examples/scaling_nmt/README.md) +* [Convolutional Sequence to Sequence Learning (Gehring et al., 2017)](examples/conv_seq2seq/README.md) +* [Language Modeling with Gated Convolutional Networks (Dauphin et al., 2017)](examples/language_model/README.conv.md) # Join the fairseq community @@ -188,7 +199,7 @@ The license applies to the pre-trained models as well. Please cite as: -```bibtex +``` bibtex @inproceedings{ott2019fairseq, title = {fairseq: A Fast, Extensible Toolkit for Sequence Modeling}, author = {Myle Ott and Sergey Edunov and Alexei Baevski and Angela Fan and Sam Gross and Nathan Ng and David Grangier and Michael Auli}, diff --git a/docs/hydra_integration.md b/docs/hydra_integration.md index 0973cd279..f924de961 100644 --- a/docs/hydra_integration.md +++ b/docs/hydra_integration.md @@ -1,111 +1,240 @@ - - ## Hydra -Hydra is an open-source Python framework that simplifies the development of research and other complex applications. The key feature is the ability to dynamically create a hierarchical configuration by composition and override it through config files and the command line. The name Hydra comes from its ability to run multiple similar jobs - much like a Hydra with multiple heads. +[Hydra](https://github.com/facebookresearch/hydra) is an open-source Python framework that simplifies the development of +research and other complex applications. The key feature is the ability to dynamically create a hierarchical +configuration by composition and override it through config files and the command line. The name Hydra comes from its +ability to run multiple similar jobs - much like a Hydra with multiple heads. -## Train models with hydra interface +## Motivation -#### Provide parameters in `.yaml` files -For example, if we'd like to train a language model with transformer, we could provide parameters in yaml files. Note that the modules used (task, model, criterion, optimizer, lr scheduler) in training must be migrated with hydra interface already (See session below). +Until recently, all components in fairseq were configured through a shared "args" namespace that was created at +application startup. Components declared their own "add_args" method to update the argparse parser, hoping that +the names would not clash with arguments from other components. While this model works for smaller applications, +as fairseq grew and became integrated into other applications, this became problematic. +In order to determine how to configure each component, one needed to a) examine what args were added by this component, and +b) read the code to figure out what shared arguments it is using that were added in other places. Reproducing +models involved sharing commands that often contained dozens of command line switches. -- Provide top level choices on which generic parameter file, and which modules to use: `config/config.yaml`, this will look like for example: +The model described above is still supported by fairseq for backward compatibility, but will be deprecated some time +in the future. -``` -defaults: - - task: language_modeling - - model: transformer_lm - - criterion: cross_entropy - - optimizer: adam - - lr_scheduler: inverse_sqrt +New components in fairseq should now create a dataclass that encapsulates all parameters required to configure this +component. The dataclass is registered along with the component, and fairseq takes care of constructing and +providing this configuration object to the component's constructor. Note that sharing parameters can optionally +still work, but one has to explicitly point to the "source of truth" (see inheritance example below). +These changes make components in fairseq +more independent and re-usable by other applications: all that is needed to create a component is to initialize its +dataclass and overwrite some of the defaults. + +While configuring fairseq through command line (using either the legacy argparse based or the new Hydra based entry points) is still +fully supported, you can now take advantage of configuring fairseq completely or piece-by-piece through +hierarchical YAML configuration files. These files can also be shipped as examples that others can use to run +an identically configured job. + +Additionally, Hydra has a rich and growing +[library of plugins](https://github.com/facebookresearch/hydra/tree/master/plugins) that provide functionality such as +hyperparameter sweeping (including using bayesian optimization through the [Ax](https://github.com/facebook/Ax) library), +job launching across various platforms, and more. + +## Creating or migrating components + +In general, each new (or updated) component should provide a companion [dataclass](https://www.python.org/dev/peps/pep-0557/). These dataclass are typically located in the same +file as the component and are passed as arguments to the register_*() functions. Top-level configs that should be +present in every fairseq application are placed in the [global](fairseq/dataclass/configs.py) config file and added +to the FairseqConfig object. + +Each dataclass is a plain-old-data object, similar to a NamedTuple. These classes are decorated with a @dataclass +decorator, and typically inherit from `FairseqDataclass` (which adds some functionality for backward compatibility). +Each field must have a type, and generally has metadata (such as a help string) and a default value. Only primitive types or other config objects are allowed as +data types for each field. + + Example: + + +``` python +from dataclasses import dataclass, field +from fairseq.dataclass import FairseqDataclass + +@dataclass +class InteractiveConfig(FairseqDataclass): + buffer_size: int = field( + default=0, + metadata={ + "help": "read this many sentences into a buffer before processing them" + }, + ) + input: str = field( + default="-", + metadata={"help": "file to read from; use - for stdin"}, + ) ``` -- Provide generic parameters common across different jobs: `config.yaml` -- Provide task parameters: `config/task/language_modeling.yaml` -- Provide model parameters: `config/model/transformer_lm.yaml` -- Provide criterion parameters: `config/criterion/cross_entropy.yaml` -- Provide optimizer parameters: `config/optimizer/adam.yaml` -- Provide lr_scheduler parameters `config/lr_scheduler/inverse_sqrt.yaml` +### Inherting values -#### Command line overriding -`train_hydra.py` is the main entry point for training with hydra interface. If we specify all parameters we want in `.yaml` files, then we could simply use command: +Some components require sharing a value. For example, a learning rate scheduler and an optimizer may both need to +know the initial learning rate value. One can declare a field that, by default, will +inherit its value from another config node in the same hierarchy: -``` -# task.data is requested field marked by `???` in yaml -python fairseq_cli/train_hydra.py \ -task.data=/private/home/abaevski/data/wiki103 \ +``` python +@dataclass +FairseqAdamConfig(FairseqDataclass): + ... + lr: List[float] = II("optimization.lr") + ... ``` -Alternatively, if we need to override certain params from the command line, we could do so as below (note the structure of where each parameter sits) +`II("optimization.lr")` is syntactic sugar for `"${optimization.lr}"` , which is the value one can use in a YAML config file or through +command line to achieve the same effect. Note that this assumes that there is an "optimization" config object +in the root config and it has a field called "lr". -``` -python fairseq_cli/train_hydra.py -task=language_modeling \ -task.data=/private/home/abaevski/data/wiki103 \ -task.tokens_per_sample=512 \ -task.sample_break_mode=none \ -model=transformer_lm \ -model.share_decoder_input_output_embed=true \ -model.dropout=0.1 \ -optimizer=adam \ -optimizer.adam_betas="'(0.9, 0.98)'" \ -optimizer.weight_decay=0.01 \ -lr_scheduler=inverse_sqrt \ -lr_scheduler.warmup_updates=4000 \ -lr_scheduler.warmup_init_lr=1e-07 \ -criterion=cross_entropy \ -common.fp16=true \ -common.log_format=json \ -common.log_interval=1 \ -dataset.max_tokens=1024 \ -dataset.num_workers=4 \ -optimization.update_freq=[16] \ -optimization.max_update=50000 \ -optimization.clip_norm=0.0 \ -optimization.lr=[0.0005] \ -checkpoint.save_dir=/checkpoint/mtian/transformer_wikitext-103-hydra-args-cli \ -checkpoint.save_interval_updates=10 +### Tasks and Models + +Creating Tasks and Models works same as before, except that legacy implementations now inherit from Legacy* base classes, +while new components inherit from FairseqTask and FairseqModel and provide a dataclass to the register_*() functions. + +Task example: + +``` python +@dataclass +class LanguageModelingConfig(FairseqDataclass): + data: Optional[str] = field( + default=None, metadata={"help": "path to data directory"} + ) + ... + +@register_task("language_modeling", dataclass=LanguageModelingConfig) +class LanguageModelingTask(LegacyFairseqTask): + ... + @classmethod + def setup_task(cls, cfg: LanguageModelingConfig): + ... ``` -## Migrate existing/Creating new modules to hydra interface +Model example: -In each of the modules we want to migrated/create with hydra interface, fundamentally we need to +``` python +@dataclass +class TransformerLanguageModelConfig(FairseqDataclass): + activation_fn: ChoiceEnum(utils.get_available_activation_fns()) = field( + default="relu", metadata={"help": "activation function to use"} + ) + dropout: float = field(default=0.1, metadata={"help": "dropout probability"}) + ... -- Provide a dataclass that layouts the parameters used in the module. +@register_model("transformer_lm", dataclass=TransformerLanguageModelConfig) +class TransformerLanguageModel(FairseqLanguageModel): + ... + @classmethod + def build_model(cls, cfg: TransformerLanguageModelConfig, task: FairseqTask): + ... +``` -- Modify the builder and/or constructor that previously takes `argparse.Namespace` argument `args`, into taking `omegaconf.DictConfig` config objects. At this moment we allow `Union[omegaconf.DictConfig, argparse.Namespace]` to support compatibility. +### Other components -- For `add_args()`, we need to extract argument from the dataclass defined in the same file, and append them into `parser`. This is also to support compatibility. This is simply supported with `gen_parser_from_dataclass` API, see examples files below. +Other components work as before, but they now take their configuration dataclass as the only constructor argument: -#### Migrated examples: +``` python +@dataclass +class MosesTokenizerConfig(FairseqDataclass): + source_lang: str = field(default="en", metadata={"help": "source language"}) + ... -- Task: `fairseq/tasks/language_modeling.py` +@register_tokenizer("moses", dataclass=MosesTokenizerConfig) +class MosesTokenizer(object): + def __init__(self, cfg: MosesTokenizerConfig): + ... +``` -- Model: `fairseq/models/transformer_lm.py` +Note that if you are adding a new registry for a new set of components, you need to add it to the FairseqConfig object in +fairseq/dataclass/configs.py: -- Criterion: `fairseq/criterions/adaptive_loss.py` and `fairseq/criterions/cross_entropy.py` +``` python +@dataclass +class FairseqConfig(object): + ... + my_new_registry: Any = None +``` -- Optimizer: `fairseq/optim/adam.py` and `fairseq/optim/nag.py` +## Training with hydra_train.py -- LR scheduler: `fairseq/optim/lr_scheduler/cosine_lr_scheduler.py` and `fairseq/optim/lr_scheduler/inverse_square_root_schedule.py` +To fully take advantage of configuration flexibility offered by Hydra, you may want to train new models using the +hydra_train.py entry point located in the fairseq_cli directory. Legacy CLI tools such as train.py, +will remain supported for the foreseeable future but will be deprecated eventually. +On startup, Hydra will create a configuration object that contains a hierarchy of all the necessary dataclasses +populated with their default values in the code. The default values are overwritten by values found in YAML files in +fairseq/config directory (which currently just set default task, optimizer, etc) and then further overwritten by values +provided through command line arguments. -## Interpolate parameters across different places +Some of the most common use cases are shown below: -## Support of legacy interface -If you still like to pass legacy style arguments in command line, `fairseq_cli/train.py` can support this. Internally it coverted `args` into hydra config objects whenever there are migrated modules aligned. +### 1. Overwrite default values through command line: + +```shell script +python fairseq_cli/hydra_train.py distributed_training.distributed_world_size=1 dataset.batch_size=2 task.data=data-bin \ +model=transformer_lm/transformer_lm_gpt task=language_modeling optimization.max_update=5000 ``` -python fairseq_cli/train.py --task language_modeling \ -/private/home/abaevski/data/wiki103 \ ---save-dir /checkpoint/mtian/transformer_wikitext-103-hydra-args-cli \ ---arch transformer_lm --share-decoder-input-output-embed \ ---dropout 0.1 \ ---optimizer adam --adam-betas '(0.9, 0.98)' --weight-decay 0.01 --clip-norm 0.0 \ ---lr 0.0005 --lr-scheduler inverse_sqrt --warmup-updates 4000 --warmup-init-lr 1e-07 \ ---tokens-per-sample 512 --sample-break-mode none \ ---max-tokens 1024 --update-freq 16 \ ---fp16 \ ---max-update 50000 --log-format json --log-interval 1 --num-workers 4 \ ---save-interval-updates 10 + +Note that along with explicitly providing values for parameters such as dataset.batch_size, this also tells Hydra to overlay configuration found in `fairseq/config/model/transformer_lm/transformer_lm_gpt.yaml` +over the default values in the dataclass. If you want to train a model without specifying a particular architecture +you can simply specify model=transformer_lm. This only works for migrated tasks and models. + +### 2. Replace bundled configs with an external config: + +```shell script +python fairseq_cli/hydra_train.py --config-path /path/to/external/configs --config-name wiki103 ``` + +where /path/to/external/configs/wiki103.yaml contains: + +``` yaml +# @package _group_ + +model: + _name: transformer_lm +distributed_training: + distributed_world_size: 1 +dataset: + batch_size: 2 +task: + _name: language_modeling + data: /path/to/data + add_bos_token: false + max_target_positions: 1024 +optimization: + max_update: 50000 + lr: [ 0.25 ] +criterion: cross_entropy +optimizer: adam +lr_scheduler: + _name: cosine +``` + +Note that here bundled configs from `fairseq/config` directory are not used, however the defaults from each dataclass will still be used (unless overwritten by your external config). + +Additionally you can choose to break up your configs by creating a directory structure in the same location as your main config file, with the names of the top-level fields +(such as "model", "dataset", etc), and placing config files with meaningful names that would populate that specific section of your +top-level config file (for example, you might have model/small_transformer_lm.yaml, model/big_transformer_lm.yaml, etc). You can then specify the correct configuration via command line, defaults in the main config, or even launch all of them as a sweep (see Hydra documentation on how to do this). + +### 3. Add an external config directory to Hydra search path: + +This allows combining default configuration (including using any bundled config files), while specifying your own config files for some parts of the configuration. + +```shell script +python fairseq_cli/hydra_train.py distributed_training.distributed_world_size=1 dataset.batch_size=2 \ +task.data=/path/to/data/ model=transformer_lm/2_layers task=language_modeling optimization.max_update=5000 \ +--config-dir /path/to/external/configs + +``` + +where /path/to/external/configs has the following structure: +``` +. ++-- model +| +-- transformer_lm +| | +-- 2_layers.yaml +``` + +and 2_layers.yaml contains a copy of transformer_lm_gpt.yaml but with decoder_layers set to 2. You can add +other configs to configure other components as well. diff --git a/fairseq/data/encoders/moses_tokenizer.py b/fairseq/data/encoders/moses_tokenizer.py index fa004dd4a..e236dad16 100644 --- a/fairseq/data/encoders/moses_tokenizer.py +++ b/fairseq/data/encoders/moses_tokenizer.py @@ -24,7 +24,7 @@ class MosesTokenizerConfig(FairseqDataclass): @register_tokenizer("moses", dataclass=MosesTokenizerConfig) class MosesTokenizer(object): - def __init__(self, cfg): + def __init__(self, cfg: MosesTokenizerConfig): self.cfg = cfg try: