Commit Graph

1943 Commits

Author SHA1 Message Date
Omry Yadan
dd106d9534 fixes tests/test_train.py to mock checkpoint.save_dir config node (#3675)
Summary:
## What does this PR do?
Some downstream users reported that errors when passing Namespace to load_checkpoint().

A recent change made the assumption that the passed object is dict like (dict or DictConfig) that have a get function.
This changes that and make sure the mocked config have checkpoint.save_dir to allow the test to run.

Pull Request resolved: https://github.com/pytorch/fairseq/pull/3675

Reviewed By: omry

Differential Revision: D29564805

Pulled By: lematt1991

fbshipit-source-id: 89308811da382667f6c5d3152ee2d6480416ee62
2021-07-06 15:07:31 -07:00
Wei-Ning Hsu
cdc1a553eb query tgt_dict after loading task_state (#2019)
Summary:
# Before submitting
`self.task.target_dictionary` is queried before `task_state` is loaded (in `self.load_model_ensemble()`).

## What does this PR do?
Fix the bug above

Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/2019

Reviewed By: alexeib

Differential Revision: D29523921

Pulled By: wnhsu

fbshipit-source-id: 763b504dc1b4899e623eaa5c19972cec9d0a8985
2021-07-01 13:12:28 -07:00
alexeib
096f492a22 fix xlsr checkpoint finetuning saving issues (#2013)
Summary:
fixes an issue with some old checkpoints that had deep nested namespaces containing choices enum - most prominently xlsr 53 checkpoint

fixes #3634

Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/2013

Reviewed By: xuqiantong

Differential Revision: D29511325

Pulled By: alexeib

fbshipit-source-id: 79df978afa7482b4ce3aaf7396e193626181aa17
2021-07-01 08:36:57 -07:00
Omry Yadan
9bee82e4a7 Hydra 1.1 compatibility: Use an explicit schema for the primary config (#3659)
Summary:
## What does this PR do?
Fixes compatibility with Hydra 1.1.
The result is compatible with both Hydra 1.0 and Hydra 1.1, and will allow a smoother migration to Hydra 1.1.

At this point I am not yet removing the restriction on the Hydra version from setup.py:
1. It depends on some Hydra 1.1 changes that are not yet released (It will be compatible with 1.1.1).
2. Upgrading will result in deprecation warnings, and fixing them will break compatibility with Hydra 1.0.

There will be some followup to make the code fully compatible with 1.1 once Hydra 1.1 is the default version in fbcode.

Pull Request resolved: https://github.com/pytorch/fairseq/pull/3659

Reviewed By: omry

Differential Revision: D29498036

Pulled By: lematt1991

fbshipit-source-id: 96999cde5daad6749ef4d3ddf6a36a1e984ff201
2021-07-01 06:37:47 -07:00
Edan Tessel Sneh
0794f9ae21 Back out "Adding FBSequenceGenerator"
Summary:
Original commit changeset: b7a83bbc719d

Reverts commit D26228721 (6381aa2bb2)

Reviewed By: theweiho

Differential Revision: D29369494

fbshipit-source-id: 9e745b11bc532ca8ced2816326aa94afbb46ba2d
2021-06-29 15:10:09 -07:00
Liang Luo
0972dde844 apply nonblocking H/D transfer optimizations
Summary:
merge D27701492 + D27701493
* make checkpoint activation cpu offloading nonblocking
* make gradient cpu offloading nonblocking
* synchronize cpu/gpu stream before applying optimizer update

Reviewed By: myleott

Differential Revision: D28047171

fbshipit-source-id: f862eca64049acc045026aa4f5e6dbe8d0f03244
2021-06-29 00:03:00 -07:00
Pierre Andrews
53bf2b1293 Extract File Chunking to its own utils (#1955)
Summary:
## What does this PR do?

there are a few places where we do file chunking for multiprocessing a single file. However, the code is partly in Binarizer and partly just duplicated here and there.

This PR extracts the file chunking/reading logic. The multiprocessing logic could probably be extracted too, but I haven't found a good abstraction yet.

# Testing

Added testing for this reading logic + maybe fixed a bug where the last part of a file might get dropped (even if it's unclear with the current stopping logic)

Tested by running the preprocessing script as follow:
```
python -m fairseq_cli.preprocess --source-lang de --target-lang en --trainpref ...train.spm.clean.de_en --srcdict ...fairseq.dict --tgtdict .../fairseq.dict --destdir ... --workers 60
```

Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1955

Reviewed By: myleott

Differential Revision: D29065473

Pulled By: Mortimerp9

fbshipit-source-id: c60843de8cfd45a63b3dbb8290f57ef3df3bf983
2021-06-28 01:46:32 -07:00
Kushal Lakhotia
f8871521f7 Load dict from pretrained hubert model in HubertEncoder (#1999)
Summary:
## What does this PR do?
Load dict from pretrained hubert model in HubertEncoder so that the dictionary is not constructed for the labels.

Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1999

Test Plan:
Tested with the cmdline below. ASR training progresses as expected without any exception.

```
PYTHONPATH=. HYDRA_FULL_ERROR=1 python fairseq_cli/hydra_train.py -m \
--config-dir examples/hubert/config/finetune \
--config-name base_10h \
dataset.num_workers=0 \
task.data=/checkpoint/kushall/data/librispeech/10h/raw \
task.label_dir=/checkpoint/kushall/data/librispeech/10h/raw \
model.w2v_path=/checkpoint/kushall/final_model_checkpoints/hubert/hubert_base_ls960_updated.pt \
hydra.sweep.dir=/checkpoint/kushall/experiments/hubert_test/base_asr_10h
```

Reviewed By: Abdel-rahmanMohamed

Differential Revision: D29405491

Pulled By: hikushalhere

fbshipit-source-id: be168a0ce27f8fcfea3dc980a192ba43fdf23871
2021-06-26 08:59:56 -07:00
Shiyan Deng
81046fc13e Add decoder and decoding wrapper for nmt
Summary:
Add a decoder class `FairSeqNVFasterTransformerDecoder` that could replace `TransformerDecoder` in nmt.
Add a decoding class `FairSeqNVFasterTransformerDecoding` that does `decoding + beam serach`.

We can't use `FairSeqNVFasterTransformerDecoding` right now in nmt because nmt ensembles decoders and calculate avg probabilities across those decoders.

Follow ups:
1. Currently `FairSeqNVFasterTransformerDecoder` doesn't produce "attn" https://fburl.com/code/pom5vhr5.
2. Move mem_cache and cache to incremental_state
2. Benchmark fairseq ft encoder decoder.
2. E2e tests stucks at somewhere.

Differential Revision: D29166310

fbshipit-source-id: 36360cfff1d22ed4f12f89068ee30dec835d2141
2021-06-24 14:04:19 -07:00
Alex Liu
520d9d3ba6 remove debug code from w2vu gen (#1997)
Summary:
see title

Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1997

Reviewed By: wnhsu, Alexander-H-Liu

Differential Revision: D29371459

Pulled By: alexeib

fbshipit-source-id: 874e36462f919aa4ba698a0dd49531c89f7e27cf
2021-06-24 14:01:15 -07:00
Ashwyn Sharma
7818f6148d Tuna integration and model packaging
Reviewed By: sravyapopuri388

Differential Revision: D29118016

fbshipit-source-id: d183c821e5d8eb1b37dda48ded9e24e5efc65dc7
2021-06-23 11:14:45 -07:00
Eduardo Romero
7ca8bc12c0 KMeans Attention
Summary: KMeans attention main file

Reviewed By: yiq-liu

Differential Revision: D28478149

fbshipit-source-id: 97ef1408cfa239bdf13ee5d54d5d31b61a7f2236
2021-06-22 09:14:59 -07:00
Nithin-Holla
3c4a8e4155 Enabling word-level timestamps for Wav2Vec 2.0 (#3627)
Summary:
# Before submitting

- [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements)
- [x] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)?
- [ ] Did you make sure to update the docs?
- [ ] Did you write any new necessary tests?

## What does this PR do?
Fixes https://github.com/pytorch/fairseq/issues/3371.

Currently, the output from Wav2Vec 2.0 decoding does not contain word-level start/end times, which can be useful for certain applications of ASR. Based on the discussion [here](https://github.com/flashlight/flashlight/issues/618), they could be computed based on the output from the Flashlight decoder. For the KenLM decoder, we could first obtain the frame number corresponding to each non-blank token. Next, the timestamp of each character could be computed as `segment_start + frame_no/total_frames * segment_duration`. Finally, the start and end time of each word could be calculated based on the timestamp of the word boundary characters. In order to enable this, the frame number of each non-blank character is returned as a result of KenLM decoding. This is similar to the `timesteps` output from the [ctcdecode](https://github.com/parlance/ctcdecode#outputs-from-the-decode-method) library.

## PR review
alexeib

Pull Request resolved: https://github.com/pytorch/fairseq/pull/3627

Reviewed By: michaelauli

Differential Revision: D29282488

Pulled By: alexeib

fbshipit-source-id: b5fe64bf50abd7ef8e9539f4e338937c866eb0ca
2021-06-21 20:16:59 -07:00
Wei-Ning Hsu
900a607ea3 add timit w2vu recipe (#1991)
Summary:
## What does this PR do?
Add TIMIT data preparation scripts for wav2vec-U

Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1991

Reviewed By: alexeib

Differential Revision: D29284481

Pulled By: wnhsu

fbshipit-source-id: dccd75159a9de4f3cd95f9e4a90ce4bdf9264f2b
2021-06-21 19:41:13 -07:00
Neeyanth Kopparapu
e47a4c84da hotfix to change factory creation for dictionaries (#1987)
Summary:
## What does this PR do?
Fixes issue of creating factories causing errors because the lambda function is not proper.

## PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

## Did you have fun?
Make sure you had fun coding �

Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1987

Test Plan:
Completed a small pretraining+finetuning procedure:

Pretraining:
```
PYTHONPATH=. python /private/home/neeyanth/project/fairseq/fairseq_cli/hydra_train.py -m \
	--config-dir ${fairseq_dir}/examples/hubert/config/pretrain \
	--config-name hubert_base_librispeech \
	hydra/launcher=submitit_local \
	hydra.launcher.gpus_per_node=1 \
	hydra.launcher.cpus_per_task=8 \
	hydra.launcher.mem_gb=384 \
	task.data=${tsv_dir} \
	task.label_dir=${km_dir} \
	task.labels=["km"] \
	+data=iter1 \
	optimization.max_update=250 \
	hydra.sweep.dir=${exp_dir} \
	hydra.run.dir=${exp_dir} > ${exp_dir}/log.out 2> ${exp_dir}/log.err &
```

Finetuning:
```
PYTHONPATH=. python /private/home/neeyanth/project/fairseq/fairseq_cli/hydra_train.py -m \
	--config-dir ${fairseq_dir}/examples/hubert/config/finetune \
	--config-name base_10h \
	hydra/launcher=submitit_local \
	hydra.launcher.gpus_per_node=1 \
	hydra.launcher.cpus_per_task=8 \
	hydra.launcher.mem_gb=384 \
	task.data=${tsv_dir} \
	task.label_dir=${tsv_dir} \
	model.w2v_path=${model_dir} \
	+data=iter1 \
	optimization.max_update=250 \
	hydra.sweep.dir=${exp_dir} \
	hydra.run.dir=${exp_dir} > ${exp_dir}/log.out 2> ${exp_dir}/log.err &
```

Reviewed By: hikushalhere

Differential Revision: D29266136

Pulled By: neeyanthkvk

fbshipit-source-id: d36c668ae38a7761b4c44f4dcb0c4cc8e15e42ce
2021-06-21 12:33:06 -07:00
alexeib
822442e42a fix task name in w2v-u generate (#1989)
Summary:
see title

Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1989

Reviewed By: arbabu123

Differential Revision: D29267899

Pulled By: alexeib

fbshipit-source-id: b89b804c14dbf8779b5cb56657d33bb03530f303
2021-06-21 11:06:18 -07:00
Sravya Popuri
fc77eeb550 Change char_inputs to export as recommended in Fairseq
Summary: TSIA

Reviewed By: jmp84, henryhu6

Differential Revision: D29232406

fbshipit-source-id: 557006705faf28d723dc9f0ed9e92b0abe68e895
2021-06-18 13:48:10 -07:00
Sravya Popuri
b3491ae9d4 Add latency metrics to simulate tuna inference script and some other minor updates
Summary:
- Add average lagging latency metrics for online model. Offline models by default return 0
- Pad smaller input chunks with 0.
- Enable export option in layer norm in transformer.py to avoid errors in scripted model inference.
- Warm up prediction for scripted online model
- Add additional args like force_read_cnt, data_split

Reviewed By: xutaima

Differential Revision: D28881594

fbshipit-source-id: fd4cce017539b5d8f6e39f9af9651341e47d6db0
2021-06-17 14:00:39 -07:00
Vimal Manohar
67138ceb08 Fix lr for reduce_lr_on_plateau when there is no warmup
Summary:
warmup_init_lr should not be used if there is no warmup i.e. warmup_updates = 0

Created from Diffusion's 'Open in Editor' feature.

Reviewed By: myleott

Differential Revision: D29174059

fbshipit-source-id: c2e4cf998aebcff090584e689f692a0abe082e65
2021-06-17 13:20:10 -07:00
Neeyanth Kopparapu
afc77bdf4b Enabled storing of dictionaries (#3601)
Summary:
## What does this PR do?
For HubertPretrainingTask, added dictionaries to the task state to enable the serialization of the dictionaries (thus removing the need to load from the disk after training)

## PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

Pull Request resolved: https://github.com/pytorch/fairseq/pull/3601

Test Plan:
To verify the success, run the Hubert Pretraining pipeline, load a checkpoint model, and verify that the "dictionaries" key is present in the state within the model.

Specifically,
```
PYTHONPATH=. python /path/to/fairseq/fairseq_cli/hydra_train.py -m \
        --config-dir ${fairseq_dir}/examples/hubert/config/pretrain \
        --config-name hubert_base_librispeech \
        hydra/launcher=submitit_local \
        hydra.launcher.gpus_per_node=2 \
        hydra.launcher.cpus_per_task=8 \
        hydra.launcher.mem_gb=384 \
        task.data=${tsv_dir} \
        task.label_dir=${km_dir} \
        task.labels=["km"] \
        +data=iter1 \
        optimization.max_update=250 \
        hydra.sweep.dir=${exp_dir} \
        hydra.run.dir=${exp_dir} > ${exp_dir}/log.out 2> ${exp_dir}/log.err &
```
Then, at the location of the model, load the model using `pytorch.load`, and verifying that "dictionaries" is a key under the `task_state` key of the model.

## Did you have fun?
Make sure you had fun coding �

Reviewed By: wnhsu

Differential Revision: D28995537

Pulled By: neeyanthkvk

fbshipit-source-id: e10c5163c367285518961b3ce1e719a29da06aa6
2021-06-15 22:02:56 -07:00
Kushal Lakhotia
128b4fc378 Check attributes in trainer and checkpoint loading before using them (#1970)
Summary:
## What does this PR do?
Fixes None exception when some attributes in  don't exist in cfg.

## PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

## Did you have fun?
Make sure you had fun coding �

Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1970

Reviewed By: alexeib

Differential Revision: D29140036

Pulled By: hikushalhere

fbshipit-source-id: 7d941bcae6bb000c281a43ca2cd0876a49912ab9
2021-06-15 14:09:52 -07:00
Kushal Lakhotia
8320f6708f Instructions for loading HuBERT model (#1966)
Summary:
## What does this PR do?
Fixes the HuBERT README to contain instructions to load pretrained checkpoints.

## PR review
Tested in a fresh environment that doesn't have access to FAIR's dev env.

Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1966

Reviewed By: wnhsu

Differential Revision: D29117906

Pulled By: hikushalhere

fbshipit-source-id: 89b0407ecf8cdbeddcab80f55e6b2f1fed24c967
2021-06-15 10:50:22 -07:00
msbaines
cd5775f301 avoid freezing batches unnecessarily (#3610)
Summary:
In EpochBatchIterator, first_batch() freezes batches in order to
generate the dummy_batch. We then freeze batches again in the call
to next_epoch_itr(). We can avoid the second freeze and reduce
time to first iteration by about 50% in cases where we have a
callable batch_sampler.

Before:

![Screen Shot 2021-06-10 at 5 08 22 PM](https://user-images.githubusercontent.com/35972327/121613200-d2366600-ca10-11eb-9d1d-bafc2403766a.png)

After:

![Screen Shot 2021-06-10 at 5 07 54 PM](https://user-images.githubusercontent.com/35972327/121613224-dfebeb80-ca10-11eb-9d5a-07be9440db77.png)

# Before submitting

- [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements)
- [ ] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)?
- [ ] Did you make sure to update the docs?
- [ ] Did you write any new necessary tests?

## What does this PR do?
Fixes # (issue).

## PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

## Did you have fun?
Make sure you had fun coding �

Pull Request resolved: https://github.com/pytorch/fairseq/pull/3610

Reviewed By: myleott

Differential Revision: D29105845

Pulled By: msbaines

fbshipit-source-id: 9795d46d70a99ad1218ce225092cc22ee3192bbc
2021-06-14 16:38:31 -07:00
Henry Hu
176b2e4e76 Fix warning for empty tensor without type
Summary:
Fairseq create an empty tensor without type.
It will create warning for torchscript model.
Warning: Creating a tensor from an empty intlist will create a tensor of default floating point type  (currently Float) in python but a tensor of type int in torchscript.

This diff adds definition of the type.

Reviewed By: myleott

Differential Revision: D29081170

fbshipit-source-id: 5c32aae65c9998b245eac43bfedc820bea509338
2021-06-14 09:10:15 -07:00
Valentin Andrei
c36294ea4f Do FP16/BF16 conversions on the host to transfer less through PCIe
Summary:
If we do the FP16/BF32 conversion on the host, we do it at DRAM speed but transfer 2X smaller buffer to the GPU through PCIe. PCIe bandwidth is an order of magnitude lower so we actually gain about 50% of execution time compared to when performing the quantization on the GPU.

Also, by transfering an already FP16 buffer, we save memory capacity.

Reviewed By: zhengwy888

Differential Revision: D24146486

fbshipit-source-id: b897e7a32835aa1b571b0fae5f3d72a131ad16a1
2021-06-11 18:00:39 -07:00
alexeib
f8a7c93440 W2v u update (#1954)
Summary:
updating the scripts and examples to be easier to follow

Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1954

Reviewed By: wnhsu

Differential Revision: D29041166

Pulled By: alexeib

fbshipit-source-id: d9410c6e925337b810e92b393e226869ef9e1733
2021-06-10 21:58:41 -07:00
Diana Liskovich
50158da3a7 Migrate DummyMaskedLMTask to FairseqTask (#3593)
Summary:
# Before submitting

- [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements)
- [ ] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)?
- [ ] Did you make sure to update the docs?
- [ ] Did you write any new necessary tests?

## What does this PR do?
Fixes # (issue).

## PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

## Did you have fun?
Make sure you had fun coding �

Pull Request resolved: https://github.com/pytorch/fairseq/pull/3593

Reviewed By: msbaines

Differential Revision: D28992614

Pulled By: dianaml0

fbshipit-source-id: b2dfcab472a65c41536e78600a0e6b3745dc3a08
2021-06-10 09:43:08 -07:00
Naman Goyal
2fd9d8a972 released xlmr xl and xxl model weights (#1944)
Summary:
# Before submitting

- [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements)
- [ ] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)?
- [ ] Did you make sure to update the docs?
- [ ] Did you write any new necessary tests?

## What does this PR do?
Fixes # (issue).

## PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

## Did you have fun?
Make sure you had fun coding �

Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1944

Reviewed By: jingfeidu

Differential Revision: D28944206

fbshipit-source-id: 583837f7dd387341574d27dd9acc145455d640a8
2021-06-07 15:05:53 -07:00
Yun Wang
fc391ff697 Fix loading some TALNet models
Summary:
D28728718 cleaned up the "kd_binary_cross_entropy" criterion, but this caused loading old models trained with this criterion to fail.
This diff replaces the "kd_binary_cross_entropy" criterion with the "wav2vec" criterion when loading models, and fixes this error.

It also removes the "log_keys" argument if it's `None`. Some criteria (e.g. wav2vec) require this argument to be a list, and will supply a default value of `[]` when it's absent. The presence of the `None` value prevents the use of this default value and causes an error.

Differential Revision: D28901263

fbshipit-source-id: 9b33aed35e76d2c734d1d4e2cbca1ff193a8c920
2021-06-04 16:19:56 -07:00
Mandeep Singh Baines
45d8fefaa6 fix logging when running single-process (#3592)
Summary:
In file_io.py, there is a logging message that happens in the global
scope. This logging message can be invoked before calling
logging.basicConfig() in fairseq_cli/train.py resulting in that
call becoming a no-op. This was causing the loglevel to remain at
WARNING.

Fix is to call logging.basicConfig() before import-ing any fairseq
libraries that may do logging in global scope.

Verified that I logging.info messages are now visible after applying
this PR.

# Before submitting

- [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements)
- [ ] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)?
- [ ] Did you make sure to update the docs?
- [ ] Did you write any new necessary tests?

## What does this PR do?
Fixes # (issue).

## PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

## Did you have fun?
Make sure you had fun coding �

Pull Request resolved: https://github.com/pytorch/fairseq/pull/3592

Reviewed By: sujitoc

Differential Revision: D28900871

Pulled By: msbaines

fbshipit-source-id: ff5393aa7c5e4cbec168ff0b846da048de76cdbc
2021-06-04 11:25:12 -07:00
Yun Wang
3084b812be Teacher-student learning for TALNet
Summary:
This diff implements teacher-student learning for TALNet.

Three classes take part in the teacher-student learning:
* The task loads the teacher models;
* The model generates predictions using the teacher models, and mixes them with the original targets;
* The `Wav2VecCriterion` reads the mixed targets to compute the loss. However, it still uses the original targets to compute the MAP and MAUC metrics.

There are two types of teachers:
* Static teachers: a file that stores predictions on training data which have been produced by running a model offline;
* Dynamic teachers: model files that are loaded at the beginning of training and executed on the fly to produce predictions.

We actually no longer use static teachers. The code about static teachers are copied over from the `KnowledgeDistillationBinaryCrossEntropyCriterion` class. This class will be cleaned up in D28728718.

The teacher models are stored in the task object, and will not be saved into checkpoints.

Reviewed By: alexeib

Differential Revision: D28728707

fbshipit-source-id: 0fcfc00db2e7194a6f7ee687cad9fa72e82a028b
2021-06-03 17:49:12 -07:00
Yun Wang
50f3766a9d TALNet: Use batch size as sample_size
Summary:
`Wav2VecCriterion` uses a sample_size for two purposes:
1. It weights the extra loss by multiplying it by sample_size;
2. It divides the total loss by sample_size before reporting them in the learning curves.

By default, when using the binary cross-entropy loss (`infonce = False`), `Wav2VecCriterion` uses the number of 1's in the label matrix as sample_size.
For TALNet, because each recording may have multiple labels, this sample_size is not a constant across batches.
TALNet also uses a consistency loss between the predictions on two different copies of augmented data as an extra loss, and it is undesirable for the weight of the extra loss to vary from batch to batch.

This diff adds a field "sample_size" to the batch in the `AcousticEventCollater`, and makes it equal to the batch size (number or recordings in a batch).
Because the extra loss is multiplied by sample_size in `Wav2VecCriterion`, this diff also divides the consistency loss by the batch size in the `forward` method of `TALNetModel`.

This diff also adds a unit test for the consistency loss.

Reviewed By: alexeib

Differential Revision: D28728699

fbshipit-source-id: dda1f2a1b02e49b894842c8990218b5fe92d0330
2021-06-03 17:49:11 -07:00
Henry Hu
4950c56f46 Add export flag to transform, so LayerNorm can be TorchScripted.
Summary:
Previously on cuda, LayerNorm would always default to FusedLayerNorm, which could not be exported.

Add export flag, so torch.nn.LayerNorm would be used.

Reviewed By: myleott, mikekgfb, kpuatfb

Differential Revision: D28858633

fbshipit-source-id: 58dd4945f596b2bcc94a6b74356bd9fd3c73ca1a
2021-06-03 16:20:34 -07:00
alexeib
c47a9b2eef fix #3574 (#1921)
Summary:
support pre-hydra w2v models

Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1921

Reviewed By: arbabu123

Differential Revision: D28807630

Pulled By: alexeib

fbshipit-source-id: 0fc8bcda12cf677e909d88678f235bfdeb50e726
2021-06-01 16:43:56 -07:00
alexeib
62ccebaf70 fix '_pickle.PicklingError: Can't pickle <enum 'Choices'>: attribute … (#1915)
Summary:
for whatever reason, checkpoints are failing to save because choiceenum can't be pickled again (could be env specific). this should permanently resolve it by converting choice enum to string in the config before saving

Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1915

Reviewed By: arbabu123

Differential Revision: D28784506

Pulled By: alexeib

fbshipit-source-id: 17843cfa00e8e624eb06262df8e1b71b062a237b
2021-05-31 01:14:37 -07:00
Yun Wang
19793a78e5 Remove duplicate registration of ManifoldPathHandler
Summary: `ManifoldPathHandler` is automatically registered with `IOPathManager` upon importing the latter (see D27960781). Therefore it is no longer necessary to register `ManifoldPathManager` in fairseq, as introduced by D27809504 (3a90a859d4).

Reviewed By: sujitoc

Differential Revision: D28735316

fbshipit-source-id: 03e246dd17ba9f2a9a81dd4e741cce88f26feedd
2021-05-27 14:29:02 -07:00
Mandeep Singh Baines
9497ae3cfb disable raise_if_valid_subsets_unintentionally_ignored check for dummy tasks (#3552)
Summary:
Fixes the following crash:
```python
Traceback (most recent call last):
  File "/private/home/msb/.conda/envs/fairseq-20210102-pt181/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 59, in _wrap
    fn(i, *args)
  File "/private/home/msb/code/fairseq/fairseq/distributed/utils.py", line 328, in distributed_main
    main(cfg, **kwargs)
  File "/private/home/msb/code/fairseq/fairseq_cli/train.py", line 117, in main
    data_utils.raise_if_valid_subsets_unintentionally_ignored(cfg)
  File "/private/home/msb/code/fairseq/fairseq/data/data_utils.py", line 584, in raise_if_valid_subsets_unintentionally_ignored
    other_paths = _find_extra_valid_paths(train_cfg.task.data)
AttributeError: 'Namespace' object has no attribute 'data'
```

Pull Request resolved: https://github.com/pytorch/fairseq/pull/3552

Reviewed By: sshleifer

Differential Revision: D28667773

Pulled By: msbaines

fbshipit-source-id: bc9a633184105dbae0cce58756bb1d379b03980a
2021-05-27 12:15:31 -07:00
Nicola De Cao
c8223e350c fixing prefix_allowed_tokens_fn (#3276)
Summary:
# Before submitting

- [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements)
- [x] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)?
- [x] Did you make sure to update the docs?
- [x] Did you write any new necessary tests?

## What does this PR do?
Fixes the use of `prefix_allowed_tokens_fn` in generation. It was working for `fairseq==0.9.0` (see https://github.com/facebookresearch/GENRE) but with the current version is broken.

## PR review
Anyone in the community is free to review the PR once the tests have passed.

## Did you have fun?
Make sure you had fun coding �

Pull Request resolved: https://github.com/pytorch/fairseq/pull/3276

Reviewed By: alexeib

Differential Revision: D26725494

Pulled By: myleott

fbshipit-source-id: ce3da725f36352687e5cb5d62a59b4c89ce0b0bc
2021-05-26 18:21:49 -07:00
alexeib
e6eddd805e make hydra/infer.py work; also dont break if something is removed fro… (#1903)
Summary:
previously hydra/infer.py did not always work for several reasons which are addressed here

new example usage:

PYTHONPATH=. python examples/speech_recognition/new/infer.py --config-dir examples/speech_recognition/hydra/conf --config-name infer task=audio_pretraining task.data=/path/to/data task.labels=ltr decoding.type=kenlm decoding.lexicon=/path/to/lexicon decoding.lmpath=/path/to/lm dataset.gen_subset=dev_other common_eval.path=/path/to/model.pt decoding.beam=5 decoding.lmweight=2 decoding.wordscore=-1

Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1903

Reviewed By: arbabu123

Differential Revision: D28700795

Pulled By: alexeib

fbshipit-source-id: 66fe454de49c1bf511b3529ac683f1c8cb08e579
2021-05-26 16:29:10 -07:00
Gagandeep Singh
237184e522 Add torch.cuda.amp support (#3460)
Summary:
# Before submitting

- [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements)
- [x] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)?
- [ ] Did you make sure to update the docs?
- [x] Did you write any new necessary tests?

## What does this PR do?
Fixes https://github.com/pytorch/fairseq/issues/3282
Add support for `torch.cuda.amp`
AMP can be enabled by `--amp`, instead of using `--fp16` for the already present full fp16 support.

## PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

## Did you have fun?
Make sure you had fun coding �

Pull Request resolved: https://github.com/pytorch/fairseq/pull/3460

Reviewed By: sshleifer, msbaines

Differential Revision: D27932253

Pulled By: myleott

fbshipit-source-id: 21637aefb5e788c59bf4f3c5de6c4a80f7319543
2021-05-26 14:39:10 -07:00
Weiyi Zheng
8df9e3a4a5 support FSDP sharded_state checkpoint loading during inference
Summary:
using the very useful feature added by QuentinDuval https://github.com/facebookresearch/fairscale/pull/683/files , we can consolidate sharded states into a full regular states. this allows inferences on sharded state almost transparently.

The main complexity comes from trying to be smart about what kind of checkpoint the user wants to load. not sure if this is over-engineering
1. if the file checkpoint-shard0.pt exists, and `--checkpoint-shard-count` is > 1, then we load sharded FSDP checkpoint
2. if checkpoint-shard0.pt exists but --checkpoint-shard-count=1, we load consolidated FSDP checkpoint
3. if checkpoint-shard0.pt does not exist, but --checkpoint-shard-count > 1, we load model parallel checkpoint
4. otherwise we are loading a single, plain checkpoint.

In theory we could be even smarter and load shard0.pt to check how many more checkpoints are needed. this is not implemented, though it will save the user having to specify --checkpoint-shard-count.

Reviewed By: sshleifer

Differential Revision: D28563441

fbshipit-source-id: dcafcaa7c9eaf5c9ff94f55c16bb3424c98dfa59
2021-05-25 17:45:51 -07:00
Kushal Lakhotia
95cf58056d Update model table in README (#1901)
Summary:
## What does this PR do?
Updated the models' table in README to show the model sizes and groups pretrained models followed by fine tuned models.

Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1901

Reviewed By: wnhsu

Differential Revision: D28688952

Pulled By: hikushalhere

fbshipit-source-id: 8621398a785caa3d7bdc68367789ad7f48499d0d
2021-05-25 13:44:33 -07:00
alexeib
5a75b079bf fix saving w2v args in config (#1896)
Summary:
previous changes broke saving updating w2v_args in config as the model had a copy of the config. this change makes the task copy over the field to save. not the nicest approach, but it works for now

Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1896

Reviewed By: arbabu123

Differential Revision: D28658802

Pulled By: alexeib

fbshipit-source-id: a13866c42c3b88c48b8b91864c1bf1aeaeba4e8a
2021-05-24 19:57:23 -07:00
alexeib
30003ba419 fix serialization on python 3.6 (#1894)
Summary:
fixes serialization  errors when using python 3.6

Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1894

Reviewed By: arbabu123

Differential Revision: D28655932

Pulled By: alexeib

fbshipit-source-id: df40f972966e828817a2861e6e907835fe1d9573
2021-05-24 19:10:13 -07:00
Sam Shleifer
2be2f3c7c1 Plasma tests: ask for less disk (#1893)
Summary:
Old logs:
```
/arrow/cpp/src/plasma/store.cc:1274: Allowing the Plasma store to use up to 107.374GB of memory.
```

New logs:
```
... up to 1e-05GB of memory.
```

Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1893

Reviewed By: myleott

Differential Revision: D28641488

Pulled By: sshleifer

fbshipit-source-id: 3373526042cdcbf434c61790be62a09f15e6ad06
2021-05-24 09:00:18 -07:00
alexeib
342d5daf34 propagate quantizer depth and factor args through w2v (#1892)
Summary:
makes quantizer larger which helps accuracy in certain cases

Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1892

Reviewed By: arbabu123

Differential Revision: D28630035

Pulled By: alexeib

fbshipit-source-id: ba5a902ff1623025e7566e901aa81cdf377a7aa0
2021-05-23 21:25:30 -07:00
Patrick von Platen
366974d981 HF Wav2Vec2 Example (#3502)
Summary:
## What does this PR do?
This PR updates some outdated code from the Hugging Face Transformers library to the new, better format.

## PR review
alexeib

## Did you have fun?
Make sure you had fun coding �

Pull Request resolved: https://github.com/pytorch/fairseq/pull/3502

Reviewed By: arbabu123

Differential Revision: D28140574

Pulled By: alexeib

fbshipit-source-id: f03643e7ebba04015d942a3aa9529f7f6600c734
2021-05-23 16:19:53 -07:00
Changhan Wang
49cf3e0bc3 fixing s2t transformer and N-best checkpoint saving
Summary:
- fixing the default value for `encoder_freezing_updates` in s2t transformer
- fixing N-best checkpoint saving: the previous implementation compares the new checkpoint with only the previous best one but not the previous N best ones. This leads to suboptimal results on N-best checkpoint averaging.

Reviewed By: jmp84

Differential Revision: D28546493

fbshipit-source-id: 44ec6d5ab49347f392d71269c5dcfd154b00c11e
2021-05-22 00:22:05 -07:00
Kushal Lakhotia
4aef9036ce Merge Hubert to master (#1877)
Summary:
# Before submitting

- [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements)
- [ X] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)?
- [ X] Did you make sure to update the docs?
- [ X] Did you write any new necessary tests?

## What does this PR do?
This PR adds relevant code for pre-training HuBERT and fine-tuning a pretrained HuBERT for ASR. It also shared trained models of different sizes.

Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1877

Reviewed By: wnhsu

Differential Revision: D28513359

Pulled By: hikushalhere

fbshipit-source-id: 8755862f236b7d840105b0fa8f5461ac053d79cc
2021-05-21 18:40:56 -07:00
Weiyi Zheng
78e75fa3ed attempt to make non-sharded FSDP checkpoint behave like regular checkpoint
Summary:
overall just wondering if feature is desirable. if it is, the next diff which supports loading sharded checkpoint into a consolidated state dict cleaner.

a couple advantages
1. allows resuming from other DDP trainers.
2. allows resuming into other DDP trainers. or FSDP of a different configuration.
3. none-sharded FSDP checkpoint can be loaded with regular load_model_ensemble_and_task()

For old training workflow that's not using `--use-sharded-state`, please rename the checkpoint to remove the "-shard0" for resuming training.

Reviewed By: sshleifer

Differential Revision: D28563032

fbshipit-source-id: ced72bed969319ab6306059721f56e29b2c3d892
2021-05-21 16:18:21 -07:00