fairseq/tests
Liang Tan b3fa5100c6 Add mha prune to fairseq
Summary:
Support multihead attention prune for Fairseq. For example, user can apply pruning on top of Roberta base model by specify the argument "--mha-heads-to-keep 8". Also, user needs to provide a ckpt which is already pruned so that the pruned ckpt can be loaded correctly.

The idea of prune can be summarized as
1. Fine tune model (e.g. roberta encoder) on a certain datasets with regularization
2. After the model is trained. User could use get_reserve_head_index and _adaptive_prune_heads functions to get the top X heads with most importance. Then user uses the rank to prune a new roberta encoder and save the pruned ckpt manually.
3. User will fine tune the the new roberta encoder via the ckpt saved above

To get rid of registering different pruned version of Roberta, I use the argument --mha-heads-to-keep to prune the Roberta model into a pruned version which matches the pruned ckpt.

Reviewed By: dianaml0

Differential Revision: D32449003

fbshipit-source-id: a952fd9ad723a6dbc5c2af574c42f2e9a1fa27dc
2022-01-11 10:09:07 -08:00
..
distributed fix flake8 issues (#2570) 2021-12-09 02:34:30 -08:00
gpu lint fixes (#2834) 2021-12-29 11:50:55 -08:00
speech conformer (#2859) 2022-01-10 16:18:38 -08:00
speech_recognition Enable Hydra configs in fairseq (#1343) (#1510) 2020-10-20 00:32:26 -07:00
__init__.py remediation of S205607 2020-07-17 17:21:51 -07:00
test_activation_checkpointing.py Make checkpoint wrapper pickleable (#1603) 2021-02-06 08:07:32 -08:00
test_amp_optimizer.py Add linting with black (#2678) 2021-11-29 12:32:59 -08:00
test_average_checkpoints.py Apply black+isort (#1357) 2020-10-18 18:14:51 -07:00
test_backtranslation_dataset.py Apply black+isort (#1357) 2020-10-18 18:14:51 -07:00
test_binaries.py Add linting with black (#2678) 2021-11-29 12:32:59 -08:00
test_character_token_embedder.py Apply black+isort (#1357) 2020-10-18 18:14:51 -07:00
test_checkpoint_utils.py Add linting with black (#2678) 2021-11-29 12:32:59 -08:00
test_concat_dataset.py Apply black+isort (#1357) 2020-10-18 18:14:51 -07:00
test_constraints.py fix flake8 issues (#2570) 2021-12-09 02:34:30 -08:00
test_convtbc.py Apply black+isort (#1357) 2020-10-18 18:14:51 -07:00
test_data_utils.py Add linting with black (#2678) 2021-11-29 12:32:59 -08:00
test_dataclass_utils.py Add linting with black (#2678) 2021-11-29 12:32:59 -08:00
test_dataset.py Add support for FullyShardedDataParallel (--ddp-backend=fully_sharded) (#1667) 2021-03-04 13:32:46 -08:00
test_dictionary.py Extract File Chunking to its own utils (#1955) 2021-06-28 01:46:32 -07:00
test_ema.py Add linting with black (#2678) 2021-11-29 12:32:59 -08:00
test_espnet_multihead_attention.py conformer (#2859) 2022-01-10 16:18:38 -08:00
test_export.py Add linting with black (#2678) 2021-11-29 12:32:59 -08:00
test_file_chunker_utils.py Extract File Chunking to its own utils (#1955) 2021-06-28 01:46:32 -07:00
test_file_io.py fix flake8 issues (#2570) 2021-12-09 02:34:30 -08:00
test_fp16_optimizer.py fix flake8 issues (#2570) 2021-12-09 02:34:30 -08:00
test_hf_hub.py formatting fix (#2816) 2021-12-16 16:11:19 -08:00
test_huffman.py Indexed Huffman Coded dataset (#2029) 2021-08-31 01:12:35 -07:00
test_inference_dropout.py Enable Hydra configs in fairseq (#1343) (#1510) 2020-10-20 00:32:26 -07:00
test_iopath.py Add linting with black (#2678) 2021-11-29 12:32:59 -08:00
test_iterators.py skip remainder batch (#2464) 2021-11-24 07:50:50 -08:00
test_label_smoothing.py Apply black+isort (#1357) 2020-10-18 18:14:51 -07:00
test_lm_context_window.py Add linting with black (#2678) 2021-11-29 12:32:59 -08:00
test_lstm_jitable.py Apply black+isort (#1357) 2020-10-18 18:14:51 -07:00
test_memory_efficient_fp16.py Enable Hydra configs in fairseq (#1343) (#1510) 2020-10-20 00:32:26 -07:00
test_metrics.py Apply black+isort (#1357) 2020-10-18 18:14:51 -07:00
test_multi_corpus_dataset.py Add linting with black (#2678) 2021-11-29 12:32:59 -08:00
test_multi_corpus_sampled_dataset.py fix flake8 issues (#2570) 2021-12-09 02:34:30 -08:00
test_multihead_attention.py Add mha prune to fairseq 2022-01-11 10:09:07 -08:00
test_noising.py Add linting with black (#2678) 2021-11-29 12:32:59 -08:00
test_online_backtranslation.py Obt 2 (#1614) 2021-03-30 09:56:03 -07:00
test_plasma_utils.py Plasma tests: ask for less disk (#1893) 2021-05-24 09:00:18 -07:00
test_positional_encoding.py conformer (#2859) 2022-01-10 16:18:38 -08:00
test_reproducibility.py fix flake8 issues (#2570) 2021-12-09 02:34:30 -08:00
test_resampling_dataset.py Apply black+isort (#1357) 2020-10-18 18:14:51 -07:00
test_roberta.py Add regularization for multihead attention module and ffn module 2021-12-30 02:02:05 -08:00
test_rotary_positional_embedding.py conformer (#2859) 2022-01-10 16:18:38 -08:00
test_sequence_generator.py formatting fix (#2816) 2021-12-16 16:11:19 -08:00
test_sequence_scorer.py Apply black+isort (#1357) 2020-10-18 18:14:51 -07:00
test_sparse_multihead_attention.py Apply black+isort (#1357) 2020-10-18 18:14:51 -07:00
test_token_block_dataset.py TokenBlockDataset np type promotion issue (#1658) 2021-02-26 21:00:38 -08:00
test_train.py fixes tests/test_train.py to mock checkpoint.save_dir config node (#3675) 2021-07-06 15:07:31 -07:00
test_transformer.py fix MultiHeadAttention assert (#1798) 2021-04-14 04:59:59 -07:00
test_utils.py Apply black+isort (#1357) 2020-10-18 18:14:51 -07:00
test_valid_subset_checks.py Migrate DummyMaskedLMTask to FairseqTask (#3593) 2021-06-10 09:43:08 -07:00
utils.py lint fixes (#2834) 2021-12-29 11:50:55 -08:00