Add fairseq-hydra-train and update docs (#1449)

Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1449

Test Plan: Imported from OSS

Reviewed By: alexeib

Differential Revision: D25094525

Pulled By: myleott

fbshipit-source-id: 430387d11196d3292933bb168cf09ea16ebc0d3b
This commit is contained in:
Myle Ott 2020-11-20 05:59:25 -08:00 committed by Facebook GitHub Bot
parent 40fbb37443
commit 3b77a61600
7 changed files with 198 additions and 108 deletions

View File

@ -1,57 +1,70 @@
## Hydra ## Hydra
[Hydra](https://github.com/facebookresearch/hydra) is an open-source Python framework that simplifies the development of [Hydra](https://github.com/facebookresearch/hydra) is an open-source Python
research and other complex applications. The key feature is the ability to dynamically create a hierarchical framework that simplifies the development of research and other complex
configuration by composition and override it through config files and the command line. The name Hydra comes from its applications. The key feature is the ability to dynamically create a
ability to run multiple similar jobs - much like a Hydra with multiple heads. hierarchical configuration by composition and override it through config files
and the command line. The name Hydra comes from its ability to run multiple
similar jobs - much like a Hydra with multiple heads.
## Motivation ## Motivation
Until recently, all components in fairseq were configured through a shared "args" namespace that was created at Until recently, all components in fairseq were configured through a shared
application startup. Components declared their own "add_args" method to update the argparse parser, hoping that `args` namespace that was created at application startup. Components declared
the names would not clash with arguments from other components. While this model works for smaller applications, their own `add_args` method to update the argparse parser, hoping that the names
as fairseq grew and became integrated into other applications, this became problematic. would not clash with arguments from other components. While this model works for
In order to determine how to configure each component, one needed to a) examine what args were added by this component, and smaller applications, as fairseq grew and became integrated into other
b) read the code to figure out what shared arguments it is using that were added in other places. Reproducing applications, this became problematic. In order to determine how to configure
models involved sharing commands that often contained dozens of command line switches. each component, one needed to a) examine what args were added by this component,
and b) read the code to figure out what shared arguments it is using that were
added in other places. Reproducing models involved sharing commands that often
contained dozens of command line switches.
The model described above is still supported by fairseq for backward compatibility, but will be deprecated some time The model described above is still supported by fairseq for backward
in the future. compatibility, but will be deprecated some time in the future.
New components in fairseq should now create a dataclass that encapsulates all parameters required to configure this New components in fairseq should now create a dataclass that encapsulates all
component. The dataclass is registered along with the component, and fairseq takes care of constructing and parameters required to configure this component. The dataclass is registered
providing this configuration object to the component's constructor. Note that sharing parameters can optionally along with the component, and fairseq takes care of constructing and providing
still work, but one has to explicitly point to the "source of truth" (see inheritance example below). this configuration object to the component's constructor. Note that sharing
These changes make components in fairseq parameters can optionally still work, but one has to explicitly point to the
more independent and re-usable by other applications: all that is needed to create a component is to initialize its "source of truth" (see inheritance example below). These changes make components
dataclass and overwrite some of the defaults. in fairseq more independent and re-usable by other applications: all that is
needed to create a component is to initialize its dataclass and overwrite some
of the defaults.
While configuring fairseq through command line (using either the legacy argparse based or the new Hydra based entry points) is still While configuring fairseq through command line (using either the legacy argparse
fully supported, you can now take advantage of configuring fairseq completely or piece-by-piece through based or the new Hydra based entry points) is still fully supported, you can now
hierarchical YAML configuration files. These files can also be shipped as examples that others can use to run take advantage of configuring fairseq completely or piece-by-piece through
an identically configured job. hierarchical YAML configuration files. These files can also be shipped as
examples that others can use to run an identically configured job.
Additionally, Hydra has a rich and growing Additionally, Hydra has a rich and growing [library of
[library of plugins](https://github.com/facebookresearch/hydra/tree/master/plugins) that provide functionality such as plugins](https://github.com/facebookresearch/hydra/tree/master/plugins) that
hyperparameter sweeping (including using bayesian optimization through the [Ax](https://github.com/facebook/Ax) library), provide functionality such as hyperparameter sweeping (including using bayesian
job launching across various platforms, and more. optimization through the [Ax](https://github.com/facebook/Ax) library), job
launching across various platforms, and more.
## Creating or migrating components ## Creating or migrating components
In general, each new (or updated) component should provide a companion [dataclass](https://www.python.org/dev/peps/pep-0557/). These dataclass are typically located in the same In general, each new (or updated) component should provide a companion
file as the component and are passed as arguments to the register_*() functions. Top-level configs that should be [dataclass](https://www.python.org/dev/peps/pep-0557/). These dataclass are
present in every fairseq application are placed in the [global](fairseq/dataclass/configs.py) config file and added typically located in the same file as the component and are passed as arguments
to the FairseqConfig object. to the `register_*()` functions. Top-level configs that should be present in
every fairseq application are placed in the
[global](fairseq/dataclass/configs.py) config file and added to the
`FairseqConfig` object.
Each dataclass is a plain-old-data object, similar to a NamedTuple. These classes are decorated with a @dataclass Each dataclass is a plain-old-data object, similar to a `NamedTuple`. These
decorator, and typically inherit from `FairseqDataclass` (which adds some functionality for backward compatibility). classes are decorated with a `@dataclass` decorator, and typically inherit from
Each field must have a type, and generally has metadata (such as a help string) and a default value. Only primitive types or other config objects are allowed as `FairseqDataclass` (which adds some functionality for backward compatibility).
Each field must have a type, and generally has metadata (such as a help string)
and a default value. Only primitive types or other config objects are allowed as
data types for each field. data types for each field.
Example: #### Example:
```python
``` python
from dataclasses import dataclass, field from dataclasses import dataclass, field
from fairseq.dataclass import FairseqDataclass from fairseq.dataclass import FairseqDataclass
@ -71,11 +84,12 @@ class InteractiveConfig(FairseqDataclass):
### Inherting values ### Inherting values
Some components require sharing a value. For example, a learning rate scheduler and an optimizer may both need to Some components require sharing a value. For example, a learning rate scheduler
know the initial learning rate value. One can declare a field that, by default, will and an optimizer may both need to know the initial learning rate value. One can
inherit its value from another config node in the same hierarchy: declare a field that, by default, will inherit its value from another config
node in the same hierarchy:
``` python ```python
@dataclass @dataclass
FairseqAdamConfig(FairseqDataclass): FairseqAdamConfig(FairseqDataclass):
... ...
@ -83,18 +97,21 @@ FairseqAdamConfig(FairseqDataclass):
... ...
``` ```
`II("optimization.lr")` is syntactic sugar for `"${optimization.lr}"` , which is the value one can use in a YAML config file or through `II("optimization.lr")` is syntactic sugar for `"${optimization.lr}"`, which is
command line to achieve the same effect. Note that this assumes that there is an "optimization" config object the value one can use in a YAML config file or through command line to achieve
in the root config and it has a field called "lr". the same effect. Note that this assumes that there is an "optimization" config
object in the root config and it has a field called "lr".
### Tasks and Models ### Tasks and Models
Creating Tasks and Models works same as before, except that legacy implementations now inherit from Legacy* base classes, Creating Tasks and Models works same as before, except that legacy
while new components inherit from FairseqTask and FairseqModel and provide a dataclass to the register_*() functions. implementations now inherit from `LegacyFairseq*` base classes, while new
components inherit from `FairseqTask` and `FairseqModel` and provide a dataclass
to the `register_*()` functions.
Task example: #### Task example:
``` python ```python
@dataclass @dataclass
class LanguageModelingConfig(FairseqDataclass): class LanguageModelingConfig(FairseqDataclass):
data: Optional[str] = field( data: Optional[str] = field(
@ -110,9 +127,9 @@ class LanguageModelingTask(LegacyFairseqTask):
... ...
``` ```
Model example: #### Model example:
``` python ```python
@dataclass @dataclass
class TransformerLanguageModelConfig(FairseqDataclass): class TransformerLanguageModelConfig(FairseqDataclass):
activation_fn: ChoiceEnum(utils.get_available_activation_fns()) = field( activation_fn: ChoiceEnum(utils.get_available_activation_fns()) = field(
@ -131,9 +148,10 @@ class TransformerLanguageModel(FairseqLanguageModel):
### Other components ### Other components
Other components work as before, but they now take their configuration dataclass as the only constructor argument: Other components work as before, but they now take their configuration dataclass
as the only constructor argument:
``` python ```python
@dataclass @dataclass
class MosesTokenizerConfig(FairseqDataclass): class MosesTokenizerConfig(FairseqDataclass):
source_lang: str = field(default="en", metadata={"help": "source language"}) source_lang: str = field(default="en", metadata={"help": "source language"})
@ -145,50 +163,61 @@ class MosesTokenizer(object):
... ...
``` ```
Note that if you are adding a new registry for a new set of components, you need to add it to the FairseqConfig object in Note that if you are adding a new registry for a new set of components, you need
fairseq/dataclass/configs.py: to add it to the `FairseqConfig` object in `fairseq/dataclass/configs.py`:
``` python ```python
@dataclass @dataclass
class FairseqConfig(object): class FairseqConfig(object):
... ...
my_new_registry: Any = None my_new_registry: Any = None
``` ```
## Training with hydra_train.py ## Training with `fairseq-hydra-train`
To fully take advantage of configuration flexibility offered by Hydra, you may want to train new models using the To fully take advantage of configuration flexibility offered by Hydra, you may
hydra_train.py entry point located in the fairseq_cli directory. Legacy CLI tools such as train.py, want to train new models using the `fairseq-hydra-train` entry point. Legacy CLI
will remain supported for the foreseeable future but will be deprecated eventually. tools such as `fairseq-train` will remain supported for the foreseeable future
but will be deprecated eventually.
On startup, Hydra will create a configuration object that contains a hierarchy of all the necessary dataclasses On startup, Hydra will create a configuration object that contains a hierarchy
populated with their default values in the code. The default values are overwritten by values found in YAML files in of all the necessary dataclasses populated with their default values in the
fairseq/config directory (which currently just set default task, optimizer, etc) and then further overwritten by values code. The default values are overwritten by values found in YAML files in
provided through command line arguments. `fairseq/config` directory (which currently sets minimal defaults) and then
further overwritten by values provided through command line arguments.
Some of the most common use cases are shown below: Some of the most common use cases are shown below:
### 1. Overwrite default values through command line: ### 1. Override default values through command line:
```shell script ```shell script
python fairseq_cli/hydra_train.py distributed_training.distributed_world_size=1 dataset.batch_size=2 task.data=data-bin \ $ fairseq-hydra-train \
model=transformer_lm/transformer_lm_gpt task=language_modeling optimization.max_update=5000 distributed_training.distributed_world_size=1 \
dataset.batch_size=2 \
task.data=data-bin \
model=transformer_lm/transformer_lm_gpt \
task=language_modeling \
optimization.max_update=5000
``` ```
Note that along with explicitly providing values for parameters such as dataset.batch_size, this also tells Hydra to overlay configuration found in `fairseq/config/model/transformer_lm/transformer_lm_gpt.yaml` Note that along with explicitly providing values for parameters such as
over the default values in the dataclass. If you want to train a model without specifying a particular architecture `dataset.batch_size`, this also tells Hydra to overlay configuration found in
you can simply specify model=transformer_lm. This only works for migrated tasks and models. `fairseq/config/model/transformer_lm/transformer_lm_gpt.yaml` over the default
values in the dataclass. If you want to train a model without specifying a
particular architecture you can simply specify `model=transformer_lm`. This only
works for migrated tasks and models.
### 2. Replace bundled configs with an external config: ### 2. Replace bundled configs with an external config:
```shell script ```shell script
python fairseq_cli/hydra_train.py --config-path /path/to/external/configs --config-name wiki103 $ fairseq-hydra-train \
--config-path /path/to/external/configs \
--config-name wiki103
``` ```
where /path/to/external/configs/wiki103.yaml contains: where `/path/to/external/configs/wiki103.yaml` contains:
``` yaml ```yaml
# @package _group_ # @package _group_
model: model:
@ -211,24 +240,38 @@ lr_scheduler:
_name: cosine _name: cosine
``` ```
Note that here bundled configs from `fairseq/config` directory are not used, however the defaults from each dataclass will still be used (unless overwritten by your external config). Note that here bundled configs from `fairseq/config` directory are not used,
however the defaults from each dataclass will still be used (unless overwritten
by your external config).
Additionally you can choose to break up your configs by creating a directory structure in the same location as your main config file, with the names of the top-level fields Additionally you can choose to break up your configs by creating a directory
(such as "model", "dataset", etc), and placing config files with meaningful names that would populate that specific section of your structure in the same location as your main config file, with the names of the
top-level config file (for example, you might have model/small_transformer_lm.yaml, model/big_transformer_lm.yaml, etc). You can then specify the correct configuration via command line, defaults in the main config, or even launch all of them as a sweep (see Hydra documentation on how to do this). top-level fields (such as "model", "dataset", etc), and placing config files
with meaningful names that would populate that specific section of your
top-level config file (for example, you might have
`model/small_transformer_lm.yaml`, `model/big_transformer_lm.yaml`, etc). You
can then specify the correct configuration via command line, defaults in the
main config, or even launch all of them as a sweep (see Hydra documentation on
how to do this).
### 3. Add an external config directory to Hydra search path: ### 3. Add an external config directory to Hydra search path:
This allows combining default configuration (including using any bundled config files), while specifying your own config files for some parts of the configuration. This allows combining default configuration (including using any bundled config
files), while specifying your own config files for some parts of the
configuration.
```shell script ```shell script
python fairseq_cli/hydra_train.py distributed_training.distributed_world_size=1 dataset.batch_size=2 \ $ fairseq-hydra-train \
task.data=/path/to/data/ model=transformer_lm/2_layers task=language_modeling optimization.max_update=5000 \ distributed_training.distributed_world_size=1 \
--config-dir /path/to/external/configs dataset.batch_size=2 \
task.data=/path/to/data/ \
model=transformer_lm/2_layers \
task=language_modeling \
optimization.max_update=5000 \
--config-dir /path/to/external/configs
``` ```
where /path/to/external/configs has the following structure: where `/path/to/external/configs` has the following structure:
``` ```
. .
+-- model +-- model
@ -236,5 +279,6 @@ where /path/to/external/configs has the following structure:
| | +-- 2_layers.yaml | | +-- 2_layers.yaml
``` ```
and 2_layers.yaml contains a copy of transformer_lm_gpt.yaml but with decoder_layers set to 2. You can add and `2_layers.yaml` contains a copy of `transformer_lm_gpt.yaml` but with
other configs to configure other components as well. `decoder_layers` set to 2. You can add other configs to configure other
components as well.

View File

@ -56,8 +56,10 @@ This configuration was used for the base model trained on the Librispeech datase
Note that the input is expected to be single channel, sampled at 16 kHz Note that the input is expected to be single channel, sampled at 16 kHz
```shell script ```shell script
$ python fairseq_cli/hydra_train.py task.data=/path/to/data \ $ fairseq-hydra-train \
--config-path /path/to/fairseq-py/examples/wav2vec/config/pretraining --config-name wav2vec2_base_librispeech task.data=/path/to/data \
--config-path /path/to/fairseq-py/examples/wav2vec/config/pretraining \
--config-name wav2vec2_base_librispeech
``` ```
Note: you can simulate 64 GPUs by using k GPUs and adding command line parameters (before --config-path) Note: you can simulate 64 GPUs by using k GPUs and adding command line parameters (before --config-path)
@ -68,8 +70,10 @@ Note: you can simulate 64 GPUs by using k GPUs and adding command line parameter
This configuration was used for the large model trained on the Libri-light dataset in the wav2vec 2.0 paper This configuration was used for the large model trained on the Libri-light dataset in the wav2vec 2.0 paper
```shell script ```shell script
$ python fairseq_cli/hydra_train.py task.data=/path/to/data \ $ fairseq-hydra-train \
--config-path /path/to/fairseq-py/examples/wav2vec/config/pretraining --config-name wav2vec2_large_librivox task.data=/path/to/data \
--config-path /path/to/fairseq-py/examples/wav2vec/config/pretraining \
--config-name wav2vec2_large_librivox
``` ```
Note: you can simulate 128 GPUs by using k GPUs and adding command line parameters (before --config-path) Note: you can simulate 128 GPUs by using k GPUs and adding command line parameters (before --config-path)
@ -88,9 +92,12 @@ $ python libri_labels.py /path/to/tsv --output-dir /output/dir --output-name $sp
Fine-tuning on 100h of Librispeech with letter targets: Fine-tuning on 100h of Librispeech with letter targets:
```shell script ```shell script
python fairseq_cli/hydra_train.py distributed_training.distributed_port=$PORT task.data=/path/to/data \ $ fairseq-hydra-train \
model.w2v_path=/path/to/model.pt --config-path /path/to/fairseq-py/examples/wav2vec/config/finetuning \ distributed_training.distributed_port=$PORT \
--config-name base_100h task.data=/path/to/data \
model.w2v_path=/path/to/model.pt \
--config-path /path/to/fairseq-py/examples/wav2vec/config/finetuning \
--config-name base_100h
``` ```
There are other config files in the config/finetuning directory that can be used to fine-tune on other splits. There are other config files in the config/finetuning directory that can be used to fine-tune on other splits.

View File

@ -1,10 +1,10 @@
# @package _group_ # @package _group_
defaults: defaults:
- task: language_modeling - task: null
- model: null - model: null
- criterion: cross_entropy - criterion: cross_entropy
- optimizer: adam - optimizer: null
- lr_scheduler: cosine - lr_scheduler: fixed
- bpe: null - bpe: null
- tokenizer: null - tokenizer: null
- scoring: null - scoring: null

View File

@ -173,6 +173,12 @@ class CommonConfig(FairseqDataclass):
profile: bool = field( profile: bool = field(
default=False, metadata={"help": "enable autograd profiler emit_nvtx"} default=False, metadata={"help": "enable autograd profiler emit_nvtx"}
) )
reset_logging: bool = field(
default=True,
metadata={
"help": "when using Hydra, reset the logging at the beginning of training"
},
)
@dataclass @dataclass

View File

@ -26,12 +26,14 @@ try:
import xentropy_cuda import xentropy_cuda
from apex.contrib import xentropy from apex.contrib import xentropy
logger.info("using fused cross entropy")
def cross_entropy(logits, target, ignore_index=-100, reduction="mean"): def cross_entropy(logits, target, ignore_index=-100, reduction="mean"):
if logits.device == torch.device("cpu"): if logits.device == torch.device("cpu"):
return _cross_entropy_pytorch(logits, target, ignore_index, reduction) return _cross_entropy_pytorch(logits, target, ignore_index, reduction)
else: else:
if not getattr(cross_entropy, "_has_logged_once", False):
logger.info("using fused cross entropy")
cross_entropy._has_logged_once = True
half_to_float = logits.dtype == torch.half half_to_float = logits.dtype == torch.half
losses = xentropy.SoftmaxCrossEntropyLoss.apply( losses = xentropy.SoftmaxCrossEntropyLoss.apply(
logits, logits,

View File

@ -4,29 +4,32 @@
# This source code is licensed under the MIT license found in the # This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import hydra import logging
from omegaconf import OmegaConf
import os import os
import sys
from fairseq.dataclass.initialize import hydra_init from fairseq.dataclass.initialize import hydra_init
from fairseq_cli.train import main as pre_main from fairseq_cli.train import main as pre_main
from fairseq import distributed_utils from fairseq import distributed_utils
from fairseq.dataclass.configs import FairseqConfig from fairseq.dataclass.configs import FairseqConfig
import logging import hydra
import torch import torch
from omegaconf import OmegaConf
logger = logging.getLogger(__name__) logger = logging.getLogger("fairseq_cli.hydra_train")
@hydra.main(config_path=os.path.join("..", "fairseq", "config"), config_name="config") @hydra.main(config_path=os.path.join("..", "fairseq", "config"), config_name="config")
def hydra_main(cfg: FairseqConfig) -> None: def hydra_main(cfg: FairseqConfig) -> None:
cfg = OmegaConf.create(OmegaConf.to_container(cfg, resolve=True, enum_to_str=True)) cfg = OmegaConf.create(OmegaConf.to_container(cfg, resolve=True, enum_to_str=True))
OmegaConf.set_struct(cfg, True) OmegaConf.set_struct(cfg, True)
if cfg.common.reset_logging:
reset_logging() # Hydra hijacks logging, fix that
if cfg.common.profile: if cfg.common.profile:
with torch.cuda.profiler.profile(): with torch.cuda.profiler.profile():
with torch.autograd.profiler.emit_nvtx(): with torch.autograd.profiler.emit_nvtx():
@ -35,7 +38,22 @@ def hydra_main(cfg: FairseqConfig) -> None:
distributed_utils.call_main(cfg, pre_main) distributed_utils.call_main(cfg, pre_main)
if __name__ == "__main__": def reset_logging():
root = logging.getLogger()
for handler in root.handlers:
root.removeHandler(handler)
root.setLevel(os.environ.get("LOGLEVEL", "INFO").upper())
handler = logging.StreamHandler(sys.stdout)
handler.setFormatter(
logging.Formatter(
fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
)
root.addHandler(handler)
def cli_main():
try: try:
from hydra._internal.utils import get_args from hydra._internal.utils import get_args
@ -46,3 +64,7 @@ if __name__ == "__main__":
hydra_init(cfg_name) hydra_init(cfg_name)
hydra_main() hydra_main()
if __name__ == "__main__":
cli_main()

View File

@ -22,14 +22,18 @@ def write_version_py():
# append latest commit hash to version string # append latest commit hash to version string
try: try:
sha = subprocess.check_output(["git", "rev-parse", "HEAD"]).decode("ascii").strip() sha = (
subprocess.check_output(["git", "rev-parse", "HEAD"])
.decode("ascii")
.strip()
)
version += "+" + sha[:7] version += "+" + sha[:7]
except Exception: except Exception:
pass pass
# write version info to fairseq/version.py # write version info to fairseq/version.py
with open(os.path.join("fairseq", "version.py"), "w") as f: with open(os.path.join("fairseq", "version.py"), "w") as f:
f.write("__version__ = \"{}\"\n".format(version)) f.write('__version__ = "{}"\n'.format(version))
return version return version
@ -194,7 +198,8 @@ def do_setup(package_data):
"tests", "tests",
"tests.*", "tests.*",
] ]
) + extra_packages, )
+ extra_packages,
package_data=package_data, package_data=package_data,
ext_modules=extensions, ext_modules=extensions,
test_suite="tests", test_suite="tests",
@ -202,6 +207,7 @@ def do_setup(package_data):
"console_scripts": [ "console_scripts": [
"fairseq-eval-lm = fairseq_cli.eval_lm:cli_main", "fairseq-eval-lm = fairseq_cli.eval_lm:cli_main",
"fairseq-generate = fairseq_cli.generate:cli_main", "fairseq-generate = fairseq_cli.generate:cli_main",
"fairseq-hydra-train = fairseq_cli.hydra_train:cli_main",
"fairseq-interactive = fairseq_cli.interactive:cli_main", "fairseq-interactive = fairseq_cli.interactive:cli_main",
"fairseq-preprocess = fairseq_cli.preprocess:cli_main", "fairseq-preprocess = fairseq_cli.preprocess:cli_main",
"fairseq-score = fairseq_cli.score:cli_main", "fairseq-score = fairseq_cli.score:cli_main",
@ -230,8 +236,11 @@ try:
fairseq_examples = os.path.join("fairseq", "examples") fairseq_examples = os.path.join("fairseq", "examples")
if "build_ext" not in sys.argv[1:] and not os.path.exists(fairseq_examples): if "build_ext" not in sys.argv[1:] and not os.path.exists(fairseq_examples):
os.symlink(os.path.join("..", "examples"), fairseq_examples) os.symlink(os.path.join("..", "examples"), fairseq_examples)
package_data = { package_data = {
"fairseq": get_files("fairseq/examples"), "fairseq": (
get_files(fairseq_examples) + get_files(os.path.join("fairseq", "config"))
)
} }
do_setup(package_data) do_setup(package_data)
finally: finally: