From d871f6169f8185837d1c11fb28da56abfd83841c Mon Sep 17 00:00:00 2001 From: Alexei Baevski Date: Mon, 12 Dec 2022 08:53:56 -0800 Subject: [PATCH] data2vec v2.0 (#4903) data2v2c 2.0 Co-authored-by: Arun Babu Co-authored-by: Wei-Ning Hsu --- examples/data2vec/README.md | 122 +++ examples/data2vec/__init__.py | 0 .../classification/base_classification.yaml | 70 ++ .../classification/run_config/slurm_1.yaml | 35 + .../classification/run_config/slurm_1g.yaml | 35 + .../classification/run_config/slurm_2.yaml | 35 + .../config/audio/pretraining/audioset.yaml | 91 ++ .../audio/pretraining/run_config/local.yaml | 15 + .../audio/pretraining/run_config/slurm_1.yaml | 37 + .../pretraining/run_config/slurm_1_aws.yaml | 36 + .../audio/pretraining/run_config/slurm_2.yaml | 37 + .../pretraining/run_config/slurm_2_aws.yaml | 37 + .../audio/pretraining/run_config/slurm_3.yaml | 36 + .../audio/pretraining/run_config/slurm_4.yaml | 36 + .../pretraining/run_config/slurm_4_aws.yaml | 37 + .../pretraining/run_config/slurm_6_aws.yaml | 36 + .../pretraining/run_config/slurm_8_aws.yaml | 36 + .../text/pretraining/run_config/local.yaml | 15 + .../pretraining/run_config/slurm_1_aws.yaml | 37 + .../text/pretraining/run_config/slurm_2.yaml | 37 + .../pretraining/run_config/slurm_2_aws.yaml | 37 + .../text/pretraining/run_config/slurm_3.yaml | 36 + .../text/pretraining/run_config/slurm_4.yaml | 36 + .../pretraining/run_config/slurm_4_aws.yaml | 41 + .../pretraining/run_config/slurm_8_aws.yaml | 41 + .../config/v2/base_audio_only_task.yaml | 113 +++ .../config/v2/base_images_only_task.yaml | 116 +++ .../config/v2/base_text_only_task.yaml | 112 +++ .../config/v2/huge_images14_only_task.yaml | 122 +++ .../config/v2/huge_images_only_task.yaml | 120 +++ .../config/v2/large_audio_only_task.yaml | 122 +++ .../config/v2/large_images_only_task.yaml | 120 +++ .../config/v2/large_text_only_task.yaml | 112 +++ .../v2/large_text_only_task_pgrp_1M.yaml | 123 +++ .../data2vec/config/v2/run_config/local.yaml | 15 + .../config/v2/run_config/slurm_1.yaml | 37 + .../config/v2/run_config/slurm_1_aws.yaml | 37 + .../config/v2/run_config/slurm_2.yaml | 37 + .../config/v2/run_config/slurm_2_aws.yaml | 39 + .../config/v2/run_config/slurm_3.yaml | 36 + .../config/v2/run_config/slurm_4.yaml | 36 + .../config/v2/run_config/slurm_4_aws.yaml | 37 + .../config/v2/run_config/slurm_6_aws.yaml | 36 + .../config/v2/run_config/slurm_8.yaml | 37 + .../config/v2/run_config/slurm_8_aws.yaml | 36 + .../config/v2/text_finetuning/cola.yaml | 60 ++ .../config/v2/text_finetuning/mnli.yaml | 60 ++ .../config/v2/text_finetuning/mrpc.yaml | 60 ++ .../config/v2/text_finetuning/qnli.yaml | 59 ++ .../config/v2/text_finetuning/qqp.yaml | 60 ++ .../config/v2/text_finetuning/rte.yaml | 59 ++ .../v2/text_finetuning/run_config/local.yaml | 15 + .../config/v2/text_finetuning/sst_2.yaml | 59 ++ .../config/v2/text_finetuning/sts_b.yaml | 61 ++ .../config/vision/finetuning/imagenet.yaml | 52 ++ .../vision/finetuning/mae_imagenet_clean.yaml | 65 ++ .../finetuning/mae_imagenet_huge_clean.yaml | 68 ++ .../finetuning/mae_imagenet_large_clean.yaml | 68 ++ .../vision/finetuning/run_config/local.yaml | 15 + .../vision/finetuning/run_config/slurm_1.yaml | 37 + .../finetuning/run_config/slurm_1_aws.yaml | 36 + .../vision/finetuning/run_config/slurm_2.yaml | 38 + .../finetuning/run_config/slurm_2_aws.yaml | 38 + .../vision/finetuning/run_config/slurm_3.yaml | 36 + .../vision/finetuning/run_config/slurm_4.yaml | 36 + .../finetuning/run_config/slurm_4_aws.yaml | 36 + .../finetuning/run_config/slurm_6_aws.yaml | 36 + .../finetuning/run_config/slurm_8_aws.yaml | 36 + .../vision/pretraining/base_imagenet.yaml | 52 ++ .../pretraining/base_imagenet_d2v1.yaml | 64 ++ .../vision/pretraining/base_mae_imagenet.yaml | 64 ++ .../vision/pretraining/run_config/local.yaml | 15 + .../pretraining/run_config/slurm_1.yaml | 37 + .../pretraining/run_config/slurm_1_aws.yaml | 36 + .../pretraining/run_config/slurm_2.yaml | 38 + .../pretraining/run_config/slurm_2_aws.yaml | 37 + .../pretraining/run_config/slurm_3.yaml | 36 + .../pretraining/run_config/slurm_4.yaml | 36 + .../pretraining/run_config/slurm_4_aws.yaml | 36 + .../pretraining/run_config/slurm_6_aws.yaml | 36 + .../pretraining/run_config/slurm_8_aws.yaml | 36 + examples/data2vec/data/__init__.py | 17 + .../data2vec/data/add_class_target_dataset.py | 63 ++ examples/data2vec/data/image_dataset.py | 127 +++ .../data/mae_finetuning_image_dataset.py | 135 +++ examples/data2vec/data/mae_image_dataset.py | 418 +++++++++ examples/data2vec/data/modality.py | 14 + examples/data2vec/data/path_dataset.py | 64 ++ examples/data2vec/fb_convert_beit_cp.py | 165 ++++ examples/data2vec/models/__init__.py | 0 .../data2vec/models/audio_classification.py | 614 +++++++++++++ examples/data2vec/models/data2vec2.py | 813 +++++++++++++++++ .../models/data2vec_image_classification.py | 143 +++ .../models/data2vec_text_classification.py | 141 +++ examples/data2vec/models/data2vec_vision.py | 727 +++++++++++++++ examples/data2vec/models/mae.py | 825 ++++++++++++++++++ .../models/mae_image_classification.py | 386 ++++++++ .../data2vec/models/modalities/__init__.py | 0 examples/data2vec/models/modalities/audio.py | 192 ++++ examples/data2vec/models/modalities/base.py | 684 +++++++++++++++ examples/data2vec/models/modalities/images.py | 256 ++++++ .../data2vec/models/modalities/modules.py | 589 +++++++++++++ examples/data2vec/models/modalities/text.py | 161 ++++ examples/data2vec/models/utils.py | 55 ++ .../scripts/convert_audioset_labels.py | 63 ++ .../multi/finetune_all_fair_aws_local_lr.sh | 18 + .../finetune_all_fair_aws_local_lr_nodep.sh | 16 + .../multi/finetune_all_fair_local_lr.sh | 28 + .../finetune_all_char_fair_aws_local_lr.sh | 17 + .../scripts/text/finetune_all_fair.sh | 21 + .../scripts/text/finetune_all_fair_aws.sh | 21 + .../text/finetune_all_fair_aws_local_lr.sh | 17 + .../scripts/text/finetune_all_fair_aws_lr.sh | 23 + .../text/finetune_all_fair_local_lr.sh | 25 + .../scripts/text/finetune_all_fair_nodep.sh | 19 + .../text/finetune_all_fair_nodep_aws.sh | 19 + .../finetune_all_fair_nodep_aws_local_lr.sh | 15 + .../text/finetune_all_fair_nodep_aws_lr.sh | 21 + .../finetune_all_fair_nodep_aws_lr_nopos.sh | 21 + .../finetune_all_large_fair_aws_local_lr.sh | 17 + .../text/finetune_all_large_fair_local_lr.sh | 26 + ...etune_all_large_fair_nodep_aws_local_lr.sh | 15 + .../finetune_sst2_qnli_sweep_fair_nodep.sh | 20 + examples/data2vec/scripts/text/glue.py | 34 + examples/data2vec/scripts/text/glue_lr.py | 143 +++ .../data2vec/scripts/text/unprocess_data.py | 188 ++++ examples/data2vec/scripts/text/valids.py | 301 +++++++ examples/data2vec/tasks/__init__.py | 18 + .../data2vec/tasks/audio_classification.py | 167 ++++ .../data2vec/tasks/image_classification.py | 129 +++ examples/data2vec/tasks/image_pretraining.py | 110 +++ .../tasks/mae_image_classification.py | 100 +++ .../data2vec/tasks/mae_image_pretraining.py | 119 +++ examples/data2vec/tasks/multimodal.py | 165 ++++ .../config/finetuning/run_config/local.yaml | 15 + .../finetuning/run_config/slurm_1g.yaml | 28 + .../finetuning/run_config/slurm_1g_aws.yaml | 25 + .../config/pretraining/run_config/local.yaml | 15 + .../pretraining/run_config/slurm_2.yaml | 37 + .../pretraining/run_config/slurm_2_aws.yaml | 39 + .../pretraining/run_config/slurm_3.yaml | 36 + .../pretraining/run_config/slurm_4.yaml | 36 + .../README.multilingual.pretraining.md | 26 + .../modules/monotonic_multihead_attention.py | 1 + .../tests/test_text_models.py | 2 +- .../new/conf/hydra/sweeper/ax_sil.yaml | 29 + .../speech_recognition/new/conf/infer.yaml | 2 + .../new/conf/run_config/fb_slurm_1.yaml | 28 + .../new/conf/run_config/fb_slurm_2g.yaml | 27 + .../config/finetuning/run_config/slurm_1.yaml | 26 + .../finetuning/run_config/slurm_16.yaml | 27 + .../finetuning/run_config/slurm_1_aws.yaml | 37 + .../finetuning/run_config/slurm_1_old.yaml | 27 + .../config/finetuning/run_config/slurm_2.yaml | 27 + .../finetuning/run_config/slurm_2_aws.yaml | 37 + .../finetuning/run_config/slurm_2g.yaml | 26 + .../config/finetuning/run_config/slurm_3.yaml | 27 + .../finetuning/run_config/slurm_4g.yaml | 26 + .../finetuning/run_config/slurm_4g_aws.yaml | 37 + .../config/finetuning/run_config/slurm_8.yaml | 26 + .../wav2vec/config/finetuning/vox_100h_2.yaml | 106 +++ .../config/finetuning/vox_100h_2_aws.yaml | 82 ++ .../wav2vec/config/finetuning/vox_100h_3.yaml | 101 +++ .../wav2vec/config/finetuning/vox_10h_2.yaml | 102 +++ .../config/finetuning/vox_10h_2_aws.yaml | 81 ++ .../config/finetuning/vox_10h_aws.yaml | 104 +++ .../config/finetuning/vox_10h_aws_v100.yaml | 102 +++ .../wav2vec/config/finetuning/vox_10m_2.yaml | 114 +++ .../config/finetuning/vox_10m_2_aws.yaml | 114 +++ .../wav2vec/config/finetuning/vox_10m_3.yaml | 105 +++ .../wav2vec/config/finetuning/vox_1h_2.yaml | 104 +++ .../config/finetuning/vox_1h_2_aws.yaml | 114 +++ .../wav2vec/config/finetuning/vox_1h_3.yaml | 104 +++ .../wav2vec/config/finetuning/vox_1h_4.yaml | 104 +++ .../wav2vec/config/finetuning/vox_1h_aws.yaml | 80 ++ .../wav2vec/config/finetuning/vox_960h_2.yaml | 105 +++ .../config/finetuning/vox_960h_2_aws.yaml | 82 ++ .../wav2vec/config/finetuning/vox_960h_3.yaml | 101 +++ fairseq/checkpoint_utils.py | 21 +- fairseq/config/fb_run_config/slurm.yaml | 29 + fairseq/criterions/__init__.py | 4 +- fairseq/criterions/ctc.py | 12 +- .../label_smoothed_cross_entropy.py | 3 +- fairseq/criterions/model_criterion.py | 33 +- fairseq/criterions/sentence_prediction.py | 151 +++- fairseq/data/__init__.py | 5 + fairseq/data/add_class_target_dataset.py | 79 ++ fairseq/data/audio/multi_modality_dataset.py | 29 +- fairseq/data/audio/raw_audio_dataset.py | 176 ++-- fairseq/data/data_utils.py | 606 ++++++++++++- fairseq/data/indexed_dataset.py | 5 + fairseq/data/iterators.py | 12 +- fairseq/data/mask_tokens_dataset.py | 202 ++--- fairseq/data/padding_mask_dataset.py | 38 + fairseq/data/subsample_dataset.py | 11 +- fairseq/dataclass/configs.py | 1 + fairseq/dataclass/utils.py | 21 +- fairseq/distributed/utils.py | 51 +- fairseq/iterative_refinement_generator.py | 4 +- fairseq/logging/meters.py | 30 + fairseq/logging/metrics.py | 20 + fairseq/models/__init__.py | 3 +- fairseq/models/fairseq_model.py | 5 + fairseq/models/wav2vec/wav2vec2_asr.py | 143 ++- fairseq/modules/__init__.py | 3 +- fairseq/modules/ema_module.py | 83 +- fairseq/modules/gumbel_vector_quantizer.py | 33 +- fairseq/modules/kmeans_vector_quantizer.py | 19 +- fairseq/modules/multihead_attention.py | 2 +- fairseq/modules/same_pad.py | 12 + fairseq/modules/transpose_last.py | 5 +- fairseq/nan_detector.py | 2 +- fairseq/optim/composite.py | 127 ++- fairseq/optim/fp16_optimizer.py | 5 + fairseq/optim/fused_adam.py | 3 + .../optim/lr_scheduler/cosine_lr_scheduler.py | 5 +- fairseq/registry.py | 6 +- fairseq/tasks/__init__.py | 6 +- fairseq/tasks/audio_finetuning.py | 1 + fairseq/tasks/audio_pretraining.py | 116 +-- fairseq/tasks/fairseq_task.py | 54 +- fairseq/tasks/masked_lm.py | 95 +- fairseq/tasks/sentence_prediction.py | 43 +- fairseq/trainer.py | 73 +- fairseq/utils.py | 4 +- fairseq_cli/hydra_validate.py | 188 ++++ fairseq_cli/train.py | 31 +- .../dependency_submitit_launcher/__init__.py | 3 + .../dependency_submitit_launcher/config.py | 23 + .../dependency_submitit_launcher/launcher.py | 121 +++ .../dependency_submitit_launcher/setup.py | 29 + setup.py | 4 +- tests/test_binaries.py | 7 + tests/test_ema.py | 16 +- tests/test_multihead_attention.py | 1 + tests/test_valid_subset_checks.py | 9 +- 236 files changed, 17327 insertions(+), 522 deletions(-) create mode 100644 examples/data2vec/__init__.py create mode 100644 examples/data2vec/config/audio/classification/base_classification.yaml create mode 100644 examples/data2vec/config/audio/classification/run_config/slurm_1.yaml create mode 100644 examples/data2vec/config/audio/classification/run_config/slurm_1g.yaml create mode 100644 examples/data2vec/config/audio/classification/run_config/slurm_2.yaml create mode 100644 examples/data2vec/config/audio/pretraining/audioset.yaml create mode 100644 examples/data2vec/config/audio/pretraining/run_config/local.yaml create mode 100644 examples/data2vec/config/audio/pretraining/run_config/slurm_1.yaml create mode 100644 examples/data2vec/config/audio/pretraining/run_config/slurm_1_aws.yaml create mode 100644 examples/data2vec/config/audio/pretraining/run_config/slurm_2.yaml create mode 100644 examples/data2vec/config/audio/pretraining/run_config/slurm_2_aws.yaml create mode 100644 examples/data2vec/config/audio/pretraining/run_config/slurm_3.yaml create mode 100644 examples/data2vec/config/audio/pretraining/run_config/slurm_4.yaml create mode 100644 examples/data2vec/config/audio/pretraining/run_config/slurm_4_aws.yaml create mode 100644 examples/data2vec/config/audio/pretraining/run_config/slurm_6_aws.yaml create mode 100644 examples/data2vec/config/audio/pretraining/run_config/slurm_8_aws.yaml create mode 100644 examples/data2vec/config/text/pretraining/run_config/local.yaml create mode 100644 examples/data2vec/config/text/pretraining/run_config/slurm_1_aws.yaml create mode 100644 examples/data2vec/config/text/pretraining/run_config/slurm_2.yaml create mode 100644 examples/data2vec/config/text/pretraining/run_config/slurm_2_aws.yaml create mode 100644 examples/data2vec/config/text/pretraining/run_config/slurm_3.yaml create mode 100644 examples/data2vec/config/text/pretraining/run_config/slurm_4.yaml create mode 100644 examples/data2vec/config/text/pretraining/run_config/slurm_4_aws.yaml create mode 100644 examples/data2vec/config/text/pretraining/run_config/slurm_8_aws.yaml create mode 100644 examples/data2vec/config/v2/base_audio_only_task.yaml create mode 100644 examples/data2vec/config/v2/base_images_only_task.yaml create mode 100644 examples/data2vec/config/v2/base_text_only_task.yaml create mode 100644 examples/data2vec/config/v2/huge_images14_only_task.yaml create mode 100644 examples/data2vec/config/v2/huge_images_only_task.yaml create mode 100644 examples/data2vec/config/v2/large_audio_only_task.yaml create mode 100644 examples/data2vec/config/v2/large_images_only_task.yaml create mode 100644 examples/data2vec/config/v2/large_text_only_task.yaml create mode 100644 examples/data2vec/config/v2/large_text_only_task_pgrp_1M.yaml create mode 100644 examples/data2vec/config/v2/run_config/local.yaml create mode 100644 examples/data2vec/config/v2/run_config/slurm_1.yaml create mode 100644 examples/data2vec/config/v2/run_config/slurm_1_aws.yaml create mode 100644 examples/data2vec/config/v2/run_config/slurm_2.yaml create mode 100644 examples/data2vec/config/v2/run_config/slurm_2_aws.yaml create mode 100644 examples/data2vec/config/v2/run_config/slurm_3.yaml create mode 100644 examples/data2vec/config/v2/run_config/slurm_4.yaml create mode 100644 examples/data2vec/config/v2/run_config/slurm_4_aws.yaml create mode 100644 examples/data2vec/config/v2/run_config/slurm_6_aws.yaml create mode 100644 examples/data2vec/config/v2/run_config/slurm_8.yaml create mode 100644 examples/data2vec/config/v2/run_config/slurm_8_aws.yaml create mode 100644 examples/data2vec/config/v2/text_finetuning/cola.yaml create mode 100644 examples/data2vec/config/v2/text_finetuning/mnli.yaml create mode 100644 examples/data2vec/config/v2/text_finetuning/mrpc.yaml create mode 100644 examples/data2vec/config/v2/text_finetuning/qnli.yaml create mode 100644 examples/data2vec/config/v2/text_finetuning/qqp.yaml create mode 100644 examples/data2vec/config/v2/text_finetuning/rte.yaml create mode 100644 examples/data2vec/config/v2/text_finetuning/run_config/local.yaml create mode 100644 examples/data2vec/config/v2/text_finetuning/sst_2.yaml create mode 100644 examples/data2vec/config/v2/text_finetuning/sts_b.yaml create mode 100644 examples/data2vec/config/vision/finetuning/imagenet.yaml create mode 100644 examples/data2vec/config/vision/finetuning/mae_imagenet_clean.yaml create mode 100644 examples/data2vec/config/vision/finetuning/mae_imagenet_huge_clean.yaml create mode 100644 examples/data2vec/config/vision/finetuning/mae_imagenet_large_clean.yaml create mode 100644 examples/data2vec/config/vision/finetuning/run_config/local.yaml create mode 100644 examples/data2vec/config/vision/finetuning/run_config/slurm_1.yaml create mode 100644 examples/data2vec/config/vision/finetuning/run_config/slurm_1_aws.yaml create mode 100644 examples/data2vec/config/vision/finetuning/run_config/slurm_2.yaml create mode 100644 examples/data2vec/config/vision/finetuning/run_config/slurm_2_aws.yaml create mode 100644 examples/data2vec/config/vision/finetuning/run_config/slurm_3.yaml create mode 100644 examples/data2vec/config/vision/finetuning/run_config/slurm_4.yaml create mode 100644 examples/data2vec/config/vision/finetuning/run_config/slurm_4_aws.yaml create mode 100644 examples/data2vec/config/vision/finetuning/run_config/slurm_6_aws.yaml create mode 100644 examples/data2vec/config/vision/finetuning/run_config/slurm_8_aws.yaml create mode 100644 examples/data2vec/config/vision/pretraining/base_imagenet.yaml create mode 100644 examples/data2vec/config/vision/pretraining/base_imagenet_d2v1.yaml create mode 100644 examples/data2vec/config/vision/pretraining/base_mae_imagenet.yaml create mode 100644 examples/data2vec/config/vision/pretraining/run_config/local.yaml create mode 100644 examples/data2vec/config/vision/pretraining/run_config/slurm_1.yaml create mode 100644 examples/data2vec/config/vision/pretraining/run_config/slurm_1_aws.yaml create mode 100644 examples/data2vec/config/vision/pretraining/run_config/slurm_2.yaml create mode 100644 examples/data2vec/config/vision/pretraining/run_config/slurm_2_aws.yaml create mode 100644 examples/data2vec/config/vision/pretraining/run_config/slurm_3.yaml create mode 100644 examples/data2vec/config/vision/pretraining/run_config/slurm_4.yaml create mode 100644 examples/data2vec/config/vision/pretraining/run_config/slurm_4_aws.yaml create mode 100644 examples/data2vec/config/vision/pretraining/run_config/slurm_6_aws.yaml create mode 100644 examples/data2vec/config/vision/pretraining/run_config/slurm_8_aws.yaml create mode 100644 examples/data2vec/data/__init__.py create mode 100644 examples/data2vec/data/add_class_target_dataset.py create mode 100644 examples/data2vec/data/image_dataset.py create mode 100644 examples/data2vec/data/mae_finetuning_image_dataset.py create mode 100644 examples/data2vec/data/mae_image_dataset.py create mode 100644 examples/data2vec/data/modality.py create mode 100644 examples/data2vec/data/path_dataset.py create mode 100644 examples/data2vec/fb_convert_beit_cp.py create mode 100644 examples/data2vec/models/__init__.py create mode 100644 examples/data2vec/models/audio_classification.py create mode 100644 examples/data2vec/models/data2vec2.py create mode 100644 examples/data2vec/models/data2vec_image_classification.py create mode 100644 examples/data2vec/models/data2vec_text_classification.py create mode 100644 examples/data2vec/models/data2vec_vision.py create mode 100644 examples/data2vec/models/mae.py create mode 100644 examples/data2vec/models/mae_image_classification.py create mode 100644 examples/data2vec/models/modalities/__init__.py create mode 100644 examples/data2vec/models/modalities/audio.py create mode 100644 examples/data2vec/models/modalities/base.py create mode 100644 examples/data2vec/models/modalities/images.py create mode 100644 examples/data2vec/models/modalities/modules.py create mode 100644 examples/data2vec/models/modalities/text.py create mode 100644 examples/data2vec/models/utils.py create mode 100644 examples/data2vec/scripts/convert_audioset_labels.py create mode 100755 examples/data2vec/scripts/multi/finetune_all_fair_aws_local_lr.sh create mode 100644 examples/data2vec/scripts/multi/finetune_all_fair_aws_local_lr_nodep.sh create mode 100755 examples/data2vec/scripts/multi/finetune_all_fair_local_lr.sh create mode 100755 examples/data2vec/scripts/text/finetune_all_char_fair_aws_local_lr.sh create mode 100755 examples/data2vec/scripts/text/finetune_all_fair.sh create mode 100755 examples/data2vec/scripts/text/finetune_all_fair_aws.sh create mode 100755 examples/data2vec/scripts/text/finetune_all_fair_aws_local_lr.sh create mode 100755 examples/data2vec/scripts/text/finetune_all_fair_aws_lr.sh create mode 100755 examples/data2vec/scripts/text/finetune_all_fair_local_lr.sh create mode 100755 examples/data2vec/scripts/text/finetune_all_fair_nodep.sh create mode 100755 examples/data2vec/scripts/text/finetune_all_fair_nodep_aws.sh create mode 100755 examples/data2vec/scripts/text/finetune_all_fair_nodep_aws_local_lr.sh create mode 100755 examples/data2vec/scripts/text/finetune_all_fair_nodep_aws_lr.sh create mode 100755 examples/data2vec/scripts/text/finetune_all_fair_nodep_aws_lr_nopos.sh create mode 100755 examples/data2vec/scripts/text/finetune_all_large_fair_aws_local_lr.sh create mode 100644 examples/data2vec/scripts/text/finetune_all_large_fair_local_lr.sh create mode 100755 examples/data2vec/scripts/text/finetune_all_large_fair_nodep_aws_local_lr.sh create mode 100755 examples/data2vec/scripts/text/finetune_sst2_qnli_sweep_fair_nodep.sh create mode 100644 examples/data2vec/scripts/text/glue.py create mode 100644 examples/data2vec/scripts/text/glue_lr.py create mode 100644 examples/data2vec/scripts/text/unprocess_data.py create mode 100644 examples/data2vec/scripts/text/valids.py create mode 100644 examples/data2vec/tasks/__init__.py create mode 100644 examples/data2vec/tasks/audio_classification.py create mode 100644 examples/data2vec/tasks/image_classification.py create mode 100644 examples/data2vec/tasks/image_pretraining.py create mode 100644 examples/data2vec/tasks/mae_image_classification.py create mode 100644 examples/data2vec/tasks/mae_image_pretraining.py create mode 100644 examples/data2vec/tasks/multimodal.py create mode 100644 examples/roberta/config/finetuning/run_config/local.yaml create mode 100644 examples/roberta/config/finetuning/run_config/slurm_1g.yaml create mode 100644 examples/roberta/config/finetuning/run_config/slurm_1g_aws.yaml create mode 100644 examples/roberta/config/pretraining/run_config/local.yaml create mode 100644 examples/roberta/config/pretraining/run_config/slurm_2.yaml create mode 100644 examples/roberta/config/pretraining/run_config/slurm_2_aws.yaml create mode 100644 examples/roberta/config/pretraining/run_config/slurm_3.yaml create mode 100644 examples/roberta/config/pretraining/run_config/slurm_4.yaml create mode 100644 examples/roberta/fb_multilingual/README.multilingual.pretraining.md create mode 100644 examples/speech_recognition/new/conf/hydra/sweeper/ax_sil.yaml create mode 100644 examples/speech_recognition/new/conf/run_config/fb_slurm_1.yaml create mode 100644 examples/speech_recognition/new/conf/run_config/fb_slurm_2g.yaml create mode 100644 examples/wav2vec/config/finetuning/run_config/slurm_1.yaml create mode 100644 examples/wav2vec/config/finetuning/run_config/slurm_16.yaml create mode 100644 examples/wav2vec/config/finetuning/run_config/slurm_1_aws.yaml create mode 100644 examples/wav2vec/config/finetuning/run_config/slurm_1_old.yaml create mode 100644 examples/wav2vec/config/finetuning/run_config/slurm_2.yaml create mode 100644 examples/wav2vec/config/finetuning/run_config/slurm_2_aws.yaml create mode 100644 examples/wav2vec/config/finetuning/run_config/slurm_2g.yaml create mode 100644 examples/wav2vec/config/finetuning/run_config/slurm_3.yaml create mode 100644 examples/wav2vec/config/finetuning/run_config/slurm_4g.yaml create mode 100644 examples/wav2vec/config/finetuning/run_config/slurm_4g_aws.yaml create mode 100644 examples/wav2vec/config/finetuning/run_config/slurm_8.yaml create mode 100644 examples/wav2vec/config/finetuning/vox_100h_2.yaml create mode 100644 examples/wav2vec/config/finetuning/vox_100h_2_aws.yaml create mode 100644 examples/wav2vec/config/finetuning/vox_100h_3.yaml create mode 100644 examples/wav2vec/config/finetuning/vox_10h_2.yaml create mode 100644 examples/wav2vec/config/finetuning/vox_10h_2_aws.yaml create mode 100644 examples/wav2vec/config/finetuning/vox_10h_aws.yaml create mode 100644 examples/wav2vec/config/finetuning/vox_10h_aws_v100.yaml create mode 100644 examples/wav2vec/config/finetuning/vox_10m_2.yaml create mode 100644 examples/wav2vec/config/finetuning/vox_10m_2_aws.yaml create mode 100644 examples/wav2vec/config/finetuning/vox_10m_3.yaml create mode 100644 examples/wav2vec/config/finetuning/vox_1h_2.yaml create mode 100644 examples/wav2vec/config/finetuning/vox_1h_2_aws.yaml create mode 100644 examples/wav2vec/config/finetuning/vox_1h_3.yaml create mode 100644 examples/wav2vec/config/finetuning/vox_1h_4.yaml create mode 100644 examples/wav2vec/config/finetuning/vox_1h_aws.yaml create mode 100644 examples/wav2vec/config/finetuning/vox_960h_2.yaml create mode 100644 examples/wav2vec/config/finetuning/vox_960h_2_aws.yaml create mode 100644 examples/wav2vec/config/finetuning/vox_960h_3.yaml create mode 100644 fairseq/config/fb_run_config/slurm.yaml create mode 100644 fairseq/data/add_class_target_dataset.py create mode 100644 fairseq/data/padding_mask_dataset.py create mode 100644 fairseq_cli/hydra_validate.py create mode 100644 hydra_plugins/dependency_submitit_launcher/hydra_plugins/dependency_submitit_launcher/__init__.py create mode 100644 hydra_plugins/dependency_submitit_launcher/hydra_plugins/dependency_submitit_launcher/config.py create mode 100644 hydra_plugins/dependency_submitit_launcher/hydra_plugins/dependency_submitit_launcher/launcher.py create mode 100644 hydra_plugins/dependency_submitit_launcher/setup.py diff --git a/examples/data2vec/README.md b/examples/data2vec/README.md index 9fd05d804..5b680ef8c 100644 --- a/examples/data2vec/README.md +++ b/examples/data2vec/README.md @@ -1,3 +1,125 @@ +# data2vec 2.0 + +data2vec 2.0 improves the training efficiency of the original data2vec algorithm. We make the following improvements for efficiency considerations - we forward only the unmasked timesteps through the encoder, we use convolutional decoder and we use multimasking to amortize the compute overhead of the teacher model. You can find details in [Efficient Self-supervised Learning with Contextualized Target Representations for Vision, Speech and Language](https://ai.facebook.com/research/xyz) + +## Pretrained and finetuned models +### Vision +| Model | Finetuning split | Link +|---|---|--- +data2vec ViT-B | No fine-tuning | [download](https://dl.fbaipublicfiles.com/fairseq/data2vec2/base_imagenet.pt) +data2vec ViT-B | Imagenet-1K | [download](https://dl.fbaipublicfiles.com/fairseq/data2vec2/base_imagenet_ft.pt) +data2vec ViT-L | No fine-tuning | [download](https://dl.fbaipublicfiles.com/fairseq/data2vec2/large_imagenet.pt) +data2vec ViT-L | Imagenet-1K | [download](https://dl.fbaipublicfiles.com/fairseq/data2vec2/large_imagenet_ft.pt) +data2vec ViT-H | No fine-tuning | [download](https://dl.fbaipublicfiles.com/fairseq/data2vec2/huge_imagenet.pt) +data2vec ViT-H | Imagenet-1K | [download](https://dl.fbaipublicfiles.com/fairseq/data2vec2/huge_imagenet_ft.pt) + +Vision models only are license under CC-BY-NC. +### Speech + +| Model | Finetuning split | Dataset | Link +|---|---|---|--- +data2vec Base | No fine-tuning | [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/data2vec2/base_libri.pt) +data2vec Base | 960 hours | [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/data2vec2/base_libri_960h.pt) +data2vec Large | No fine-tuning | [Libri-light](https://github.com/facebookresearch/libri-light) | [download](https://dl.fbaipublicfiles.com/fairseq/data2vec2/large_vox.pt) +data2vec Large | 960 hours | [Libri-light](https://github.com/facebookresearch/libri-light) | [download](https://dl.fbaipublicfiles.com/fairseq/data2vec2/large_vox_960h.pt) + +### NLP + +Model | Fine-tuning data | Dataset | Link +|---|---|---|---| +data2vec Base | No fine-tuning | Books + Wiki | [download](https://dl.fbaipublicfiles.com/fairseq/data2vec2/nlp_base.pt) + +[//]: # (## Data Preparation) + +[//]: # () +[//]: # (### Vision) + +[//]: # (add details) + +[//]: # (### Speech) + +[//]: # (add details) + +[//]: # () +[//]: # (### NLP) + +[//]: # (add details) + + +## Commands to train different models using data2vec 2.0 + +### Vision + +Commands to pretrain different model configurations +```shell script +$ python fairseq_cli/hydra_train.py -m --config-dir examples/data2vec/config/v2 \ +--config-name base_images_only_task task.data=/path/to/dir +``` + +```shell script +$ python fairseq_cli/hydra_train.py -m --config-dir examples/data2vec/config/v2 \ +--config-name large_images_only_task task.data=/path/to/dir +``` + +```shell script +$ python fairseq_cli/hydra_train.py -m --config-dir examples/data2vec/config/v2 \ +--config-name huge_images14_only_task task.data=/path/to/dir +``` + +Commands to finetune different model configurations + +```shell script +$ python fairseq_cli/hydra_train.py -m --config-dir examples/data2vec/config/vision/finetuning \ +--config-name mae_imagenet_clean task.data=/path/to/dir model.model_path=/path/to/pretrained/model +``` + +```shell script +$ python fairseq_cli/hydra_train.py -m --config-dir examples/data2vec/config/vision/finetuning \ +--config-name mae_imagenet_large_clean task.data=/path/to/dir model.model_path=/path/to/pretrained/model +``` + +```shell script +$ python fairseq_cli/hydra_train.py -m --config-dir examples/data2vec/config/vision/finetuning \ +--config-name mae_imagenet_huge_clean task.data=/path/to/dir model.model_path=/path/to/pretrained/model +``` + +### Speech + +```shell script +$ python fairseq_cli/hydra_train.py -m --config-dir examples/data2vec/config/v2 \ +--config-name base_audio_only_task task.data=/path/to/manifests +``` + +```shell script +$ python fairseq_cli/hydra_train.py -m --config-dir examples/data2vec/config/v2 \ +--config-name large_audio_only_task task.data=/path/to/manifests +``` + +Finetuning: + +```shell script +$ python fairseq_cli/hydra_train.py -m --config-dir examples/wav2vec/config/finetuning --config-name vox_10h \ +task.data=/path/to/manifests model.w2v_path=/path/to/pretrained/model common.user_dir=examples/data2vec +``` + +Replace vox_10h with the right config depending on your model and fine-tuning split. +See examples/wav2vec/config/finetuning for all available configs. + +### NLP + +Commands to pretrain +```shell script +$ python fairseq_cli/hydra_train.py -m --config-dir examples/data2vec/config/v2 \ +--config-name base_text_only_task task.data=/path/to/file +``` + +Commands to fine-tune all GLUE tasks +```shell script +$ task=cola # choose from [cola|qnli|mrpc|rte|sst_2|mnli|qqp|sts_b] +$ lr=1e-5 # sweep [1e-5|2e-5|4e-5|6e-5] for each task +$ python fairseq_cli/hydra_train.py -m --config-dir examples/data2vec/config/v2/text_finetuning \ +--config-name $task task.data=/path/to/file model.model_path=/path/to/pretrained/model "optimization.lr=[${lr}]" +``` # data2vec diff --git a/examples/data2vec/__init__.py b/examples/data2vec/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/examples/data2vec/config/audio/classification/base_classification.yaml b/examples/data2vec/config/audio/classification/base_classification.yaml new file mode 100644 index 000000000..fdb9c8d3d --- /dev/null +++ b/examples/data2vec/config/audio/classification/base_classification.yaml @@ -0,0 +1,70 @@ +# @package _group_ + +common: + fp16: true + log_format: json + log_interval: 200 + all_gather_list_size: 70000 + tensorboard_logdir: tb + min_loss_scale: 1e-6 + +checkpoint: + save_interval: 1 + no_epoch_checkpoints: true + best_checkpoint_metric: mAP + maximize_best_checkpoint_metric: true + +task: + _name: audio_classification + data: ??? + normalize: true + labels: lbl + +dataset: + num_workers: 6 + max_tokens: 2560000 + skip_invalid_size_inputs_valid_test: true + valid_subset: eval + validate_interval: 5 + +distributed_training: + ddp_backend: legacy_ddp + distributed_world_size: 8 + +criterion: + _name: model + can_sum: false + log_keys: + - _predictions + - _targets + +optimization: + max_update: 30000 + lr: [0.00006] # scratch 53-5 + +optimizer: + _name: adam + adam_betas: (0.9,0.98) + adam_eps: 1e-08 + +lr_scheduler: + _name: cosine + warmup_updates: 5000 + +model: + _name: audio_classification + model_path: ??? + apply_mask: true + mask_prob: 0.6 + mask_length: 5 # scratch 1 + mask_channel_prob: 0 + mask_channel_length: 64 + layerdrop: 0.1 + dropout: 0.1 + activation_dropout: 0.1 + attention_dropout: 0.2 + feature_grad_mult: 0 # scratch 1 + label_mixup: true + source_mixup: 0.5 + prediction_mode: lin_softmax # scratch average_sigmoid + diff --git a/examples/data2vec/config/audio/classification/run_config/slurm_1.yaml b/examples/data2vec/config/audio/classification/run_config/slurm_1.yaml new file mode 100644 index 000000000..881a1583f --- /dev/null +++ b/examples/data2vec/config/audio/classification/run_config/slurm_1.yaml @@ -0,0 +1,35 @@ +# @package _global_ + +hydra: + job: + config: + override_dirname: + kv_sep: ':' + item_sep: '/' + exclude_keys: + - run_config + - distributed_training.distributed_port + - distributed_training.distributed_world_size + - model.pretrained_model_path + - model.target_network_path + - next_script + - task.cache_in_scratch + - task.data + - checkpoint.save_interval_updates + - checkpoint.keep_interval_updates + - checkpoint.save_on_overflow + sweep: + dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname} + subdir: '' + launcher: + submitit_folder: ${hydra.sweep.dir} + timeout_min: 4320 + cpus_per_task: 10 + gpus_per_node: 8 + tasks_per_node: 8 + mem_gb: 450 + nodes: 1 + name: ${env:PREFIX}_${hydra.job.config_name} + partition: devlab,learnlab,learnfair,scavenge + constraint: volta32gb,ib4 + max_num_timeout: 30 diff --git a/examples/data2vec/config/audio/classification/run_config/slurm_1g.yaml b/examples/data2vec/config/audio/classification/run_config/slurm_1g.yaml new file mode 100644 index 000000000..de7894d9c --- /dev/null +++ b/examples/data2vec/config/audio/classification/run_config/slurm_1g.yaml @@ -0,0 +1,35 @@ +# @package _global_ + +hydra: + job: + config: + override_dirname: + kv_sep: ':' + item_sep: '/' + exclude_keys: + - run_config + - distributed_training.distributed_port + - distributed_training.distributed_world_size + - model.pretrained_model_path + - model.target_network_path + - next_script + - task.cache_in_scratch + - task.data + - checkpoint.save_interval_updates + - checkpoint.keep_interval_updates + - checkpoint.save_on_overflow + sweep: + dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname} + subdir: '' + launcher: + submitit_folder: ${hydra.sweep.dir} + timeout_min: 4320 + cpus_per_task: 10 + gpus_per_node: 1 + tasks_per_node: 1 + mem_gb: 100 + nodes: 1 + name: ${env:PREFIX}_${hydra.job.config_name} + partition: devlab,learnlab,learnfair,scavenge + constraint: volta32gb + max_num_timeout: 30 diff --git a/examples/data2vec/config/audio/classification/run_config/slurm_2.yaml b/examples/data2vec/config/audio/classification/run_config/slurm_2.yaml new file mode 100644 index 000000000..b016cac9b --- /dev/null +++ b/examples/data2vec/config/audio/classification/run_config/slurm_2.yaml @@ -0,0 +1,35 @@ +# @package _global_ + +hydra: + job: + config: + override_dirname: + kv_sep: ':' + item_sep: '/' + exclude_keys: + - run_config + - distributed_training.distributed_port + - distributed_training.distributed_world_size + - model.pretrained_model_path + - model.target_network_path + - next_script + - task.cache_in_scratch + - task.data + - checkpoint.save_interval_updates + - checkpoint.keep_interval_updates + - checkpoint.save_on_overflow + sweep: + dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname} + subdir: '' + launcher: + submitit_folder: ${hydra.sweep.dir} + timeout_min: 4320 + cpus_per_task: 10 + gpus_per_node: 8 + tasks_per_node: 8 + mem_gb: 450 + nodes: 2 + name: ${env:PREFIX}_${hydra.job.config_name} + partition: devlab,learnlab,learnfair,scavenge + constraint: volta32gb,ib4 + max_num_timeout: 30 diff --git a/examples/data2vec/config/audio/pretraining/audioset.yaml b/examples/data2vec/config/audio/pretraining/audioset.yaml new file mode 100644 index 000000000..dd30fbedd --- /dev/null +++ b/examples/data2vec/config/audio/pretraining/audioset.yaml @@ -0,0 +1,91 @@ +# @package _group_ + +common: + fp16: true + log_format: json + log_interval: 200 + tensorboard_logdir: tb + min_loss_scale: 1e-6 + user_dir: /private/home/abaevski/fairseq-py/examples/data2vec + +checkpoint: + save_interval: 1 + save_interval_updates: 25000 + keep_interval_updates: 1 + no_epoch_checkpoints: true + +task: + _name: audio_pretraining + data: /private/home/abaevski/data/audioset + max_sample_size: 320000 + min_sample_size: 32000 + normalize: true + +dataset: + num_workers: 6 + max_tokens: 3400000 + skip_invalid_size_inputs_valid_test: true + validate_interval: 5 + required_batch_size_multiple: 1 + disable_validation: true + +distributed_training: + distributed_world_size: 24 + ddp_backend: legacy_ddp + +criterion: + _name: model + log_keys: + - ema_decay + - target_var + - pred_var +# - avg_self_attn +# - weights + +optimization: + max_update: 200000 + lr: [0.0005] + +optimizer: + _name: adam + adam_betas: (0.9,0.98) + adam_eps: 1e-06 + weight_decay: 0.01 + +lr_scheduler: + _name: cosine + warmup_updates: 10000 + +model: + _name: data2vec_audio + extractor_mode: layer_norm + encoder_layerdrop: 0.05 + dropout_input: 0.0 + dropout_features: 0.0 + feature_grad_mult: 1.0 + encoder_embed_dim: 768 + + mask_prob: 0.65 + mask_length: 10 + + loss_beta: 0 + loss_scale: null + + instance_norm_target_layer: true + layer_norm_targets: true + average_top_k_layers: 12 + + self_attn_norm_type: deepnorm + final_norm_type: deepnorm + + pos_conv_depth: 5 + conv_pos: 95 + + ema_decay: 0.999 + ema_end_decay: 0.9999 + ema_anneal_end_step: 30000 + ema_transformer_only: true + ema_layers_only: false + + require_same_masks: true + mask_dropout: 0 diff --git a/examples/data2vec/config/audio/pretraining/run_config/local.yaml b/examples/data2vec/config/audio/pretraining/run_config/local.yaml new file mode 100644 index 000000000..45595f9ee --- /dev/null +++ b/examples/data2vec/config/audio/pretraining/run_config/local.yaml @@ -0,0 +1,15 @@ +# @package _global_ +hydra: + sweep: + dir: ${env:PWD}/tmp_dbg/${now:%H-%M-%S} + +distributed_training: + distributed_world_size: 1 + nprocs_per_node: 1 + distributed_port: -1 + +common: + log_interval: 1 + +dataset: + num_workers: 0 diff --git a/examples/data2vec/config/audio/pretraining/run_config/slurm_1.yaml b/examples/data2vec/config/audio/pretraining/run_config/slurm_1.yaml new file mode 100644 index 000000000..732f01889 --- /dev/null +++ b/examples/data2vec/config/audio/pretraining/run_config/slurm_1.yaml @@ -0,0 +1,37 @@ +# @package _global_ + +hydra: + job: + config: + override_dirname: + kv_sep: ':' + item_sep: '/' + exclude_keys: + - run_config + - distributed_training.distributed_port + - distributed_training.distributed_world_size + - model.pretrained_model_path + - model.target_network_path + - next_script + - task.cache_in_scratch + - task.data + - checkpoint.save_interval_updates + - checkpoint.keep_interval_updates + - checkpoint.save_on_overflow + - common.log_interval + - common.user_dir + sweep: + dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname} + subdir: '' + launcher: + submitit_folder: ${hydra.sweep.dir} + timeout_min: 4320 + cpus_per_task: 80 + gpus_per_node: 8 + tasks_per_node: 1 + mem_gb: 450 + nodes: 1 + name: ${env:PREFIX}_${hydra.job.config_name} + partition: devlab,learnlab,learnfair,scavenge + constraint: volta32gb,ib4 + max_num_timeout: 30 diff --git a/examples/data2vec/config/audio/pretraining/run_config/slurm_1_aws.yaml b/examples/data2vec/config/audio/pretraining/run_config/slurm_1_aws.yaml new file mode 100644 index 000000000..e2bab5675 --- /dev/null +++ b/examples/data2vec/config/audio/pretraining/run_config/slurm_1_aws.yaml @@ -0,0 +1,36 @@ +# @package _global_ + +hydra: + job: + config: + override_dirname: + kv_sep: ':' + item_sep: '/' + exclude_keys: + - run_config + - distributed_training.distributed_port + - distributed_training.distributed_world_size + - model.pretrained_model_path + - model.target_network_path + - next_script + - task.cache_in_scratch + - task.data + - checkpoint.save_interval_updates + - checkpoint.keep_interval_updates + - checkpoint.save_on_overflow + - common.log_interval + - common.user_dir + sweep: + dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname} + subdir: '' + launcher: + submitit_folder: ${hydra.sweep.dir} + timeout_min: 4320 + cpus_per_task: 80 + gpus_per_node: 8 + tasks_per_node: 1 + mem_gb: 0 + nodes: 1 + name: ${env:PREFIX}_${hydra.job.config_name} + partition: wav2vec,learnlab,learnfair + max_num_timeout: 30 diff --git a/examples/data2vec/config/audio/pretraining/run_config/slurm_2.yaml b/examples/data2vec/config/audio/pretraining/run_config/slurm_2.yaml new file mode 100644 index 000000000..ec53dc2a9 --- /dev/null +++ b/examples/data2vec/config/audio/pretraining/run_config/slurm_2.yaml @@ -0,0 +1,37 @@ +# @package _global_ + +hydra: + job: + config: + override_dirname: + kv_sep: ':' + item_sep: '/' + exclude_keys: + - run_config + - distributed_training.distributed_port + - distributed_training.distributed_world_size + - model.pretrained_model_path + - model.target_network_path + - next_script + - task.cache_in_scratch + - task.data + - checkpoint.save_interval_updates + - checkpoint.keep_interval_updates + - checkpoint.save_on_overflow + - common.log_interval + - common.user_dir + sweep: + dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname} + subdir: '' + launcher: + submitit_folder: ${hydra.sweep.dir} + timeout_min: 4320 + cpus_per_task: 10 + gpus_per_node: 8 + tasks_per_node: 8 + mem_gb: 450 + nodes: 2 + name: ${env:PREFIX}_${hydra.job.config_name} + partition: devlab,learnlab,learnfair,scavenge + constraint: volta32gb,ib4 + max_num_timeout: 30 diff --git a/examples/data2vec/config/audio/pretraining/run_config/slurm_2_aws.yaml b/examples/data2vec/config/audio/pretraining/run_config/slurm_2_aws.yaml new file mode 100644 index 000000000..70cc8cbb5 --- /dev/null +++ b/examples/data2vec/config/audio/pretraining/run_config/slurm_2_aws.yaml @@ -0,0 +1,37 @@ +# @package _global_ + +hydra: + job: + config: + override_dirname: + kv_sep: ':' + item_sep: '/' + exclude_keys: + - run_config + - distributed_training.distributed_port + - distributed_training.distributed_world_size + - model.pretrained_model_path + - model.target_network_path + - next_script + - task.cache_in_scratch + - task.data + - task.post_save_script + - checkpoint.save_interval_updates + - checkpoint.keep_interval_updates + - checkpoint.save_on_overflow + - common.log_interval + - common.user_dir + sweep: + dir: /fsx-wav2vec/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname} + subdir: '' + launcher: + submitit_folder: ${hydra.sweep.dir} + timeout_min: 4320 + cpus_per_task: 10 + gpus_per_node: 8 + tasks_per_node: 8 + mem_gb: 0 + nodes: 2 + name: ${env:PREFIX}_${hydra.job.config_name} + partition: wav2vec,learnlab,learnfair + max_num_timeout: 30 diff --git a/examples/data2vec/config/audio/pretraining/run_config/slurm_3.yaml b/examples/data2vec/config/audio/pretraining/run_config/slurm_3.yaml new file mode 100644 index 000000000..14b47d14e --- /dev/null +++ b/examples/data2vec/config/audio/pretraining/run_config/slurm_3.yaml @@ -0,0 +1,36 @@ +# @package _global_ + +hydra: + job: + config: + override_dirname: + kv_sep: ':' + item_sep: '/' + exclude_keys: + - run_config + - distributed_training.distributed_port + - distributed_training.distributed_world_size + - model.pretrained_model_path + - model.target_network_path + - next_script + - task.cache_in_scratch + - task.data + - checkpoint.save_interval_updates + - checkpoint.keep_interval_updates + - checkpoint.save_on_overflow + - common.log_interval + sweep: + dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname} + subdir: '' + launcher: + submitit_folder: ${hydra.sweep.dir} + timeout_min: 4320 + cpus_per_task: 80 + gpus_per_node: 8 + tasks_per_node: 1 + mem_gb: 450 + nodes: 3 + name: ${env:PREFIX}_${hydra.job.config_name} + partition: devlab,learnlab,learnfair,scavenge + constraint: volta32gb,ib4 + max_num_timeout: 30 diff --git a/examples/data2vec/config/audio/pretraining/run_config/slurm_4.yaml b/examples/data2vec/config/audio/pretraining/run_config/slurm_4.yaml new file mode 100644 index 000000000..c54d735fb --- /dev/null +++ b/examples/data2vec/config/audio/pretraining/run_config/slurm_4.yaml @@ -0,0 +1,36 @@ +# @package _global_ + +hydra: + job: + config: + override_dirname: + kv_sep: ':' + item_sep: '/' + exclude_keys: + - run_config + - distributed_training.distributed_port + - distributed_training.distributed_world_size + - model.pretrained_model_path + - model.target_network_path + - next_script + - task.cache_in_scratch + - task.data + - checkpoint.save_interval_updates + - checkpoint.keep_interval_updates + - checkpoint.save_on_overflow + - common.log_interval + sweep: + dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname} + subdir: '' + launcher: + submitit_folder: ${hydra.sweep.dir} + timeout_min: 4320 + cpus_per_task: 10 + gpus_per_node: 8 + tasks_per_node: 8 + mem_gb: 450 + nodes: 4 + name: ${env:PREFIX}_${hydra.job.config_name} + partition: devlab,learnlab,learnfair,scavenge + constraint: volta32gb,ib4 + max_num_timeout: 30 diff --git a/examples/data2vec/config/audio/pretraining/run_config/slurm_4_aws.yaml b/examples/data2vec/config/audio/pretraining/run_config/slurm_4_aws.yaml new file mode 100644 index 000000000..0231b2690 --- /dev/null +++ b/examples/data2vec/config/audio/pretraining/run_config/slurm_4_aws.yaml @@ -0,0 +1,37 @@ +# @package _global_ + +hydra: + job: + config: + override_dirname: + kv_sep: ':' + item_sep: '/' + exclude_keys: + - run_config + - distributed_training.distributed_port + - distributed_training.distributed_world_size + - model.pretrained_model_path + - model.target_network_path + - next_script + - task.cache_in_scratch + - task.data + - task.post_save_script + - checkpoint.save_interval_updates + - checkpoint.keep_interval_updates + - checkpoint.save_on_overflow + - common.log_interval + - common.user_dir + sweep: + dir: /fsx-wav2vec/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname} + subdir: '' + launcher: + submitit_folder: ${hydra.sweep.dir} + timeout_min: 4320 + cpus_per_task: 10 + gpus_per_node: 8 + tasks_per_node: 8 + mem_gb: 0 + nodes: 4 + name: ${env:PREFIX}_${hydra.job.config_name} + partition: wav2vec,learnlab,learnfair + max_num_timeout: 30 diff --git a/examples/data2vec/config/audio/pretraining/run_config/slurm_6_aws.yaml b/examples/data2vec/config/audio/pretraining/run_config/slurm_6_aws.yaml new file mode 100644 index 000000000..9a4e43a98 --- /dev/null +++ b/examples/data2vec/config/audio/pretraining/run_config/slurm_6_aws.yaml @@ -0,0 +1,36 @@ +# @package _global_ + +hydra: + job: + config: + override_dirname: + kv_sep: ':' + item_sep: '/' + exclude_keys: + - run_config + - distributed_training.distributed_port + - distributed_training.distributed_world_size + - model.pretrained_model_path + - model.target_network_path + - next_script + - task.cache_in_scratch + - task.data + - checkpoint.save_interval_updates + - checkpoint.keep_interval_updates + - checkpoint.save_on_overflow + - common.log_interval + - common.user_dir + sweep: + dir: /fsx-wav2vec/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname} + subdir: '' + launcher: + submitit_folder: ${hydra.sweep.dir} + timeout_min: 4320 + cpus_per_task: 10 + gpus_per_node: 8 + tasks_per_node: 8 + mem_gb: 0 + nodes: 6 + name: ${env:PREFIX}_${hydra.job.config_name} + partition: wav2vec,learnlab,learnfair + max_num_timeout: 30 diff --git a/examples/data2vec/config/audio/pretraining/run_config/slurm_8_aws.yaml b/examples/data2vec/config/audio/pretraining/run_config/slurm_8_aws.yaml new file mode 100644 index 000000000..78c9f57ae --- /dev/null +++ b/examples/data2vec/config/audio/pretraining/run_config/slurm_8_aws.yaml @@ -0,0 +1,36 @@ +# @package _global_ + +hydra: + job: + config: + override_dirname: + kv_sep: ':' + item_sep: '/' + exclude_keys: + - run_config + - distributed_training.distributed_port + - distributed_training.distributed_world_size + - model.pretrained_model_path + - model.target_network_path + - next_script + - task.cache_in_scratch + - task.data + - checkpoint.save_interval_updates + - checkpoint.keep_interval_updates + - checkpoint.save_on_overflow + - common.log_interval + - common.user_dir + sweep: + dir: /fsx-wav2vec/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname} + subdir: '' + launcher: + submitit_folder: ${hydra.sweep.dir} + timeout_min: 4320 + cpus_per_task: 10 + gpus_per_node: 8 + tasks_per_node: 8 + mem_gb: 0 + nodes: 8 + name: ${env:PREFIX}_${hydra.job.config_name} + partition: wav2vec,learnlab,learnfair + max_num_timeout: 30 diff --git a/examples/data2vec/config/text/pretraining/run_config/local.yaml b/examples/data2vec/config/text/pretraining/run_config/local.yaml new file mode 100644 index 000000000..45595f9ee --- /dev/null +++ b/examples/data2vec/config/text/pretraining/run_config/local.yaml @@ -0,0 +1,15 @@ +# @package _global_ +hydra: + sweep: + dir: ${env:PWD}/tmp_dbg/${now:%H-%M-%S} + +distributed_training: + distributed_world_size: 1 + nprocs_per_node: 1 + distributed_port: -1 + +common: + log_interval: 1 + +dataset: + num_workers: 0 diff --git a/examples/data2vec/config/text/pretraining/run_config/slurm_1_aws.yaml b/examples/data2vec/config/text/pretraining/run_config/slurm_1_aws.yaml new file mode 100644 index 000000000..4bac45a58 --- /dev/null +++ b/examples/data2vec/config/text/pretraining/run_config/slurm_1_aws.yaml @@ -0,0 +1,37 @@ +# @package _global_ + +hydra: + job: + config: + override_dirname: + kv_sep: '_' + item_sep: '/' + exclude_keys: + - run_config + - distributed_training.distributed_port + - distributed_training.distributed_world_size + - model.pretrained_model_path + - model.target_network_path + - next_script + - task.cache_in_scratch + - task.data + - checkpoint.save_interval_updates + - checkpoint.keep_interval_updates + - checkpoint.save_on_overflow + - common.log_interval + - common.user_dir + sweep: + dir: /fsx-wav2vec/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname} + subdir: '' + launcher: + submitit_folder: ${hydra.sweep.dir}/submitit + timeout_min: 4320 + cpus_per_task: 80 + gpus_per_node: 8 + tasks_per_node: 1 + mem_gb: 0 + nodes: 1 + name: ${env:PREFIX}_${hydra.job.config_name} + partition: wav2vec + max_num_timeout: 30 + exclude: a100-st-p4d24xlarge-471 diff --git a/examples/data2vec/config/text/pretraining/run_config/slurm_2.yaml b/examples/data2vec/config/text/pretraining/run_config/slurm_2.yaml new file mode 100644 index 000000000..006a0f211 --- /dev/null +++ b/examples/data2vec/config/text/pretraining/run_config/slurm_2.yaml @@ -0,0 +1,37 @@ +# @package _global_ + +hydra: + job: + config: + override_dirname: + kv_sep: ':' + item_sep: '/' + exclude_keys: + - run_config + - distributed_training.distributed_port + - distributed_training.distributed_world_size + - model.pretrained_model_path + - model.target_network_path + - next_script + - task.cache_in_scratch + - task.data + - checkpoint.save_interval_updates + - checkpoint.keep_interval_updates + - checkpoint.save_on_overflow + - common.log_interval + - common.user_dir + sweep: + dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname} + subdir: '' + launcher: + submitit_folder: ${hydra.sweep.dir} + timeout_min: 4320 + cpus_per_task: 80 + gpus_per_node: 8 + tasks_per_node: 1 + mem_gb: 450 + nodes: 2 + name: ${env:PREFIX}_${hydra.job.config_name} + partition: devlab,learnlab,learnfair,scavenge + constraint: volta32gb,ib4 + max_num_timeout: 30 diff --git a/examples/data2vec/config/text/pretraining/run_config/slurm_2_aws.yaml b/examples/data2vec/config/text/pretraining/run_config/slurm_2_aws.yaml new file mode 100644 index 000000000..4292198b4 --- /dev/null +++ b/examples/data2vec/config/text/pretraining/run_config/slurm_2_aws.yaml @@ -0,0 +1,37 @@ +# @package _global_ + +hydra: + job: + config: + override_dirname: + kv_sep: '_' + item_sep: '/' + exclude_keys: + - run_config + - distributed_training.distributed_port + - distributed_training.distributed_world_size + - model.pretrained_model_path + - model.target_network_path + - next_script + - task.cache_in_scratch + - task.data + - checkpoint.save_interval_updates + - checkpoint.keep_interval_updates + - checkpoint.save_on_overflow + - common.log_interval + - common.user_dir + sweep: + dir: /fsx-wav2vec/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname} + subdir: '' + launcher: + submitit_folder: ${hydra.sweep.dir}/submitit + timeout_min: 4320 + cpus_per_task: 10 + gpus_per_node: 8 + tasks_per_node: 8 + mem_gb: 0 + nodes: 2 + name: ${env:PREFIX}_${hydra.job.config_name} + partition: wav2vec + max_num_timeout: 30 + exclude: a100-st-p4d24xlarge-471 diff --git a/examples/data2vec/config/text/pretraining/run_config/slurm_3.yaml b/examples/data2vec/config/text/pretraining/run_config/slurm_3.yaml new file mode 100644 index 000000000..0e1555d20 --- /dev/null +++ b/examples/data2vec/config/text/pretraining/run_config/slurm_3.yaml @@ -0,0 +1,36 @@ +# @package _global_ + +hydra: + job: + config: + override_dirname: + kv_sep: ':' + item_sep: '/' + exclude_keys: + - run_config + - distributed_training.distributed_port + - distributed_training.distributed_world_size + - model.pretrained_model_path + - model.target_network_path + - next_script + - task.cache_in_scratch + - task.data + - checkpoint.save_interval_updates + - checkpoint.keep_interval_updates + - checkpoint.save_on_overflow + - common.log_interval + sweep: + dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname} + subdir: '' + launcher: + submitit_folder: ${hydra.sweep.dir} + timeout_min: 4320 + cpus_per_task: 10 + gpus_per_node: 8 + tasks_per_node: 8 + mem_gb: 450 + nodes: 3 + name: ${env:PREFIX}_${hydra.job.config_name} + partition: devlab,learnlab,learnfair,scavenge + constraint: volta32gb,ib4 + max_num_timeout: 30 diff --git a/examples/data2vec/config/text/pretraining/run_config/slurm_4.yaml b/examples/data2vec/config/text/pretraining/run_config/slurm_4.yaml new file mode 100644 index 000000000..c54d735fb --- /dev/null +++ b/examples/data2vec/config/text/pretraining/run_config/slurm_4.yaml @@ -0,0 +1,36 @@ +# @package _global_ + +hydra: + job: + config: + override_dirname: + kv_sep: ':' + item_sep: '/' + exclude_keys: + - run_config + - distributed_training.distributed_port + - distributed_training.distributed_world_size + - model.pretrained_model_path + - model.target_network_path + - next_script + - task.cache_in_scratch + - task.data + - checkpoint.save_interval_updates + - checkpoint.keep_interval_updates + - checkpoint.save_on_overflow + - common.log_interval + sweep: + dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname} + subdir: '' + launcher: + submitit_folder: ${hydra.sweep.dir} + timeout_min: 4320 + cpus_per_task: 10 + gpus_per_node: 8 + tasks_per_node: 8 + mem_gb: 450 + nodes: 4 + name: ${env:PREFIX}_${hydra.job.config_name} + partition: devlab,learnlab,learnfair,scavenge + constraint: volta32gb,ib4 + max_num_timeout: 30 diff --git a/examples/data2vec/config/text/pretraining/run_config/slurm_4_aws.yaml b/examples/data2vec/config/text/pretraining/run_config/slurm_4_aws.yaml new file mode 100644 index 000000000..5df84cd6d --- /dev/null +++ b/examples/data2vec/config/text/pretraining/run_config/slurm_4_aws.yaml @@ -0,0 +1,41 @@ +# @package _global_ + +hydra: + job: + config: + override_dirname: + kv_sep: '_' + item_sep: '/' + exclude_keys: + - run_config + - distributed_training.distributed_port + - distributed_training.distributed_world_size + - model.pretrained_model_path + - model.target_network_path + - next_script + - task.cache_in_scratch + - task.data + - checkpoint.save_interval_updates + - checkpoint.keep_interval_updates + - checkpoint.save_on_overflow + - common.log_interval + - common.user_dir + sweep: + dir: /fsx-wav2vec/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname} + subdir: '' + launcher: + submitit_folder: ${hydra.sweep.dir}/submitit + timeout_min: 4320 + cpus_per_task: 10 + gpus_per_node: 8 + tasks_per_node: 8 + mem_gb: 0 + nodes: 4 + name: ${env:PREFIX}_${hydra.job.config_name} + partition: wav2vec + max_num_timeout: 30 + exclude: a100-st-p4d24xlarge-471 + +distributed_training: + distributed_world_size: 32 + ddp_backend: legacy_ddp diff --git a/examples/data2vec/config/text/pretraining/run_config/slurm_8_aws.yaml b/examples/data2vec/config/text/pretraining/run_config/slurm_8_aws.yaml new file mode 100644 index 000000000..5b32c23a6 --- /dev/null +++ b/examples/data2vec/config/text/pretraining/run_config/slurm_8_aws.yaml @@ -0,0 +1,41 @@ +# @package _global_ + +hydra: + job: + config: + override_dirname: + kv_sep: '_' + item_sep: '/' + exclude_keys: + - run_config + - distributed_training.distributed_port + - distributed_training.distributed_world_size + - model.pretrained_model_path + - model.target_network_path + - next_script + - task.cache_in_scratch + - task.data + - checkpoint.save_interval_updates + - checkpoint.keep_interval_updates + - checkpoint.save_on_overflow + - common.log_interval + - common.user_dir + sweep: + dir: /fsx-wav2vec/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname} + subdir: '' + launcher: + submitit_folder: ${hydra.sweep.dir}/submitit + timeout_min: 4320 + cpus_per_task: 10 + gpus_per_node: 8 + tasks_per_node: 8 + mem_gb: 0 + nodes: 8 + name: pt + partition: wav2vec + max_num_timeout: 30 + exclude: a100-st-p4d24xlarge-471 + +distributed_training: + distributed_world_size: 64 + ddp_backend: legacy_ddp diff --git a/examples/data2vec/config/v2/base_audio_only_task.yaml b/examples/data2vec/config/v2/base_audio_only_task.yaml new file mode 100644 index 000000000..65a9ab3e7 --- /dev/null +++ b/examples/data2vec/config/v2/base_audio_only_task.yaml @@ -0,0 +1,113 @@ +# @package _group_ + +common: + fp16: true + log_format: json + log_interval: 200 + tensorboard_logdir: tb + min_loss_scale: 1e-6 + fp16_no_flatten_grads: false + user_dir: ${env:PWD}/examples/data2vec + +checkpoint: + save_interval: 1 + save_interval_updates: 25000 + keep_interval_updates: 1 + no_epoch_checkpoints: true + +task: + _name: audio_pretraining + data: /private/home/abaevski/data/librispeech/full + max_sample_size: 320000 + min_sample_size: 32000 + normalize: true + precompute_mask_config: {} + +dataset: + num_workers: 6 + max_tokens: 1000000 + skip_invalid_size_inputs_valid_test: true + validate_interval: 5 + required_batch_size_multiple: 1 + disable_validation: true + +distributed_training: + distributed_world_size: 8 + ddp_backend: legacy_ddp + +criterion: + _name: model + log_keys: + - ema_decay + - target_var + - pred_var + - model_norm + - ema_norm + - masked_pct + +optimization: + max_update: 400000 + lr: [0.00075] + debug_param_names: true + +optimizer: + _name: adam + adam_betas: [ 0.9,0.98 ] + adam_eps: 1e-06 + weight_decay: 0.01 + +lr_scheduler: + _name: cosine + warmup_updates: 8000 + +model: + _name: data2vec_multi + + loss_beta: 0 + loss_scale: null + + depth: 12 + embed_dim: 768 + clone_batch: 8 + + ema_decay: 0.999 + ema_end_decay: 0.99999 + ema_anneal_end_step: 75000 + ema_encoder_only: false + + average_top_k_layers: 8 + instance_norm_target_layer: true + layer_norm_target_layer: false + layer_norm_targets: false + + layerdrop: 0.05 + norm_eps: 1e-5 + + supported_modality: AUDIO + + modalities: + audio: + feature_encoder_spec: '[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512,2,2)] + [(512,2,2)]' + conv_pos_depth: 5 + conv_pos_width: 95 + conv_pos_groups: 16 + prenet_depth: 0 + mask_prob: 0.5 + mask_prob_adjust: 0.05 + inverse_mask: false + mask_length: 5 + mask_noise_std: 0.01 + mask_dropout: 0 + add_masks: false + ema_local_encoder: false + use_alibi_encoder: true + prenet_layerdrop: 0.05 + prenet_dropout: 0.1 + learned_alibi_scale: true + learned_alibi_scale_per_head: true + decoder: + input_dropout: 0.1 + decoder_dim: 384 + decoder_groups: 16 + decoder_kernel: 7 + decoder_layers: 4 diff --git a/examples/data2vec/config/v2/base_images_only_task.yaml b/examples/data2vec/config/v2/base_images_only_task.yaml new file mode 100644 index 000000000..ff0c247b1 --- /dev/null +++ b/examples/data2vec/config/v2/base_images_only_task.yaml @@ -0,0 +1,116 @@ +# @package _group_ + +common: + fp16: true + log_format: json + log_interval: 200 + tensorboard_logdir: tb + min_loss_scale: 1e-6 + fp16_no_flatten_grads: true + user_dir: ${env:PWD}/examples/data2vec + +checkpoint: + save_interval: 5 + save_interval_updates: 25000 + keep_interval_updates: 1 + no_epoch_checkpoints: true + +task: + _name: mae_image_pretraining + data: /datasets01/imagenet_full_size/061417/ + rebuild_batches: true + local_cache_path: /scratch/cache_abaevski/imagenet + key: source + precompute_mask_config: {} + +dataset: + num_workers: 10 + batch_size: 16 + skip_invalid_size_inputs_valid_test: true + required_batch_size_multiple: 1 + disable_validation: true + +distributed_training: + distributed_world_size: 16 + ddp_backend: c10d + +criterion: + _name: model + log_keys: + - ema_decay + - target_var + - pred_var + - model_norm + - ema_norm + - masked_pct + +optimization: + max_update: 375300 + lr: [ 0.001 ] + debug_param_names: true + clip_norm: 4 + +optimizer: + _name: composite + dynamic_groups: true + groups: + default: + lr_float: 1e-3 + optimizer: + _name: adam + adam_betas: [0.9,0.95] + weight_decay: 0.05 + lr_scheduler: + _name: cosine + warmup_updates: 50040 + +lr_scheduler: pass_through + +model: + _name: data2vec_multi + + ema_decay: 0.9998 + ema_end_decay: 0.99999 + ema_anneal_end_step: 100000 + instance_norm_target_layer: true + layer_norm_target_layer: false + layer_norm_targets: true + end_of_block_targets: false + + depth: 10 + average_top_k_layers: 10 + clone_batch: 16 + + norm_eps: 1e-6 + + min_target_var: 0 + min_pred_var: 0 + + encoder_dropout: 0 + post_mlp_drop: 0 + attention_dropout: 0 + activation_dropout: 0 + + supported_modality: IMAGE + cls_loss: 0.01 + + ema_encoder_only: false + + modalities: + image: + inverse_mask: true + mask_prob: 0.8 + mask_prob_adjust: 0.07 + mask_length: 3 + mask_noise_std: 0.01 + prenet_depth: 2 + ema_local_encoder: true + num_extra_tokens: 1 + init_extra_token_zero: false + use_alibi_encoder: false + decoder: + decoder_dim: 768 + decoder_groups: 16 + decoder_kernel: 3 + decoder_layers: 6 + input_dropout: 0 \ No newline at end of file diff --git a/examples/data2vec/config/v2/base_text_only_task.yaml b/examples/data2vec/config/v2/base_text_only_task.yaml new file mode 100644 index 000000000..62f22eb0f --- /dev/null +++ b/examples/data2vec/config/v2/base_text_only_task.yaml @@ -0,0 +1,112 @@ +# @package _group_ + +common: + fp16: true + log_format: json + log_interval: 200 + tensorboard_logdir: tb + fp16_no_flatten_grads: true + user_dir: ${env:PWD}/examples/data2vec + +checkpoint: + no_epoch_checkpoints: true + save_interval_updates: 50000 + keep_interval_updates: 1 + +distributed_training: + distributed_world_size: 16 + ddp_backend: legacy_ddp + +task: + _name: masked_lm + data: /fsx-wav2vec/abaevski/data/nlp/bookwiki_aml-full-mmap2-bin + sample_break_mode: none + tokens_per_sample: 512 + include_target_tokens: true + random_token_prob: 0 + leave_unmasked_prob: 0 + include_index: True + skip_masking: True + d2v2_multi: True + +criterion: + _name: model + log_keys: + - ema_decay + - target_var + - pred_var + - model_norm + - ema_norm + - masked_pct + +dataset: + batch_size: 4 + ignore_unused_valid_subsets: true + skip_invalid_size_inputs_valid_test: true + disable_validation: true + +optimization: + clip_norm: 1 + lr: [0.0002] + max_update: 1000000 + update_freq: [1] + +optimizer: + _name: composite + dynamic_groups: true + groups: + default: + lr_float: 0.0002 + optimizer: + _name: adam + adam_betas: [0.9,0.98] + adam_eps: 1e-06 + weight_decay: 0.01 + lr_scheduler: + _name: cosine + warmup_updates: 4000 + +lr_scheduler: pass_through + +model: + _name: data2vec_multi + + loss_beta: 0 + loss_scale: 1 + + depth: 12 + embed_dim: 768 + clone_batch: 8 + + ema_decay: 0.9999 + ema_end_decay: 0.99999 + ema_anneal_end_step: 100000 + ema_encoder_only: true + + average_top_k_layers: 12 + layer_norm_target_layer: false + instance_norm_target_layer: true + batch_norm_target_layer: false + instance_norm_targets: false + layer_norm_targets: false + + layerdrop: 0 + norm_eps: 1e-5 + + supported_modality: TEXT + + modalities: + text: + mask_prob: 0.48 + mask_length: 1 + mask_noise_std: 0.01 + prenet_depth: 0 + decoder: + input_dropout: 0.1 + decoder_dim: 768 + decoder_groups: 1 + decoder_kernel: 9 + decoder_layers: 5 + decoder_residual: false + projection_layers: 2 + projection_ratio: 2.0 diff --git a/examples/data2vec/config/v2/huge_images14_only_task.yaml b/examples/data2vec/config/v2/huge_images14_only_task.yaml new file mode 100644 index 000000000..a8a15253f --- /dev/null +++ b/examples/data2vec/config/v2/huge_images14_only_task.yaml @@ -0,0 +1,122 @@ +# @package _group_ + +common: + fp16: true + log_format: json + log_interval: 200 + tensorboard_logdir: tb + min_loss_scale: 1e-6 + fp16_no_flatten_grads: true + user_dir: ${env:PWD}/examples/data2vec + +checkpoint: + save_interval: 5 + save_interval_updates: 25000 + keep_interval_updates: 1 + no_epoch_checkpoints: true + +task: + _name: mae_image_pretraining + data: /datasets01/imagenet_full_size/061417/ + rebuild_batches: true + local_cache_path: /scratch/cache_abaevski/imagenet + key: source + precompute_mask_config: {} + +dataset: + num_workers: 10 + batch_size: 8 + skip_invalid_size_inputs_valid_test: true + required_batch_size_multiple: 1 + disable_validation: true + +distributed_training: + distributed_world_size: 32 + ddp_backend: c10d + +criterion: + _name: model + log_keys: + - ema_decay + - target_var + - pred_var + - model_norm + - ema_norm + - masked_pct + +optimization: + max_update: 500000 + lr: [ 0.0004 ] + debug_param_names: true + clip_norm: 4 + +optimizer: + _name: composite + dynamic_groups: true + groups: + default: + lr_float: 4e-4 + optimizer: + _name: adam + adam_betas: [0.9,0.95] + weight_decay: 0.05 + lr_scheduler: + _name: cosine + warmup_updates: 50040 + +lr_scheduler: pass_through + +model: + _name: data2vec_multi + + ema_decay: 0.9998 + ema_end_decay: 1 + ema_anneal_end_step: 300000 + instance_norm_target_layer: true + layer_norm_target_layer: false + layer_norm_targets: true + end_of_block_targets: false + + depth: 32 + embed_dim: 1280 + num_heads: 16 + + average_top_k_layers: 24 + clone_batch: 16 + + norm_eps: 1e-6 + + min_target_var: 0 + min_pred_var: 0 + + encoder_dropout: 0 + post_mlp_drop: 0 + attention_dropout: 0 + activation_dropout: 0 + + supported_modality: IMAGE + cls_loss: 0.01 + + ema_encoder_only: false + + modalities: + image: + patch_size: 14 + inverse_mask: true + mask_prob: 0.75 + mask_prob_adjust: 0.1 + mask_length: 3 + mask_noise_std: 0.01 + prenet_depth: 0 + ema_local_encoder: true + num_extra_tokens: 1 + init_extra_token_zero: false + use_alibi_encoder: false + embed_dim: 1280 + decoder: + decoder_dim: 1024 + decoder_groups: 16 + decoder_kernel: 5 + decoder_layers: 3 + final_layer_norm: false + input_dropout: 0 \ No newline at end of file diff --git a/examples/data2vec/config/v2/huge_images_only_task.yaml b/examples/data2vec/config/v2/huge_images_only_task.yaml new file mode 100644 index 000000000..7a352ac3c --- /dev/null +++ b/examples/data2vec/config/v2/huge_images_only_task.yaml @@ -0,0 +1,120 @@ +# @package _group_ + +common: + fp16: true + log_format: json + log_interval: 200 + tensorboard_logdir: tb + min_loss_scale: 1e-6 + fp16_no_flatten_grads: true + user_dir: ${env:PWD}/examples/data2vec + +checkpoint: + save_interval: 5 + save_interval_updates: 25000 + keep_interval_updates: 1 + no_epoch_checkpoints: true + +task: + _name: mae_image_pretraining + data: /datasets01/imagenet_full_size/061417/ + rebuild_batches: true + local_cache_path: /scratch/cache_abaevski/imagenet + key: source + precompute_mask_config: {} + +dataset: + num_workers: 10 + batch_size: 8 + skip_invalid_size_inputs_valid_test: true + required_batch_size_multiple: 1 + disable_validation: true + +distributed_training: + distributed_world_size: 16 + ddp_backend: c10d + +criterion: + _name: model + log_keys: + - ema_decay + - target_var + - pred_var + - model_norm + - ema_norm + - masked_pct + +optimization: + max_update: 375300 + lr: [ 0.0004 ] + debug_param_names: true + clip_norm: 4 + +optimizer: + _name: composite + dynamic_groups: true + groups: + default: + lr_float: 4e-4 + optimizer: + _name: adam + adam_betas: [0.9,0.95] + weight_decay: 0.05 + lr_scheduler: + _name: cosine + warmup_updates: 50040 + +lr_scheduler: pass_through + +model: + _name: data2vec_multi + + ema_decay: 0.9998 + ema_end_decay: 0.99995 + ema_anneal_end_step: 150000 + instance_norm_target_layer: true + layer_norm_target_layer: false + layer_norm_targets: true + end_of_block_targets: false + + depth: 32 + embed_dim: 1280 + num_heads: 16 + + average_top_k_layers: 24 + clone_batch: 16 + + norm_eps: 1e-6 + + min_target_var: 0 + min_pred_var: 0 + + encoder_dropout: 0 + post_mlp_drop: 0 + attention_dropout: 0 + activation_dropout: 0 + + supported_modality: IMAGE + cls_loss: 0.01 + + ema_encoder_only: false + + modalities: + image: + inverse_mask: true + mask_prob: 0.75 + mask_prob_adjust: 0.1 + mask_length: 3 + mask_noise_std: 0.01 + prenet_depth: 0 + ema_local_encoder: true + num_extra_tokens: 1 + init_extra_token_zero: false + use_alibi_encoder: false + embed_dim: 1280 + decoder: + decoder_dim: 1024 + decoder_groups: 16 + decoder_kernel: 5 + decoder_layers: 3 + input_dropout: 0 \ No newline at end of file diff --git a/examples/data2vec/config/v2/large_audio_only_task.yaml b/examples/data2vec/config/v2/large_audio_only_task.yaml new file mode 100644 index 000000000..3f6158972 --- /dev/null +++ b/examples/data2vec/config/v2/large_audio_only_task.yaml @@ -0,0 +1,122 @@ +# @package _group_ + +common: + fp16: true + log_format: json + log_interval: 200 + tensorboard_logdir: tb + min_loss_scale: 1e-6 + fp16_no_flatten_grads: true + user_dir: ${env:PWD}/examples/data2vec + +checkpoint: + save_interval: 1 + save_interval_updates: 25000 + keep_interval_updates: 1 + no_epoch_checkpoints: true + +task: + _name: audio_pretraining + data: /fsx-wav2vec/abaevski/data/librivox/no_silence + max_sample_size: 320000 + min_sample_size: 32000 + normalize: true + precompute_mask_config: {} + +dataset: + num_workers: 8 + max_tokens: 320000 + skip_invalid_size_inputs_valid_test: true + validate_interval: 5 + required_batch_size_multiple: 1 + disable_validation: true + +distributed_training: + distributed_world_size: 48 + ddp_backend: c10d + +criterion: + _name: model + log_keys: + - ema_decay + - target_var + - pred_var + - model_norm + - ema_norm + - masked_pct + +optimization: + max_update: 600000 + debug_param_names: true + clip_norm: 1 + +optimizer: + _name: composite + dynamic_groups: true + groups: + default: + lr_float: 0.0004 + optimizer: + _name: adam + adam_betas: [0.9,0.98] + adam_eps: 1e-06 + weight_decay: 0.01 + lr_scheduler: + _name: cosine + warmup_updates: 10000 + +lr_scheduler: pass_through + +model: + _name: data2vec_multi + + loss_beta: 0 + loss_scale: null + + depth: 16 + embed_dim: 1024 + num_heads: 16 + + clone_batch: 12 + + ema_decay: 0.9997 + ema_end_decay: 1 + ema_anneal_end_step: 300000 + ema_encoder_only: false + + average_top_k_layers: 16 + instance_norm_target_layer: true + layer_norm_target_layer: false + layer_norm_targets: false + + layerdrop: 0 + norm_eps: 1e-5 + + supported_modality: AUDIO + + modalities: + audio: + feature_encoder_spec: '[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512,2,2)] + [(512,2,2)]' + conv_pos_depth: 5 + conv_pos_width: 95 + conv_pos_groups: 16 + prenet_depth: 8 + mask_prob: 0.55 + mask_prob_adjust: 0.1 + inverse_mask: false + mask_length: 5 + mask_noise_std: 0.01 + mask_dropout: 0 + add_masks: false + ema_local_encoder: false + use_alibi_encoder: true + prenet_layerdrop: 0 + prenet_dropout: 0.1 + learned_alibi_scale: true + learned_alibi_scale_per_head: true + decoder: + input_dropout: 0.1 + decoder_dim: 768 + decoder_groups: 16 + decoder_kernel: 7 + decoder_layers: 4 diff --git a/examples/data2vec/config/v2/large_images_only_task.yaml b/examples/data2vec/config/v2/large_images_only_task.yaml new file mode 100644 index 000000000..6b957fc12 --- /dev/null +++ b/examples/data2vec/config/v2/large_images_only_task.yaml @@ -0,0 +1,120 @@ +# @package _group_ + +common: + fp16: true + log_format: json + log_interval: 200 + tensorboard_logdir: tb + min_loss_scale: 1e-6 + fp16_no_flatten_grads: true + user_dir: ${env:PWD}/examples/data2vec + +checkpoint: + save_interval: 5 + save_interval_updates: 25000 + keep_interval_updates: 1 + no_epoch_checkpoints: true + +task: + _name: mae_image_pretraining + data: /datasets01/imagenet_full_size/061417/ + rebuild_batches: true + local_cache_path: /scratch/cache_abaevski/imagenet + key: source + precompute_mask_config: {} + +dataset: + num_workers: 10 + batch_size: 8 + skip_invalid_size_inputs_valid_test: true + required_batch_size_multiple: 1 + disable_validation: true + +distributed_training: + distributed_world_size: 16 + ddp_backend: c10d + +criterion: + _name: model + log_keys: + - ema_decay + - target_var + - pred_var + - model_norm + - ema_norm + - masked_pct + +optimization: + max_update: 375300 + lr: [ 0.0004 ] + debug_param_names: true + clip_norm: 4 + +optimizer: + _name: composite + dynamic_groups: true + groups: + default: + lr_float: 4e-4 + optimizer: + _name: adam + adam_betas: [0.9,0.95] + weight_decay: 0.05 + lr_scheduler: + _name: cosine + warmup_updates: 50040 + +lr_scheduler: pass_through + +model: + _name: data2vec_multi + + ema_decay: 0.9998 + ema_end_decay: 0.99999 + ema_anneal_end_step: 150000 + instance_norm_target_layer: true + layer_norm_target_layer: false + layer_norm_targets: true + end_of_block_targets: false + + depth: 24 + embed_dim: 1024 + num_heads: 16 + + average_top_k_layers: 18 + clone_batch: 16 + + norm_eps: 1e-6 + + min_target_var: 0 + min_pred_var: 0 + + encoder_dropout: 0 + post_mlp_drop: 0 + attention_dropout: 0 + activation_dropout: 0 + + supported_modality: IMAGE + cls_loss: 0.01 + + ema_encoder_only: false + + modalities: + image: + inverse_mask: true + mask_prob: 0.75 + mask_prob_adjust: 0.1 + mask_length: 3 + mask_noise_std: 0.01 + prenet_depth: 0 + ema_local_encoder: true + num_extra_tokens: 1 + init_extra_token_zero: false + use_alibi_encoder: false + embed_dim: 1024 + decoder: + decoder_dim: 1024 + decoder_groups: 16 + decoder_kernel: 5 + decoder_layers: 3 + input_dropout: 0 \ No newline at end of file diff --git a/examples/data2vec/config/v2/large_text_only_task.yaml b/examples/data2vec/config/v2/large_text_only_task.yaml new file mode 100644 index 000000000..fd69048e7 --- /dev/null +++ b/examples/data2vec/config/v2/large_text_only_task.yaml @@ -0,0 +1,112 @@ +# @package _group_ + +common: + fp16: true + log_format: json + log_interval: 200 + tensorboard_logdir: tb + min_loss_scale: 1e-6 + fp16_no_flatten_grads: true + user_dir: ${env:PWD}/examples/data2vec + +checkpoint: + save_interval_updates: 50000 + keep_interval_updates: 1 + no_epoch_checkpoints: true + +task: + _name: masked_lm + data: /fsx-wav2vec/abaevski/data/nlp/bookwiki_aml-full-mmap2-bin + sample_break_mode: none + tokens_per_sample: 512 + include_target_tokens: true + random_token_prob: 0 + leave_unmasked_prob: 0 + include_index: True + skip_masking: True + d2v2_multi: True + +dataset: + batch_size: 2 + ignore_unused_valid_subsets: true + skip_invalid_size_inputs_valid_test: true + disable_validation: true + +distributed_training: + distributed_world_size: 32 + ddp_backend: c10d + +criterion: + _name: model + log_keys: + - ema_decay + - target_var + - pred_var + - model_norm + - ema_norm + - masked_pct + +optimization: + max_update: 600000 + clip_norm: 1 + +optimizer: + _name: composite + dynamic_groups: true + groups: + default: + lr_float: 0.0001 + optimizer: + _name: adam + adam_betas: [0.9,0.98] + adam_eps: 1e-06 + weight_decay: 0.01 + lr_scheduler: + _name: cosine + warmup_updates: 4000 + +lr_scheduler: pass_through + +model: + _name: data2vec_multi + + loss_beta: 0 + loss_scale: 1 + + depth: 24 + num_heads: 16 + embed_dim: 1024 + clone_batch: 8 + + ema_decay: 0.9999 + ema_end_decay: 0.99999 + ema_anneal_end_step: 100000 + ema_encoder_only: true + + average_top_k_layers: 24 + layer_norm_target_layer: true + instance_norm_target_layer: false + batch_norm_target_layer: false + instance_norm_targets: true + layer_norm_targets: false + + layerdrop: 0 + norm_eps: 1e-5 + + supported_modality: TEXT + + modalities: + text: + mask_prob: 0.5 + mask_length: 1 + mask_noise_std: 0.01 + prenet_depth: 0 + decoder: + input_dropout: 0.1 + decoder_dim: 768 + decoder_groups: 1 + decoder_kernel: 9 + decoder_layers: 5 + decoder_residual: false + projection_layers: 2 + projection_ratio: 2.0 diff --git a/examples/data2vec/config/v2/large_text_only_task_pgrp_1M.yaml b/examples/data2vec/config/v2/large_text_only_task_pgrp_1M.yaml new file mode 100644 index 000000000..739e6f672 --- /dev/null +++ b/examples/data2vec/config/v2/large_text_only_task_pgrp_1M.yaml @@ -0,0 +1,123 @@ +# @package _group_ + +common: + fp16: true + log_format: json + log_interval: 200 + tensorboard_logdir: tb + fp16_no_flatten_grads: true + user_dir: ${env:PWD}/examples/data2vec + +checkpoint: + no_epoch_checkpoints: true + save_interval_updates: 50000 + keep_interval_updates: 1 + +distributed_training: + distributed_world_size: 32 + ddp_backend: legacy_ddp + +task: + _name: masked_lm + data: /fsx-wav2vec/abaevski/data/nlp/bookwiki_aml-full-mmap2-bin + sample_break_mode: none + tokens_per_sample: 512 + include_target_tokens: true + random_token_prob: 0 + leave_unmasked_prob: 0 + include_index: True + skip_masking: True + d2v2_multi: True + +criterion: + _name: model + log_keys: + - ema_decay + - target_var + - pred_var + - model_norm + - ema_norm + - masked_pct + +dataset: + batch_size: 2 + ignore_unused_valid_subsets: true + skip_invalid_size_inputs_valid_test: true + disable_validation: true + +optimization: + clip_norm: 1 + lr: [3e-4] + max_update: 1000000 + update_freq: [1] + +optimizer: + _name: composite + groups: + default: + lr_float: 1e-4 + optimizer: + _name: adam + adam_betas: [0.9,0.98] + adam_eps: 1e-06 + weight_decay: 0.01 + lr_scheduler: + _name: cosine + warmup_updates: 4000 + decoder: + lr_float: 1e-4 + optimizer: + _name: adam + adam_betas: [0.9,0.98] + adam_eps: 1e-06 + weight_decay: 0.01 + lr_scheduler: + _name: cosine + warmup_updates: 4000 + +lr_scheduler: pass_through + +model: + _name: data2vec_multi + + loss_beta: 4 + loss_scale: 1 + + depth: 24 + num_heads: 16 + embed_dim: 1024 + clone_batch: 8 + + ema_decay: 0.9999 + ema_end_decay: 0.99999 + ema_anneal_end_step: 100000 + ema_encoder_only: true + + average_top_k_layers: 24 + layer_norm_target_layer: true + instance_norm_target_layer: false + batch_norm_target_layer: false + instance_norm_targets: true + layer_norm_targets: false + + layerdrop: 0 + norm_eps: 1e-5 + + supported_modality: TEXT + decoder_group: true + + modalities: + text: + mask_prob: 0.5 + mask_length: 1 + mask_noise_std: 0.01 + prenet_depth: 0 + decoder: + input_dropout: 0.1 + decoder_dim: 768 + decoder_groups: 1 + decoder_kernel: 9 + decoder_layers: 5 + decoder_residual: false + projection_layers: 2 + projection_ratio: 2.0 diff --git a/examples/data2vec/config/v2/run_config/local.yaml b/examples/data2vec/config/v2/run_config/local.yaml new file mode 100644 index 000000000..45595f9ee --- /dev/null +++ b/examples/data2vec/config/v2/run_config/local.yaml @@ -0,0 +1,15 @@ +# @package _global_ +hydra: + sweep: + dir: ${env:PWD}/tmp_dbg/${now:%H-%M-%S} + +distributed_training: + distributed_world_size: 1 + nprocs_per_node: 1 + distributed_port: -1 + +common: + log_interval: 1 + +dataset: + num_workers: 0 diff --git a/examples/data2vec/config/v2/run_config/slurm_1.yaml b/examples/data2vec/config/v2/run_config/slurm_1.yaml new file mode 100644 index 000000000..732f01889 --- /dev/null +++ b/examples/data2vec/config/v2/run_config/slurm_1.yaml @@ -0,0 +1,37 @@ +# @package _global_ + +hydra: + job: + config: + override_dirname: + kv_sep: ':' + item_sep: '/' + exclude_keys: + - run_config + - distributed_training.distributed_port + - distributed_training.distributed_world_size + - model.pretrained_model_path + - model.target_network_path + - next_script + - task.cache_in_scratch + - task.data + - checkpoint.save_interval_updates + - checkpoint.keep_interval_updates + - checkpoint.save_on_overflow + - common.log_interval + - common.user_dir + sweep: + dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname} + subdir: '' + launcher: + submitit_folder: ${hydra.sweep.dir} + timeout_min: 4320 + cpus_per_task: 80 + gpus_per_node: 8 + tasks_per_node: 1 + mem_gb: 450 + nodes: 1 + name: ${env:PREFIX}_${hydra.job.config_name} + partition: devlab,learnlab,learnfair,scavenge + constraint: volta32gb,ib4 + max_num_timeout: 30 diff --git a/examples/data2vec/config/v2/run_config/slurm_1_aws.yaml b/examples/data2vec/config/v2/run_config/slurm_1_aws.yaml new file mode 100644 index 000000000..b2184f8cf --- /dev/null +++ b/examples/data2vec/config/v2/run_config/slurm_1_aws.yaml @@ -0,0 +1,37 @@ +# @package _global_ + +hydra: + job: + config: + override_dirname: + kv_sep: ':' + item_sep: '/' + exclude_keys: + - run_config + - distributed_training.distributed_port + - distributed_training.distributed_world_size + - model.pretrained_model_path + - model.target_network_path + - next_script + - task.cache_in_scratch + - task.local_cache_path + - task.data + - checkpoint.save_interval_updates + - checkpoint.keep_interval_updates + - checkpoint.save_on_overflow + - common.log_interval + - common.user_dir + sweep: + dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname} + subdir: '' + launcher: + submitit_folder: ${hydra.sweep.dir} + timeout_min: 4320 + cpus_per_task: 80 + gpus_per_node: 8 + tasks_per_node: 1 + mem_gb: 0 + nodes: 1 + name: ${env:PREFIX}_${hydra.job.config_name} + partition: wav2vec,learnlab,learnfair + max_num_timeout: 30 diff --git a/examples/data2vec/config/v2/run_config/slurm_2.yaml b/examples/data2vec/config/v2/run_config/slurm_2.yaml new file mode 100644 index 000000000..ec53dc2a9 --- /dev/null +++ b/examples/data2vec/config/v2/run_config/slurm_2.yaml @@ -0,0 +1,37 @@ +# @package _global_ + +hydra: + job: + config: + override_dirname: + kv_sep: ':' + item_sep: '/' + exclude_keys: + - run_config + - distributed_training.distributed_port + - distributed_training.distributed_world_size + - model.pretrained_model_path + - model.target_network_path + - next_script + - task.cache_in_scratch + - task.data + - checkpoint.save_interval_updates + - checkpoint.keep_interval_updates + - checkpoint.save_on_overflow + - common.log_interval + - common.user_dir + sweep: + dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname} + subdir: '' + launcher: + submitit_folder: ${hydra.sweep.dir} + timeout_min: 4320 + cpus_per_task: 10 + gpus_per_node: 8 + tasks_per_node: 8 + mem_gb: 450 + nodes: 2 + name: ${env:PREFIX}_${hydra.job.config_name} + partition: devlab,learnlab,learnfair,scavenge + constraint: volta32gb,ib4 + max_num_timeout: 30 diff --git a/examples/data2vec/config/v2/run_config/slurm_2_aws.yaml b/examples/data2vec/config/v2/run_config/slurm_2_aws.yaml new file mode 100644 index 000000000..553765597 --- /dev/null +++ b/examples/data2vec/config/v2/run_config/slurm_2_aws.yaml @@ -0,0 +1,39 @@ +# @package _global_ + +hydra: + job: + config: + override_dirname: + kv_sep: ':' + item_sep: '/' + exclude_keys: + - run_config + - distributed_training.distributed_port + - distributed_training.distributed_world_size + - model.pretrained_model_path + - model.target_network_path + - next_script + - task.cache_in_scratch + - task.local_cache_path + - task.data + - task.post_save_script + - checkpoint.save_interval_updates + - checkpoint.keep_interval_updates + - checkpoint.save_on_overflow + - common.log_interval + - common.user_dir + - model.model_path + sweep: + dir: /fsx-wav2vec/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname} + subdir: '' + launcher: + submitit_folder: ${hydra.sweep.dir} + timeout_min: 4320 + cpus_per_task: 12 + gpus_per_node: 8 + tasks_per_node: 8 + mem_gb: 0 + nodes: 2 + name: ${env:PREFIX}_${hydra.job.config_name} + partition: wav2vec + max_num_timeout: 30 diff --git a/examples/data2vec/config/v2/run_config/slurm_3.yaml b/examples/data2vec/config/v2/run_config/slurm_3.yaml new file mode 100644 index 000000000..14b47d14e --- /dev/null +++ b/examples/data2vec/config/v2/run_config/slurm_3.yaml @@ -0,0 +1,36 @@ +# @package _global_ + +hydra: + job: + config: + override_dirname: + kv_sep: ':' + item_sep: '/' + exclude_keys: + - run_config + - distributed_training.distributed_port + - distributed_training.distributed_world_size + - model.pretrained_model_path + - model.target_network_path + - next_script + - task.cache_in_scratch + - task.data + - checkpoint.save_interval_updates + - checkpoint.keep_interval_updates + - checkpoint.save_on_overflow + - common.log_interval + sweep: + dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname} + subdir: '' + launcher: + submitit_folder: ${hydra.sweep.dir} + timeout_min: 4320 + cpus_per_task: 80 + gpus_per_node: 8 + tasks_per_node: 1 + mem_gb: 450 + nodes: 3 + name: ${env:PREFIX}_${hydra.job.config_name} + partition: devlab,learnlab,learnfair,scavenge + constraint: volta32gb,ib4 + max_num_timeout: 30 diff --git a/examples/data2vec/config/v2/run_config/slurm_4.yaml b/examples/data2vec/config/v2/run_config/slurm_4.yaml new file mode 100644 index 000000000..c54d735fb --- /dev/null +++ b/examples/data2vec/config/v2/run_config/slurm_4.yaml @@ -0,0 +1,36 @@ +# @package _global_ + +hydra: + job: + config: + override_dirname: + kv_sep: ':' + item_sep: '/' + exclude_keys: + - run_config + - distributed_training.distributed_port + - distributed_training.distributed_world_size + - model.pretrained_model_path + - model.target_network_path + - next_script + - task.cache_in_scratch + - task.data + - checkpoint.save_interval_updates + - checkpoint.keep_interval_updates + - checkpoint.save_on_overflow + - common.log_interval + sweep: + dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname} + subdir: '' + launcher: + submitit_folder: ${hydra.sweep.dir} + timeout_min: 4320 + cpus_per_task: 10 + gpus_per_node: 8 + tasks_per_node: 8 + mem_gb: 450 + nodes: 4 + name: ${env:PREFIX}_${hydra.job.config_name} + partition: devlab,learnlab,learnfair,scavenge + constraint: volta32gb,ib4 + max_num_timeout: 30 diff --git a/examples/data2vec/config/v2/run_config/slurm_4_aws.yaml b/examples/data2vec/config/v2/run_config/slurm_4_aws.yaml new file mode 100644 index 000000000..a77f62aec --- /dev/null +++ b/examples/data2vec/config/v2/run_config/slurm_4_aws.yaml @@ -0,0 +1,37 @@ +# @package _global_ + +hydra: + job: + config: + override_dirname: + kv_sep: ':' + item_sep: '/' + exclude_keys: + - run_config + - distributed_training.distributed_port + - distributed_training.distributed_world_size + - model.pretrained_model_path + - model.target_network_path + - next_script + - task.cache_in_scratch + - task.data + - task.post_save_script + - checkpoint.save_interval_updates + - checkpoint.keep_interval_updates + - checkpoint.save_on_overflow + - common.log_interval + - common.user_dir + sweep: + dir: /fsx-wav2vec/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname} + subdir: '' + launcher: + submitit_folder: ${hydra.sweep.dir} + timeout_min: 4320 + cpus_per_task: 12 + gpus_per_node: 8 + tasks_per_node: 8 + mem_gb: 0 + nodes: 4 + name: ${env:PREFIX}_${hydra.job.config_name} + partition: wav2vec + max_num_timeout: 30 diff --git a/examples/data2vec/config/v2/run_config/slurm_6_aws.yaml b/examples/data2vec/config/v2/run_config/slurm_6_aws.yaml new file mode 100644 index 000000000..20e06582b --- /dev/null +++ b/examples/data2vec/config/v2/run_config/slurm_6_aws.yaml @@ -0,0 +1,36 @@ +# @package _global_ + +hydra: + job: + config: + override_dirname: + kv_sep: ':' + item_sep: '/' + exclude_keys: + - run_config + - distributed_training.distributed_port + - distributed_training.distributed_world_size + - model.pretrained_model_path + - model.target_network_path + - next_script + - task.cache_in_scratch + - task.data + - checkpoint.save_interval_updates + - checkpoint.keep_interval_updates + - checkpoint.save_on_overflow + - common.log_interval + - common.user_dir + sweep: + dir: /fsx-wav2vec/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname} + subdir: '' + launcher: + submitit_folder: ${hydra.sweep.dir} + timeout_min: 4320 + cpus_per_task: 12 + gpus_per_node: 8 + tasks_per_node: 8 + mem_gb: 0 + nodes: 6 + name: ${env:PREFIX}_${hydra.job.config_name} + partition: wav2vec,learnlab,learnfair + max_num_timeout: 30 diff --git a/examples/data2vec/config/v2/run_config/slurm_8.yaml b/examples/data2vec/config/v2/run_config/slurm_8.yaml new file mode 100644 index 000000000..e3ec2c284 --- /dev/null +++ b/examples/data2vec/config/v2/run_config/slurm_8.yaml @@ -0,0 +1,37 @@ +# @package _global_ + +hydra: + job: + config: + override_dirname: + kv_sep: ':' + item_sep: '/' + exclude_keys: + - run_config + - distributed_training.distributed_port + - distributed_training.distributed_world_size + - model.pretrained_model_path + - model.target_network_path + - next_script + - task.cache_in_scratch + - task.data + - checkpoint.save_interval_updates + - checkpoint.keep_interval_updates + - checkpoint.save_on_overflow + - common.log_interval + - common.user_dir + sweep: + dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname} + subdir: '' + launcher: + submitit_folder: ${hydra.sweep.dir} + timeout_min: 4320 + cpus_per_task: 10 + gpus_per_node: 8 + tasks_per_node: 8 + mem_gb: 450 + nodes: 8 + name: ${env:PREFIX}_${hydra.job.config_name} + partition: devlab,learnlab,learnfair,scavenge + constraint: volta32gb,ib4 + max_num_timeout: 30 diff --git a/examples/data2vec/config/v2/run_config/slurm_8_aws.yaml b/examples/data2vec/config/v2/run_config/slurm_8_aws.yaml new file mode 100644 index 000000000..a9dce876c --- /dev/null +++ b/examples/data2vec/config/v2/run_config/slurm_8_aws.yaml @@ -0,0 +1,36 @@ +# @package _global_ + +hydra: + job: + config: + override_dirname: + kv_sep: ':' + item_sep: '/' + exclude_keys: + - run_config + - distributed_training.distributed_port + - distributed_training.distributed_world_size + - model.pretrained_model_path + - model.target_network_path + - next_script + - task.cache_in_scratch + - task.data + - checkpoint.save_interval_updates + - checkpoint.keep_interval_updates + - checkpoint.save_on_overflow + - common.log_interval + - common.user_dir + sweep: + dir: /fsx-wav2vec/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname} + subdir: '' + launcher: + submitit_folder: ${hydra.sweep.dir} + timeout_min: 4320 + cpus_per_task: 12 + gpus_per_node: 8 + tasks_per_node: 8 + mem_gb: 0 + nodes: 8 + name: ${env:PREFIX}_${hydra.job.config_name} + partition: wav2vec,learnlab,learnfair + max_num_timeout: 30 diff --git a/examples/data2vec/config/v2/text_finetuning/cola.yaml b/examples/data2vec/config/v2/text_finetuning/cola.yaml new file mode 100644 index 000000000..d4ac4ec8b --- /dev/null +++ b/examples/data2vec/config/v2/text_finetuning/cola.yaml @@ -0,0 +1,60 @@ +# @package _group_ + +common: + fp16: true + fp16_init_scale: 4 + threshold_loss_scale: 1 + fp16_scale_window: 128 + log_format: json + log_interval: 200 + user_dir: ${env:PWD}/examples/data2vec + +task: + _name: sentence_prediction + data: ??? + init_token: 0 + separator_token: 2 + num_classes: 2 + max_positions: 512 + d2v2_multi: True + +checkpoint: + best_checkpoint_metric: mcc + maximize_best_checkpoint_metric: true + no_epoch_checkpoints: true + +distributed_training: + find_unused_parameters: true + distributed_world_size: 1 + nprocs_per_node: 1 + distributed_port: -1 + +criterion: + _name: sentence_prediction + report_mcc: True + +dataset: + batch_size: 16 + required_batch_size_multiple: 1 + max_tokens: 4400 + num_workers: 1 + +optimizer: + _name: adam + weight_decay: 0.1 + adam_betas: (0.9,0.98) + adam_eps: 1e-06 + +lr_scheduler: + _name: polynomial_decay + warmup_updates: 320 + +optimization: + clip_norm: 0.0 + lr: [2e-05] + max_update: 5336 + max_epoch: 10 + +model: + _name: data2vec_text_classification + model_path: ??? diff --git a/examples/data2vec/config/v2/text_finetuning/mnli.yaml b/examples/data2vec/config/v2/text_finetuning/mnli.yaml new file mode 100644 index 000000000..1a9d6e52f --- /dev/null +++ b/examples/data2vec/config/v2/text_finetuning/mnli.yaml @@ -0,0 +1,60 @@ +# @package _group_ + +common: + fp16: true + fp16_init_scale: 4 + threshold_loss_scale: 1 + fp16_scale_window: 128 + log_format: json + log_interval: 200 + user_dir: ${env:PWD}/examples/data2vec + +task: + _name: sentence_prediction + data: ??? + init_token: 0 + separator_token: 2 + num_classes: 3 + max_positions: 512 + d2v2_multi: True + +checkpoint: + best_checkpoint_metric: accuracy + maximize_best_checkpoint_metric: true + no_epoch_checkpoints: true + +distributed_training: + find_unused_parameters: true + distributed_world_size: 1 + nprocs_per_node: 1 + distributed_port: -1 + +criterion: + _name: sentence_prediction + +dataset: + batch_size: 32 + required_batch_size_multiple: 1 + max_tokens: 4400 + valid_subset: valid,valid1 + num_workers: 1 + +optimizer: + _name: adam + weight_decay: 0.1 + adam_betas: (0.9,0.98) + adam_eps: 1e-06 + +lr_scheduler: + _name: polynomial_decay + warmup_updates: 7432 + +optimization: + clip_norm: 0.0 + lr: [2e-05] + max_update: 123873 + max_epoch: 10 + +model: + _name: data2vec_text_classification + model_path: ??? diff --git a/examples/data2vec/config/v2/text_finetuning/mrpc.yaml b/examples/data2vec/config/v2/text_finetuning/mrpc.yaml new file mode 100644 index 000000000..8f93d9d9e --- /dev/null +++ b/examples/data2vec/config/v2/text_finetuning/mrpc.yaml @@ -0,0 +1,60 @@ +# @package _group_ + +common: + fp16: true + fp16_init_scale: 4 + threshold_loss_scale: 1 + fp16_scale_window: 128 + log_format: json + log_interval: 200 + user_dir: ${env:PWD}/examples/data2vec + +task: + _name: sentence_prediction + data: ??? + init_token: 0 + separator_token: 2 + num_classes: 2 + max_positions: 512 + d2v2_multi: True + +checkpoint: + best_checkpoint_metric: acc_and_f1 + maximize_best_checkpoint_metric: true + no_epoch_checkpoints: true + +distributed_training: + find_unused_parameters: true + distributed_world_size: 1 + nprocs_per_node: 1 + distributed_port: -1 + +criterion: + _name: sentence_prediction + report_acc_and_f1: True + +dataset: + batch_size: 16 + required_batch_size_multiple: 1 + max_tokens: 4400 + num_workers: 1 + +optimizer: + _name: adam + weight_decay: 0.1 + adam_betas: (0.9,0.98) + adam_eps: 1e-06 + +lr_scheduler: + _name: polynomial_decay + warmup_updates: 137 + +optimization: + clip_norm: 0.0 + lr: [2e-05] + max_update: 2296 + max_epoch: 10 + +model: + _name: data2vec_text_classification + model_path: ??? diff --git a/examples/data2vec/config/v2/text_finetuning/qnli.yaml b/examples/data2vec/config/v2/text_finetuning/qnli.yaml new file mode 100644 index 000000000..739fb53b6 --- /dev/null +++ b/examples/data2vec/config/v2/text_finetuning/qnli.yaml @@ -0,0 +1,59 @@ +# @package _group_ + +common: + fp16: true + fp16_init_scale: 4 + threshold_loss_scale: 1 + fp16_scale_window: 128 + log_format: json + log_interval: 200 + user_dir: ${env:PWD}/examples/data2vec + +task: + _name: sentence_prediction + data: ??? + init_token: 0 + separator_token: 2 + num_classes: 2 + max_positions: 512 + d2v2_multi: True + +checkpoint: + best_checkpoint_metric: accuracy + maximize_best_checkpoint_metric: true + no_epoch_checkpoints: true + +distributed_training: + find_unused_parameters: true + distributed_world_size: 1 + nprocs_per_node: 1 + distributed_port: -1 + +criterion: + _name: sentence_prediction + +dataset: + batch_size: 32 + required_batch_size_multiple: 1 + max_tokens: 4400 + num_workers: 1 + +optimizer: + _name: adam + weight_decay: 0.1 + adam_betas: (0.9,0.98) + adam_eps: 1e-06 + +lr_scheduler: + _name: polynomial_decay + warmup_updates: 1986 + +optimization: + clip_norm: 0.0 + lr: [2e-05] + max_update: 33112 + max_epoch: 10 + +model: + _name: data2vec_text_classification + model_path: ??? diff --git a/examples/data2vec/config/v2/text_finetuning/qqp.yaml b/examples/data2vec/config/v2/text_finetuning/qqp.yaml new file mode 100644 index 000000000..9accbaa52 --- /dev/null +++ b/examples/data2vec/config/v2/text_finetuning/qqp.yaml @@ -0,0 +1,60 @@ +# @package _group_ + +common: + fp16: true + fp16_init_scale: 4 + threshold_loss_scale: 1 + fp16_scale_window: 128 + log_format: json + log_interval: 200 + user_dir: ${env:PWD}/examples/data2vec + +task: + _name: sentence_prediction + data: ??? + init_token: 0 + separator_token: 2 + num_classes: 2 + max_positions: 512 + d2v2_multi: True + +checkpoint: + best_checkpoint_metric: acc_and_f1 + maximize_best_checkpoint_metric: true + no_epoch_checkpoints: true + +distributed_training: + find_unused_parameters: true + distributed_world_size: 1 + nprocs_per_node: 1 + distributed_port: -1 + +criterion: + _name: sentence_prediction + report_acc_and_f1: True + +dataset: + batch_size: 32 + required_batch_size_multiple: 1 + max_tokens: 4400 + num_workers: 1 + +optimizer: + _name: adam + weight_decay: 0.1 + adam_betas: (0.9,0.98) + adam_eps: 1e-06 + +lr_scheduler: + _name: polynomial_decay + warmup_updates: 28318 + +optimization: + clip_norm: 0.0 + lr: [2e-05] + max_update: 113272 + max_epoch: 10 + +model: + _name: data2vec_text_classification + model_path: ??? diff --git a/examples/data2vec/config/v2/text_finetuning/rte.yaml b/examples/data2vec/config/v2/text_finetuning/rte.yaml new file mode 100644 index 000000000..ea07764d9 --- /dev/null +++ b/examples/data2vec/config/v2/text_finetuning/rte.yaml @@ -0,0 +1,59 @@ +# @package _group_ + +common: + fp16: true + fp16_init_scale: 4 + threshold_loss_scale: 1 + fp16_scale_window: 128 + log_format: json + log_interval: 200 + user_dir: ${env:PWD}/examples/data2vec + +task: + _name: sentence_prediction + data: ??? + init_token: 0 + separator_token: 2 + num_classes: 2 + max_positions: 512 + d2v2_multi: True + +checkpoint: + best_checkpoint_metric: accuracy + maximize_best_checkpoint_metric: true + no_epoch_checkpoints: true + +distributed_training: + find_unused_parameters: true + distributed_world_size: 1 + nprocs_per_node: 1 + distributed_port: -1 + +criterion: + _name: sentence_prediction + +dataset: + batch_size: 16 + required_batch_size_multiple: 1 + max_tokens: 4400 + num_workers: 1 + +optimizer: + _name: adam + weight_decay: 0.1 + adam_betas: (0.9,0.98) + adam_eps: 1e-06 + +lr_scheduler: + _name: polynomial_decay + warmup_updates: 122 + +optimization: + clip_norm: 0.0 + lr: [2e-05] + max_update: 2036 + max_epoch: 10 + +model: + _name: data2vec_text_classification + model_path: ??? diff --git a/examples/data2vec/config/v2/text_finetuning/run_config/local.yaml b/examples/data2vec/config/v2/text_finetuning/run_config/local.yaml new file mode 100644 index 000000000..45595f9ee --- /dev/null +++ b/examples/data2vec/config/v2/text_finetuning/run_config/local.yaml @@ -0,0 +1,15 @@ +# @package _global_ +hydra: + sweep: + dir: ${env:PWD}/tmp_dbg/${now:%H-%M-%S} + +distributed_training: + distributed_world_size: 1 + nprocs_per_node: 1 + distributed_port: -1 + +common: + log_interval: 1 + +dataset: + num_workers: 0 diff --git a/examples/data2vec/config/v2/text_finetuning/sst_2.yaml b/examples/data2vec/config/v2/text_finetuning/sst_2.yaml new file mode 100644 index 000000000..a273e5b94 --- /dev/null +++ b/examples/data2vec/config/v2/text_finetuning/sst_2.yaml @@ -0,0 +1,59 @@ +# @package _group_ + +common: + fp16: true + fp16_init_scale: 4 + threshold_loss_scale: 1 + fp16_scale_window: 128 + log_format: json + log_interval: 200 + user_dir: ${env:PWD}/examples/data2vec + +task: + _name: sentence_prediction + data: ??? + init_token: 0 + separator_token: 2 + num_classes: 2 + max_positions: 512 + d2v2_multi: True + +checkpoint: + best_checkpoint_metric: accuracy + maximize_best_checkpoint_metric: true + no_epoch_checkpoints: true + +distributed_training: + find_unused_parameters: true + distributed_world_size: 1 + nprocs_per_node: 1 + distributed_port: -1 + +criterion: + _name: sentence_prediction + +dataset: + batch_size: 32 + required_batch_size_multiple: 1 + max_tokens: 4400 + num_workers: 1 + +optimizer: + _name: adam + weight_decay: 0.1 + adam_betas: (0.9,0.98) + adam_eps: 1e-06 + +lr_scheduler: + _name: polynomial_decay + warmup_updates: 1256 + +optimization: + clip_norm: 0.0 + lr: [2e-05] + max_update: 20935 + max_epoch: 10 + +model: + _name: data2vec_text_classification + model_path: ??? diff --git a/examples/data2vec/config/v2/text_finetuning/sts_b.yaml b/examples/data2vec/config/v2/text_finetuning/sts_b.yaml new file mode 100644 index 000000000..fb009ab95 --- /dev/null +++ b/examples/data2vec/config/v2/text_finetuning/sts_b.yaml @@ -0,0 +1,61 @@ +# @package _group_ + +common: + fp16: true + fp16_init_scale: 4 + threshold_loss_scale: 1 + fp16_scale_window: 128 + log_format: json + log_interval: 200 + user_dir: ${env:PWD}/examples/data2vec + +task: + _name: sentence_prediction + data: ??? + init_token: 0 + separator_token: 2 + num_classes: 1 + max_positions: 512 + d2v2_multi: True + +checkpoint: + best_checkpoint_metric: pearson_and_spearman + maximize_best_checkpoint_metric: true + no_epoch_checkpoints: true + +distributed_training: + find_unused_parameters: true + distributed_world_size: 1 + nprocs_per_node: 1 + distributed_port: -1 + +criterion: + _name: sentence_prediction + regression_target: true + report_pearson_and_spearman: True + +dataset: + batch_size: 16 + required_batch_size_multiple: 1 + max_tokens: 4400 + num_workers: 1 + +optimizer: + _name: adam + weight_decay: 0.1 + adam_betas: (0.9,0.98) + adam_eps: 1e-06 + +lr_scheduler: + _name: polynomial_decay + warmup_updates: 214 + +optimization: + clip_norm: 0.0 + lr: [4e-05] + max_update: 3598 + max_epoch: 10 + +model: + _name: data2vec_text_classification + model_path: ??? diff --git a/examples/data2vec/config/vision/finetuning/imagenet.yaml b/examples/data2vec/config/vision/finetuning/imagenet.yaml new file mode 100644 index 000000000..d6d4864cc --- /dev/null +++ b/examples/data2vec/config/vision/finetuning/imagenet.yaml @@ -0,0 +1,52 @@ +# @package _group_ + +common: + fp16: true + log_format: json + log_interval: 200 + tensorboard_logdir: tb + +checkpoint: + save_interval: 1 + save_interval_updates: 25000 + keep_interval_updates: 1 + no_epoch_checkpoints: true + best_checkpoint_metric: accuracy + +task: + _name: image_classification + data: /datasets01/imagenet_full_size/061417 + +dataset: + num_workers: 6 + batch_size: 64 + skip_invalid_size_inputs_valid_test: true + required_batch_size_multiple: 1 + valid_subset: val + +distributed_training: + distributed_world_size: 8 + ddp_backend: c10d + +criterion: + _name: model + log_keys: + - correct + +optimization: + max_update: 100000 + lr: [0.0005] + +optimizer: + _name: adam + adam_betas: (0.9,0.98) + adam_eps: 1e-06 + weight_decay: 0.01 + +lr_scheduler: + _name: cosine + warmup_updates: 10000 + +model: + _name: data2vec_image_classification + model_path: ??? diff --git a/examples/data2vec/config/vision/finetuning/mae_imagenet_clean.yaml b/examples/data2vec/config/vision/finetuning/mae_imagenet_clean.yaml new file mode 100644 index 000000000..17d4c0a8f --- /dev/null +++ b/examples/data2vec/config/vision/finetuning/mae_imagenet_clean.yaml @@ -0,0 +1,65 @@ +# @package _group_ + +common: + fp16: true + log_format: json + log_interval: 200 + tensorboard_logdir: tb + fp16_no_flatten_grads: true + +checkpoint: + save_interval: 1 + save_interval_updates: 25000 + keep_interval_updates: 1 + no_epoch_checkpoints: true + best_checkpoint_metric: accuracy + maximize_best_checkpoint_metric: true + +task: + _name: mae_image_classification + data: /datasets01/imagenet_full_size/061417 + +dataset: + num_workers: 6 + batch_size: 32 + skip_invalid_size_inputs_valid_test: true + required_batch_size_multiple: 2 + valid_subset: val + +distributed_training: + distributed_world_size: 16 + ddp_backend: c10d + +criterion: + _name: model + log_keys: + - correct + +optimization: + max_update: 250200 + lr: [0.001] + +optimizer: + _name: composite + dynamic_groups: true + groups: + default: + lr_float: 0.001 + optimizer: + _name: adam + adam_betas: [0.9,0.95] + weight_decay: 0.05 + lr_scheduler: + _name: cosine + warmup_updates: 16000 + min_lr: 1e-6 + + +lr_scheduler: pass_through + +model: + _name: mae_image_classification + mixup: 0.7 + mixup_prob: 0.9 + + model_path: ??? diff --git a/examples/data2vec/config/vision/finetuning/mae_imagenet_huge_clean.yaml b/examples/data2vec/config/vision/finetuning/mae_imagenet_huge_clean.yaml new file mode 100644 index 000000000..2d2eb57ba --- /dev/null +++ b/examples/data2vec/config/vision/finetuning/mae_imagenet_huge_clean.yaml @@ -0,0 +1,68 @@ +# @package _group_ + +common: + fp16: true + log_format: json + log_interval: 200 + tensorboard_logdir: tb + fp16_no_flatten_grads: true + +checkpoint: + save_interval: 1 + save_interval_updates: 25000 + keep_interval_updates: 1 + no_epoch_checkpoints: true + best_checkpoint_metric: accuracy + maximize_best_checkpoint_metric: true + +task: + _name: mae_image_classification + data: /datasets01/imagenet_full_size/061417 + +dataset: + num_workers: 6 + batch_size: 32 + skip_invalid_size_inputs_valid_test: true + required_batch_size_multiple: 2 + valid_subset: val + +distributed_training: + distributed_world_size: 16 + ddp_backend: c10d + +criterion: + _name: model + log_keys: + - correct + +optimization: + max_update: 125200 + lr: [0.0005] + clip_norm: 4 + +optimizer: + _name: composite + dynamic_groups: true + groups: + default: + lr_float: 0.0005 + optimizer: + _name: adam + adam_betas: [0.9,0.95] + weight_decay: 0.05 + lr_scheduler: + _name: cosine + warmup_updates: 16000 + min_lr: 1e-20 + + +lr_scheduler: pass_through + +model: + _name: mae_image_classification + mixup: 0.7 + mixup_prob: 0.9 + layer_decay: 0.75 + drop_path_rate: 0.2 + + model_path: ??? diff --git a/examples/data2vec/config/vision/finetuning/mae_imagenet_large_clean.yaml b/examples/data2vec/config/vision/finetuning/mae_imagenet_large_clean.yaml new file mode 100644 index 000000000..3a9413cef --- /dev/null +++ b/examples/data2vec/config/vision/finetuning/mae_imagenet_large_clean.yaml @@ -0,0 +1,68 @@ +# @package _group_ + +common: + fp16: true + log_format: json + log_interval: 200 + tensorboard_logdir: tb + fp16_no_flatten_grads: true + +checkpoint: + save_interval: 1 + save_interval_updates: 25000 + keep_interval_updates: 1 + no_epoch_checkpoints: true + best_checkpoint_metric: accuracy + maximize_best_checkpoint_metric: true + +task: + _name: mae_image_classification + data: /datasets01/imagenet_full_size/061417 + +dataset: + num_workers: 6 + batch_size: 32 + skip_invalid_size_inputs_valid_test: true + required_batch_size_multiple: 2 + valid_subset: val + +distributed_training: + distributed_world_size: 16 + ddp_backend: c10d + +criterion: + _name: model + log_keys: + - correct + +optimization: + max_update: 125200 + lr: [0.0005] + clip_norm: 4 + +optimizer: + _name: composite + dynamic_groups: true + groups: + default: + lr_float: 0.0005 + optimizer: + _name: adam + adam_betas: [0.9,0.95] + weight_decay: 0.05 + lr_scheduler: + _name: cosine + warmup_updates: 16000 + min_lr: 1e-7 + + +lr_scheduler: pass_through + +model: + _name: mae_image_classification + mixup: 0.7 + mixup_prob: 0.9 + layer_decay: 0.75 + drop_path_rate: 0.2 + + model_path: ??? diff --git a/examples/data2vec/config/vision/finetuning/run_config/local.yaml b/examples/data2vec/config/vision/finetuning/run_config/local.yaml new file mode 100644 index 000000000..45595f9ee --- /dev/null +++ b/examples/data2vec/config/vision/finetuning/run_config/local.yaml @@ -0,0 +1,15 @@ +# @package _global_ +hydra: + sweep: + dir: ${env:PWD}/tmp_dbg/${now:%H-%M-%S} + +distributed_training: + distributed_world_size: 1 + nprocs_per_node: 1 + distributed_port: -1 + +common: + log_interval: 1 + +dataset: + num_workers: 0 diff --git a/examples/data2vec/config/vision/finetuning/run_config/slurm_1.yaml b/examples/data2vec/config/vision/finetuning/run_config/slurm_1.yaml new file mode 100644 index 000000000..732f01889 --- /dev/null +++ b/examples/data2vec/config/vision/finetuning/run_config/slurm_1.yaml @@ -0,0 +1,37 @@ +# @package _global_ + +hydra: + job: + config: + override_dirname: + kv_sep: ':' + item_sep: '/' + exclude_keys: + - run_config + - distributed_training.distributed_port + - distributed_training.distributed_world_size + - model.pretrained_model_path + - model.target_network_path + - next_script + - task.cache_in_scratch + - task.data + - checkpoint.save_interval_updates + - checkpoint.keep_interval_updates + - checkpoint.save_on_overflow + - common.log_interval + - common.user_dir + sweep: + dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname} + subdir: '' + launcher: + submitit_folder: ${hydra.sweep.dir} + timeout_min: 4320 + cpus_per_task: 80 + gpus_per_node: 8 + tasks_per_node: 1 + mem_gb: 450 + nodes: 1 + name: ${env:PREFIX}_${hydra.job.config_name} + partition: devlab,learnlab,learnfair,scavenge + constraint: volta32gb,ib4 + max_num_timeout: 30 diff --git a/examples/data2vec/config/vision/finetuning/run_config/slurm_1_aws.yaml b/examples/data2vec/config/vision/finetuning/run_config/slurm_1_aws.yaml new file mode 100644 index 000000000..e2bab5675 --- /dev/null +++ b/examples/data2vec/config/vision/finetuning/run_config/slurm_1_aws.yaml @@ -0,0 +1,36 @@ +# @package _global_ + +hydra: + job: + config: + override_dirname: + kv_sep: ':' + item_sep: '/' + exclude_keys: + - run_config + - distributed_training.distributed_port + - distributed_training.distributed_world_size + - model.pretrained_model_path + - model.target_network_path + - next_script + - task.cache_in_scratch + - task.data + - checkpoint.save_interval_updates + - checkpoint.keep_interval_updates + - checkpoint.save_on_overflow + - common.log_interval + - common.user_dir + sweep: + dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname} + subdir: '' + launcher: + submitit_folder: ${hydra.sweep.dir} + timeout_min: 4320 + cpus_per_task: 80 + gpus_per_node: 8 + tasks_per_node: 1 + mem_gb: 0 + nodes: 1 + name: ${env:PREFIX}_${hydra.job.config_name} + partition: wav2vec,learnlab,learnfair + max_num_timeout: 30 diff --git a/examples/data2vec/config/vision/finetuning/run_config/slurm_2.yaml b/examples/data2vec/config/vision/finetuning/run_config/slurm_2.yaml new file mode 100644 index 000000000..c8b0f02a9 --- /dev/null +++ b/examples/data2vec/config/vision/finetuning/run_config/slurm_2.yaml @@ -0,0 +1,38 @@ +# @package _global_ + +hydra: + job: + config: + override_dirname: + kv_sep: ':' + item_sep: '/' + exclude_keys: + - run_config + - distributed_training.distributed_port + - distributed_training.distributed_world_size + - model.pretrained_model_path + - model.target_network_path + - next_script + - task.cache_in_scratch + - task.data + - checkpoint.save_interval_updates + - checkpoint.keep_interval_updates + - checkpoint.save_on_overflow + - common.log_interval + - common.user_dir + - task.local_cache_path + sweep: + dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname} + subdir: '' + launcher: + submitit_folder: ${hydra.sweep.dir} + timeout_min: 4320 + cpus_per_task: 10 + gpus_per_node: 8 + tasks_per_node: 8 + mem_gb: 450 + nodes: 2 + name: ${env:PREFIX}_${hydra.job.config_name} + partition: devlab,learnlab,learnfair,scavenge + constraint: volta32gb,ib4 + max_num_timeout: 30 diff --git a/examples/data2vec/config/vision/finetuning/run_config/slurm_2_aws.yaml b/examples/data2vec/config/vision/finetuning/run_config/slurm_2_aws.yaml new file mode 100644 index 000000000..93d0d9c20 --- /dev/null +++ b/examples/data2vec/config/vision/finetuning/run_config/slurm_2_aws.yaml @@ -0,0 +1,38 @@ +# @package _global_ + +hydra: + job: + config: + override_dirname: + kv_sep: ':' + item_sep: '/' + exclude_keys: + - run_config + - distributed_training.distributed_port + - distributed_training.distributed_world_size + - model.pretrained_model_path + - model.target_network_path + - next_script + - task.cache_in_scratch + - task.data + - checkpoint.save_interval_updates + - checkpoint.keep_interval_updates + - checkpoint.save_on_overflow + - common.log_interval + - common.user_dir + - task.local_cache_path + - model.model_path + sweep: + dir: /fsx-wav2vec/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname} + subdir: '' + launcher: + submitit_folder: ${hydra.sweep.dir} + timeout_min: 4320 + cpus_per_task: 10 + gpus_per_node: 8 + tasks_per_node: 8 + mem_gb: 0 + nodes: 2 + name: ${env:PREFIX}_${hydra.job.config_name} + partition: wav2vec,learnlab,learnfair + max_num_timeout: 30 diff --git a/examples/data2vec/config/vision/finetuning/run_config/slurm_3.yaml b/examples/data2vec/config/vision/finetuning/run_config/slurm_3.yaml new file mode 100644 index 000000000..14b47d14e --- /dev/null +++ b/examples/data2vec/config/vision/finetuning/run_config/slurm_3.yaml @@ -0,0 +1,36 @@ +# @package _global_ + +hydra: + job: + config: + override_dirname: + kv_sep: ':' + item_sep: '/' + exclude_keys: + - run_config + - distributed_training.distributed_port + - distributed_training.distributed_world_size + - model.pretrained_model_path + - model.target_network_path + - next_script + - task.cache_in_scratch + - task.data + - checkpoint.save_interval_updates + - checkpoint.keep_interval_updates + - checkpoint.save_on_overflow + - common.log_interval + sweep: + dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname} + subdir: '' + launcher: + submitit_folder: ${hydra.sweep.dir} + timeout_min: 4320 + cpus_per_task: 80 + gpus_per_node: 8 + tasks_per_node: 1 + mem_gb: 450 + nodes: 3 + name: ${env:PREFIX}_${hydra.job.config_name} + partition: devlab,learnlab,learnfair,scavenge + constraint: volta32gb,ib4 + max_num_timeout: 30 diff --git a/examples/data2vec/config/vision/finetuning/run_config/slurm_4.yaml b/examples/data2vec/config/vision/finetuning/run_config/slurm_4.yaml new file mode 100644 index 000000000..c54d735fb --- /dev/null +++ b/examples/data2vec/config/vision/finetuning/run_config/slurm_4.yaml @@ -0,0 +1,36 @@ +# @package _global_ + +hydra: + job: + config: + override_dirname: + kv_sep: ':' + item_sep: '/' + exclude_keys: + - run_config + - distributed_training.distributed_port + - distributed_training.distributed_world_size + - model.pretrained_model_path + - model.target_network_path + - next_script + - task.cache_in_scratch + - task.data + - checkpoint.save_interval_updates + - checkpoint.keep_interval_updates + - checkpoint.save_on_overflow + - common.log_interval + sweep: + dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname} + subdir: '' + launcher: + submitit_folder: ${hydra.sweep.dir} + timeout_min: 4320 + cpus_per_task: 10 + gpus_per_node: 8 + tasks_per_node: 8 + mem_gb: 450 + nodes: 4 + name: ${env:PREFIX}_${hydra.job.config_name} + partition: devlab,learnlab,learnfair,scavenge + constraint: volta32gb,ib4 + max_num_timeout: 30 diff --git a/examples/data2vec/config/vision/finetuning/run_config/slurm_4_aws.yaml b/examples/data2vec/config/vision/finetuning/run_config/slurm_4_aws.yaml new file mode 100644 index 000000000..d5d11cb75 --- /dev/null +++ b/examples/data2vec/config/vision/finetuning/run_config/slurm_4_aws.yaml @@ -0,0 +1,36 @@ +# @package _global_ + +hydra: + job: + config: + override_dirname: + kv_sep: ':' + item_sep: '/' + exclude_keys: + - run_config + - distributed_training.distributed_port + - distributed_training.distributed_world_size + - model.pretrained_model_path + - model.target_network_path + - next_script + - task.cache_in_scratch + - task.data + - checkpoint.save_interval_updates + - checkpoint.keep_interval_updates + - checkpoint.save_on_overflow + - common.log_interval + - common.user_dir + sweep: + dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname} + subdir: '' + launcher: + submitit_folder: ${hydra.sweep.dir} + timeout_min: 4320 + cpus_per_task: 10 + gpus_per_node: 8 + tasks_per_node: 8 + mem_gb: 0 + nodes: 4 + name: ${env:PREFIX}_${hydra.job.config_name} + partition: wav2vec,learnlab,learnfair + max_num_timeout: 30 diff --git a/examples/data2vec/config/vision/finetuning/run_config/slurm_6_aws.yaml b/examples/data2vec/config/vision/finetuning/run_config/slurm_6_aws.yaml new file mode 100644 index 000000000..906f08a60 --- /dev/null +++ b/examples/data2vec/config/vision/finetuning/run_config/slurm_6_aws.yaml @@ -0,0 +1,36 @@ +# @package _global_ + +hydra: + job: + config: + override_dirname: + kv_sep: ':' + item_sep: '/' + exclude_keys: + - run_config + - distributed_training.distributed_port + - distributed_training.distributed_world_size + - model.pretrained_model_path + - model.target_network_path + - next_script + - task.cache_in_scratch + - task.data + - checkpoint.save_interval_updates + - checkpoint.keep_interval_updates + - checkpoint.save_on_overflow + - common.log_interval + - common.user_dir + sweep: + dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname} + subdir: '' + launcher: + submitit_folder: ${hydra.sweep.dir} + timeout_min: 4320 + cpus_per_task: 10 + gpus_per_node: 8 + tasks_per_node: 8 + mem_gb: 0 + nodes: 6 + name: ${env:PREFIX}_${hydra.job.config_name} + partition: wav2vec,learnlab,learnfair + max_num_timeout: 30 diff --git a/examples/data2vec/config/vision/finetuning/run_config/slurm_8_aws.yaml b/examples/data2vec/config/vision/finetuning/run_config/slurm_8_aws.yaml new file mode 100644 index 000000000..d60e13f8b --- /dev/null +++ b/examples/data2vec/config/vision/finetuning/run_config/slurm_8_aws.yaml @@ -0,0 +1,36 @@ +# @package _global_ + +hydra: + job: + config: + override_dirname: + kv_sep: ':' + item_sep: '/' + exclude_keys: + - run_config + - distributed_training.distributed_port + - distributed_training.distributed_world_size + - model.pretrained_model_path + - model.target_network_path + - next_script + - task.cache_in_scratch + - task.data + - checkpoint.save_interval_updates + - checkpoint.keep_interval_updates + - checkpoint.save_on_overflow + - common.log_interval + - common.user_dir + sweep: + dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname} + subdir: '' + launcher: + submitit_folder: ${hydra.sweep.dir} + timeout_min: 4320 + cpus_per_task: 10 + gpus_per_node: 8 + tasks_per_node: 8 + mem_gb: 0 + nodes: 8 + name: ${env:PREFIX}_${hydra.job.config_name} + partition: wav2vec,learnlab,learnfair + max_num_timeout: 30 diff --git a/examples/data2vec/config/vision/pretraining/base_imagenet.yaml b/examples/data2vec/config/vision/pretraining/base_imagenet.yaml new file mode 100644 index 000000000..9bfc0f32b --- /dev/null +++ b/examples/data2vec/config/vision/pretraining/base_imagenet.yaml @@ -0,0 +1,52 @@ +# @package _group_ + +common: + fp16: true + log_format: json + log_interval: 200 + tensorboard_logdir: tb + +checkpoint: + save_interval: 5 + save_interval_updates: 25000 + keep_interval_updates: 1 + no_epoch_checkpoints: true + +task: + _name: image_pretraining + data: /datasets01/imagenet_full_size/061417/ + +dataset: + num_workers: 6 + batch_size: 64 + skip_invalid_size_inputs_valid_test: true + required_batch_size_multiple: 1 + disable_validation: true + +distributed_training: + distributed_world_size: 16 + ddp_backend: c10d + +criterion: + _name: model + log_keys: + - ema_decay + - target_var + - pred_var + +optimization: + max_update: 400000 + lr: [0.0005] + +optimizer: + _name: adam + adam_betas: (0.9,0.98) + adam_eps: 1e-06 + weight_decay: 0.01 + +lr_scheduler: + _name: cosine + warmup_updates: 10000 + +model: + _name: data2vec_vision diff --git a/examples/data2vec/config/vision/pretraining/base_imagenet_d2v1.yaml b/examples/data2vec/config/vision/pretraining/base_imagenet_d2v1.yaml new file mode 100644 index 000000000..5fd399b11 --- /dev/null +++ b/examples/data2vec/config/vision/pretraining/base_imagenet_d2v1.yaml @@ -0,0 +1,64 @@ +# @package _group_ + +common: + fp16: true + log_format: json + log_interval: 200 + tensorboard_logdir: tb + +checkpoint: + save_interval: 5 + save_interval_updates: 25000 + keep_interval_updates: 1 + no_epoch_checkpoints: true + +task: + _name: image_pretraining + data: /datasets01/imagenet_full_size/061417 + +dataset: + num_workers: 6 + batch_size: 128 + skip_invalid_size_inputs_valid_test: true + required_batch_size_multiple: 2 + disable_validation: true + +distributed_training: + distributed_world_size: 16 + ddp_backend: legacy_ddp + +criterion: + _name: model + log_keys: + - ema_decay + - target_var + - pred_var + +optimization: + max_update: 375300 #300*1251 + lr: [0.0005] + clip_norm: 3.0 + +optimizer: + _name: adam + adam_betas: (0.9,0.999) + adam_eps: 1e-08 + weight_decay: 0.05 + +lr_scheduler: + _name: cosine + warmup_updates: 12510 # it should be 10 epochs + +model: + _name: data2vec_vision + + attention_dropout: 0.05 + + ema_decay: 0.999 + ema_end_decay: 0.9998 + layer_norm_targets: True + average_top_k_layers: 6 + + loss_beta: 2.0 + + drop_path: 0.25 diff --git a/examples/data2vec/config/vision/pretraining/base_mae_imagenet.yaml b/examples/data2vec/config/vision/pretraining/base_mae_imagenet.yaml new file mode 100644 index 000000000..d7872b5e0 --- /dev/null +++ b/examples/data2vec/config/vision/pretraining/base_mae_imagenet.yaml @@ -0,0 +1,64 @@ +# @package _group_ + +common: + fp16: true + log_format: json + log_interval: 200 + tensorboard_logdir: tb + fp16_no_flatten_grads: true + +checkpoint: + save_interval: 5 + save_interval_updates: 25000 + keep_interval_updates: 1 + no_epoch_checkpoints: true + +task: + _name: mae_image_pretraining + data: /datasets01/imagenet_full_size/061417/ + rebuild_batches: true + +dataset: + num_workers: 6 + batch_size: 64 + skip_invalid_size_inputs_valid_test: true + required_batch_size_multiple: 1 + disable_validation: true + +distributed_training: + distributed_world_size: 16 + ddp_backend: c10d + +criterion: + _name: model + +optimization: + max_update: 375300 + lr: [0.0006] + +optimizer: + _name: composite + groups: + with_decay: + lr_float: 6e-4 + optimizer: + _name: adam + adam_betas: [0.9,0.95] + weight_decay: 0.05 + lr_scheduler: + _name: cosine + warmup_updates: 50040 + no_decay: + lr_float: 6e-4 + optimizer: + _name: adam + adam_betas: [0.9,0.95] + weight_decay: 0 + lr_scheduler: + _name: cosine + warmup_updates: 50040 + +lr_scheduler: pass_through + +model: + _name: mae diff --git a/examples/data2vec/config/vision/pretraining/run_config/local.yaml b/examples/data2vec/config/vision/pretraining/run_config/local.yaml new file mode 100644 index 000000000..45595f9ee --- /dev/null +++ b/examples/data2vec/config/vision/pretraining/run_config/local.yaml @@ -0,0 +1,15 @@ +# @package _global_ +hydra: + sweep: + dir: ${env:PWD}/tmp_dbg/${now:%H-%M-%S} + +distributed_training: + distributed_world_size: 1 + nprocs_per_node: 1 + distributed_port: -1 + +common: + log_interval: 1 + +dataset: + num_workers: 0 diff --git a/examples/data2vec/config/vision/pretraining/run_config/slurm_1.yaml b/examples/data2vec/config/vision/pretraining/run_config/slurm_1.yaml new file mode 100644 index 000000000..732f01889 --- /dev/null +++ b/examples/data2vec/config/vision/pretraining/run_config/slurm_1.yaml @@ -0,0 +1,37 @@ +# @package _global_ + +hydra: + job: + config: + override_dirname: + kv_sep: ':' + item_sep: '/' + exclude_keys: + - run_config + - distributed_training.distributed_port + - distributed_training.distributed_world_size + - model.pretrained_model_path + - model.target_network_path + - next_script + - task.cache_in_scratch + - task.data + - checkpoint.save_interval_updates + - checkpoint.keep_interval_updates + - checkpoint.save_on_overflow + - common.log_interval + - common.user_dir + sweep: + dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname} + subdir: '' + launcher: + submitit_folder: ${hydra.sweep.dir} + timeout_min: 4320 + cpus_per_task: 80 + gpus_per_node: 8 + tasks_per_node: 1 + mem_gb: 450 + nodes: 1 + name: ${env:PREFIX}_${hydra.job.config_name} + partition: devlab,learnlab,learnfair,scavenge + constraint: volta32gb,ib4 + max_num_timeout: 30 diff --git a/examples/data2vec/config/vision/pretraining/run_config/slurm_1_aws.yaml b/examples/data2vec/config/vision/pretraining/run_config/slurm_1_aws.yaml new file mode 100644 index 000000000..e2bab5675 --- /dev/null +++ b/examples/data2vec/config/vision/pretraining/run_config/slurm_1_aws.yaml @@ -0,0 +1,36 @@ +# @package _global_ + +hydra: + job: + config: + override_dirname: + kv_sep: ':' + item_sep: '/' + exclude_keys: + - run_config + - distributed_training.distributed_port + - distributed_training.distributed_world_size + - model.pretrained_model_path + - model.target_network_path + - next_script + - task.cache_in_scratch + - task.data + - checkpoint.save_interval_updates + - checkpoint.keep_interval_updates + - checkpoint.save_on_overflow + - common.log_interval + - common.user_dir + sweep: + dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname} + subdir: '' + launcher: + submitit_folder: ${hydra.sweep.dir} + timeout_min: 4320 + cpus_per_task: 80 + gpus_per_node: 8 + tasks_per_node: 1 + mem_gb: 0 + nodes: 1 + name: ${env:PREFIX}_${hydra.job.config_name} + partition: wav2vec,learnlab,learnfair + max_num_timeout: 30 diff --git a/examples/data2vec/config/vision/pretraining/run_config/slurm_2.yaml b/examples/data2vec/config/vision/pretraining/run_config/slurm_2.yaml new file mode 100644 index 000000000..c8b0f02a9 --- /dev/null +++ b/examples/data2vec/config/vision/pretraining/run_config/slurm_2.yaml @@ -0,0 +1,38 @@ +# @package _global_ + +hydra: + job: + config: + override_dirname: + kv_sep: ':' + item_sep: '/' + exclude_keys: + - run_config + - distributed_training.distributed_port + - distributed_training.distributed_world_size + - model.pretrained_model_path + - model.target_network_path + - next_script + - task.cache_in_scratch + - task.data + - checkpoint.save_interval_updates + - checkpoint.keep_interval_updates + - checkpoint.save_on_overflow + - common.log_interval + - common.user_dir + - task.local_cache_path + sweep: + dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname} + subdir: '' + launcher: + submitit_folder: ${hydra.sweep.dir} + timeout_min: 4320 + cpus_per_task: 10 + gpus_per_node: 8 + tasks_per_node: 8 + mem_gb: 450 + nodes: 2 + name: ${env:PREFIX}_${hydra.job.config_name} + partition: devlab,learnlab,learnfair,scavenge + constraint: volta32gb,ib4 + max_num_timeout: 30 diff --git a/examples/data2vec/config/vision/pretraining/run_config/slurm_2_aws.yaml b/examples/data2vec/config/vision/pretraining/run_config/slurm_2_aws.yaml new file mode 100644 index 000000000..032e53a30 --- /dev/null +++ b/examples/data2vec/config/vision/pretraining/run_config/slurm_2_aws.yaml @@ -0,0 +1,37 @@ +# @package _global_ + +hydra: + job: + config: + override_dirname: + kv_sep: ':' + item_sep: '/' + exclude_keys: + - run_config + - distributed_training.distributed_port + - distributed_training.distributed_world_size + - model.pretrained_model_path + - model.target_network_path + - next_script + - task.cache_in_scratch + - task.data + - checkpoint.save_interval_updates + - checkpoint.keep_interval_updates + - checkpoint.save_on_overflow + - common.log_interval + - common.user_dir + - task.local_cache_path + sweep: + dir: /fsx-wav2vec/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname} + subdir: '' + launcher: + submitit_folder: ${hydra.sweep.dir} + timeout_min: 4320 + cpus_per_task: 10 + gpus_per_node: 8 + tasks_per_node: 8 + mem_gb: 0 + nodes: 2 + name: ${env:PREFIX}_${hydra.job.config_name} + partition: wav2vec,learnlab,learnfair + max_num_timeout: 30 diff --git a/examples/data2vec/config/vision/pretraining/run_config/slurm_3.yaml b/examples/data2vec/config/vision/pretraining/run_config/slurm_3.yaml new file mode 100644 index 000000000..14b47d14e --- /dev/null +++ b/examples/data2vec/config/vision/pretraining/run_config/slurm_3.yaml @@ -0,0 +1,36 @@ +# @package _global_ + +hydra: + job: + config: + override_dirname: + kv_sep: ':' + item_sep: '/' + exclude_keys: + - run_config + - distributed_training.distributed_port + - distributed_training.distributed_world_size + - model.pretrained_model_path + - model.target_network_path + - next_script + - task.cache_in_scratch + - task.data + - checkpoint.save_interval_updates + - checkpoint.keep_interval_updates + - checkpoint.save_on_overflow + - common.log_interval + sweep: + dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname} + subdir: '' + launcher: + submitit_folder: ${hydra.sweep.dir} + timeout_min: 4320 + cpus_per_task: 80 + gpus_per_node: 8 + tasks_per_node: 1 + mem_gb: 450 + nodes: 3 + name: ${env:PREFIX}_${hydra.job.config_name} + partition: devlab,learnlab,learnfair,scavenge + constraint: volta32gb,ib4 + max_num_timeout: 30 diff --git a/examples/data2vec/config/vision/pretraining/run_config/slurm_4.yaml b/examples/data2vec/config/vision/pretraining/run_config/slurm_4.yaml new file mode 100644 index 000000000..c54d735fb --- /dev/null +++ b/examples/data2vec/config/vision/pretraining/run_config/slurm_4.yaml @@ -0,0 +1,36 @@ +# @package _global_ + +hydra: + job: + config: + override_dirname: + kv_sep: ':' + item_sep: '/' + exclude_keys: + - run_config + - distributed_training.distributed_port + - distributed_training.distributed_world_size + - model.pretrained_model_path + - model.target_network_path + - next_script + - task.cache_in_scratch + - task.data + - checkpoint.save_interval_updates + - checkpoint.keep_interval_updates + - checkpoint.save_on_overflow + - common.log_interval + sweep: + dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname} + subdir: '' + launcher: + submitit_folder: ${hydra.sweep.dir} + timeout_min: 4320 + cpus_per_task: 10 + gpus_per_node: 8 + tasks_per_node: 8 + mem_gb: 450 + nodes: 4 + name: ${env:PREFIX}_${hydra.job.config_name} + partition: devlab,learnlab,learnfair,scavenge + constraint: volta32gb,ib4 + max_num_timeout: 30 diff --git a/examples/data2vec/config/vision/pretraining/run_config/slurm_4_aws.yaml b/examples/data2vec/config/vision/pretraining/run_config/slurm_4_aws.yaml new file mode 100644 index 000000000..d5d11cb75 --- /dev/null +++ b/examples/data2vec/config/vision/pretraining/run_config/slurm_4_aws.yaml @@ -0,0 +1,36 @@ +# @package _global_ + +hydra: + job: + config: + override_dirname: + kv_sep: ':' + item_sep: '/' + exclude_keys: + - run_config + - distributed_training.distributed_port + - distributed_training.distributed_world_size + - model.pretrained_model_path + - model.target_network_path + - next_script + - task.cache_in_scratch + - task.data + - checkpoint.save_interval_updates + - checkpoint.keep_interval_updates + - checkpoint.save_on_overflow + - common.log_interval + - common.user_dir + sweep: + dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname} + subdir: '' + launcher: + submitit_folder: ${hydra.sweep.dir} + timeout_min: 4320 + cpus_per_task: 10 + gpus_per_node: 8 + tasks_per_node: 8 + mem_gb: 0 + nodes: 4 + name: ${env:PREFIX}_${hydra.job.config_name} + partition: wav2vec,learnlab,learnfair + max_num_timeout: 30 diff --git a/examples/data2vec/config/vision/pretraining/run_config/slurm_6_aws.yaml b/examples/data2vec/config/vision/pretraining/run_config/slurm_6_aws.yaml new file mode 100644 index 000000000..906f08a60 --- /dev/null +++ b/examples/data2vec/config/vision/pretraining/run_config/slurm_6_aws.yaml @@ -0,0 +1,36 @@ +# @package _global_ + +hydra: + job: + config: + override_dirname: + kv_sep: ':' + item_sep: '/' + exclude_keys: + - run_config + - distributed_training.distributed_port + - distributed_training.distributed_world_size + - model.pretrained_model_path + - model.target_network_path + - next_script + - task.cache_in_scratch + - task.data + - checkpoint.save_interval_updates + - checkpoint.keep_interval_updates + - checkpoint.save_on_overflow + - common.log_interval + - common.user_dir + sweep: + dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname} + subdir: '' + launcher: + submitit_folder: ${hydra.sweep.dir} + timeout_min: 4320 + cpus_per_task: 10 + gpus_per_node: 8 + tasks_per_node: 8 + mem_gb: 0 + nodes: 6 + name: ${env:PREFIX}_${hydra.job.config_name} + partition: wav2vec,learnlab,learnfair + max_num_timeout: 30 diff --git a/examples/data2vec/config/vision/pretraining/run_config/slurm_8_aws.yaml b/examples/data2vec/config/vision/pretraining/run_config/slurm_8_aws.yaml new file mode 100644 index 000000000..d60e13f8b --- /dev/null +++ b/examples/data2vec/config/vision/pretraining/run_config/slurm_8_aws.yaml @@ -0,0 +1,36 @@ +# @package _global_ + +hydra: + job: + config: + override_dirname: + kv_sep: ':' + item_sep: '/' + exclude_keys: + - run_config + - distributed_training.distributed_port + - distributed_training.distributed_world_size + - model.pretrained_model_path + - model.target_network_path + - next_script + - task.cache_in_scratch + - task.data + - checkpoint.save_interval_updates + - checkpoint.keep_interval_updates + - checkpoint.save_on_overflow + - common.log_interval + - common.user_dir + sweep: + dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname} + subdir: '' + launcher: + submitit_folder: ${hydra.sweep.dir} + timeout_min: 4320 + cpus_per_task: 10 + gpus_per_node: 8 + tasks_per_node: 8 + mem_gb: 0 + nodes: 8 + name: ${env:PREFIX}_${hydra.job.config_name} + partition: wav2vec,learnlab,learnfair + max_num_timeout: 30 diff --git a/examples/data2vec/data/__init__.py b/examples/data2vec/data/__init__.py new file mode 100644 index 000000000..d76112bfc --- /dev/null +++ b/examples/data2vec/data/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from .image_dataset import ImageDataset +from .path_dataset import PathDataset +from .mae_image_dataset import MaeImageDataset +from .mae_finetuning_image_dataset import MaeFinetuningImageDataset + + +__all__ = [ + "ImageDataset", + "MaeImageDataset", + "MaeFinetuningImageDataset", + "PathDataset", +] \ No newline at end of file diff --git a/examples/data2vec/data/add_class_target_dataset.py b/examples/data2vec/data/add_class_target_dataset.py new file mode 100644 index 000000000..c346c83e5 --- /dev/null +++ b/examples/data2vec/data/add_class_target_dataset.py @@ -0,0 +1,63 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch + +from fairseq.data import BaseWrapperDataset, data_utils + + +class AddClassTargetDataset(BaseWrapperDataset): + def __init__( + self, + dataset, + labels, + multi_class, + num_classes=None, + label_indices=None, + add_to_input=True, + ): + super().__init__(dataset) + + self.label_indices = label_indices + self.labels = labels + self.multi_class = multi_class + self.add_to_input = add_to_input + if num_classes is None and multi_class: + assert self.label_indices is not None + num_classes = len(self.label_indices) + + self.num_classes = num_classes + + def __getitem__(self, index): + item = self.dataset[index] + item_labels = self.labels[index] + if self.multi_class: + item["label"] = torch.zeros(self.num_classes) + for il in item_labels: + if self.label_indices is not None: + il = self.label_indices[il] + item["label"][il] = 1.0 + else: + item["label"] = torch.tensor( + self.labels[index] + if self.label_indices is None + else self.label_indices[self.labels[index]] + ) + + return item + + def collater(self, samples): + collated = self.dataset.collater(samples) + if len(collated) == 0: + return collated + + indices = set(collated["id"].tolist()) + target = [s["label"] for s in samples if s["id"] in indices] + collated["label"] = torch.stack(target, dim=0) + + if self.add_to_input: + collated["net_input"]["label"] = collated["label"] + + return collated diff --git a/examples/data2vec/data/image_dataset.py b/examples/data2vec/data/image_dataset.py new file mode 100644 index 000000000..7f551057e --- /dev/null +++ b/examples/data2vec/data/image_dataset.py @@ -0,0 +1,127 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +import logging + +import numpy as np +import os +from typing import Optional, Callable, Set + +import torch + +from torchvision.datasets.vision import VisionDataset +from torchvision.transforms import ToTensor + +from fairseq.data import FairseqDataset + + +logger = logging.getLogger(__name__) + + +class ImageDataset(FairseqDataset, VisionDataset): + def __init__( + self, + root: str, + extensions: Set[str], + load_classes: bool, + transform: Optional[Callable] = None, + shuffle=True, + ): + FairseqDataset.__init__(self) + VisionDataset.__init__(self, root=root, transform=transform) + + self.shuffle = shuffle + self.tensor_transform = ToTensor() + + self.classes = None + self.labels = None + if load_classes: + classes = [d.name for d in os.scandir(root) if d.is_dir()] + classes.sort() + self.classes = {cls_name: i for i, cls_name in enumerate(classes)} + logger.info(f"loaded {len(self.classes)} classes") + self.labels = [] + + def walk_path(root_path): + for root, _, fnames in sorted(os.walk(root_path, followlinks=True)): + for fname in sorted(fnames): + fname_ext = os.path.splitext(fname) + if fname_ext[-1].lower() not in extensions: + continue + + path = os.path.join(root, fname) + yield path + + logger.info(f"finding images in {root}") + if self.classes is not None: + self.files = [] + self.labels = [] + for c, i in self.classes.items(): + for f in walk_path(os.path.join(root, c)): + self.files.append(f) + self.labels.append(i) + else: + self.files = [f for f in walk_path(root)] + + logger.info(f"loaded {len(self.files)} examples") + + def __getitem__(self, index): + from PIL import Image + + fpath = self.files[index] + + with open(fpath, "rb") as f: + img = Image.open(f).convert("RGB") + + if self.transform is None: + img = self.tensor_transform(img) + else: + img = self.transform(img) + assert torch.is_tensor(img) + + res = {"id": index, "img": img} + + if self.labels is not None: + res["label"] = self.labels[index] + + return res + + def __len__(self): + return len(self.files) + + def collater(self, samples): + if len(samples) == 0: + return {} + + collated_img = torch.stack([s["img"] for s in samples], dim=0) + + res = { + "id": torch.LongTensor([s["id"] for s in samples]), + "net_input": { + "img": collated_img, + }, + } + + if "label" in samples[0]: + res["net_input"]["label"] = torch.LongTensor([s["label"] for s in samples]) + + return res + + def num_tokens(self, index): + return 1 + + def size(self, index): + return 1 + + def ordered_indices(self): + """Return an ordered list of indices. Batches will be constructed based + on this order.""" + if self.shuffle: + order = [np.random.permutation(len(self))] + else: + order = [np.arange(len(self))] + + return order[0] diff --git a/examples/data2vec/data/mae_finetuning_image_dataset.py b/examples/data2vec/data/mae_finetuning_image_dataset.py new file mode 100644 index 000000000..28cbcb38a --- /dev/null +++ b/examples/data2vec/data/mae_finetuning_image_dataset.py @@ -0,0 +1,135 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +import logging + +import numpy as np +import os + +import torch + +from torchvision import datasets, transforms + +from timm.data import create_transform +from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +import PIL + +from fairseq.data import FairseqDataset +from .mae_image_dataset import caching_loader + + +logger = logging.getLogger(__name__) + + +def build_transform(is_train, input_size, color_jitter, aa, reprob, remode, recount): + mean = IMAGENET_DEFAULT_MEAN + std = IMAGENET_DEFAULT_STD + # train transform + if is_train: + # this should always dispatch to transforms_imagenet_train + transform = create_transform( + input_size=input_size, + is_training=True, + color_jitter=color_jitter, + auto_augment=aa, + interpolation="bicubic", + re_prob=reprob, + re_mode=remode, + re_count=recount, + mean=mean, + std=std, + ) + return transform + + # eval transform + t = [] + if input_size <= 224: + crop_pct = 224 / 256 + else: + crop_pct = 1.0 + size = int(input_size / crop_pct) + t.append( + transforms.Resize( + size, interpolation=PIL.Image.BICUBIC + ), # to maintain same ratio w.r.t. 224 images + ) + t.append(transforms.CenterCrop(input_size)) + + t.append(transforms.ToTensor()) + t.append(transforms.Normalize(mean, std)) + return transforms.Compose(t) + + +class MaeFinetuningImageDataset(FairseqDataset): + def __init__( + self, + root: str, + split: str, + is_train: bool, + input_size, + color_jitter=None, + aa="rand-m9-mstd0.5-inc1", + reprob=0.25, + remode="pixel", + recount=1, + local_cache_path=None, + shuffle=True, + ): + FairseqDataset.__init__(self) + + self.shuffle = shuffle + + transform = build_transform( + is_train, input_size, color_jitter, aa, reprob, remode, recount + ) + + path = os.path.join(root, split) + loader = caching_loader(local_cache_path, datasets.folder.default_loader) + + self.dataset = datasets.ImageFolder(path, loader=loader, transform=transform) + + logger.info(f"loaded {len(self.dataset)} examples") + + def __getitem__(self, index): + img, label = self.dataset[index] + return {"id": index, "img": img, "label": label} + + def __len__(self): + return len(self.dataset) + + def collater(self, samples): + if len(samples) == 0: + return {} + + collated_img = torch.stack([s["img"] for s in samples], dim=0) + + res = { + "id": torch.LongTensor([s["id"] for s in samples]), + "net_input": { + "imgs": collated_img, + }, + } + + if "label" in samples[0]: + res["net_input"]["labels"] = torch.LongTensor([s["label"] for s in samples]) + + return res + + def num_tokens(self, index): + return 1 + + def size(self, index): + return 1 + + def ordered_indices(self): + """Return an ordered list of indices. Batches will be constructed based + on this order.""" + if self.shuffle: + order = [np.random.permutation(len(self))] + else: + order = [np.arange(len(self))] + + return order[0] diff --git a/examples/data2vec/data/mae_image_dataset.py b/examples/data2vec/data/mae_image_dataset.py new file mode 100644 index 000000000..4aacb9489 --- /dev/null +++ b/examples/data2vec/data/mae_image_dataset.py @@ -0,0 +1,418 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +from functools import partial +import logging +import math +import random +import time + +import numpy as np +import os + +import torch + +from torchvision import datasets, transforms +from .path_dataset import PathDataset + +from fairseq.data import FairseqDataset +from fairseq.data.data_utils import compute_block_mask_1d, compute_block_mask_2d + +from shutil import copyfile + +logger = logging.getLogger(__name__) + + +def load(path, loader, cache): + if hasattr(caching_loader, "cache_root"): + cache = caching_loader.cache_root + + cached_path = cache + path + + num_tries = 3 + for curr_try in range(num_tries): + try: + if curr_try == 2: + return loader(path) + if not os.path.exists(cached_path) or curr_try > 0: + os.makedirs(os.path.dirname(cached_path), exist_ok=True) + copyfile(path, cached_path) + os.chmod(cached_path, 0o777) + return loader(cached_path) + except Exception as e: + logger.warning(str(e)) + if "Errno 13" in str(e): + caching_loader.cache_root = f"/scratch/{random.randint(0, 69420)}" + logger.warning(f"setting cache root to {caching_loader.cache_root}") + cached_path = caching_loader.cache_root + path + if curr_try == (num_tries - 1): + raise + time.sleep(2) + + +def caching_loader(cache_root: str, loader): + if cache_root is None: + return loader + + if cache_root == "slurm_tmpdir": + cache_root = os.environ["SLURM_TMPDIR"] + assert len(cache_root) > 0 + + if not cache_root.endswith("/"): + cache_root += "/" + + return partial(load, loader=loader, cache=cache_root) + + +class RandomResizedCropAndInterpolationWithTwoPic: + """Crop the given PIL Image to random size and aspect ratio with random interpolation. + + A crop of random size (default: of 0.08 to 1.0) of the original size and a random + aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This crop + is finally resized to given size. + This is popularly used to train the Inception networks. + + Args: + size: expected output size of each edge + scale: range of size of the origin size cropped + ratio: range of aspect ratio of the origin aspect ratio cropped + interpolation: Default: PIL.Image.BILINEAR + """ + + def __init__( + self, + size, + second_size=None, + scale=(0.08, 1.0), + ratio=(3.0 / 4.0, 4.0 / 3.0), + interpolation="bilinear", + second_interpolation="lanczos", + ): + if isinstance(size, tuple): + self.size = size + else: + self.size = (size, size) + if second_size is not None: + if isinstance(second_size, tuple): + self.second_size = second_size + else: + self.second_size = (second_size, second_size) + else: + self.second_size = None + if (scale[0] > scale[1]) or (ratio[0] > ratio[1]): + logger.warning("range should be of kind (min, max)") + + if interpolation == "random": + from PIL import Image + + self.interpolation = (Image.BILINEAR, Image.BICUBIC) + else: + self.interpolation = self._pil_interp(interpolation) + + self.second_interpolation = ( + self._pil_interp(second_interpolation) + if second_interpolation is not None + else None + ) + self.scale = scale + self.ratio = ratio + + def _pil_interp(self, method): + from PIL import Image + + if method == "bicubic": + return Image.BICUBIC + elif method == "lanczos": + return Image.LANCZOS + elif method == "hamming": + return Image.HAMMING + else: + # default bilinear, do we want to allow nearest? + return Image.BILINEAR + + @staticmethod + def get_params(img, scale, ratio): + """Get parameters for ``crop`` for a random sized crop. + + Args: + img (PIL Image): Image to be cropped. + scale (tuple): range of size of the origin size cropped + ratio (tuple): range of aspect ratio of the origin aspect ratio cropped + + Returns: + tuple: params (i, j, h, w) to be passed to ``crop`` for a random + sized crop. + """ + area = img.size[0] * img.size[1] + + for attempt in range(10): + target_area = random.uniform(*scale) * area + log_ratio = (math.log(ratio[0]), math.log(ratio[1])) + aspect_ratio = math.exp(random.uniform(*log_ratio)) + + w = int(round(math.sqrt(target_area * aspect_ratio))) + h = int(round(math.sqrt(target_area / aspect_ratio))) + + if w <= img.size[0] and h <= img.size[1]: + i = random.randint(0, img.size[1] - h) + j = random.randint(0, img.size[0] - w) + return i, j, h, w + + # Fallback to central crop + in_ratio = img.size[0] / img.size[1] + if in_ratio < min(ratio): + w = img.size[0] + h = int(round(w / min(ratio))) + elif in_ratio > max(ratio): + h = img.size[1] + w = int(round(h * max(ratio))) + else: # whole image + w = img.size[0] + h = img.size[1] + i = (img.size[1] - h) // 2 + j = (img.size[0] - w) // 2 + return i, j, h, w + + def __call__(self, img): + import torchvision.transforms.functional as F + + """ + Args: + img (PIL Image): Image to be cropped and resized. + + Returns: + PIL Image: Randomly cropped and resized image. + """ + i, j, h, w = self.get_params(img, self.scale, self.ratio) + if isinstance(self.interpolation, (tuple, list)): + interpolation = random.choice(self.interpolation) + else: + interpolation = self.interpolation + if self.second_size is None: + return F.resized_crop(img, i, j, h, w, self.size, interpolation) + else: + return F.resized_crop( + img, i, j, h, w, self.size, interpolation + ), F.resized_crop( + img, i, j, h, w, self.second_size, self.second_interpolation + ) + + +class MaeImageDataset(FairseqDataset): + def __init__( + self, + root: str, + split: str, + input_size, + local_cache_path=None, + shuffle=True, + key="imgs", + beit_transforms=False, + target_transform=False, + no_transform=False, + compute_mask=False, + patch_size: int = 16, + mask_prob: float = 0.75, + mask_prob_adjust: float = 0, + mask_length: int = 1, + inverse_mask: bool = False, + expand_adjacent: bool = False, + mask_dropout: float = 0, + non_overlapping: bool = False, + require_same_masks: bool = True, + clone_batch: int = 1, + dataset_type: str = "imagefolder", + ): + FairseqDataset.__init__(self) + + self.shuffle = shuffle + self.key = key + + loader = caching_loader(local_cache_path, datasets.folder.default_loader) + + self.transform_source = None + self.transform_target = None + + if target_transform: + self.transform_source = transforms.ColorJitter(0.4, 0.4, 0.4) + self.transform_target = transforms.ColorJitter(0.4, 0.4, 0.4) + + if no_transform: + if input_size <= 224: + crop_pct = 224 / 256 + else: + crop_pct = 1.0 + size = int(input_size / crop_pct) + + self.transform_train = transforms.Compose( + [ + transforms.Resize(size, interpolation=3), + transforms.CenterCrop(input_size), + ] + ) + + self.transform_train = transforms.Resize((input_size, input_size)) + elif beit_transforms: + beit_transform_list = [] + if not target_transform: + beit_transform_list.append(transforms.ColorJitter(0.4, 0.4, 0.4)) + beit_transform_list.extend( + [ + transforms.RandomHorizontalFlip(p=0.5), + RandomResizedCropAndInterpolationWithTwoPic( + size=input_size, + second_size=None, + interpolation="bicubic", + second_interpolation=None, + ), + ] + ) + self.transform_train = transforms.Compose(beit_transform_list) + else: + self.transform_train = transforms.Compose( + [ + transforms.RandomResizedCrop( + input_size, scale=(0.2, 1.0), interpolation=3 + ), # 3 is bicubic + transforms.RandomHorizontalFlip(), + ] + ) + self.final_transform = transforms.Compose( + [ + transforms.ToTensor(), + transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ), + ] + ) + + if dataset_type == "imagefolder": + self.dataset = datasets.ImageFolder( + os.path.join(root, split), loader=loader + ) + elif dataset_type == "path": + self.dataset = PathDataset( + root, + loader, + None, + None, + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225], + ) + else: + raise Exception(f"invalid dataset type {dataset_type}") + + logger.info( + f"initial transform: {self.transform_train}, " + f"source transform: {self.transform_source}, " + f"target transform: {self.transform_target}, " + f"final transform: {self.final_transform}" + ) + logger.info(f"loaded {len(self.dataset)} examples") + + self.is_compute_mask = compute_mask + self.patches = (input_size // patch_size) ** 2 + self.mask_prob = mask_prob + self.mask_prob_adjust = mask_prob_adjust + self.mask_length = mask_length + self.inverse_mask = inverse_mask + self.expand_adjacent = expand_adjacent + self.mask_dropout = mask_dropout + self.non_overlapping = non_overlapping + self.require_same_masks = require_same_masks + self.clone_batch = clone_batch + + def __getitem__(self, index): + img, _ = self.dataset[index] + + img = self.transform_train(img) + + source = None + target = None + if self.transform_source is not None: + source = self.final_transform(self.transform_source(img)) + if self.transform_target is not None: + target = self.final_transform(self.transform_target(img)) + + if source is None: + img = self.final_transform(img) + + v = {"id": index, self.key: source if source is not None else img} + if target is not None: + v["target"] = target + + if self.is_compute_mask: + if self.mask_length == 1: + mask = compute_block_mask_1d( + shape=(self.clone_batch, self.patches), + mask_prob=self.mask_prob, + mask_length=self.mask_length, + mask_prob_adjust=self.mask_prob_adjust, + inverse_mask=self.inverse_mask, + require_same_masks=True, + ) + else: + mask = compute_block_mask_2d( + shape=(self.clone_batch, self.patches), + mask_prob=self.mask_prob, + mask_length=self.mask_length, + mask_prob_adjust=self.mask_prob_adjust, + inverse_mask=self.inverse_mask, + require_same_masks=True, + expand_adjcent=self.expand_adjacent, + mask_dropout=self.mask_dropout, + non_overlapping=self.non_overlapping, + ) + + v["precomputed_mask"] = mask + + return v + + def __len__(self): + return len(self.dataset) + + def collater(self, samples): + if len(samples) == 0: + return {} + + collated_img = torch.stack([s[self.key] for s in samples], dim=0) + + res = { + "id": torch.LongTensor([s["id"] for s in samples]), + "net_input": { + self.key: collated_img, + }, + } + + if "target" in samples[0]: + collated_target = torch.stack([s["target"] for s in samples], dim=0) + res["net_input"]["target"] = collated_target + + if "precomputed_mask" in samples[0]: + collated_mask = torch.cat([s["precomputed_mask"] for s in samples], dim=0) + res["net_input"]["precomputed_mask"] = collated_mask + + return res + + def num_tokens(self, index): + return 1 + + def size(self, index): + return 1 + + @property + def sizes(self): + return np.full((len(self),), 1) + + def ordered_indices(self): + """Return an ordered list of indices. Batches will be constructed based + on this order.""" + if self.shuffle: + order = [np.random.permutation(len(self))] + else: + order = [np.arange(len(self))] + + return order[0] diff --git a/examples/data2vec/data/modality.py b/examples/data2vec/data/modality.py new file mode 100644 index 000000000..aa23ac94f --- /dev/null +++ b/examples/data2vec/data/modality.py @@ -0,0 +1,14 @@ +# Copyright (c) 2017-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the LICENSE file in +# the root directory of this source tree. An additional grant of patent rights +# can be found in the PATENTS file in the same directory. + +from enum import Enum, auto + + +class Modality(Enum): + AUDIO = auto() + IMAGE = auto() + TEXT = auto() diff --git a/examples/data2vec/data/path_dataset.py b/examples/data2vec/data/path_dataset.py new file mode 100644 index 000000000..02010058e --- /dev/null +++ b/examples/data2vec/data/path_dataset.py @@ -0,0 +1,64 @@ +import glob +import os +from typing import List, Optional, Tuple + +import logging +import numpy as np +import torchvision.transforms.functional as TF +import PIL +from PIL import Image +from torchvision.datasets import VisionDataset + +logger = logging.getLogger(__name__) + + +class PathDataset(VisionDataset): + def __init__( + self, + root: List[str], + loader: None = None, + transform: Optional[str] = None, + extra_transform: Optional[str] = None, + mean: Optional[List[float]] = None, + std: Optional[List[float]] = None, + ): + super().__init__(root=root) + + PIL.Image.MAX_IMAGE_PIXELS = 256000001 + + self.files = [] + for folder in self.root: + self.files.extend( + sorted(glob.glob(os.path.join(folder, "**", "*.jpg"), recursive=True)) + ) + self.files.extend( + sorted(glob.glob(os.path.join(folder, "**", "*.png"), recursive=True)) + ) + + self.transform = transform + self.extra_transform = extra_transform + self.mean = mean + self.std = std + + self.loader = loader + + logger.info(f"loaded {len(self.files)} samples from {root}") + + assert (mean is None) == (std is None) + + def __len__(self) -> int: + return len(self.files) + + def __getitem__(self, idx) -> Tuple[np.ndarray, np.ndarray]: + path = self.files[idx] + + if self.loader is not None: + return self.loader(path), None + + img = Image.open(path).convert("RGB") + if self.transform is not None: + img = self.transform(img) + img = TF.to_tensor(img) + if self.mean is not None and self.std is not None: + img = TF.normalize(img, self.mean, self.std) + return img, None diff --git a/examples/data2vec/fb_convert_beit_cp.py b/examples/data2vec/fb_convert_beit_cp.py new file mode 100644 index 000000000..cf42ace76 --- /dev/null +++ b/examples/data2vec/fb_convert_beit_cp.py @@ -0,0 +1,165 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import torch + +from omegaconf import OmegaConf + +from fairseq.criterions.model_criterion import ModelCriterionConfig +from fairseq.dataclass.configs import FairseqConfig + +from tasks import ImageClassificationConfig, ImagePretrainingConfig +from models.data2vec_image_classification import ( + Data2VecImageClassificationConfig, + Data2VecImageClassificationModel, +) +from models.data2vec_vision import Data2VecVisionConfig, Data2VecVisionModel + + +def get_parser(): + parser = argparse.ArgumentParser( + description="convert beit checkpoint into data2vec - vision checkpoint" + ) + # fmt: off + parser.add_argument('checkpoint', help='checkpoint to convert') + parser.add_argument('--output', required=True, metavar='PATH', help='where to output converted checkpoint') + parser.add_argument('--type', type=str, choices=['vision', 'image_classification'], default='image_classification', help='type of model to upgrade') + parser.add_argument('--inception_norms', action='store_true', default=False) + # fmt: on + + return parser + + +def update_checkpoint(model_dict, prefix, is_nested): + + replace_paths = { + "cls_token": "model.cls_emb" if is_nested else "cls_emb", + "patch_embed": "model.patch_embed" if is_nested else "patch_embed", + "mask_token": "mask_emb", + } + + starts_with = { + "patch_embed.proj": "model.patch_embed.conv" + if is_nested + else "patch_embed.conv", + "lm_head": "final_proj", + "fc_norm": "fc_norm", + "head": "head", + } + + partial = { + "mlp.fc1": "mlp.0", + "mlp.fc2": "mlp.2", + } + + for k in list(model_dict.keys()): + for sw, r in starts_with.items(): + if k.startswith(sw): + replace_paths[k] = k.replace(sw, r) + for p, r in partial.items(): + if p in k: + replace_paths[k] = prefix + k.replace(p, r) + + if prefix != "": + for k in list(model_dict.keys()): + if k not in replace_paths: + replace_paths[k] = prefix + k + + for k in list(model_dict.keys()): + if k in replace_paths: + model_dict[replace_paths[k]] = model_dict[k] + if k != replace_paths[k]: + del model_dict[k] + + return model_dict + + +def main(): + parser = get_parser() + args = parser.parse_args() + + cp = torch.load(args.checkpoint, map_location="cpu") + + cfg = FairseqConfig( + criterion=ModelCriterionConfig(_name="model", log_keys=["correct"]), + ) + + if args.type == "image_classification": + + cfg.task = ImageClassificationConfig( + _name="image_classification", + data=".", + ) + + if args.inception_norms: + cfg.task.normalization_mean = [0.5, 0.5, 0.5] + cfg.task.normalization_std = [0.5, 0.5, 0.5] + + cfg.model = Data2VecImageClassificationConfig( + _name="data2vec_image_classification", + ) + cfg.model.pretrained_model_args = FairseqConfig( + model=Data2VecVisionConfig( + _name="data2vec_vision", shared_rel_pos_bias=False + ), + task=ImagePretrainingConfig( + _name="image_pretraining", + ), + ) + + cfg = OmegaConf.create(cfg) + + state = { + "cfg": OmegaConf.to_container(cfg, resolve=True, enum_to_str=True), + "model": cp["module"], + "best_loss": None, + "optimizer": None, + "extra_state": {}, + } + + model = Data2VecImageClassificationModel(cfg.model) + model.load_state_dict( + update_checkpoint(state["model"], prefix="model.encoder.", is_nested=True), + strict=True, + ) + elif args.type == "vision": + cfg.task = ImagePretrainingConfig( + _name="image_pretraining", + data=".", + ) + + if args.inception_norms: + cfg.task.normalization_mean = [0.5, 0.5, 0.5] + cfg.task.normalization_std = [0.5, 0.5, 0.5] + + cfg.model = Data2VecVisionConfig( + _name="data2vec_vision", + ) + cfg = OmegaConf.create(cfg) + + state = { + "cfg": OmegaConf.to_container(cfg, resolve=True, enum_to_str=True), + "model": cp["model"], + "best_loss": None, + "optimizer": None, + "extra_state": {}, + } + + model = Data2VecVisionModel(cfg.model) + model.load_state_dict( + update_checkpoint(state["model"], prefix="encoder.", is_nested=False), + strict=True, + ) + else: + raise Exception("unsupported type " + args.type) + + print(state["cfg"], state.keys()) + torch.save(state, args.output) + + +if __name__ == "__main__": + main() diff --git a/examples/data2vec/models/__init__.py b/examples/data2vec/models/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/examples/data2vec/models/audio_classification.py b/examples/data2vec/models/audio_classification.py new file mode 100644 index 000000000..06d215826 --- /dev/null +++ b/examples/data2vec/models/audio_classification.py @@ -0,0 +1,614 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import contextlib +import logging +import re +from dataclasses import dataclass, field +from typing import Any, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from omegaconf import II, MISSING, open_dict + +from fairseq import checkpoint_utils, tasks +from fairseq.dataclass import FairseqDataclass +from fairseq.dataclass.utils import convert_namespace_to_omegaconf +from fairseq.models import ( + BaseFairseqModel, + register_model, +) +from fairseq.models.wav2vec.wav2vec2 import MASKING_DISTRIBUTION_CHOICES +from fairseq.modules import TransposeLast +from fairseq.tasks import FairseqTask + +logger = logging.getLogger(__name__) + + +@dataclass +class AudioClassificationConfig(FairseqDataclass): + model_path: str = field( + default=MISSING, metadata={"help": "path to wav2vec 2.0 model"} + ) + no_pretrained_weights: bool = field( + default=False, metadata={"help": "if true, does not load pretrained weights"} + ) + dropout_input: float = field( + default=0.0, + metadata={"help": "dropout to apply to the input (after feat extr)"}, + ) + final_dropout: float = field( + default=0.0, + metadata={"help": "dropout after transformer and before final projection"}, + ) + dropout: float = field( + default=0.0, metadata={"help": "dropout probability inside wav2vec 2.0 model"} + ) + attention_dropout: float = field( + default=0.0, + metadata={ + "help": "dropout probability for attention weights inside wav2vec 2.0 model" + }, + ) + activation_dropout: float = field( + default=0.0, + metadata={ + "help": "dropout probability after activation in FFN inside wav2vec 2.0 model" + }, + ) + + # masking + apply_mask: bool = field( + default=False, metadata={"help": "apply masking during fine-tuning"} + ) + mask_length: int = field( + default=10, metadata={"help": "repeat the mask indices multiple times"} + ) + mask_prob: float = field( + default=0.5, + metadata={ + "help": "probability of replacing a token with mask (normalized by length)" + }, + ) + mask_selection: MASKING_DISTRIBUTION_CHOICES = field( + default="static", metadata={"help": "how to choose masks"} + ) + mask_other: float = field( + default=0, + metadata={ + "help": "secondary mask argument (used for more complex distributions), " + "see help in compute_mask_indices" + }, + ) + no_mask_overlap: bool = field( + default=False, metadata={"help": "whether to allow masks to overlap"} + ) + mask_min_space: Optional[int] = field( + default=1, + metadata={"help": "min space between spans (if no overlap is enabled)"}, + ) + require_same_masks: bool = field( + default=True, + metadata={ + "help": "whether to number of masked timesteps must be the same across all " + "examples in a batch" + }, + ) + mask_dropout: float = field( + default=0.0, + metadata={"help": "percent of masks to unmask for each sample"}, + ) + + # channel masking + mask_channel_length: int = field( + default=10, metadata={"help": "length of the mask for features (channels)"} + ) + mask_channel_prob: float = field( + default=0.0, metadata={"help": "probability of replacing a feature with 0"} + ) + mask_channel_selection: MASKING_DISTRIBUTION_CHOICES = field( + default="static", + metadata={"help": "how to choose mask length for channel masking"}, + ) + mask_channel_other: float = field( + default=0, + metadata={ + "help": "secondary mask argument (used for more complex distributions), " + "see help in compute_mask_indicesh" + }, + ) + no_mask_channel_overlap: bool = field( + default=False, metadata={"help": "whether to allow channel masks to overlap"} + ) + freeze_finetune_updates: int = field( + default=0, metadata={"help": "dont finetune wav2vec for this many updates"} + ) + feature_grad_mult: float = field( + default=0.0, metadata={"help": "reset feature grad mult in wav2vec 2.0 to this"} + ) + layerdrop: float = field( + default=0.0, metadata={"help": "probability of dropping a layer in wav2vec 2.0"} + ) + mask_channel_min_space: Optional[int] = field( + default=1, + metadata={"help": "min space between spans (if no overlap is enabled)"}, + ) + mask_channel_before: bool = False + normalize: bool = II("task.normalize") + data: str = II("task.data") + # this holds the loaded wav2vec args + d2v_args: Any = None + offload_activations: bool = field( + default=False, metadata={"help": "offload_activations"} + ) + min_params_to_wrap: int = field( + default=int(1e8), + metadata={ + "help": "minimum number of params for a layer to be wrapped with FSDP() when " + "training with --ddp-backend=fully_sharded. Smaller values will " + "improve memory efficiency, but may make torch.distributed " + "communication less efficient due to smaller input sizes. This option " + "is set to 0 (i.e., always wrap) when --checkpoint-activations or " + "--offload-activations are passed." + }, + ) + + checkpoint_activations: bool = field( + default=False, + metadata={"help": "recompute activations and save memory for extra compute"}, + ) + ddp_backend: str = II("distributed_training.ddp_backend") + + prediction_mode: str = "lin_softmax" + eval_prediction_mode: Optional[str] = None + conv_kernel: int = -1 + conv_stride: int = 1 + two_convs: bool = False + extreme_factor: float = 1.0 + + conv_feature_layers: Optional[str] = field( + default=None, + metadata={ + "help": "string describing convolutional feature extraction layers in form of a python list that contains " + "[(dim, kernel_size, stride), ...]" + }, + ) + + mixup_prob: float = 1.0 + source_mixup: float = -1 + same_mixup: bool = True + label_mixup: bool = False + + gain_mode: str = "none" + + +@register_model("audio_classification", dataclass=AudioClassificationConfig) +class AudioClassificationModel(BaseFairseqModel): + def __init__(self, cfg: AudioClassificationConfig, num_classes): + super().__init__() + + self.apply_mask = cfg.apply_mask + self.cfg = cfg + + arg_overrides = { + "dropout": cfg.dropout, + "activation_dropout": cfg.activation_dropout, + "dropout_input": cfg.dropout_input, + "attention_dropout": cfg.attention_dropout, + "mask_length": cfg.mask_length, + "mask_prob": cfg.mask_prob, + "require_same_masks": getattr(cfg, "require_same_masks", True), + "mask_dropout": getattr(cfg, "mask_dropout", 0), + "mask_selection": cfg.mask_selection, + "mask_other": cfg.mask_other, + "no_mask_overlap": cfg.no_mask_overlap, + "mask_channel_length": cfg.mask_channel_length, + "mask_channel_prob": cfg.mask_channel_prob, + "mask_channel_before": cfg.mask_channel_before, + "mask_channel_selection": cfg.mask_channel_selection, + "mask_channel_other": cfg.mask_channel_other, + "no_mask_channel_overlap": cfg.no_mask_channel_overlap, + "encoder_layerdrop": cfg.layerdrop, + "feature_grad_mult": cfg.feature_grad_mult, + "checkpoint_activations": cfg.checkpoint_activations, + "offload_activations": cfg.offload_activations, + "min_params_to_wrap": cfg.min_params_to_wrap, + "mixup": -1, + } + + if cfg.conv_feature_layers is not None: + arg_overrides["conv_feature_layers"] = cfg.conv_feature_layers + + if cfg.d2v_args is None: + state = checkpoint_utils.load_checkpoint_to_cpu( + cfg.model_path, arg_overrides + ) + d2v_args = state.get("cfg", None) + if d2v_args is None: + d2v_args = convert_namespace_to_omegaconf(state["args"]) + d2v_args.criterion = None + d2v_args.lr_scheduler = None + cfg.d2v_args = d2v_args + + logger.info(d2v_args) + + else: + state = None + d2v_args = cfg.d2v_args + + model_normalized = d2v_args.task.get( + "normalize", d2v_args.model.get("normalize", False) + ) + assert cfg.normalize == model_normalized, ( + "Fine-tuning works best when data normalization is the same. " + "Please check that --normalize is set or unset for both pre-training and here" + ) + + if hasattr(cfg, "checkpoint_activations") and cfg.checkpoint_activations: + with open_dict(d2v_args): + d2v_args.model.checkpoint_activations = cfg.checkpoint_activations + + d2v_args.task.data = cfg.data + task = tasks.setup_task(d2v_args.task) + model = task.build_model(d2v_args.model, from_checkpoint=True) + + model.remove_pretraining_modules() + + if state is not None and not cfg.no_pretrained_weights: + self.load_model_weights(state, model, cfg) + + d = d2v_args.model.encoder_embed_dim + + self.d2v_model = model + + self.final_dropout = nn.Dropout(cfg.final_dropout) + self.freeze_finetune_updates = cfg.freeze_finetune_updates + self.num_updates = 0 + + for p in self.parameters(): + p.param_group = "pretrained" + + if cfg.prediction_mode == "proj_avg_proj": + self.proj = nn.Linear(d, d * 2) + self.proj2 = nn.Linear(d * 2, num_classes) + + for p in self.proj.parameters(): + p.param_group = "projection" + for p in self.proj2.parameters(): + p.param_group = "projection" + elif self.cfg.prediction_mode == "summary_proj": + self.proj = nn.Linear(d // 3, num_classes) + for p in self.proj.parameters(): + p.param_group = "projection" + elif self.cfg.conv_kernel > 1 and not self.cfg.two_convs: + self.proj = nn.Sequential( + TransposeLast(), + nn.Conv1d(d, num_classes, kernel_size=self.cfg.conv_kernel, stride=self.cfg.conv_stride), + TransposeLast(), + ) + for p in self.proj.parameters(): + p.param_group = "projection" + elif self.cfg.conv_kernel > 0 and self.cfg.two_convs: + self.proj = nn.Sequential( + TransposeLast(), + nn.Conv1d(d, d, kernel_size=self.cfg.conv_kernel, stride=self.cfg.conv_stride), + TransposeLast(), + nn.GELU(), + nn.Linear(d, num_classes), + ) + for p in self.proj.parameters(): + p.param_group = "projection" + else: + self.proj = nn.Linear(d, num_classes) + for p in self.proj.parameters(): + p.param_group = "projection" + + def upgrade_state_dict_named(self, state_dict, name): + super().upgrade_state_dict_named(state_dict, name) + return state_dict + + @classmethod + def build_model(cls, cfg: AudioClassificationConfig, task: FairseqTask): + """Build a new model instance.""" + + assert hasattr(task, "labels"), f"Task {task} must have an attribute 'labels'" + + return cls(cfg, len(task.labels)) + + def load_model_weights(self, state, model, cfg): + if cfg.ddp_backend == "fully_sharded": + from fairseq.distributed import FullyShardedDataParallel + + for name, module in model.named_modules(): + if "encoder.layers" in name and len(name.split(".")) == 3: + # Only for layers, we do a special handling and load the weights one by one + # We dont load all weights together as that wont be memory efficient and may + # cause oom + new_dict = { + k.replace(name + ".", ""): v + for (k, v) in state["model"].items() + if name + "." in k + } + assert isinstance(module, FullyShardedDataParallel) + with module.summon_full_params(): + module.load_state_dict(new_dict, strict=True) + module._reset_lazy_init() + + # Once layers are loaded, filter them out and load everything else. + r = re.compile("encoder.layers.\d.") + filtered_list = list(filter(r.match, state["model"].keys())) + + new_big_dict = { + k: v for (k, v) in state["model"].items() if k not in filtered_list + } + + model.load_state_dict(new_big_dict, strict=False) + else: + if "_ema" in state["model"]: + del state["model"]["_ema"] + model.load_state_dict(state["model"], strict=False) + + def set_num_updates(self, num_updates): + """Set the number of parameters updates.""" + super().set_num_updates(num_updates) + self.num_updates = num_updates + + def compute_gain(self, sound, fs=16_000, min_db=-80.0, mode="A_weighting"): + if fs == 16000: + n_fft = 2048 + elif fs == 44100: + n_fft = 4096 + else: + raise Exception("Invalid fs {}".format(fs)) + stride = n_fft // 2 + + def a_weight(fs, n_fft, min_db=-80.0): + freq = np.linspace(0, fs // 2, n_fft // 2 + 1) + freq_sq = np.power(freq, 2) + freq_sq[0] = 1.0 + weight = 2.0 + 20.0 * ( + 2 * np.log10(12194) + + 2 * np.log10(freq_sq) + - np.log10(freq_sq + 12194 ** 2) + - np.log10(freq_sq + 20.6 ** 2) + - 0.5 * np.log10(freq_sq + 107.7 ** 2) + - 0.5 * np.log10(freq_sq + 737.9 ** 2) + ) + weight = np.maximum(weight, min_db) + + return weight + + gain = [] + for i in range(0, len(sound) - n_fft + 1, stride): + if mode == "RMSE": + g = np.mean(sound[i : i + n_fft] ** 2) + elif mode == "A_weighting": + spec = np.fft.rfft(np.hanning(n_fft + 1)[:-1] * sound[i : i + n_fft]) + power_spec = np.abs(spec) ** 2 + a_weighted_spec = power_spec * np.power(10, a_weight(fs, n_fft) / 10) + g = np.sum(a_weighted_spec) + else: + raise Exception("Invalid mode {}".format(mode)) + gain.append(g) + + gain = np.array(gain) + gain = np.maximum(gain, np.power(10, min_db / 10)) + gain_db = 10 * np.log10(gain) + + return gain_db + + # adapted from https://github.com/mil-tokyo/bc_learning_sound/blob/master/utils.py + def compute_gain_torch(self, sound, fs=16_000, min_db=-80.0, mode="A_weighting"): + if fs == 16000: + n_fft = 2048 + elif fs == 44100: + n_fft = 4096 + else: + raise Exception("Invalid fs {}".format(fs)) + + if mode == "A_weighting": + if not hasattr(self, f"a_weight"): + self.a_weight = {} + + if fs not in self.a_weight: + + def a_weight(fs, n_fft, min_db=-80.0): + freq = np.linspace(0, fs // 2, n_fft // 2 + 1) + freq_sq = freq ** 2 + freq_sq[0] = 1.0 + weight = 2.0 + 20.0 * ( + 2 * np.log10(12194) + + 2 * np.log10(freq_sq) + - np.log10(freq_sq + 12194 ** 2) + - np.log10(freq_sq + 20.6 ** 2) + - 0.5 * np.log10(freq_sq + 107.7 ** 2) + - 0.5 * np.log10(freq_sq + 737.9 ** 2) + ) + weight = np.maximum(weight, min_db) + + return weight + + self.a_weight[fs] = torch.from_numpy( + np.power(10, a_weight(fs, n_fft, min_db) / 10) + ).to(device=sound.device) + + sound = sound.unfold(-1, n_fft, n_fft // 2) + + if mode == "RMSE": + sound = sound ** 2 + g = sound.mean(-1) + elif mode == "A_weighting": + w = torch.hann_window(n_fft, device=sound.device) * sound + spec = torch.fft.rfft(w) + power_spec = spec.abs() ** 2 + a_weighted_spec = power_spec * self.a_weight[fs] + g = a_weighted_spec.sum(-1) + else: + raise Exception("Invalid mode {}".format(mode)) + + gain = torch.maximum(g, torch.tensor(10 ** (min_db / 10), device=g.device)) + gain_db = 10 * torch.log10(gain) + + return gain_db + + def forward(self, source, padding_mask, label=None, **kwargs): + + if self.cfg.source_mixup >= 0 and self.training and self.cfg.mixup_prob > 0: + with torch.no_grad(): + mixed_source = source + mix_mask = None + if self.cfg.mixup_prob < 1: + mix_mask = ( + torch.empty((source.size(0),), device=source.device) + .bernoulli_(self.cfg.mixup_prob) + .bool() + ) + mixed_source = source[mix_mask] + + r = ( + torch.FloatTensor( + 1 if self.cfg.same_mixup else mixed_source.size(0) + ) + .uniform_(max(1e-6, self.cfg.source_mixup), 1) + .to(dtype=source.dtype, device=source.device) + ) + + mixup_perm = torch.randperm(source.size(0)) + s2 = source[mixup_perm] + + if self.cfg.gain_mode == "none": + p = r.unsqueeze(-1) + if mix_mask is not None: + s2 = s2[mix_mask] + else: + if self.cfg.gain_mode == "naive_rms": + G1 = source.pow(2).mean(dim=-1).sqrt() + else: + G1, _ = self.compute_gain_torch( + source, mode=self.cfg.gain_mode + ).max(-1) + G1 = G1.to(dtype=source.dtype) + + G2 = G1[mixup_perm] + + if mix_mask is not None: + G1 = G1[mix_mask] + G2 = G2[mix_mask] + s2 = s2[mix_mask] + + p = 1 / (1 + 10 ** ((G1 - G2) / 20) * (1 - r) / r) + p = p.unsqueeze(-1) + + mixed = (p * mixed_source) + (1 - p) * s2 + + if mix_mask is None: + source = mixed / torch.sqrt(p ** 2 + (1 - p) ** 2) + else: + source[mix_mask] = mixed / torch.sqrt(p ** 2 + (1 - p) ** 2) + + if label is not None and self.cfg.label_mixup: + r = r.unsqueeze(-1) + if mix_mask is None: + label = label * r + (1 - r) * label[mixup_perm] + else: + label[mix_mask] = ( + label[mix_mask] * r + (1 - r) * label[mixup_perm][mix_mask] + ) + + d2v_args = { + "source": source, + "padding_mask": padding_mask, + "mask": self.apply_mask and self.training, + } + + ft = self.freeze_finetune_updates <= self.num_updates + + with torch.no_grad() if not ft else contextlib.ExitStack(): + res = self.d2v_model.extract_features(**d2v_args) + + x = res["x"] + padding_mask = res["padding_mask"] + if padding_mask is not None: + x[padding_mask] = 0 + + x = self.final_dropout(x) + + if self.training or ( + self.cfg.eval_prediction_mode is None or self.cfg.eval_prediction_mode == "" + ): + prediction_mode = self.cfg.prediction_mode + else: + prediction_mode = self.cfg.eval_prediction_mode + + if prediction_mode == "average_before": + x = x.mean(dim=1) + + if prediction_mode != "summary_mha" and prediction_mode != "summary_proj" and prediction_mode != "cls": + x = self.proj(x) + + logits = True + if prediction_mode == "lin_softmax": + x = F.logsigmoid(x.float()) + x = torch.logsumexp(x + x, dim=1) - torch.logsumexp(x, dim=1) + x = x.clamp(max=0) + x = x - torch.log(-(torch.expm1(x))) + elif prediction_mode == "extremized_odds": + x = x.float().sum(dim=1) + x = x * self.cfg.extreme_factor + elif prediction_mode == "average_before": + x = x.float() + elif prediction_mode == "average": + x = x.float().mean(dim=1) + elif prediction_mode == "average_sigmoid": + x = torch.sigmoid(x.float()) + x = x.mean(dim=1) + logits = False + elif prediction_mode == "max": + x, _ = x.float().max(dim=1) + elif prediction_mode == "max_sigmoid": + x = torch.sigmoid(x.float()) + x, _ = x.float().max(dim=1) + logits = False + elif prediction_mode == "proj_avg_proj": + x = x.mean(dim=1) + x = self.proj2(x) + elif prediction_mode == "summary_mha" or prediction_mode == "summary_proj": + x = self.d2v_model.summary( + x, padding_mask, proj=prediction_mode == "summary_proj" + ) + x = x.type_as(source) + x = self.proj(x) + elif prediction_mode == "cls": + x = x[:,0] + x = self.proj(x) + else: + raise Exception(f"unknown prediction mode {prediction_mode}") + + if label is None: + return torch.sigmoid(x) if logits else x + + x = torch.nan_to_num(x) + + if logits: + loss = F.binary_cross_entropy_with_logits( + x, label.float(), reduction="none" + ) + else: + loss = F.binary_cross_entropy(x, label.float(), reduction="none") + + result = { + "losses": { + "main": loss, + }, + "sample_size": label.sum(), + } + + if not self.training: + result["_predictions"] = torch.sigmoid(x) if logits else x + result["_targets"] = label + + return result diff --git a/examples/data2vec/models/data2vec2.py b/examples/data2vec/models/data2vec2.py new file mode 100644 index 000000000..0c61b3708 --- /dev/null +++ b/examples/data2vec/models/data2vec2.py @@ -0,0 +1,813 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import math +from dataclasses import dataclass, field +from typing import Optional, Callable +from functools import partial +import numpy as np + +from omegaconf import II + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.distributed as dist + +from fairseq.modules import EMAModule, EMAModuleConfig + +from fairseq.dataclass import FairseqDataclass +from fairseq.models import BaseFairseqModel, register_model + +from examples.data2vec.data.modality import Modality + +from examples.data2vec.models.modalities.base import ( + MaskSeed, + D2vModalityConfig, + ModalitySpecificEncoder, + get_annealed_rate, +) +from examples.data2vec.models.modalities.modules import ( + D2vDecoderConfig, + AltBlock, + Decoder1d, +) + +from examples.data2vec.models.modalities.audio import ( + D2vAudioConfig, + AudioEncoder, +) +from examples.data2vec.models.modalities.images import ( + D2vImageConfig, + ImageEncoder, +) +from examples.data2vec.models.modalities.text import ( + D2vTextConfig, + TextEncoder, +) + +logger = logging.getLogger(__name__) + + +@dataclass +class D2vModalitiesConfig(FairseqDataclass): + audio: D2vAudioConfig = D2vAudioConfig() + image: D2vImageConfig = D2vImageConfig() + text: D2vTextConfig = D2vTextConfig() + + +@dataclass +class Data2VecMultiConfig(FairseqDataclass): + + loss_beta: float = field( + default=0, metadata={"help": "beta for smooth l1 loss. 0 means use l2 loss"} + ) + loss_scale: Optional[float] = field( + default=None, + metadata={ + "help": "scale the reconstruction loss by this constant. if None then scales by 1/sqrt(dim)" + }, + ) + + depth: int = 8 + start_drop_path_rate: float = 0 + end_drop_path_rate: float = 0 + num_heads: int = 12 + norm_eps: float = 1e-6 + norm_affine: bool = True + encoder_dropout: float = 0.1 + post_mlp_drop: float = 0.1 + attention_dropout: float = 0.1 + activation_dropout: float = 0.0 + dropout_input: float = 0.0 + layerdrop: float = 0.0 + embed_dim: int = 768 + mlp_ratio: float = 4 + layer_norm_first: bool = False + + average_top_k_layers: int = field( + default=8, metadata={"help": "how many layers to average"} + ) + + end_of_block_targets: bool = False + + clone_batch: int = 1 + + layer_norm_target_layer: bool = False + batch_norm_target_layer: bool = False + instance_norm_target_layer: bool = False + instance_norm_targets: bool = False + layer_norm_targets: bool = False + + ema_decay: float = field(default=0.999, metadata={"help": "initial ema decay rate"}) + ema_same_dtype: bool = True + log_norms: bool = True + ema_end_decay: float = field( + default=0.9999, metadata={"help": "final ema decay rate"} + ) + + # when to finish annealing ema decay rate + ema_anneal_end_step: int = II("optimization.max_update") + + ema_encoder_only: bool = field( + default=True, + metadata={ + "help": "whether to momentum update only the shared transformer encoder" + }, + ) + + max_update: int = II("optimization.max_update") + + modalities: D2vModalitiesConfig = D2vModalitiesConfig() + + shared_decoder: Optional[D2vDecoderConfig] = None + + min_target_var: float = field( + default=0.1, metadata={"help": "stop training if target var falls below this"} + ) + min_pred_var: float = field( + default=0.01, + metadata={"help": "stop training if prediction var falls below this"}, + ) + + supported_modality: Optional[Modality] = None + mae_init: bool = False + + seed: int = II("common.seed") + + skip_ema: bool = False + + cls_loss: float = 0 + recon_loss: float = 0 + d2v_loss: float = 1 + + decoder_group: bool = False + + +@register_model("data2vec_multi", dataclass=Data2VecMultiConfig) +class Data2VecMultiModel(BaseFairseqModel): + def make_modality_encoder( + self, + cfg: D2vModalityConfig, + embed_dim: int, + make_block: Callable[[float], nn.ModuleList], + norm_layer: Callable[[int], nn.LayerNorm], + layer_norm_first: bool, + alibi_biases, + task, + ) -> ModalitySpecificEncoder: + if cfg.type == Modality.AUDIO: + enc_cls = AudioEncoder + elif cfg.type == Modality.IMAGE: + enc_cls = ImageEncoder + elif cfg.type == Modality.TEXT: + enc_cls = TextEncoder + if hasattr(task, "text_task"): + task = task.text_task + else: + raise Exception(f"unsupported modality {cfg.type}") + + return enc_cls( + cfg, + embed_dim, + make_block, + norm_layer, + layer_norm_first, + alibi_biases, + task, + ) + + def __init__(self, cfg: Data2VecMultiConfig, modalities, skip_ema=False, task=None): + super().__init__() + self.cfg = cfg + self.modalities = modalities + self.task = task + + make_layer_norm = partial( + nn.LayerNorm, eps=cfg.norm_eps, elementwise_affine=cfg.norm_affine + ) + + def make_block(drop_path, dim=None, heads=None): + return AltBlock( + cfg.embed_dim if dim is None else dim, + cfg.num_heads if heads is None else heads, + cfg.mlp_ratio, + qkv_bias=True, + drop=cfg.encoder_dropout, + attn_drop=cfg.attention_dropout, + mlp_drop=cfg.activation_dropout, + post_mlp_drop=cfg.post_mlp_drop, + drop_path=drop_path, + norm_layer=make_layer_norm, + layer_norm_first=cfg.layer_norm_first, + ffn_targets=not cfg.end_of_block_targets, + ) + + self.alibi_biases = {} + self.modality_encoders = nn.ModuleDict() + for mod in self.modalities: + mod_cfg = getattr(cfg.modalities, mod.name.lower()) + enc = self.make_modality_encoder( + mod_cfg, + cfg.embed_dim, + make_block, + make_layer_norm, + cfg.layer_norm_first, + self.alibi_biases, + task, + ) + self.modality_encoders[mod.name] = enc + + self.ema = None + + self.average_top_k_layers = cfg.average_top_k_layers + self.loss_beta = cfg.loss_beta + self.loss_scale = cfg.loss_scale + + self.dropout_input = nn.Dropout(cfg.dropout_input) + + dpr = np.linspace(cfg.start_drop_path_rate, cfg.end_drop_path_rate, cfg.depth) + + self.blocks = nn.ModuleList([make_block(dpr[i]) for i in range(cfg.depth)]) + + self.norm = None + if cfg.layer_norm_first: + self.norm = make_layer_norm(cfg.embed_dim) + + if self.cfg.mae_init: + self.apply(self._init_weights) + else: + from fairseq.modules.transformer_sentence_encoder import init_bert_params + + self.apply(init_bert_params) + + for mod_enc in self.modality_encoders.values(): + mod_enc.reset_parameters() + + if not skip_ema: + self.ema = self.make_ema_teacher(cfg.ema_decay) + self.shared_decoder = ( + Decoder1d(cfg.shared_decoder, cfg.embed_dim) + if self.cfg.shared_decoder is not None + else None + ) + if self.shared_decoder is not None: + self.shared_decoder.apply(self._init_weights) + + self.recon_proj = None + if cfg.recon_loss > 0: + self.recon_proj = nn.Linear(cfg.embed_dim, cfg.embed_dim) + + for pn, p in self.named_parameters(): + if len(p.shape) == 1 or pn.endswith(".bias") or "alibi_scale" in pn: + p.optim_overrides = {"optimizer": {"weight_decay_scale": 0}} + if cfg.decoder_group and "decoder" in pn: + p.param_group = "decoder" + + self.num_updates = 0 + + def _init_weights(self, m): + + try: + from apex.normalization import FusedLayerNorm + + fn = FusedLayerNorm + except: + fn = nn.LayerNorm + + if isinstance(m, nn.Linear): + torch.nn.init.xavier_uniform_(m.weight) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm) or isinstance(m, fn): + if m.bias is not None: + nn.init.constant_(m.bias, 0) + if m.weight is not None: + nn.init.constant_(m.weight, 1.0) + + @torch.no_grad() + def make_ema_teacher(self, ema_decay): + ema_config = EMAModuleConfig( + ema_decay=ema_decay, + ema_fp32=True, + log_norms=self.cfg.log_norms, + add_missing_params=False, + ) + + model_copy = self.make_target_model() + + return EMAModule( + model_copy, + ema_config, + copy_model=False, + ) + + def make_target_model(self): + logger.info("making target model") + + model_copy = Data2VecMultiModel( + self.cfg, self.modalities, skip_ema=True, task=self.task + ) + + if self.cfg.ema_encoder_only: + model_copy = model_copy.blocks + for p_s, p_t in zip(self.blocks.parameters(), model_copy.parameters()): + p_t.data.copy_(p_s.data) + else: + for p_s, p_t in zip(self.parameters(), model_copy.parameters()): + p_t.data.copy_(p_s.data) + + for mod_enc in model_copy.modality_encoders.values(): + mod_enc.decoder = None + if not mod_enc.modality_cfg.ema_local_encoder: + mod_enc.local_encoder = None + mod_enc.project_features = None + + model_copy.requires_grad_(False) + return model_copy + + def set_num_updates(self, num_updates): + super().set_num_updates(num_updates) + + if self.ema is not None and ( + (self.num_updates == 0 and num_updates > 1) + or self.num_updates >= num_updates + ): + pass + elif self.training and self.ema is not None: + ema_weight_decay = None + if self.cfg.ema_decay != self.cfg.ema_end_decay: + if num_updates >= self.cfg.ema_anneal_end_step: + decay = self.cfg.ema_end_decay + else: + decay = get_annealed_rate( + self.cfg.ema_decay, + self.cfg.ema_end_decay, + num_updates, + self.cfg.ema_anneal_end_step, + ) + self.ema.set_decay(decay, weight_decay=ema_weight_decay) + if self.ema.get_decay() < 1: + self.ema.step(self.blocks if self.cfg.ema_encoder_only else self) + + self.num_updates = num_updates + + def state_dict(self, destination=None, prefix="", keep_vars=False): + state = super().state_dict(destination, prefix, keep_vars) + + if self.ema is not None: + state[prefix + "_ema"] = self.ema.fp32_params + + return state + + def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): + k = prefix + "_ema" + if self.ema is not None: + assert k in state_dict + self.ema.restore(state_dict[k], True) + del state_dict[k] + elif k in state_dict: + del state_dict[k] + + return super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) + + @classmethod + def build_model(cls, cfg: Data2VecMultiConfig, task=None): + """Build a new model instance.""" + if task is None or not hasattr(task, "supported_modalities"): + modalities = ( + [cfg.supported_modality] + if cfg.supported_modality is not None + else [ + Modality.AUDIO, + Modality.IMAGE, + Modality.TEXT, + ] + ) + else: + modalities = task.supported_modalities + return cls(cfg, modalities, task=task, skip_ema=cfg.skip_ema) + + def forward( + self, + source, + target=None, + id=None, + mode=None, + padding_mask=None, + mask=True, + features_only=False, + force_remove_masked=False, + remove_extra_tokens=True, + precomputed_mask=None, + ): + if mode is None: + assert self.cfg.supported_modality is not None + mode = self.cfg.supported_modality + + if isinstance(mode, Modality): + mode = mode.name + + feature_extractor = self.modality_encoders[mode] + + mask_seeds = None + if id is not None: + mask_seeds = MaskSeed(seed=self.cfg.seed, update=self.num_updates, ids=id) + + extractor_out = feature_extractor( + source, + padding_mask, + mask, + remove_masked=not features_only or force_remove_masked, + clone_batch=self.cfg.clone_batch if not features_only else 1, + mask_seeds=mask_seeds, + precomputed_mask=precomputed_mask, + ) + + x = extractor_out["x"] + encoder_mask = extractor_out["encoder_mask"] + masked_padding_mask = extractor_out["padding_mask"] + masked_alibi_bias = extractor_out.get("alibi_bias", None) + alibi_scale = extractor_out.get("alibi_scale", None) + + if self.dropout_input is not None: + x = self.dropout_input(x) + + layer_results = [] + for i, blk in enumerate(self.blocks): + if ( + not self.training + or self.cfg.layerdrop == 0 + or (np.random.random() > self.cfg.layerdrop) + ): + ab = masked_alibi_bias + if ab is not None and alibi_scale is not None: + scale = ( + alibi_scale[i] + if alibi_scale.size(0) > 1 + else alibi_scale.squeeze(0) + ) + ab = ab * scale.type_as(ab) + + x, lr = blk( + x, + padding_mask=masked_padding_mask, + alibi_bias=ab, + ) + if features_only: + layer_results.append(lr) + + if self.norm is not None: + x = self.norm(x) + + if features_only: + if remove_extra_tokens: + x = x[:, feature_extractor.modality_cfg.num_extra_tokens :] + if masked_padding_mask is not None: + masked_padding_mask = masked_padding_mask[ + :, feature_extractor.modality_cfg.num_extra_tokens : + ] + + return { + "x": x, + "padding_mask": masked_padding_mask, + "layer_results": layer_results, + "mask": encoder_mask, + } + + xs = [] + + if self.shared_decoder is not None: + dx = self.forward_decoder( + x, + feature_extractor, + self.shared_decoder, + encoder_mask, + ) + xs.append(dx) + if feature_extractor.decoder is not None: + dx = self.forward_decoder( + x, + feature_extractor, + feature_extractor.decoder, + encoder_mask, + ) + xs.append(dx) + orig_x = x + + assert len(xs) > 0 + + p = next(self.ema.model.parameters()) + device = x.device + dtype = x.dtype + ema_device = p.device + ema_dtype = p.dtype + + if not self.cfg.ema_same_dtype: + dtype = ema_dtype + + if ema_device != device or ema_dtype != dtype: + logger.info(f"adjusting ema dtype to {dtype} and device to {device}") + self.ema.model = self.ema.model.to(dtype=dtype, device=device) + ema_dtype = dtype + + def to_device(d): + for k, p in d.items(): + if isinstance(d[k], dict): + to_device(d[k]) + else: + d[k] = p.to(device=device) + + to_device(self.ema.fp32_params) + tm = self.ema.model + + with torch.no_grad(): + tm.eval() + + if self.cfg.ema_encoder_only: + assert target is None + ema_input = extractor_out["local_features"] + ema_input = feature_extractor.contextualized_features( + ema_input.to(dtype=ema_dtype), + padding_mask, + mask=False, + remove_masked=False, + ) + ema_blocks = tm + else: + ema_blocks = tm.blocks + if feature_extractor.modality_cfg.ema_local_encoder: + inp = ( + target.to(dtype=ema_dtype) + if target is not None + else source.to(dtype=ema_dtype) + ) + ema_input = tm.modality_encoders[mode]( + inp, + padding_mask, + mask=False, + remove_masked=False, + ) + else: + assert target is None + ema_input = extractor_out["local_features"] + ema_feature_enc = tm.modality_encoders[mode] + ema_input = ema_feature_enc.contextualized_features( + ema_input.to(dtype=ema_dtype), + padding_mask, + mask=False, + remove_masked=False, + ) + + ema_padding_mask = ema_input["padding_mask"] + ema_alibi_bias = ema_input.get("alibi_bias", None) + ema_alibi_scale = ema_input.get("alibi_scale", None) + ema_input = ema_input["x"] + + y = [] + ema_x = [] + extra_tokens = feature_extractor.modality_cfg.num_extra_tokens + for i, blk in enumerate(ema_blocks): + ab = ema_alibi_bias + if ab is not None and alibi_scale is not None: + scale = ( + ema_alibi_scale[i] + if ema_alibi_scale.size(0) > 1 + else ema_alibi_scale.squeeze(0) + ) + ab = ab * scale.type_as(ab) + + ema_input, lr = blk( + ema_input, + padding_mask=ema_padding_mask, + alibi_bias=ab, + ) + y.append(lr[:, extra_tokens:]) + ema_x.append(ema_input[:, extra_tokens:]) + + y = self.make_targets(y, self.average_top_k_layers) + orig_targets = y + + if self.cfg.clone_batch > 1: + y = y.repeat_interleave(self.cfg.clone_batch, 0) + + masked = encoder_mask.mask.unsqueeze(-1) + masked_b = encoder_mask.mask.bool() + y = y[masked_b] + + if xs[0].size(1) == masked_b.size(1): + xs = [x[masked_b] for x in xs] + else: + xs = [x.reshape(-1, x.size(-1)) for x in xs] + + sample_size = masked.sum().long() + + result = { + "losses": {}, + "sample_size": sample_size, + } + + sample_size = result["sample_size"] + + if self.cfg.cls_loss > 0: + assert extra_tokens > 0 + cls_target = orig_targets.mean(dim=1) + if self.cfg.clone_batch > 1: + cls_target = cls_target.repeat_interleave(self.cfg.clone_batch, 0) + cls_pred = x[:, extra_tokens - 1] + result["losses"]["cls"] = self.d2v_loss(cls_pred, cls_target) * ( + self.cfg.cls_loss * sample_size + ) + + if self.cfg.recon_loss > 0: + + with torch.no_grad(): + target = feature_extractor.patchify(source) + mean = target.mean(dim=-1, keepdim=True) + var = target.var(dim=-1, keepdim=True) + target = (target - mean) / (var + 1.0e-6) ** 0.5 + + if self.cfg.clone_batch > 1: + target = target.repeat_interleave(self.cfg.clone_batch, 0) + + if masked_b is not None: + target = target[masked_b] + + recon = xs[0] + if self.recon_proj is not None: + recon = self.recon_proj(recon) + + result["losses"]["recon"] = ( + self.d2v_loss(recon, target.float()) * self.cfg.recon_loss + ) + + if self.cfg.d2v_loss > 0: + for i, x in enumerate(xs): + reg_loss = self.d2v_loss(x, y) + n = f"{mode}_regression_{i}" if len(xs) > 1 else f"{mode}_regression" + result["losses"][n] = reg_loss * self.cfg.d2v_loss + + suffix = "" if len(self.modalities) == 1 else f"_{mode}" + with torch.no_grad(): + if encoder_mask is not None: + result["masked_pct"] = 1 - ( + encoder_mask.ids_keep.size(1) / encoder_mask.ids_restore.size(1) + ) + for i, x in enumerate(xs): + n = f"pred_var{suffix}_{i}" if len(xs) > 1 else f"pred_var{suffix}" + result[n] = self.compute_var(x.float()) + if self.ema is not None: + for k, v in self.ema.logs.items(): + result[k] = v + + y = y.float() + result[f"target_var{suffix}"] = self.compute_var(y) + + if self.num_updates > 5000: + if result[f"target_var{suffix}"] < self.cfg.min_target_var: + logger.error( + f"target var is {result[f'target_var{suffix}'].item()} < {self.cfg.min_target_var}, exiting ({mode})" + ) + raise Exception( + f"target var is {result[f'target_var{suffix}'].item()} < {self.cfg.min_target_var}, exiting ({mode})" + ) + + for k in result.keys(): + if k.startswith("pred_var") and result[k] < self.cfg.min_pred_var: + logger.error( + f"{k} is {result[k].item()} < {self.cfg.min_pred_var}, exiting ({mode})" + ) + raise Exception( + f"{k} is {result[k].item()} < {self.cfg.min_pred_var}, exiting ({mode})" + ) + + result["ema_decay"] = self.ema.get_decay() * 1000 + + return result + + def forward_decoder( + self, + x, + feature_extractor, + decoder, + mask_info, + ): + x = feature_extractor.decoder_input(x, mask_info) + x = decoder(*x) + + return x + + def d2v_loss(self, x, y): + x = x.view(-1, x.size(-1)).float() + y = y.view(-1, x.size(-1)) + + if self.loss_beta == 0: + loss = F.mse_loss(x, y, reduction="none") + else: + loss = F.smooth_l1_loss(x, y, reduction="none", beta=self.loss_beta) + + if self.loss_scale is not None: + scale = self.loss_scale + else: + scale = 1 / math.sqrt(x.size(-1)) + + reg_loss = loss * scale + + return reg_loss + + def make_targets(self, y, num_layers): + + with torch.no_grad(): + target_layer_results = y[-num_layers:] + + permuted = False + if self.cfg.instance_norm_target_layer or self.cfg.batch_norm_target_layer: + target_layer_results = [ + tl.transpose(1, 2) for tl in target_layer_results # BTC -> BCT + ] + permuted = True + if self.cfg.batch_norm_target_layer: + target_layer_results = [ + F.batch_norm( + tl.float(), running_mean=None, running_var=None, training=True + ) + for tl in target_layer_results + ] + if self.cfg.instance_norm_target_layer: + target_layer_results = [ + F.instance_norm(tl.float()) for tl in target_layer_results + ] + if permuted: + target_layer_results = [ + tl.transpose(1, 2) for tl in target_layer_results # BCT -> BTC + ] + if self.cfg.layer_norm_target_layer: + target_layer_results = [ + F.layer_norm(tl.float(), tl.shape[-1:]) + for tl in target_layer_results + ] + + y = target_layer_results[0].float() + for tl in target_layer_results[1:]: + y.add_(tl.float()) + y = y.div_(len(target_layer_results)) + + if self.cfg.layer_norm_targets: + y = F.layer_norm(y, y.shape[-1:]) + + if self.cfg.instance_norm_targets: + y = F.instance_norm(y.transpose(1, 2)).transpose(1, 2) + + return y + + @staticmethod + def compute_var(y): + y = y.view(-1, y.size(-1)) + if dist.is_initialized(): + zc = torch.tensor(y.size(0)).cuda() + zs = y.sum(dim=0) + zss = (y**2).sum(dim=0) + + dist.all_reduce(zc) + dist.all_reduce(zs) + dist.all_reduce(zss) + + var = zss / (zc - 1) - (zs**2) / (zc * (zc - 1)) + return torch.sqrt(var + 1e-6).mean() + else: + return torch.sqrt(y.var(dim=0) + 1e-6).mean() + + def extract_features( + self, source, mode=None, padding_mask=None, mask=False, remove_extra_tokens=True + ): + res = self.forward( + source, + mode=mode, + padding_mask=padding_mask, + mask=mask, + features_only=True, + remove_extra_tokens=remove_extra_tokens, + ) + return res + + def remove_pretraining_modules(self, modality=None, keep_decoder=False): + self.ema = None + self.cfg.clone_batch = 1 + self.recon_proj = None + + if not keep_decoder: + self.shared_decoder = None + + modality = modality.lower() if modality is not None else None + for k in list(self.modality_encoders.keys()): + if modality is not None and k.lower() != modality: + del self.modality_encoders[k] + else: + self.modality_encoders[k].remove_pretraining_modules( + keep_decoder=keep_decoder + ) + if not keep_decoder: + self.modality_encoders[k].decoder = None diff --git a/examples/data2vec/models/data2vec_image_classification.py b/examples/data2vec/models/data2vec_image_classification.py new file mode 100644 index 000000000..851c9ce45 --- /dev/null +++ b/examples/data2vec/models/data2vec_image_classification.py @@ -0,0 +1,143 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# The code in this file is adapted from the BeiT implementation which can be found here: +# https://github.com/microsoft/unilm/tree/master/beit + +import logging + +from dataclasses import dataclass +from typing import Any + +from omegaconf import II, MISSING + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from fairseq import checkpoint_utils, tasks + +from fairseq.dataclass import FairseqDataclass +from fairseq.models import BaseFairseqModel, register_model + + +logger = logging.getLogger(__name__) + + +@dataclass +class Data2VecImageClassificationConfig(FairseqDataclass): + model_path: str = MISSING + no_pretrained_weights: bool = False + num_classes: int = 1000 + mixup: float = 0.8 + cutmix: float = 1.0 + label_smoothing: float = 0.1 + + pretrained_model_args: Any = None + data: str = II("task.data") + + +@register_model( + "data2vec_image_classification", dataclass=Data2VecImageClassificationConfig +) +class Data2VecImageClassificationModel(BaseFairseqModel): + def __init__(self, cfg: Data2VecImageClassificationConfig): + super().__init__() + self.cfg = cfg + + if cfg.pretrained_model_args is None: + state = checkpoint_utils.load_checkpoint_to_cpu(cfg.model_path, {}) + pretrained_args = state.get("cfg", None) + pretrained_args.criterion = None + pretrained_args.lr_scheduler = None + cfg.pretrained_model_args = pretrained_args + + logger.info(pretrained_args) + else: + state = None + pretrained_args = cfg.pretrained_model_args + + pretrained_args.task.data = cfg.data + task = tasks.setup_task(pretrained_args.task) + model = task.build_model(pretrained_args.model, from_checkpoint=True) + + model.remove_pretraining_modules() + + self.model = model + + if state is not None and not cfg.no_pretrained_weights: + self.load_model_weights(state, model, cfg) + + self.fc_norm = nn.LayerNorm(pretrained_args.model.embed_dim) + self.head = nn.Linear(pretrained_args.model.embed_dim, cfg.num_classes) + + self.head.weight.data.mul_(1e-3) + self.head.bias.data.mul_(1e-3) + + self.mixup_fn = None + + if cfg.mixup > 0 or cfg.cutmix > 0: + from timm.data import Mixup + + self.mixup_fn = Mixup( + mixup_alpha=cfg.mixup, + cutmix_alpha=cfg.cutmix, + cutmix_minmax=None, + prob=1.0, + switch_prob=0.5, + mode="batch", + label_smoothing=cfg.label_smoothing, + num_classes=cfg.num_classes, + ) + + def load_model_weights(self, state, model, cfg): + if "_ema" in state["model"]: + del state["model"]["_ema"] + model.load_state_dict(state["model"], strict=True) + + @classmethod + def build_model(cls, cfg: Data2VecImageClassificationConfig, task=None): + """Build a new model instance.""" + + return cls(cfg) + + def forward( + self, + img, + label=None, + ): + if self.training and self.mixup_fn is not None and label is not None: + img, label = self.mixup_fn(img, label) + + x = self.model(img, mask=False) + x = x[:, 1:] + x = self.fc_norm(x.mean(1)) + x = self.head(x) + + if label is None: + return x + + if self.training and self.mixup_fn is not None: + loss = -label * F.log_softmax(x.float(), dim=-1) + else: + loss = F.cross_entropy( + x.float(), + label, + label_smoothing=self.cfg.label_smoothing if self.training else 0, + reduction="none", + ) + + result = { + "losses": {"regression": loss}, + "sample_size": img.size(0), + } + + if not self.training: + with torch.no_grad(): + pred = x.argmax(-1) + correct = (pred == label).sum() + result["correct"] = correct + + return result diff --git a/examples/data2vec/models/data2vec_text_classification.py b/examples/data2vec/models/data2vec_text_classification.py new file mode 100644 index 000000000..e787b916d --- /dev/null +++ b/examples/data2vec/models/data2vec_text_classification.py @@ -0,0 +1,141 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# The code in this file is adapted from the BeiT implementation which can be found here: +# https://github.com/microsoft/unilm/tree/master/beit + +import logging + +from dataclasses import dataclass +from typing import Any + +from omegaconf import II, MISSING + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from fairseq import checkpoint_utils, tasks + +from fairseq.dataclass import FairseqDataclass +from fairseq.models import BaseFairseqModel, register_model +from fairseq.models.roberta.model import RobertaClassificationHead + +from examples.data2vec.data.modality import Modality + + +logger = logging.getLogger(__name__) + + +@dataclass +class Data2VecTextClassificationConfig(FairseqDataclass): + pooler_dropout: float = 0.0 + pooler_activation_fn: str = "tanh" + quant_noise_pq: int = 0 + quant_noise_pq_block_size: int = 8 + spectral_norm_classification_head: bool = False + + model_path: str = MISSING + no_pretrained_weights: bool = False + + pretrained_model_args: Any = None + + +@register_model( + "data2vec_text_classification", dataclass=Data2VecTextClassificationConfig +) +class Data2VecTextClassificationModel(BaseFairseqModel): + def __init__(self, cfg: Data2VecTextClassificationConfig): + super().__init__() + self.cfg = cfg + + if cfg.pretrained_model_args is None: + state = checkpoint_utils.load_checkpoint_to_cpu(cfg.model_path, {}) + pretrained_args = state.get("cfg", None) + pretrained_args.criterion = None + pretrained_args.lr_scheduler = None + cfg.pretrained_model_args = pretrained_args + + logger.info(pretrained_args) + else: + state = None + pretrained_args = cfg.pretrained_model_args + + task = tasks.setup_task(pretrained_args.task) + model = task.build_model(pretrained_args.model, from_checkpoint=True) + + model.remove_pretraining_modules() + + self.model = model + + if state is not None and not cfg.no_pretrained_weights: + self.load_model_weights(state, model, cfg) + + self.classification_heads = nn.ModuleDict() + + + def load_model_weights(self, state, model, cfg): + for k in list(state["model"].keys()): + if ( + k.startswith("shared_decoder") or + k.startswith("_ema") or + "decoder" in k + ): + logger.info(f"Deleting {k} from checkpoint") + del state["model"][k] + model.load_state_dict(state["model"], strict=True) + + @classmethod + def build_model(cls, cfg: Data2VecTextClassificationConfig, task=None): + """Build a new model instance.""" + + return cls(cfg) + + def register_classification_head( + self, name, num_classes=None, inner_dim=None, **kwargs + ): + """Register a classification head.""" + if name in self.classification_heads: + prev_num_classes = self.classification_heads[name].out_proj.out_features + prev_inner_dim = self.classification_heads[name].dense.out_features + if num_classes != prev_num_classes or inner_dim != prev_inner_dim: + logger.warning( + 're-registering head "{}" with num_classes {} (prev: {}) ' + "and inner_dim {} (prev: {})".format( + name, num_classes, prev_num_classes, inner_dim, prev_inner_dim + ) + ) + embed_dim = self.cfg.pretrained_model_args.model.embed_dim + self.classification_heads[name] = RobertaClassificationHead( + input_dim=embed_dim, + inner_dim=inner_dim or embed_dim, + num_classes=num_classes, + activation_fn=self.cfg.pooler_activation_fn, + pooler_dropout=self.cfg.pooler_dropout, + q_noise=self.cfg.quant_noise_pq, + qn_block_size=self.cfg.quant_noise_pq_block_size, + do_spectral_norm=self.cfg.spectral_norm_classification_head, + ) + + def forward( + self, + source, + id, + padding_mask, + features_only=True, + remove_extra_tokens=True, + classification_head_name=None, + ): + encoder_out = self.model( + source, + id=id, + mode=Modality.TEXT, + padding_mask=padding_mask, + mask=False, + features_only=features_only, + remove_extra_tokens=remove_extra_tokens + ) + logits = self.classification_heads[classification_head_name](encoder_out["x"]) + return logits, encoder_out diff --git a/examples/data2vec/models/data2vec_vision.py b/examples/data2vec/models/data2vec_vision.py new file mode 100644 index 000000000..2f8989442 --- /dev/null +++ b/examples/data2vec/models/data2vec_vision.py @@ -0,0 +1,727 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# The code in this file is adapted from the BeiT implementation which can be found here: +# https://github.com/microsoft/unilm/tree/master/beit + +import logging +import math +import numpy as np +import random + +from dataclasses import dataclass, field +from typing import Optional + +from omegaconf import II + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.distributed as dist + +from fairseq.modules import EMAModule, EMAModuleConfig +from fairseq.dataclass import FairseqDataclass +from fairseq.models import BaseFairseqModel, register_model + + +logger = logging.getLogger(__name__) + + +@dataclass +class Data2VecVisionConfig(FairseqDataclass): + layer_scale_init_value: float = field( + default=1e-4, metadata={"help": "rescale layer outputs, 0 to disable"} + ) + num_mask_patches: int = field( + default=75, + metadata={"help": "number of the visual tokens/patches need be masked"}, + ) + min_mask_patches_per_block: int = 16 + max_mask_patches_per_block: int = 196 + image_size: int = 224 + patch_size: int = 16 + in_channels: int = 3 + + shared_rel_pos_bias: bool = True + + drop_path: float = 0.1 + attention_dropout: float = 0.0 + + depth: int = 12 + embed_dim: int = 768 + num_heads: int = 12 + mlp_ratio: int = 4 + + loss_beta: float = field( + default=0, metadata={"help": "beta for smooth l1 loss. 0 means use l2 loss"} + ) + loss_scale: Optional[float] = field( + default=None, + metadata={ + "help": "scale the reconstruction loss by this constant. if None then scales by 1/sqrt(dim)" + }, + ) + average_top_k_layers: int = field( + default=8, metadata={"help": "how many layers to average"} + ) + + end_of_block_targets: bool = True + layer_norm_target_layer: bool = False + instance_norm_target_layer: bool = False + batch_norm_target_layer: bool = False + instance_norm_targets: bool = False + layer_norm_targets: bool = False + + ema_decay: float = field(default=0.999, metadata={"help": "initial ema decay rate"}) + ema_end_decay: float = field( + default=0.9999, metadata={"help": "final ema decay rate"} + ) + + # when to finish annealing ema decay rate + ema_anneal_end_step: int = II("optimization.max_update") + + ema_transformer_only: bool = field( + default=True, + metadata={"help": "whether to momentum update only the transformer layers"}, + ) + + +def get_annealed_rate(start, end, curr_step, total_steps): + r = end - start + pct_remaining = 1 - curr_step / total_steps + return end - r * pct_remaining + + +@register_model("data2vec_vision", dataclass=Data2VecVisionConfig) +class Data2VecVisionModel(BaseFairseqModel): + def __init__(self, cfg: Data2VecVisionConfig): + super().__init__() + self.cfg = cfg + + self.ema = None + + self.average_top_k_layers = cfg.average_top_k_layers + self.loss_beta = cfg.loss_beta + self.loss_scale = ( + cfg.loss_scale + if cfg.loss_scale is not None + else 1 / math.sqrt(cfg.embed_dim) + ) + + self.patch_embed = PatchEmbed( + img_size=cfg.image_size, + patch_size=cfg.patch_size, + in_chans=cfg.in_channels, + embed_dim=cfg.embed_dim, + ) + + patch_size = self.patch_embed.patch_size + self.window_size = ( + cfg.image_size // patch_size[0], + cfg.image_size // patch_size[1], + ) + + self.cls_emb = nn.Parameter(torch.FloatTensor(1, 1, cfg.embed_dim)) + self.mask_emb = nn.Parameter(torch.FloatTensor(1, 1, cfg.embed_dim)) + + nn.init.trunc_normal_(self.cls_emb, 0.02) + nn.init.trunc_normal_(self.mask_emb, 0.02) + + self.encoder = TransformerEncoder(cfg, self.patch_embed.patch_shape) + + self.final_proj = nn.Linear(cfg.embed_dim, cfg.embed_dim) + self.num_updates = 0 + + def make_ema_teacher(self): + ema_config = EMAModuleConfig( + ema_decay=self.cfg.ema_decay, + ema_fp32=True, + ) + self.ema = EMAModule( + self.encoder if self.cfg.ema_transformer_only else self, + ema_config, + ) + + def set_num_updates(self, num_updates): + super().set_num_updates(num_updates) + + if self.ema is None and self.final_proj is not None: + logger.info(f"making ema teacher") + self.make_ema_teacher() + elif self.training and self.ema is not None: + if self.cfg.ema_decay != self.cfg.ema_end_decay: + if num_updates >= self.cfg.ema_anneal_end_step: + decay = self.cfg.ema_end_decay + else: + decay = get_annealed_rate( + self.cfg.ema_decay, + self.cfg.ema_end_decay, + num_updates, + self.cfg.ema_anneal_end_step, + ) + self.ema.set_decay(decay) + if self.ema.get_decay() < 1: + self.ema.step(self.encoder if self.cfg.ema_transformer_only else self) + + self.num_updates = num_updates + + def state_dict(self, destination=None, prefix="", keep_vars=False): + state = super().state_dict(destination, prefix, keep_vars) + + if self.ema is not None: + state[prefix + "_ema"] = self.ema.fp32_params + + return state + + def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): + if self.ema is not None: + k = prefix + "_ema" + assert k in state_dict + self.ema.restore(state_dict[k], True) + del state_dict[k] + return super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) + + @classmethod + def build_model(cls, cfg: Data2VecVisionConfig, task=None): + """Build a new model instance.""" + + return cls(cfg) + + def make_mask(self, bsz, num_masks, min_masks, max_masks): + height, width = self.window_size + + masks = np.zeros(shape=(bsz, height, width), dtype=np.int) + + for i in range(bsz): + mask = masks[i] + mask_count = 0 + + min_aspect = 0.3 + max_aspect = 1 / min_aspect + log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect)) + + def _mask(mask, max_mask_patches): + delta = 0 + for attempt in range(10): + target_area = random.uniform(min_masks, max_mask_patches) + aspect_ratio = math.exp(random.uniform(*log_aspect_ratio)) + h = int(round(math.sqrt(target_area * aspect_ratio))) + w = int(round(math.sqrt(target_area / aspect_ratio))) + if w < width and h < height: + top = random.randint(0, height - h) + left = random.randint(0, width - w) + + num_masked = mask[top : top + h, left : left + w].sum() + # Overlap + if 0 < h * w - num_masked <= max_mask_patches: + for i in range(top, top + h): + for j in range(left, left + w): + if mask[i, j] == 0: + mask[i, j] = 1 + delta += 1 + + if delta > 0: + break + return delta + + while mask_count < num_masks: + max_mask_patches = min(num_masks - mask_count, max_masks) + + delta = _mask(mask, max_mask_patches) + if delta == 0: + break + else: + mask_count += delta + + return torch.from_numpy(masks) + + def forward( + self, + img, + mask: bool = True, + layer_results: bool = False, + ): + x = self.patch_embed(img) + batch_size, seq_len, _ = x.size() + + if mask: + mask_indices = self.make_mask( + img.size(0), + self.cfg.num_mask_patches, + self.cfg.min_mask_patches_per_block, + self.cfg.max_mask_patches_per_block, + ) + bool_mask = mask_indices.view(mask_indices.size(0), -1).bool() + else: + mask_indices = bool_mask = None + + cls_tokens = self.cls_emb.expand(batch_size, -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + + if self.ema is not None: + with torch.no_grad(): + self.ema.model.eval() + + if self.cfg.ema_transformer_only: + y = self.ema.model( + x, + layer_results="end" if self.cfg.end_of_block_targets else "fc", + ) + else: + y = self.ema.model( + img, + mask=False, + layer_results=True, + ) + + y = y[-self.cfg.average_top_k_layers :] + + permuted = False + if self.cfg.instance_norm_target_layer or self.cfg.batch_norm_target_layer: + y = [tl.transpose(1, 2) for tl in y] # BTC -> BCT + permuted = True + + if self.cfg.batch_norm_target_layer: + y = [ + F.batch_norm( + tl.float(), running_mean=None, running_var=None, training=True + ) + for tl in y + ] + + if self.cfg.instance_norm_target_layer: + y = [F.instance_norm(tl.float()) for tl in y] + + if permuted: + y = [tl.transpose(1, 2) for tl in y] # BCT -> BTC + + if self.cfg.layer_norm_target_layer: + y = [F.layer_norm(tl.float(), tl.shape[-1:]) for tl in y] + + y = sum(y) / len(y) + + if self.cfg.layer_norm_targets: + y = F.layer_norm(y.float(), y.shape[-1:]) + + if self.cfg.instance_norm_targets: + y = F.instance_norm(y.float().transpose(1, 2)).transpose(1, 2) + + y = y[bool_mask].float() + + if mask_indices is not None: + mask_token = self.mask_emb.expand(batch_size, seq_len, -1) + w = mask_indices.view(mask_indices.size(0), -1, 1).type_as(mask_token) + x[:, 1:] = x[:, 1:] * (1 - w) + mask_token * w + + if layer_results: + enc_layer_results = "end" if self.cfg.end_of_block_targets else "fc" + else: + enc_layer_results = None + + x = self.encoder(x, layer_results=enc_layer_results) + if layer_results or mask_indices is None: + return x + + x = x[bool_mask].float() + + if self.loss_beta == 0: + loss = F.mse_loss(x, y, reduction="none").sum(dim=-1) + else: + loss = F.smooth_l1_loss(x, y, reduction="none", beta=self.loss_beta).sum( + dim=-1 + ) + + if self.loss_scale > 0: + loss = loss * self.loss_scale + + result = { + "losses": {"regression": loss.sum()}, + "sample_size": loss.numel(), + "target_var": self.compute_var(y), + "pred_var": self.compute_var(x), + "ema_decay": self.ema.get_decay() * 1000, + } + return result + + @staticmethod + def compute_var(y): + y = y.view(-1, y.size(-1)) + if dist.is_initialized(): + zc = torch.tensor(y.size(0)).cuda() + zs = y.sum(dim=0) + zss = (y ** 2).sum(dim=0) + + dist.all_reduce(zc) + dist.all_reduce(zs) + dist.all_reduce(zss) + + var = zss / (zc - 1) - (zs ** 2) / (zc * (zc - 1)) + return torch.sqrt(var + 1e-6).mean() + else: + return torch.sqrt(y.var(dim=0) + 1e-6).mean() + + def remove_pretraining_modules(self, last_layer=None): + self.final_proj = None + self.ema = None + self.encoder.norm = nn.Identity() + self.mask_emb = None + if last_layer is not None: + self.encoder.layers = nn.ModuleList( + l for i, l in enumerate(self.encoder.layers) if i <= last_layer + ) + + +class PatchEmbed(nn.Module): + """Image to Patch Embedding""" + + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): + super().__init__() + if isinstance(img_size, int): + img_size = img_size, img_size + if isinstance(patch_size, int): + patch_size = patch_size, patch_size + num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) + self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) + self.img_size = img_size + self.patch_size = patch_size + self.num_patches = num_patches + + self.conv = nn.Conv2d( + in_chans, embed_dim, kernel_size=patch_size, stride=patch_size + ) + + def forward(self, x): + # BCHW -> BTC + x = self.conv(x).flatten(2).transpose(1, 2) + return x + + +class Attention(nn.Module): + def __init__( + self, + dim, + num_heads=8, + qkv_bias=True, + attn_drop=0.0, + proj_drop=0.0, + window_size=None, + attn_head_dim=None, + ): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + if attn_head_dim is not None: + head_dim = attn_head_dim + all_head_dim = head_dim * self.num_heads + self.scale = head_dim ** -0.5 + + self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False) + if qkv_bias: + self.q_bias = nn.Parameter(torch.zeros(all_head_dim)) + self.v_bias = nn.Parameter(torch.zeros(all_head_dim)) + else: + self.q_bias = None + self.v_bias = None + + if window_size: + self.window_size = window_size + self.num_relative_distance = (2 * window_size[0] - 1) * ( + 2 * window_size[1] - 1 + ) + 3 + self.relative_position_bias_table = nn.Parameter( + torch.zeros(self.num_relative_distance, num_heads) + ) # 2*Wh-1 * 2*Ww-1, nH + # cls to token & token 2 cls & cls to cls + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(window_size[0]) + coords_w = torch.arange(window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = ( + coords_flatten[:, :, None] - coords_flatten[:, None, :] + ) # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute( + 1, 2, 0 + ).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * window_size[1] - 1 + relative_position_index = torch.zeros( + size=(window_size[0] * window_size[1] + 1,) * 2, + dtype=relative_coords.dtype, + ) + relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + relative_position_index[0, 0:] = self.num_relative_distance - 3 + relative_position_index[0:, 0] = self.num_relative_distance - 2 + relative_position_index[0, 0] = self.num_relative_distance - 1 + + self.register_buffer("relative_position_index", relative_position_index) + else: + self.window_size = None + self.relative_position_bias_table = None + self.relative_position_index = None + + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(all_head_dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x, rel_pos_bias=None): + B, N, C = x.shape + qkv_bias = None + if self.q_bias is not None: + qkv_bias = torch.cat( + ( + self.q_bias, + torch.zeros_like(self.v_bias, requires_grad=False), + self.v_bias, + ) + ) + # qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) + qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + q, k, v = ( + qkv[0], + qkv[1], + qkv[2], + ) # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = q @ k.transpose(-2, -1) + + if self.relative_position_bias_table is not None: + assert 1==2 + relative_position_bias = self.relative_position_bias_table[ + self.relative_position_index.view(-1) + ].view( + self.window_size[0] * self.window_size[1] + 1, + self.window_size[0] * self.window_size[1] + 1, + -1, + ) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute( + 2, 0, 1 + ).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + print("attn.size() :", attn.size()) + print("rel_pos_bias.size() :", rel_pos_bias.size()) + if rel_pos_bias is not None: + attn = attn + rel_pos_bias + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, -1) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class RelativePositionBias(nn.Module): + def __init__(self, window_size, num_heads): + super().__init__() + self.window_size = window_size + self.num_relative_distance = (2 * window_size[0] - 1) * ( + 2 * window_size[1] - 1 + ) + 3 + self.relative_position_bias_table = nn.Parameter( + torch.zeros(self.num_relative_distance, num_heads) + ) + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(window_size[0]) + coords_w = torch.arange(window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = ( + coords_flatten[:, :, None] - coords_flatten[:, None, :] + ) # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute( + 1, 2, 0 + ).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * window_size[1] - 1 + relative_position_index = torch.zeros( + size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype + ) + relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + relative_position_index[0, 0:] = self.num_relative_distance - 3 + relative_position_index[0:, 0] = self.num_relative_distance - 2 + relative_position_index[0, 0] = self.num_relative_distance - 1 + + self.register_buffer("relative_position_index", relative_position_index) + + def forward(self): + relative_position_bias = self.relative_position_bias_table[ + self.relative_position_index.view(-1) + ].view( + self.window_size[0] * self.window_size[1] + 1, + self.window_size[0] * self.window_size[1] + 1, + -1, + ) # Wh*Ww,Wh*Ww,nH + print("self.window_size :", self.window_size) + print("self.num_relative_distance :", self.num_relative_distance) + print("self.relative_position_index :", self.relative_position_index.size(), self.relative_position_index) + print("relative_position_bias.size(), relative_position_bias :",relative_position_bias.size(), relative_position_bias) + print("self.relative_position_bias_table.size(), self.relative_position_bias_table :",self.relative_position_bias_table.size(), self.relative_position_bias_table) + return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + if self.drop_prob == 0.0 or not self.training: + return x + keep_prob = 1 - self.drop_prob + shape = (x.shape[0],) + (1,) * ( + x.ndim - 1 + ) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) + random_tensor.floor_() + output = x.div(keep_prob) * random_tensor + return output + + def extra_repr(self) -> str: + return "p={}".format(self.drop_prob) + + +class Block(nn.Module): + def __init__( + self, + dim, + num_heads, + mlp_ratio=4.0, + drop=0.0, + attn_drop=0.0, + drop_path=0.0, + init_values=None, + window_size=None, + ): + super().__init__() + + self.norm1 = nn.LayerNorm(dim) + self.attn = Attention( + dim, + num_heads=num_heads, + attn_drop=attn_drop, + proj_drop=drop, + window_size=window_size, + ) + + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + self.norm2 = nn.LayerNorm(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + + self.mlp = nn.Sequential( + nn.Linear(dim, mlp_hidden_dim), + nn.GELU(), + nn.Linear(mlp_hidden_dim, dim), + nn.Dropout(drop), + ) + + if init_values > 0: + self.gamma_1 = nn.Parameter( + init_values * torch.ones((dim)), requires_grad=True + ) + self.gamma_2 = nn.Parameter( + init_values * torch.ones((dim)), requires_grad=True + ) + else: + self.gamma_1, self.gamma_2 = None, None + + def forward(self, x, rel_pos_bias=None): + print("inside block :", x.size()) + if self.gamma_1 is None: + x = x + self.drop_path(self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias)) + fc_feature = self.drop_path(self.mlp(self.norm2(x))) + x = x + fc_feature + else: + x = x + self.drop_path( + self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias) + ) + fc_feature = self.drop_path(self.gamma_2 * self.mlp(self.norm2(x))) + x = x + fc_feature + return x, fc_feature + + +class TransformerEncoder(nn.Module): + def __init__(self, cfg: Data2VecVisionConfig, patch_shape): + super().__init__() + + self.rel_pos_bias = None + if cfg.shared_rel_pos_bias: + self.rel_pos_bias = RelativePositionBias( + window_size=patch_shape, num_heads=cfg.num_heads + ) + + dpr = [ + x.item() for x in torch.linspace(0, cfg.drop_path, cfg.depth) + ] # stochastic depth decay rule + + print("TransformerEncoder > patch_shape :", patch_shape) + self.blocks = nn.ModuleList( + Block( + dim=cfg.embed_dim, + num_heads=cfg.num_heads, + attn_drop=cfg.attention_dropout, + drop_path=dpr[i], + init_values=cfg.layer_scale_init_value, + window_size=patch_shape if not cfg.shared_rel_pos_bias else None, + ) + for i in range(cfg.depth) + ) + + self.norm = nn.LayerNorm(cfg.embed_dim) + + self.apply(self.init_weights) + self.fix_init_weight() + + def init_weights(self, m): + std = 0.02 + if isinstance(m, nn.Linear): + nn.init.trunc_normal_(m.weight, std=std) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + nn.init.trunc_normal_(m.weight, std=std) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def fix_init_weight(self): + def rescale(param, layer_id): + param.div_(math.sqrt(2.0 * layer_id)) + + for layer_id, layer in enumerate(self.blocks): + rescale(layer.attn.proj.weight.data, layer_id + 1) + rescale(layer.mlp[2].weight.data, layer_id + 1) + + def extract_features(self, x, layer_results): + + rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None + + z = [] + for i, blk in enumerate(self.blocks): + x, fc_feature = blk(x, rel_pos_bias=rel_pos_bias) + if layer_results == "end": + z.append(x) + elif layer_results == "fc": + z.append(fc_feature) + + return z if layer_results else self.norm(x) + + def forward(self, x, layer_results=None): + x = self.extract_features(x, layer_results=layer_results) + if layer_results: + return [z[:, 1:] for z in x] + + x = x[:, 1:] + return x diff --git a/examples/data2vec/models/mae.py b/examples/data2vec/models/mae.py new file mode 100644 index 000000000..5101e070e --- /dev/null +++ b/examples/data2vec/models/mae.py @@ -0,0 +1,825 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# The code in this file is adapted from the BeiT implementation which can be found here: +# https://github.com/microsoft/unilm/tree/master/beit + +import logging +from dataclasses import dataclass +from functools import partial + +from timm.models.vision_transformer import PatchEmbed, Block + +import torch +import torch.nn as nn + +import numpy as np + +from fairseq.dataclass import FairseqDataclass +from fairseq.models import BaseFairseqModel, register_model +from fairseq.models.wav2vec.wav2vec2 import TransformerSentenceEncoderLayer + +from apex.normalization import FusedLayerNorm +import torch.nn.functional as F + + +logger = logging.getLogger(__name__) + + +@dataclass +class MaeConfig(FairseqDataclass): + input_size: int = 224 + in_chans: int = 3 + patch_size: int = 16 + embed_dim: int = 768 + depth: int = 12 + num_heads: int = 12 + decoder_embed_dim: int = 512 + decoder_depth: int = 8 + decoder_num_heads: int = 16 + mlp_ratio: int = 4 + norm_eps: float = 1e-6 + + drop_path_rate: float = 0.0 + + mask_ratio: float = 0.75 + norm_pix_loss: bool = True + + w2v_block: bool = False + alt_block: bool = False + alt_block2: bool = False + alt_attention: bool = False + block_dropout: float = 0 + attention_dropout: float = 0 + activation_dropout: float = 0 + layer_norm_first: bool = False + + fused_ln: bool = True + end_of_block_targets: bool = True + + no_decoder_embed: bool = False + no_decoder_pos_embed: bool = False + mask_noise_std: float = 0 + + single_qkv: bool = False + use_rel_pos_bias: bool = False + no_cls: bool = False + + +def modify_relative_position_bias(orig_bias, bsz, mask): + if mask is None: + return orig_bias.unsqueeze(0).repeat( + bsz, 1, 1, 1 + ) # heads x seq_len x seq_len => bsz x heads x seq_len x seq_len + heads, max_seq_len, max_seq_len = orig_bias.shape # includes CLS token + mask_for_rel_pos_bias = torch.cat( + (torch.zeros(bsz, 1, dtype=mask.dtype, device=mask.device), mask), dim=1 + ).bool() # bsz x seqlen (add CLS token) + unmasked_for_rel_pos_bias = ~mask_for_rel_pos_bias + unmasked_for_rel_pos_bias = unmasked_for_rel_pos_bias.unsqueeze(1).repeat( + 1, heads, 1 + ) # bsz x seq_len => bsz x heads x seq_len + b_t_t_rel_pos_bias = orig_bias.unsqueeze(0).repeat( + bsz, 1, 1, 1 + ) # heads x seq_len x seq_len => bsz x heads x seq_len x seq_len + b_t_t_rel_pos_bias = b_t_t_rel_pos_bias.masked_select( + unmasked_for_rel_pos_bias.unsqueeze(-1) + ) + b_t_t_rel_pos_bias = b_t_t_rel_pos_bias.view(bsz, heads, -1, max_seq_len) + new_len = b_t_t_rel_pos_bias.size(-2) + b_t_t_rel_pos_bias = b_t_t_rel_pos_bias.masked_select( + unmasked_for_rel_pos_bias.unsqueeze(-2) + ) + b_t_t_rel_pos_bias = b_t_t_rel_pos_bias.view(bsz, heads, new_len, new_len) + return b_t_t_rel_pos_bias + + +class AltBlock(nn.Module): + def __init__( + self, + dim, + num_heads, + mlp_ratio=4.0, + qkv_bias=False, + qk_scale=None, + drop=0.0, + attn_drop=0.0, + drop_path=0.0, + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + layer_norm_first=True, + ffn_targets=False, + use_rel_pos_bias=False, + window_size=None, + alt_attention=False, + ): + super().__init__() + + self.layer_norm_first = layer_norm_first + self.ffn_targets = ffn_targets + + from timm.models.vision_transformer import Attention, DropPath, Mlp + + self.norm1 = norm_layer(dim) + self.use_rel_pos_bias = use_rel_pos_bias + if use_rel_pos_bias: + self.attn = AltAttention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop, + window_size=window_size, + ) + else: + if alt_attention: + from .multi.modules import AltAttention as AltAttention2 + self.attn = AltAttention2( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop, + ) + else: + self.attn = Attention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop, + ) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + ) + + def forward(self, x, rel_pos_bias=None, pos_mask=None): + if self.layer_norm_first: + if self.use_rel_pos_bias: + x = x + self.drop_path( + self.attn( + self.norm1(x), rel_pos_bias=rel_pos_bias, pos_mask=pos_mask + ) + ) + else: + x = x + self.drop_path(self.attn(self.norm1(x))) + t = self.mlp(self.norm2(x)) + x = x + self.drop_path(t) + if not self.ffn_targets: + t = x + return x, t + else: + if self.use_rel_pos_bias: + x = x + self.drop_path( + self.attn(x, rel_pos_bias=rel_pos_bias, pos_mask=pos_mask) + ) + else: + x = x + self.drop_path(self.attn(x)) + r = x = self.norm1(x) + x = self.mlp(x) + t = x + x = self.norm2(r + self.drop_path(x)) + if not self.ffn_targets: + t = x + return x, t + + +class AltAttention(nn.Module): + def __init__( + self, + dim, + num_heads=8, + qkv_bias=True, + qk_scale=None, + attn_drop=0.0, + proj_drop=0.0, + window_size=None, + attn_head_dim=None, + ): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + if attn_head_dim is not None: + head_dim = attn_head_dim + all_head_dim = head_dim * self.num_heads + self.scale = qk_scale or head_dim ** -0.5 + + self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False) + if qkv_bias: + self.q_bias = nn.Parameter(torch.zeros(all_head_dim)) + self.v_bias = nn.Parameter(torch.zeros(all_head_dim)) + else: + self.q_bias = None + self.v_bias = None + + if window_size: + self.window_size = window_size + self.num_relative_distance = (2 * window_size[0] - 1) * ( + 2 * window_size[1] - 1 + ) + 3 + self.relative_position_bias_table = nn.Parameter( + torch.zeros(self.num_relative_distance, num_heads) + ) # 2*Wh-1 * 2*Ww-1, nH + # cls to token & token 2 cls & cls to cls + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(window_size[0]) + coords_w = torch.arange(window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = ( + coords_flatten[:, :, None] - coords_flatten[:, None, :] + ) # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute( + 1, 2, 0 + ).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * window_size[1] - 1 + relative_position_index = torch.zeros( + size=(window_size[0] * window_size[1] + 1,) * 2, + dtype=relative_coords.dtype, + ) + relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + relative_position_index[0, 0:] = self.num_relative_distance - 3 + relative_position_index[0:, 0] = self.num_relative_distance - 2 + relative_position_index[0, 0] = self.num_relative_distance - 1 + + self.register_buffer("relative_position_index", relative_position_index) + else: + self.window_size = None + self.relative_position_bias_table = None + self.relative_position_index = None + + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(all_head_dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x, rel_pos_bias=None, pos_mask=None): + B, N, C = x.shape + qkv_bias = None + if self.q_bias is not None: + qkv_bias = torch.cat( + ( + self.q_bias, + torch.zeros_like(self.v_bias, requires_grad=False), + self.v_bias, + ) + ) + # qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) + qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + q, k, v = ( + qkv[0], + qkv[1], + qkv[2], + ) # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = q @ k.transpose(-2, -1) + + if self.relative_position_bias_table is not None: + relative_position_bias = self.relative_position_bias_table[ + self.relative_position_index.view(-1) + ].view( + self.window_size[0] * self.window_size[1] + 1, + self.window_size[0] * self.window_size[1] + 1, + -1, + ) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute( + 2, 0, 1 + ).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + modify_relative_position_bias( + relative_position_bias, x.size(0), pos_mask + ) + + if rel_pos_bias is not None: + attn = attn + rel_pos_bias + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, -1) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class RelativePositionBias(nn.Module): + def __init__(self, window_size, num_heads): + super().__init__() + self.window_size = window_size + self.num_relative_distance = (2 * window_size[0] - 1) * ( + 2 * window_size[1] - 1 + ) + 3 + self.relative_position_bias_table = nn.Parameter( + torch.zeros(self.num_relative_distance, num_heads) + ) + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(window_size[0]) + coords_w = torch.arange(window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = ( + coords_flatten[:, :, None] - coords_flatten[:, None, :] + ) # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute( + 1, 2, 0 + ).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * window_size[1] - 1 + relative_position_index = torch.zeros( + size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype + ) + relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + relative_position_index[0, 0:] = self.num_relative_distance - 3 + relative_position_index[0:, 0] = self.num_relative_distance - 2 + relative_position_index[0, 0] = self.num_relative_distance - 1 + + self.register_buffer("relative_position_index", relative_position_index) + + def forward(self): + relative_position_bias = self.relative_position_bias_table[ + self.relative_position_index.view(-1) + ].view( + self.window_size[0] * self.window_size[1] + 1, + self.window_size[0] * self.window_size[1] + 1, + -1, + ) # Wh*Ww,Wh*Ww,nH + return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + + +def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): + """ + grid_size: int of the grid height and width + return: + pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + grid_h = np.arange(grid_size, dtype=np.float32) + grid_w = np.arange(grid_size, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + + grid = grid.reshape([2, 1, grid_size, grid_size]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token: + pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + assert embed_dim % 2 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + return emb + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + out: (M, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=np.float) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000 ** omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + + +def interpolate_pos_embed(model, checkpoint_model): + if "pos_embed" in checkpoint_model: + pos_embed_checkpoint = checkpoint_model["pos_embed"] + embedding_size = pos_embed_checkpoint.shape[-1] + num_patches = model.patch_embed.num_patches + num_extra_tokens = model.pos_embed.shape[-2] - num_patches + # height (== width) for the checkpoint position embedding + orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) + # height (== width) for the new position embedding + new_size = int(num_patches ** 0.5) + # class_token and dist_token are kept unchanged + if orig_size != new_size: + print( + "Position interpolate from %dx%d to %dx%d" + % (orig_size, orig_size, new_size, new_size) + ) + extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] + # only the position tokens are interpolated + pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] + pos_tokens = pos_tokens.reshape( + -1, orig_size, orig_size, embedding_size + ).permute(0, 3, 1, 2) + pos_tokens = torch.nn.functional.interpolate( + pos_tokens, + size=(new_size, new_size), + mode="bicubic", + align_corners=False, + ) + pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) + new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) + checkpoint_model["pos_embed"] = new_pos_embed + + +@register_model("mae", dataclass=MaeConfig) +class MaeModel(BaseFairseqModel): + def __init__(self, cfg: MaeConfig): + super().__init__() + self.cfg = cfg + + self.mask_ratio = cfg.mask_ratio + + # -------------------------------------------------------------------------- + # MAE encoder specifics + self.patch_embed = PatchEmbed( + cfg.input_size, cfg.patch_size, cfg.in_chans, cfg.embed_dim + ) + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, cfg.embed_dim)) if not cfg.no_cls else None + self.pos_embed = nn.Parameter( + torch.zeros(1, num_patches + int(not cfg.no_cls), cfg.embed_dim), requires_grad=False + ) # fixed sin-cos embedding + + norm_layer = partial(nn.LayerNorm, eps=cfg.norm_eps) + + dpr = [ + x.item() for x in torch.linspace(0, cfg.drop_path_rate, cfg.depth) + ] # stochastic depth decay rule + + def make_block(drop_path): + if cfg.w2v_block: + return TransformerSentenceEncoderLayer( + embedding_dim=cfg.embed_dim, + ffn_embedding_dim=cfg.embed_dim * cfg.mlp_ratio, + num_attention_heads=cfg.num_heads, + dropout=cfg.block_dropout, + attention_dropout=cfg.attention_dropout, + activation_dropout=cfg.activation_dropout, + activation_fn="gelu", + layer_norm_first=cfg.layer_norm_first, + drop_path=drop_path, + norm_eps=1e-6, + single_qkv=cfg.single_qkv, + fused_ln=cfg.fused_ln, + ) + elif cfg.alt_block: + window_size = ( + cfg.input_size // self.patch_embed.patch_size[0], + cfg.input_size // self.patch_embed.patch_size[1], + ) + return AltBlock( + cfg.embed_dim, + cfg.num_heads, + cfg.mlp_ratio, + qkv_bias=True, + qk_scale=None, + norm_layer=norm_layer, + drop_path=drop_path, + layer_norm_first=cfg.layer_norm_first, + ffn_targets=not cfg.end_of_block_targets, + use_rel_pos_bias=cfg.use_rel_pos_bias, + window_size=window_size + if (self.cfg.use_rel_pos_bias and not self.cfg.shared_rel_pos_bias) + else None, + alt_attention=cfg.alt_attention, + ) + elif cfg.alt_block2: + from .multi.modules import AltBlock as AltBlock2 + return AltBlock2( + cfg.embed_dim, + cfg.num_heads, + cfg.mlp_ratio, + qkv_bias=True, + qk_scale=None, + norm_layer=norm_layer, + drop_path=drop_path, + layer_norm_first=cfg.layer_norm_first, + ffn_targets=not cfg.end_of_block_targets, + ) + else: + return Block( + cfg.embed_dim, + cfg.num_heads, + cfg.mlp_ratio, + qkv_bias=True, + qk_scale=None, + norm_layer=norm_layer, + drop_path=drop_path, + ) + + self.blocks = nn.ModuleList([make_block(dpr[i]) for i in range(cfg.depth)]) + self.norm = norm_layer(cfg.embed_dim) + # -------------------------------------------------------------------------- + + # -------------------------------------------------------------------------- + # MAE decoder specifics + self.decoder_embed = ( + nn.Linear(cfg.embed_dim, cfg.decoder_embed_dim, bias=True) + if not cfg.no_decoder_embed + else None + ) + + self.mask_token = ( + nn.Parameter( + torch.zeros( + 1, + 1, + cfg.decoder_embed_dim + if not cfg.no_decoder_embed + else cfg.embed_dim, + ) + ) + if cfg.mask_noise_std <= 0 + else None + ) + + self.decoder_pos_embed = ( + nn.Parameter( + torch.zeros( + 1, + num_patches + 1, + cfg.decoder_embed_dim + if not cfg.no_decoder_embed + else cfg.embed_dim, + ), + requires_grad=False, + ) + if not cfg.no_decoder_pos_embed + else None + ) + + self.decoder_blocks = nn.ModuleList( + [ + Block( + cfg.decoder_embed_dim, + cfg.decoder_num_heads, + cfg.mlp_ratio, + qkv_bias=True, + qk_scale=None, + norm_layer=norm_layer, + ) + for _ in range(cfg.decoder_depth) + ] + ) + + self.decoder_norm = norm_layer(cfg.decoder_embed_dim) + self.decoder_pred = nn.Linear( + cfg.decoder_embed_dim, cfg.patch_size ** 2 * cfg.in_chans, bias=True + ) # decoder to patch + # -------------------------------------------------------------------------- + + self.norm_pix_loss = cfg.norm_pix_loss + + self.initialize_weights() + + for pn, p in self.named_parameters(): + if len(p.shape) == 1 or pn.endswith(".bias"): + p.param_group = "no_decay" + else: + p.param_group = "with_decay" + + def initialize_weights(self): + # initialization + # initialize (and freeze) pos_embed by sin-cos embedding + pos_embed = get_2d_sincos_pos_embed( + self.pos_embed.shape[-1], + int(self.patch_embed.num_patches ** 0.5), + cls_token=not self.cfg.no_cls, + ) + self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) + + if self.decoder_pos_embed is not None: + decoder_pos_embed = get_2d_sincos_pos_embed( + self.decoder_pos_embed.shape[-1], + int(self.patch_embed.num_patches ** 0.5), + cls_token=not self.cfg.no_cls, + ) + self.decoder_pos_embed.data.copy_( + torch.from_numpy(decoder_pos_embed).float().unsqueeze(0) + ) + + # initialize patch_embed like nn.Linear (instead of nn.Conv2d) + w = self.patch_embed.proj.weight.data + torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + + # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.) + if self.cls_token is not None: + torch.nn.init.normal_(self.cls_token, std=0.02) + + if self.mask_token is not None: + torch.nn.init.normal_(self.mask_token, std=0.02) + + # initialize nn.Linear and nn.LayerNorm + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + # we use xavier_uniform following official JAX ViT: + torch.nn.init.xavier_uniform_(m.weight) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm) or isinstance(m, FusedLayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def patchify(self, imgs): + """ + imgs: (N, 3, H, W) + x: (N, L, patch_size**2 *3) + """ + p = self.patch_embed.patch_size[0] + assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0 + + h = w = imgs.shape[2] // p + x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p)) + x = torch.einsum("nchpwq->nhwpqc", x) + x = x.reshape(shape=(imgs.shape[0], h * w, p ** 2 * 3)) + return x + + def unpatchify(self, x): + """ + x: (N, L, patch_size**2 *3) + imgs: (N, 3, H, W) + """ + p = self.patch_embed.patch_size[0] + h = w = int(x.shape[1] ** 0.5) + assert h * w == x.shape[1] + + x = x.reshape(shape=(x.shape[0], h, w, p, p, 3)) + x = torch.einsum("nhwpqc->nchpwq", x) + imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p)) + return imgs + + def random_masking(self, x, mask_ratio): + """ + Perform per-sample random masking by per-sample shuffling. + Per-sample shuffling is done by argsort random noise. + x: [N, L, D], sequence + """ + N, L, D = x.shape # batch, length, dim + len_keep = int(L * (1 - mask_ratio)) + + noise = torch.rand(N, L, device=x.device) # noise in [0, 1] + + # sort noise for each sample + ids_shuffle = torch.argsort( + noise, dim=1 + ) # ascend: small is keep, large is remove + ids_restore = torch.argsort(ids_shuffle, dim=1) + + # keep the first subset + ids_keep = ids_shuffle[:, :len_keep] + x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) + + # generate the binary mask: 0 is keep, 1 is remove + mask = torch.ones([N, L], device=x.device) + mask[:, :len_keep] = 0 + # unshuffle to get the binary mask + mask = torch.gather(mask, dim=1, index=ids_restore) + + return x_masked, mask, ids_restore # x_masked is actually unmasked x + + @classmethod + def build_model(cls, cfg: MaeConfig, task=None): + """Build a new model instance.""" + + return cls(cfg) + + def forward_encoder(self, x, mask_ratio): + # embed patches + x = self.patch_embed(x) + + # add pos embed w/o cls token + # if self.cls_token is not None: + # x = x + self.pos_embed + # else: + x = x + self.pos_embed[:, 1:, :] + + # masking: length -> length * mask_ratio + if mask_ratio > 0: + x, mask, ids_restore = self.random_masking(x, mask_ratio) + else: + mask = ids_restore = None + + # append cls token + if self.cls_token is not None: + cls_token = self.cls_token + self.pos_embed[:, :1, :] + cls_tokens = cls_token.expand(x.shape[0], -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + + # apply Transformer blocks + for blk in self.blocks: + x = blk(x) + + if self.norm is not None: + x = self.norm(x) + + return x, mask, ids_restore + + def forward_decoder(self, x, ids_restore): + # embed tokens + x = self.decoder_embed(x) + + # append mask tokens to sequence + mask_tokens = self.mask_token.repeat( + x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1 + ) + if self.cls_token is not None: + x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token + else: + x_ = torch.cat([x, mask_tokens], dim=1) # no cls token + + x_ = torch.gather( + x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]) + ) # unshuffle + + if self.cls_token is not None: + x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token + + # add pos embed + x = x + self.decoder_pos_embed + + # apply Transformer blocks + for blk in self.decoder_blocks: + x = blk(x) + x = self.decoder_norm(x) + + # predictor projection + x = self.decoder_pred(x) + + if self.cls_token is not None: + # remove cls token + x = x[:, 1:, :] + + return x + + def forward_loss(self, imgs, pred, mask): + """ + imgs: [N, 3, H, W] + pred: [N, L, p*p*3] + mask: [N, L], 0 is keep, 1 is remove, + """ + target = self.patchify(imgs) + if self.norm_pix_loss: + mean = target.mean(dim=-1, keepdim=True) + var = target.var(dim=-1, keepdim=True) + target = (target - mean) / (var + 1.0e-6) ** 0.5 + + loss = (pred - target) ** 2 + loss = loss.mean(dim=-1) # [N, L], mean loss per patch + + loss = (loss * mask).sum() + return loss, mask.sum() + + def forward(self, imgs, predictions_only=False): + latent, mask, ids_restore = self.forward_encoder( + imgs, self.mask_ratio if not predictions_only else 0 + ) + + if predictions_only: + return latent + + pred = self.forward_decoder(latent, ids_restore) # [N, L, p*p*3] + loss, sample_size = self.forward_loss(imgs, pred, mask) + + result = { + "losses": {"regression": loss}, + "sample_size": sample_size, + } + return result + + def remove_pretraining_modules(self): + self.decoder_embed = None + self.decoder_blocks = None + self.decoder_norm = None + self.decoder_pos_embed = None + self.decoder_pred = None + self.mask_token = None + if self.cfg.layer_norm_first: + self.norm = None diff --git a/examples/data2vec/models/mae_image_classification.py b/examples/data2vec/models/mae_image_classification.py new file mode 100644 index 000000000..e304618dc --- /dev/null +++ b/examples/data2vec/models/mae_image_classification.py @@ -0,0 +1,386 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# The code in this file is adapted from the BeiT implementation which can be found here: +# https://github.com/microsoft/unilm/tree/master/beit + +import logging + +from dataclasses import dataclass +from enum import Enum, auto +from typing import Any, Optional + +import numpy as np +from omegaconf import II, MISSING + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from fairseq import checkpoint_utils, tasks +from omegaconf import open_dict + +from fairseq.dataclass import FairseqDataclass +from fairseq.models import BaseFairseqModel, register_model +from .mae import interpolate_pos_embed + + +logger = logging.getLogger(__name__) + + +class PredictionMode(Enum): + MEAN_POOLING = auto() + CLS_TOKEN = auto() + LIN_SOFTMAX = auto() + + +@dataclass +class MaeImageClassificationConfig(FairseqDataclass): + model_path: str = MISSING + no_pretrained_weights: bool = False + linear_classifier: bool = False + num_classes: int = 1000 + mixup: float = 0.8 + cutmix: float = 1.0 + label_smoothing: float = 0.1 + + drop_path_rate: float = 0.1 + layer_decay: float = 0.65 + + mixup_prob: float = 1.0 + mixup_switch_prob: float = 0.5 + mixup_mode: str = "batch" + + pretrained_model_args: Any = None + data: str = II("task.data") + + norm_eps: Optional[float] = None + + remove_alibi: bool = False + + # regularization overwrites + encoder_dropout: float = 0 + post_mlp_drop: float = 0 + attention_dropout: float = 0 + activation_dropout: float = 0.0 + dropout_input: float = 0.0 + layerdrop: float = 0.0 + + prenet_layerdrop: float = 0 + prenet_dropout: float = 0 + + use_fc_norm: bool = True + prediction_mode: PredictionMode = PredictionMode.MEAN_POOLING + + no_decay_blocks: bool = True + + +def get_layer_id_for_vit(name, num_layers): + """ + Assign a parameter with its layer id + Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33 + """ + if name in ["cls_token", "pos_embed"]: + return 0 + elif name.startswith("patch_embed"): + return 0 + elif name.startswith("rel_pos_bias"): + return num_layers - 1 + elif name.startswith("blocks"): + return int(name.split(".")[1]) + 1 + else: + return num_layers + + +@register_model("mae_image_classification", dataclass=MaeImageClassificationConfig) +class MaeImageClassificationModel(BaseFairseqModel): + def __init__(self, cfg: MaeImageClassificationConfig): + super().__init__() + self.cfg = cfg + + if cfg.pretrained_model_args is None: + state = checkpoint_utils.load_checkpoint_to_cpu(cfg.model_path, {}) + pretrained_args = state.get("cfg", None) + + pretrained_args.criterion = None + pretrained_args.lr_scheduler = None + + logger.info(pretrained_args.model) + + with open_dict(pretrained_args.model): + pretrained_args.model.drop_path_rate = cfg.drop_path_rate + if cfg.norm_eps is not None: + pretrained_args.model.norm_eps = cfg.norm_eps + + cfg.pretrained_model_args = pretrained_args + + logger.info(pretrained_args) + else: + state = None + pretrained_args = cfg.pretrained_model_args + + if "data" in pretrained_args.task: + pretrained_args.task.data = cfg.data + elif "image" in pretrained_args.task: + pretrained_args.task.image.data = cfg.data + + if "modalities" in pretrained_args.model: + prenet_blocks = pretrained_args.model["modalities"]["image"]["prenet_depth"] + model_blocks = pretrained_args.model["depth"] + with open_dict(pretrained_args): + dpr = np.linspace(0, cfg.drop_path_rate, model_blocks).tolist() + pretrained_args.model["modalities"]["image"][ + "start_drop_path_rate" + ] = dpr[0] + pretrained_args.model["modalities"]["image"][ + "end_drop_path_rate" + ] = max(0, dpr[prenet_blocks - 1]) + pretrained_args.model["start_drop_path_rate"] = dpr[prenet_blocks] + pretrained_args.model["end_drop_path_rate"] = dpr[-1] + + if "mae_masking" in pretrained_args.model["modalities"]["image"]: + del pretrained_args.model["modalities"]["image"]["mae_masking"] + + if cfg.remove_alibi: + pretrained_args.model["modalities"]["image"][ + "use_alibi_encoder" + ] = False + if ( + state is not None + and "modality_encoders.IMAGE.alibi_bias" in state["model"] + ): + del state["model"]["modality_encoders.IMAGE.alibi_bias"] + + pretrained_args.model["encoder_dropout"] = cfg.encoder_dropout + pretrained_args.model["post_mlp_drop"] = cfg.post_mlp_drop + pretrained_args.model["attention_dropout"] = cfg.attention_dropout + pretrained_args.model["activation_dropout"] = cfg.activation_dropout + pretrained_args.model["dropout_input"] = cfg.dropout_input + pretrained_args.model["layerdrop"] = cfg.layerdrop + + pretrained_args.model["modalities"]["image"][ + "prenet_layerdrop" + ] = cfg.prenet_layerdrop + pretrained_args.model["modalities"]["image"][ + "prenet_dropout" + ] = cfg.prenet_dropout + else: + # not d2v multi + with open_dict(pretrained_args): + pretrained_args.model["drop_path_rate"] = cfg.drop_path_rate + pretrained_args.model["block_dropout"] = cfg.encoder_dropout + pretrained_args.model["attention_dropout"] = cfg.attention_dropout + pretrained_args.model["activation_dropout"] = cfg.activation_dropout + + task = tasks.setup_task(pretrained_args.task) + model = task.build_model(pretrained_args.model, from_checkpoint=True) + + self.d2v_multi = "data2vec_multi" in pretrained_args.model._name + self.linear_classifier = cfg.linear_classifier + + self.model = model + + if state is not None and not cfg.no_pretrained_weights: + interpolate_pos_embed(model, state) + + if "modality_encoders.IMAGE.positional_encoder.pos_embed" in state["model"]: + state["model"][ + "modality_encoders.IMAGE.positional_encoder.positions" + ] = state["model"][ + "modality_encoders.IMAGE.positional_encoder.pos_embed" + ] + del state["model"][ + "modality_encoders.IMAGE.positional_encoder.pos_embed" + ] + if "modality_encoders.IMAGE.encoder_mask" in state["model"]: + del state["model"]["modality_encoders.IMAGE.encoder_mask"] + + model.load_state_dict(state["model"], strict=True) + + if self.d2v_multi: + model.remove_pretraining_modules(modality="image") + else: + model.remove_pretraining_modules() + + if self.linear_classifier: + model.requires_grad_(False) + + self.fc_norm = None + if self.cfg.use_fc_norm: + self.fc_norm = nn.LayerNorm(pretrained_args.model.embed_dim, eps=1e-6) + nn.init.constant_(self.fc_norm.bias, 0) + nn.init.constant_(self.fc_norm.weight, 1.0) + + self.head = nn.Linear(pretrained_args.model.embed_dim, cfg.num_classes) + + nn.init.trunc_normal_(self.head.weight, std=0.02) + nn.init.constant_(self.head.bias, 0) + + self.mixup_fn = None + + if cfg.mixup > 0 or cfg.cutmix > 0: + from timm.data import Mixup + + self.mixup_fn = Mixup( + mixup_alpha=cfg.mixup, + cutmix_alpha=cfg.cutmix, + cutmix_minmax=None, + prob=cfg.mixup_prob, + switch_prob=cfg.mixup_switch_prob, + mode=cfg.mixup_mode, + label_smoothing=cfg.label_smoothing, + num_classes=cfg.num_classes, + ) + + if self.model.norm is not None: + for pn, p in self.model.norm.named_parameters(): + if len(p.shape) == 1 or pn.endswith(".bias"): + p.optim_overrides = {"optimizer": {"weight_decay_scale": 0}} + + if self.fc_norm is not None: + for pn, p in self.fc_norm.named_parameters(): + if len(p.shape) == 1 or pn.endswith(".bias"): + p.optim_overrides = {"optimizer": {"weight_decay_scale": 0}} + + for pn, p in self.head.named_parameters(): + if len(p.shape) == 1 or pn.endswith(".bias"): + p.optim_overrides = {"optimizer": {"weight_decay_scale": 0}} + + if self.d2v_multi: + mod_encs = list(model.modality_encoders.values()) + assert len(mod_encs) == 1, len(mod_encs) + blocks = list(mod_encs[0].context_encoder.blocks) + list(model.blocks) + else: + blocks = model.blocks + + num_layers = len(blocks) + 1 + layer_scales = list( + cfg.layer_decay ** (num_layers - i) for i in range(num_layers + 1) + ) + + if self.d2v_multi: + for n, p in self.model.named_parameters(): + optimizer_override_dict = {} + + if len(p.shape) == 1 or n.endswith(".bias"): + optimizer_override_dict["weight_decay_scale"] = 0 + + p.optim_overrides = {"optimizer": optimizer_override_dict} + + if cfg.layer_decay > 0: + for i, b in enumerate(blocks): + lid = i + 1 + if layer_scales[lid] == 1.0: + continue + + for n, p in b.named_parameters(): + optim_override = getattr(p, "optim_overrides", {}) + if "optimizer" not in optim_override: + optim_override["optimizer"] = {} + + if cfg.no_decay_blocks: + optim_override["optimizer"]["lr_scale"] = layer_scales[lid] + p.optim_overrides = optim_override + else: + optim_override["optimizer"] = { + "lr_scale": layer_scales[lid] + } + p.optim_overrides = optim_override + + else: + for n, p in self.model.named_parameters(): + optimizer_override_dict = {} + layer_id = get_layer_id_for_vit(n, num_layers) + + if len(p.shape) == 1 or n.endswith(".bias"): + optimizer_override_dict["weight_decay_scale"] = 0 + + if cfg.layer_decay > 0: + optimizer_override_dict["lr_scale"] = layer_scales[layer_id] + p.optim_overrides = {"optimizer": optimizer_override_dict} + + @classmethod + def build_model(cls, cfg: MaeImageClassificationConfig, task=None): + """Build a new model instance.""" + + return cls(cfg) + + def forward( + self, + imgs, + labels=None, + ): + if self.training and self.mixup_fn is not None and labels is not None: + imgs, labels = self.mixup_fn(imgs, labels) + + if self.linear_classifier: + with torch.no_grad(): + x = self.model_forward(imgs) + else: + x = self.model_forward(imgs) + + if self.cfg.prediction_mode == PredictionMode.MEAN_POOLING: + x = x.mean(dim=1) + elif self.cfg.prediction_mode == PredictionMode.CLS_TOKEN: + x = x[:, 0] + elif self.cfg.prediction_mode == PredictionMode.LIN_SOFTMAX: + dtype = x.dtype + x = F.logsigmoid(x.float()) + x = torch.logsumexp(x + x, dim=1) - torch.logsumexp(x + 1e-6, dim=1) + x = x.clamp(max=0) + x = x - torch.log(-(torch.expm1(x))) + x = torch.nan_to_num(x, nan=0, posinf=0, neginf=0) + x = x.to(dtype=dtype) + else: + raise Exception(f"unknown prediction mode {self.cfg.prediction_mode.name}") + + if self.fc_norm is not None: + x = self.fc_norm(x) + + x = self.head(x) + + if labels is None: + return x + + if self.training and self.mixup_fn is not None: + loss = -labels * F.log_softmax(x.float(), dim=-1) + else: + loss = F.cross_entropy( + x.float(), + labels, + label_smoothing=self.cfg.label_smoothing if self.training else 0, + reduction="none", + ) + + result = { + "losses": {"regression": loss}, + "sample_size": imgs.size(0), + } + + if not self.training: + with torch.no_grad(): + pred = x.argmax(-1) + correct = (pred == labels).sum() + result["correct"] = correct + + return result + + def model_forward(self, imgs): + if self.d2v_multi: + x = self.model.extract_features( + imgs, + mode="IMAGE", + mask=False, + remove_extra_tokens=( + self.cfg.prediction_mode != PredictionMode.CLS_TOKEN + ), + )["x"] + else: + x = self.model(imgs, predictions_only=True) + if ( + "no_cls" not in self.model.cfg or not self.model.cfg.no_cls + ) and not self.cfg.prediction_mode == PredictionMode.CLS_TOKEN: + x = x[:, 1:] + return x diff --git a/examples/data2vec/models/modalities/__init__.py b/examples/data2vec/models/modalities/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/examples/data2vec/models/modalities/audio.py b/examples/data2vec/models/modalities/audio.py new file mode 100644 index 000000000..80d2857b2 --- /dev/null +++ b/examples/data2vec/models/modalities/audio.py @@ -0,0 +1,192 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from functools import partial +import torch +import torch.nn as nn +import numpy as np +from dataclasses import dataclass, field +from typing import Callable, Dict, Optional +from fairseq.models.wav2vec import ConvFeatureExtractionModel +from fairseq.modules import ( + LayerNorm, + SamePad, + TransposeLast, +) +from fairseq.tasks import FairseqTask +from .base import D2vModalityConfig, ModalitySpecificEncoder, get_alibi_bias +from .modules import BlockEncoder, Decoder1d +from examples.data2vec.data.modality import Modality + + +@dataclass +class D2vAudioConfig(D2vModalityConfig): + type: Modality = Modality.AUDIO + extractor_mode: str = "layer_norm" + feature_encoder_spec: str = field( + default="[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512,2,2)] + [(512,2,2)]", + metadata={ + "help": "string describing convolutional feature extraction layers in form of a python list that contains " + "[(dim, kernel_size, stride), ...]" + }, + ) + conv_pos_width: int = field( + default=95, + metadata={"help": "number of filters for convolutional positional embeddings"}, + ) + conv_pos_groups: int = field( + default=16, + metadata={"help": "number of groups for convolutional positional embedding"}, + ) + conv_pos_depth: int = field( + default=5, + metadata={"help": "depth of positional encoder network"}, + ) + conv_pos_pre_ln: bool = False + + +class AudioEncoder(ModalitySpecificEncoder): + + modality_cfg: D2vAudioConfig + + def __init__( + self, + modality_cfg: D2vAudioConfig, + embed_dim: int, + make_block: Callable[[float], nn.ModuleList], + norm_layer: Callable[[int], nn.LayerNorm], + layer_norm_first: bool, + alibi_biases: Dict, + task: Optional[FairseqTask], + ): + + self.feature_enc_layers = eval(modality_cfg.feature_encoder_spec) + feature_embed_dim = self.feature_enc_layers[-1][0] + + local_encoder = ConvFeatureExtractionModel( + conv_layers=self.feature_enc_layers, + dropout=0.0, + mode=modality_cfg.extractor_mode, + conv_bias=False, + ) + + project_features = nn.Sequential( + TransposeLast(), + nn.LayerNorm(feature_embed_dim), + nn.Linear(feature_embed_dim, embed_dim), + ) + + num_pos_layers = modality_cfg.conv_pos_depth + k = max(3, modality_cfg.conv_pos_width // num_pos_layers) + + positional_encoder = nn.Sequential( + TransposeLast(), + *[ + nn.Sequential( + nn.Conv1d( + embed_dim, + embed_dim, + kernel_size=k, + padding=k // 2, + groups=modality_cfg.conv_pos_groups, + ), + SamePad(k), + TransposeLast(), + LayerNorm(embed_dim, elementwise_affine=False), + TransposeLast(), + nn.GELU(), + ) + for _ in range(num_pos_layers) + ], + TransposeLast(), + ) + + if modality_cfg.conv_pos_pre_ln: + positional_encoder = nn.Sequential(LayerNorm(embed_dim), positional_encoder) + + dpr = np.linspace( + modality_cfg.start_drop_path_rate, + modality_cfg.end_drop_path_rate, + modality_cfg.prenet_depth, + ) + context_encoder = BlockEncoder( + nn.ModuleList(make_block(dpr[i]) for i in range(modality_cfg.prenet_depth)), + norm_layer(embed_dim) if not layer_norm_first else None, + layer_norm_first, + modality_cfg.prenet_layerdrop, + modality_cfg.prenet_dropout, + ) + + decoder = ( + Decoder1d(modality_cfg.decoder, embed_dim) + if modality_cfg.decoder is not None + else None + ) + + alibi_bias_fn = partial(get_alibi_bias, alibi_biases=alibi_biases) + + super().__init__( + modality_cfg=modality_cfg, + embed_dim=embed_dim, + local_encoder=local_encoder, + project_features=project_features, + fixed_positional_encoder=None, + relative_positional_encoder=positional_encoder, + context_encoder=context_encoder, + decoder=decoder, + get_alibi_bias=alibi_bias_fn, + ) + + def convert_padding_mask(self, x, padding_mask): + def get_feat_extract_output_lengths(input_lengths: torch.LongTensor): + """ + Computes the output length of the convolutional layers + """ + + def _conv_out_length(input_length, kernel_size, stride): + return torch.floor((input_length - kernel_size) / stride + 1) + + for i in range(len(self.feature_enc_layers)): + input_lengths = _conv_out_length( + input_lengths, + self.feature_enc_layers[i][1], + self.feature_enc_layers[i][2], + ) + + return input_lengths.to(torch.long) + + if padding_mask is not None: + input_lengths = (1 - padding_mask.long()).sum(-1) + # apply conv formula to get real output_lengths + output_lengths = get_feat_extract_output_lengths(input_lengths) + + if padding_mask.any(): + padding_mask = torch.zeros(x.shape[:2], dtype=x.dtype, device=x.device) + + # these two operations makes sure that all values + # before the output lengths indices are attended to + padding_mask[ + ( + torch.arange(padding_mask.shape[0], device=padding_mask.device), + output_lengths - 1, + ) + ] = 1 + padding_mask = ( + 1 - padding_mask.flip([-1]).cumsum(-1).flip([-1]) + ).bool() + else: + padding_mask = torch.zeros( + x.shape[:2], dtype=torch.bool, device=x.device + ) + + return padding_mask + + def reset_parameters(self): + super().reset_parameters() + for mod in self.project_features.children(): + if isinstance(mod, nn.Linear): + mod.reset_parameters() + if self.decoder is not None: + self.decoder.reset_parameters() diff --git a/examples/data2vec/models/modalities/base.py b/examples/data2vec/models/modalities/base.py new file mode 100644 index 000000000..642cc8466 --- /dev/null +++ b/examples/data2vec/models/modalities/base.py @@ -0,0 +1,684 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import math +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from collections import namedtuple +from dataclasses import dataclass +from functools import partial +from omegaconf import MISSING, II +from typing import Optional, Callable +from fairseq.data.data_utils import compute_mask_indices +from fairseq.modules import GradMultiply +from fairseq.utils import index_put +from examples.data2vec.data.modality import Modality +from .modules import D2vDecoderConfig + +logger = logging.getLogger(__name__) + + +@dataclass +class D2vModalityConfig: + type: Modality = MISSING + prenet_depth: int = 4 + prenet_layerdrop: float = 0 + prenet_dropout: float = 0 + start_drop_path_rate: float = 0 + end_drop_path_rate: float = 0 + + num_extra_tokens: int = 0 + init_extra_token_zero: bool = True + + mask_noise_std: float = 0.01 + mask_prob_min: Optional[float] = None + mask_prob: float = 0.7 + inverse_mask: bool = False + mask_prob_adjust: float = 0 + keep_masked_pct: float = 0 + + mask_length: int = 5 + add_masks: bool = False + remove_masks: bool = False + mask_dropout: float = 0.0 + encoder_zero_mask: bool = True + + mask_channel_prob: float = 0.0 + mask_channel_length: int = 64 + + ema_local_encoder: bool = False # used in data2vec_multi + local_grad_mult: float = 1.0 + + use_alibi_encoder: bool = False + alibi_scale: float = 1.0 + learned_alibi: bool = False + alibi_max_pos: Optional[int] = None + learned_alibi_scale: bool = False + learned_alibi_scale_per_head: bool = False + learned_alibi_scale_per_layer: bool = False + + num_alibi_heads: int = II("model.num_heads") + model_depth: int = II("model.depth") + + decoder: Optional[D2vDecoderConfig] = D2vDecoderConfig() + + +MaskSeed = namedtuple("MaskSeed", ["seed", "update", "ids"]) +MaskInfo = namedtuple("MaskInfo", ["x_unmasked", "mask", "ids_restore", "ids_keep"]) + + +class ModalitySpecificEncoder(nn.Module): + def __init__( + self, + modality_cfg: D2vModalityConfig, + embed_dim: int, + local_encoder: nn.Module, + project_features: nn.Module, + fixed_positional_encoder: Optional[nn.Module], + relative_positional_encoder: Optional[nn.Module], + context_encoder: nn.Module, + decoder: nn.Module, + get_alibi_bias: Optional[Callable[[int, int, str, str], torch.Tensor]], + ): + super().__init__() + + self.modality_cfg = modality_cfg + self.local_encoder = local_encoder + self.project_features = project_features + self.fixed_positional_encoder = fixed_positional_encoder + self.relative_positional_encoder = relative_positional_encoder + self.context_encoder = context_encoder + + self.decoder = decoder + self.get_alibi_bias = get_alibi_bias if modality_cfg.use_alibi_encoder else None + + self.local_grad_mult = self.modality_cfg.local_grad_mult + + self.extra_tokens = None + if modality_cfg.num_extra_tokens > 0: + self.extra_tokens = nn.Parameter( + torch.zeros(1, modality_cfg.num_extra_tokens, embed_dim) + ) + if not modality_cfg.init_extra_token_zero: + nn.init.normal_(self.extra_tokens) + elif self.extra_tokens.size(1) > 1: + nn.init.normal_(self.extra_tokens[:, 1:]) + + self.alibi_scale = None + if self.get_alibi_bias is not None: + self.alibi_scale = nn.Parameter( + torch.full( + ( + (modality_cfg.prenet_depth + modality_cfg.model_depth) + if modality_cfg.learned_alibi_scale_per_layer + else 1, + 1, + self.modality_cfg.num_alibi_heads + if modality_cfg.learned_alibi_scale_per_head + else 1, + 1, + 1, + ), + modality_cfg.alibi_scale, + dtype=torch.float, + ), + requires_grad=modality_cfg.learned_alibi_scale, + ) + + if modality_cfg.learned_alibi and self.get_alibi_bias is not None: + assert modality_cfg.alibi_max_pos is not None + alibi_bias = self.get_alibi_bias( + batch_size=1, + time_steps=modality_cfg.alibi_max_pos, + heads=modality_cfg.num_alibi_heads, + scale=1.0, + dtype=torch.float, + device="cpu", + ) + self.alibi_bias = nn.Parameter(alibi_bias) + self.get_alibi_bias = partial( + _learned_alibi_bias, alibi_bias=self.alibi_bias + ) + + def upgrade_state_dict_named(self, state_dict, name): + k = f"{name}.alibi_scale" + if k in state_dict and state_dict[k].dim() == 4: + state_dict[k] = state_dict[k].unsqueeze(0) + + return state_dict + + def convert_padding_mask(self, x, padding_mask): + return padding_mask + + def decoder_input(self, x, mask_info: MaskInfo): + inp_drop = self.modality_cfg.decoder.input_dropout + if inp_drop > 0: + x = F.dropout(x, inp_drop, training=self.training, inplace=True) + + num_extra = self.modality_cfg.num_extra_tokens + + if mask_info is not None: + num_masked = mask_info.ids_restore.shape[1] - x.shape[1] + num_extra + + mask_tokens = x.new_empty( + x.size(0), + num_masked, + x.size(-1), + ).normal_(0, self.modality_cfg.mask_noise_std) + + x_ = torch.cat([x[:, num_extra:], mask_tokens], dim=1) + x = torch.gather(x_, dim=1, index=mask_info.ids_restore) + + if self.modality_cfg.decoder.add_positions_masked: + assert self.fixed_positional_encoder is not None + pos = self.fixed_positional_encoder(x, None) + x = x + (pos * mask_info.mask.unsqueeze(-1)) + else: + x = x[:, num_extra:] + + if self.modality_cfg.decoder.add_positions_all: + assert self.fixed_positional_encoder is not None + x = x + self.fixed_positional_encoder(x, None) + + return x, mask_info + + def local_features(self, features): + if self.local_grad_mult > 0: + if self.local_grad_mult == 1.0: + x = self.local_encoder(features) + else: + x = GradMultiply.apply( + self.local_encoder(features), self.local_grad_mult + ) + else: + with torch.no_grad(): + x = self.local_encoder(features) + + x = self.project_features(x) + return x + + def contextualized_features( + self, + x, + padding_mask, + mask, + remove_masked, + clone_batch: int = 1, + mask_seeds: Optional[torch.Tensor] = None, + precomputed_mask=None, + ): + + if padding_mask is not None: + padding_mask = self.convert_padding_mask(x, padding_mask) + + local_features = x + if mask and clone_batch == 1: + local_features = local_features.clone() + + orig_B, orig_T, _ = x.shape + pre_mask_B = orig_B + mask_info = None + + x_pos = None + if self.fixed_positional_encoder is not None: + x = x + self.fixed_positional_encoder(x, padding_mask) + + if mask: + if clone_batch > 1: + x = x.repeat_interleave(clone_batch, 0) + if mask_seeds is not None: + clone_hash = [ + int(hash((mask_seeds.seed, ind)) % 1e10) + for ind in range(clone_batch - 1) + ] + clone_hash = torch.tensor([0] + clone_hash).long().view(1, -1) + + id = mask_seeds.ids + id = id.repeat_interleave(clone_batch, 0) + id = id.view(-1, clone_batch) + clone_hash.to(id) + id = id.view(-1) + mask_seeds = MaskSeed( + seed=mask_seeds.seed, update=mask_seeds.update, ids=id + ) + if padding_mask is not None: + padding_mask = padding_mask.repeat_interleave(clone_batch, 0) + + x, mask_info = self.compute_mask( + x, + padding_mask, + mask_seed=mask_seeds, + apply=self.relative_positional_encoder is not None or not remove_masked, + precomputed_mask=precomputed_mask, + ) + + if self.relative_positional_encoder is not None: + x_pos = self.relative_positional_encoder(x) + + masked_padding_mask = padding_mask + if mask and remove_masked: + x = mask_info.x_unmasked + if x_pos is not None: + x = x + gather_unmasked(x_pos, mask_info) + + if padding_mask is not None and padding_mask.any(): + masked_padding_mask = gather_unmasked_mask(padding_mask, mask_info) + if not masked_padding_mask.any(): + masked_padding_mask = None + else: + masked_padding_mask = None + + elif x_pos is not None: + x = x + x_pos + + alibi_bias = None + alibi_scale = self.alibi_scale + + if self.get_alibi_bias is not None: + alibi_bias = self.get_alibi_bias( + batch_size=pre_mask_B, + time_steps=orig_T, + heads=self.modality_cfg.num_alibi_heads, + dtype=torch.float32, + device=x.device, + ) + + if alibi_scale is not None: + alibi_scale = alibi_scale.clamp_min(0) + if alibi_scale.size(0) == 1: + alibi_bias = alibi_bias * alibi_scale.squeeze(0).type_as(alibi_bias) + alibi_scale = None + + if clone_batch > 1: + alibi_bias = alibi_bias.repeat_interleave(clone_batch, 0) + + if mask_info is not None and remove_masked: + alibi_bias = masked_alibi(alibi_bias, mask_info) + + if self.extra_tokens is not None: + num = self.extra_tokens.size(1) + x = torch.cat([self.extra_tokens.expand(x.size(0), -1, -1), x], dim=1) + if masked_padding_mask is not None: + # B x T + masked_padding_mask = F.pad(masked_padding_mask, (num, 0)) + if alibi_bias is not None: + # B x H x T x T + alibi_bias = F.pad(alibi_bias, (num, 0, num, 0)) + + x = self.context_encoder( + x, + masked_padding_mask, + alibi_bias, + alibi_scale[: self.modality_cfg.prenet_depth] + if alibi_scale is not None + else None, + ) + + return { + "x": x, + "local_features": local_features, + "padding_mask": masked_padding_mask, + "alibi_bias": alibi_bias, + "alibi_scale": alibi_scale[self.modality_cfg.prenet_depth :] + if alibi_scale is not None and alibi_scale.size(0) > 1 + else alibi_scale, + "encoder_mask": mask_info, + } + + def forward( + self, + features, + padding_mask, + mask: bool, + remove_masked: bool, + clone_batch: int = 1, + mask_seeds: Optional[torch.Tensor] = None, + precomputed_mask=None, + ): + x = self.local_features(features) + return self.contextualized_features( + x, + padding_mask, + mask, + remove_masked, + clone_batch, + mask_seeds, + precomputed_mask, + ) + + def reset_parameters(self): + pass + + def compute_mask( + self, + x, + padding_mask, + mask_seed: Optional[MaskSeed], + apply, + precomputed_mask, + ): + if precomputed_mask is not None: + mask = precomputed_mask + mask_info = self.make_maskinfo(x, mask) + else: + B, T, C = x.shape + cfg = self.modality_cfg + + mask_prob = cfg.mask_prob + + if ( + cfg.mask_prob_min is not None + and cfg.mask_prob_min >= 0 + and cfg.mask_prob_min < mask_prob + ): + mask_prob = np.random.uniform(cfg.mask_prob_min, mask_prob) + + if mask_prob > 0: + if cfg.mask_length == 1: + mask_info = random_masking(x, mask_prob, mask_seed) + else: + if self.modality_cfg.inverse_mask: + mask_prob = 1 - mask_prob + + mask = compute_mask_indices( + (B, T), + padding_mask, + mask_prob, + cfg.mask_length, + min_masks=1, + require_same_masks=True, + mask_dropout=cfg.mask_dropout, + add_masks=cfg.add_masks, + seed=mask_seed.seed if mask_seed is not None else None, + epoch=mask_seed.update if mask_seed is not None else None, + indices=mask_seed.ids if mask_seed is not None else None, + ) + + mask = torch.from_numpy(mask).to(device=x.device) + if self.modality_cfg.inverse_mask: + mask = 1 - mask + mask_info = self.make_maskinfo(x, mask) + else: + mask_info = None + + if apply: + x = self.apply_mask(x, mask_info) + + return x, mask_info + + def make_maskinfo(self, x, mask, shape=None): + if shape is None: + B, T, D = x.shape + else: + B, T, D = shape + + mask = mask.to(torch.uint8) + ids_shuffle = mask.argsort(dim=1) + ids_restore = ids_shuffle.argsort(dim=1).unsqueeze(-1).expand(-1, -1, D) + + len_keep = T - mask[0].sum() + if self.modality_cfg.keep_masked_pct > 0: + len_keep += round((T - int(len_keep)) * self.modality_cfg.keep_masked_pct) + + ids_keep = ids_shuffle[:, :len_keep] + + if shape is not None: + x_unmasked = None + else: + ids_keep = ids_keep.unsqueeze(-1).expand(-1, -1, D) + x_unmasked = torch.gather(x, dim=1, index=ids_keep) + + mask_info = MaskInfo( + x_unmasked=x_unmasked, + mask=mask, + ids_restore=ids_restore, + ids_keep=ids_keep, + ) + return mask_info + + def apply_mask(self, x, mask_info): + cfg = self.modality_cfg + B, T, C = x.shape + + if mask_info is not None: + mask = mask_info.mask + if cfg.encoder_zero_mask: + x = x * (1 - mask.type_as(x).unsqueeze(-1)) + else: + num_masks = mask.sum().item() + masks = x.new_empty(num_masks, x.size(-1)).normal_( + 0, cfg.mask_noise_std + ) + x = index_put(x, mask, masks) + if cfg.mask_channel_prob > 0: + mask_channel = compute_mask_indices( + (B, C), + None, + cfg.mask_channel_prob, + cfg.mask_channel_length, + ) + mask_channel = ( + torch.from_numpy(mask_channel) + .to(x.device) + .unsqueeze(1) + .expand(-1, T, -1) + ) + x = index_put(x, mask_channel, 0) + return x + + def remove_pretraining_modules(self, keep_decoder=False): + if not keep_decoder: + self.decoder = None + + +def get_annealed_rate(start, end, curr_step, total_steps): + if curr_step >= total_steps: + return end + r = end - start + pct_remaining = 1 - curr_step / total_steps + return end - r * pct_remaining + + +# adapted from MAE +def random_masking(x, mask_ratio, mask_seed: Optional[MaskSeed]): + N, L, D = x.shape # batch, length, dim + len_keep = int(L * (1 - mask_ratio)) + + generator = None + if mask_seed is not None: + seed = int( + hash((mask_seed.seed, mask_seed.update, mask_seed.ids.sum().item())) % 1e6 + ) + generator = torch.Generator(device=x.device) + generator.manual_seed(seed) + + noise = torch.rand(N, L, generator=generator, device=x.device) # noise in [0, 1] + + # sort noise for each sample + ids_shuffle = noise.argsort(dim=1) # ascend: small is keep, large is remove + ids_restore = ids_shuffle.argsort(dim=1) + + # keep the first subset + ids_keep = ids_shuffle[:, :len_keep] + ids_keep = ids_keep.unsqueeze(-1).expand(-1, -1, D) + x_unmasked = torch.gather(x, dim=1, index=ids_keep) + + # generate the binary mask: 0 is keep, 1 is remove + mask = torch.ones([N, L], dtype=x.dtype, device=x.device) + mask[:, :len_keep] = 0 + # unshuffle to get the binary mask + mask = torch.gather(mask, dim=1, index=ids_restore) + + ids_restore = ids_restore.unsqueeze(-1).expand(-1, -1, D) + + return MaskInfo( + x_unmasked=x_unmasked, mask=mask, ids_restore=ids_restore, ids_keep=ids_keep + ) + + +def gather_unmasked(x: torch.Tensor, mask_info: MaskInfo) -> torch.Tensor: + return torch.gather( + x, + dim=1, + index=mask_info.ids_keep, + ) + + +def gather_unmasked_mask(x: torch.Tensor, mask_info: MaskInfo) -> torch.Tensor: + return torch.gather( + x, + dim=1, + index=mask_info.ids_keep[..., 0], # ignore the feature dimension + ) + + +def get_alibi( + max_positions: int, + attention_heads: int, + dims: int = 1, + distance: str = "manhattan", +): + def get_slopes(n): + def get_slopes_power_of_2(n): + start = 2 ** (-(2 ** -(math.log2(n) - 3))) + ratio = start + return [start * ratio**i for i in range(n)] + + # In the paper, we only train models that have 2^a heads for some + # a. This function has some good properties that only occur when + # the input is a power of 2. To maintain that even when the number + # of heads is not a power of 2, we use this workaround. + if math.log2(n).is_integer(): + return get_slopes_power_of_2(n) + else: + closest_power_of_2 = 2 ** math.floor(math.log2(n)) + return ( + get_slopes_power_of_2(closest_power_of_2) + + get_slopes(2 * closest_power_of_2)[0::2][: n - closest_power_of_2] + ) + + maxpos = max_positions + attn_heads = attention_heads + slopes = torch.Tensor(get_slopes(attn_heads)) + + if dims == 1: + # prepare alibi position linear bias. Note that wav2vec2 is non + # autoregressive model so we want a symmetric mask with 0 on the + # diagonal and other wise linear decreasing valuees + pos_bias = ( + torch.abs( + torch.arange(maxpos).unsqueeze(0) - torch.arange(maxpos).unsqueeze(1) + ) + * -1 + ) + elif dims == 2: + if distance == "manhattan": + df = lambda x1, y1, x2, y2: abs(x1 - x2) + abs(y1 - y2) + elif distance == "euclidean": + df = lambda x1, y1, x2, y2: math.sqrt((x1 - x2) ** 2 + (y1 - y2) ** 2) + + n = math.sqrt(max_positions) + assert n.is_integer(), n + n = int(n) + + pos_bias = torch.zeros((max_positions, max_positions)) + + for i in range(n): + for j in range(n): + for k in range(n): + for l in range(n): + new_x = i * n + j + new_y = k * n + l + pos_bias[new_x, new_y] = -df(i, j, k, l) + + else: + raise Exception(f"unsupported number of alibi dims: {dims}") + + alibi_bias = slopes.unsqueeze(1).unsqueeze(1) * pos_bias.unsqueeze(0).expand( + attn_heads, -1, -1 + ) + + return alibi_bias + + +def get_alibi_bias( + alibi_biases, + batch_size, + time_steps, + heads, + dtype, + device, + dims=1, + distance="manhattan", +): + cache_key = f"{dims}_{heads}_{distance}" + + buffered = alibi_biases.get(cache_key, None) + + target_size = heads * batch_size + if ( + buffered is None + or buffered.size(0) < target_size + or buffered.size(1) < time_steps + or buffered.dtype != dtype + or buffered.device != device + ): + bt = max(time_steps, buffered.size(1) if buffered is not None else 0) + bn = max(target_size, buffered.size(0) if buffered is not None else 0) // heads + + buffered = ( + get_alibi(bt, heads, dims=dims, distance=distance) + .to(dtype=dtype, device=device) + .repeat(bn, 1, 1) + ) + + alibi_biases[cache_key] = buffered + + b = buffered[:target_size, :time_steps, :time_steps] + b = b.view(batch_size, heads, time_steps, time_steps) + return b + + +def _learned_alibi_bias( + alibi_bias, + batch_size, + time_steps, + heads, + scale, + dtype, + device, +): + assert alibi_bias.size(1) == heads, alibi_bias.shape + assert alibi_bias.dtype == dtype, alibi_bias.dtype + assert alibi_bias.device == device, alibi_bias.device + + if alibi_bias.size(-1) < time_steps: + psz = math.ceil((time_steps - alibi_bias.size(-1)) / 2) + alibi_bias = F.pad(alibi_bias, (psz, psz, psz, psz), mode="replicate") + + alibi_bias = alibi_bias.expand(batch_size, -1, -1, -1) * scale + return alibi_bias[..., :time_steps, :time_steps] + + +def masked_alibi(alibi_bias, mask_info): + H = alibi_bias.size(1) + + orig_bias = alibi_bias + + index = mask_info.ids_keep.unsqueeze(1)[..., 0].unsqueeze(-1) + alibi_bias = torch.gather( + orig_bias, + dim=-2, + index=index.expand(-1, H, -1, mask_info.ids_restore.size(1)), + ) + alibi_bias = torch.gather( + alibi_bias, + dim=-1, + index=index.transpose(-1, -2).expand(-1, H, alibi_bias.size(-2), -1), + ) + + return alibi_bias diff --git a/examples/data2vec/models/modalities/images.py b/examples/data2vec/models/modalities/images.py new file mode 100644 index 000000000..a6b738cb0 --- /dev/null +++ b/examples/data2vec/models/modalities/images.py @@ -0,0 +1,256 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from functools import partial +from dataclasses import dataclass +from typing import Callable, Dict, Optional +from timm.models.layers import to_2tuple +from fairseq.tasks import FairseqTask +from examples.data2vec.models.mae import get_2d_sincos_pos_embed, PatchEmbed +from .base import ( + D2vModalityConfig, + ModalitySpecificEncoder, + get_alibi_bias, + MaskSeed, +) +from .modules import ( + BlockEncoder, + Decoder2d, + FixedPositionalEncoder, + TransformerDecoder, + EncDecTransformerDecoder, +) +from examples.data2vec.data.modality import Modality + + +@dataclass +class D2vImageConfig(D2vModalityConfig): + type: Modality = Modality.IMAGE + + input_size: int = 224 + in_chans: int = 3 + patch_size: int = 16 + embed_dim: int = 768 + + alibi_dims: int = 2 + alibi_distance: str = "manhattan" + + fixed_positions: bool = True + + transformer_decoder: bool = False + enc_dec_transformer: bool = False + + +class ImageEncoder(ModalitySpecificEncoder): + + modality_cfg: D2vImageConfig + + def __init__( + self, + modality_cfg: D2vImageConfig, + embed_dim: int, + make_block: Callable[[float, Optional[int], Optional[int]], nn.ModuleList], + norm_layer: Callable[[int], nn.LayerNorm], + layer_norm_first: bool, + alibi_biases: Dict, + task: Optional[FairseqTask], + ): + + img_size = to_2tuple(modality_cfg.input_size) + patch_size = to_2tuple(modality_cfg.patch_size) + num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) + + local_encoder = PatchEmbed( + modality_cfg.input_size, + modality_cfg.patch_size, + modality_cfg.in_chans, + modality_cfg.embed_dim, + ) + + w = local_encoder.proj.weight.data + torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + + if modality_cfg.embed_dim != embed_dim: + local_encoder = nn.Sequential( + local_encoder, + nn.Linear(modality_cfg.embed_dim, embed_dim), + ) + + project_features = nn.Identity() + + pos_embed = nn.Parameter( + torch.zeros(1, num_patches, embed_dim), requires_grad=False + ) + + side_n = int(num_patches ** 0.5) + + emb = get_2d_sincos_pos_embed( + pos_embed.shape[-1], + side_n, + cls_token=False, + ) + pos_embed.data.copy_(torch.from_numpy(emb).float().unsqueeze(0)) + fixed_positional_encoder = ( + FixedPositionalEncoder(pos_embed) if modality_cfg.fixed_positions else None + ) + + dpr = np.linspace( + modality_cfg.start_drop_path_rate, + modality_cfg.end_drop_path_rate, + modality_cfg.prenet_depth, + ) + + context_encoder = BlockEncoder( + nn.ModuleList(make_block(dpr[i]) for i in range(modality_cfg.prenet_depth)), + norm_layer(embed_dim) if not layer_norm_first else None, + layer_norm_first, + modality_cfg.prenet_layerdrop, + modality_cfg.prenet_dropout, + ) + + if modality_cfg.transformer_decoder: + if modality_cfg.enc_dec_transformer: + decoder = EncDecTransformerDecoder(modality_cfg.decoder, embed_dim) + else: + dec_enc = BlockEncoder( + nn.ModuleList( + make_block(0, modality_cfg.decoder.decoder_dim, 8) + for _ in range(modality_cfg.decoder.decoder_layers) + ), + None, + layer_norm_first, + 0, + 0, + ) + decoder = TransformerDecoder(modality_cfg.decoder, embed_dim, dec_enc) + else: + decoder = ( + Decoder2d(modality_cfg.decoder, embed_dim, side_n, side_n) + if modality_cfg.decoder is not None + else None + ) + + alibi_bias_fn = partial( + get_alibi_bias, + alibi_biases=alibi_biases, + heads=modality_cfg.num_alibi_heads, + dims=modality_cfg.alibi_dims, + distance=modality_cfg.alibi_distance, + ) + + super().__init__( + modality_cfg=modality_cfg, + embed_dim=embed_dim, + local_encoder=local_encoder, + project_features=project_features, + fixed_positional_encoder=fixed_positional_encoder, + relative_positional_encoder=None, + context_encoder=context_encoder, + decoder=decoder, + get_alibi_bias=alibi_bias_fn, + ) + + def reset_parameters(self): + super().reset_parameters() + if self.decoder is not None: + self.decoder.reset_parameters() + + @torch.no_grad() + def patchify(self, imgs): + """ + imgs: (N, 3, H, W) + x: (N, L, patch_size**2 *3) + """ + p = self.modality_cfg.patch_size + h = w = imgs.shape[2] // p + x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p)) + x = torch.einsum("nchpwq->nhwpqc", x) + x = x.reshape(shape=(imgs.shape[0], h * w, p ** 2 * 3)) + + return x + + @torch.no_grad() + def unpatchify(self, x): + """ + x: (N, L, patch_size**2 *3) + imgs: (N, 3, H, W) + """ + p = self.modality_cfg.patch_size + h = w = int(x.shape[1] ** 0.5) + assert h * w == x.shape[1] + + x = x.reshape(shape=(x.shape[0], h, w, p, p, 3)) + x = torch.einsum("nhwpqc->nchpwq", x) + imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p)) + return imgs + + def compute_mask( + self, + x, + padding_mask, + mask_seed: Optional[MaskSeed], + apply, + shape=None, + precomputed_mask=None, + ): + mlen = self.modality_cfg.mask_length + if mlen <= 1: + return super().compute_mask( + x, padding_mask, mask_seed, apply, precomputed_mask + ) + + if precomputed_mask is not None: + mask = precomputed_mask + else: + from fairseq.data.data_utils import compute_block_mask_2d + + if shape is not None: + B, L, D = shape + else: + B, L, D = x.shape + + mask = compute_block_mask_2d( + shape=(B, L), + mask_prob=self.modality_cfg.mask_prob, + mask_length=self.modality_cfg.mask_length, + mask_prob_adjust=self.modality_cfg.mask_prob_adjust, + inverse_mask=self.modality_cfg.inverse_mask, + require_same_masks=True, + mask_dropout=self.modality_cfg.mask_dropout, + ) + + mask_info = self.make_maskinfo(x, mask, shape) + if apply: + x = self.apply_mask(x, mask_info) + + return x, mask_info + + def decoder_input(self, x, mask_info): + if ( + not self.modality_cfg.transformer_decoder + or not self.modality_cfg.enc_dec_transformer + ): + return super().decoder_input(x, mask_info) + + inp_drop = self.modality_cfg.decoder.input_dropout + if inp_drop > 0: + x = F.dropout(x, inp_drop, training=self.training, inplace=True) + + kv = x[:, self.modality_cfg.num_extra_tokens :] + + assert self.fixed_positional_encoder is not None + pos = self.fixed_positional_encoder(x, None).expand(x.size(0), -1, -1) + + mask = mask_info.mask.bool() + if self.modality_cfg.decoder.add_positions_all: + kv = kv + pos[~mask].view(kv.shape) + + q = pos[mask].view(x.size(0), -1, x.size(-1)) + + return q, kv diff --git a/examples/data2vec/models/modalities/modules.py b/examples/data2vec/models/modalities/modules.py new file mode 100644 index 000000000..a4e1a4ea0 --- /dev/null +++ b/examples/data2vec/models/modalities/modules.py @@ -0,0 +1,589 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from dataclasses import dataclass +from fairseq.modules import ( + LayerNorm, + SamePad, + SamePad2d, + TransposeLast, +) + + +@dataclass +class D2vDecoderConfig: + decoder_dim: int = 384 + decoder_groups: int = 16 + decoder_kernel: int = 5 + decoder_layers: int = 5 + input_dropout: float = 0.1 + + add_positions_masked: bool = False + add_positions_all: bool = False + + decoder_residual: bool = True + projection_layers: int = 1 + projection_ratio: float = 2.0 + + +class FixedPositionalEncoder(nn.Module): + def __init__(self, pos_embed): + super().__init__() + self.positions = pos_embed + + def forward(self, x, padding_mask): + return self.positions + + +class TextFeatPositionalEncoder(nn.Module): + """ + Original encoder expects (B, T) long input. This module wraps it to take + local_encoder output which are (B, T, D) float tensors + """ + + def __init__(self, pos_encoder): + super().__init__() + self.pos_encoder = pos_encoder + + def forward(self, x, padding_mask): + # assume padded token embeddings are 0s + # TODO: consider using padding_mask as input + return self.pos_encoder(x[..., 0]) + + +class BlockEncoder(nn.Module): + def __init__(self, blocks, norm_layer, layer_norm_first, layerdrop, dropout): + super().__init__() + self.blocks = blocks + self.norm = norm_layer + self.layer_norm_first = layer_norm_first + self.layerdrop = layerdrop + self.dropout = nn.Dropout(dropout, inplace=True) + + def forward(self, x, padding_mask, alibi_bias, alibi_scale): + if self.norm is not None and not self.layer_norm_first: + x = self.norm(x) + + x = self.dropout(x) + + for i, blk in enumerate(self.blocks): + if ( + not self.training + or self.layerdrop == 0 + or (np.random.random() > self.layerdrop) + ): + ab = alibi_bias + if ab is not None and alibi_scale is not None: + scale = ( + alibi_scale[i] + if alibi_scale.size(0) > 1 + else alibi_scale.squeeze(0) + ) + ab = ab * scale.type_as(ab) + x, _ = blk(x, padding_mask, ab) + + if self.norm is not None and self.layer_norm_first: + x = self.norm(x) + + return x + + +class DecoderBase(nn.Module): + decoder_cfg: D2vDecoderConfig + + def __init__(self, cfg: D2vDecoderConfig): + super().__init__() + + self.decoder_cfg = cfg + + def reset_parameters(self): + for mod in self.proj.modules(): + if isinstance(mod, nn.Linear): + mod.reset_parameters() + + def add_residual(self, x, residual, i, mask_info): + if ( + residual is None + or not self.decoder_cfg.decoder_residual + or residual.size(1) != x.size(1) + ): + return x + + ret = x + residual + + return ret + + +class Decoder1d(DecoderBase): + def __init__(self, cfg: D2vDecoderConfig, input_dim): + super().__init__(cfg) + + def make_block(in_dim): + block = [ + nn.Conv1d( + in_dim, + cfg.decoder_dim, + kernel_size=cfg.decoder_kernel, + padding=cfg.decoder_kernel // 2, + groups=cfg.decoder_groups, + ), + SamePad(cfg.decoder_kernel), + TransposeLast(), + LayerNorm(cfg.decoder_dim, elementwise_affine=False), + TransposeLast(), + nn.GELU(), + ] + + return nn.Sequential(*block) + + self.blocks = nn.Sequential( + *[ + make_block(input_dim if i == 0 else cfg.decoder_dim) + for i in range(cfg.decoder_layers) + ] + ) + + projs = [] + curr_dim = cfg.decoder_dim + for i in range(cfg.projection_layers - 1): + next_dim = int(curr_dim * cfg.projection_ratio) if i == 0 else curr_dim + projs.append(nn.Linear(curr_dim, next_dim)) + projs.append(nn.GELU()) + curr_dim = next_dim + projs.append(nn.Linear(curr_dim, input_dim)) + if len(projs) == 1: + self.proj = projs[0] + else: + self.proj = nn.Sequential(*projs) + + def forward(self, x, mask_info): + + x = x.transpose(1, 2) + + residual = x + + for i, layer in enumerate(self.blocks): + x = layer(x) + x = self.add_residual(x, residual, i, mask_info) + residual = x + + x = x.transpose(1, 2) + x = self.proj(x) + return x + + +class Decoder2d(DecoderBase): + def __init__(self, cfg: D2vDecoderConfig, input_dim, h_size, w_size): + super().__init__(cfg) + + self.h_size = h_size + self.w_size = w_size + + def make_block(in_dim): + block = [ + nn.Conv2d( + in_dim, + cfg.decoder_dim, + kernel_size=cfg.decoder_kernel, + padding=cfg.decoder_kernel // 2, + groups=cfg.decoder_groups, + ), + SamePad2d(cfg.decoder_kernel), + TransposeLast(tranpose_dim=-3), + LayerNorm(cfg.decoder_dim, elementwise_affine=False), + TransposeLast(tranpose_dim=-3), + nn.GELU(), + ] + + return nn.Sequential(*block) + + self.blocks = nn.Sequential( + *[ + make_block(input_dim if i == 0 else cfg.decoder_dim) + for i in range(cfg.decoder_layers) + ] + ) + + self.proj = nn.Linear(cfg.decoder_dim, input_dim) + + def forward(self, x, mask_info): + B, T, C = x.shape + + x = x.transpose(1, 2).reshape(B, C, self.h_size, self.w_size) + + residual = x + + for i, layer in enumerate(self.blocks): + x = layer(x) + x = self.add_residual(x, residual, i, mask_info) + residual = x + + x = x.reshape(B, -1, T).transpose(1, 2) + x = self.proj(x) + return x + + +class TransformerDecoder(nn.Module): + decoder_cfg: D2vDecoderConfig + + def __init__(self, cfg: D2vDecoderConfig, input_dim, encoder): + super().__init__() + + self.decoder_cfg = cfg + + self.input_proj = nn.Linear(input_dim, cfg.decoder_dim) + + self.encoder = encoder + + self.proj = nn.Linear(cfg.decoder_dim, input_dim) + + def reset_parameters(self): + from fairseq.modules.transformer_sentence_encoder import init_bert_params + + self.apply(init_bert_params) + + def forward(self, x, mask_info): + x = self.input_proj(x) + x = self.encoder(x, None, None, 1) + x = self.proj(x) + return x + + +class AltBlock(nn.Module): + def __init__( + self, + dim, + num_heads, + mlp_ratio=4.0, + qkv_bias=False, + qk_scale=None, + drop=0.0, + attn_drop=0.0, + mlp_drop=0.0, + post_mlp_drop=0.0, + drop_path=0.0, + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + layer_norm_first=True, + ffn_targets=False, + cosine_attention=False, + ): + super().__init__() + + self.layer_norm_first = layer_norm_first + self.ffn_targets = ffn_targets + + from timm.models.vision_transformer import DropPath, Mlp + + self.norm1 = norm_layer(dim) + self.attn = AltAttention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop, + cosine_attention=cosine_attention, + ) + + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=mlp_drop, + ) + self.post_mlp_dropout = nn.Dropout(post_mlp_drop, inplace=False) + + def forward(self, x, padding_mask=None, alibi_bias=None): + if self.layer_norm_first: + x = x + self.drop_path(self.attn(self.norm1(x), padding_mask, alibi_bias)) + r = x = self.mlp(self.norm2(x)) + t = x + x = r + self.drop_path(self.post_mlp_dropout(x)) + if not self.ffn_targets: + t = x + else: + x = x + self.drop_path(self.attn(x, padding_mask, alibi_bias)) + r = x = self.norm1(x) + x = self.mlp(x) + t = x + x = self.norm2(r + self.drop_path(self.post_mlp_dropout(x))) + if not self.ffn_targets: + t = x + + return x, t + + +class AltAttention(nn.Module): + def __init__( + self, + dim, + num_heads=8, + qkv_bias=False, + qk_scale=None, + attn_drop=0.0, + proj_drop=0.0, + cosine_attention=False, + ): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.cosine_attention = cosine_attention + + if cosine_attention: + self.logit_scale = nn.Parameter( + torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True + ) + + def forward(self, x, padding_mask=None, alibi_bias=None): + B, N, C = x.shape + qkv = ( + self.qkv(x) + .reshape(B, N, 3, self.num_heads, C // self.num_heads) + .permute(2, 0, 3, 1, 4) # qkv x B x H x L x D + ) + q, k, v = ( + qkv[0], + qkv[1], + qkv[2], + ) # make torchscript happy (cannot use tensor as tuple) + + dtype = q.dtype + + if self.cosine_attention: + # cosine attention + attn = F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1) + logit_scale = torch.clamp( + self.logit_scale, max=torch.log(torch.tensor(1.0 / 0.01)) + ).exp() + attn = attn * logit_scale + else: + q = q * self.scale + attn = q @ k.transpose(-2, -1) + + if alibi_bias is not None: + attn = attn.type_as(alibi_bias) + attn[:, : alibi_bias.size(1)] += alibi_bias + + if padding_mask is not None and padding_mask.any(): + attn = attn.masked_fill( + padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), + float("-inf"), + ) + + attn = attn.softmax(dim=-1, dtype=torch.float32).to(dtype=dtype) + attn = self.attn_drop(attn) + x = (attn @ v).transpose(1, 2) # + x = x.reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class EncDecAttention(nn.Module): + def __init__( + self, + q_dim, + kv_dim, + num_heads=8, + qkv_bias=False, + qk_scale=None, + attn_drop=0.0, + proj_drop=0.0, + cosine_attention=False, + ): + super().__init__() + self.num_heads = num_heads + head_dim = q_dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + self.q_proj = nn.Linear(q_dim, q_dim, bias=qkv_bias) + self.kv_proj = nn.Linear(kv_dim, 2 * q_dim, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(q_dim, q_dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.cosine_attention = cosine_attention + + if cosine_attention: + self.logit_scale = nn.Parameter( + torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True + ) + + def forward(self, q, kv, padding_mask=None, alibi_bias=None): + B, N, C = q.shape + + q = ( + self.q_proj(q) + .reshape(B, N, self.num_heads, C // self.num_heads) + .permute(0, 2, 1, 3) + ) # B x H x L x D + kv = ( + self.kv_proj(kv) + .reshape(B, -1, 2, self.num_heads, C // self.num_heads) + .permute(2, 0, 3, 1, 4) + ) # kv x B x H x L x D + k, v = ( + kv[0], + kv[1], + ) # make torchscript happy (cannot use tensor as tuple) + + dtype = q.dtype + + if self.cosine_attention: + # cosine attention + attn = F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1) + logit_scale = torch.clamp( + self.logit_scale, max=torch.log(torch.tensor(1.0 / 0.01)) + ).exp() + attn = attn * logit_scale + else: + q = q * self.scale + attn = q @ k.transpose(-2, -1) + + if alibi_bias is not None: + attn = attn.type_as(alibi_bias) + attn[:, : alibi_bias.size(1)] += alibi_bias + + if padding_mask is not None and padding_mask.any(): + attn = attn.masked_fill( + padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), + float("-inf"), + ) + + attn = attn.softmax(dim=-1, dtype=torch.float32).to(dtype=dtype) + attn = self.attn_drop(attn) + x = (attn @ v).transpose(1, 2) # + x = x.reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class EncDecBlock(nn.Module): + def __init__( + self, + q_dim, + kv_dim, + num_heads, + mlp_ratio=4.0, + qkv_bias=False, + qk_scale=None, + drop=0.0, + attn_drop=0.0, + mlp_drop=0.0, + post_mlp_drop=0.0, + drop_path=0.0, + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + layer_norm_first=True, + cosine_attention=False, + first_residual=True, + ): + super().__init__() + + self.layer_norm_first = layer_norm_first + + from timm.models.vision_transformer import DropPath, Mlp + + self.norm1 = norm_layer(q_dim) + self.attn = EncDecAttention( + q_dim, + kv_dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop, + cosine_attention=cosine_attention, + ) + + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + self.norm2 = norm_layer(q_dim) + mlp_hidden_dim = int(q_dim * mlp_ratio) + self.mlp = Mlp( + in_features=q_dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=mlp_drop, + ) + self.post_mlp_dropout = nn.Dropout(post_mlp_drop, inplace=False) + self.first_residual = first_residual + + def forward(self, q, kv, padding_mask=None, alibi_bias=None): + r = q if self.first_residual else 0 + if self.layer_norm_first: + x = r + self.drop_path( + self.attn(self.norm1(q), kv, padding_mask, alibi_bias) + ) + r = x = self.mlp(self.norm2(x)) + x = r + self.drop_path(self.post_mlp_dropout(x)) + else: + x = r + self.drop_path(self.attn(q, kv, padding_mask, alibi_bias)) + r = x = self.norm1(x) + x = self.mlp(x) + x = self.norm2(r + self.drop_path(self.post_mlp_dropout(x))) + + return x + + +class EncDecTransformerDecoder(nn.Module): + def __init__(self, cfg: D2vDecoderConfig, input_dim): + super().__init__() + + self.input_proj = nn.Linear(input_dim, cfg.decoder_dim) + + self.blocks = nn.Sequential( + *[ + EncDecBlock( + q_dim=cfg.decoder_dim, + kv_dim=input_dim, + num_heads=8, + mlp_ratio=4.0, + qkv_bias=True, + qk_scale=None, + drop=0.0, + attn_drop=0.0, + mlp_drop=0.0, + post_mlp_drop=0.0, + drop_path=0.0, + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + layer_norm_first=False, + cosine_attention=False, + first_residual=i > 0, + ) + for i in range(cfg.decoder_layers) + ] + ) + + self.proj = nn.Linear(cfg.decoder_dim, input_dim) + + def reset_parameters(self): + from fairseq.modules.transformer_sentence_encoder import init_bert_params + + self.apply(init_bert_params) + + def forward(self, x, kv): + x = self.input_proj(x) + for i, layer in enumerate(self.blocks): + x = layer(x, kv) + + x = self.proj(x) + return x diff --git a/examples/data2vec/models/modalities/text.py b/examples/data2vec/models/modalities/text.py new file mode 100644 index 000000000..adfac1ca4 --- /dev/null +++ b/examples/data2vec/models/modalities/text.py @@ -0,0 +1,161 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math +from dataclasses import dataclass +from functools import partial +from typing import Callable, Dict, Optional + +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from fairseq.modules import PositionalEmbedding, FairseqDropout, LayerNorm +from fairseq.tasks import FairseqTask +from .base import D2vModalityConfig, ModalitySpecificEncoder, get_alibi_bias +from .modules import BlockEncoder, Decoder1d +from examples.data2vec.data.modality import Modality + + +@dataclass +class D2vTextConfig(D2vModalityConfig): + type: Modality = Modality.TEXT + max_source_positions: int = 512 + learned_pos: bool = True + dropout: float = 0.1 # used for both local_encoder and contextualized encoder. tied with global transformer in data2vec_text + + no_scale_embedding: bool = True + layernorm_embedding: bool = True + no_token_positional_embeddings: bool = False + + +class TextEncoder(ModalitySpecificEncoder): + + modality_cfg: D2vTextConfig + + def __init__( + self, + modality_cfg: D2vTextConfig, + embed_dim: int, + make_block: Callable[[float], nn.ModuleList], + norm_layer: Callable[[int], nn.LayerNorm], + layer_norm_first: bool, + alibi_biases: Dict, + task: Optional[FairseqTask], + ): + self.pad_idx = task.source_dictionary.pad() + self.vocab_size = len(task.source_dictionary) + + local_encoder = TextLocalEncoder( + vocab_size=self.vocab_size, + embed_dim=embed_dim, + max_source_positions=modality_cfg.max_source_positions, + pad_idx=self.pad_idx, + no_scale_embedding=modality_cfg.no_scale_embedding, + layernorm_embedding=modality_cfg.layernorm_embedding, + dropout=modality_cfg.dropout, + no_token_positional_embeddings=modality_cfg.no_token_positional_embeddings, + learned_pos=modality_cfg.learned_pos, + ) + dpr = np.linspace( + modality_cfg.start_drop_path_rate, + modality_cfg.end_drop_path_rate, + modality_cfg.prenet_depth, + ) + context_encoder = BlockEncoder( + nn.ModuleList(make_block(dpr[i]) for i in range(modality_cfg.prenet_depth)), + norm_layer(embed_dim) + if not layer_norm_first and modality_cfg.prenet_depth > 0 + else None, + layer_norm_first, + modality_cfg.prenet_layerdrop, + modality_cfg.prenet_dropout if modality_cfg.prenet_depth > 0 else 0.0, + ) + decoder = ( + Decoder1d(modality_cfg.decoder, embed_dim) + if modality_cfg.decoder is not None + else None + ) + + alibi_bias_fn = partial(get_alibi_bias, alibi_biases=alibi_biases) + + super().__init__( + modality_cfg=modality_cfg, + embed_dim=embed_dim, + local_encoder=local_encoder, + project_features=nn.Identity(), + fixed_positional_encoder=None, + relative_positional_encoder=None, + context_encoder=context_encoder, + decoder=decoder, + get_alibi_bias=alibi_bias_fn, + ) + + def reset_parameters(self): + super().reset_parameters() + + def convert_padding_mask(self, x, padding_mask): + if padding_mask is None or padding_mask.size(1) == x.size(1): + return padding_mask + + diff = self.downsample - padding_mask.size(1) % self.downsample + if 0 < diff < self.downsample: + padding_mask = F.pad(padding_mask, (0, diff), value=True) + + padding_mask = padding_mask.view(padding_mask.size(0), -1, self.downsample) + padding_mask = padding_mask.all(-1) + if padding_mask.size(1) > x.size(1): + padding_mask = padding_mask[:, : x.size(1)] + + assert x.size(1) == padding_mask.size( + 1 + ), f"{x.size(1), padding_mask.size(1), diff, self.downsample}" + + return padding_mask + + +class TextLocalEncoder(nn.Module): + def __init__( + self, + vocab_size, + embed_dim, + max_source_positions, + pad_idx, + no_scale_embedding, + layernorm_embedding, + dropout, + no_token_positional_embeddings, + learned_pos, + ): + super().__init__() + self.pad_idx = pad_idx + self.dropout_module = FairseqDropout(dropout) + + self.embed_tokens = nn.Embedding(vocab_size, embed_dim, pad_idx) + self.embed_scale = 1.0 if no_scale_embedding else math.sqrt(embed_dim) + self.embed_positions = ( + PositionalEmbedding( + max_source_positions, + embed_dim, + pad_idx, + learned=learned_pos, + ) + if not no_token_positional_embeddings + else None + ) + self.embed_scale = 1.0 if no_scale_embedding else math.sqrt(embed_dim) + + self.layernorm_embedding = None + if layernorm_embedding: + self.layernorm_embedding = LayerNorm(embed_dim) + + def forward(self, src_tokens): + x = self.embed_scale * self.embed_tokens(src_tokens) + if self.embed_positions is not None: + x = x + self.embed_positions(src_tokens) + + if self.layernorm_embedding is not None: + x = self.layernorm_embedding(x) + x = self.dropout_module(x) + return x diff --git a/examples/data2vec/models/utils.py b/examples/data2vec/models/utils.py new file mode 100644 index 000000000..0e2f240d4 --- /dev/null +++ b/examples/data2vec/models/utils.py @@ -0,0 +1,55 @@ +import math +import torch + +def get_alibi( + max_positions: int, + attention_heads: int, +): + def get_slopes(n): + def get_slopes_power_of_2(n): + start = 2 ** (-(2 ** -(math.log2(n) - 3))) + ratio = start + return [start * ratio ** i for i in range(n)] + + # In the paper, we only train models that have 2^a heads for some + # a. This function has some good properties that only occur when + # the input is a power of 2. To maintain that even when the number + # of heads is not a power of 2, we use this workaround. + if math.log2(n).is_integer(): + return get_slopes_power_of_2(n) + else: + closest_power_of_2 = 2 ** math.floor(math.log2(n)) + return ( + get_slopes_power_of_2(closest_power_of_2) + + get_slopes(2 * closest_power_of_2)[0::2][: n - closest_power_of_2] + ) + + maxpos = max_positions + attn_heads = attention_heads + slopes = torch.Tensor(get_slopes(attn_heads)) + # prepare alibi position linear bias. Note that wav2vec2 is non + # autoregressive model so we want a symmetric mask with 0 on the + # diagonal and other wise linear decreasing valuees + pos_bias = ( + torch.abs( + torch.arange(maxpos).unsqueeze(0) - torch.arange(maxpos).unsqueeze(1) + ) + * -1 + ) + alibi_bias = slopes.unsqueeze(1).unsqueeze(1) * pos_bias.unsqueeze(0).expand( + attn_heads, -1, -1 + ) + return alibi_bias + +def masked_alibi(alibi_bias, mask_indices, orig_B, orig_T): + alibi_bias = alibi_bias.view(orig_B, -1, orig_T, orig_T) + H = alibi_bias.size(1) + alibi_mask = mask_indices.unsqueeze(1) + alibi_bias = alibi_bias.masked_select(alibi_mask.unsqueeze(-1)) + alibi_bias = alibi_bias.view(orig_B, H, -1, orig_T) + M = alibi_bias.size(-2) + alibi_bias = alibi_bias.masked_select(alibi_mask.unsqueeze(-2)) + alibi_bias = alibi_bias.view(-1, M, M) + return alibi_bias + + diff --git a/examples/data2vec/scripts/convert_audioset_labels.py b/examples/data2vec/scripts/convert_audioset_labels.py new file mode 100644 index 000000000..7d720e606 --- /dev/null +++ b/examples/data2vec/scripts/convert_audioset_labels.py @@ -0,0 +1,63 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import os + + +def get_parser(): + parser = argparse.ArgumentParser(description="convert audioset labels") + # fmt: off + parser.add_argument('in_file', help='audioset csv file to convert') + parser.add_argument('--manifest', required=True, metavar='PATH', help='wav2vec-like manifest') + parser.add_argument('--descriptors', required=True, metavar='PATH', help='path to label descriptor file') + parser.add_argument('--output', required=True, metavar='PATH', help='where to output converted labels') + # fmt: on + + return parser + + +def main(): + parser = get_parser() + args = parser.parse_args() + + label_descriptors = {} + with open(args.descriptors, "r") as ldf: + next(ldf) + for line in ldf: + if line.strip() == "": + continue + + items = line.split(",") + assert len(items) > 2, line + idx = items[0] + lbl = items[1] + assert lbl not in label_descriptors, lbl + label_descriptors[lbl] = idx + + labels = {} + with open(args.in_file, "r") as ifd: + for line in ifd: + if line.lstrip().startswith("#"): + continue + items = line.rstrip().split(",") + id = items[0].strip() + start = items[1].strip() + end = items[2].strip() + lbls = [label_descriptors[it.strip(' "')] for it in items[3:]] + labels[id] = [start, end, ",".join(lbls)] + + with open(args.manifest, "r") as mf, open(args.output, "w") as of: + next(mf) + for line in mf: + path, _ = line.split("\t") + id = os.path.splitext(os.path.basename(path))[0] + lbl = labels[id] + print("\t".join(lbl), file=of) + + +if __name__ == "__main__": + main() diff --git a/examples/data2vec/scripts/multi/finetune_all_fair_aws_local_lr.sh b/examples/data2vec/scripts/multi/finetune_all_fair_aws_local_lr.sh new file mode 100755 index 000000000..41bcd31fc --- /dev/null +++ b/examples/data2vec/scripts/multi/finetune_all_fair_aws_local_lr.sh @@ -0,0 +1,18 @@ +#!/bin/bash + +set -eu + +job_id="$1" +task_id="$2" +dir="$3" + +echo "job_id: $job_id, task_id: $task_id, dir: $dir" + +mkdir -p "$dir/log" +sbatch_args="-p wav2vec --nodes=1 --ntasks-per-node=1" +sbatch_args="$sbatch_args --gpus-per-node=1 --cpus-per-task=8 --mem=0 --time=24:00:00" +sbatch_args="$sbatch_args -d afterok:$job_id -o $dir/log/decode_sweep_%A.out" +sbatch_args="$sbatch_args -e $dir/log/decode_sweep_%A.err" + +sbatch $sbatch_args examples/data2vec/scripts/multi/finetune_all_fair_local_lr.sh $dir + diff --git a/examples/data2vec/scripts/multi/finetune_all_fair_aws_local_lr_nodep.sh b/examples/data2vec/scripts/multi/finetune_all_fair_aws_local_lr_nodep.sh new file mode 100644 index 000000000..fc85908b7 --- /dev/null +++ b/examples/data2vec/scripts/multi/finetune_all_fair_aws_local_lr_nodep.sh @@ -0,0 +1,16 @@ +#!/bin/bash + +set -eu + +dir="$1" + +echo "dir: $dir" + +mkdir -p "$dir/log" +sbatch_args="-p wav2vec --nodes=1 --ntasks-per-node=1" +sbatch_args="$sbatch_args --gpus-per-node=1 --cpus-per-task=8 --mem=0 --time=24:00:00" +sbatch_args="$sbatch_args -o $dir/log/decode_sweep_%A.out" +sbatch_args="$sbatch_args -e $dir/log/decode_sweep_%A.err" + +sbatch $sbatch_args examples/data2vec/scripts/multi/finetune_all_fair_local_lr.sh $dir + diff --git a/examples/data2vec/scripts/multi/finetune_all_fair_local_lr.sh b/examples/data2vec/scripts/multi/finetune_all_fair_local_lr.sh new file mode 100755 index 000000000..121226972 --- /dev/null +++ b/examples/data2vec/scripts/multi/finetune_all_fair_local_lr.sh @@ -0,0 +1,28 @@ +#!/usr/bin/env zsh + +dir="$1" +cp="$dir/checkpoints/checkpoint_last.pt" + +echo "dir: $dir" + +declare -A tasks +tasks[cola]="/fsx-wav2vec/abaevski/data/nlp/GLUE/CoLA-bin" +tasks[qnli]="/fsx-wav2vec/abaevski/data/nlp/GLUE/QNLI-bin" +tasks[mrpc]="/fsx-wav2vec/abaevski/data/nlp/GLUE/MRPC-bin" +tasks[rte]="/fsx-wav2vec/abaevski/data/nlp/GLUE/RTE-bin" +tasks[sst_2]="/fsx-wav2vec/abaevski/data/nlp/GLUE/SST-2-bin" +tasks[mnli]="/fsx-wav2vec/abaevski/data/nlp/GLUE/MNLI-bin" +tasks[qqp]="/fsx-wav2vec/abaevski/data/nlp/GLUE/QQP-bin" +tasks[sts_b]="/fsx-wav2vec/abaevski/data/nlp/GLUE/STS-B-bin" + +lrs=(5e-6 8e-6 1e-5 2e-5) + +for task data_path in ${(kv)tasks}; do + for lr in $lrs; do + echo $lr $task + PYTHONPATH=. PREFIX="${PREFIX}" SUFFIX="" \ + python fairseq_cli/hydra_train.py -m --config-dir examples/data2vec/config/multi/text_finetuning \ + --config-name $task +run_config=local task.data="$data_path" common.log_interval=200 dataset.num_workers=1 \ + model.model_path="$cp" hydra.sweep.dir="$dir/finetune_lr/$task/$lr" "optimization.lr=[${lr}]" +model=text_wrap + done +done diff --git a/examples/data2vec/scripts/text/finetune_all_char_fair_aws_local_lr.sh b/examples/data2vec/scripts/text/finetune_all_char_fair_aws_local_lr.sh new file mode 100755 index 000000000..18b862c24 --- /dev/null +++ b/examples/data2vec/scripts/text/finetune_all_char_fair_aws_local_lr.sh @@ -0,0 +1,17 @@ +#!/bin/bash + +set -eu + +job_id="$1" +task_id="$2" +dir="$3" + +echo "job_id: $job_id, task_id: $task_id, dir: $dir" + +mkdir -p "$dir/log" +sbatch_args="-p wav2vec --nodes=1 --ntasks-per-node=1" +sbatch_args="$sbatch_args --gpus-per-node=1 --cpus-per-task=8 --mem=0 --time=24:00:00" +sbatch_args="$sbatch_args -d afterok:$job_id -o $dir/log/ft_%A.out" +sbatch_args="$sbatch_args -e $dir/log/ft_%A.err" + +sbatch $sbatch_args examples/data2vec/scripts/text/finetune_all_char_fair_local_lr.sh $dir diff --git a/examples/data2vec/scripts/text/finetune_all_fair.sh b/examples/data2vec/scripts/text/finetune_all_fair.sh new file mode 100755 index 000000000..34a2df399 --- /dev/null +++ b/examples/data2vec/scripts/text/finetune_all_fair.sh @@ -0,0 +1,21 @@ +#!/usr/bin/env zsh + +job_id=$1 +task_id=$2 +dir="$3" +cp="$dir/$task_id/checkpoints/checkpoint_last.pt" + +echo "job_id: $job_id, task_id: $task_id, dir: $dir" + +declare -A tasks +tasks[cola]="/private/home/jgu/data/GLUE/CoLA-bin" +tasks[qnli]="/private/home/jgu/data/GLUE/QNLI-bin" +tasks[mrpc]="/private/home/jgu/data/GLUE/MRPC-bin" +tasks[rte]="/private/home/jgu/data/GLUE/RTE-bin" +tasks[sst_2]="/private/home/jgu/data/GLUE/SST-2-bin" + +for task data_path in ${(kv)tasks}; do + PYTHONPATH=. PREFIX="${PREFIX}" SUFFIX="" nohup python fairseq_cli/hydra_train.py -m --config-dir examples/roberta/config/finetuning \ + --config-name $task hydra/launcher=submitit_slurm +run_config=slurm_1g task.data="$data_path" hydra.launcher.name=finetune_${task}_${PREFIX} \ + checkpoint.restore_file="$cp" +hydra.launcher.additional_parameters.dependency="afterok:$job_id" hydra.sweep.dir="$dir/finetune/$task" & +done diff --git a/examples/data2vec/scripts/text/finetune_all_fair_aws.sh b/examples/data2vec/scripts/text/finetune_all_fair_aws.sh new file mode 100755 index 000000000..b417c2002 --- /dev/null +++ b/examples/data2vec/scripts/text/finetune_all_fair_aws.sh @@ -0,0 +1,21 @@ +#!/usr/bin/env zsh + +job_id=$1 +task_id=$2 +dir="$3" +cp="$dir/checkpoints/checkpoint_last.pt" + +echo "job_id: $job_id, task_id: $task_id, dir: $dir" + +declare -A tasks +tasks[cola]="/fsx-wav2vec/abaevski/data/nlp/GLUE/CoLA-bin" +tasks[qnli]="/fsx-wav2vec/abaevski/data/nlp/GLUE/QNLI-bin" +tasks[mrpc]="/fsx-wav2vec/abaevski/data/nlp/GLUE/MRPC-bin" +tasks[rte]="/fsx-wav2vec/abaevski/data/nlp/GLUE/RTE-bin" +tasks[sst_2]="/fsx-wav2vec/abaevski/data/nlp/GLUE/SST-2-bin" + +for task data_path in ${(kv)tasks}; do + PYTHONPATH=. PREFIX="${PREFIX}" SUFFIX="" nohup python fairseq_cli/hydra_train.py -m --config-dir examples/roberta/config/finetuning \ + --config-name $task hydra/launcher=submitit_slurm +run_config=slurm_1g_aws task.data="$data_path" hydra.launcher.name=finetune_${task}_${PREFIX} \ + checkpoint.restore_file="$cp" +hydra.launcher.additional_parameters.dependency="afterok:$job_id" hydra.sweep.dir="$dir/finetune/$task" & +done diff --git a/examples/data2vec/scripts/text/finetune_all_fair_aws_local_lr.sh b/examples/data2vec/scripts/text/finetune_all_fair_aws_local_lr.sh new file mode 100755 index 000000000..64dbcb111 --- /dev/null +++ b/examples/data2vec/scripts/text/finetune_all_fair_aws_local_lr.sh @@ -0,0 +1,17 @@ +#!/bin/bash + +set -eu + +job_id="$1" +task_id="$2" +dir="$3" + +echo "job_id: $job_id, task_id: $task_id, dir: $dir" + +mkdir -p "$dir/log" +sbatch_args="-p wav2vec --nodes=1 --ntasks-per-node=1" +sbatch_args="$sbatch_args --gpus-per-node=1 --cpus-per-task=8 --mem=0 --time=24:00:00" +sbatch_args="$sbatch_args -d afterok:$job_id -o $dir/log/decode_sweep_%A.out" +sbatch_args="$sbatch_args -e $dir/log/decode_sweep_%A.err" + +sbatch $sbatch_args examples/data2vec/scripts/text/finetune_all_fair_local_lr.sh $dir diff --git a/examples/data2vec/scripts/text/finetune_all_fair_aws_lr.sh b/examples/data2vec/scripts/text/finetune_all_fair_aws_lr.sh new file mode 100755 index 000000000..d75c54957 --- /dev/null +++ b/examples/data2vec/scripts/text/finetune_all_fair_aws_lr.sh @@ -0,0 +1,23 @@ +#!/usr/bin/env zsh + +job_id=$1 +task_id=$2 +dir="$3" +cp="$dir/checkpoints/checkpoint_last.pt" + +echo "job_id: $job_id, task_id: $task_id, dir: $dir" + +declare -A tasks +tasks[cola]="/fsx-wav2vec/abaevski/data/nlp/GLUE/CoLA-bin" +tasks[qnli]="/fsx-wav2vec/abaevski/data/nlp/GLUE/QNLI-bin" +tasks[mrpc]="/fsx-wav2vec/abaevski/data/nlp/GLUE/MRPC-bin" +tasks[rte]="/fsx-wav2vec/abaevski/data/nlp/GLUE/RTE-bin" +tasks[sst_2]="/fsx-wav2vec/abaevski/data/nlp/GLUE/SST-2-bin" + +for task data_path in ${(kv)tasks}; do + for lr in 5e-6 8e-6 1e-5 2e-5 5e-5 8e-5 1e-4 2e-4; do + PYTHONPATH=. PREFIX="${PREFIX}" SUFFIX="" nohup python fairseq_cli/hydra_train.py -m --config-dir examples/roberta/config/finetuning \ + --config-name $task hydra/launcher=submitit_slurm +run_config=slurm_1g_aws task.data="$data_path" hydra.launcher.name=finetune_${task}_${PREFIX} \ + checkpoint.restore_file="$cp" +hydra.launcher.additional_parameters.dependency="afterok:$job_id" hydra.sweep.dir="$dir/finetune_lr/$task/$lr" "optimization.lr=[${lr}]" & + done +done diff --git a/examples/data2vec/scripts/text/finetune_all_fair_local_lr.sh b/examples/data2vec/scripts/text/finetune_all_fair_local_lr.sh new file mode 100755 index 000000000..8be98c084 --- /dev/null +++ b/examples/data2vec/scripts/text/finetune_all_fair_local_lr.sh @@ -0,0 +1,25 @@ +#!/usr/bin/env zsh + +dir="$1" +cp="$dir/checkpoints/checkpoint_last.pt" + +echo "dir: $dir" + +declare -A tasks +tasks[cola]="/fsx-wav2vec/abaevski/data/nlp/GLUE/CoLA-bin" +tasks[qnli]="/fsx-wav2vec/abaevski/data/nlp/GLUE/QNLI-bin" +tasks[mrpc]="/fsx-wav2vec/abaevski/data/nlp/GLUE/MRPC-bin" +tasks[rte]="/fsx-wav2vec/abaevski/data/nlp/GLUE/RTE-bin" +tasks[sst_2]="/fsx-wav2vec/abaevski/data/nlp/GLUE/SST-2-bin" + +lrs=(5e-6 8e-6 1e-5 2e-5) + +for task data_path in ${(kv)tasks}; do + for lr in $lrs; do + echo $lr $task + PYTHONPATH=. PREFIX="${PREFIX}" SUFFIX="" \ + python fairseq_cli/hydra_train.py -m --config-dir examples/roberta/config/finetuning \ + --config-name $task +run_config=local task.data="$data_path" common.log_interval=200 dataset.num_workers=1 \ + checkpoint.restore_file="$cp" hydra.sweep.dir="$dir/finetune_lr/$task/$lr" "optimization.lr=[${lr}]" + done +done diff --git a/examples/data2vec/scripts/text/finetune_all_fair_nodep.sh b/examples/data2vec/scripts/text/finetune_all_fair_nodep.sh new file mode 100755 index 000000000..d02bcc0f7 --- /dev/null +++ b/examples/data2vec/scripts/text/finetune_all_fair_nodep.sh @@ -0,0 +1,19 @@ +#!/usr/bin/env zsh + +dir="$1" +cp="$dir/checkpoints/checkpoint_last.pt" + +echo "dir: $dir" + +declare -A tasks +tasks[cola]="/private/home/jgu/data/GLUE/CoLA-bin" +tasks[qnli]="/private/home/jgu/data/GLUE/QNLI-bin" +tasks[mrpc]="/private/home/jgu/data/GLUE/MRPC-bin" +tasks[rte]="/private/home/jgu/data/GLUE/RTE-bin" +tasks[sst_2]="/private/home/jgu/data/GLUE/SST-2-bin" + +for task data_path in ${(kv)tasks}; do + PYTHONPATH=. PREFIX="${PREFIX}" SUFFIX="" nohup python fairseq_cli/hydra_train.py -m --config-dir examples/roberta/config/finetuning \ + --config-name $task hydra/launcher=submitit_slurm +run_config=slurm_1g task.data="$data_path" hydra.launcher.name=finetune_${task}_${PREFIX} \ + checkpoint.restore_file="$cp" hydra.sweep.dir="$dir/finetune/$task" & +done diff --git a/examples/data2vec/scripts/text/finetune_all_fair_nodep_aws.sh b/examples/data2vec/scripts/text/finetune_all_fair_nodep_aws.sh new file mode 100755 index 000000000..75538354e --- /dev/null +++ b/examples/data2vec/scripts/text/finetune_all_fair_nodep_aws.sh @@ -0,0 +1,19 @@ +#!/usr/bin/env zsh + +dir="$1" +cp="$dir/checkpoints/checkpoint_last.pt" + +echo "dir: $dir" + +declare -A tasks +tasks[cola]="/fsx-wav2vec/abaevski/data/nlp/GLUE/CoLA-bin" +tasks[qnli]="/fsx-wav2vec/abaevski/data/nlp/GLUE/QNLI-bin" +tasks[mrpc]="/fsx-wav2vec/abaevski/data/nlp/GLUE/MRPC-bin" +tasks[rte]="/fsx-wav2vec/abaevski/data/nlp/GLUE/RTE-bin" +tasks[sst_2]="/fsx-wav2vec/abaevski/data/nlp/GLUE/SST-2-bin" + +for task data_path in ${(kv)tasks}; do + PYTHONPATH=. PREFIX="${PREFIX}" SUFFIX="" nohup python fairseq_cli/hydra_train.py -m --config-dir examples/roberta/config/finetuning \ + --config-name $task hydra/launcher=submitit_slurm +run_config=slurm_1g_aws task.data="$data_path" hydra.launcher.name=finetune_${task}_${PREFIX} \ + checkpoint.restore_file="$cp" hydra.sweep.dir="$dir/finetune/$task" & +done diff --git a/examples/data2vec/scripts/text/finetune_all_fair_nodep_aws_local_lr.sh b/examples/data2vec/scripts/text/finetune_all_fair_nodep_aws_local_lr.sh new file mode 100755 index 000000000..16c1358b2 --- /dev/null +++ b/examples/data2vec/scripts/text/finetune_all_fair_nodep_aws_local_lr.sh @@ -0,0 +1,15 @@ +#!/bin/bash + +set -eu + +dir="$1" + +echo "dir: $dir" + +mkdir -p "$dir/log" +sbatch_args="-p wav2vec --nodes=1 --ntasks-per-node=1" +sbatch_args="$sbatch_args --gpus-per-node=1 --cpus-per-task=8 --mem=0 --time=24:00:00" +sbatch_args="$sbatch_args -o $dir/log/decode_sweep_%A.out" +sbatch_args="$sbatch_args -e $dir/log/decode_sweep_%A.err" + +sbatch $sbatch_args examples/data2vec/scripts/text/finetune_all_fair_local_lr.sh $dir diff --git a/examples/data2vec/scripts/text/finetune_all_fair_nodep_aws_lr.sh b/examples/data2vec/scripts/text/finetune_all_fair_nodep_aws_lr.sh new file mode 100755 index 000000000..fb5ddbe22 --- /dev/null +++ b/examples/data2vec/scripts/text/finetune_all_fair_nodep_aws_lr.sh @@ -0,0 +1,21 @@ +#!/usr/bin/env zsh + +dir="$1" +cp="$dir/checkpoints/checkpoint_last.pt" + +echo "dir: $dir" + +declare -A tasks +tasks[cola]="/fsx-wav2vec/abaevski/data/nlp/GLUE/CoLA-bin" +tasks[qnli]="/fsx-wav2vec/abaevski/data/nlp/GLUE/QNLI-bin" +tasks[mrpc]="/fsx-wav2vec/abaevski/data/nlp/GLUE/MRPC-bin" +tasks[rte]="/fsx-wav2vec/abaevski/data/nlp/GLUE/RTE-bin" +tasks[sst_2]="/fsx-wav2vec/abaevski/data/nlp/GLUE/SST-2-bin" + +for task data_path in ${(kv)tasks}; do + for lr in 5e-6 8e-6 1e-5 2e-5 5e-5 8e-5 1e-4 2e-4; do + PYTHONPATH=. PREFIX="${PREFIX}" SUFFIX="" nohup python fairseq_cli/hydra_train.py -m --config-dir examples/roberta/config/finetuning \ + --config-name $task hydra/launcher=submitit_slurm +run_config=slurm_1g_aws task.data="$data_path" hydra.launcher.name=finetune_${task}_${PREFIX} \ + checkpoint.restore_file="$cp" hydra.sweep.dir="$dir/finetune_lr/$task/$lr" "optimization.lr=[${lr}]" & + done +done diff --git a/examples/data2vec/scripts/text/finetune_all_fair_nodep_aws_lr_nopos.sh b/examples/data2vec/scripts/text/finetune_all_fair_nodep_aws_lr_nopos.sh new file mode 100755 index 000000000..1ffab1c85 --- /dev/null +++ b/examples/data2vec/scripts/text/finetune_all_fair_nodep_aws_lr_nopos.sh @@ -0,0 +1,21 @@ +#!/usr/bin/env zsh + +dir="$1" +cp="$dir/checkpoints/checkpoint_last.pt" + +echo "dir: $dir" + +declare -A tasks +tasks[cola]="/fsx-wav2vec/abaevski/data/nlp/GLUE/CoLA-bin" +tasks[qnli]="/fsx-wav2vec/abaevski/data/nlp/GLUE/QNLI-bin" +tasks[mrpc]="/fsx-wav2vec/abaevski/data/nlp/GLUE/MRPC-bin" +tasks[rte]="/fsx-wav2vec/abaevski/data/nlp/GLUE/RTE-bin" +tasks[sst_2]="/fsx-wav2vec/abaevski/data/nlp/GLUE/SST-2-bin" + +for task data_path in ${(kv)tasks}; do + for lr in 5e-6 8e-6 1e-5 2e-5 5e-5 8e-5 1e-4 2e-4; do + PYTHONPATH=. PREFIX="${PREFIX}" SUFFIX="" nohup python fairseq_cli/hydra_train.py -m --config-dir examples/roberta/config/finetuning \ + --config-name $task hydra/launcher=submitit_slurm +run_config=slurm_1g_aws task.data="$data_path" hydra.launcher.name=finetune_${task}_${PREFIX} \ + checkpoint.restore_file="$cp" hydra.sweep.dir="$dir/finetune_lr/$task/$lr" "optimization.lr=[${lr}]" +model.encoder_learned_pos=False & + done +done diff --git a/examples/data2vec/scripts/text/finetune_all_large_fair_aws_local_lr.sh b/examples/data2vec/scripts/text/finetune_all_large_fair_aws_local_lr.sh new file mode 100755 index 000000000..c3c58adcb --- /dev/null +++ b/examples/data2vec/scripts/text/finetune_all_large_fair_aws_local_lr.sh @@ -0,0 +1,17 @@ +#!/bin/bash + +set -eu + +job_id="$1" +task_id="$2" +dir="$3" + +echo "job_id: $job_id, task_id: $task_id, dir: $dir" + +mkdir -p "$dir/log" +sbatch_args="-p wav2vec --nodes=1 --ntasks-per-node=1" +sbatch_args="$sbatch_args --gpus-per-node=1 --cpus-per-task=8 --mem=0 --time=24:00:00" +sbatch_args="$sbatch_args -d afterok:$job_id -o $dir/log/decode_sweep_%A.out" +sbatch_args="$sbatch_args -e $dir/log/decode_sweep_%A.err" + +sbatch $sbatch_args examples/data2vec/scripts/text/finetune_all_large_fair_local_lr.sh $dir diff --git a/examples/data2vec/scripts/text/finetune_all_large_fair_local_lr.sh b/examples/data2vec/scripts/text/finetune_all_large_fair_local_lr.sh new file mode 100644 index 000000000..5efb00e0d --- /dev/null +++ b/examples/data2vec/scripts/text/finetune_all_large_fair_local_lr.sh @@ -0,0 +1,26 @@ +#!/usr/bin/env zsh + +dir="$1" +cp="$dir/checkpoints/checkpoint_last.pt" + +echo "dir: $dir" + +declare -A tasks +tasks[cola]="/fsx-wav2vec/abaevski/data/nlp/GLUE/CoLA-bin" +tasks[qnli]="/fsx-wav2vec/abaevski/data/nlp/GLUE/QNLI-bin" +tasks[mrpc]="/fsx-wav2vec/abaevski/data/nlp/GLUE/MRPC-bin" +tasks[rte]="/fsx-wav2vec/abaevski/data/nlp/GLUE/RTE-bin" +tasks[sst_2]="/fsx-wav2vec/abaevski/data/nlp/GLUE/SST-2-bin" + +lrs=(5e-6 8e-6 1e-5 2e-5) + +for task data_path in ${(kv)tasks}; do + for lr in $lrs; do + echo $lr $task + PYTHONPATH=. PREFIX="${PREFIX}" SUFFIX="" \ + python fairseq_cli/hydra_train.py -m --config-dir examples/roberta/config/finetuning \ + --config-name $task +run_config=local task.data="$data_path" common.log_interval=200 dataset.num_workers=1 \ + checkpoint.restore_file="$cp" hydra.sweep.dir="$dir/finetune_lr/$task/$lr" "optimization.lr=[${lr}]" \ + model._name=roberta_large + done +done diff --git a/examples/data2vec/scripts/text/finetune_all_large_fair_nodep_aws_local_lr.sh b/examples/data2vec/scripts/text/finetune_all_large_fair_nodep_aws_local_lr.sh new file mode 100755 index 000000000..4fb21bce7 --- /dev/null +++ b/examples/data2vec/scripts/text/finetune_all_large_fair_nodep_aws_local_lr.sh @@ -0,0 +1,15 @@ +#!/bin/bash + +set -eu + +dir="$1" + +echo "dir: $dir" + +mkdir -p "$dir/log" +sbatch_args="-p wav2vec --nodes=1 --ntasks-per-node=1" +sbatch_args="$sbatch_args --gpus-per-node=1 --cpus-per-task=8 --mem=0 --time=24:00:00" +sbatch_args="$sbatch_args -o $dir/log/decode_sweep_%A.out" +sbatch_args="$sbatch_args -e $dir/log/decode_sweep_%A.err" + +sbatch $sbatch_args examples/data2vec/scripts/text/finetune_all_large_fair_local_lr.sh $dir diff --git a/examples/data2vec/scripts/text/finetune_sst2_qnli_sweep_fair_nodep.sh b/examples/data2vec/scripts/text/finetune_sst2_qnli_sweep_fair_nodep.sh new file mode 100755 index 000000000..d7b43bee8 --- /dev/null +++ b/examples/data2vec/scripts/text/finetune_sst2_qnli_sweep_fair_nodep.sh @@ -0,0 +1,20 @@ +#!/usr/bin/env zsh + +dir="$1" +cp="$dir/checkpoints/checkpoint_last.pt" + +echo "dir: $dir" + +declare -A tasks +tasks[qnli]="/private/home/jgu/data/GLUE/QNLI-bin" +tasks[sst_2]="/private/home/jgu/data/GLUE/SST-2-bin" + +lrs="5e-6 1e-5 2e-5 5e-5 1e-4 2e-4 5e-4 1e-3" + +for task data_path in ${(kv)tasks}; do + for lr in $(echo "$lrs"); do + PYTHONPATH=. PREFIX="${PREFIX}" SUFFIX="" nohup python fairseq_cli/hydra_train.py -m --config-dir examples/roberta/config/finetuning \ + --config-name $task hydra/launcher=submitit_slurm +run_config=slurm_1g task.data="$data_path" hydra.launcher.name=finetune_${task}_${PREFIX} \ + checkpoint.restore_file="$cp" hydra.sweep.dir="$dir/finetune_sweep/$task/lr_$lr" "optimization.lr=[${lr}]" & + done +done diff --git a/examples/data2vec/scripts/text/glue.py b/examples/data2vec/scripts/text/glue.py new file mode 100644 index 000000000..5382d3183 --- /dev/null +++ b/examples/data2vec/scripts/text/glue.py @@ -0,0 +1,34 @@ +from valids import parser, main as valids_main +import os.path as osp + + +args = parser.parse_args() +args.target = "valid_accuracy" +args.best_biggest = True +args.best = True +args.last = 0 +args.path_contains = None + +res = valids_main(args, print_output=False) + +grouped = {} +for k, v in res.items(): + k = osp.dirname(k) + run = osp.dirname(k) + task = osp.basename(k) + val = v["valid_accuracy"] + + if run not in grouped: + grouped[run] = {} + + grouped[run][task] = val + +for run, tasks in grouped.items(): + print(run) + avg = sum(float(v) for v in tasks.values()) / len(tasks) + avg_norte = sum(float(v) for k,v in tasks.items() if k != 'rte') / (len(tasks) -1) + try: + print(f"{tasks['cola']}\t{tasks['qnli']}\t{tasks['mrpc']}\t{tasks['rte']}\t{tasks['sst_2']}\t{avg:.2f}\t{avg_norte:.2f}") + except: + print(tasks) + print() diff --git a/examples/data2vec/scripts/text/glue_lr.py b/examples/data2vec/scripts/text/glue_lr.py new file mode 100644 index 000000000..75bdfe036 --- /dev/null +++ b/examples/data2vec/scripts/text/glue_lr.py @@ -0,0 +1,143 @@ +import os.path as osp +import re +from collections import defaultdict + +from valids import parser, main as valids_main + + +TASK_TO_METRIC = { + "cola": "mcc", + "qnli": "accuracy", + "mrpc": "acc_and_f1", + "rte": "accuracy", + "sst_2": "accuracy", + "mnli": "accuracy", + "qqp": "acc_and_f1", + "sts_b": "pearson_and_spearman", +} +TASKS = ["cola", "qnli", "mrpc", "rte", "sst_2", "mnli", "qqp", "sts_b"] + + +def get_best_stat_str(task_vals, show_subdir): + task_to_best_val = {} + task_to_best_dir = {} + for task, subdir_to_val in task_vals.items(): + task_to_best_val[task] = max(subdir_to_val.values()) + task_to_best_dir[task] = max(subdir_to_val.keys(), key=lambda x: subdir_to_val[x]) + + # import pdb; pdb.set_trace() + N1 = len(task_to_best_val) + N2 = len([k for k in task_to_best_val if k != "rte"]) + avg1 = sum(task_to_best_val.values()) / N1 + avg2 = sum(v for task, v in task_to_best_val.items() if task != "rte") / N2 + + try: + msg = "" + for task in TASKS: + dir = task_to_best_dir.get(task, 'null') + val = task_to_best_val.get(task, -100) + msg += f"({dir}, {val})\t" if show_subdir else f"{val}\t" + msg += f"{avg1:.2f}\t{avg2:.2f}" + except Exception as e: + msg = str(e) + msg += str(sorted(task_vals.items())) + return msg + +def get_all_stat_str(task_vals): + msg = "" + for task in [task for task in TASKS if task in task_vals]: + msg += f"=== {task}\n" + for subdir in sorted(task_vals[task].keys()): + msg += f"\t{subdir}\t{task_vals[task][subdir]}\n" + return msg + +def get_tabular_stat_str(task_vals): + """assume subdir is /run_*/0""" + msg = "" + for task in [task for task in TASKS if task in task_vals]: + msg += f"=== {task}\n" + param_to_runs = defaultdict(dict) + for subdir in task_vals[task]: + match = re.match("(.*)/(run_.*)/0", subdir) + assert match, "subdir" + param, run = match.groups() + param_to_runs[param][run] = task_vals[task][subdir] + params = sorted(param_to_runs, key=lambda x: float(x)) + runs = sorted(set(run for runs in param_to_runs.values() for run in runs)) + msg += ("runs:" + "\t".join(runs) + "\n") + msg += ("params:" + "\t".join(params) + "\n") + for param in params: + msg += "\t".join([str(param_to_runs[param].get(run, None)) for run in runs]) + msg += "\n" + # for subdir in sorted(task_vals[task].keys()): + # msg += f"\t{subdir}\t{task_vals[task][subdir]}\n" + return msg + + + +def main(): + parser.add_argument("--show_glue", action="store_true", help="show glue metric for each task instead of accuracy") + parser.add_argument("--print_mode", default="best", help="best|all|tabular") + parser.add_argument("--show_subdir", action="store_true", help="print the subdir that has the best results for each run") + parser.add_argument("--override_target", default="valid_accuracy", help="override target") + + args = parser.parse_args() + args.target = args.override_target + args.best_biggest = True + args.best = True + args.last = 0 + args.path_contains = None + + res = valids_main(args, print_output=False) + grouped_acc = {} + grouped_met = {} # use official metric for each task + for path, v in res.items(): + path = "/".join([args.base, path]) + path = re.sub("//*", "/", path) + match = re.match("(.*)finetune[^/]*/([^/]*)/(.*)", path) + if not match: + continue + run, task, subdir = match.groups() + + if run not in grouped_acc: + grouped_acc[run] = {} + grouped_met[run] = {} + if task not in grouped_acc[run]: + grouped_acc[run][task] = {} + grouped_met[run][task] = {} + + if v is not None: + grouped_acc[run][task][subdir] = float(v.get("valid_accuracy", -100)) + grouped_met[run][task][subdir] = float(v.get(f"valid_{TASK_TO_METRIC[task]}", -100)) + else: + print(f"{path} has None return") + + header = "\t".join(TASKS) + for run in sorted(grouped_acc): + print(run) + if args.print_mode == "all": + if args.show_glue: + print("===== GLUE =====") + print(get_all_stat_str(grouped_met[run])) + else: + print("===== ACC =====") + print(get_all_stat_str(grouped_acc[run])) + elif args.print_mode == "best": + print(f" {header}") + if args.show_glue: + print(f"GLEU: {get_best_stat_str(grouped_met[run], args.show_subdir)}") + else: + print(f"ACC: {get_best_stat_str(grouped_acc[run], args.show_subdir)}") + elif args.print_mode == "tabular": + if args.show_glue: + print("===== GLUE =====") + print(get_tabular_stat_str(grouped_met[run])) + else: + print("===== ACC =====") + print(get_tabular_stat_str(grouped_acc[run])) + else: + raise ValueError(args.print_mode) + print() + +if __name__ == "__main__": + main() diff --git a/examples/data2vec/scripts/text/unprocess_data.py b/examples/data2vec/scripts/text/unprocess_data.py new file mode 100644 index 000000000..f1acb624b --- /dev/null +++ b/examples/data2vec/scripts/text/unprocess_data.py @@ -0,0 +1,188 @@ +import json +import os +import tqdm +from fairseq.data import Dictionary, data_utils + + +def load_dictionary(dict_path): + return Dictionary.load(dict_path) + +def load_dataset(split_path, src_dict): + dataset = data_utils.load_indexed_dataset( + split_path, + src_dict, + combine=False, # set to true for loading `train*` + ) + if dataset is None: + raise FileNotFoundError(f"Dataset not found: {split_path}") + return dataset + +def load_bpe(enc_path): + with open(enc_path) as f: + bpe2idx = json.load(f) + idx2bpe = {v: k for k, v in bpe2idx.items()} + return bpe2idx, idx2bpe + +def detokenize(tokens, src_dict, idx2bpe): + raw_inds = map(int, src_dict.string(tokens).split()) + raw_chrs = "".join([idx2bpe[raw_ind] for raw_ind in raw_inds]) + raw_chrs = raw_chrs.replace("\u0120", " ") + return raw_chrs + +def _main(src_root, src_dict_path, src_bpe_path, src_splits, tgt_root, tgt_splits): + src_dict = load_dictionary(src_dict_path) + bpe2idx, idx2bpe = load_bpe(src_bpe_path) + + assert len(src_splits) == len(tgt_splits) + for src_split, tgt_split in zip(src_splits, tgt_splits): + src_dataset = load_dataset(f"{src_root}/{src_split}", src_dict) + tgt_path = f"{tgt_root}/{tgt_split}.txt" + print(f"processing {src_split} (dump to {tgt_path})...") + os.makedirs(os.path.dirname(tgt_path), exist_ok=True) + with open(tgt_path, "w") as f: + for tokens in tqdm.tqdm(src_dataset): + raw_str = detokenize(tokens, src_dict, idx2bpe) + f.write(raw_str + "\n") + +def main_pt(): + src_root = "/datasets01/bookwiki_CC-NEWS_openwebtext_stories-mmap2-bin/121219/bookwiki_CC-NEWS_openwebtext_stories-mmap2-bin" + src_dict_path = f"{src_root}/dict.txt" + src_bpe_path = f"{src_root}/encoder.json" + src_splits = [ + "bookwiki_aml-mmap2-bin/shard0/train", + "bookwiki_aml-mmap2-bin/shard1/train", + "bookwiki_aml-mmap2-bin/shard2/train", + "bookwiki_aml-mmap2-bin/shard3/train", + "bookwiki_aml-mmap2-bin/shard4/train", + "bookwiki_aml-mmap2-bin/valid/valid", + ] + + tgt_root = "/checkpoint/wnhsu/data/data2vec2/data/text/bookwiki_aml-full-mmap2-txt" + tgt_splits = [ + "train0", + "train1", + "train2", + "train3", + "train4", + "valid", + ] + _main(src_root, src_dict_path, src_bpe_path, src_splits, tgt_root, tgt_splits) + +def main_ft(): + src_root = "/fsx-wav2vec/wnhsu/data/data2vec2/data/text/GLUE" + src_dict_path = f"{src_root}/dict.txt" + src_bpe_path = f"{src_root}/encoder.json" + src_splits = [ + "CoLA-bin/input0/train", + "CoLA-bin/input0/valid", + "CoLA-bin/input0/test", + + "MNLI-bin/input0/train", + "MNLI-bin/input0/valid", + "MNLI-bin/input0/test", + "MNLI-bin/input0/test1", + "MNLI-bin/input1/train", + "MNLI-bin/input1/valid", + "MNLI-bin/input1/test", + "MNLI-bin/input1/test1", + + "MRPC-bin/input0/train", + "MRPC-bin/input0/valid", + "MRPC-bin/input0/test", + "MRPC-bin/input1/train", + "MRPC-bin/input1/valid", + "MRPC-bin/input1/test", + + "QNLI-bin/input0/train", + "QNLI-bin/input0/valid", + "QNLI-bin/input0/test", + "QNLI-bin/input1/train", + "QNLI-bin/input1/valid", + "QNLI-bin/input1/test", + + "QQP-bin/input0/train", + "QQP-bin/input0/valid", + "QQP-bin/input0/test", + "QQP-bin/input1/train", + "QQP-bin/input1/valid", + "QQP-bin/input1/test", + + "RTE-bin/input0/train", + "RTE-bin/input0/valid", + "RTE-bin/input0/test", + "RTE-bin/input1/train", + "RTE-bin/input1/valid", + "RTE-bin/input1/test", + + "SST-2-bin/input0/train", + "SST-2-bin/input0/valid", + "SST-2-bin/input0/test", + + "STS-B-bin/input0/train", + "STS-B-bin/input0/valid", + "STS-B-bin/input0/test", + "STS-B-bin/input1/train", + "STS-B-bin/input1/valid", + "STS-B-bin/input1/test", + ] + + tgt_root = "/fsx-wav2vec/wnhsu/data/data2vec2/data/text/GLUE_chr" + tgt_splits = [ + "CoLA-bin/input0/train", + "CoLA-bin/input0/valid", + "CoLA-bin/input0/test", + + "MNLI-bin/input0/train", + "MNLI-bin/input0/valid", + "MNLI-bin/input0/test", + "MNLI-bin/input0/test1", + "MNLI-bin/input1/train", + "MNLI-bin/input1/valid", + "MNLI-bin/input1/test", + "MNLI-bin/input1/test1", + + "MRPC-bin/input0/train", + "MRPC-bin/input0/valid", + "MRPC-bin/input0/test", + "MRPC-bin/input1/train", + "MRPC-bin/input1/valid", + "MRPC-bin/input1/test", + + "QNLI-bin/input0/train", + "QNLI-bin/input0/valid", + "QNLI-bin/input0/test", + "QNLI-bin/input1/train", + "QNLI-bin/input1/valid", + "QNLI-bin/input1/test", + + "QQP-bin/input0/train", + "QQP-bin/input0/valid", + "QQP-bin/input0/test", + "QQP-bin/input1/train", + "QQP-bin/input1/valid", + "QQP-bin/input1/test", + + "RTE-bin/input0/train", + "RTE-bin/input0/valid", + "RTE-bin/input0/test", + "RTE-bin/input1/train", + "RTE-bin/input1/valid", + "RTE-bin/input1/test", + + "SST-2-bin/input0/train", + "SST-2-bin/input0/valid", + "SST-2-bin/input0/test", + + "STS-B-bin/input0/train", + "STS-B-bin/input0/valid", + "STS-B-bin/input0/test", + "STS-B-bin/input1/train", + "STS-B-bin/input1/valid", + "STS-B-bin/input1/test", + ] + _main(src_root, src_dict_path, src_bpe_path, src_splits, tgt_root, tgt_splits) + + +if __name__ == "__main__": + main_pt() + main_ft() diff --git a/examples/data2vec/scripts/text/valids.py b/examples/data2vec/scripts/text/valids.py new file mode 100644 index 000000000..b2e5cfb25 --- /dev/null +++ b/examples/data2vec/scripts/text/valids.py @@ -0,0 +1,301 @@ +import os, argparse, re, json, copy, math +from collections import OrderedDict +import numpy as np + +parser = argparse.ArgumentParser(description='Process some integers.') +parser.add_argument('base', help='base log path') +parser.add_argument('--file_name', default='train.log', help='the log file name') +parser.add_argument('--target', default='valid_loss', help='target metric') +parser.add_argument('--last', type=int, default=999999999, help='print last n matches') +parser.add_argument('--last_files', type=int, default=None, help='print last x files') +parser.add_argument('--everything', action='store_true', help='print everything instead of only last match') +parser.add_argument('--path_contains', help='only consider matching file pattern') +parser.add_argument('--group_on', help='if set, groups by this metric and shows table of differences') +parser.add_argument('--epoch', help='epoch for comparison', type=int) +parser.add_argument('--skip_empty', action='store_true', help='skip empty results') +parser.add_argument('--skip_containing', help='skips entries containing this attribute') +parser.add_argument('--unique_epochs', action='store_true', help='only consider the last line fore each epoch') +parser.add_argument('--best', action='store_true', help='print the last best result') +parser.add_argument('--avg_params', help='average these params through entire log') +parser.add_argument('--extract_prev', help='extracts this metric from previous line') + +parser.add_argument('--remove_metric', help='extracts this metric from previous line') + +parser.add_argument('--compact', action='store_true', help='if true, just prints checkpoint best val') +parser.add_argument('--hydra', action='store_true', help='if true, uses hydra param conventions') + +parser.add_argument('--best_biggest', action='store_true', help='if true, best is the biggest number, not smallest') +parser.add_argument('--key_len', type=int, default=10, help='max length of key') + +parser.add_argument('--best_only', action='store_true', help='if set, only prints the best value') +parser.add_argument('--flat', action='store_true', help='just print the best results') + + +def main(args, print_output): + ret = {} + + entries = [] + + def extract_metric(s, metric): + try: + j = json.loads(s) + except: + return None + if args.epoch is not None and ('epoch' not in j or j['epoch'] != args.epoch): + return None + return j[metric] if metric in j else None + + + def extract_params(s): + s = s.replace(args.base, '', 1) + if args.path_contains is not None: + s = s.replace(args.path_contains, '', 1) + + if args.hydra: + num_matches = re.findall(r'(?:/|__)([^/:]+):(\d+\.?\d*)', s) + # str_matches = re.findall(r'(?:/|__)([^/:]+):([^\.]*[^\d\.]+)(?:/|__)', s) + str_matches = re.findall(r'(?:/|__)?((?:(?!(?:\:|__)).)+):([^\.]*[^\d\.]+\d*)(?:/|__)', s) + lr_matches = re.findall(r'optimization.(lr):\[([\d\.,]+)\]', s) + task_matches = re.findall(r'.*/(\d+)$', s) + else: + num_matches = re.findall(r'\.?([^\.]+?)(\d+(e\-\d+)?(?:\.\d+)?)(\.|$)', s) + str_matches = re.findall(r'[/\.]([^\.]*[^\d\.]+\d*)(?=\.)', s) + lr_matches = [] + task_matches = [] + + cp_matches = re.findall(r'checkpoint(?:_\d+)?_(\d+).pt', s) + + items = OrderedDict() + for m in str_matches: + if isinstance(m, tuple): + if 'checkpoint' not in m[0]: + items[m[0]] = m[1] + else: + items[m] = '' + + for m in num_matches: + items[m[0]] = m[1] + + for m in lr_matches: + items[m[0]] = m[1] + + for m in task_matches: + items["hydra_task"] = m + + for m in cp_matches: + items['checkpoint'] = m + + return items + + abs_best = None + + sources = [] + for root, _, files in os.walk(args.base): + if args.path_contains is not None and not args.path_contains in root: + continue + for f in files: + if f.endswith(args.file_name): + sources.append((root, f)) + + if args.last_files is not None: + sources = sources[-args.last_files:] + + for root, file in sources: + with open(os.path.join(root, file), 'r') as fin: + found = [] + avg = {} + prev = None + for line in fin: + line = line.rstrip() + if line.find(args.target) != -1 and ( + args.skip_containing is None or line.find(args.skip_containing) == -1): + try: + idx = line.index("{") + line = line[idx:] + line_json = json.loads(line) + except: + continue + if prev is not None: + try: + prev.update(line_json) + line_json = prev + except: + pass + if args.target in line_json: + found.append(line_json) + if args.avg_params: + avg_params = args.avg_params.split(',') + for p in avg_params: + m = extract_metric(line, p) + if m is not None: + prev_v, prev_c = avg.get(p, (0, 0)) + avg[p] = prev_v + float(m), prev_c + 1 + if args.extract_prev: + try: + prev = json.loads(line) + except: + pass + best = None + if args.best: + curr_best = None + for i in range(len(found)): + cand_best = found[i][args.target] if args.target in found[i] else None + + def cmp(a, b): + a = float(a) + b = float(b) + if args.best_biggest: + return a > b + return a < b + + if cand_best is not None and not math.isnan(float(cand_best)) and ( + curr_best is None or cmp(cand_best, curr_best)): + curr_best = cand_best + if abs_best is None or cmp(curr_best, abs_best): + abs_best = curr_best + best = found[i] + if args.unique_epochs or args.epoch: + last_found = [] + last_epoch = None + for i in reversed(range(len(found))): + epoch = found[i]['epoch'] + if args.epoch and args.epoch != epoch: + continue + if epoch != last_epoch: + last_epoch = epoch + last_found.append(found[i]) + found = list(reversed(last_found)) + + if len(found) == 0: + if print_output and (args.last_files is not None or not args.skip_empty): + # print(root.split('/')[-1]) + print(root[len(args.base):]) + print('Nothing') + else: + if not print_output: + ret[root[len(args.base):]] = best + continue + + if args.compact: + # print('{}\t{}'.format(root.split('/')[-1], curr_best)) + print('{}\t{}'.format(root[len(args.base)+1:], curr_best)) + continue + + if args.group_on is None and not args.best_only: + # print(root.split('/')[-1]) + print(root[len(args.base):]) + if not args.everything: + if best is not None and args.group_on is None and not args.best_only and not args.flat: + print(best, '(best)') + if args.group_on is None and args.last and not args.best_only and not args.flat: + for f in found[-args.last:]: + if args.extract_prev is not None: + try: + print('{}\t{}'.format(f[args.extract_prev], f[args.target])) + except Exception as e: + print('Exception!', e) + else: + print(f) + try: + metric = found[-1][args.target] if not args.best or best is None else best[args.target] + except: + print(found[-1]) + raise + if metric is not None: + entries.append((extract_params(root), metric)) + else: + for f in found: + print(f) + if not args.group_on and print_output: + print() + + if len(avg) > 0: + for k, (v, c) in avg.items(): + print(f'{k}: {v/c}') + + if args.best_only: + print(abs_best) + + if args.flat: + print("\t".join(m for _, m in entries)) + + if args.group_on is not None: + by_val = OrderedDict() + for e, m in entries: + k = args.group_on + if k not in e: + m_keys = [x for x in e.keys() if x.startswith(k)] + if len(m_keys) == 0: + val = "False" + else: + assert len(m_keys) == 1 + k = m_keys[0] + val = m_keys[0] + else: + val = e[args.group_on] + if val == "": + val = "True" + scrubbed_entry = copy.deepcopy(e) + if k in scrubbed_entry: + del scrubbed_entry[k] + if args.remove_metric and args.remove_metric in scrubbed_entry: + val += '_' + scrubbed_entry[args.remove_metric] + del scrubbed_entry[args.remove_metric] + by_val.setdefault(tuple(scrubbed_entry.items()), dict())[val] = m + distinct_vals = set() + for v in by_val.values(): + distinct_vals.update(v.keys()) + try: + distinct_vals = {int(d) for d in distinct_vals} + except: + print(distinct_vals) + print() + print("by_val", len(by_val)) + for k,v in by_val.items(): + print(k, '=>', v) + print() + + # , by_val, entries) + raise + from natsort import natsorted + svals = list(map(str, natsorted(distinct_vals))) + print('{}\t{}'.format(args.group_on, '\t'.join(svals))) + sums = OrderedDict({n:[] for n in svals}) + for k, v in by_val.items(): + kstr = '.'.join(':'.join(x) for x in k) + vstr = '' + for mv in svals: + x = v[mv] if mv in v else '' + vstr += '\t{}'.format(round(x, 5) if isinstance(x, float) else x) + try: + sums[mv].append(float(x)) + except: + pass + print('{}{}'.format(kstr[:args.key_len], vstr)) + if any(len(x) > 0 for x in sums.values()): + print('min:', end='') + for v in sums.values(): + min = np.min(v) + print(f'\t{round(min, 5)}', end='') + print() + print('max:', end='') + for v in sums.values(): + max = np.max(v) + print(f'\t{round(max, 5)}', end='') + print() + print('avg:', end='') + for v in sums.values(): + mean = np.mean(v) + print(f'\t{round(mean, 5)}', end='') + print() + print('median:', end='') + for v in sums.values(): + median = np.median(v) + print(f'\t{round(median, 5)}', end='') + print() + + return ret + +if __name__ == "__main__": + args = parser.parse_args() + main(args, print_output=True) \ No newline at end of file diff --git a/examples/data2vec/tasks/__init__.py b/examples/data2vec/tasks/__init__.py new file mode 100644 index 000000000..a7422e4b3 --- /dev/null +++ b/examples/data2vec/tasks/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from .image_pretraining import ImagePretrainingTask, ImagePretrainingConfig +from .image_classification import ImageClassificationTask, ImageClassificationConfig +from .mae_image_pretraining import MaeImagePretrainingTask, MaeImagePretrainingConfig + + +__all__ = [ + "ImageClassificationTask", + "ImageClassificationConfig", + "ImagePretrainingTask", + "ImagePretrainingConfig", + "MaeImagePretrainingTask", + "MaeImagePretrainingConfig", +] \ No newline at end of file diff --git a/examples/data2vec/tasks/audio_classification.py b/examples/data2vec/tasks/audio_classification.py new file mode 100644 index 000000000..2925a04cf --- /dev/null +++ b/examples/data2vec/tasks/audio_classification.py @@ -0,0 +1,167 @@ +# Copyright (c) 2017-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the LICENSE file in +# the root directory of this source tree. An additional grant of patent rights +# can be found in the PATENTS file in the same directory. + +import logging +import os +import numpy as np +import math +import torch + +from sklearn import metrics as sklearn_metrics +from dataclasses import dataclass + +from fairseq.tasks.audio_pretraining import AudioPretrainingTask, AudioPretrainingConfig +from fairseq.tasks import register_task +from fairseq.logging import metrics + +from ..data.add_class_target_dataset import AddClassTargetDataset + + +logger = logging.getLogger(__name__) + + +@dataclass +class AudioClassificationConfig(AudioPretrainingConfig): + label_descriptors: str = "label_descriptors.csv" + labels: str = "lbl" + + +@register_task("audio_classification", dataclass=AudioClassificationConfig) +class AudioClassificationTask(AudioPretrainingTask): + """ """ + + cfg: AudioClassificationConfig + + def __init__( + self, + cfg: AudioClassificationConfig, + ): + super().__init__(cfg) + + self.state.add_factory("labels", self.load_labels) + + def load_labels(self): + labels = {} + path = os.path.join(self.cfg.data, self.cfg.label_descriptors) + with open(path, "r") as ldf: + for line in ldf: + if line.strip() == "": + continue + items = line.split(",") + idx = items[0] + lbl = items[1] + assert lbl not in labels, lbl + labels[lbl] = idx + return labels + + @property + def labels(self): + return self.state.labels + + def load_dataset( + self, split: str, task_cfg: AudioClassificationConfig = None, **kwargs + ): + super().load_dataset(split, task_cfg, **kwargs) + + task_cfg = task_cfg or self.cfg + + data_path = self.cfg.data + label_path = os.path.join(data_path, f"{split}.{task_cfg.labels}") + skipped_indices = getattr(self.datasets[split], "skipped_indices", set()) + labels = [] + with open(label_path, "r") as f: + for i, line in enumerate(f): + if i not in skipped_indices: + lbl_items = line.rstrip().split("\t") + labels.append([int(x) for x in lbl_items[2].split(",")]) + + assert len(labels) == len(self.datasets[split]), ( + f"labels length ({len(labels)}) and dataset length " + f"({len(self.datasets[split])}) do not match" + ) + + self.datasets[split] = AddClassTargetDataset( + self.datasets[split], + labels, + multi_class=True, + add_to_input=True, + num_classes=len(self.labels), + ) + + def calculate_stats(self, output, target): + + classes_num = target.shape[-1] + stats = [] + + # Accuracy, only used for single-label classification such as esc-50, not for multiple label one such as AudioSet + # acc = sklearn_metrics.accuracy_score(np.argmax(target, 1), np.argmax(output, 1)) + + # Class-wise statistics + for k in range(classes_num): + # Average precision + avg_precision = sklearn_metrics.average_precision_score( + target[:, k], output[:, k], average=None + ) + + dict = { + "AP": avg_precision, + } + + # # AUC + # try: + # auc = sklearn_metrics.roc_auc_score(target[:, k], output[:, k], average=None) + # except: + # auc = 0 + # + # # Precisions, recalls + # (precisions, recalls, thresholds) = sklearn_metrics.precision_recall_curve( + # target[:, k], output[:, k] + # ) + # + # # FPR, TPR + # (fpr, tpr, thresholds) = sklearn_metrics.roc_curve(target[:, k], output[:, k]) + # + # save_every_steps = 1000 # Sample statistics to reduce size + # dict = { + # "precisions": precisions[0::save_every_steps], + # "recalls": recalls[0::save_every_steps], + # "AP": avg_precision, + # "fpr": fpr[0::save_every_steps], + # "fnr": 1.0 - tpr[0::save_every_steps], + # "auc": auc, + # # note acc is not class-wise, this is just to keep consistent with other metrics + # "acc": acc, + # } + stats.append(dict) + + return stats + + def valid_step(self, sample, model, criterion): + loss, sample_size, logging_output = super().valid_step(sample, model, criterion) + return loss, sample_size, logging_output + + def reduce_metrics(self, logging_outputs, criterion): + super().reduce_metrics(logging_outputs, criterion) + if "_predictions" in logging_outputs[0]: + metrics.log_concat_tensor( + "_predictions", + torch.cat([l["_predictions"].cpu() for l in logging_outputs], dim=0), + ) + metrics.log_concat_tensor( + "_targets", + torch.cat([l["_targets"].cpu() for l in logging_outputs], dim=0), + ) + + def compute_stats(meters): + if meters["_predictions"].tensor.shape[0] < 100: + return 0 + stats = self.calculate_stats( + meters["_predictions"].tensor, meters["_targets"].tensor + ) + return np.nanmean([stat["AP"] for stat in stats]) + + metrics.log_derived("mAP", compute_stats) diff --git a/examples/data2vec/tasks/image_classification.py b/examples/data2vec/tasks/image_classification.py new file mode 100644 index 000000000..1ea4c2afe --- /dev/null +++ b/examples/data2vec/tasks/image_classification.py @@ -0,0 +1,129 @@ +# Copyright (c) 2017-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the LICENSE file in +# the root directory of this source tree. An additional grant of patent rights +# can be found in the PATENTS file in the same directory. + +import os.path as osp +import logging + +from dataclasses import dataclass +import torch +from torchvision import transforms + +from fairseq.dataclass import FairseqDataclass +from fairseq.tasks import register_task +from fairseq.logging import metrics + +try: + from ..data import ImageDataset +except: + import sys + + sys.path.append("..") + from data import ImageDataset + +from .image_pretraining import ( + ImagePretrainingConfig, + ImagePretrainingTask, + IMG_EXTENSIONS, +) + +logger = logging.getLogger(__name__) + + +@dataclass +class ImageClassificationConfig(ImagePretrainingConfig): + pass + + +@register_task("image_classification", dataclass=ImageClassificationConfig) +class ImageClassificationTask(ImagePretrainingTask): + + cfg: ImageClassificationConfig + + @classmethod + def setup_task(cls, cfg: ImageClassificationConfig, **kwargs): + return cls(cfg) + + def load_dataset(self, split: str, task_cfg: FairseqDataclass = None, **kwargs): + data_path = self.cfg.data + cfg = task_cfg or self.cfg + + path_with_split = osp.join(data_path, split) + if osp.exists(path_with_split): + data_path = path_with_split + + from timm.data import create_transform + + if split == "train": + # this should always dispatch to transforms_imagenet_train + transform = create_transform( + input_size=cfg.input_size, + is_training=True, + auto_augment="rand-m9-mstd0.5-inc1", + interpolation="bicubic", + re_prob=0.25, + re_mode="pixel", + re_count=1, + mean=cfg.normalization_mean, + std=cfg.normalization_std, + ) + if not cfg.input_size > 32: + transform.transforms[0] = transforms.RandomCrop( + cfg.input_size, padding=4 + ) + else: + t = [] + if cfg.input_size > 32: + crop_pct = 1 + if cfg.input_size < 384: + crop_pct = 224 / 256 + size = int(cfg.input_size / crop_pct) + t.append( + transforms.Resize( + size, interpolation=3 + ), # to maintain same ratio w.r.t. 224 images + ) + t.append(transforms.CenterCrop(cfg.input_size)) + + t.append(transforms.ToTensor()) + t.append( + transforms.Normalize(cfg.normalization_mean, cfg.normalization_std) + ) + transform = transforms.Compose(t) + logger.info(transform) + + self.datasets[split] = ImageDataset( + root=data_path, + extensions=IMG_EXTENSIONS, + load_classes=True, + transform=transform, + ) + for k in self.datasets.keys(): + if k != split: + assert self.datasets[k].classes == self.datasets[split].classes + + def build_model(self, model_cfg: FairseqDataclass, from_checkpoint=False): + model = super().build_model(model_cfg, from_checkpoint) + + actualized_cfg = getattr(model, "cfg", None) + if actualized_cfg is not None: + if hasattr(actualized_cfg, "pretrained_model_args"): + model_cfg.pretrained_model_args = actualized_cfg.pretrained_model_args + + return model + + def reduce_metrics(self, logging_outputs, criterion): + super().reduce_metrics(logging_outputs, criterion) + + if "correct" in logging_outputs[0]: + zero = torch.scalar_tensor(0.0) + correct = sum(log.get("correct", zero) for log in logging_outputs) + metrics.log_scalar_sum("_correct", correct) + + metrics.log_derived( + "accuracy", + lambda meters: 100 * meters["_correct"].sum / meters["sample_size"].sum, + ) diff --git a/examples/data2vec/tasks/image_pretraining.py b/examples/data2vec/tasks/image_pretraining.py new file mode 100644 index 000000000..cd688fd13 --- /dev/null +++ b/examples/data2vec/tasks/image_pretraining.py @@ -0,0 +1,110 @@ +# Copyright (c) 2017-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the LICENSE file in +# the root directory of this source tree. An additional grant of patent rights +# can be found in the PATENTS file in the same directory. + +import logging +import sys +import os.path as osp + +from dataclasses import dataclass, field +from typing import List +from omegaconf import MISSING + +import torch +from torchvision import transforms + +from fairseq.dataclass import FairseqDataclass +from fairseq.tasks import FairseqTask, register_task + +try: + from ..data import ImageDataset +except: + sys.path.append("..") + from data import ImageDataset + +logger = logging.getLogger(__name__) + +IMG_EXTENSIONS = { + ".jpg", + ".jpeg", + ".png", + ".ppm", + ".bmp", + ".pgm", + ".tif", + ".tiff", + ".webp", +} + + +@dataclass +class ImagePretrainingConfig(FairseqDataclass): + data: str = field(default=MISSING, metadata={"help": "path to data directory"}) + input_size: int = 224 + normalization_mean: List[float] = (0.485, 0.456, 0.406) + normalization_std: List[float] = (0.229, 0.224, 0.225) + + +@register_task("image_pretraining", dataclass=ImagePretrainingConfig) +class ImagePretrainingTask(FairseqTask): + """ """ + + cfg: ImagePretrainingConfig + + @classmethod + def setup_task(cls, cfg: ImagePretrainingConfig, **kwargs): + """Setup the task (e.g., load dictionaries). + + Args: + cfg (AudioPretrainingConfig): configuration of this task + """ + + return cls(cfg) + + def load_dataset(self, split: str, task_cfg: FairseqDataclass = None, **kwargs): + data_path = self.cfg.data + cfg = task_cfg or self.cfg + + path_with_split = osp.join(data_path, split) + if osp.exists(path_with_split): + data_path = path_with_split + + transform = transforms.Compose( + [ + transforms.ColorJitter(0.4, 0.4, 0.4), + transforms.RandomHorizontalFlip(p=0.5), + transforms.RandomResizedCrop( + size=cfg.input_size, + interpolation=transforms.InterpolationMode.BICUBIC, + ), + transforms.ToTensor(), + transforms.Normalize( + mean=torch.tensor(cfg.normalization_mean), + std=torch.tensor(cfg.normalization_std), + ), + ] + ) + + logger.info(transform) + + self.datasets[split] = ImageDataset( + root=data_path, + extensions=IMG_EXTENSIONS, + load_classes=False, + transform=transform, + ) + + @property + def source_dictionary(self): + return None + + @property + def target_dictionary(self): + return None + + def max_positions(self): + """Maximum input length supported by the encoder.""" + return sys.maxsize, sys.maxsize diff --git a/examples/data2vec/tasks/mae_image_classification.py b/examples/data2vec/tasks/mae_image_classification.py new file mode 100644 index 000000000..1bf935879 --- /dev/null +++ b/examples/data2vec/tasks/mae_image_classification.py @@ -0,0 +1,100 @@ +# Copyright (c) 2017-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the LICENSE file in +# the root directory of this source tree. An additional grant of patent rights +# can be found in the PATENTS file in the same directory. + +import logging +import sys +import torch + +from typing import Optional +from dataclasses import dataclass, field +from omegaconf import MISSING + +from fairseq.dataclass import FairseqDataclass +from fairseq.tasks import FairseqTask, register_task +from fairseq.logging import metrics + +try: + from ..data import MaeFinetuningImageDataset +except: + sys.path.append("..") + from data import MaeFinetuningImageDataset + +logger = logging.getLogger(__name__) + + +@dataclass +class MaeImageClassificationConfig(FairseqDataclass): + data: str = field(default=MISSING, metadata={"help": "path to data directory"}) + input_size: int = 224 + local_cache_path: Optional[str] = None + + rebuild_batches: bool = True + + +@register_task("mae_image_classification", dataclass=MaeImageClassificationConfig) +class MaeImageClassificationTask(FairseqTask): + """ """ + + cfg: MaeImageClassificationConfig + + @classmethod + def setup_task(cls, cfg: MaeImageClassificationConfig, **kwargs): + """Setup the task (e.g., load dictionaries). + + Args: + cfg (AudioPretrainingConfig): configuration of this task + """ + + return cls(cfg) + + def load_dataset(self, split: str, task_cfg: FairseqDataclass = None, **kwargs): + data_path = self.cfg.data + cfg = task_cfg or self.cfg + + self.datasets[split] = MaeFinetuningImageDataset( + root=data_path, + split=split, + is_train=split == "train", + input_size=cfg.input_size, + local_cache_path=cfg.local_cache_path, + shuffle=split == "train", + ) + + def build_model(self, model_cfg: FairseqDataclass, from_checkpoint=False): + model = super().build_model(model_cfg, from_checkpoint) + + actualized_cfg = getattr(model, "cfg", None) + if actualized_cfg is not None: + if hasattr(actualized_cfg, "pretrained_model_args"): + model_cfg.pretrained_model_args = actualized_cfg.pretrained_model_args + + return model + + def reduce_metrics(self, logging_outputs, criterion): + super().reduce_metrics(logging_outputs, criterion) + + if "correct" in logging_outputs[0]: + zero = torch.scalar_tensor(0.0) + correct = sum(log.get("correct", zero) for log in logging_outputs) + metrics.log_scalar_sum("_correct", correct) + + metrics.log_derived( + "accuracy", + lambda meters: 100 * meters["_correct"].sum / meters["sample_size"].sum, + ) + + @property + def source_dictionary(self): + return None + + @property + def target_dictionary(self): + return None + + def max_positions(self): + """Maximum input length supported by the encoder.""" + return sys.maxsize, sys.maxsize diff --git a/examples/data2vec/tasks/mae_image_pretraining.py b/examples/data2vec/tasks/mae_image_pretraining.py new file mode 100644 index 000000000..35a14891c --- /dev/null +++ b/examples/data2vec/tasks/mae_image_pretraining.py @@ -0,0 +1,119 @@ +# Copyright (c) 2017-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the LICENSE file in +# the root directory of this source tree. An additional grant of patent rights +# can be found in the PATENTS file in the same directory. + +import logging +import sys + +from typing import Optional, List +from dataclasses import dataclass, field +from omegaconf import MISSING, II + +from fairseq.data import SubsampleDataset +from fairseq.dataclass import FairseqDataclass +from fairseq.tasks import FairseqTask, register_task + +try: + from ..data import MaeImageDataset +except: + sys.path.append("..") + from data import MaeImageDataset + +logger = logging.getLogger(__name__) + + +@dataclass +class ImageMaskingConfig: + patch_size: int = II("model.modalities.image.patch_size") + mask_prob: float = II("model.modalities.image.mask_prob") + mask_prob_adjust: float = II("model.modalities.image.mask_prob_adjust") + mask_length: int = II("model.modalities.image.mask_length") + inverse_mask: bool = II("model.modalities.image.inverse_mask") + mask_dropout: float = II("model.modalities.image.mask_dropout") + clone_batch: int = II("model.clone_batch") + expand_adjacent: bool = False + non_overlapping: bool = False + + +@dataclass +class MaeImagePretrainingConfig(FairseqDataclass): + data: str = field(default=MISSING, metadata={"help": "path to data directory"}) + multi_data: Optional[List[str]] = None + input_size: int = 224 + local_cache_path: Optional[str] = None + key: str = "imgs" + + beit_transforms: bool = False + target_transform: bool = False + no_transform: bool = False + + rebuild_batches: bool = True + + precompute_mask_config: Optional[ImageMaskingConfig] = None + + subsample: float = 1 + seed: int = II("common.seed") + dataset_type: str = "imagefolder" + + +@register_task("mae_image_pretraining", dataclass=MaeImagePretrainingConfig) +class MaeImagePretrainingTask(FairseqTask): + """ """ + + cfg: MaeImagePretrainingConfig + + @classmethod + def setup_task(cls, cfg: MaeImagePretrainingConfig, **kwargs): + """Setup the task (e.g., load dictionaries). + + Args: + cfg (AudioPretrainingConfig): configuration of this task + """ + + return cls(cfg) + + def load_dataset(self, split: str, task_cfg: FairseqDataclass = None, **kwargs): + data_path = self.cfg.data + cfg = task_cfg or self.cfg + + compute_mask = cfg.precompute_mask_config is not None + mask_args = {} + if compute_mask: + mask_args = cfg.precompute_mask_config + + self.datasets[split] = MaeImageDataset( + root=data_path if cfg.multi_data is None else cfg.multi_data, + split=split, + input_size=cfg.input_size, + local_cache_path=cfg.local_cache_path, + key=cfg.key, + beit_transforms=cfg.beit_transforms, + target_transform=cfg.target_transform, + no_transform=cfg.no_transform, + compute_mask=compute_mask, + dataset_type=cfg.dataset_type, + **mask_args, + ) + + if cfg.subsample < 1: + self.datasets[split] = SubsampleDataset( + self.datasets[split], + cfg.subsample, + shuffle=True, + seed=cfg.seed, + ) + + @property + def source_dictionary(self): + return None + + @property + def target_dictionary(self): + return None + + def max_positions(self): + """Maximum input length supported by the encoder.""" + return sys.maxsize, sys.maxsize diff --git a/examples/data2vec/tasks/multimodal.py b/examples/data2vec/tasks/multimodal.py new file mode 100644 index 000000000..74648e918 --- /dev/null +++ b/examples/data2vec/tasks/multimodal.py @@ -0,0 +1,165 @@ +# Copyright (c) 2017-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the LICENSE file in +# the root directory of this source tree. An additional grant of patent rights +# can be found in the PATENTS file in the same directory. + +import sys + +from dataclasses import dataclass +from typing import Optional, List +from omegaconf import II + +from fairseq.data.iterators import GroupedEpochBatchIterator + +from fairseq.dataclass import FairseqDataclass +from fairseq.tasks import FairseqTask, register_task +from fairseq.tasks.audio_pretraining import AudioPretrainingConfig, AudioPretrainingTask +from fairseq.tasks.masked_lm import MaskedLMConfig, MaskedLMTask +from .mae_image_pretraining import MaeImagePretrainingConfig, MaeImagePretrainingTask +from examples.data2vec.data.modality import Modality + +from fairseq.data.audio.multi_modality_dataset import ( + MultiModalityDataset, + ModalityDatasetItem, +) + + +@dataclass +class MultimodalPretrainingConfig(FairseqDataclass): + audio: Optional[AudioPretrainingConfig] = None + image: Optional[MaeImagePretrainingConfig] = None + text: Optional[MaskedLMConfig] = None + + audio_ratio: float = 1 + image_ratio: float = 1 + text_ratio: float = 1 + + max_tokens: Optional[int] = II("dataset.max_tokens") + batch_size: Optional[int] = II("dataset.batch_size") + update_freq: List[int] = II("optimization.update_freq") + + rebuild_batches: bool = True + + +@register_task("multimodal_pretraining", dataclass=MultimodalPretrainingConfig) +class MultimodalPretrainingTask(FairseqTask): + """ """ + + cfg: MultimodalPretrainingConfig + + def __init__(self, cfg: MultimodalPretrainingConfig): + super().__init__(cfg) + self.audio_task = ( + AudioPretrainingTask(cfg.audio) if cfg.audio is not None else None + ) + self.image_task = ( + MaeImagePretrainingTask(cfg.image) if cfg.image is not None else None + ) + self.text_task = MaskedLMTask(cfg.text) if cfg.text is not None else None + + self.mult_ratios = [] + + @classmethod + def setup_task(cls, cfg: MultimodalPretrainingConfig, **kwargs): + """Setup the task (e.g., load dictionaries). + + Args: + cfg (AudioPretrainingConfig): configuration of this task + """ + + return cls(cfg) + + def load_dataset(self, split: str, task_cfg: FairseqDataclass = None, **kwargs): + datasets = [] + self.mult_ratios = [] + + def load_ds(task, name, ratio): + if task is not None: + task.load_dataset(split) + ds = ModalityDatasetItem( + datasetname=name, + dataset=task.dataset(split), + max_positions=task.max_positions(), + max_tokens=self.cfg.max_tokens, + max_sentences=self.cfg.batch_size, + ) + datasets.append(ds) + self.mult_ratios.append(ratio) + + load_ds(self.audio_task, Modality.AUDIO, self.cfg.audio_ratio) + load_ds(self.image_task, Modality.IMAGE, self.cfg.image_ratio) + load_ds(self.text_task, Modality.TEXT, self.cfg.text_ratio) + + assert len(datasets) > 0 + + self.datasets[split] = MultiModalityDataset(datasets) + + @property + def supported_modalities(self): + modalities = [] + if self.cfg.text is not None: + modalities.append(Modality.TEXT) + if self.cfg.audio is not None: + modalities.append(Modality.AUDIO) + if self.cfg.image is not None: + modalities.append(Modality.IMAGE) + + return modalities + + def get_batch_iterator( + self, + dataset, + max_tokens=None, + max_sentences=None, + max_positions=None, + ignore_invalid_inputs=False, + required_batch_size_multiple=1, + seed=1, + num_shards=1, + shard_id=0, + num_workers=0, + epoch=0, + data_buffer_size=0, + disable_iterator_cache=False, + skip_remainder_batch=False, + grouped_shuffling=False, + update_epoch_batch_itr=False, + ): + + # initialize the dataset with the correct starting epoch + dataset.set_epoch(epoch) + + batch_samplers = dataset.get_batch_samplers( + self.mult_ratios, required_batch_size_multiple, seed + ) + + # return a reusable, sharded iterator + epoch_iter = GroupedEpochBatchIterator( + dataset=dataset, + collate_fn=dataset.collater, + batch_samplers=batch_samplers, + seed=seed, + num_shards=num_shards, + shard_id=shard_id, + num_workers=num_workers, + epoch=epoch, + mult_rate=max(self.cfg.update_freq), + buffer_size=data_buffer_size, + skip_remainder_batch=skip_remainder_batch, + ) + self.dataset_to_epoch_iter[dataset] = {} # refresh it every epoch + return epoch_iter + + @property + def source_dictionary(self): + return None + + @property + def target_dictionary(self): + return None + + def max_positions(self): + """Maximum input length supported by the encoder.""" + return sys.maxsize, sys.maxsize diff --git a/examples/roberta/config/finetuning/run_config/local.yaml b/examples/roberta/config/finetuning/run_config/local.yaml new file mode 100644 index 000000000..45595f9ee --- /dev/null +++ b/examples/roberta/config/finetuning/run_config/local.yaml @@ -0,0 +1,15 @@ +# @package _global_ +hydra: + sweep: + dir: ${env:PWD}/tmp_dbg/${now:%H-%M-%S} + +distributed_training: + distributed_world_size: 1 + nprocs_per_node: 1 + distributed_port: -1 + +common: + log_interval: 1 + +dataset: + num_workers: 0 diff --git a/examples/roberta/config/finetuning/run_config/slurm_1g.yaml b/examples/roberta/config/finetuning/run_config/slurm_1g.yaml new file mode 100644 index 000000000..8bc21854d --- /dev/null +++ b/examples/roberta/config/finetuning/run_config/slurm_1g.yaml @@ -0,0 +1,28 @@ + +# @package _global_ + +hydra: + job: + config: + override_dirname: + kv_sep: '_' + item_sep: '/' + exclude_keys: + - run_config + - distributed_training.distributed_port + sweep: + dir: /checkpoint/${env:USER}/roberta_ft/${env:PREFIX}/${hydra.job.config_name}/${env:SUFFIX} + subdir: ${hydra.job.num} + launcher: + submitit_folder: ${hydra.sweep.dir}/submitit + timeout_min: 1000 + cpus_per_task: 8 + gpus_per_node: 1 + tasks_per_node: 1 + mem_gb: 60 + nodes: 1 + name: ${env:PREFIX}_${hydra.job.config_name} + partition: devlab,learnlab,learnfair,scavenge + constraint: volta32gb + max_num_timeout: 30 + exclude: learnfair1381,learnfair5192,learnfair2304 diff --git a/examples/roberta/config/finetuning/run_config/slurm_1g_aws.yaml b/examples/roberta/config/finetuning/run_config/slurm_1g_aws.yaml new file mode 100644 index 000000000..085391cff --- /dev/null +++ b/examples/roberta/config/finetuning/run_config/slurm_1g_aws.yaml @@ -0,0 +1,25 @@ +# @package _global_ + +hydra: + job: + config: + override_dirname: + kv_sep: '_' + item_sep: '/' + exclude_keys: + - run_config + - distributed_training.distributed_port + sweep: + dir: /fsx-wav2vec/${env:USER}/roberta_ft/${env:PREFIX}/${hydra.job.config_name}/${env:SUFFIX} + subdir: ${hydra.job.num} + launcher: + submitit_folder: ${hydra.sweep.dir}/submitit + timeout_min: 1000 + cpus_per_task: 8 + gpus_per_node: 1 + tasks_per_node: 1 + mem_gb: 0 + nodes: 1 + name: ${env:PREFIX}_${hydra.job.config_name} + partition: learnfair,wav2vec + max_num_timeout: 30 diff --git a/examples/roberta/config/pretraining/run_config/local.yaml b/examples/roberta/config/pretraining/run_config/local.yaml new file mode 100644 index 000000000..45595f9ee --- /dev/null +++ b/examples/roberta/config/pretraining/run_config/local.yaml @@ -0,0 +1,15 @@ +# @package _global_ +hydra: + sweep: + dir: ${env:PWD}/tmp_dbg/${now:%H-%M-%S} + +distributed_training: + distributed_world_size: 1 + nprocs_per_node: 1 + distributed_port: -1 + +common: + log_interval: 1 + +dataset: + num_workers: 0 diff --git a/examples/roberta/config/pretraining/run_config/slurm_2.yaml b/examples/roberta/config/pretraining/run_config/slurm_2.yaml new file mode 100644 index 000000000..006a0f211 --- /dev/null +++ b/examples/roberta/config/pretraining/run_config/slurm_2.yaml @@ -0,0 +1,37 @@ +# @package _global_ + +hydra: + job: + config: + override_dirname: + kv_sep: ':' + item_sep: '/' + exclude_keys: + - run_config + - distributed_training.distributed_port + - distributed_training.distributed_world_size + - model.pretrained_model_path + - model.target_network_path + - next_script + - task.cache_in_scratch + - task.data + - checkpoint.save_interval_updates + - checkpoint.keep_interval_updates + - checkpoint.save_on_overflow + - common.log_interval + - common.user_dir + sweep: + dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname} + subdir: '' + launcher: + submitit_folder: ${hydra.sweep.dir} + timeout_min: 4320 + cpus_per_task: 80 + gpus_per_node: 8 + tasks_per_node: 1 + mem_gb: 450 + nodes: 2 + name: ${env:PREFIX}_${hydra.job.config_name} + partition: devlab,learnlab,learnfair,scavenge + constraint: volta32gb,ib4 + max_num_timeout: 30 diff --git a/examples/roberta/config/pretraining/run_config/slurm_2_aws.yaml b/examples/roberta/config/pretraining/run_config/slurm_2_aws.yaml new file mode 100644 index 000000000..a5937ea5a --- /dev/null +++ b/examples/roberta/config/pretraining/run_config/slurm_2_aws.yaml @@ -0,0 +1,39 @@ +# @package _global_ + +hydra: + job: + config: + override_dirname: + kv_sep: ':' + item_sep: '/' + exclude_keys: + - run_config + - distributed_training.distributed_port + - distributed_training.distributed_world_size + - model.pretrained_model_path + - model.target_network_path + - next_script + - task.cache_in_scratch + - task.local_cache_path + - task.data + - task.post_save_script + - checkpoint.save_interval_updates + - checkpoint.keep_interval_updates + - checkpoint.save_on_overflow + - common.log_interval + - common.user_dir + - model.model_path + sweep: + dir: /fsx-wav2vec/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname} + subdir: '' + launcher: + submitit_folder: ${hydra.sweep.dir} + timeout_min: 4320 + cpus_per_task: 10 + gpus_per_node: 8 + tasks_per_node: 8 + mem_gb: 0 + nodes: 2 + name: ${env:PREFIX}_${hydra.job.config_name} + partition: wav2vec + max_num_timeout: 30 diff --git a/examples/roberta/config/pretraining/run_config/slurm_3.yaml b/examples/roberta/config/pretraining/run_config/slurm_3.yaml new file mode 100644 index 000000000..0e1555d20 --- /dev/null +++ b/examples/roberta/config/pretraining/run_config/slurm_3.yaml @@ -0,0 +1,36 @@ +# @package _global_ + +hydra: + job: + config: + override_dirname: + kv_sep: ':' + item_sep: '/' + exclude_keys: + - run_config + - distributed_training.distributed_port + - distributed_training.distributed_world_size + - model.pretrained_model_path + - model.target_network_path + - next_script + - task.cache_in_scratch + - task.data + - checkpoint.save_interval_updates + - checkpoint.keep_interval_updates + - checkpoint.save_on_overflow + - common.log_interval + sweep: + dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname} + subdir: '' + launcher: + submitit_folder: ${hydra.sweep.dir} + timeout_min: 4320 + cpus_per_task: 10 + gpus_per_node: 8 + tasks_per_node: 8 + mem_gb: 450 + nodes: 3 + name: ${env:PREFIX}_${hydra.job.config_name} + partition: devlab,learnlab,learnfair,scavenge + constraint: volta32gb,ib4 + max_num_timeout: 30 diff --git a/examples/roberta/config/pretraining/run_config/slurm_4.yaml b/examples/roberta/config/pretraining/run_config/slurm_4.yaml new file mode 100644 index 000000000..c54d735fb --- /dev/null +++ b/examples/roberta/config/pretraining/run_config/slurm_4.yaml @@ -0,0 +1,36 @@ +# @package _global_ + +hydra: + job: + config: + override_dirname: + kv_sep: ':' + item_sep: '/' + exclude_keys: + - run_config + - distributed_training.distributed_port + - distributed_training.distributed_world_size + - model.pretrained_model_path + - model.target_network_path + - next_script + - task.cache_in_scratch + - task.data + - checkpoint.save_interval_updates + - checkpoint.keep_interval_updates + - checkpoint.save_on_overflow + - common.log_interval + sweep: + dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname} + subdir: '' + launcher: + submitit_folder: ${hydra.sweep.dir} + timeout_min: 4320 + cpus_per_task: 10 + gpus_per_node: 8 + tasks_per_node: 8 + mem_gb: 450 + nodes: 4 + name: ${env:PREFIX}_${hydra.job.config_name} + partition: devlab,learnlab,learnfair,scavenge + constraint: volta32gb,ib4 + max_num_timeout: 30 diff --git a/examples/roberta/fb_multilingual/README.multilingual.pretraining.md b/examples/roberta/fb_multilingual/README.multilingual.pretraining.md new file mode 100644 index 000000000..234fd7470 --- /dev/null +++ b/examples/roberta/fb_multilingual/README.multilingual.pretraining.md @@ -0,0 +1,26 @@ +# Multilingual pretraining RoBERTa + +This tutorial will walk you through pretraining multilingual RoBERTa. + +### 1) Preprocess the data + +```bash +DICTIONARY="/private/home/namangoyal/dataset/XLM/wiki/17/175k/vocab" +DATA_LOCATION="/private/home/namangoyal/dataset/XLM/wiki/17/175k" + +for LANG in en es it +do + fairseq-preprocess \ + --only-source \ + --srcdict $DICTIONARY \ + --trainpref "$DATA_LOCATION/train.$LANG" \ + --validpref "$DATA_LOCATION/valid.$LANG" \ + --testpref "$DATA_LOCATION/test.$LANG" \ + --destdir "wiki_17-bin/$LANG" \ + --workers 60; +done +``` + +### 2) Train RoBERTa base + +[COMING UP...] diff --git a/examples/simultaneous_translation/modules/monotonic_multihead_attention.py b/examples/simultaneous_translation/modules/monotonic_multihead_attention.py index 11ef60c94..06d20d8d4 100644 --- a/examples/simultaneous_translation/modules/monotonic_multihead_attention.py +++ b/examples/simultaneous_translation/modules/monotonic_multihead_attention.py @@ -396,6 +396,7 @@ class MonotonicAttention(MultiheadAttention): "p_choose": p_choose, "alpha": alpha, "beta": beta, + "soft_energy": soft_energy, } def _get_monotonic_buffer(self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]): diff --git a/examples/simultaneous_translation/tests/test_text_models.py b/examples/simultaneous_translation/tests/test_text_models.py index 127adfa63..19d635630 100644 --- a/examples/simultaneous_translation/tests/test_text_models.py +++ b/examples/simultaneous_translation/tests/test_text_models.py @@ -334,7 +334,7 @@ class InfiniteLookbackTestCase( self.model.decoder.layers[0].encoder_attn, "chunk_size", int(1e10) - ) + ) or int(1e10) ) self.assertTrue( diff --git a/examples/speech_recognition/new/conf/hydra/sweeper/ax_sil.yaml b/examples/speech_recognition/new/conf/hydra/sweeper/ax_sil.yaml new file mode 100644 index 000000000..eaaebcf5f --- /dev/null +++ b/examples/speech_recognition/new/conf/hydra/sweeper/ax_sil.yaml @@ -0,0 +1,29 @@ +# @package hydra.sweeper +_target_: hydra_plugins.hydra_ax_sweeper.ax_sweeper.AxSweeper +max_batch_size: null +ax_config: + max_trials: 64 + early_stop: + minimize: true + max_epochs_without_improvement: 10 + epsilon: 0.025 + experiment: + name: ${dataset.gen_subset} + objective_name: wer + minimize: true + parameter_constraints: null + outcome_constraints: null + status_quo: null + client: + verbose_logging: false + random_seed: null + params: + decoding.lmweight: + type: range + bounds: [0.0, 10.0] + decoding.wordscore: + type: range + bounds: [-10.0, 10.0] + decoding.silweight: + type: range + bounds: [ -10.0, 0.0 ] diff --git a/examples/speech_recognition/new/conf/infer.yaml b/examples/speech_recognition/new/conf/infer.yaml index 21dd19fad..2d168d06a 100644 --- a/examples/speech_recognition/new/conf/infer.yaml +++ b/examples/speech_recognition/new/conf/infer.yaml @@ -10,6 +10,8 @@ hydra: sweep: dir: /checkpoint/${env:USER}/${env:PREFIX}/${common_eval.results_path} subdir: ${dataset.gen_subset} +common: + user_dir: /private/home/abaevski/fairseq-py/examples/data2vec common_eval: results_path: null path: null diff --git a/examples/speech_recognition/new/conf/run_config/fb_slurm_1.yaml b/examples/speech_recognition/new/conf/run_config/fb_slurm_1.yaml new file mode 100644 index 000000000..d0a9b0e58 --- /dev/null +++ b/examples/speech_recognition/new/conf/run_config/fb_slurm_1.yaml @@ -0,0 +1,28 @@ +# @package _global_ + +hydra: + job: + config: + override_dirname: + kv_sep: ':' + item_sep: '/' + exclude_keys: + - run_config + - distributed_training.distributed_port + - common_eval.path + sweep: + dir: /checkpoint/abaevski/asr/d2v2/decoding/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname} +# subdir: ${hydra.job.override_dirname} + launcher: + cpus_per_task: 16 + gpus_per_node: 1 + tasks_per_node: 1 + nodes: 1 + partition: devlab,learnlab + mem_gb: 100 + timeout_min: 2000 + max_num_timeout: 10 + name: ${env:PREFIX}_${hydra.job.config_name} + submitit_folder: ${hydra.sweep.dir}/%j + constraint: volta32gb + exclude: learnfair7598 \ No newline at end of file diff --git a/examples/speech_recognition/new/conf/run_config/fb_slurm_2g.yaml b/examples/speech_recognition/new/conf/run_config/fb_slurm_2g.yaml new file mode 100644 index 000000000..c0c442f76 --- /dev/null +++ b/examples/speech_recognition/new/conf/run_config/fb_slurm_2g.yaml @@ -0,0 +1,27 @@ +# @package _global_ + +hydra: + job: + config: + override_dirname: + kv_sep: ':' + item_sep: '/' + exclude_keys: + - run_config + - distributed_training.distributed_port + - common_eval.path + sweep: + dir: /checkpoint/abaevski/asr/d2v2/decoding/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname} +# subdir: ${hydra.job.override_dirname} + launcher: + cpus_per_task: 16 + gpus_per_node: 2 + tasks_per_node: 2 + nodes: 1 + partition: devlab,learnlab + mem_gb: 100 + timeout_min: 2000 + max_num_timeout: 10 + name: ${env:PREFIX}_${hydra.job.config_name} + submitit_folder: ${hydra.sweep.dir}/%j + constraint: volta32gb \ No newline at end of file diff --git a/examples/wav2vec/config/finetuning/run_config/slurm_1.yaml b/examples/wav2vec/config/finetuning/run_config/slurm_1.yaml new file mode 100644 index 000000000..4a848435c --- /dev/null +++ b/examples/wav2vec/config/finetuning/run_config/slurm_1.yaml @@ -0,0 +1,26 @@ +# @package _global_ + +hydra: + job: + config: + override_dirname: + kv_sep: ':' + item_sep: '__' + exclude_keys: + - run_config + - distributed_training.distributed_port + sweep: + dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname} + subdir: ${hydra.job.num} + launcher: + submitit_folder: ${hydra.sweep.dir} + timeout_min: 4320 + cpus_per_task: 10 + gpus_per_node: 8 + tasks_per_node: 8 + mem_gb: 450 + nodes: 1 + name: ${env:PREFIX}_${hydra.job.config_name} + partition: devlab,learnlab,learnfair,scavenge + constraint: volta32gb + max_num_timeout: 30 \ No newline at end of file diff --git a/examples/wav2vec/config/finetuning/run_config/slurm_16.yaml b/examples/wav2vec/config/finetuning/run_config/slurm_16.yaml new file mode 100644 index 000000000..041843a9b --- /dev/null +++ b/examples/wav2vec/config/finetuning/run_config/slurm_16.yaml @@ -0,0 +1,27 @@ +# @package _global_ + +hydra: + job: + config: + override_dirname: + kv_sep: ':' + item_sep: '__' + exclude_keys: + - run_config + - distributed_training.distributed_port + sweep: + dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname} + subdir: ${hydra.job.num} + launcher: + submitit_folder: ${hydra.sweep.dir} + timeout_min: 4320 + cpus_per_task: 80 + gpus_per_node: 8 + tasks_per_node: 1 + mem_gb: 450 + nodes: 16 + name: ${env:PREFIX}_${hydra.job.config_name} + partition: learnlab,learnfair,scavenge + constraint: volta32gb + max_num_timeout: 30 + exclude: learnfair1381,learnfair5192,learnfair2304 \ No newline at end of file diff --git a/examples/wav2vec/config/finetuning/run_config/slurm_1_aws.yaml b/examples/wav2vec/config/finetuning/run_config/slurm_1_aws.yaml new file mode 100644 index 000000000..b9335df78 --- /dev/null +++ b/examples/wav2vec/config/finetuning/run_config/slurm_1_aws.yaml @@ -0,0 +1,37 @@ +# @package _global_ + +hydra: + job: + config: + override_dirname: + kv_sep: ':' + item_sep: '/' + exclude_keys: + - run_config + - distributed_training.distributed_port + - distributed_training.distributed_world_size + - model.pretrained_model_path + - model.target_network_path + - next_script + - task.cache_in_scratch + - task.local_cache_path + - task.data + - checkpoint.save_interval_updates + - checkpoint.keep_interval_updates + - checkpoint.save_on_overflow + - common.log_interval + - common.user_dir + sweep: + dir: /fsx-wav2vec/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname} + subdir: '' + launcher: + submitit_folder: ${hydra.sweep.dir} + timeout_min: 4320 + cpus_per_task: 80 + gpus_per_node: 8 + tasks_per_node: 1 + mem_gb: 0 + nodes: 1 + name: ${env:PREFIX}_${hydra.job.config_name} + partition: wav2vec,learnlab,learnfair + max_num_timeout: 30 diff --git a/examples/wav2vec/config/finetuning/run_config/slurm_1_old.yaml b/examples/wav2vec/config/finetuning/run_config/slurm_1_old.yaml new file mode 100644 index 000000000..a8d2363dc --- /dev/null +++ b/examples/wav2vec/config/finetuning/run_config/slurm_1_old.yaml @@ -0,0 +1,27 @@ +# @package _global_ + +hydra: + job: + config: + override_dirname: + kv_sep: ':' + item_sep: '__' + exclude_keys: + - run_config + - distributed_training.distributed_port + sweep: + dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname} + subdir: ${hydra.job.num} + launcher: + submitit_folder: ${hydra.sweep.dir} + timeout_min: 4320 + cpus_per_task: 80 + gpus_per_node: 8 + tasks_per_node: 1 + mem_gb: 450 + nodes: 1 + name: ${env:PREFIX}_wav2vec3_small_librispeech + partition: devlab,learnlab,learnfair,scavenge + constraint: volta32gb + max_num_timeout: 30 + exclude: learnfair1381 \ No newline at end of file diff --git a/examples/wav2vec/config/finetuning/run_config/slurm_2.yaml b/examples/wav2vec/config/finetuning/run_config/slurm_2.yaml new file mode 100644 index 000000000..65ec48920 --- /dev/null +++ b/examples/wav2vec/config/finetuning/run_config/slurm_2.yaml @@ -0,0 +1,27 @@ +# @package _global_ + +hydra: + job: + config: + override_dirname: + kv_sep: ':' + item_sep: '__' + exclude_keys: + - run_config + - distributed_training.distributed_port + sweep: + dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname} + subdir: ${hydra.job.num} + launcher: + submitit_folder: ${hydra.sweep.dir} + timeout_min: 4320 + cpus_per_task: 10 + gpus_per_node: 8 + tasks_per_node: 8 + mem_gb: 450 + nodes: 2 + name: ${env:PREFIX}_${hydra.job.config_name} + partition: devlab,learnlab,learnfair,scavenge + constraint: volta32gb + max_num_timeout: 30 + exclude: learnfair7491,learnfair7477,learnfair7487 \ No newline at end of file diff --git a/examples/wav2vec/config/finetuning/run_config/slurm_2_aws.yaml b/examples/wav2vec/config/finetuning/run_config/slurm_2_aws.yaml new file mode 100644 index 000000000..e7590efc0 --- /dev/null +++ b/examples/wav2vec/config/finetuning/run_config/slurm_2_aws.yaml @@ -0,0 +1,37 @@ +# @package _global_ + +hydra: + job: + config: + override_dirname: + kv_sep: ':' + item_sep: '/' + exclude_keys: + - run_config + - distributed_training.distributed_port + - distributed_training.distributed_world_size + - model.pretrained_model_path + - model.target_network_path + - next_script + - task.cache_in_scratch + - task.local_cache_path + - task.data + - checkpoint.save_interval_updates + - checkpoint.keep_interval_updates + - checkpoint.save_on_overflow + - common.log_interval + - common.user_dir + sweep: + dir: /fsx-wav2vec/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname} + subdir: '' + launcher: + submitit_folder: ${hydra.sweep.dir} + timeout_min: 4320 + cpus_per_task: 80 + gpus_per_node: 8 + tasks_per_node: 1 + mem_gb: 0 + nodes: 2 + name: ${env:PREFIX}_${hydra.job.config_name} + partition: wav2vec,learnlab,learnfair + max_num_timeout: 30 diff --git a/examples/wav2vec/config/finetuning/run_config/slurm_2g.yaml b/examples/wav2vec/config/finetuning/run_config/slurm_2g.yaml new file mode 100644 index 000000000..aaa20ebd0 --- /dev/null +++ b/examples/wav2vec/config/finetuning/run_config/slurm_2g.yaml @@ -0,0 +1,26 @@ +# @package _global_ + +hydra: + job: + config: + override_dirname: + kv_sep: ':' + item_sep: '__' + exclude_keys: + - run_config + - distributed_training.distributed_port + sweep: + dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname} + subdir: ${hydra.job.num} + launcher: + submitit_folder: ${hydra.sweep.dir} + timeout_min: 4320 + cpus_per_task: 10 + gpus_per_node: 2 + tasks_per_node: 2 + mem_gb: 200 + nodes: 1 + name: ${env:PREFIX}_${hydra.job.config_name} + partition: devlab,learnlab,learnfair,scavenge + constraint: volta32gb + max_num_timeout: 30 diff --git a/examples/wav2vec/config/finetuning/run_config/slurm_3.yaml b/examples/wav2vec/config/finetuning/run_config/slurm_3.yaml new file mode 100644 index 000000000..9614ececa --- /dev/null +++ b/examples/wav2vec/config/finetuning/run_config/slurm_3.yaml @@ -0,0 +1,27 @@ +# @package _global_ + +hydra: + job: + config: + override_dirname: + kv_sep: ':' + item_sep: '__' + exclude_keys: + - run_config + - distributed_training.distributed_port + sweep: + dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname} + subdir: ${hydra.job.num} + launcher: + submitit_folder: ${hydra.sweep.dir} + timeout_min: 4320 + cpus_per_task: 10 + gpus_per_node: 8 + tasks_per_node: 8 + mem_gb: 450 + nodes: 3 + name: ${env:PREFIX}_${hydra.job.config_name} + partition: devlab,learnlab,learnfair,scavenge + constraint: volta32gb + max_num_timeout: 30 + exclude: learnfair7491,learnfair7477,learnfair7487 \ No newline at end of file diff --git a/examples/wav2vec/config/finetuning/run_config/slurm_4g.yaml b/examples/wav2vec/config/finetuning/run_config/slurm_4g.yaml new file mode 100644 index 000000000..c0c9f6043 --- /dev/null +++ b/examples/wav2vec/config/finetuning/run_config/slurm_4g.yaml @@ -0,0 +1,26 @@ +# @package _global_ + +hydra: + job: + config: + override_dirname: + kv_sep: ':' + item_sep: '__' + exclude_keys: + - run_config + - distributed_training.distributed_port + sweep: + dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname} + subdir: ${hydra.job.num} + launcher: + submitit_folder: ${hydra.sweep.dir} + timeout_min: 4320 + cpus_per_task: 10 + gpus_per_node: 4 + tasks_per_node: 4 + mem_gb: 200 + nodes: 1 + name: ${env:PREFIX}_${hydra.job.config_name} + partition: devlab,learnlab,learnfair,scavenge + constraint: volta32gb + max_num_timeout: 30 diff --git a/examples/wav2vec/config/finetuning/run_config/slurm_4g_aws.yaml b/examples/wav2vec/config/finetuning/run_config/slurm_4g_aws.yaml new file mode 100644 index 000000000..6bbbf3b64 --- /dev/null +++ b/examples/wav2vec/config/finetuning/run_config/slurm_4g_aws.yaml @@ -0,0 +1,37 @@ +# @package _global_ + +hydra: + job: + config: + override_dirname: + kv_sep: ':' + item_sep: '/' + exclude_keys: + - run_config + - distributed_training.distributed_port + - distributed_training.distributed_world_size + - model.pretrained_model_path + - model.target_network_path + - next_script + - task.cache_in_scratch + - task.local_cache_path + - task.data + - checkpoint.save_interval_updates + - checkpoint.keep_interval_updates + - checkpoint.save_on_overflow + - common.log_interval + - common.user_dir + sweep: + dir: /fsx-wav2vec/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname} + subdir: '' + launcher: + submitit_folder: ${hydra.sweep.dir} + timeout_min: 4320 + cpus_per_task: 80 + gpus_per_node: 4 + tasks_per_node: 1 + mem_gb: 0 + nodes: 1 + name: ${env:PREFIX}_${hydra.job.config_name} + partition: wav2vec,learnlab,learnfair + max_num_timeout: 30 diff --git a/examples/wav2vec/config/finetuning/run_config/slurm_8.yaml b/examples/wav2vec/config/finetuning/run_config/slurm_8.yaml new file mode 100644 index 000000000..984f21888 --- /dev/null +++ b/examples/wav2vec/config/finetuning/run_config/slurm_8.yaml @@ -0,0 +1,26 @@ +# @package _global_ + +hydra: + job: + config: + override_dirname: + kv_sep: ':' + item_sep: '__' + exclude_keys: + - run_config + - distributed_training.distributed_port + sweep: + dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname} + subdir: ${hydra.job.num} + launcher: + submitit_folder: ${hydra.sweep.dir} + timeout_min: 4320 + cpus_per_task: 10 + gpus_per_node: 8 + tasks_per_node: 8 + mem_gb: 400 + nodes: 8 + name: ${env:PREFIX}_${hydra.job.config_name} + partition: devlab,learnlab,learnfair,scavenge + constraint: volta32gb + max_num_timeout: 30 diff --git a/examples/wav2vec/config/finetuning/vox_100h_2.yaml b/examples/wav2vec/config/finetuning/vox_100h_2.yaml new file mode 100644 index 000000000..9bf588f58 --- /dev/null +++ b/examples/wav2vec/config/finetuning/vox_100h_2.yaml @@ -0,0 +1,106 @@ +# @package _group_ + +common: + fp16: true + log_format: json + log_interval: 200 + user_dir: /private/home/abaevski/fairseq-py/examples/data2vec +# tensorboard_logdir: tb + +checkpoint: + save_interval: 1 + no_epoch_checkpoints: true + best_checkpoint_metric: wer + +task: + _name: audio_finetuning + data: /checkpoint/abaevski/data/speech/libri/1h/wav2vec/raw + labels: ltr + normalize: true + +dataset: + num_workers: 6 + max_tokens: 1280000 + skip_invalid_size_inputs_valid_test: true + validate_after_updates: 100 + validate_interval: 1 + valid_subset: dev_other + required_batch_size_multiple: 1 + +distributed_training: + ddp_backend: legacy_ddp + distributed_world_size: 8 + +criterion: + _name: ctc + zero_infinity: true + post_process: letter + wer_kenlm_model: /checkpoint/abaevski/data/speech/libri/4-gram.bin + wer_lexicon: /checkpoint/abaevski/data/speech/libri/10h/wav2vec/raw/lexicon_ltr2.lst + wer_lm_weight: 2.0 + wer_word_score: 0 + wer_sil_weight: -2 + +optimization: + max_update: 100000 + lr: [1e-5] +# lr: [1e-5] # base 10h wer + sentence_avg: true + update_freq: [1] # base 10h we -> 2/4 + +optimizer: + _name: adam + adam_betas: (0.9,0.98) + adam_eps: 1e-08 + +lr_scheduler: + _name: tri_stage + phase_ratio: null + warmup_steps: 8000 + hold_steps: 0 + decay_steps: 72000 + final_lr_scale: 0.05 + +model: + _name: wav2vec_ctc + w2v_path: ??? + apply_mask: true + mask_prob: 0.4 + mask_length: 5 +# mask_prob: 0.65 # base 10h wer + mask_channel_prob: 0.1 +# mask_channel_prob: 0.6 # base 10h wer + mask_channel_length: 64 + layerdrop: 0.1 +# layerdrop: 0.05 # base 10h wer + activation_dropout: 0.1 + feature_grad_mult: 0.0 + freeze_finetune_updates: 100 + dropout: 0 + final_dropout: 0 + attention_dropout: 0 + +hydra: + job: + config: + override_dirname: + kv_sep: ':' + item_sep: '__' + exclude_keys: + - run_config + - distributed_training.distributed_port + sweep: + dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}/${hydra.job.override_dirname} + subdir: ${hydra.job.num} + launcher: + submitit_folder: ${hydra.sweep.dir} + timeout_min: 3000 + cpus_per_task: 10 + gpus_per_node: 4 + tasks_per_node: 4 + mem_gb: 250 + nodes: 1 + name: ${env:PREFIX}_${hydra.job.config_name} + partition: devlab,learnlab,learnfair,scavenge + constraint: volta32gb + max_num_timeout: 30 diff --git a/examples/wav2vec/config/finetuning/vox_100h_2_aws.yaml b/examples/wav2vec/config/finetuning/vox_100h_2_aws.yaml new file mode 100644 index 000000000..3a0d517eb --- /dev/null +++ b/examples/wav2vec/config/finetuning/vox_100h_2_aws.yaml @@ -0,0 +1,82 @@ +# @package _group_ + +common: + fp16: true + log_format: json + log_interval: 200 + user_dir: /data/home/abaevski/fairseq-py/examples/data2vec +# tensorboard_logdir: tb + +checkpoint: + save_interval: 1 + no_epoch_checkpoints: true + best_checkpoint_metric: wer + +task: + _name: audio_finetuning + data: /fsx-wav2vec/abaevski/data/libri/100h/raw + labels: ltr + normalize: true + +dataset: + num_workers: 6 + max_tokens: 1280000 + skip_invalid_size_inputs_valid_test: true + validate_after_updates: 100 + validate_interval: 1 + valid_subset: dev_other + required_batch_size_multiple: 1 + +distributed_training: + ddp_backend: legacy_ddp + distributed_world_size: 8 + +criterion: + _name: ctc + zero_infinity: true + post_process: letter + wer_kenlm_model: /fsx-wav2vec/abaevski/data/libri/4-gram.bin + wer_lexicon: /fsx-wav2vec/abaevski/data/libri/10h/wav2vec/raw/lexicon_ltr2.lst + wer_lm_weight: 2.0 + wer_word_score: 0 + wer_sil_weight: -2 + +optimization: + max_update: 100000 + lr: [1e-5] +# lr: [1e-5] # base 10h wer + sentence_avg: true + update_freq: [1] # base 10h we -> 2/4 + +optimizer: + _name: adam + adam_betas: (0.9,0.98) + adam_eps: 1e-08 + +lr_scheduler: + _name: tri_stage + phase_ratio: null + warmup_steps: 8000 + hold_steps: 0 + decay_steps: 82000 + final_lr_scale: 0.05 + +model: + _name: wav2vec_ctc + w2v_path: ??? + apply_mask: true + mask_prob: 0.4 + mask_length: 7 +# mask_prob: 0.65 # base 10h wer + mask_channel_prob: 0.1 +# mask_channel_prob: 0.6 # base 10h wer + mask_channel_length: 64 + layerdrop: 0 +# layerdrop: 0.05 # base 10h wer + activation_dropout: 0.1 + feature_grad_mult: 0.0 + freeze_finetune_updates: 100 + dropout: 0 + final_dropout: 0 + attention_dropout: 0 + diff --git a/examples/wav2vec/config/finetuning/vox_100h_3.yaml b/examples/wav2vec/config/finetuning/vox_100h_3.yaml new file mode 100644 index 000000000..46778666f --- /dev/null +++ b/examples/wav2vec/config/finetuning/vox_100h_3.yaml @@ -0,0 +1,101 @@ +# @package _group_ + +common: + fp16: true + log_format: json + log_interval: 200 + user_dir: /private/home/abaevski/fairseq-py/examples/data2vec +# tensorboard_logdir: tb + +checkpoint: + save_interval: 1 + no_epoch_checkpoints: true + best_checkpoint_metric: wer + +task: + _name: audio_finetuning + data: /checkpoint/abaevski/data/speech/libri/1h/wav2vec/raw + labels: ltr + normalize: true + +dataset: + num_workers: 6 + max_tokens: 1000000 + skip_invalid_size_inputs_valid_test: true + validate_after_updates: 100 + validate_interval: 1 + valid_subset: dev_other + required_batch_size_multiple: 1 + +distributed_training: + ddp_backend: legacy_ddp + distributed_world_size: 8 + +criterion: + _name: ctc + zero_infinity: true + post_process: letter + wer_kenlm_model: /checkpoint/abaevski/data/speech/libri/4-gram.bin + wer_lexicon: /checkpoint/abaevski/data/speech/libri/10h/wav2vec/raw/lexicon_ltr2.lst + wer_lm_weight: 2.0 + wer_word_score: -1.0 + +optimization: + max_update: 100000 + lr: [1e-5] +# lr: [1e-5] # base 10h wer + sentence_avg: true + update_freq: [1] # base 10h we -> 2/4 + +optimizer: + _name: adam + adam_betas: (0.9,0.98) + adam_eps: 1e-08 + +lr_scheduler: + _name: cosine + warmup_updates: 8000 + +model: + _name: wav2vec_ctc + w2v_path: ??? + apply_mask: true + mask_prob: 0.4 + mask_length: 5 +# mask_prob: 0.65 # base 10h wer + mask_channel_prob: 0.1 +# mask_channel_prob: 0.6 # base 10h wer + mask_channel_length: 64 + layerdrop: 0.1 +# layerdrop: 0.05 # base 10h wer + activation_dropout: 0.1 + feature_grad_mult: 0.0 + freeze_finetune_updates: 100 + dropout: 0 + final_dropout: 0 + attention_dropout: 0 + +hydra: + job: + config: + override_dirname: + kv_sep: ':' + item_sep: '__' + exclude_keys: + - run_config + - distributed_training.distributed_port + sweep: + dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}/${hydra.job.override_dirname} + subdir: ${hydra.job.num} + launcher: + submitit_folder: ${hydra.sweep.dir} + timeout_min: 3000 + cpus_per_task: 10 + gpus_per_node: 4 + tasks_per_node: 4 + mem_gb: 250 + nodes: 1 + name: ${env:PREFIX}_${hydra.job.config_name} + partition: devlab,learnlab,learnfair,scavenge + constraint: volta32gb + max_num_timeout: 30 diff --git a/examples/wav2vec/config/finetuning/vox_10h_2.yaml b/examples/wav2vec/config/finetuning/vox_10h_2.yaml new file mode 100644 index 000000000..05ee76f14 --- /dev/null +++ b/examples/wav2vec/config/finetuning/vox_10h_2.yaml @@ -0,0 +1,102 @@ +# @package _group_ + +common: + fp16: true + log_format: json + log_interval: 200 + user_dir: /private/home/abaevski/fairseq-py/examples/data2vec +# tensorboard_logdir: tb + +checkpoint: + save_interval: 10 + no_epoch_checkpoints: true + best_checkpoint_metric: wer + keep_interval_updates: 1 + +task: + _name: audio_finetuning + data: /checkpoint/abaevski/data/speech/libri/10h/wav2vec/raw + labels: ltr + normalize: true + +dataset: + num_workers: 6 + max_tokens: 1280000 + skip_invalid_size_inputs_valid_test: true + validate_after_updates: 100 + validate_interval: 10 + valid_subset: dev_other + required_batch_size_multiple: 1 + +distributed_training: + ddp_backend: legacy_ddp + distributed_world_size: 4 + +criterion: + _name: ctc + zero_infinity: true + post_process: letter + wer_kenlm_model: /checkpoint/abaevski/data/speech/libri/4-gram.bin + wer_lexicon: /checkpoint/abaevski/data/speech/libri/10h/wav2vec/raw/lexicon_ltr2.lst + wer_lm_weight: 2.0 + wer_word_score: -1.0 + +optimization: + max_update: 60000 + lr: [2e-5] +# lr: [1e-5] # base 10h wer + sentence_avg: true + update_freq: [1] # base 10h we -> 2/4 + +optimizer: + _name: adam + adam_betas: (0.9,0.98) + adam_eps: 1e-08 + +lr_scheduler: + _name: cosine + warmup_updates: 8000 + +model: + _name: wav2vec_ctc + w2v_path: ??? + apply_mask: true + mask_prob: 0.5 + mask_length: 5 +# mask_prob: 0.65 # base 10h wer + mask_channel_prob: 0.1 +# mask_channel_prob: 0.6 # base 10h wer + mask_channel_length: 64 + layerdrop: 0.1 +# layerdrop: 0.05 # base 10h wer + activation_dropout: 0.1 + feature_grad_mult: 0.0 + freeze_finetune_updates: 100 + dropout: 0 + final_dropout: 0 + attention_dropout: 0 + +hydra: + job: + config: + override_dirname: + kv_sep: ':' + item_sep: '__' + exclude_keys: + - run_config + - distributed_training.distributed_port + sweep: + dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}/${hydra.job.override_dirname} + subdir: ${hydra.job.num} + launcher: + submitit_folder: ${hydra.sweep.dir} + timeout_min: 3000 + cpus_per_task: 10 + gpus_per_node: 4 + tasks_per_node: 4 + mem_gb: 250 + nodes: 1 + name: ${env:PREFIX}_${hydra.job.config_name} + partition: devlab,learnlab,learnfair,scavenge + constraint: volta32gb + max_num_timeout: 30 diff --git a/examples/wav2vec/config/finetuning/vox_10h_2_aws.yaml b/examples/wav2vec/config/finetuning/vox_10h_2_aws.yaml new file mode 100644 index 000000000..a0afc9c5d --- /dev/null +++ b/examples/wav2vec/config/finetuning/vox_10h_2_aws.yaml @@ -0,0 +1,81 @@ +# @package _group_ + +common: + fp16: true + log_format: json + log_interval: 200 + user_dir: /data/home/abaevski/fairseq-py/examples/data2vec +# tensorboard_logdir: tb + +checkpoint: + save_interval: 10 + no_epoch_checkpoints: true + best_checkpoint_metric: wer + +task: + _name: audio_finetuning + data: /fsx-wav2vec/abaevski/data/libri/10h/wav2vec/raw + labels: ltr + normalize: true + +dataset: + num_workers: 6 + max_tokens: 1280000 + skip_invalid_size_inputs_valid_test: true + validate_after_updates: 100 + validate_interval: 10 + valid_subset: dev_other + required_batch_size_multiple: 1 + +distributed_training: + ddp_backend: legacy_ddp + distributed_world_size: 4 + +criterion: + _name: ctc + zero_infinity: true + post_process: letter + wer_kenlm_model: /fsx-wav2vec/abaevski/data/libri/4-gram.bin + wer_lexicon: /fsx-wav2vec/abaevski/data/libri/10h/wav2vec/raw/lexicon_ltr2.lst + wer_lm_weight: 2.0 + wer_word_score: 4 + wer_sil_weight: -5 + +optimization: + max_update: 60000 + lr: [1e-5] +# lr: [1e-5] # base 10h wer + sentence_avg: true + update_freq: [1] # base 10h we -> 2/4 + +optimizer: + _name: adam + adam_betas: (0.9,0.98) + adam_eps: 1e-08 + +lr_scheduler: + _name: tri_stage + phase_ratio: null + warmup_steps: 8000 + hold_steps: 0 + decay_steps: 72000 + final_lr_scale: 0.05 + +model: + _name: wav2vec_ctc + w2v_path: ??? + apply_mask: true + mask_prob: 0.75 + mask_length: 5 +# mask_prob: 0.65 # base 10h wer + mask_channel_prob: 0.1 +# mask_channel_prob: 0.6 # base 10h wer + mask_channel_length: 64 + layerdrop: 0 +# layerdrop: 0.05 # base 10h wer + activation_dropout: 0.1 + feature_grad_mult: 0.0 + freeze_finetune_updates: 100 + dropout: 0 + final_dropout: 0 + attention_dropout: 0 diff --git a/examples/wav2vec/config/finetuning/vox_10h_aws.yaml b/examples/wav2vec/config/finetuning/vox_10h_aws.yaml new file mode 100644 index 000000000..c75437365 --- /dev/null +++ b/examples/wav2vec/config/finetuning/vox_10h_aws.yaml @@ -0,0 +1,104 @@ +# @package _group_ + +common: + fp16: true + log_format: json + log_interval: 200 + user_dir: /data/home/abaevski/fairseq-py/examples/data2vec +# tensorboard_logdir: tb + +checkpoint: + save_interval: 10 + no_epoch_checkpoints: true + best_checkpoint_metric: wer + +task: + _name: audio_finetuning + data: /fsx-wav2vec/abaevski/data/libri/10h/wav2vec/raw + labels: ltr + normalize: true + +dataset: + num_workers: 6 + max_tokens: 1280000 + skip_invalid_size_inputs_valid_test: true + validate_after_updates: 100 + validate_interval: 10 + valid_subset: dev_other + required_batch_size_multiple: 1 + +distributed_training: + ddp_backend: legacy_ddp + distributed_world_size: 4 + +criterion: + _name: ctc + zero_infinity: true + post_process: letter +# wer_kenlm_model: /fsx-wav2vec/abaevski/data/libri/4-gram.bin +# wer_lexicon: /fsx-wav2vec/abaevski/data/libri/10h/wav2vec/raw/lexicon_ltr2.lst +# wer_lm_weight: 2.0 +# wer_word_score: -1.0 + +optimization: + max_update: 60000 + lr: [2e-5] +# lr: [1e-5] # base 10h wer + sentence_avg: true + update_freq: [1] # base 10h we -> 2/4 + +optimizer: + _name: adam + adam_betas: (0.9,0.98) + adam_eps: 1e-08 + +lr_scheduler: + _name: tri_stage + phase_ratio: null + warmup_steps: 8000 + hold_steps: 0 + decay_steps: 72000 + final_lr_scale: 0.05 + +model: + _name: wav2vec_ctc + w2v_path: ??? + apply_mask: true + mask_prob: 0.4 + mask_length: 5 +# mask_prob: 0.65 # base 10h wer + mask_channel_prob: 0.1 +# mask_channel_prob: 0.6 # base 10h wer + mask_channel_length: 64 + layerdrop: 0.1 +# layerdrop: 0.05 # base 10h wer + activation_dropout: 0.1 + feature_grad_mult: 0.0 + freeze_finetune_updates: 100 + dropout: 0 + final_dropout: 0 + attention_dropout: 0 + +hydra: + job: + config: + override_dirname: + kv_sep: ':' + item_sep: '__' + exclude_keys: + - run_config + - distributed_training.distributed_port + sweep: + dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}/${hydra.job.override_dirname} + subdir: ${hydra.job.num} + launcher: + submitit_folder: ${hydra.sweep.dir} + timeout_min: 3000 + cpus_per_task: 10 + gpus_per_node: 4 + tasks_per_node: 4 + mem_gb: 0 + nodes: 1 + name: ${env:PREFIX}_${hydra.job.config_name} + partition: wav2vec,learnlab + max_num_timeout: 30 diff --git a/examples/wav2vec/config/finetuning/vox_10h_aws_v100.yaml b/examples/wav2vec/config/finetuning/vox_10h_aws_v100.yaml new file mode 100644 index 000000000..58ad2acf7 --- /dev/null +++ b/examples/wav2vec/config/finetuning/vox_10h_aws_v100.yaml @@ -0,0 +1,102 @@ +# @package _group_ + +common: + fp16: true + log_format: json + log_interval: 200 +# tensorboard_logdir: tb + +checkpoint: + save_interval: 10 + no_epoch_checkpoints: true + best_checkpoint_metric: wer + +task: + _name: audio_finetuning + data: /fsx/abaevski/data/libri/10h/wav2vec/raw + labels: ltr + cache_in_scratch: true + + +dataset: + num_workers: 10 + max_tokens: 1280000 + skip_invalid_size_inputs_valid_test: true + validate_after_updates: 100 + validate_interval: 10 + valid_subset: dev_other + required_batch_size_multiple: 1 + +distributed_training: + ddp_backend: legacy_ddp + distributed_world_size: 4 + +criterion: + _name: ctc + zero_infinity: true + post_process: letter + wer_lexicon: /fsx/abaevski/data/libri/10h/wav2vec/raw/lexicon_ltr2.lst + wer_lm_weight: 2.0 + wer_word_score: -1.0 + +optimization: + max_update: 60000 + lr: [2e-5] +# lr: [1e-5] # base 10h wer + sentence_avg: true + update_freq: [1] # base 10h we -> 2/4 + +optimizer: + _name: adam + adam_betas: (0.9,0.98) + adam_eps: 1e-08 + +lr_scheduler: + _name: tri_stage + phase_ratio: null + warmup_steps: 8000 + hold_steps: 0 + decay_steps: 72000 + final_lr_scale: 0.05 + +model: + _name: wav2vec_ctc + w2v_path: ??? + apply_mask: true + mask_prob: 0.6 +# mask_prob: 0.65 # base 10h wer + mask_channel_prob: 0.1 +# mask_channel_prob: 0.6 # base 10h wer + mask_channel_length: 64 + layerdrop: 0.1 +# layerdrop: 0.05 # base 10h wer + activation_dropout: 0.1 + feature_grad_mult: 0.0 + freeze_finetune_updates: 100 + dropout: 0 + final_dropout: 0 + attention_dropout: 0 + +hydra: + job: + config: + override_dirname: + kv_sep: ':' + item_sep: '__' + exclude_keys: + - run_config + - distributed_training.distributed_port + sweep: + dir: /fsx/${env:USER}/w2v_ft/${env:PREFIX}/${hydra.job.config_name}/${hydra.job.override_dirname} + subdir: ${hydra.job.num} + launcher: + submitit_folder: ${hydra.sweep.dir} + timeout_min: 3000 + cpus_per_task: 10 + gpus_per_node: 4 + tasks_per_node: 4 + mem_gb: 0 + nodes: 1 + name: ${env:PREFIX}_${hydra.job.config_name} + partition: learnfair + max_num_timeout: 30 diff --git a/examples/wav2vec/config/finetuning/vox_10m_2.yaml b/examples/wav2vec/config/finetuning/vox_10m_2.yaml new file mode 100644 index 000000000..1ac7c1217 --- /dev/null +++ b/examples/wav2vec/config/finetuning/vox_10m_2.yaml @@ -0,0 +1,114 @@ +# @package _group_ + +common: + fp16: true + fp16_no_flatten_grads: true + log_format: json + log_interval: 200 + user_dir: /private/home/abaevski/fairseq-py/examples/data2vec +# tensorboard_logdir: tb + +checkpoint: + save_interval: 500 + save_interval_updates: 500 + keep_interval_updates: 1 + no_epoch_checkpoints: true + best_checkpoint_metric: wer + +task: + _name: audio_finetuning + data: /checkpoint/abaevski/data/speech/libri/10m/wav2vec/raw + labels: ltr + normalize: true + +dataset: + num_workers: 6 + max_tokens: 1000000 + skip_invalid_size_inputs_valid_test: true + validate_after_updates: 100 + validate_interval: 500 + valid_subset: dev_other + required_batch_size_multiple: 1 + +distributed_training: + ddp_backend: legacy_ddp + distributed_world_size: 4 + +criterion: + _name: ctc + zero_infinity: true + post_process: letter + wer_kenlm_model: /checkpoint/abaevski/data/speech/libri/4-gram.bin + wer_lexicon: /checkpoint/abaevski/data/speech/libri/10h/wav2vec/raw/lexicon_ltr2.lst + wer_lm_weight: 5 + wer_word_score: 2 + wer_sil_weight: -2 + +optimization: + max_update: 10000 + lr: [2e-6] +# lr: [1e-5] # base 10h wer + sentence_avg: true + update_freq: [4] # base 10h we -> 2/4 + +optimizer: + _name: composite + dynamic_groups: true + groups: + default: + lr_float: 2e-6 + optimizer: + _name: adam + adam_betas: [0.9,0.95] + lr_scheduler: + _name: cosine + warmup_updates: 1000 + +lr_scheduler: pass_through + +model: + _name: wav2vec_ctc + w2v_path: ??? + apply_mask: true + mask_prob: 0.4 + mask_length: 3 +# mask_prob: 0.65 # base 10h wer + mask_channel_prob: 0.25 +# mask_channel_prob: 0.6 # base 10h wer + mask_channel_length: 64 + layerdrop: 0.1 +# layerdrop: 0.05 # base 10h wer + freeze_finetune_updates: 100 + + zero_mask: true + feature_grad_mult: 0.0 + activation_dropout: 0.1 + dropout: 0 + final_dropout: 0 + attention_dropout: 0 + update_alibi: false + +#hydra: +# job: +# config: +# override_dirname: +# kv_sep: ':' +# item_sep: '__' +# exclude_keys: +# - run_config +# - distributed_training.distributed_port +# sweep: +# dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}/${hydra.job.override_dirname} +# subdir: ${hydra.job.num} +# launcher: +# submitit_folder: ${hydra.sweep.dir} +# timeout_min: 3000 +# cpus_per_task: 10 +# gpus_per_node: 4 +# tasks_per_node: 4 +# mem_gb: 250 +# nodes: 1 +# name: ${env:PREFIX}_${hydra.job.config_name} +# partition: devlab,learnlab,learnfair,scavenge +# constraint: volta32gb +# max_num_timeout: 30 diff --git a/examples/wav2vec/config/finetuning/vox_10m_2_aws.yaml b/examples/wav2vec/config/finetuning/vox_10m_2_aws.yaml new file mode 100644 index 000000000..a9c270855 --- /dev/null +++ b/examples/wav2vec/config/finetuning/vox_10m_2_aws.yaml @@ -0,0 +1,114 @@ +# @package _group_ + +common: + fp16: true + fp16_no_flatten_grads: true + log_format: json + log_interval: 200 + user_dir: /data/home/abaevski/fairseq-py/examples/data2vec +# tensorboard_logdir: tb + +checkpoint: + save_interval: 500 + save_interval_updates: 500 + keep_interval_updates: 1 + no_epoch_checkpoints: true + best_checkpoint_metric: wer + +task: + _name: audio_finetuning + data: /fsx-wav2vec/abaevski/data/libri/10m/wav2vec/raw + labels: ltr + normalize: true + +dataset: + num_workers: 6 + max_tokens: 1000000 + skip_invalid_size_inputs_valid_test: true + validate_after_updates: 100 + validate_interval: 500 + valid_subset: dev_other + required_batch_size_multiple: 1 + +distributed_training: + ddp_backend: legacy_ddp + distributed_world_size: 4 + +criterion: + _name: ctc + zero_infinity: true + post_process: letter + wer_kenlm_model: /fsx-wav2vec/abaevski/data/libri/4-gram.bin + wer_lexicon: /fsx-wav2vec/abaevski/data/libri/10h/wav2vec/raw/lexicon_ltr2.lst + wer_lm_weight: 5 + wer_word_score: 2 + wer_sil_weight: -2 + +optimization: + max_update: 10000 + lr: [2e-6] +# lr: [1e-5] # base 10h wer + sentence_avg: true + update_freq: [4] # base 10h we -> 2/4 + +optimizer: + _name: composite + dynamic_groups: true + groups: + default: + lr_float: 2e-6 + optimizer: + _name: adam + adam_betas: [0.9,0.95] + lr_scheduler: + _name: cosine + warmup_updates: 1000 + +lr_scheduler: pass_through + +model: + _name: wav2vec_ctc + w2v_path: ??? + apply_mask: true + mask_prob: 0.4 + mask_length: 3 +# mask_prob: 0.65 # base 10h wer + mask_channel_prob: 0.25 +# mask_channel_prob: 0.6 # base 10h wer + mask_channel_length: 64 + layerdrop: 0.1 +# layerdrop: 0.05 # base 10h wer + freeze_finetune_updates: 100 + + zero_mask: true + feature_grad_mult: 0.0 + activation_dropout: 0.1 + dropout: 0 + final_dropout: 0 + attention_dropout: 0 + update_alibi: false + +#hydra: +# job: +# config: +# override_dirname: +# kv_sep: ':' +# item_sep: '__' +# exclude_keys: +# - run_config +# - distributed_training.distributed_port +# sweep: +# dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}/${hydra.job.override_dirname} +# subdir: ${hydra.job.num} +# launcher: +# submitit_folder: ${hydra.sweep.dir} +# timeout_min: 3000 +# cpus_per_task: 10 +# gpus_per_node: 4 +# tasks_per_node: 4 +# mem_gb: 250 +# nodes: 1 +# name: ${env:PREFIX}_${hydra.job.config_name} +# partition: devlab,learnlab,learnfair,scavenge +# constraint: volta32gb +# max_num_timeout: 30 diff --git a/examples/wav2vec/config/finetuning/vox_10m_3.yaml b/examples/wav2vec/config/finetuning/vox_10m_3.yaml new file mode 100644 index 000000000..b6804126c --- /dev/null +++ b/examples/wav2vec/config/finetuning/vox_10m_3.yaml @@ -0,0 +1,105 @@ +# @package _group_ + +common: + fp16: true + log_format: json + log_interval: 200 + user_dir: /private/home/abaevski/fairseq-py/examples/data2vec +# tensorboard_logdir: tb + +checkpoint: + save_interval: 1000 + save_interval_updates: 100 + keep_interval_updates: 1 + no_epoch_checkpoints: true + best_checkpoint_metric: wer + +task: + _name: audio_finetuning + data: /checkpoint/abaevski/data/speech/libri/10m/wav2vec/raw + labels: ltr + normalize: true + +dataset: + num_workers: 6 + max_tokens: 1280000 + skip_invalid_size_inputs_valid_test: true + validate_after_updates: 10000 + validate_interval: 500 + valid_subset: dev_other + required_batch_size_multiple: 8 + +distributed_training: + ddp_backend: legacy_ddp + distributed_world_size: 4 + +criterion: + _name: ctc + zero_infinity: true + post_process: letter + wer_kenlm_model: /checkpoint/abaevski/data/speech/libri/4-gram.bin + wer_lexicon: /checkpoint/abaevski/data/speech/libri/10h/wav2vec/raw/lexicon_ltr2.lst + wer_lm_weight: 8 + wer_word_score: 5.8 + wer_sil_weight: -8 + +optimization: + max_update: 13000 + lr: [2e-5] +# lr: [1e-5] # base 10h wer + sentence_avg: true + update_freq: [5] # base 10h we -> 2/4 + +optimizer: + _name: adam + adam_betas: (0.9,0.98) + adam_eps: 1e-08 + +lr_scheduler: + _name: tri_stage + phase_ratio: [0.1, 0.4, 0.5] + final_lr_scale: 0.05 + +model: + _name: wav2vec_ctc + w2v_path: ??? + apply_mask: true + mask_prob: 0.65 + mask_length: 10 +# mask_prob: 0.65 # base 10h wer + mask_channel_prob: 0.25 +# mask_channel_prob: 0.6 # base 10h wer + mask_channel_length: 64 + layerdrop: 0.1 +# layerdrop: 0.05 # base 10h wer + activation_dropout: 0.1 + feature_grad_mult: 0.0 + freeze_finetune_updates: 10000 + dropout: 0 + final_dropout: 0 + attention_dropout: 0 + +hydra: + job: + config: + override_dirname: + kv_sep: ':' + item_sep: '__' + exclude_keys: + - run_config + - distributed_training.distributed_port + sweep: + dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}/${hydra.job.override_dirname} + subdir: ${hydra.job.num} + launcher: + submitit_folder: ${hydra.sweep.dir} + timeout_min: 3000 + cpus_per_task: 10 + gpus_per_node: 4 + tasks_per_node: 4 + mem_gb: 250 + nodes: 1 + name: ${env:PREFIX}_${hydra.job.config_name} + partition: devlab,learnlab,learnfair,scavenge + constraint: volta32gb + max_num_timeout: 30 diff --git a/examples/wav2vec/config/finetuning/vox_1h_2.yaml b/examples/wav2vec/config/finetuning/vox_1h_2.yaml new file mode 100644 index 000000000..75f4aafd7 --- /dev/null +++ b/examples/wav2vec/config/finetuning/vox_1h_2.yaml @@ -0,0 +1,104 @@ +# @package _group_ + +common: + fp16: true + log_format: json + log_interval: 200 + user_dir: /private/home/abaevski/fairseq-py/examples/data2vec +# tensorboard_logdir: tb + +checkpoint: + save_interval: 100 + save_interval_updates: 500 + keep_interval_updates: 1 + no_epoch_checkpoints: true + best_checkpoint_metric: wer + +task: + _name: audio_finetuning + data: /checkpoint/abaevski/data/speech/libri/1h/wav2vec/raw + labels: ltr + normalize: true + +dataset: + num_workers: 6 + max_tokens: 1000000 + skip_invalid_size_inputs_valid_test: true + validate_after_updates: 100 + validate_interval: 100 + valid_subset: dev_other + required_batch_size_multiple: 1 + +distributed_training: + ddp_backend: legacy_ddp + distributed_world_size: 8 + +criterion: + _name: ctc + zero_infinity: true + post_process: letter + wer_kenlm_model: /checkpoint/abaevski/data/speech/libri/4-gram.bin + wer_lexicon: /checkpoint/abaevski/data/speech/libri/10h/wav2vec/raw/lexicon_ltr2.lst + wer_lm_weight: 6 + wer_word_score: -0.1 + wer_sil_weight: -4.7 + +optimization: + max_update: 60000 + lr: [1e-5] +# lr: [1e-5] # base 10h wer + sentence_avg: true + update_freq: [1] # base 10h we -> 2/4 + +optimizer: + _name: adam + adam_betas: (0.9,0.98) + adam_eps: 1e-08 + +lr_scheduler: + _name: cosine + warmup_updates: 4000 + +model: + _name: wav2vec_ctc + w2v_path: ??? + apply_mask: true + mask_prob: 0.65 + mask_length: 5 +# mask_prob: 0.65 # base 10h wer + mask_channel_prob: 0.25 +# mask_channel_prob: 0.6 # base 10h wer + mask_channel_length: 64 + layerdrop: 0.1 +# layerdrop: 0.05 # base 10h wer + activation_dropout: 0.1 + feature_grad_mult: 0.0 + freeze_finetune_updates: 100 + dropout: 0 + final_dropout: 0 + attention_dropout: 0 + +hydra: + job: + config: + override_dirname: + kv_sep: ':' + item_sep: '__' + exclude_keys: + - run_config + - distributed_training.distributed_port + sweep: + dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}/${hydra.job.override_dirname} + subdir: ${hydra.job.num} + launcher: + submitit_folder: ${hydra.sweep.dir} + timeout_min: 3000 + cpus_per_task: 10 + gpus_per_node: 4 + tasks_per_node: 4 + mem_gb: 250 + nodes: 1 + name: ${env:PREFIX}_${hydra.job.config_name} + partition: devlab,learnlab,learnfair,scavenge + constraint: volta32gb + max_num_timeout: 30 diff --git a/examples/wav2vec/config/finetuning/vox_1h_2_aws.yaml b/examples/wav2vec/config/finetuning/vox_1h_2_aws.yaml new file mode 100644 index 000000000..cc4d511d1 --- /dev/null +++ b/examples/wav2vec/config/finetuning/vox_1h_2_aws.yaml @@ -0,0 +1,114 @@ +# @package _group_ + +common: + fp16: true + fp16_no_flatten_grads: true + log_format: json + log_interval: 200 + user_dir: /data/home/abaevski/fairseq-py/examples/data2vec +# tensorboard_logdir: tb + +checkpoint: + save_interval: 100 + save_interval_updates: 500 + keep_interval_updates: 1 + no_epoch_checkpoints: true + best_checkpoint_metric: wer + +task: + _name: audio_finetuning + data: /fsx-wav2vec/abaevski/data/libri/1h/wav2vec/raw + labels: ltr + normalize: true + +dataset: + num_workers: 6 + max_tokens: 1000000 + skip_invalid_size_inputs_valid_test: true + validate_after_updates: 100 + validate_interval: 500 + valid_subset: dev_other + required_batch_size_multiple: 1 + +distributed_training: + ddp_backend: legacy_ddp + distributed_world_size: 4 + +criterion: + _name: ctc + zero_infinity: true + post_process: letter + wer_kenlm_model: /fsx-wav2vec/abaevski/data/libri/4-gram.bin + wer_lexicon: /fsx-wav2vec/abaevski/data/libri/10h/wav2vec/raw/lexicon_ltr2.lst + wer_lm_weight: 5 + wer_word_score: 0 + wer_sil_weight: -4 + +optimization: + max_update: 10000 + lr: [2e-6] +# lr: [1e-5] # base 10h wer + sentence_avg: true + update_freq: [4] # base 10h we -> 2/4 + +optimizer: + _name: composite + dynamic_groups: true + groups: + default: + lr_float: 2e-6 + optimizer: + _name: adam + adam_betas: [0.9,0.95] + lr_scheduler: + _name: cosine + warmup_updates: 1000 + +lr_scheduler: pass_through + +model: + _name: wav2vec_ctc + w2v_path: ??? + apply_mask: true + mask_prob: 0.4 + mask_length: 3 +# mask_prob: 0.65 # base 10h wer + mask_channel_prob: 0.25 +# mask_channel_prob: 0.6 # base 10h wer + mask_channel_length: 64 + layerdrop: 0.1 +# layerdrop: 0.05 # base 10h wer + freeze_finetune_updates: 100 + + zero_mask: true + feature_grad_mult: 0.0 + activation_dropout: 0.1 + dropout: 0 + final_dropout: 0 + attention_dropout: 0 + update_alibi: false + +#hydra: +# job: +# config: +# override_dirname: +# kv_sep: ':' +# item_sep: '__' +# exclude_keys: +# - run_config +# - distributed_training.distributed_port +# sweep: +# dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}/${hydra.job.override_dirname} +# subdir: ${hydra.job.num} +# launcher: +# submitit_folder: ${hydra.sweep.dir} +# timeout_min: 3000 +# cpus_per_task: 10 +# gpus_per_node: 4 +# tasks_per_node: 4 +# mem_gb: 250 +# nodes: 1 +# name: ${env:PREFIX}_${hydra.job.config_name} +# partition: devlab,learnlab,learnfair,scavenge +# constraint: volta32gb +# max_num_timeout: 30 diff --git a/examples/wav2vec/config/finetuning/vox_1h_3.yaml b/examples/wav2vec/config/finetuning/vox_1h_3.yaml new file mode 100644 index 000000000..842c89717 --- /dev/null +++ b/examples/wav2vec/config/finetuning/vox_1h_3.yaml @@ -0,0 +1,104 @@ +# @package _group_ + +common: + fp16: true + log_format: json + log_interval: 200 + user_dir: /private/home/abaevski/fairseq-py/examples/data2vec +# tensorboard_logdir: tb + +checkpoint: + save_interval: 100 + save_interval_updates: 500 + keep_interval_updates: 1 + no_epoch_checkpoints: true + best_checkpoint_metric: wer + +task: + _name: audio_finetuning + data: /checkpoint/abaevski/data/speech/libri/1h/wav2vec/raw + labels: ltr + normalize: true + +dataset: + num_workers: 6 + max_tokens: 640000 + skip_invalid_size_inputs_valid_test: true + validate_after_updates: 10000 + validate_interval: 100 + valid_subset: dev_other + required_batch_size_multiple: 8 + +distributed_training: + ddp_backend: legacy_ddp + distributed_world_size: 8 + +criterion: + _name: ctc + zero_infinity: true + post_process: letter + wer_kenlm_model: /checkpoint/abaevski/data/speech/libri/4-gram.bin + wer_lexicon: /checkpoint/abaevski/data/speech/libri/10h/wav2vec/raw/lexicon_ltr2.lst + wer_lm_weight: 6 + wer_word_score: -0.1 + wer_sil_weight: -4.7 + +optimization: + max_update: 13000 + lr: [6e-5] +# lr: [1e-5] # base 10h wer + sentence_avg: true + update_freq: [5] # base 10h we -> 2/4 + +optimizer: + _name: adam + adam_betas: (0.9,0.98) + adam_eps: 1e-08 + +lr_scheduler: + _name: cosine + warmup_updates: 4000 + +model: + _name: wav2vec_ctc + w2v_path: ??? + apply_mask: true + mask_prob: 0.3 + mask_length: 3 +# mask_prob: 0.65 # base 10h wer + mask_channel_prob: 0.25 +# mask_channel_prob: 0.6 # base 10h wer + mask_channel_length: 64 + layerdrop: 0.1 +# layerdrop: 0.05 # base 10h wer + activation_dropout: 0.1 + feature_grad_mult: 0.0 + freeze_finetune_updates: 10000 + dropout: 0 + final_dropout: 0 + attention_dropout: 0 + +hydra: + job: + config: + override_dirname: + kv_sep: ':' + item_sep: '__' + exclude_keys: + - run_config + - distributed_training.distributed_port + sweep: + dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}/${hydra.job.override_dirname} + subdir: ${hydra.job.num} + launcher: + submitit_folder: ${hydra.sweep.dir} + timeout_min: 3000 + cpus_per_task: 10 + gpus_per_node: 4 + tasks_per_node: 4 + mem_gb: 250 + nodes: 1 + name: ${env:PREFIX}_${hydra.job.config_name} + partition: devlab,learnlab,learnfair,scavenge + constraint: volta32gb + max_num_timeout: 30 diff --git a/examples/wav2vec/config/finetuning/vox_1h_4.yaml b/examples/wav2vec/config/finetuning/vox_1h_4.yaml new file mode 100644 index 000000000..698ed8c4d --- /dev/null +++ b/examples/wav2vec/config/finetuning/vox_1h_4.yaml @@ -0,0 +1,104 @@ +# @package _group_ + +common: + fp16: true + log_format: json + log_interval: 200 + user_dir: /private/home/abaevski/fairseq-py/examples/data2vec +# tensorboard_logdir: tb + +checkpoint: + save_interval: 100 + save_interval_updates: 1000 + keep_interval_updates: 1 + no_epoch_checkpoints: true + best_checkpoint_metric: wer + +task: + _name: audio_finetuning + data: /checkpoint/abaevski/data/speech/libri/1h/wav2vec/raw + labels: ltr + normalize: true + +dataset: + num_workers: 6 + max_tokens: 640000 + skip_invalid_size_inputs_valid_test: true + validate_after_updates: 10000 + validate_interval: 100 + valid_subset: dev_other + required_batch_size_multiple: 8 + +distributed_training: + ddp_backend: legacy_ddp + distributed_world_size: 8 + +criterion: + _name: ctc + zero_infinity: true + post_process: letter + wer_kenlm_model: /checkpoint/abaevski/data/speech/libri/4-gram.bin + wer_lexicon: /checkpoint/abaevski/data/speech/libri/10h/wav2vec/raw/lexicon_ltr2.lst + wer_lm_weight: 2.0 + wer_word_score: -1.0 + +optimization: + max_update: 13000 + lr: [6e-5] +# lr: [1e-5] # base 10h wer + sentence_avg: true + update_freq: [5] # base 10h we -> 2/4 + +optimizer: + _name: adam + adam_betas: (0.9,0.98) + adam_eps: 1e-08 + +lr_scheduler: + _name: tri_stage + phase_ratio: [0.1, 0.4, 0.5] + final_lr_scale: 0.05 + +model: + _name: wav2vec_ctc + w2v_path: ??? + apply_mask: true + mask_prob: 0.65 + mask_length: 10 +# mask_prob: 0.65 # base 10h wer + mask_channel_prob: 0.25 +# mask_channel_prob: 0.6 # base 10h wer + mask_channel_length: 64 + layerdrop: 0.1 +# layerdrop: 0.05 # base 10h wer + activation_dropout: 0.1 + feature_grad_mult: 0.0 + freeze_finetune_updates: 10000 + dropout: 0 + final_dropout: 0 + attention_dropout: 0 + +hydra: + job: + config: + override_dirname: + kv_sep: ':' + item_sep: '__' + exclude_keys: + - run_config + - distributed_training.distributed_port + sweep: + dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}/${hydra.job.override_dirname} + subdir: ${hydra.job.num} + launcher: + submitit_folder: ${hydra.sweep.dir} + timeout_min: 3000 + cpus_per_task: 10 + gpus_per_node: 4 + tasks_per_node: 4 + mem_gb: 250 + nodes: 1 + name: ${env:PREFIX}_${hydra.job.config_name} + partition: devlab,learnlab,learnfair,scavenge + constraint: volta32gb + max_num_timeout: 30 diff --git a/examples/wav2vec/config/finetuning/vox_1h_aws.yaml b/examples/wav2vec/config/finetuning/vox_1h_aws.yaml new file mode 100644 index 000000000..aa6700415 --- /dev/null +++ b/examples/wav2vec/config/finetuning/vox_1h_aws.yaml @@ -0,0 +1,80 @@ +# @package _group_ + +common: + fp16: true + log_format: json + log_interval: 200 + user_dir: /data/home/abaevski/fairseq-py/examples/data2vec +# tensorboard_logdir: tb + +checkpoint: + save_interval: 100 + save_interval_updates: 500 + keep_interval_updates: 1 + no_epoch_checkpoints: true + best_checkpoint_metric: wer + +task: + _name: audio_finetuning + data: /fsx-wav2vec/abaevski/data/libri/10m/wav2vec/raw + labels: ltr + normalize: true + +dataset: + num_workers: 6 + max_tokens: 1000000 + skip_invalid_size_inputs_valid_test: true + validate_after_updates: 10000 + validate_interval: 100 + valid_subset: dev_other + required_batch_size_multiple: 8 + +distributed_training: + ddp_backend: legacy_ddp + distributed_world_size: 8 + +criterion: + _name: ctc + zero_infinity: true + post_process: letter + wer_kenlm_model: /fsx-wav2vec/abaevski/data/libri/4-gram.bin + wer_lexicon: /fsx-wav2vec/abaevski/data/libri/10h/wav2vec/raw/lexicon_ltr2.lst + wer_lm_weight: 5 + wer_word_score: -0.1 + wer_sil_weight: -4.7 + +optimization: + max_update: 13000 + lr: [6e-5] +# lr: [1e-5] # base 10h wer + sentence_avg: true + update_freq: [5] # base 10h we -> 2/4 + +optimizer: + _name: adam + adam_betas: (0.9,0.98) + adam_eps: 1e-08 + +lr_scheduler: + _name: cosine + warmup_updates: 4000 + +model: + _name: wav2vec_ctc + w2v_path: ??? + apply_mask: true + mask_prob: 0.3 + mask_length: 3 +# mask_prob: 0.65 # base 10h wer + mask_channel_prob: 0.25 +# mask_channel_prob: 0.6 # base 10h wer + mask_channel_length: 64 + layerdrop: 0.1 +# layerdrop: 0.05 # base 10h wer + activation_dropout: 0.1 + feature_grad_mult: 0.0 + freeze_finetune_updates: 10000 + dropout: 0 + final_dropout: 0 + attention_dropout: 0 + update_alibi: false diff --git a/examples/wav2vec/config/finetuning/vox_960h_2.yaml b/examples/wav2vec/config/finetuning/vox_960h_2.yaml new file mode 100644 index 000000000..d96e2325b --- /dev/null +++ b/examples/wav2vec/config/finetuning/vox_960h_2.yaml @@ -0,0 +1,105 @@ +# @package _group_ + +common: + fp16: true + log_format: json + log_interval: 200 + user_dir: /private/home/abaevski/fairseq-py/examples/data2vec +# tensorboard_logdir: tb + +checkpoint: + save_interval: 1 + no_epoch_checkpoints: true + best_checkpoint_metric: wer + +task: + _name: audio_finetuning + data: /checkpoint/abaevski/data/speech/libri/960h/wav2vec/raw + labels: ltr + normalize: true + +dataset: + num_workers: 6 + max_tokens: 1000000 + skip_invalid_size_inputs_valid_test: true + validate_after_updates: 100 + validate_interval: 1 + valid_subset: dev_other + required_batch_size_multiple: 1 + +distributed_training: + ddp_backend: legacy_ddp + distributed_world_size: 16 + +criterion: + _name: ctc + zero_infinity: true + post_process: letter + wer_kenlm_model: /checkpoint/abaevski/data/speech/libri/4-gram.bin + wer_lexicon: /checkpoint/abaevski/data/speech/libri/10h/wav2vec/raw/lexicon_ltr2.lst + wer_lm_weight: 2.0 + wer_word_score: -1.0 + +optimization: + max_update: 200000 + lr: [1e-5] +# lr: [1e-5] # base 10h wer + sentence_avg: true + update_freq: [1] # base 10h we -> 2/4 + +optimizer: + _name: adam + adam_betas: (0.9,0.98) + adam_eps: 1e-08 + +lr_scheduler: + _name: tri_stage + phase_ratio: null + warmup_steps: 8000 + hold_steps: 0 + decay_steps: 200000 + final_lr_scale: 0.05 + +model: + _name: wav2vec_ctc + w2v_path: ??? + apply_mask: true + mask_prob: 0.4 + mask_length: 5 +# mask_prob: 0.65 # base 10h wer + mask_channel_prob: 0.1 +# mask_channel_prob: 0.6 # base 10h wer + mask_channel_length: 64 + layerdrop: 0.1 +# layerdrop: 0.05 # base 10h wer + activation_dropout: 0.1 + feature_grad_mult: 0.0 + freeze_finetune_updates: 100 + dropout: 0 + final_dropout: 0 + attention_dropout: 0 + +hydra: + job: + config: + override_dirname: + kv_sep: ':' + item_sep: '__' + exclude_keys: + - run_config + - distributed_training.distributed_port + sweep: + dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}/${hydra.job.override_dirname} + subdir: ${hydra.job.num} + launcher: + submitit_folder: ${hydra.sweep.dir} + timeout_min: 3000 + cpus_per_task: 10 + gpus_per_node: 4 + tasks_per_node: 4 + mem_gb: 250 + nodes: 1 + name: ${env:PREFIX}_${hydra.job.config_name} + partition: devlab,learnlab,learnfair,scavenge + constraint: volta32gb + max_num_timeout: 30 diff --git a/examples/wav2vec/config/finetuning/vox_960h_2_aws.yaml b/examples/wav2vec/config/finetuning/vox_960h_2_aws.yaml new file mode 100644 index 000000000..41d2b38f8 --- /dev/null +++ b/examples/wav2vec/config/finetuning/vox_960h_2_aws.yaml @@ -0,0 +1,82 @@ +# @package _group_ + +common: + fp16: true + log_format: json + log_interval: 200 + user_dir: /data/home/abaevski/fairseq-py/examples/data2vec +# tensorboard_logdir: tb + +checkpoint: + save_interval: 1 + no_epoch_checkpoints: true + best_checkpoint_metric: wer + +task: + _name: audio_finetuning + data: /fsx-wav2vec/abaevski/data/librispeech + labels: ltr + normalize: true + +dataset: + num_workers: 6 + max_tokens: 1280000 + skip_invalid_size_inputs_valid_test: true + validate_after_updates: 100 + validate_interval: 1 + valid_subset: dev_other + required_batch_size_multiple: 1 + +distributed_training: + ddp_backend: legacy_ddp + distributed_world_size: 16 + +criterion: + _name: ctc + zero_infinity: true + post_process: letter + wer_kenlm_model: /fsx-wav2vec/abaevski/data/libri/4-gram.bin + wer_lexicon: /fsx-wav2vec/abaevski/data/libri/10h/wav2vec/raw/lexicon_ltr2.lst + wer_lm_weight: 1.5 + wer_word_score: 0 + wer_sil_weight: -1 + +optimization: + max_update: 200000 + lr: [2e-5] +# lr: [1e-5] # base 10h wer + sentence_avg: true + update_freq: [1] # base 10h we -> 2/4 + +optimizer: + _name: adam + adam_betas: (0.9,0.98) + adam_eps: 1e-08 + +lr_scheduler: + _name: tri_stage + phase_ratio: null + warmup_steps: 8000 + hold_steps: 0 + decay_steps: 192000 + final_lr_scale: 0.05 + +model: + _name: wav2vec_ctc + w2v_path: ??? + apply_mask: true + mask_prob: 0.3 + mask_length: 5 +# mask_prob: 0.65 # base 10h wer + mask_channel_prob: 0.1 +# mask_channel_prob: 0.6 # base 10h wer + mask_channel_length: 64 + layerdrop: 0 +# layerdrop: 0.05 # base 10h wer + activation_dropout: 0.1 + feature_grad_mult: 0.0 + freeze_finetune_updates: 100 + dropout: 0 + final_dropout: 0 + attention_dropout: 0 + diff --git a/examples/wav2vec/config/finetuning/vox_960h_3.yaml b/examples/wav2vec/config/finetuning/vox_960h_3.yaml new file mode 100644 index 000000000..ef6597aa6 --- /dev/null +++ b/examples/wav2vec/config/finetuning/vox_960h_3.yaml @@ -0,0 +1,101 @@ +# @package _group_ + +common: + fp16: true + log_format: json + log_interval: 200 + user_dir: /private/home/abaevski/fairseq-py/examples/data2vec +# tensorboard_logdir: tb + +checkpoint: + save_interval: 1 + no_epoch_checkpoints: true + best_checkpoint_metric: wer + +task: + _name: audio_finetuning + data: /checkpoint/abaevski/data/speech/libri/1h/wav2vec/raw + labels: ltr + normalize: true + +dataset: + num_workers: 6 + max_tokens: 1000000 + skip_invalid_size_inputs_valid_test: true + validate_after_updates: 100 + validate_interval: 1 + valid_subset: dev_other + required_batch_size_multiple: 1 + +distributed_training: + ddp_backend: legacy_ddp + distributed_world_size: 16 + +criterion: + _name: ctc + zero_infinity: true + post_process: letter + wer_kenlm_model: /checkpoint/abaevski/data/speech/libri/4-gram.bin + wer_lexicon: /checkpoint/abaevski/data/speech/libri/10h/wav2vec/raw/lexicon_ltr2.lst + wer_lm_weight: 2.0 + wer_word_score: -1.0 + +optimization: + max_update: 200000 + lr: [1e-5] +# lr: [1e-5] # base 10h wer + sentence_avg: true + update_freq: [1] # base 10h we -> 2/4 + +optimizer: + _name: adam + adam_betas: (0.9,0.98) + adam_eps: 1e-08 + +lr_scheduler: + _name: cosine + warmup_updates: 8000 + +model: + _name: wav2vec_ctc + w2v_path: ??? + apply_mask: true + mask_prob: 0.4 + mask_length: 5 +# mask_prob: 0.65 # base 10h wer + mask_channel_prob: 0.1 +# mask_channel_prob: 0.6 # base 10h wer + mask_channel_length: 64 + layerdrop: 0.1 +# layerdrop: 0.05 # base 10h wer + activation_dropout: 0.1 + feature_grad_mult: 0.0 + freeze_finetune_updates: 100 + dropout: 0 + final_dropout: 0 + attention_dropout: 0 + +hydra: + job: + config: + override_dirname: + kv_sep: ':' + item_sep: '__' + exclude_keys: + - run_config + - distributed_training.distributed_port + sweep: + dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}/${hydra.job.override_dirname} + subdir: ${hydra.job.num} + launcher: + submitit_folder: ${hydra.sweep.dir} + timeout_min: 3000 + cpus_per_task: 10 + gpus_per_node: 4 + tasks_per_node: 4 + mem_gb: 250 + nodes: 1 + name: ${env:PREFIX}_${hydra.job.config_name} + partition: devlab,learnlab,learnfair,scavenge + constraint: volta32gb + max_num_timeout: 30 diff --git a/fairseq/checkpoint_utils.py b/fairseq/checkpoint_utils.py index 138b4d1eb..ff1da2553 100644 --- a/fairseq/checkpoint_utils.py +++ b/fairseq/checkpoint_utils.py @@ -45,14 +45,14 @@ def save_checkpoint(cfg: CheckpointConfig, trainer, epoch_itr, val_loss): save_checkpoint.best = best_function(val_loss, prev_best) if cfg.no_save: - return + return None trainer.consolidate_optimizer() # TODO(SS): do we need this if no_save_optimizer_state if not trainer.should_save_checkpoint_on_current_rank: if trainer.always_call_state_dict_during_save_checkpoint: trainer.state_dict() - return + return None write_timer = meters.StopwatchMeter() write_timer.start() @@ -111,8 +111,9 @@ def save_checkpoint(cfg: CheckpointConfig, trainer, epoch_itr, val_loss): checkpoints = [ os.path.join(cfg.save_dir, fn) for fn, cond in checkpoint_conds.items() if cond ] + saved_cp = None if len(checkpoints) > 0 and trainer.should_save_checkpoint_on_current_rank: - trainer.save_checkpoint(checkpoints[0], extra_state) + saved_cp = trainer.save_checkpoint(checkpoints[0], extra_state) for cp in checkpoints[1:]: if cfg.write_checkpoints_asynchronously: # TODO[ioPath]: Need to implement a delayed asynchronous @@ -133,7 +134,11 @@ def save_checkpoint(cfg: CheckpointConfig, trainer, epoch_itr, val_loss): ) ) - if not end_of_epoch and cfg.keep_interval_updates > 0: + if ( + not end_of_epoch + and cfg.keep_interval_updates > 0 + and trainer.should_save_checkpoint_on_current_rank + ): # remove old checkpoints; checkpoints are sorted in descending order if cfg.keep_interval_updates_pattern == -1: checkpoints = checkpoint_paths( @@ -157,7 +162,7 @@ def save_checkpoint(cfg: CheckpointConfig, trainer, epoch_itr, val_loss): elif PathManager.exists(old_chk): PathManager.rm(old_chk) - if cfg.keep_last_epochs > 0: + if cfg.keep_last_epochs > 0 and trainer.should_save_checkpoint_on_current_rank: # remove old epoch checkpoints; checkpoints are sorted in descending order checkpoints = checkpoint_paths( cfg.save_dir, pattern=r"checkpoint(\d+){}\.pt".format(suffix) @@ -168,7 +173,7 @@ def save_checkpoint(cfg: CheckpointConfig, trainer, epoch_itr, val_loss): elif PathManager.exists(old_chk): PathManager.rm(old_chk) - if cfg.keep_best_checkpoints > 0: + if cfg.keep_best_checkpoints > 0 and trainer.should_save_checkpoint_on_current_rank: # only keep the best N checkpoints according to validation metric checkpoints = checkpoint_paths( cfg.save_dir, @@ -184,6 +189,8 @@ def save_checkpoint(cfg: CheckpointConfig, trainer, epoch_itr, val_loss): elif PathManager.exists(old_chk): PathManager.rm(old_chk) + return saved_cp + def load_checkpoint(cfg: CheckpointConfig, trainer, **passthrough_args): """ @@ -574,6 +581,8 @@ def _torch_persistent_save(obj, f): if i == 2: logger.error(traceback.format_exc()) raise + else: + time.sleep(2.5) def _upgrade_state_dict(state): diff --git a/fairseq/config/fb_run_config/slurm.yaml b/fairseq/config/fb_run_config/slurm.yaml new file mode 100644 index 000000000..20cf8f520 --- /dev/null +++ b/fairseq/config/fb_run_config/slurm.yaml @@ -0,0 +1,29 @@ +# @package _global_ + +hydra: + job: + config: + override_dirname: + kv_sep: ':' + item_sep: '__' + exclude_keys: + - fb_run_config + - distributed_training.distributed_port + sweep: + dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname} + launcher: + cpus_per_task: 60 + gpus_per_node: ??? + tasks_per_node: 1 + nodes: 1 + partition: learnfair + mem_gb: 400 + timeout_min: 4320 + max_num_timeout: 10 + name: ${env:PREFIX}_${hydra.job.config_name} + submitit_folder: ${hydra.sweep.dir} + +distributed_training: + ddp_backend: c10d + distributed_world_size: ??? + distributed_port: ??? diff --git a/fairseq/criterions/__init__.py b/fairseq/criterions/__init__.py index 4dbf46a1c..ecd65d34a 100644 --- a/fairseq/criterions/__init__.py +++ b/fairseq/criterions/__init__.py @@ -25,8 +25,8 @@ from omegaconf import DictConfig ) -def build_criterion(cfg: DictConfig, task): - return build_criterion_(cfg, task) +def build_criterion(cfg: DictConfig, task, from_checkpoint=False): + return build_criterion_(cfg, task, from_checkpoint=from_checkpoint) # automatically import any Python files in the criterions/ directory diff --git a/fairseq/criterions/ctc.py b/fairseq/criterions/ctc.py index 6d53198b0..e55e928b4 100644 --- a/fairseq/criterions/ctc.py +++ b/fairseq/criterions/ctc.py @@ -7,18 +7,17 @@ import math from argparse import Namespace from dataclasses import dataclass, field +from omegaconf import II from typing import Optional import torch import torch.nn.functional as F -from omegaconf import II - from fairseq import metrics, utils from fairseq.criterions import FairseqCriterion, register_criterion -from fairseq.data.data_utils import post_process from fairseq.dataclass import FairseqDataclass -from fairseq.logging.meters import safe_round +from fairseq.data.data_utils import post_process from fairseq.tasks import FairseqTask +from fairseq.logging.meters import safe_round @dataclass @@ -54,6 +53,10 @@ class CtcCriterionConfig(FairseqDataclass): default=-1.0, metadata={"help": "lm word score to use with wer_kenlm_model"}, ) + wer_sil_weight: float = field( + default=0, + metadata={"help": "lm word score to use with wer_kenlm_model"}, + ) wer_args: Optional[str] = field( default=None, @@ -101,6 +104,7 @@ class CtcCriterion(FairseqCriterion): dec_args.beam_threshold = min(50, len(task.target_dictionary)) dec_args.lm_weight = cfg.wer_lm_weight dec_args.word_score = cfg.wer_word_score + dec_args.sil_weight = cfg.wer_sil_weight dec_args.unk_weight = -math.inf dec_args.sil_weight = 0 diff --git a/fairseq/criterions/label_smoothed_cross_entropy.py b/fairseq/criterions/label_smoothed_cross_entropy.py index 257466903..cb43be0ca 100644 --- a/fairseq/criterions/label_smoothed_cross_entropy.py +++ b/fairseq/criterions/label_smoothed_cross_entropy.py @@ -7,11 +7,10 @@ import math from dataclasses import dataclass, field import torch -from omegaconf import II - from fairseq import metrics, utils from fairseq.criterions import FairseqCriterion, register_criterion from fairseq.dataclass import FairseqDataclass +from omegaconf import II @dataclass diff --git a/fairseq/criterions/model_criterion.py b/fairseq/criterions/model_criterion.py index f9a810d83..2c9fbb255 100644 --- a/fairseq/criterions/model_criterion.py +++ b/fairseq/criterions/model_criterion.py @@ -12,6 +12,7 @@ import torch from fairseq import metrics, utils from fairseq.criterions import FairseqCriterion, register_criterion from fairseq.dataclass import FairseqDataclass +from fairseq.logging.meters import safe_round logger = logging.getLogger(__name__) @@ -27,6 +28,7 @@ class ModelCriterionConfig(FairseqDataclass): default_factory=list, metadata={"help": "additional output keys to log"}, ) + can_sum: bool = True @register_criterion("model", dataclass=ModelCriterionConfig) @@ -43,10 +45,11 @@ class ModelCriterion(FairseqCriterion): net_output dict can be logged via the log_keys parameter. """ - def __init__(self, task, loss_weights=None, log_keys=None): + def __init__(self, task, loss_weights=None, log_keys=None, can_sum=True): super().__init__(task) self.loss_weights = loss_weights self.log_keys = log_keys + self.can_sum = can_sum def forward(self, model, sample, reduce=True): net_output = model(**sample["net_input"]) @@ -69,7 +72,7 @@ class ModelCriterion(FairseqCriterion): ) raise if coef != 0 and p is not None: - scaled_losses[lk] = coef * p.float() + scaled_losses[lk] = coef * p.float().sum() loss = sum(scaled_losses.values()) @@ -93,6 +96,8 @@ class ModelCriterion(FairseqCriterion): if lk in net_output and net_output[lk] is not None: if not torch.is_tensor(net_output[lk]) or net_output[lk].numel() == 1: logging_output[lk] = float(net_output[lk]) + elif lk.startswith("_"): + logging_output[lk] = net_output[lk] else: for i, v in enumerate(net_output[lk]): logging_output[f"{lk}_{i}"] = float(v) @@ -124,6 +129,7 @@ class ModelCriterion(FairseqCriterion): metrics.log_scalar("loss", loss_sum / sample_size, sample_size, round=3) metrics.log_scalar("ntokens", ntokens) metrics.log_scalar("nsentences", nsentences) + metrics.log_scalar("sample_size", sample_size) builtin_keys = { "loss", @@ -138,18 +144,33 @@ class ModelCriterion(FairseqCriterion): ) for k in logging_outputs[0]: - if k not in builtin_keys: + if k not in builtin_keys and not k.startswith("_"): val = sum(log.get(k, 0) for log in logging_outputs) if k.startswith("loss_"): metrics.log_scalar(k, val / sample_size, sample_size, round=3) else: metrics.log_scalar(k, val / world_size, round=3) - @staticmethod - def logging_outputs_can_be_summed() -> bool: + correct = sum(log.get("correct", 0) for log in logging_outputs) + total = sum(log.get("count", 0) for log in logging_outputs) + + if total > 0: + metrics.log_scalar("_correct", correct) + metrics.log_scalar("_total", total) + + metrics.log_derived( + "accuracy", + lambda meters: safe_round( + meters["_correct"].sum / meters["_total"].sum, 5 + ) + if meters["_total"].sum > 0 + else float("nan"), + ) + + def logging_outputs_can_be_summed(self) -> bool: """ Whether the logging outputs returned by `forward` can be summed across workers prior to calling `reduce_metrics`. Setting this to True will improves distributed training speed. """ - return True + return self.can_sum diff --git a/fairseq/criterions/sentence_prediction.py b/fairseq/criterions/sentence_prediction.py index b402d7603..01c2a2ba6 100644 --- a/fairseq/criterions/sentence_prediction.py +++ b/fairseq/criterions/sentence_prediction.py @@ -5,13 +5,49 @@ import math from dataclasses import dataclass, field +from itertools import chain +import numpy as np import torch import torch.nn.functional as F +from sklearn.metrics import f1_score +from sklearn.metrics import matthews_corrcoef as _matthews_corrcoef +from scipy.stats import pearsonr, spearmanr from fairseq import metrics from fairseq.criterions import FairseqCriterion, register_criterion from fairseq.dataclass import FairseqDataclass +from fairseq.logging.meters import safe_round + + +def simple_accuracy(preds, labels): + return (preds == labels).mean() + + +def acc_and_f1(preds, labels): + acc = simple_accuracy(preds, labels) + f1 = f1_score(y_true=labels, y_pred=preds) + return { + "acc": acc, + "f1": f1, + "acc_and_f1": (acc + f1) / 2, + } + + +def pearson_and_spearman(preds, labels): + pearson_corr = pearsonr(preds, labels)[0] + spearman_corr = spearmanr(preds, labels)[0] + return { + "pearson": pearson_corr, + "spearmanr": spearman_corr, + "corr": (pearson_corr + spearman_corr) / 2, + } + + +def matthews_corrcoef(preds, labels): + # make it consistent with other metrics taking (preds, labels) as input + mcc = _matthews_corrcoef(labels, preds) + return mcc @dataclass @@ -23,6 +59,9 @@ class SentencePredictionConfig(FairseqDataclass): regression_target: bool = field( default=False, ) + report_mcc: bool = False + report_acc_and_f1: bool = False + report_pearson_and_spearman: bool = False @register_criterion("sentence_prediction", dataclass=SentencePredictionConfig) @@ -31,6 +70,13 @@ class SentencePredictionCriterion(FairseqCriterion): super().__init__(task) self.classification_head_name = cfg.classification_head_name self.regression_target = cfg.regression_target + self.keep_pred_and_targ = ( + cfg.report_mcc or cfg.report_acc_and_f1 or cfg.report_pearson_and_spearman + ) + self.report_mcc = cfg.report_mcc + self.report_acc_and_f1 = cfg.report_acc_and_f1 + self.report_pearson_and_spearman = cfg.report_pearson_and_spearman + self.label_dict = task.label_dictionary def forward(self, model, sample, reduce=True): """Compute the loss for the given sample. @@ -65,14 +111,16 @@ class SentencePredictionCriterion(FairseqCriterion): loss = task_loss # mha & ffn regularization update if ( - hasattr(model.args, "mha_reg_scale_factor") + hasattr(model, "args") + and hasattr(model.args, "mha_reg_scale_factor") and model.args.mha_reg_scale_factor != 0.0 ): mha_reg_loss = model._get_adaptive_head_loss() loss += mha_reg_loss logging_output.update({"mha_reg_loss": mha_reg_loss}) if ( - hasattr(model.args, "ffn_reg_scale_factor") + hasattr(model, "args") + and hasattr(model.args, "ffn_reg_scale_factor") and model.args.ffn_reg_scale_factor != 0.0 ): ffn_reg_loss = model._get_adaptive_ffn_loss() @@ -90,6 +138,25 @@ class SentencePredictionCriterion(FairseqCriterion): if not self.regression_target: preds = logits.argmax(dim=1) logging_output["ncorrect"] = (preds == targets).sum() + if self.keep_pred_and_targ and not model.training: + if self.regression_target: + logging_output["pred"] = logits.detach().cpu().tolist() + logging_output["targ"] = targets.detach().cpu().tolist() + else: + # remove offset `self.label_dict.nspecial` from OffsetTokensDataset + preds = self.label_dict.string(preds + self.label_dict.nspecial).split() + targets = self.label_dict.string( + targets + self.label_dict.nspecial + ).split() + logging_output["pred"] = list(map(int, preds)) + logging_output["targ"] = list(map(int, targets)) + + if self.report_mcc: + logging_output["report_mcc"] = True + if self.report_acc_and_f1: + logging_output["report_acc_and_f1"] = True + if self.report_pearson_and_spearman: + logging_output["report_pearson_and_spearman"] = True return loss, sample_size, logging_output @@ -131,6 +198,86 @@ class SentencePredictionCriterion(FairseqCriterion): "accuracy", 100.0 * ncorrect / nsentences, nsentences, round=1 ) + # Metrics used by GLUE + pred = np.array( + list(chain.from_iterable(log.get("pred", []) for log in logging_outputs)) + ) + targ = np.array( + list(chain.from_iterable(log.get("targ", []) for log in logging_outputs)) + ) + if len(pred): + metrics.log_concat_tensor("pred", torch.from_numpy(pred), dim=0) + metrics.log_concat_tensor("targ", torch.from_numpy(targ), dim=0) + if any("report_mcc" in log for log in logging_outputs): + metrics.log_derived( + "mcc", + lambda meters: safe_round( + matthews_corrcoef( + meters["pred"].tensor.numpy(), + meters["targ"].tensor.numpy(), + ) + * 100, + 1, + ), + ) + if any("report_acc_and_f1" in log for log in logging_outputs): + metrics.log_derived( + "acc_and_f1", + lambda meters: safe_round( + acc_and_f1( + meters["pred"].tensor.numpy(), + meters["targ"].tensor.numpy(), + )["acc_and_f1"] + * 100, + 1, + ), + ) + metrics.log_derived( + "f1", + lambda meters: safe_round( + acc_and_f1( + meters["pred"].tensor.numpy(), + meters["targ"].tensor.numpy(), + )["f1"] + * 100, + 1, + ), + ) + if any("report_pearson_and_spearman" in log for log in logging_outputs): + metrics.log_derived( + "pearson_and_spearman", + lambda meters: safe_round( + pearson_and_spearman( + meters["pred"].tensor.numpy(), + meters["targ"].tensor.numpy(), + )["corr"] + * 100, + 1, + ), + ) + metrics.log_derived( + "pearson", + lambda meters: safe_round( + pearson_and_spearman( + meters["pred"].tensor.numpy(), + meters["targ"].tensor.numpy(), + )["pearson"] + * 100, + 1, + ), + ) + metrics.log_derived( + "spearman", + lambda meters: safe_round( + pearson_and_spearman( + meters["pred"].tensor.numpy(), + meters["targ"].tensor.numpy(), + )["spearmanr"] + * 100, + 1, + ), + ) + @staticmethod def logging_outputs_can_be_summed() -> bool: """ diff --git a/fairseq/data/__init__.py b/fairseq/data/__init__.py index 8acf2ca17..a27e3184a 100644 --- a/fairseq/data/__init__.py +++ b/fairseq/data/__init__.py @@ -39,6 +39,11 @@ from .noising import NoisingDataset from .numel_dataset import NumelDataset from .num_samples_dataset import NumSamplesDataset from .offset_tokens_dataset import OffsetTokensDataset +from .padding_mask_dataset import ( + LeftPaddingMaskDataset, + PaddingMaskDataset, + RightPaddingMaskDataset, +) from .pad_dataset import LeftPadDataset, PadDataset, RightPadDataset from .prepend_dataset import PrependDataset from .prepend_token_dataset import PrependTokenDataset diff --git a/fairseq/data/add_class_target_dataset.py b/fairseq/data/add_class_target_dataset.py new file mode 100644 index 000000000..bf89f2565 --- /dev/null +++ b/fairseq/data/add_class_target_dataset.py @@ -0,0 +1,79 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch + +from . import BaseWrapperDataset, data_utils +from fairseq.data.text_compressor import TextCompressor, TextCompressionLevel + + +class AddTargetDataset(BaseWrapperDataset): + def __init__( + self, + dataset, + labels, + pad, + eos, + batch_targets, + process_label=None, + label_len_fn=None, + add_to_input=False, + text_compression_level=TextCompressionLevel.none, + ): + super().__init__(dataset) + self.labels = labels + self.batch_targets = batch_targets + self.pad = pad + self.eos = eos + self.process_label = process_label + self.label_len_fn = label_len_fn + self.add_to_input = add_to_input + self.text_compressor = TextCompressor(level=text_compression_level) + + def get_label(self, index, process_fn=None): + lbl = self.labels[index] + lbl = self.text_compressor.decompress(lbl) + return lbl if process_fn is None else process_fn(lbl) + + def __getitem__(self, index): + item = self.dataset[index] + item["label"] = self.get_label(index, process_fn=self.process_label) + return item + + def size(self, index): + sz = self.dataset.size(index) + own_sz = self.label_len_fn(self.get_label(index)) + return sz, own_sz + + def collater(self, samples): + collated = self.dataset.collater(samples) + if len(collated) == 0: + return collated + indices = set(collated["id"].tolist()) + target = [s["label"] for s in samples if s["id"] in indices] + + if self.batch_targets: + collated["target_lengths"] = torch.LongTensor([len(t) for t in target]) + target = data_utils.collate_tokens(target, pad_idx=self.pad, left_pad=False) + collated["ntokens"] = collated["target_lengths"].sum().item() + else: + collated["ntokens"] = sum([len(t) for t in target]) + + collated["target"] = target + + if self.add_to_input: + eos = target.new_full((target.size(0), 1), self.eos) + collated["target"] = torch.cat([target, eos], dim=-1).long() + collated["net_input"]["prev_output_tokens"] = torch.cat( + [eos, target], dim=-1 + ).long() + collated["ntokens"] += target.size(0) + return collated + + def filter_indices_by_size(self, indices, max_sizes): + indices, ignored = data_utils._filter_by_size_dynamic( + indices, self.size, max_sizes + ) + return indices, ignored diff --git a/fairseq/data/audio/multi_modality_dataset.py b/fairseq/data/audio/multi_modality_dataset.py index 2db163734..0a42c1061 100644 --- a/fairseq/data/audio/multi_modality_dataset.py +++ b/fairseq/data/audio/multi_modality_dataset.py @@ -10,7 +10,6 @@ import math from typing import List, Optional, NamedTuple import numpy as np -from fairseq.data.resampling_dataset import ResamplingDataset import torch from fairseq.data import ( ConcatDataset, @@ -31,16 +30,6 @@ class ModalityDatasetItem(NamedTuple): max_sentences: Optional[int] = None -def resampling_dataset_present(ds): - if isinstance(ds, ResamplingDataset): - return True - if isinstance(ds, ConcatDataset): - return any(resampling_dataset_present(d) for d in ds.datasets) - if hasattr(ds, "dataset"): - return resampling_dataset_present(ds.dataset) - return False - - # MultiModalityDataset: it concate multiple datasets with different modalities. # Compared with ConcatDataset it can 1) sample data given the ratios for different datasets # 2) it adds mode to indicate what type of the data samples come from. @@ -106,7 +95,7 @@ class MultiModalityDataset(ConcatDataset): Returns indices sorted by length. So less padding is needed. """ if len(self.datasets) == 1: - return self.datasets[0].ordered_indices() + return [self.datasets[0].ordered_indices()] indices_group = [] for d_idx, ds in enumerate(self.datasets): sample_num = self.cumulative_sizes[d_idx] @@ -117,16 +106,13 @@ class MultiModalityDataset(ConcatDataset): return indices_group def get_raw_batch_samplers(self, required_batch_size_multiple, seed): + if len(self.raw_sub_batch_samplers) > 0: + logger.info(" raw_sub_batch_samplers exists. No action is taken") + return with data_utils.numpy_seed(seed): indices = self.ordered_indices() + for i, ds in enumerate(self.datasets): - # If we have ResamplingDataset, the same id can correpond to a different - # sample in the next epoch, so we need to rebuild this at every epoch - if i < len(self.raw_sub_batch_samplers) and not resampling_dataset_present( - ds - ): - logger.info(f"dataset {i} is valid and it is not re-sampled") - continue indices[i] = ds.filter_indices_by_size( indices[i], self.max_positions[i], @@ -137,10 +123,7 @@ class MultiModalityDataset(ConcatDataset): max_sentences=self.max_sentences[i], required_batch_size_multiple=required_batch_size_multiple, ) - if i < len(self.raw_sub_batch_samplers): - self.raw_sub_batch_samplers[i] = sub_batch_sampler - else: - self.raw_sub_batch_samplers.append(sub_batch_sampler) + self.raw_sub_batch_samplers.append(sub_batch_sampler) def get_batch_samplers(self, mult_ratios, required_batch_size_multiple, seed): self.get_raw_batch_samplers(required_batch_size_multiple, seed) diff --git a/fairseq/data/audio/raw_audio_dataset.py b/fairseq/data/audio/raw_audio_dataset.py index 181e2bbc9..edb307e68 100644 --- a/fairseq/data/audio/raw_audio_dataset.py +++ b/fairseq/data/audio/raw_audio_dataset.py @@ -7,6 +7,7 @@ import logging import os import sys +import time import io import numpy as np @@ -14,7 +15,7 @@ import torch import torch.nn.functional as F from .. import FairseqDataset -from ..data_utils import compute_mask_indices, get_buckets, get_bucketed_sizes +from ..data_utils import compute_block_mask_1d, get_buckets, get_bucketed_sizes from fairseq.data.audio.audio_utils import ( parse_path, read_from_stored_zip, @@ -35,8 +36,17 @@ class RawAudioDataset(FairseqDataset): shuffle=True, pad=False, normalize=False, - compute_mask_indices=False, - **mask_compute_kwargs, + compute_mask=False, + feature_encoder_spec: str = "None", + mask_prob: float = 0.75, + mask_prob_adjust: float = 0, + mask_length: int = 1, + inverse_mask: bool = False, + require_same_masks: bool = True, + clone_batch: int = 1, + expand_adjacent: bool = False, + mask_dropout: float = 0, + non_overlapping: bool = False, ): super().__init__() @@ -49,12 +59,19 @@ class RawAudioDataset(FairseqDataset): self.pad = pad self.shuffle = shuffle self.normalize = normalize - self.compute_mask_indices = compute_mask_indices - if self.compute_mask_indices: - self.mask_compute_kwargs = mask_compute_kwargs - self._features_size_map = {} - self._C = mask_compute_kwargs["encoder_embed_dim"] - self._conv_feature_layers = eval(mask_compute_kwargs["conv_feature_layers"]) + + self.is_compute_mask = compute_mask + self.feature_encoder_spec = eval(feature_encoder_spec) + self._features_size_map = {} + self.mask_prob = mask_prob + self.mask_prob_adjust = mask_prob_adjust + self.mask_length = mask_length + self.inverse_mask = inverse_mask + self.require_same_masks = require_same_masks + self.clone_batch = clone_batch + self.expand_adjacent = expand_adjacent + self.mask_dropout = mask_dropout + self.non_overlapping = non_overlapping def __getitem__(self, index): raise NotImplementedError() @@ -76,48 +93,21 @@ class RawAudioDataset(FairseqDataset): feats = F.layer_norm(feats, feats.shape) return feats - def crop_to_max_size(self, wav, target_size): - size = len(wav) + def crop_to_max_size(self, t, target_size, dim=0): + size = t.size(dim) diff = size - target_size if diff <= 0: - return wav + return t start = np.random.randint(0, diff + 1) end = size - diff + start - return wav[start:end] - def _compute_mask_indices(self, dims, padding_mask): - B, T, C = dims - mask_indices, mask_channel_indices = None, None - if self.mask_compute_kwargs["mask_prob"] > 0: - mask_indices = compute_mask_indices( - (B, T), - padding_mask, - self.mask_compute_kwargs["mask_prob"], - self.mask_compute_kwargs["mask_length"], - self.mask_compute_kwargs["mask_selection"], - self.mask_compute_kwargs["mask_other"], - min_masks=2, - no_overlap=self.mask_compute_kwargs["no_mask_overlap"], - min_space=self.mask_compute_kwargs["mask_min_space"], - ) - mask_indices = torch.from_numpy(mask_indices) - if self.mask_compute_kwargs["mask_channel_prob"] > 0: - mask_channel_indices = compute_mask_indices( - (B, C), - None, - self.mask_compute_kwargs["mask_channel_prob"], - self.mask_compute_kwargs["mask_channel_length"], - self.mask_compute_kwargs["mask_channel_selection"], - self.mask_compute_kwargs["mask_channel_other"], - no_overlap=self.mask_compute_kwargs["no_mask_channel_overlap"], - min_space=self.mask_compute_kwargs["mask_channel_min_space"], - ) - mask_channel_indices = ( - torch.from_numpy(mask_channel_indices).unsqueeze(1).expand(-1, T, -1) - ) + slices = [] + for d in range(dim): + slices.append(slice(None)) + slices.append(slice(start, end)) - return mask_indices, mask_channel_indices + return t[slices] @staticmethod def _bucket_tensor(tensor, num_pad, value): @@ -166,33 +156,24 @@ class RawAudioDataset(FairseqDataset): input["source"] = self._bucket_tensor(collated_sources, num_pad, 0) input["padding_mask"] = self._bucket_tensor(padding_mask, num_pad, True) - if self.compute_mask_indices: - B = input["source"].size(0) - T = self._get_mask_indices_dims(input["source"].size(-1)) - padding_mask_reshaped = input["padding_mask"].clone() - extra = padding_mask_reshaped.size(1) % T - if extra > 0: - padding_mask_reshaped = padding_mask_reshaped[:, :-extra] - padding_mask_reshaped = padding_mask_reshaped.view( - padding_mask_reshaped.size(0), T, -1 + if "precomputed_mask" in samples[0]: + target_size = self._get_mask_indices_dims(target_size) + collated_mask = torch.cat( + [ + self.crop_to_max_size(s["precomputed_mask"], target_size, dim=1) + for s in samples + ], + dim=0, ) - padding_mask_reshaped = padding_mask_reshaped.all(-1) - input["padding_count"] = padding_mask_reshaped.sum(-1).max().item() - mask_indices, mask_channel_indices = self._compute_mask_indices( - (B, T, self._C), - padding_mask_reshaped, - ) - input["mask_indices"] = mask_indices - input["mask_channel_indices"] = mask_channel_indices - out["sample_size"] = mask_indices.sum().item() + input["precomputed_mask"] = collated_mask out["net_input"] = input return out def _get_mask_indices_dims(self, size, padding=0, dilation=1): - if size not in self._features_size_map: + if size not in self.feature_encoder_spec: L_in = size - for (_, kernel_size, stride) in self._conv_feature_layers: + for (_, kernel_size, stride) in self.feature_encoder_spec: L_out = L_in + 2 * padding - dilation * (kernel_size - 1) - 1 L_out = 1 + L_out // stride L_in = L_out @@ -244,6 +225,9 @@ class RawAudioDataset(FairseqDataset): f"{self.buckets}" ) + def filter_indices_by_size(self, indices, max_sizes): + return indices, [] + class FileAudioDataset(RawAudioDataset): def __init__( @@ -256,7 +240,7 @@ class FileAudioDataset(RawAudioDataset): pad=False, normalize=False, num_buckets=0, - compute_mask_indices=False, + compute_mask=False, text_compression_level=TextCompressionLevel.none, **mask_compute_kwargs, ): @@ -267,7 +251,7 @@ class FileAudioDataset(RawAudioDataset): shuffle=shuffle, pad=pad, normalize=normalize, - compute_mask_indices=compute_mask_indices, + compute_mask=compute_mask, **mask_compute_kwargs, ) @@ -319,11 +303,43 @@ class FileAudioDataset(RawAudioDataset): assert is_sf_audio_data(byte_data) path_or_fp = io.BytesIO(byte_data) - wav, curr_sample_rate = sf.read(path_or_fp, dtype="float32") + retry = 3 + wav = None + for i in range(retry): + try: + wav, curr_sample_rate = sf.read(path_or_fp, dtype="float32") + break + except Exception as e: + logger.warning( + f"Failed to read {path_or_fp}: {e}. Sleeping for {1 * i}" + ) + time.sleep(1 * i) + + if wav is None: + raise Exception(f"Failed to load {path_or_fp}") feats = torch.from_numpy(wav).float() feats = self.postprocess(feats, curr_sample_rate) - return {"id": index, "source": feats} + + v = {"id": index, "source": feats} + + if self.is_compute_mask: + T = self._get_mask_indices_dims(feats.size(-1)) + mask = compute_block_mask_1d( + shape=(self.clone_batch, T), + mask_prob=self.mask_prob, + mask_length=self.mask_length, + mask_prob_adjust=self.mask_prob_adjust, + inverse_mask=self.inverse_mask, + require_same_masks=True, + expand_adjcent=self.expand_adjacent, + mask_dropout=self.mask_dropout, + non_overlapping=self.non_overlapping, + ) + + v["precomputed_mask"] = mask + + return v class BinarizedAudioDataset(RawAudioDataset): @@ -338,7 +354,7 @@ class BinarizedAudioDataset(RawAudioDataset): pad=False, normalize=False, num_buckets=0, - compute_mask_indices=False, + compute_mask=False, **mask_compute_kwargs, ): super().__init__( @@ -348,7 +364,7 @@ class BinarizedAudioDataset(RawAudioDataset): shuffle=shuffle, pad=pad, normalize=normalize, - compute_mask_indices=compute_mask_indices, + compute_mask=compute_mask, **mask_compute_kwargs, ) @@ -390,4 +406,22 @@ class BinarizedAudioDataset(RawAudioDataset): wav, curr_sample_rate = sf.read(fname) feats = torch.from_numpy(wav).float() feats = self.postprocess(feats, curr_sample_rate) - return {"id": index, "source": feats} + v = {"id": index, "source": feats} + + if self.is_compute_mask: + T = self._get_mask_indices_dims(feats.size(-1)) + mask = compute_block_mask_1d( + shape=(self.clone_batch, T), + mask_prob=self.mask_prob, + mask_length=self.mask_length, + mask_prob_adjust=self.mask_prob_adjust, + inverse_mask=self.inverse_mask, + require_same_masks=True, + expand_adjcent=self.expand_adjacent, + mask_dropout=self.mask_dropout, + non_overlapping=self.non_overlapping, + ) + + v["precomputed_mask"] = mask + + return v diff --git a/fairseq/data/data_utils.py b/fairseq/data/data_utils.py index 0372d52b0..9a19cc3c1 100644 --- a/fairseq/data/data_utils.py +++ b/fairseq/data/data_utils.py @@ -14,6 +14,7 @@ import re import warnings from typing import Optional, Tuple +import math import numpy as np import torch @@ -337,7 +338,7 @@ def batch_by_size( if fixed_shapes is None: if num_tokens_vec is None: - return batch_by_size_fn( + b = batch_by_size_fn( indices, num_tokens_fn, max_tokens, @@ -345,7 +346,7 @@ def batch_by_size( bsz_mult, ) else: - return batch_by_size_vec( + b = batch_by_size_vec( indices, num_tokens_vec, max_tokens, @@ -353,6 +354,11 @@ def batch_by_size( bsz_mult, ) + if bsz_mult > 1 and len(b[-1]) % bsz_mult != 0: + b = b[:-1] + + return b + else: fixed_shapes = np.array(fixed_shapes, dtype=np.int64) sort_order = np.lexsort( @@ -402,6 +408,12 @@ def compute_mask_indices( min_space: int = 0, require_same_masks: bool = True, mask_dropout: float = 0.0, + add_masks: bool = False, + seed: Optional[int] = None, + epoch: Optional[int] = None, + indices: Optional[torch.Tensor] = None, + idc_select_ver: int = 1, # 2 to reproduce mask_tokens_dataset + num_mask_ver: int = 2, # 2 to reproduce mask_tokens_dataset ) -> np.ndarray: """ Computes random mask spans for a given shape @@ -428,49 +440,73 @@ def compute_mask_indices( bsz, all_sz = shape mask = np.full((bsz, all_sz), False) - all_num_mask = int( - # add a random number for probabilistic rounding - mask_prob * all_sz / float(mask_length) - + np.random.rand() - ) - - all_num_mask = max(min_masks, all_num_mask) + if num_mask_ver == 1: + all_num_mask = int( + # add a random number for probabilistic rounding + mask_prob * all_sz / float(mask_length) + + np.random.rand() + ) + all_num_mask = max(min_masks, all_num_mask) mask_idcs = [] for i in range(bsz): + if seed is not None and epoch is not None and indices is not None: + seed_i = int(hash((seed, epoch, indices[i].item())) % 1e6) + else: + seed_i = None + + rng = np.random.default_rng(seed_i) + if padding_mask is not None: sz = all_sz - padding_mask[i].long().sum().item() + assert sz >= 0, sz + else: + sz = all_sz + + if num_mask_ver == 1: + if padding_mask is not None: + num_mask = int( + # add a random number for probabilistic rounding + mask_prob * sz / float(mask_length) + + np.random.rand() + ) + num_mask = max(min_masks, num_mask) + else: + num_mask = all_num_mask + elif num_mask_ver == 2: num_mask = int( # add a random number for probabilistic rounding mask_prob * sz / float(mask_length) - + np.random.rand() + + rng.random() ) num_mask = max(min_masks, num_mask) else: - sz = all_sz - num_mask = all_num_mask + raise ValueError() if mask_type == "static": lengths = np.full(num_mask, mask_length) elif mask_type == "uniform": - lengths = np.random.randint(mask_other, mask_length * 2 + 1, size=num_mask) + lengths = rng.randint(mask_other, mask_length * 2 + 1, size=num_mask) elif mask_type == "normal": - lengths = np.random.normal(mask_length, mask_other, size=num_mask) + lengths = rng.normal(mask_length, mask_other, size=num_mask) lengths = [max(1, int(round(x))) for x in lengths] elif mask_type == "poisson": - lengths = np.random.poisson(mask_length, size=num_mask) + lengths = rng.poisson(mask_length, size=num_mask) lengths = [int(round(x)) for x in lengths] else: raise Exception("unknown mask selection " + mask_type) if sum(lengths) == 0: - lengths[0] = min(mask_length, sz - 1) + if mask_type == "static": + raise ValueError(f"this should never happens") + else: + lengths = [min(mask_length, sz - 1)] if no_overlap: mask_idc = [] def arrange(s, e, length, keep_length): - span_start = np.random.randint(s, e - length) + span_start = rng.randint(s, e - length) mask_idc.extend(span_start + i for i in range(length)) new_parts = [] @@ -491,16 +527,20 @@ def compute_mask_indices( if l_sum == 0: break probs = lens / np.sum(lens) - c = np.random.choice(len(parts), p=probs) + c = rng.choice(len(parts), p=probs) s, e = parts.pop(c) parts.extend(arrange(s, e, length, min_length)) mask_idc = np.asarray(mask_idc) else: - min_len = min(lengths) - if sz - min_len <= num_mask: - min_len = sz - num_mask - 1 - - mask_idc = np.random.choice(sz - min_len, num_mask, replace=False) + if idc_select_ver == 1: + min_len = min(lengths) + if sz - min_len <= num_mask: + min_len = sz - num_mask - 1 + mask_idc = rng.choice(sz - min_len, num_mask, replace=False) + elif idc_select_ver == 2: + mask_idc = rng.choice(sz, num_mask, replace=False) + else: + raise ValueError() mask_idc = np.asarray( [ @@ -510,20 +550,300 @@ def compute_mask_indices( ] ) - mask_idcs.append(np.unique(mask_idc[mask_idc < sz])) - - min_len = min([len(m) for m in mask_idcs]) - for i, mask_idc in enumerate(mask_idcs): - if len(mask_idc) > min_len and require_same_masks: - mask_idc = np.random.choice(mask_idc, min_len, replace=False) - if mask_dropout > 0: - num_holes = np.rint(len(mask_idc) * mask_dropout).astype(int) - mask_idc = np.random.choice( - mask_idc, len(mask_idc) - num_holes, replace=False + mask_idc = np.unique(mask_idc[mask_idc < sz]) + if len(mask_idc) >= sz: + raise ValueError( + ( + f"the entire sequence is masked. " + f"sz={sz}; mask_idc[mask_idc]; " + f"index={indices[i] if indices is not None else None}" + ) ) + mask_idcs.append(mask_idc) + + target_len = None + if require_same_masks: + if add_masks: + target_len = max([len(m) for m in mask_idcs]) + else: + target_len = min([len(m) for m in mask_idcs]) + + for i, mask_idc in enumerate(mask_idcs): + if target_len is not None and len(mask_idc) > target_len: + mask_idc = rng.choice(mask_idc, target_len, replace=False) mask[i, mask_idc] = True + if target_len is not None and len(mask_idc) < target_len: + unmasked = np.flatnonzero(~mask[i]) + to_mask = rng.choice(unmasked, target_len - len(mask_idc), replace=False) + mask[i, to_mask] = True + + if mask_dropout > 0: + masked = np.flatnonzero(mask[i]) + num_holes = np.rint(len(masked) * mask_dropout).astype(int) + to_drop = rng.choice(masked, num_holes, replace=False) + mask[i, to_drop] = False + + return mask + + +def compute_block_mask_2d( + shape: Tuple[int, int], + mask_prob: float, + mask_length: int, + mask_prob_adjust: float = 0, + inverse_mask: bool = False, + require_same_masks: bool = True, + expand_adjcent: bool = False, + mask_dropout: float = 0, + non_overlapping: bool = False, +) -> torch.Tensor: + + assert mask_length > 1 + + B, L = shape + + d = int(L**0.5) + + if inverse_mask: + mask_prob = 1 - mask_prob + + if non_overlapping: + sz = math.ceil(d / mask_length) + inp_len = sz * sz + + inp = torch.zeros((B, 1, sz, sz)) + w = torch.ones((1, 1, mask_length, mask_length)) + + mask_inds = torch.multinomial( + 1 - inp.view(B, -1), + int(inp_len * (mask_prob + mask_prob_adjust) * (1 + mask_dropout)), + replacement=False, + ) + inp.view(B, -1).scatter_(1, mask_inds, 1) + + mask = torch.nn.functional.conv_transpose2d(inp, w, stride=mask_length).squeeze( + 1 + ) + if mask.size(-1) > d: + mask = mask[..., :d, :d] + else: + mask = torch.zeros((B, d, d)) + mask_inds = torch.randint( + 0, + L, + size=( + B, + int( + L + * ((mask_prob + mask_prob_adjust) / mask_length**2) + * (1 + mask_dropout) + ), + ), + ) + mask.view(B, -1).scatter_(1, mask_inds, 1) + centers = mask.nonzero(as_tuple=True) + + inds = ([], [], []) + + offset = mask_length // 2 + for i in range(mask_length): + for j in range(mask_length): + k1 = i - offset + k2 = j - offset + inds[0].append(centers[0]) + inds[1].append(centers[1] + k1) + inds[2].append(centers[2] + k2) + + i0 = torch.cat(inds[0]) + i1 = torch.cat(inds[1]).clamp_(min=0, max=d - 1) + i2 = torch.cat(inds[2]).clamp_(min=0, max=d - 1) + + mask[(i0, i1, i2)] = 1 + + def get_nbs(b, m, w): + all_nbs = torch.nn.functional.conv2d(m.unsqueeze(1), w, padding="same") + all_nbs = all_nbs.clamp_max_(1).view(b, -1) + return all_nbs + + if require_same_masks and expand_adjcent: + w = torch.zeros((1, 1, 3, 3)) + w[..., 0, 1] = 1 + w[..., 2, 1] = 1 + w[..., 1, 0] = 1 + w[..., 1, 2] = 1 + + all_nbs = get_nbs(B, mask, w) + + mask = mask.reshape(B, -1) + + if require_same_masks: + n_masks = mask.sum(dim=-1) + final_target_len = int(L * (mask_prob)) + target_len = int(final_target_len * (1 + mask_dropout)) + + for i in range(len(mask)): + n = n_masks[i] + m = mask[i] + r = 0 + while expand_adjcent and n < target_len: + if r == 0: + nbs = all_nbs[i] + else: + nbs = get_nbs(1, m.view(1, d, d), w).flatten() + + cands = (1 - m + nbs) > 1 + cand_sz = int(cands.sum().item()) + + assert cand_sz > 0, f"{nbs} {cand_sz}" + + to_mask = torch.multinomial( + cands.float(), min(cand_sz, int(target_len - n)), replacement=False + ) + m[to_mask] = 1 + assert to_mask.numel() > 0 + n += to_mask.numel() + r += 1 + + if n > final_target_len: + to_unmask = torch.multinomial( + m, int(n - final_target_len), replacement=False + ) + m[to_unmask] = 0 + elif n < final_target_len: + to_mask = torch.multinomial( + (1 - m), int(final_target_len - n), replacement=False + ) + m[to_mask] = 1 + + if inverse_mask: + mask = 1 - mask + + return mask + + +def compute_block_mask_1d( + shape: Tuple[int, int], + mask_prob: float, + mask_length: int, + mask_prob_adjust: float = 0, + inverse_mask: bool = False, + require_same_masks: bool = True, + expand_adjcent: bool = False, + mask_dropout: float = 0, + non_overlapping: bool = False, +) -> torch.Tensor: + + B, L = shape + + if inverse_mask: + mask_prob = 1 - mask_prob + + if non_overlapping: + sz = math.ceil(L / mask_length) + + inp = torch.zeros((B, 1, sz)) + w = torch.ones((1, 1, mask_length)) + + mask_inds = torch.multinomial( + 1 - inp.view(B, -1), + int(sz * (mask_prob + mask_prob_adjust) * (1 + mask_dropout)), + replacement=False, + ) + inp.view(B, -1).scatter_(1, mask_inds, 1) + + mask = torch.nn.functional.conv_transpose1d(inp, w, stride=mask_length).squeeze( + 1 + ) + if mask.size(-1) > L: + mask = mask[..., :L] + + else: + mask = torch.zeros((B, L)) + mask_inds = torch.randint( + 0, + L, + size=( + B, + int( + L + * ((mask_prob + mask_prob_adjust) / mask_length) + * (1 + mask_dropout) + ), + ), + ) + + mask.view(B, -1).scatter_(1, mask_inds, 1) + centers = mask.nonzero(as_tuple=True) + + inds = ([], []) + + offset = mask_length // 2 + for i in range(mask_length): + k1 = i - offset + inds[0].append(centers[0]) + inds[1].append(centers[1] + k1) + + i0 = torch.cat(inds[0]) + i1 = torch.cat(inds[1]).clamp_(min=0, max=L - 1) + + mask[(i0, i1)] = 1 + + def get_nbs(b, m, w): + all_nbs = torch.nn.functional.conv1d(m.unsqueeze(1), w, padding="same") + all_nbs = all_nbs.clamp_max_(1).view(b, -1) + return all_nbs + + if require_same_masks and expand_adjcent: + w = torch.ones((1, 1, 3)) + w[..., 1] = 0 + all_nbs = get_nbs(B, mask, w) + + mask = mask.view(B, -1) + + if require_same_masks: + n_masks = mask.sum(dim=-1) + final_target_len = int(L * (mask_prob)) + target_len = int(final_target_len * (1 + mask_dropout)) + + for i in range(len(mask)): + n = n_masks[i] + m = mask[i] + r = 0 + while expand_adjcent and n < target_len: + if r == 0: + nbs = all_nbs[i] + else: + nbs = get_nbs(1, m.unsqueeze(0), w).squeeze(0) + + cands = (1 - m + nbs) > 1 + cand_sz = int(cands.sum().item()) + + assert cand_sz > 0, f"{nbs} {cand_sz}" + + to_mask = torch.multinomial( + cands.float(), min(cand_sz, int(target_len - n)), replacement=False + ) + m[to_mask] = 1 + assert to_mask.numel() > 0 + n += to_mask.numel() + r += 1 + + if n > final_target_len: + to_unmask = torch.multinomial( + m, int(n - final_target_len), replacement=False + ) + m[to_unmask] = 0 + elif n < final_target_len: + to_mask = torch.multinomial( + (1 - m), int(final_target_len - n), replacement=False + ) + m[to_mask] = 1 + + if inverse_mask: + mask = 1 - mask + return mask @@ -602,3 +922,223 @@ def raise_if_valid_subsets_unintentionally_ignored(train_cfg) -> None: advice = "Set --combine-val to combine them or --ignore-unused-valid-subsets to ignore them." msg = f"Valid paths {ignored_paths} will be ignored. {advice}" raise ValueError(msg) + + +def compute_mask_indices_for_one( + sz, + mask_prob: float, + mask_length: int, + seed=None, + epoch=None, + index=None, + min_masks=0, +): + """ + set seed, epoch, index for deterministic masking + """ + seed = int(hash((seed, epoch, index)) % 1e6) if seed else None + rng = np.random.default_rng(seed) + + # decide elements to mask + mask = np.full(sz, False) + num_mask = int( + # add a random number for probabilistic rounding + mask_prob * sz / float(mask_length) + + rng.random() + ) + num_mask = max(min_masks, num_mask) + + # multiple masking as described in the vq-wav2vec paper (https://arxiv.org/abs/1910.05453) + mask_idc = rng.choice(sz, num_mask, replace=False) + mask_idc = np.concatenate([mask_idc + i for i in range(mask_length)]) + mask_idc = mask_idc[mask_idc < len(mask)] + try: + mask[mask_idc] = True + except: # something wrong + print(f"Assigning mask indexes {mask_idc} to mask {mask} failed!") + raise + + return mask + + +def compute_mask_indices_v2( + shape: Tuple[int, int], + padding_mask: Optional[torch.Tensor], + mask_prob: float, + mask_length: int, + min_masks: int = 0, + require_same_masks: bool = True, + seed: Optional[int] = None, + epoch: Optional[int] = None, + indices: Optional[torch.Tensor] = None, +) -> np.ndarray: + bsz, all_sz = shape + mask = np.full((bsz, all_sz), False) + for i in range(bsz): + if padding_mask is not None: + sz = all_sz - padding_mask[i].long().sum().item() + else: + sz = all_sz + index = indices[i].item() if indices is not None else None + mask_for_one = compute_mask_indices_for_one( + sz, mask_prob, mask_length, seed, epoch, index, min_masks + ) + mask[i, :sz] = mask_for_one + + if require_same_masks: + index_sum = indices.sum().item() if indices is not None else None + seed = int(hash((seed, epoch, index_sum)) % 1e6) if seed else None + rng = np.random.default_rng(seed) + + num_mask = mask.sum(-1).min() + for i in range(bsz): + extra = mask[i].sum() - num_mask + if extra > 0: + to_unmask = rng.choice(np.nonzero(mask[i])[0], extra, replace=False) + mask[i, to_unmask] = False + + return mask + + +# TODO: a copy of the original compute_mask_indices +def compute_mask_indices_v3( + shape: Tuple[int, int], + padding_mask: Optional[torch.Tensor], + mask_prob: float, + mask_length: int, + mask_type: str = "static", + mask_other: float = 0.0, + min_masks: int = 0, + no_overlap: bool = False, + min_space: int = 0, + require_same_masks: bool = True, + mask_dropout: float = 0.0, + seed: Optional[int] = None, + epoch: Optional[int] = None, + indices: Optional[torch.Tensor] = None, +) -> np.ndarray: + """ + Computes random mask spans for a given shape + + Args: + shape: the the shape for which to compute masks. + should be of size 2 where first element is batch size and 2nd is timesteps + padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements + mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by + number of timesteps divided by length of mask span to mask approximately this percentage of all elements. + however due to overlaps, the actual number will be smaller (unless no_overlap is True) + mask_type: how to compute mask lengths + static = fixed size + uniform = sample from uniform distribution [mask_other, mask_length*2] + normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element + poisson = sample from possion distribution with lambda = mask length + min_masks: minimum number of masked spans + no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping + min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans + require_same_masks: if true, will randomly drop out masks until same amount of masks remains in each sample + mask_dropout: randomly dropout this percentage of masks in each example + """ + bsz, all_sz = shape + mask = np.full((bsz, all_sz), False) + + all_num_mask = int( + # add a random number for probabilistic rounding + mask_prob * all_sz / float(mask_length) + + np.random.rand() + ) + + all_num_mask = max(min_masks, all_num_mask) + + mask_idcs = [] + for i in range(bsz): + if seed is not None and epoch is not None and indices is not None: + seed_i = int(hash((seed, epoch, indices[i].item())) % 1e6) + else: + seed_i = None + rng = np.random.default_rng(seed_i) + + if padding_mask is not None: + sz = all_sz - padding_mask[i].long().sum().item() + num_mask = int( + # add a random number for probabilistic rounding + mask_prob * sz / float(mask_length) + + rng.random() + ) + num_mask = max(min_masks, num_mask) + else: + sz = all_sz + num_mask = all_num_mask + + if mask_type == "static": + lengths = np.full(num_mask, mask_length) + elif mask_type == "uniform": + lengths = rng.randint(mask_other, mask_length * 2 + 1, size=num_mask) + elif mask_type == "normal": + lengths = rng.normal(mask_length, mask_other, size=num_mask) + lengths = [max(1, int(round(x))) for x in lengths] + elif mask_type == "poisson": + lengths = rng.poisson(mask_length, size=num_mask) + lengths = [int(round(x)) for x in lengths] + else: + raise Exception("unknown mask selection " + mask_type) + + if sum(lengths) == 0: + lengths[0] = min(mask_length, sz - 1) + + if no_overlap: + mask_idc = [] + + def arrange(s, e, length, keep_length): + span_start = rng.randint(s, e - length) + mask_idc.extend(span_start + i for i in range(length)) + + new_parts = [] + if span_start - s - min_space >= keep_length: + new_parts.append((s, span_start - min_space + 1)) + if e - span_start - length - min_space > keep_length: + new_parts.append((span_start + length + min_space, e)) + return new_parts + + parts = [(0, sz)] + min_length = min(lengths) + for length in sorted(lengths, reverse=True): + lens = np.fromiter( + (e - s if e - s >= length + min_space else 0 for s, e in parts), + np.int, + ) + l_sum = np.sum(lens) + if l_sum == 0: + break + probs = lens / np.sum(lens) + c = rng.choice(len(parts), p=probs) + s, e = parts.pop(c) + parts.extend(arrange(s, e, length, min_length)) + mask_idc = np.asarray(mask_idc) + else: + min_len = min(lengths) + if sz - min_len <= num_mask: + min_len = sz - num_mask - 1 + + mask_idc = rng.choice(sz - min_len, num_mask, replace=False) + + mask_idc = np.asarray( + [ + mask_idc[j] + offset + for j in range(len(mask_idc)) + for offset in range(lengths[j]) + ] + ) + + mask_idcs.append(np.unique(mask_idc[mask_idc < sz])) + + min_len = min([len(m) for m in mask_idcs]) + for i, mask_idc in enumerate(mask_idcs): + if len(mask_idc) > min_len and require_same_masks: + mask_idc = rng.choice(mask_idc, min_len, replace=False) + if mask_dropout > 0: + num_holes = np.rint(len(mask_idc) * mask_dropout).astype(int) + mask_idc = rng.choice(mask_idc, len(mask_idc) - num_holes, replace=False) + + mask[i, mask_idc] = True + + return mask diff --git a/fairseq/data/indexed_dataset.py b/fairseq/data/indexed_dataset.py index 81cba4af6..1947d9940 100644 --- a/fairseq/data/indexed_dataset.py +++ b/fairseq/data/indexed_dataset.py @@ -542,6 +542,11 @@ class MMapIndexedDataset(torch.utils.data.Dataset): data_file_path(path) ) + @property + def can_reuse_epoch_itr_across_epochs(self): + # TODO: a quick fix. make it a child class of FairseqDataset instead? + return True + def get_indexed_dataset_to_local(path) -> str: local_index_path = PathManager.get_local_path(index_file_path(path)) diff --git a/fairseq/data/iterators.py b/fairseq/data/iterators.py index 45a8c65fa..a48826513 100644 --- a/fairseq/data/iterators.py +++ b/fairseq/data/iterators.py @@ -156,7 +156,7 @@ class StreamingEpochBatchIterator(EpochBatchIterating): num_workers=0, buffer_size=0, timeout=0, - persistent_workers=False, + persistent_workers=True, ): assert isinstance(dataset, torch.utils.data.IterableDataset) self.dataset = dataset @@ -164,11 +164,11 @@ class StreamingEpochBatchIterator(EpochBatchIterating): self.collate_fn = collate_fn self.epoch = max(epoch, 1) # we use 1-based indexing for epochs self.num_workers = num_workers + self.persistent_workers = persistent_workers and num_workers > 0 # This upper limit here is to prevent people from abusing this feature # in a shared computing environment. self.buffer_size = min(buffer_size, 20) self.timeout = timeout - self.persistent_workers = persistent_workers self._current_epoch_iterator = None @@ -321,7 +321,7 @@ class EpochBatchIterator(EpochBatchIterating): skip_remainder_batch=False, grouped_shuffling=False, reuse_dataloader=False, - persistent_workers=False, + persistent_workers=True, ): assert isinstance(dataset, torch.utils.data.Dataset) self.dataset = dataset @@ -334,6 +334,7 @@ class EpochBatchIterator(EpochBatchIterating): self.num_shards = num_shards self.shard_id = shard_id self.num_workers = num_workers + self.persistent_workers = persistent_workers and num_workers > 0 # This upper limit here is to prevent people from abusing this feature # in a shared computing environment. self.buffer_size = min(buffer_size, 20) @@ -350,7 +351,6 @@ class EpochBatchIterator(EpochBatchIterating): self.dataloader = None self.reuse_dataloader = reuse_dataloader - self.persistent_workers = persistent_workers @property def frozen_batches(self): @@ -777,8 +777,6 @@ class GroupedEpochBatchIterator(EpochBatchIterator): mult_rate=1, buffer_size=0, skip_remainder_batch=False, - reuse_dataloader=False, - persistent_workers=False, ): super().__init__( dataset, @@ -791,8 +789,6 @@ class GroupedEpochBatchIterator(EpochBatchIterator): epoch, buffer_size, skip_remainder_batch=skip_remainder_batch, - reuse_dataloader=reuse_dataloader, - persistent_workers=persistent_workers, ) # level 0: sub-samplers 1: batch_idx 2: batches self._frozen_batches = tuple([tuple(sub_batch) for sub_batch in batch_samplers]) diff --git a/fairseq/data/mask_tokens_dataset.py b/fairseq/data/mask_tokens_dataset.py index 912323559..0ca9051c9 100644 --- a/fairseq/data/mask_tokens_dataset.py +++ b/fairseq/data/mask_tokens_dataset.py @@ -69,6 +69,7 @@ class MaskTokensDataset(BaseWrapperDataset): mask_whole_words: torch.Tensor = None, mask_multiple_length: int = 1, mask_stdev: float = 0.0, + skip_masking: bool = False, ): assert 0.0 < mask_prob < 1.0 assert 0.0 <= random_token_prob <= 1.0 @@ -89,6 +90,7 @@ class MaskTokensDataset(BaseWrapperDataset): self.mask_whole_words = mask_whole_words self.mask_multiple_length = mask_multiple_length self.mask_stdev = mask_stdev + self.skip_masking = skip_masking if random_token_prob > 0.0: if freq_weighted_replacement: @@ -113,108 +115,112 @@ class MaskTokensDataset(BaseWrapperDataset): @lru_cache(maxsize=8) def __getitem_cached__(self, seed: int, epoch: int, index: int): - with data_utils.numpy_seed(self.seed, self.epoch, index): - item = self.dataset[index] - sz = len(item) + seed = int(hash((seed, epoch, index)) % 1e6) + rng = np.random.default_rng(seed) + item = self.dataset[index] + sz = len(item) - assert ( - self.mask_idx not in item - ), "Dataset contains mask_idx (={}), this is not expected!".format( - self.mask_idx, + assert ( + self.mask_idx not in item + ), "Dataset contains mask_idx (={}), this is not expected!".format( + self.mask_idx, + ) + if self.skip_masking: + return torch.from_numpy(np.copy(item)) + + if self.mask_whole_words is not None: + word_begins_mask = self.mask_whole_words.gather(0, item) + word_begins_idx = word_begins_mask.nonzero().view(-1) + sz = len(word_begins_idx) + words = np.split(word_begins_mask, word_begins_idx)[1:] + assert len(words) == sz + word_lens = list(map(len, words)) + + # decide elements to mask + mask = np.full(sz, False) + num_mask = int( + # add a random number for probabilistic rounding + self.mask_prob * sz / float(self.mask_multiple_length) + + rng.random() + ) + + # multiple masking as described in the vq-wav2vec paper (https://arxiv.org/abs/1910.05453) + mask_idc = rng.choice(sz, num_mask, replace=False) + if self.mask_stdev > 0.0: + lengths = rng.normal( + self.mask_multiple_length, self.mask_stdev, size=num_mask ) - - if self.mask_whole_words is not None: - word_begins_mask = self.mask_whole_words.gather(0, item) - word_begins_idx = word_begins_mask.nonzero().view(-1) - sz = len(word_begins_idx) - words = np.split(word_begins_mask, word_begins_idx)[1:] - assert len(words) == sz - word_lens = list(map(len, words)) - - # decide elements to mask - mask = np.full(sz, False) - num_mask = int( - # add a random number for probabilistic rounding - self.mask_prob * sz / float(self.mask_multiple_length) - + np.random.rand() + lengths = [max(0, int(round(x))) for x in lengths] + mask_idc = np.asarray( + [ + mask_idc[j] + offset + for j in range(len(mask_idc)) + for offset in range(lengths[j]) + ], + dtype=np.int64, ) + else: + mask_idc = np.concatenate( + [mask_idc + i for i in range(self.mask_multiple_length)] + ) + mask_idc = mask_idc[mask_idc < len(mask)] + try: + mask[mask_idc] = True + except: # something wrong + print("Assigning mask indexes {} to mask {} failed!".format(mask_idc, mask)) + raise - # multiple masking as described in the vq-wav2vec paper (https://arxiv.org/abs/1910.05453) - mask_idc = np.random.choice(sz, num_mask, replace=False) - if self.mask_stdev > 0.0: - lengths = np.random.normal( - self.mask_multiple_length, self.mask_stdev, size=num_mask - ) - lengths = [max(0, int(round(x))) for x in lengths] - mask_idc = np.asarray( - [ - mask_idc[j] + offset - for j in range(len(mask_idc)) - for offset in range(lengths[j]) - ], - dtype=np.int64, - ) - else: - mask_idc = np.concatenate( - [mask_idc + i for i in range(self.mask_multiple_length)] - ) - mask_idc = mask_idc[mask_idc < len(mask)] - try: - mask[mask_idc] = True - except: # something wrong - print( - "Assigning mask indexes {} to mask {} failed!".format( - mask_idc, mask - ) - ) - raise - - if self.return_masked_tokens: - # exit early if we're just returning the masked tokens - # (i.e., the targets for masked LM training) - if self.mask_whole_words is not None: - mask = np.repeat(mask, word_lens) - new_item = np.full(len(mask), self.pad_idx) - new_item[mask] = item[torch.from_numpy(mask.astype(np.uint8)) == 1] - return torch.from_numpy(new_item) - - # decide unmasking and random replacement - rand_or_unmask_prob = self.random_token_prob + self.leave_unmasked_prob - if rand_or_unmask_prob > 0.0: - rand_or_unmask = mask & (np.random.rand(sz) < rand_or_unmask_prob) - if self.random_token_prob == 0.0: - unmask = rand_or_unmask - rand_mask = None - elif self.leave_unmasked_prob == 0.0: - unmask = None - rand_mask = rand_or_unmask - else: - unmask_prob = self.leave_unmasked_prob / rand_or_unmask_prob - decision = np.random.rand(sz) < unmask_prob - unmask = rand_or_unmask & decision - rand_mask = rand_or_unmask & (~decision) - else: - unmask = rand_mask = None - - if unmask is not None: - mask = mask ^ unmask - + # if self.return_masked_tokens: + # print(( + # f"IDX={index}; seed={seed}; epoch={epoch}; is_tgt={self.return_masked_tokens}: " + # f"{np.nonzero(mask)[0].sum()}" + # )) + if self.return_masked_tokens: + # exit early if we're just returning the masked tokens + # (i.e., the targets for masked LM training) if self.mask_whole_words is not None: mask = np.repeat(mask, word_lens) - - new_item = np.copy(item) - new_item[mask] = self.mask_idx - if rand_mask is not None: - num_rand = rand_mask.sum() - if num_rand > 0: - if self.mask_whole_words is not None: - rand_mask = np.repeat(rand_mask, word_lens) - num_rand = rand_mask.sum() - - new_item[rand_mask] = np.random.choice( - len(self.vocab), - num_rand, - p=self.weights, - ) - + new_item = np.full(len(mask), self.pad_idx) + new_item[mask] = item[torch.from_numpy(mask.astype(np.uint8)) == 1] return torch.from_numpy(new_item) + + # decide unmasking and random replacement + rand_or_unmask_prob = self.random_token_prob + self.leave_unmasked_prob + if rand_or_unmask_prob > 0.0: + rand_or_unmask = mask & (rng.random(sz) < rand_or_unmask_prob) + if self.random_token_prob == 0.0: + unmask = rand_or_unmask + rand_mask = None + elif self.leave_unmasked_prob == 0.0: + unmask = None + rand_mask = rand_or_unmask + else: + unmask_prob = self.leave_unmasked_prob / rand_or_unmask_prob + decision = rng.random(sz) < unmask_prob + unmask = rand_or_unmask & decision + rand_mask = rand_or_unmask & (~decision) + else: + unmask = rand_mask = None + + if unmask is not None: + mask = mask ^ unmask + + if self.mask_whole_words is not None: + mask = np.repeat(mask, word_lens) + + new_item = np.copy(item) + new_item[mask] = self.mask_idx + if rand_mask is not None: + num_rand = rand_mask.sum() + if num_rand > 0: + if self.mask_whole_words is not None: + rand_mask = np.repeat(rand_mask, word_lens) + num_rand = rand_mask.sum() + + new_item[rand_mask] = rng.choice( + len(self.vocab), + num_rand, + p=self.weights, + ) + + return torch.from_numpy(new_item) diff --git a/fairseq/data/padding_mask_dataset.py b/fairseq/data/padding_mask_dataset.py new file mode 100644 index 000000000..d7f7b88db --- /dev/null +++ b/fairseq/data/padding_mask_dataset.py @@ -0,0 +1,38 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch + +from fairseq.data import data_utils +from . import BaseWrapperDataset + + +class PaddingMaskDataset(BaseWrapperDataset): + def __init__(self, dataset, left_pad, pad_length=None): + super().__init__(dataset) + self.left_pad = left_pad + self.pad_length = pad_length + + def __getitem__(self, index): + item = self.dataset[index] + return torch.zeros_like(item).bool() + + def __len__(self): + return len(self.dataset) + + def collater(self, samples): + return data_utils.collate_tokens( + samples, True, left_pad=self.left_pad, pad_to_length=self.pad_length + ) + + +class LeftPaddingMaskDataset(PaddingMaskDataset): + def __init__(self, dataset): + super().__init__(dataset, left_pad=True) + + +class RightPaddingMaskDataset(PaddingMaskDataset): + def __init__(self, dataset): + super().__init__(dataset, left_pad=False) diff --git a/fairseq/data/subsample_dataset.py b/fairseq/data/subsample_dataset.py index 48feaf883..fe5c7e2ac 100644 --- a/fairseq/data/subsample_dataset.py +++ b/fairseq/data/subsample_dataset.py @@ -3,9 +3,11 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import contextlib import logging import numpy as np +from fairseq.data.data_utils import numpy_seed from . import BaseWrapperDataset @@ -21,13 +23,14 @@ class SubsampleDataset(BaseWrapperDataset): size_ratio(float): the ratio to subsample to. must be between 0 and 1 (exclusive) """ - def __init__(self, dataset, size_ratio, shuffle=False): + def __init__(self, dataset, size_ratio, shuffle=False, seed=None): super().__init__(dataset) assert size_ratio < 1 self.actual_size = np.ceil(len(dataset) * size_ratio).astype(int) - self.indices = np.random.choice( - list(range(len(self.dataset))), self.actual_size, replace=False - ) + with numpy_seed(seed) if seed is not None else contextlib.ExitStack(): + self.indices = np.random.choice( + list(range(len(self.dataset))), self.actual_size, replace=False + ) self.shuffle = shuffle logger.info( "subsampled dataset from {} to {} (ratio={})".format( diff --git a/fairseq/dataclass/configs.py b/fairseq/dataclass/configs.py index 5fdfab38d..af957fec6 100644 --- a/fairseq/dataclass/configs.py +++ b/fairseq/dataclass/configs.py @@ -636,6 +636,7 @@ class OptimizationConfig(FairseqDataclass): " (default is to skip it)." }, ) + debug_param_names: bool = False @dataclass diff --git a/fairseq/dataclass/utils.py b/fairseq/dataclass/utils.py index 69b77962e..f6467d5f4 100644 --- a/fairseq/dataclass/utils.py +++ b/fairseq/dataclass/utils.py @@ -487,15 +487,22 @@ def overwrite_args_by_name(cfg: DictConfig, overrides: Dict[str, any]): def merge_with_parent(dc: FairseqDataclass, cfg: DictConfig, remove_missing=False): if remove_missing: - if is_dataclass(dc): - target_keys = set(dc.__dataclass_fields__.keys()) - else: - target_keys = set(dc.keys()) + def remove_missing_rec(src_keys, target_cfg): + if is_dataclass(target_cfg): + target_keys = set(target_cfg.__dataclass_fields__.keys()) + else: + target_keys = set(target_cfg.keys()) + + for k in list(src_keys.keys()): + if k not in target_keys: + del src_keys[k] + elif OmegaConf.is_config(src_keys[k]): + tgt = getattr(target_cfg, k) + if tgt is not None and (is_dataclass(tgt) or hasattr(tgt, "keys")): + remove_missing_rec(src_keys[k], tgt) with open_dict(cfg): - for k in list(cfg.keys()): - if k not in target_keys: - del cfg[k] + remove_missing_rec(cfg, dc) merged_cfg = OmegaConf.merge(dc, cfg) merged_cfg.__dict__["_parent"] = cfg.__dict__["_parent"] diff --git a/fairseq/distributed/utils.py b/fairseq/distributed/utils.py index 2c52f76aa..968830d58 100644 --- a/fairseq/distributed/utils.py +++ b/fairseq/distributed/utils.py @@ -51,18 +51,21 @@ def infer_init_method(cfg: DistributedTrainingConfig, force_distributed=False): if cfg.pipeline_model_parallel: num_pipeline_devices, num_pipelines_per_node = _pipeline_parallel_pre_init(cfg) + if cfg.distributed_world_size == 1: + return if all( key in os.environ for key in ["MASTER_ADDR", "MASTER_PORT", "WORLD_SIZE", "RANK"] ): # support torch.distributed.launch _infer_torch_distributed_launch_init(cfg) - elif cfg.distributed_port > 0: + else: # we can determine the init method automatically for Slurm - _infer_slurm_init(cfg, num_pipelines_per_node) - elif cfg.distributed_world_size > 1 or force_distributed: - # fallback for single node with multiple GPUs - _infer_single_node_init(cfg) + if not _infer_slurm_init(cfg, num_pipelines_per_node): + if cfg.distributed_port <= 0 or force_distributed: + _infer_single_node_init(cfg) + elif cfg.distributed_port <= 0: + _infer_single_node_init(cfg) if cfg.pipeline_model_parallel: _pipeline_parallel_post_init(cfg, num_pipeline_devices, num_pipelines_per_node) @@ -71,12 +74,21 @@ def infer_init_method(cfg: DistributedTrainingConfig, force_distributed=False): cfg.distributed_num_procs = min( torch.cuda.device_count(), cfg.distributed_world_size ) + else: + if cfg.device_id > 0: + logger.info( + "setting CUDA device={} on rank {}".format( + cfg.device_id, cfg.distributed_rank + ) + ) + torch.cuda.set_device(cfg.device_id) def _infer_torch_distributed_launch_init(cfg: DistributedTrainingConfig): cfg.distributed_init_method = "env://" cfg.distributed_world_size = int(os.environ["WORLD_SIZE"]) cfg.distributed_rank = int(os.environ["RANK"]) + cfg.device_id = cfg.distributed_rank % torch.cuda.device_count() # processes are created by torch.distributed.launch cfg.distributed_no_spawn = True @@ -127,22 +139,44 @@ def _infer_slurm_init(cfg: DistributedTrainingConfig, num_pipelines_per_node): # number of pipelines across all nodes. cfg.distributed_world_size = nnodes * num_pipelines_per_node else: - assert ntasks_per_node == cfg.distributed_world_size // nnodes + assert ( + ntasks_per_node == cfg.distributed_world_size // nnodes + ), f"{ntasks_per_node}, {cfg.distributed_world_size}, {nnodes}" cfg.distributed_no_spawn = True cfg.distributed_rank = int(os.environ.get("SLURM_PROCID")) cfg.device_id = int(os.environ.get("SLURM_LOCALID")) + logger.info(f"Rank {cfg.distributed_rank}, device_id: {cfg.device_id}") + return True except subprocess.CalledProcessError as e: # scontrol failed raise e except FileNotFoundError: # Slurm is not installed pass + return False + def _infer_single_node_init(cfg: DistributedTrainingConfig): assert ( cfg.distributed_world_size <= torch.cuda.device_count() ), f"world size is {cfg.distributed_world_size} but have {torch.cuda.device_count()} available devices" - port = random.randint(10000, 20000) - cfg.distributed_init_method = "tcp://localhost:{port}".format(port=port) + + if cfg.distributed_port <= 0: + jobid = os.environ.get("SLURM_JOB_ID") + task_id = os.environ.get("SLURM_ARRAY_TASK_ID") + + if jobid is not None: + if task_id is not None: + jobid += str(task_id) + jobid = int(jobid) + rng = random.Random(jobid) + port = rng.randint(10000, 60000) + else: + port = random.randint(10000, 60000) + + cfg.distributed_port = port + cfg.distributed_init_method = "tcp://localhost:{port}".format( + port=cfg.distributed_port + ) def _pipeline_parallel_pre_init(cfg: DistributedTrainingConfig): @@ -341,6 +375,7 @@ def call_main(cfg: FairseqConfig, main, **kwargs): start_rank = cfg.distributed_training.distributed_rank cfg.distributed_training.distributed_rank = None # assign automatically kwargs["start_rank"] = start_rank + torch.multiprocessing.spawn( fn=distributed_main, args=(main, cfg, kwargs), diff --git a/fairseq/iterative_refinement_generator.py b/fairseq/iterative_refinement_generator.py index 4fb0946f4..3d32c6bf4 100644 --- a/fairseq/iterative_refinement_generator.py +++ b/fairseq/iterative_refinement_generator.py @@ -235,7 +235,7 @@ class IterativeRefinementGenerator(object): terminated.fill_(1) # collect finalized sentences - finalized_idxs = sent_idxs[terminated] + finalized_idxs = sent_idxs[terminated.to(sent_idxs.device)] finalized_tokens = decoder_out.output_tokens[terminated] finalized_scores = decoder_out.output_scores[terminated] finalized_attn = ( @@ -285,7 +285,7 @@ class IterativeRefinementGenerator(object): encoder_out = model.encoder.reorder_encoder_out( encoder_out, not_terminated.nonzero(as_tuple=False).squeeze() ) - sent_idxs = sent_idxs[not_terminated] + sent_idxs = sent_idxs[not_terminated.to(sent_idxs.device)] prev_output_tokens = prev_decoder_out.output_tokens.clone() if self.beam_size > 1: diff --git a/fairseq/logging/meters.py b/fairseq/logging/meters.py index d5f7c775d..495bd0830 100644 --- a/fairseq/logging/meters.py +++ b/fairseq/logging/meters.py @@ -139,6 +139,36 @@ class SumMeter(Meter): return val +class ConcatTensorMeter(Meter): + """Concatenates tensors""" + + def __init__(self, dim=0): + super().__init__() + self.reset() + self.dim = dim + + def reset(self): + self.tensor = None + + def update(self, val): + if self.tensor is None: + self.tensor = val + else: + self.tensor = torch.cat([self.tensor, val], dim=self.dim) + + def state_dict(self): + return { + "tensor": self.tensor, + } + + def load_state_dict(self, state_dict): + self.tensor = state_dict["tensor"] + + @property + def smoothed_value(self) -> float: + return [] # return a dummy value + + class TimeMeter(Meter): """Computes the average occurrence of some event per second""" diff --git a/fairseq/logging/metrics.py b/fairseq/logging/metrics.py index 892b0ea4d..49301f27f 100644 --- a/fairseq/logging/metrics.py +++ b/fairseq/logging/metrics.py @@ -151,6 +151,26 @@ def log_scalar_sum( agg[key].update(value) +def log_concat_tensor( + key: str, + value: torch.Tensor, + priority: int = 10, + dim: int = 0, +): + """Log a scalar value that is summed for reporting. + + Args: + key (str): name of the field to log + value (float): value to log + priority (int): smaller values are logged earlier in the output + round (Optional[int]): number of digits to round to when displaying + """ + for agg in get_active_aggregators(): + if key not in agg: + agg.add_meter(key, ConcatTensorMeter(dim=dim), priority) + agg[key].update(value) + + def log_derived(key: str, fn: Callable[[MetersDict], float], priority: int = 20): """Log a scalar value derived from other meters. diff --git a/fairseq/models/__init__.py b/fairseq/models/__init__.py index 616e3051e..11cf6ee53 100644 --- a/fairseq/models/__init__.py +++ b/fairseq/models/__init__.py @@ -128,7 +128,8 @@ def register_model(name, dataclass=None): def register_model_cls(cls): if name in MODEL_REGISTRY: - raise ValueError("Cannot register duplicate model ({})".format(name)) + return MODEL_REGISTRY[name] + if not issubclass(cls, BaseFairseqModel): raise ValueError( "Model ({}: {}) must extend BaseFairseqModel".format(name, cls.__name__) diff --git a/fairseq/models/fairseq_model.py b/fairseq/models/fairseq_model.py index 42f9134a3..65ead9dcf 100644 --- a/fairseq/models/fairseq_model.py +++ b/fairseq/models/fairseq_model.py @@ -160,6 +160,11 @@ class BaseFairseqModel(nn.Module): if hasattr(m, "set_num_updates") and m != self: m.set_num_updates(num_updates) + def set_epoch(self, epoch): + for m in self.modules(): + if hasattr(m, "set_epoch") and m != self: + m.set_epoch(epoch) + def prepare_for_inference_(self, cfg: DictConfig): """Prepare model for inference.""" kwargs = {} diff --git a/fairseq/models/wav2vec/wav2vec2_asr.py b/fairseq/models/wav2vec/wav2vec2_asr.py index 9d3ffa154..bf261624e 100644 --- a/fairseq/models/wav2vec/wav2vec2_asr.py +++ b/fairseq/models/wav2vec/wav2vec2_asr.py @@ -47,6 +47,7 @@ class Wav2Vec2AsrConfig(FairseqDataclass): default=0.0, metadata={"help": "dropout to apply to the input (after feat extr)"}, ) + final_dropout: float = field( default=0.0, metadata={"help": "dropout after transformer and before final projection"}, @@ -66,19 +67,6 @@ class Wav2Vec2AsrConfig(FairseqDataclass): "help": "dropout probability after activation in FFN inside wav2vec 2.0 model" }, ) - conv_feature_layers: Optional[str] = field( - default="[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512,2,2)] + [(512,2,2)]", - metadata={ - "help": ( - "string describing convolutional feature extraction " - "layers in form of a python list that contains " - "[(dim, kernel_size, stride), ...]" - ), - }, - ) - encoder_embed_dim: Optional[int] = field( - default=768, metadata={"help": "encoder embedding dimension"} - ) # masking apply_mask: bool = field( @@ -152,12 +140,14 @@ class Wav2Vec2AsrConfig(FairseqDataclass): layerdrop: float = field( default=0.0, metadata={"help": "probability of dropping a layer in wav2vec 2.0"} ) + drop_path: float = 0 mask_channel_min_space: Optional[int] = field( default=1, metadata={"help": "min space between spans (if no overlap is enabled)"}, ) mask_channel_before: bool = False normalize: bool = II("task.normalize") + update_alibi: bool = True data: str = II("task.data") # this holds the loaded wav2vec args w2v_args: Any = None @@ -182,6 +172,11 @@ class Wav2Vec2AsrConfig(FairseqDataclass): ) ddp_backend: str = II("distributed_training.ddp_backend") + zero_mask: bool = False + load_ema: bool = False + + layer_decay: float = 1 + @dataclass class Wav2Vec2CtcConfig(Wav2Vec2AsrConfig): @@ -224,6 +219,12 @@ class Wav2VecCtc(BaseFairseqModel): number_of_classes, device=logits.device ) * float("-inf") masking_tensor[0] = 0 + + if logits.size(0) > net_output["padding_mask"].size(1): + net_output["padding_mask"] = F.pad( + net_output["padding_mask"], (1, 0), value=False + ) + logits[net_output["padding_mask"].T] = masking_tensor.type_as(logits) if normalize: @@ -371,6 +372,19 @@ class Wav2VecEncoder(FairseqEncoder): "checkpoint_activations": cfg.checkpoint_activations, "offload_activations": cfg.offload_activations, "min_params_to_wrap": cfg.min_params_to_wrap, + # d2v multi args + "encoder_dropout": cfg.dropout, + "drop_path": getattr(cfg, "drop_path", 0), + "mask_dropout": getattr(cfg, "mask_dropout", 0), + "zero_mask": getattr(cfg, "zero_mask", False), + "local_grad_mult": cfg.feature_grad_mult, + "layerdrop": cfg.layerdrop, + "prenet_layerdrop": cfg.layerdrop, + "prenet_dropout": cfg.dropout, + "post_mlp_drop": cfg.dropout, + "encoder_zero_mask": getattr(cfg, "zero_mask", False), + "inverse_mask": False, + "learned_alibi_scale": getattr(cfg, "update_alibi", True), } if cfg.w2v_args is None: @@ -380,6 +394,7 @@ class Wav2VecEncoder(FairseqEncoder): w2v_args = convert_namespace_to_omegaconf(state["args"]) w2v_args.criterion = None w2v_args.lr_scheduler = None + cfg.w2v_args = w2v_args logger.info(w2v_args) @@ -390,31 +405,52 @@ class Wav2VecEncoder(FairseqEncoder): if isinstance(w2v_args, Namespace): cfg.w2v_args = w2v_args = convert_namespace_to_omegaconf(w2v_args) - model_normalized = w2v_args.task.get( - "normalize", w2v_args.model.get("normalize", False) - ) - assert cfg.normalize == model_normalized, ( - "Fine-tuning works best when data normalization is the same. " - "Please check that --normalize is set or unset for both pre-training and here" - ) + self.is_d2v_multi = "data2vec_multi" in w2v_args.model.get("_name", None) - if hasattr(cfg, "checkpoint_activations") and cfg.checkpoint_activations: - with open_dict(w2v_args): - w2v_args.model.checkpoint_activations = cfg.checkpoint_activations + if not self.is_d2v_multi: + model_normalized = w2v_args.task.get( + "normalize", w2v_args.model.get("normalize", False) + ) + assert cfg.normalize == model_normalized, ( + "Fine-tuning works best when data normalization is the same. " + "Please check that --normalize is set or unset for both pre-training and here" + ) - w2v_args.task.data = cfg.data - task = tasks.setup_task(w2v_args.task) - model = task.build_model(w2v_args.model, from_checkpoint=True) + if hasattr(cfg, "checkpoint_activations") and cfg.checkpoint_activations: + with open_dict(w2v_args): + w2v_args.model.checkpoint_activations = cfg.checkpoint_activations - model.remove_pretraining_modules() + w2v_args.task.data = cfg.data + task = tasks.setup_task(w2v_args.task, from_checkpoint=True) + model = task.build_model(w2v_args.model, from_checkpoint=True) + + model.remove_pretraining_modules() + d = w2v_args.model.encoder_embed_dim + else: + assert cfg.normalize + + if hasattr(w2v_args.task, "audio"): + w2v_args.task.audio.data = cfg.data + else: + w2v_args.task.data = cfg.data + task = tasks.setup_task(w2v_args.task, from_checkpoint=True) + + model = task.build_model(w2v_args.model, from_checkpoint=True) + + model.remove_pretraining_modules(modality="audio") + d = w2v_args.model.embed_dim if state is not None and not cfg.no_pretrained_weights: + if cfg.load_ema: + assert "_ema" in state["model"] + for k in state["model"]["_ema"]: + mk = "encoder." + k + assert mk in state["model"], mk + state["model"][mk] = state["model"]["_ema"][k] self.load_model_weights(state, model, cfg) super().__init__(task.source_dictionary) - d = w2v_args.model.encoder_embed_dim - self.w2v_model = model self.final_dropout = nn.Dropout(cfg.final_dropout) @@ -432,6 +468,29 @@ class Wav2VecEncoder(FairseqEncoder): if targ_d is not None: self.proj = Linear(d, targ_d) + layer_decay = getattr(cfg, "layer_decay", 1) + if layer_decay < 1: + mod_encs = list(model.modality_encoders.values()) + assert len(mod_encs) == 1, len(mod_encs) + blocks = list(mod_encs[0].context_encoder.blocks) + list(model.blocks) + num_layers = len(blocks) + 1 + layer_scales = list( + layer_decay ** (num_layers - i) for i in range(num_layers + 1) + ) + + for i, b in enumerate(blocks): + lid = i + 1 + if layer_scales[lid] == 1.0: + continue + + for n, p in b.named_parameters(): + optim_override = getattr(p, "optim_overrides", {}) + if "optimizer" not in optim_override: + optim_override["optimizer"] = {} + + optim_override["optimizer"]["lr_scale"] = layer_scales[lid] + p.optim_overrides = optim_override + def load_model_weights(self, state, model, cfg): if cfg.ddp_backend == "fully_sharded": from fairseq.distributed import FullyShardedDataParallel @@ -461,8 +520,25 @@ class Wav2VecEncoder(FairseqEncoder): model.load_state_dict(new_big_dict, strict=False) else: - if "_ema" in state["model"]: - del state["model"]["_ema"] + to_delete = {"_ema", "target_proj", "decoder"} + for k in to_delete: + if k in state["model"]: + del state["model"][k] + + if hasattr(model, "modality_encoders"): + if "modality_encoders.AUDIO.encoder_mask" not in state["model"]: + model.modality_encoders["AUDIO"].encoder_mask = None + elif not cfg.zero_mask: + model.modality_encoders["AUDIO"].encoder_mask = None + del state["model"]["modality_encoders.AUDIO.encoder_mask"] + + for k in list(state["model"].keys()): + if k.startswith("modality_encoders.") and not k.startswith( + "modality_encoders.AUDIO" + ): + del state["model"][k] + + print(model) model.load_state_dict(state["model"], strict=True) def set_num_updates(self, num_updates): @@ -478,6 +554,9 @@ class Wav2VecEncoder(FairseqEncoder): "mask": self.apply_mask and self.training, } + if self.is_d2v_multi: + w2v_args["mode"] = "AUDIO" + ft = self.freeze_finetune_updates <= self.num_updates with torch.no_grad() if not ft else contextlib.ExitStack(): @@ -626,6 +705,7 @@ class TransformerDecoder(FairseqIncrementalDecoder): - the decoder's output of shape `(batch, tgt_len, vocab)` - a dictionary with any model-specific outputs """ + if type(prev_output_tokens) == list: max_len = max((len(x) for x in prev_output_tokens)) tmp = torch.zeros( @@ -634,6 +714,7 @@ class TransformerDecoder(FairseqIncrementalDecoder): for (i, p) in enumerate(prev_output_tokens): tmp[i, : len(p)] = p prev_output_tokens = tmp + prev_output_tokens = prev_output_tokens.long() x, extra = self.extract_features( prev_output_tokens, encoder_out, incremental_state diff --git a/fairseq/modules/__init__.py b/fairseq/modules/__init__.py index 1df76ef24..dcfda9b82 100644 --- a/fairseq/modules/__init__.py +++ b/fairseq/modules/__init__.py @@ -32,7 +32,7 @@ from .location_attention import LocationAttention from .lstm_cell_with_zoneout import LSTMCellWithZoneOut from .multihead_attention import MultiheadAttention from .positional_embedding import PositionalEmbedding -from .same_pad import SamePad +from .same_pad import SamePad, SamePad2d from .scalar_bias import ScalarBias from .sinusoidal_positional_embedding import SinusoidalPositionalEmbedding from .transformer_sentence_encoder_layer import TransformerSentenceEncoderLayer @@ -87,6 +87,7 @@ __all__ = [ "MultiheadAttention", "PositionalEmbedding", "SamePad", + "SamePad2d", "ScalarBias", "SinusoidalPositionalEmbedding", "TransformerSentenceEncoderLayer", diff --git a/fairseq/modules/ema_module.py b/fairseq/modules/ema_module.py index a5b98861d..f0ece842d 100644 --- a/fairseq/modules/ema_module.py +++ b/fairseq/modules/ema_module.py @@ -11,8 +11,18 @@ import logging import torch +from omegaconf import II from fairseq.dataclass import FairseqDataclass +try: + from amp_C import multi_tensor_l2norm + + multi_tensor_l2norm_available = True +except ImportError: + multi_tensor_l2norm_available = False + +logger = logging.getLogger(__name__) + @dataclass class EMAModuleConfig(FairseqDataclass): @@ -23,12 +33,21 @@ class EMAModuleConfig(FairseqDataclass): default=False, metadata={"help": "If true, store EMA model in fp32 even if model is in fp16"}, ) + add_missing_params: bool = True + log_norms: bool = False class EMAModule: """Exponential Moving Average of Fairseq Models""" - def __init__(self, model, config: EMAModuleConfig, device=None, skip_keys=None): + def __init__( + self, + model, + config: EMAModuleConfig, + copy_model=True, + device=None, + skip_keys=None, + ): """ @param model model to initialize the EMA with @param config EMAConfig object with configuration like @@ -37,11 +56,18 @@ class EMAModule: Otherwise EMA is in the same device as the model. """ - self.decay = config.ema_decay - self.model = copy.deepcopy(model) - self.model.requires_grad_(False) self.config = config + + if copy_model: + self.model = copy.deepcopy(model) + self.model.requires_grad_(False) + else: + self.model = model + + self.config = config + self.decay = config.ema_decay self.skip_keys = skip_keys or set() + self.add_missing_params = config.add_missing_params self.fp32_params = {} if device is not None: @@ -51,7 +77,8 @@ class EMAModule: if self.config.ema_fp32: self.build_fp32_params() - self.update_freq_counter = 0 + self.log_norms = config.log_norms and multi_tensor_l2norm_available + self.logs = {} def build_fp32_params(self, state_dict=None): """ @@ -74,9 +101,16 @@ class EMAModule: for param_key in state_dict: if param_key in self.fp32_params: - self.fp32_params[param_key].copy_(state_dict[param_key]) + if param_key == "__sq_mom": + self.fp32_params[param_key] = state_dict[param_key] + else: + self.fp32_params[param_key].copy_(state_dict[param_key]) else: self.fp32_params[param_key] = _to_float(state_dict[param_key]) + if "__sq_mom" in self.fp32_params: + self.fp32_params["__sq_mom"][param_key] = torch.zeros_like( + self.fp32_params[param_key] + ) def restore(self, state_dict, build_fp32_params=False): """Load data from a model spec into EMA model""" @@ -84,8 +118,10 @@ class EMAModule: if build_fp32_params: self.build_fp32_params(state_dict) - def set_decay(self, decay): + def set_decay(self, decay, weight_decay=None): self.decay = decay + if weight_decay is not None: + self.weight_decay = weight_decay def get_decay(self): return self.decay @@ -98,9 +134,17 @@ class EMAModule: ema_params = ( self.fp32_params if self.config.ema_fp32 else self.model.state_dict() ) + + new_p = [] + ema_p = [] + for key, param in new_model.named_parameters(): if isinstance(param, dict): continue + + if not self.add_missing_params and key not in ema_params: + continue + try: ema_param = ema_params[key] except KeyError: @@ -119,18 +163,39 @@ class EMAModule: # Do not decay a model.version pytorch param continue + lr = 1 - decay + if key in self.skip_keys or not param.requires_grad: ema_params[key].copy_(param.to(dtype=ema_param.dtype).data) ema_param = ema_params[key] else: - ema_param.mul_(decay) - ema_param.add_(param.data.to(dtype=ema_param.dtype), alpha=1 - decay) + if self.log_norms: + new_p.append(param) + ema_p.append(ema_param) + + ema_param.mul_(1 - lr) + ema_param.add_(param.data.to(dtype=ema_param.dtype), alpha=lr) ema_state_dict[key] = ema_param for key, param in new_model.named_buffers(): ema_state_dict[key] = param + if self.log_norms: + if "model_norm" in self.logs: + self.prev_model_norm = self.logs["model_norm"] + + chunk_size = 2048 * 32 + has_inf = torch.zeros( + (1, 1), dtype=torch.int, device=next(new_model.parameters()).device + ) + + new_norm = multi_tensor_l2norm(chunk_size, has_inf, [new_p], False) + old_norm = multi_tensor_l2norm(chunk_size, has_inf, [ema_p], False) + + self.logs["model_norm"] = new_norm[0] + self.logs["ema_norm"] = old_norm[0] + self.restore(ema_state_dict, build_fp32_params=False) @torch.no_grad() diff --git a/fairseq/modules/gumbel_vector_quantizer.py b/fairseq/modules/gumbel_vector_quantizer.py index 91655bc5e..867b019f6 100644 --- a/fairseq/modules/gumbel_vector_quantizer.py +++ b/fairseq/modules/gumbel_vector_quantizer.py @@ -21,6 +21,8 @@ class GumbelVectorQuantizer(nn.Module): activation=nn.GELU(), weight_proj_depth=1, weight_proj_factor=1, + hard=True, + std=0, ): """Vector quantization using gumbel softmax @@ -44,6 +46,7 @@ class GumbelVectorQuantizer(nn.Module): self.input_dim = dim self.num_vars = num_vars self.time_first = time_first + self.hard = hard assert ( vq_dim % groups == 0 @@ -53,7 +56,10 @@ class GumbelVectorQuantizer(nn.Module): num_groups = groups if not combine_groups else 1 self.vars = nn.Parameter(torch.FloatTensor(1, num_groups * num_vars, var_dim)) - nn.init.uniform_(self.vars) + if std == 0: + nn.init.uniform_(self.vars) + else: + nn.init.normal_(self.vars, mean=0, std=std) if weight_proj_depth > 1: @@ -151,16 +157,17 @@ class GumbelVectorQuantizer(nn.Module): x = self.weight_proj(x) x = x.view(bsz * tsz * self.groups, -1) - _, k = x.max(-1) - hard_x = ( - x.new_zeros(*x.shape) - .scatter_(-1, k.view(-1, 1), 1.0) - .view(bsz * tsz, self.groups, -1) - ) - hard_probs = torch.mean(hard_x.float(), dim=0) - result["code_perplexity"] = torch.exp( - -torch.sum(hard_probs * torch.log(hard_probs + 1e-7), dim=-1) - ).sum() + with torch.no_grad(): + _, k = x.max(-1) + hard_x = ( + x.new_zeros(*x.shape) + .scatter_(-1, k.view(-1, 1), 1.0) + .view(bsz * tsz, self.groups, -1) + ) + hard_probs = torch.mean(hard_x.float(), dim=0) + result["code_perplexity"] = torch.exp( + -torch.sum(hard_probs * torch.log(hard_probs + 1e-7), dim=-1) + ).sum() avg_probs = torch.softmax( x.view(bsz * tsz, self.groups, -1).float(), dim=-1 @@ -172,7 +179,9 @@ class GumbelVectorQuantizer(nn.Module): result["temp"] = self.curr_temp if self.training: - x = F.gumbel_softmax(x.float(), tau=self.curr_temp, hard=True).type_as(x) + x = F.gumbel_softmax(x.float(), tau=self.curr_temp, hard=self.hard).type_as( + x + ) else: x = hard_x diff --git a/fairseq/modules/kmeans_vector_quantizer.py b/fairseq/modules/kmeans_vector_quantizer.py index 040db1e83..1015c3899 100644 --- a/fairseq/modules/kmeans_vector_quantizer.py +++ b/fairseq/modules/kmeans_vector_quantizer.py @@ -100,15 +100,16 @@ class KmeansVectorQuantizer(nn.Module): assert ze.shape == zq.shape, (ze.shape, zq.shape) x = self._pass_grad(ze, zq) - hard_x = ( - idx.new_zeros(bsz * tsz * self.groups, self.num_vars) - .scatter_(-1, idx.view(-1, 1), 1.0) - .view(bsz * tsz, self.groups, -1) - ) - hard_probs = torch.mean(hard_x.float(), dim=0) - result["code_perplexity"] = torch.exp( - -torch.sum(hard_probs * torch.log(hard_probs + 1e-7), dim=-1) - ).sum() + with torch.no_grad(): + hard_x = ( + idx.new_zeros(bsz * tsz * self.groups, self.num_vars) + .scatter_(-1, idx.view(-1, 1), 1.0) + .view(bsz * tsz, self.groups, -1) + ) + hard_probs = torch.mean(hard_x.float(), dim=0) + result["code_perplexity"] = torch.exp( + -torch.sum(hard_probs * torch.log(hard_probs + 1e-7), dim=-1) + ).sum() if produce_targets: result["targets"] = idx diff --git a/fairseq/modules/multihead_attention.py b/fairseq/modules/multihead_attention.py index 1cf3abe94..262132dfe 100644 --- a/fairseq/modules/multihead_attention.py +++ b/fairseq/modules/multihead_attention.py @@ -551,7 +551,7 @@ class MultiheadAttention(FairseqIncrementalDecoder): self.out_proj.weight, self.out_proj.bias, self.training or self.dropout_module.apply_during_inference, - key_padding_mask, + key_padding_mask.bool() if key_padding_mask is not None else None, need_weights, attn_mask, use_separate_proj_weight=True, diff --git a/fairseq/modules/same_pad.py b/fairseq/modules/same_pad.py index 4c04990ea..a3ce4131c 100644 --- a/fairseq/modules/same_pad.py +++ b/fairseq/modules/same_pad.py @@ -19,3 +19,15 @@ class SamePad(nn.Module): if self.remove > 0: x = x[:, :, : -self.remove] return x + + +class SamePad2d(nn.Module): + def __init__(self, kernel_size): + super().__init__() + self.remove = 1 if kernel_size % 2 == 0 else 0 + + def forward(self, x): + assert len(x.size()) == 4 + if self.remove > 0: + x = x[:, :, : -self.remove, : -self.remove] + return x diff --git a/fairseq/modules/transpose_last.py b/fairseq/modules/transpose_last.py index e578b3ec5..d7cca9a4b 100644 --- a/fairseq/modules/transpose_last.py +++ b/fairseq/modules/transpose_last.py @@ -10,11 +10,12 @@ import torch.nn as nn class TransposeLast(nn.Module): - def __init__(self, deconstruct_idx=None): + def __init__(self, deconstruct_idx=None, tranpose_dim=-2): super().__init__() self.deconstruct_idx = deconstruct_idx + self.tranpose_dim = tranpose_dim def forward(self, x): if self.deconstruct_idx is not None: x = x[self.deconstruct_idx] - return x.transpose(-2, -1) + return x.transpose(self.tranpose_dim, -1) diff --git a/fairseq/nan_detector.py b/fairseq/nan_detector.py index 7d46d766d..bd0f91107 100644 --- a/fairseq/nan_detector.py +++ b/fairseq/nan_detector.py @@ -38,7 +38,7 @@ class NanDetector: for name, param in self.named_parameters: if param.grad is not None: grad_norm = torch.norm(param.grad.data.float(), p=2) - norm[name] = grad_norm.item() + norm[name] = param.norm().item() if torch.isnan(grad_norm).any() or torch.isinf(grad_norm).any(): gradients[name] = param.grad.data if len(gradients) > 0: diff --git a/fairseq/optim/composite.py b/fairseq/optim/composite.py index 63701ee8b..1ef0114ed 100644 --- a/fairseq/optim/composite.py +++ b/fairseq/optim/composite.py @@ -13,6 +13,7 @@ from fairseq.dataclass import FairseqDataclass from fairseq.optim import FairseqOptimizer, register_optimizer, _build_optimizer from fairseq.optim.lr_scheduler import FairseqLRScheduler, build_lr_scheduler from omegaconf import II, open_dict +import copy logger = logging.getLogger(__name__) @@ -37,6 +38,12 @@ class CompositeOptimizerConfig(FairseqDataclass): "Configures a different optimizer and (optionally) lr scheduler for each parameter group" }, ) + dynamic_groups: bool = field( + default=False, + metadata={ + "help": "create groups dynamically based on parameters, if set to False, all parameters needs to have group_names" + }, + ) @register_optimizer("composite", dataclass=CompositeOptimizerConfig) @@ -54,31 +61,107 @@ class FairseqCompositeOptimizer(FairseqOptimizer): len(params) > 1 ), "Composite optimizer only works when there are multiple parameter groups (try fp16_no_flatten_grads: true)" + def dict_hash(dictionary: Dict[str, Any]) -> str: + import hashlib + import json + + dhash = hashlib.md5() + encoded = json.dumps(dictionary, sort_keys=True).encode() + dhash.update(encoded) + return dhash.hexdigest() + groupped_params = defaultdict(list) - for p in params: - group = getattr(p, "param_group", "default") - groupped_params[group].append(p) - - assert groupped_params.keys() == cfg.groups.keys(), ( - f"Parameter groups {groupped_params.keys()} and optimizer groups {cfg.groups.keys()} are not the same! " - "Try setting 'param_group' on your parameters in the model." - ) - - for group, group_params in groupped_params.items(): - group_cfg = cfg.groups[group] - with open_dict(group_cfg): - if group_cfg.lr_float is not None: - group_cfg.optimizer.lr = [group_cfg.lr_float] - group_cfg.lr_scheduler.lr = [group_cfg.lr_float] + overrides = defaultdict(dict) + if not cfg.dynamic_groups: + for p in params: + group = getattr(p, "param_group", "default") + override_config = getattr(p, "optim_overrides", None) + if override_config is not None and bool(override_config): + overrides[group] = override_config else: - group_cfg.optimizer.lr = group_cfg.lr - group_cfg.lr_scheduler.lr = group_cfg.lr - self.optimizers[group] = _build_optimizer(group_cfg.optimizer, group_params) - if group_cfg.lr_scheduler is not None: - self.lr_schedulers[group] = build_lr_scheduler( - group_cfg.lr_scheduler, self.optimizers[group] - ) + assert ( + override_config == None or override_config == overrides[group] + ), f"For group {group}, different overrides found {override_config} v/s {overrides[group]}" + groupped_params[group].append(p) + for p, params in groupped_params.items(): + override_config = getattr(params[0], "optim_overrides", None) + if override_config is not None: + for pp in params[1:]: + assert override_config == getattr( + pp, "optim_overrides", None + ), f" {str(override_config)} != {str(getattr(pp, 'optim_overrides', None))}" + else: + for p in params: + group = getattr(p, "param_group", "default") + override_config = getattr(p, "optim_overrides", None) + if override_config is not None: + override_config["group_name"] = group + group_name = dict_hash(override_config) + overrides[group_name] = override_config + else: + group_name = group + groupped_params[group_name].append(p) + + self.optimizers_config = {} + for group, group_params in groupped_params.items(): + p_group = group + if group in overrides and "group_name" in overrides[group]: + p_group = overrides[group]["group_name"] + if group in cfg.groups: + group_cfg = cfg.groups[group] + optimizer_config = copy.deepcopy(group_cfg.optimizer) + scheduler_config = copy.deepcopy(group_cfg.lr_scheduler) + explicit_group_present = True + else: + group_cfg = cfg.groups[p_group] + optimizer_config = copy.deepcopy(group_cfg.optimizer) + scheduler_config = copy.deepcopy(group_cfg.lr_scheduler) + explicit_group_present = False + + if getattr(group_cfg, "lr_float", None) is not None: + with open_dict(optimizer_config): + optimizer_config.lr = [group_cfg.lr_float] + + if group in overrides and "optimizer" in overrides[group]: + with open_dict(optimizer_config): + if "lr_scale" in overrides[group]["optimizer"]: + lr_scale = overrides[group]["optimizer"]["lr_scale"] + optimizer_config.lr = [ + lr * lr_scale for lr in optimizer_config.lr + ] + + if explicit_group_present: + logger.info( + f"For group:{group}, config as well as override present for lr" + ) + + if ( + "weight_decay_scale" in overrides[group]["optimizer"] + and "optimizer_config" in optimizer_config + ): + weight_decay_scale = overrides[group]["optimizer"][ + "weight_decay_scale" + ] + optimizer_config.weight_decay = ( + optimizer_config.weight_decay * weight_decay_scale + ) + if explicit_group_present: + logger.info( + f"For group:{group}, config as well as override present for weight_decay" + ) + + with open_dict(scheduler_config): + scheduler_config.lr = optimizer_config.lr + self.optimizers[group] = _build_optimizer(optimizer_config, group_params) + self.optimizers_config[group] = optimizer_config + if scheduler_config is not None: + self.lr_schedulers[group] = build_lr_scheduler( + scheduler_config, self.optimizers[group] + ) + logger.info("Optimizers for different groups are as below") + for group in self.optimizers_config.keys(): + logger.info(f"Group : {group}:{self.optimizers_config[group]}") if len(self.lr_schedulers) > 0: assert len(self.lr_schedulers) == len(self.optimizers), ( f"Please provide an lr scheduler for each optimizer to use pass_through scheduler. " diff --git a/fairseq/optim/fp16_optimizer.py b/fairseq/optim/fp16_optimizer.py index 2c4ee326e..6a4da342c 100644 --- a/fairseq/optim/fp16_optimizer.py +++ b/fairseq/optim/fp16_optimizer.py @@ -72,6 +72,8 @@ class _FP16OptimizerMixin(object): p32.grad = torch.zeros_like(p32.data) if hasattr(p, "param_group"): p32.param_group = p.param_group + if hasattr(p, "optim_overrides"): + p32.optim_overrides = p.optim_overrides fp32_params.append(p32) return fp32_params @@ -194,6 +196,9 @@ class _FP16OptimizerMixin(object): 0, aggregate_norm_fn ) + if torch.is_tensor(self._multiply_factor): + self._multiply_factor = self._multiply_factor.to(grad_norm.device) + if self.scaler is not None: if grad_norm > max_norm > 0.0: self._multiply_factor *= max_norm / grad_norm diff --git a/fairseq/optim/fused_adam.py b/fairseq/optim/fused_adam.py index 1290ecfdb..39a2a8369 100644 --- a/fairseq/optim/fused_adam.py +++ b/fairseq/optim/fused_adam.py @@ -210,6 +210,9 @@ class FusedAdamV1(torch.optim.Optimizer): exp_avg_sq = exp_avg_sq.float() * state["exp_avg_sq_scale"] beta1, beta2 = group["betas"] + if "step" not in state: + state["step"] = group["step"] + state["step"] += 1 with torch.cuda.device(p_data_fp32.device): diff --git a/fairseq/optim/lr_scheduler/cosine_lr_scheduler.py b/fairseq/optim/lr_scheduler/cosine_lr_scheduler.py index d01143498..5fcaea25d 100644 --- a/fairseq/optim/lr_scheduler/cosine_lr_scheduler.py +++ b/fairseq/optim/lr_scheduler/cosine_lr_scheduler.py @@ -77,9 +77,8 @@ class CosineLRSchedule(FairseqLRScheduler): ) self.max_lr = cfg.lr[0] if isinstance(cfg.lr, Collection) else cfg.lr - assert ( - self.max_lr > cfg.min_lr - ), f"max_lr (={cfg.lr}) must be more than min_lr (={cfg.min_lr})" + if self.max_lr < cfg.min_lr: + cfg.min_lr = self.max_lr warmup_end_lr = self.max_lr if cfg.warmup_init_lr < 0: diff --git a/fairseq/registry.py b/fairseq/registry.py index f3b940604..904ffcd60 100644 --- a/fairseq/registry.py +++ b/fairseq/registry.py @@ -36,8 +36,9 @@ def setup_registry(registry_name: str, base_class=None, default=None, required=F choice = cfg._name if choice and choice in DATACLASS_REGISTRY: + from_checkpoint = extra_kwargs.get("from_checkpoint", False) dc = DATACLASS_REGISTRY[choice] - cfg = merge_with_parent(dc(), cfg) + cfg = merge_with_parent(dc(), cfg, remove_missing=from_checkpoint) elif isinstance(cfg, str): choice = cfg if choice in DATACLASS_REGISTRY: @@ -58,6 +59,9 @@ def setup_registry(registry_name: str, base_class=None, default=None, required=F else: builder = cls + if "from_checkpoint" in extra_kwargs: + del extra_kwargs["from_checkpoint"] + return builder(cfg, *extra_args, **extra_kwargs) def register_x(name, dataclass=None): diff --git a/fairseq/tasks/__init__.py b/fairseq/tasks/__init__.py index 9a46b012c..6da1f001f 100644 --- a/fairseq/tasks/__init__.py +++ b/fairseq/tasks/__init__.py @@ -35,8 +35,9 @@ def setup_task(cfg: FairseqDataclass, **kwargs): task_name = getattr(cfg, "_name", None) if task_name and task_name in TASK_DATACLASS_REGISTRY: + remove_missing = "from_checkpoint" in kwargs and kwargs["from_checkpoint"] dc = TASK_DATACLASS_REGISTRY[task_name] - cfg = merge_with_parent(dc(), cfg) + cfg = merge_with_parent(dc(), cfg, remove_missing=remove_missing) task = TASK_REGISTRY[task_name] assert ( @@ -68,7 +69,8 @@ def register_task(name, dataclass=None): def register_task_cls(cls): if name in TASK_REGISTRY: - raise ValueError("Cannot register duplicate task ({})".format(name)) + return TASK_REGISTRY[name] + if not issubclass(cls, FairseqTask): raise ValueError( "Task ({}: {}) must extend FairseqTask".format(name, cls.__name__) diff --git a/fairseq/tasks/audio_finetuning.py b/fairseq/tasks/audio_finetuning.py index 5e04a1b79..77634f1ba 100644 --- a/fairseq/tasks/audio_finetuning.py +++ b/fairseq/tasks/audio_finetuning.py @@ -100,6 +100,7 @@ class AudioFinetuningConfig(AudioPretrainingConfig): "adds 'prev_output_tokens' to input and appends eos to target" }, ) + rebuild_batches: bool = True @register_task("audio_finetuning", dataclass=AudioFinetuningConfig) diff --git a/fairseq/tasks/audio_pretraining.py b/fairseq/tasks/audio_pretraining.py index a55c70400..e6de0666a 100644 --- a/fairseq/tasks/audio_pretraining.py +++ b/fairseq/tasks/audio_pretraining.py @@ -12,9 +12,9 @@ import sys from argparse import Namespace from dataclasses import dataclass, field from typing import Optional -from omegaconf import MISSING, II, OmegaConf +from omegaconf import MISSING, II -from fairseq.data import BinarizedAudioDataset, FileAudioDataset +from fairseq.data import BinarizedAudioDataset, FileAudioDataset, SubsampleDataset from fairseq.dataclass import FairseqDataclass, ChoiceEnum from fairseq.data.text_compressor import TextCompressionLevel @@ -25,24 +25,16 @@ logger = logging.getLogger(__name__) @dataclass -class InferredW2vConfig: - # The following are needed to precompute mask and mask channel indices - # before model's forward. - mask_length: Optional[int] = II("model.mask_length") - mask_prob: Optional[float] = II("model.mask_prob") - mask_selection: Optional[str] = II("model.mask_selection") - mask_other: Optional[float] = II("model.mask_other") - no_mask_overlap: Optional[bool] = II("model.no_mask_overlap") - mask_min_space: Optional[int] = II("model.mask_min_space") - mask_channel_length: Optional[int] = II("model.mask_channel_length") - mask_channel_prob: Optional[float] = II("model.mask_channel_prob") - mask_channel_selection: Optional[str] = II("model.mask_channel_selection") - mask_channel_other: Optional[float] = II("model.mask_channel_other") - no_mask_channel_overlap: Optional[bool] = II("model.no_mask_channel_overlap") - mask_channel_min_space: Optional[int] = II("model.mask_channel_min_space") - - conv_feature_layers: Optional[str] = II("model.conv_feature_layers") - encoder_embed_dim: Optional[int] = II("model.encoder_embed_dim") +class AudioMaskingConfig: + feature_encoder_spec: str = II("model.modalities.audio.feature_encoder_spec") + mask_prob: float = II("model.modalities.audio.mask_prob") + mask_prob_adjust: float = II("model.modalities.audio.mask_prob_adjust") + mask_length: int = II("model.modalities.audio.mask_length") + inverse_mask: bool = II("model.modalities.audio.inverse_mask") + mask_dropout: float = II("model.modalities.audio.mask_dropout") + clone_batch: int = II("model.clone_batch") + expand_adjacent: bool = False + non_overlapping: bool = False @dataclass @@ -82,20 +74,6 @@ class AudioPretrainingConfig(FairseqDataclass): default=0, metadata={"help": "number of buckets"}, ) - precompute_mask_indices: bool = field( - default=False, - metadata={ - "help": "flag to compute mask indices in data preparation.", - }, - ) - - inferred_w2v_config: Optional[InferredW2vConfig] = field( - default=None, - metadata={ - "help": "wav2vec 2.0 masking arguments used to pre-compute masks (required for TPU)", - }, - ) - tpu: bool = II("common.tpu") text_compression_level: ChoiceEnum([x.name for x in TextCompressionLevel]) = field( default="none", @@ -105,6 +83,14 @@ class AudioPretrainingConfig(FairseqDataclass): }, ) + rebuild_batches: bool = True + precompute_mask_config: Optional[AudioMaskingConfig] = None + + post_save_script: Optional[str] = None + + subsample: float = 1 + seed: int = II("common.seed") + @register_task("audio_pretraining", dataclass=AudioPretrainingConfig) class AudioPretrainingTask(FairseqTask): @@ -122,17 +108,6 @@ class AudioPretrainingTask(FairseqTask): return cls(cfg) - def _get_mask_precompute_kwargs(self, cfg): - if self.cfg.precompute_mask_indices or self.cfg.tpu: - assert ( - cfg.inferred_w2v_config is not None - ), "inferred_w2v_config must be set" - return OmegaConf.to_container( - cfg.inferred_w2v_config, resolve=True, enum_to_str=True - ) - else: - return {} - def load_dataset(self, split: str, task_cfg: FairseqDataclass = None, **kwargs): data_path = self.cfg.data task_cfg = task_cfg or self.cfg @@ -145,6 +120,12 @@ class AudioPretrainingTask(FairseqTask): text_compression_level = getattr( TextCompressionLevel, str(self.cfg.text_compression_level) ) + + compute_mask = task_cfg.precompute_mask_config is not None + mask_args = {} + if compute_mask: + mask_args = task_cfg.precompute_mask_config + if getattr(task_cfg, "binarized_dataset", False): self.datasets[split] = BinarizedAudioDataset( data_path, @@ -155,8 +136,8 @@ class AudioPretrainingTask(FairseqTask): pad=task_cfg.labels is not None or task_cfg.enable_padding, normalize=task_cfg.normalize, num_buckets=self.cfg.num_batch_buckets or int(self.cfg.tpu), - compute_mask_indices=(self.cfg.precompute_mask_indices or self.cfg.tpu), - **self._get_mask_precompute_kwargs(task_cfg), + compute_mask=compute_mask, + **mask_args, ) else: manifest_path = os.path.join(data_path, "{}.tsv".format(split)) @@ -169,9 +150,17 @@ class AudioPretrainingTask(FairseqTask): pad=task_cfg.labels is not None or task_cfg.enable_padding, normalize=task_cfg.normalize, num_buckets=self.cfg.num_batch_buckets or int(self.cfg.tpu), - compute_mask_indices=(self.cfg.precompute_mask_indices or self.cfg.tpu), text_compression_level=text_compression_level, - **self._get_mask_precompute_kwargs(task_cfg), + compute_mask=compute_mask, + **mask_args, + ) + + if getattr(task_cfg, "subsample", 1) < 1: + self.datasets[split] = SubsampleDataset( + self.datasets[split], + task_cfg.subsample, + shuffle=True, + seed=task_cfg.seed, ) if self.cfg.tpu and task_cfg.inferred_w2v_config.mask_channel_prob == 0.0: @@ -181,14 +170,6 @@ class AudioPretrainingTask(FairseqTask): "0. You may want to set this to a low value close to 0." ) - @property - def source_dictionary(self): - return None - - @property - def target_dictionary(self): - return None - def max_positions(self): """Maximum input length supported by the encoder.""" return sys.maxsize, sys.maxsize @@ -203,3 +184,24 @@ class AudioPretrainingTask(FairseqTask): model_cfg.w2v_args = actualized_cfg.w2v_args return model + + def post_save(self, cp_path, num_updates): + if self.cfg.post_save_script is not None: + logger.info(f"launching {self.cfg.post_save_script}") + import os.path as osp + from fairseq.file_io import PathManager + + eval_cp_path = osp.join( + osp.dirname(cp_path), f"checkpoint_eval_{num_updates}.pt" + ) + + print(cp_path, eval_cp_path, osp.dirname(cp_path)) + + assert PathManager.copy( + cp_path, eval_cp_path, overwrite=True + ), f"Failed to copy {cp_path} to {eval_cp_path}" + + import subprocess + import shlex + + subprocess.call(shlex.split(f"{self.cfg.post_save_script} {eval_cp_path}")) diff --git a/fairseq/tasks/fairseq_task.py b/fairseq/tasks/fairseq_task.py index 3cba8f224..7481e0f4b 100644 --- a/fairseq/tasks/fairseq_task.py +++ b/fairseq/tasks/fairseq_task.py @@ -272,6 +272,7 @@ class FairseqTask(object): and not update_epoch_batch_itr and self.can_reuse_epoch_itr(dataset) ) + logger.info(f"can_reuse_epoch_itr = {can_reuse_epoch_itr}") if can_reuse_epoch_itr and dataset in self.dataset_to_epoch_iter: logger.debug("reusing EpochBatchIterator for epoch {}".format(epoch)) return self.dataset_to_epoch_iter[dataset] @@ -281,26 +282,39 @@ class FairseqTask(object): # initialize the dataset with the correct starting epoch dataset.set_epoch(epoch) - # get indices ordered by example size - with data_utils.numpy_seed(seed): - indices = dataset.ordered_indices() + def make_batches(dataset, epoch): + logger.info(f"creating new batches for epoch {epoch}") - # filter examples that are too large - if max_positions is not None: - indices = self.filter_indices_by_size( - indices, dataset, max_positions, ignore_invalid_inputs + # get indices ordered by example size + with data_utils.numpy_seed(seed + epoch): + indices = dataset.ordered_indices() + + # filter examples that are too large + if max_positions is not None: + indices = self.filter_indices_by_size( + indices, dataset, max_positions, ignore_invalid_inputs + ) + + # create mini-batches with given size constraints + batches = dataset.batch_by_size( + indices, + max_tokens=max_tokens, + max_sentences=max_sentences, + required_batch_size_multiple=required_batch_size_multiple, ) - - # create mini-batches with given size constraints - batch_sampler = dataset.batch_by_size( - indices, - max_tokens=max_tokens, - max_sentences=max_sentences, - required_batch_size_multiple=required_batch_size_multiple, - ) + return batches reuse_dataloader = getattr(self.cfg, "reuse_dataloader", True) - persistent_workers = getattr(self.cfg, "persistent_workers", False) + persistent_workers = getattr(self.cfg, "persistent_workers", True) + rebuild_batches = getattr(self.cfg, "rebuild_batches", False) + logger.info(f"reuse_dataloader = {reuse_dataloader}") + logger.info(f"rebuild_batches = {rebuild_batches}") + + if rebuild_batches: + logger.info("batches will be rebuilt for each epoch") + batch_sampler = make_batches + else: + batch_sampler = make_batches(dataset, epoch) # return a reusable, sharded iterator epoch_iter = iterators.EpochBatchIterator( @@ -341,7 +355,7 @@ class FairseqTask(object): model = quantization_utils.quantize_model_scalar(model, cfg) return model - def build_criterion(self, cfg: DictConfig): + def build_criterion(self, cfg: DictConfig, from_checkpoint=False): """ Build the :class:`~fairseq.criterions.FairseqCriterion` instance for this task. @@ -354,7 +368,7 @@ class FairseqTask(object): """ from fairseq import criterions - return criterions.build_criterion(cfg, self) + return criterions.build_criterion(cfg, self, from_checkpoint=from_checkpoint) def build_generator( self, @@ -614,13 +628,13 @@ class FairseqTask(object): def source_dictionary(self): """Return the source :class:`~fairseq.data.Dictionary` (if applicable for this task).""" - raise NotImplementedError + return None @property def target_dictionary(self): """Return the target :class:`~fairseq.data.Dictionary` (if applicable for this task).""" - raise NotImplementedError + return None def build_tokenizer(self, args): """Build the pre-tokenizer for this task.""" diff --git a/fairseq/tasks/masked_lm.py b/fairseq/tasks/masked_lm.py index 6393ee480..b064907a5 100644 --- a/fairseq/tasks/masked_lm.py +++ b/fairseq/tasks/masked_lm.py @@ -20,6 +20,7 @@ from fairseq.data import ( NumSamplesDataset, PrependTokenDataset, RightPadDataset, + RightPaddingMaskDataset, SortDataset, TokenBlockDataset, data_utils, @@ -106,6 +107,22 @@ class MaskedLMConfig(FairseqDataclass): "help": "include target tokens in model input. this is used for data2vec" }, ) + include_index: bool = field( + default=True, + metadata={"help": "include index in model input. this is used for data2vec"}, + ) + skip_masking: bool = field( + default=False, + metadata={"help": "skip masking at dataset"}, + ) + # subsample_train: float = field( + # default=1, + # metadata={"help": "shorten training set for debugging"}, + # ) + d2v2_multi: bool = field( + default=False, + metadata={"help": "prepare dataset for data2vec_multi"}, + ) @register_task("masked_lm", dataclass=MaskedLMConfig) @@ -115,20 +132,25 @@ class MaskedLMTask(FairseqTask): """Task for training masked language models (e.g., BERT, RoBERTa).""" - def __init__(self, cfg: MaskedLMConfig, dictionary): + def __init__(self, cfg: MaskedLMConfig, dictionary=None): super().__init__(cfg) - self.dictionary = dictionary + self.dictionary = dictionary or self.load_dict(cfg) # add mask token - self.mask_idx = dictionary.add_symbol("") + self.mask_idx = self.dictionary.add_symbol("") @classmethod def setup_task(cls, cfg: MaskedLMConfig, **kwargs): + dictionary = cls.load_dict(cfg) + return cls(cfg, dictionary) + + @classmethod + def load_dict(cls, cfg): paths = utils.split_paths(cfg.data) assert len(paths) > 0 dictionary = Dictionary.load(os.path.join(paths[0], "dict.txt")) logger.info("dictionary: {} types".format(len(dictionary))) - return cls(cfg, dictionary) + return dictionary def _load_dataset_split(self, split, epoch, combine): paths = utils.split_paths(self.cfg.data) @@ -197,6 +219,7 @@ class MaskedLMTask(FairseqTask): mask_whole_words=mask_whole_words, mask_multiple_length=self.cfg.mask_multiple_length, mask_stdev=self.cfg.mask_stdev, + skip_masking=self.cfg.skip_masking, ) with data_utils.numpy_seed(self.cfg.seed): @@ -207,6 +230,16 @@ class MaskedLMTask(FairseqTask): pad_idx=self.source_dictionary.pad(), ) + if self.cfg.d2v2_multi: + dataset = self._d2v2_multi_dataset(src_dataset) + else: + dataset = self._regular_dataset(src_dataset, target_dataset) + + self.datasets[split] = SortDataset( + dataset, sort_order=[shuffle, src_dataset.sizes] + ) + + def _regular_dataset(self, src_dataset, target_dataset): input_dict = { "src_tokens": RightPadDataset( src_dataset, @@ -216,23 +249,41 @@ class MaskedLMTask(FairseqTask): } if self.cfg.include_target_tokens: input_dict["target_tokens"] = target_dataset + if self.cfg.include_index: + input_dict["src_id"] = IdDataset() - self.datasets[split] = SortDataset( - NestedDictionaryDataset( - { - "id": IdDataset(), - "net_input": input_dict, - "target": target_dataset, - "nsentences": NumSamplesDataset(), - "ntokens": NumelDataset(src_dataset, reduce=True), - }, - sizes=[src_dataset.sizes], - ), - sort_order=[ - shuffle, - src_dataset.sizes, - ], + dataset = NestedDictionaryDataset( + { + "id": IdDataset(), + "net_input": input_dict, + "target": target_dataset, + "nsentences": NumSamplesDataset(), + "ntokens": NumelDataset(src_dataset, reduce=True), + }, + sizes=[src_dataset.sizes], ) + return dataset + + def _d2v2_multi_dataset(self, src_dataset): + input_dict = { + "source": RightPadDataset( + src_dataset, + pad_idx=self.source_dictionary.pad(), + ), + "id": IdDataset(), + "padding_mask": RightPaddingMaskDataset(src_dataset), + } + + dataset = NestedDictionaryDataset( + { + "id": IdDataset(), + "net_input": input_dict, + "nsentences": NumSamplesDataset(), + "ntokens": NumelDataset(src_dataset, reduce=True), + }, + sizes=[src_dataset.sizes], + ) + return dataset def build_dataset_for_inference(self, src_tokens, src_lengths, sort=True): src_dataset = RightPadDataset( @@ -268,3 +319,9 @@ class MaskedLMTask(FairseqTask): @property def target_dictionary(self): return self.dictionary + + def begin_epoch(self, epoch, model): + model.set_epoch(epoch) + + def max_positions(self): + return self.cfg.tokens_per_sample diff --git a/fairseq/tasks/sentence_prediction.py b/fairseq/tasks/sentence_prediction.py index 52532ff61..de80addaf 100644 --- a/fairseq/tasks/sentence_prediction.py +++ b/fairseq/tasks/sentence_prediction.py @@ -23,6 +23,7 @@ from fairseq.data import ( PrependTokenDataset, RawLabelDataset, RightPadDataset, + RightPaddingMaskDataset, RollDataset, SortDataset, StripTokenDataset, @@ -83,6 +84,11 @@ class SentencePredictionConfig(FairseqDataclass): classification_head_name: str = II("criterion.classification_head_name") seed: int = II("common.seed") + d2v2_multi: bool = field( + default=False, + metadata={"help": "prepare dataset for data2vec_multi"}, + ) + @register_task("sentence_prediction", dataclass=SentencePredictionConfig) class SentencePredictionTask(FairseqTask): @@ -181,28 +187,39 @@ class SentencePredictionTask(FairseqTask): self.cfg.seed, ) - dataset = { - "id": IdDataset(), - "net_input": { + if self.cfg.d2v2_multi: + net_input = { + "source": RightPadDataset( + src_tokens, + pad_idx=self.source_dictionary.pad(), + ), + "id": IdDataset(), + "padding_mask": RightPaddingMaskDataset(src_tokens), + } + else: + net_input = { "src_tokens": RightPadDataset( src_tokens, pad_idx=self.source_dictionary.pad(), ), "src_lengths": NumelDataset(src_tokens, reduce=False), - }, + } + if self.cfg.add_prev_output_tokens: + prev_tokens_dataset = RightPadDataset( + RollDataset(src_tokens, 1), + pad_idx=self.dictionary.pad(), + ) + net_input.update( + prev_output_tokens=prev_tokens_dataset, + ) + + dataset = { + "id": IdDataset(), + "net_input": net_input, "nsentences": NumSamplesDataset(), "ntokens": NumelDataset(src_tokens, reduce=True), } - if self.cfg.add_prev_output_tokens: - prev_tokens_dataset = RightPadDataset( - RollDataset(src_tokens, 1), - pad_idx=self.dictionary.pad(), - ) - dataset["net_input"].update( - prev_output_tokens=prev_tokens_dataset, - ) - if not self.cfg.regression_target: label_dataset = make_dataset("label", self.label_dictionary) if label_dataset is not None: diff --git a/fairseq/trainer.py b/fairseq/trainer.py index da1f94910..16b1b9169 100644 --- a/fairseq/trainer.py +++ b/fairseq/trainer.py @@ -30,7 +30,6 @@ from fairseq.nan_detector import NanDetector from fairseq.optim import lr_scheduler from fairseq.utils import safe_hasattr - logger = logging.getLogger(__name__) @@ -288,12 +287,27 @@ class Trainer(object): return self._lr_scheduler def _build_optimizer(self): - params = list( - filter( - lambda p: p.requires_grad, - chain(self.model.parameters(), self.criterion.parameters()), + + if ( + self.cfg.optimization.debug_param_names + and self.cfg.common.fp16_no_flatten_grads + ): + params = [] + self.param_names = [] + + for n, p in chain( + self.model.named_parameters(), self.criterion.named_parameters() + ): + if p.requires_grad: + params.append(p) + self.param_names.append(n) + else: + params = list( + filter( + lambda p: p.requires_grad, + chain(self.model.parameters(), self.criterion.parameters()), + ) ) - ) if self.is_fsdp and self.cfg.common.fp16: # FullyShardedDataParallel always uses MemoryEfficientFP16 wrapper, @@ -432,18 +446,21 @@ class Trainer(object): def save_checkpoint(self, filename, extra_state): """Save all training state in a checkpoint file.""" - - logger.info(f"Saving checkpoint to {os.path.abspath(filename)}") - # call state_dict on all ranks in case it needs internal communication - state_dict = utils.move_to_cpu(self.state_dict()) - state_dict["extra_state"].update(extra_state) if self.should_save_checkpoint_on_current_rank: + + logger.info(f"Saving checkpoint to {os.path.abspath(filename)}") + # call state_dict on all ranks in case it needs internal communication + state_dict = utils.move_to_cpu(self.state_dict()) + state_dict["extra_state"].update(extra_state) + checkpoint_utils.torch_persistent_save( state_dict, filename, async_write=self.cfg.checkpoint.write_checkpoints_asynchronously, ) - logger.info(f"Finished saving checkpoint to {os.path.abspath(filename)}") + logger.info(f"Finished saving checkpoint to {os.path.abspath(filename)}") + return os.path.abspath(filename) + return None def load_checkpoint( self, @@ -793,6 +810,8 @@ class Trainer(object): if self.cfg.ema.store_ema and getattr(self.task, "uses_ema", False): extra_kwargs["ema_model"] = self.ema.get_model() + has_oom = False + # forward and backward pass logging_outputs, sample_size, ooms = [], 0, 0 for i, sample in enumerate(samples): # delayed update loop @@ -842,17 +861,9 @@ class Trainer(object): except RuntimeError as e: if "out of memory" in str(e): self._log_oom(e) + has_oom = True if raise_oom: raise e - logger.warning( - "attempting to recover from OOM in forward/backward pass" - ) - ooms += 1 - self.zero_grad() - if self.cuda: - torch.cuda.empty_cache() - if self.cfg.distributed_training.distributed_world_size == 1: - return None else: raise e except Exception: @@ -862,6 +873,18 @@ class Trainer(object): ) raise + if has_oom: + logger.warning( + "attempting to recover from OOM in forward/backward pass" + ) + ooms += 1 + self.zero_grad() + if self.cuda: + torch.cuda.empty_cache() + + if self.cfg.distributed_training.distributed_world_size == 1: + return None + if self.tpu and i < len(samples) - 1: # tpu-comment: every XLA operation before marking step is # appended to the IR graph, and processing too many batches @@ -989,6 +1012,14 @@ class Trainer(object): logger.info( f"NOTE: gradient overflow detected, ignoring gradient, {str(e)}" ) + + if hasattr(self, "param_names") and hasattr( + self.optimizer, "fp32_optimizer" + ): + for p, n in zip(self.optimizer.fp32_optimizer.params, self.param_names): + if torch.isinf(p.grad).any() or torch.isnan(p.grad).any(): + logger.info(f"overflow in param {n}") + grad_norm = torch.tensor(0.0).cuda() self.zero_grad() except RuntimeError as e: diff --git a/fairseq/utils.py b/fairseq/utils.py index e6196863f..e5c77e4b9 100644 --- a/fairseq/utils.py +++ b/fairseq/utils.py @@ -270,9 +270,9 @@ def strip_pad(tensor, pad): return tensor[tensor.ne(pad)] -def buffered_arange(max): +def buffered_arange(max, device="cpu"): if not hasattr(buffered_arange, "buf"): - buffered_arange.buf = torch.LongTensor() + buffered_arange.buf = torch.LongTensor().to(device) if max > buffered_arange.buf.numel(): buffered_arange.buf.resize_(max) torch.arange(max, out=buffered_arange.buf) diff --git a/fairseq_cli/hydra_validate.py b/fairseq_cli/hydra_validate.py new file mode 100644 index 000000000..cb6f7612d --- /dev/null +++ b/fairseq_cli/hydra_validate.py @@ -0,0 +1,188 @@ +#!/usr/bin/env python3 -u +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import os +import sys +from itertools import chain + +import torch +from hydra.core.hydra_config import HydraConfig +from omegaconf import OmegaConf, open_dict +import hydra + +from fairseq import checkpoint_utils, distributed_utils, utils +from fairseq.dataclass.configs import FairseqConfig +from fairseq.dataclass.initialize import add_defaults, hydra_init +from fairseq.dataclass.utils import omegaconf_no_object_check +from fairseq.distributed import utils as distributed_utils +from fairseq.logging import metrics, progress_bar +from fairseq.utils import reset_logging + +logging.basicConfig( + format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + level=os.environ.get("LOGLEVEL", "INFO").upper(), + stream=sys.stdout, +) +logger = logging.getLogger("fairseq_cli.validate") + + +@hydra.main(config_path=os.path.join("..", "fairseq", "config"), config_name="config") +def hydra_main(cfg: FairseqConfig) -> float: + return _hydra_main(cfg) + + +def _hydra_main(cfg: FairseqConfig, **kwargs) -> float: + add_defaults(cfg) + + if cfg.common.reset_logging: + reset_logging() # Hydra hijacks logging, fix that + else: + # check if directly called or called through hydra_main + if HydraConfig.initialized(): + with open_dict(cfg): + # make hydra logging work with ddp (see # see https://github.com/facebookresearch/hydra/issues/1126) + cfg.job_logging_cfg = OmegaConf.to_container( + HydraConfig.get().job_logging, resolve=True + ) + + with omegaconf_no_object_check(): + cfg = OmegaConf.create( + OmegaConf.to_container(cfg, resolve=True, enum_to_str=True) + ) + OmegaConf.set_struct(cfg, True) + + assert ( + cfg.dataset.max_tokens is not None or cfg.dataset.batch_size is not None + ), "Must specify batch size either with --max-tokens or --batch-size" + + distributed_utils.call_main(cfg, validate, **kwargs) + + +def validate(cfg): + utils.import_user_module(cfg.common) + + use_fp16 = cfg.common.fp16 + use_cuda = torch.cuda.is_available() and not cfg.common.cpu + + if use_cuda: + torch.cuda.set_device(cfg.distributed_training.device_id) + + if cfg.distributed_training.distributed_world_size > 1: + data_parallel_world_size = distributed_utils.get_data_parallel_world_size() + data_parallel_rank = distributed_utils.get_data_parallel_rank() + else: + data_parallel_world_size = 1 + data_parallel_rank = 0 + + overrides = {"task": {"data": cfg.task.data}} + + # Load ensemble + logger.info("loading model(s) from {}".format(cfg.common_eval.path)) + models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task( + [cfg.common_eval.path], + arg_overrides=overrides, + suffix=cfg.checkpoint.checkpoint_suffix, + ) + model = models[0] + + # Move models to GPU + for model in models: + model.eval() + if use_fp16: + model.half() + if use_cuda: + model.cuda() + + # Print args + logger.info(saved_cfg) + + # Build criterion + criterion = task.build_criterion(saved_cfg.criterion, from_checkpoint=True) + criterion.eval() + + for subset in cfg.dataset.valid_subset.split(","): + try: + task.load_dataset(subset, combine=False, epoch=1, task_cfg=saved_cfg.task) + dataset = task.dataset(subset) + except KeyError: + raise Exception("Cannot find dataset: " + subset) + + # Initialize data iterator + itr = task.get_batch_iterator( + dataset=dataset, + max_tokens=cfg.dataset.max_tokens, + max_sentences=cfg.dataset.batch_size, + max_positions=utils.resolve_max_positions( + task.max_positions(), + *[m.max_positions() for m in models], + ), + ignore_invalid_inputs=cfg.dataset.skip_invalid_size_inputs_valid_test, + required_batch_size_multiple=cfg.dataset.required_batch_size_multiple, + seed=cfg.common.seed, + num_shards=data_parallel_world_size, + shard_id=data_parallel_rank, + num_workers=cfg.dataset.num_workers, + data_buffer_size=cfg.dataset.data_buffer_size, + ).next_epoch_itr(shuffle=False) + progress = progress_bar.progress_bar( + itr, + log_format=cfg.common.log_format, + log_interval=cfg.common.log_interval, + prefix=f"valid on '{subset}' subset", + default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"), + ) + + def apply_half(t): + if t.dtype is torch.float32: + return t.to(dtype=torch.half) + return t + + log_outputs = [] + for i, sample in enumerate(progress): + sample = utils.move_to_cuda(sample) if use_cuda else sample + + if use_fp16: + sample = utils.apply_to_sample(apply_half, sample) + + _loss, _sample_size, log_output = task.valid_step(sample, model, criterion) + with metrics.aggregate() as agg: + task.reduce_metrics([log_output], criterion) + progress.log(agg.get_smoothed_values(), step=i) + # progress.log(log_output, step=i) from vision + log_outputs.append(log_output) + + if data_parallel_world_size > 1: + log_outputs = distributed_utils.all_gather_list( + log_outputs, + max_size=cfg.common.all_gather_list_size, + group=distributed_utils.get_data_parallel_group(), + ) + log_outputs = list(chain.from_iterable(log_outputs)) + + with metrics.aggregate() as agg: + task.reduce_metrics(log_outputs, criterion) + log_output = agg.get_smoothed_values() + + progress.print(log_output, tag=subset, step=i) + + +def cli_main(): + try: + from hydra._internal.utils import get_args + + cfg_name = get_args().config_name or "config" + except: + logger.warning("Failed to get config name from hydra args") + cfg_name = "config" + + hydra_init(cfg_name) + hydra_main() + + +if __name__ == "__main__": + cli_main() diff --git a/fairseq_cli/train.py b/fairseq_cli/train.py index 376bd1d03..f771bff65 100644 --- a/fairseq_cli/train.py +++ b/fairseq_cli/train.py @@ -125,12 +125,13 @@ def main(cfg: FairseqConfig) -> None: # Load valid dataset (we load training data below, based on the latest checkpoint) # We load the valid dataset AFTER building the model - data_utils.raise_if_valid_subsets_unintentionally_ignored(cfg) - if cfg.dataset.combine_valid_subsets: - task.load_dataset("valid", combine=True, epoch=1) - else: - for valid_sub_split in cfg.dataset.valid_subset.split(","): - task.load_dataset(valid_sub_split, combine=False, epoch=1) + if not cfg.dataset.disable_validation: + data_utils.raise_if_valid_subsets_unintentionally_ignored(cfg) + if cfg.dataset.combine_valid_subsets: + task.load_dataset("valid", combine=True, epoch=1) + else: + for valid_sub_split in cfg.dataset.valid_subset.split(","): + task.load_dataset(valid_sub_split, combine=False, epoch=1) # (optionally) Configure quantization if cfg.common.quantization_config_path is not None: @@ -175,6 +176,20 @@ def main(cfg: FairseqConfig) -> None: max_epoch = cfg.optimization.max_epoch or math.inf lr = trainer.get_lr() + # TODO: a dry run on validation set to pin the memory + valid_subsets = cfg.dataset.valid_subset.split(",") + if not cfg.dataset.disable_validation: + for subset in valid_subsets: + logger.info('begin dry-run validation on "{}" subset'.format(subset)) + itr = trainer.get_valid_iterator(subset).next_epoch_itr( + shuffle=False, set_dataset_epoch=False # use a fixed valid set + ) + if cfg.common.tpu: + itr = utils.tpu_data_loader(itr) + for _ in itr: + pass + # TODO: end of dry run section + train_meter = meters.StopwatchMeter() train_meter.start() while epoch_itr.next_epoch_idx <= max_epoch: @@ -424,9 +439,11 @@ def validate_and_save( # Save checkpoint if do_save or should_stop: - checkpoint_utils.save_checkpoint( + cp_path = checkpoint_utils.save_checkpoint( cfg.checkpoint, trainer, epoch_itr, valid_losses[0] ) + if cp_path is not None and hasattr(task, "post_save"): + task.post_save(cp_path, num_updates) return valid_losses, should_stop diff --git a/hydra_plugins/dependency_submitit_launcher/hydra_plugins/dependency_submitit_launcher/__init__.py b/hydra_plugins/dependency_submitit_launcher/hydra_plugins/dependency_submitit_launcher/__init__.py new file mode 100644 index 000000000..4884f5bdc --- /dev/null +++ b/hydra_plugins/dependency_submitit_launcher/hydra_plugins/dependency_submitit_launcher/__init__.py @@ -0,0 +1,3 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved + +__version__ = "0.1" diff --git a/hydra_plugins/dependency_submitit_launcher/hydra_plugins/dependency_submitit_launcher/config.py b/hydra_plugins/dependency_submitit_launcher/hydra_plugins/dependency_submitit_launcher/config.py new file mode 100644 index 000000000..91926c4ab --- /dev/null +++ b/hydra_plugins/dependency_submitit_launcher/hydra_plugins/dependency_submitit_launcher/config.py @@ -0,0 +1,23 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +from dataclasses import dataclass, field + +from hydra.core.config_store import ConfigStore + +from hydra_plugins.hydra_submitit_launcher.config import SlurmQueueConf + + +@dataclass +class DependencySubmititConf(SlurmQueueConf): + """Slurm configuration overrides and specific parameters""" + + _target_: str = ( + "hydra_plugins.dependency_submitit_launcher.launcher.DependencySubmititLauncher" + ) + + +ConfigStore.instance().store( + group="hydra/launcher", + name="dependency_submitit_slurm", + node=DependencySubmititConf(), + provider="dependency_submitit_slurm", +) diff --git a/hydra_plugins/dependency_submitit_launcher/hydra_plugins/dependency_submitit_launcher/launcher.py b/hydra_plugins/dependency_submitit_launcher/hydra_plugins/dependency_submitit_launcher/launcher.py new file mode 100644 index 000000000..b3fcf79e1 --- /dev/null +++ b/hydra_plugins/dependency_submitit_launcher/hydra_plugins/dependency_submitit_launcher/launcher.py @@ -0,0 +1,121 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +import logging +import os +import subprocess +from pathlib import Path +from typing import Any, List, Sequence + +from hydra.core.singleton import Singleton +from hydra.core.utils import JobReturn, filter_overrides +from omegaconf import OmegaConf + +log = logging.getLogger(__name__) + +from .config import DependencySubmititConf +from hydra_plugins.hydra_submitit_launcher.submitit_launcher import BaseSubmititLauncher + + +class DependencySubmititLauncher(BaseSubmititLauncher): + _EXECUTOR = "slurm" + + def launch( + self, job_overrides: Sequence[Sequence[str]], initial_job_idx: int + ) -> Sequence[JobReturn]: + + # lazy import to ensure plugin discovery remains fast + import submitit + + assert self.config is not None + + num_jobs = len(job_overrides) + assert num_jobs > 0 + + next_script = None + + for jo in job_overrides: + if next_script is None: + for item in jo: + if "next_script=" in item: + next_script = item + break + assert ( + next_script is not None + ), "job overrides must contain +next_script=path/to/next/script" + jo.remove(next_script) + + idx = next_script.find("=") + next_script = next_script[idx + 1 :] + + params = self.params + # build executor + init_params = {"folder": self.params["submitit_folder"]} + specific_init_keys = {"max_num_timeout"} + + init_params.update( + **{ + f"{self._EXECUTOR}_{x}": y + for x, y in params.items() + if x in specific_init_keys + } + ) + init_keys = specific_init_keys | {"submitit_folder"} + executor = submitit.AutoExecutor(cluster=self._EXECUTOR, **init_params) + + # specify resources/parameters + baseparams = set(OmegaConf.structured(DependencySubmititConf).keys()) + params = { + x if x in baseparams else f"{self._EXECUTOR}_{x}": y + for x, y in params.items() + if x not in init_keys + } + executor.update_parameters(**params) + + log.info( + f"Submitit '{self._EXECUTOR}' sweep output dir : " + f"{self.config.hydra.sweep.dir}" + ) + sweep_dir = Path(str(self.config.hydra.sweep.dir)) + sweep_dir.mkdir(parents=True, exist_ok=True) + if "mode" in self.config.hydra.sweep: + mode = int(str(self.config.hydra.sweep.mode), 8) + os.chmod(sweep_dir, mode=mode) + + job_params: List[Any] = [] + for idx, overrides in enumerate(job_overrides): + idx = initial_job_idx + idx + lst = " ".join(filter_overrides(overrides)) + log.info(f"\t#{idx} : {lst}") + job_params.append( + ( + list(overrides), + "hydra.sweep.dir", + idx, + f"job_id_for_{idx}", + Singleton.get_state(), + ) + ) + + jobs = executor.map_array(self, *zip(*job_params)) + + for j, jp in zip(jobs, job_params): + job_id = str(j.job_id) + task_id = "0" if "_" not in job_id else job_id.split("_")[1] + sweep_config = self.config_loader.load_sweep_config(self.config, jp[0]) + dir = sweep_config.hydra.sweep.dir + + dir = ( + dir.replace("[", "") + .replace("]", "") + .replace("{", "") + .replace("}", "") + .replace(",", "_") + .replace("'", "") + .replace('"', "") + ) + + subprocess.call( + [next_script, job_id, task_id, dir], + shell=False, + ) + + return [j.results()[0] for j in jobs] diff --git a/hydra_plugins/dependency_submitit_launcher/setup.py b/hydra_plugins/dependency_submitit_launcher/setup.py new file mode 100644 index 000000000..bf795462b --- /dev/null +++ b/hydra_plugins/dependency_submitit_launcher/setup.py @@ -0,0 +1,29 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# type: ignore +from pathlib import Path + +from read_version import read_version +from setuptools import find_namespace_packages, setup + +setup( + name="dependency-submitit-launcher", + version=read_version("hydra_plugins/dependency_submitit_launcher", "__init__.py"), + author="Alexei Baevski", + author_email="abaevski@fb.com", + description="Dependency-supporting Submitit Launcher for Hydra apps", + packages=find_namespace_packages(include=["hydra_plugins.*"]), + classifiers=[ + "License :: OSI Approved :: MIT License", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Operating System :: MacOS", + "Operating System :: POSIX :: Linux", + "Development Status :: 4 - Beta", + ], + install_requires=[ + "hydra-core>=1.0.4", + "submitit>=1.0.0", + ], + include_package_data=True, +) diff --git a/setup.py b/setup.py index 8a9b2f977..dae06080c 100644 --- a/setup.py +++ b/setup.py @@ -184,10 +184,12 @@ def do_setup(package_data): "numpy>=1.21.3", "regex", "sacrebleu>=1.4.12", - "torch>=1.10", + "torch>=1.13", "tqdm", "bitarray", "torchaudio>=0.8.0", + "scikit-learn", + "packaging", ], extras_require={ "dev": ["flake8", "pytest", "black==22.3.0"], diff --git a/tests/test_binaries.py b/tests/test_binaries.py index 1ab92f5f7..41d9210e7 100644 --- a/tests/test_binaries.py +++ b/tests/test_binaries.py @@ -11,6 +11,7 @@ import random import sys import tempfile import unittest +from packaging import version from io import StringIO from typing import Dict, List @@ -625,6 +626,10 @@ class TestTranslation(unittest.TestCase): ) generate_main(data_dir, extra_flags=[]) + @unittest.skipIf( + version.parse(torch.__version__) > version.parse("1.8"), + "skip for latest torch versions", + ) def test_transformer_pointer_generator(self): with contextlib.redirect_stdout(StringIO()): with tempfile.TemporaryDirectory( @@ -1827,6 +1832,8 @@ def train_masked_lm(data_dir, arch, extra_flags=None): "masked_lm", "--batch-size", "500", + "--required-batch-size-multiple", + "1", "--save-dir", data_dir, "--max-epoch", diff --git a/tests/test_ema.py b/tests/test_ema.py index 847316ff4..bd2cf2c78 100644 --- a/tests/test_ema.py +++ b/tests/test_ema.py @@ -6,6 +6,7 @@ import unittest from copy import deepcopy from dataclasses import dataclass +import pytest from typing import Optional from unittest.mock import patch @@ -160,8 +161,7 @@ class TestEMA(unittest.TestCase): self._test_ema_start_update(updates=1) def test_ema_fp32(self): - # CPU no longer supports Linear in half precision - dtype = torch.half if torch.cuda.is_available() else torch.float + dtype = torch.float model = DummyModule().to(dtype) optimizer = torch.optim.SGD(model.parameters(), lr=0.01) @@ -213,12 +213,12 @@ class TestEMA(unittest.TestCase): ).to(dtype), ) + @pytest.mark.skipif( + not torch.cuda.is_available(), + reason="CPU no longer supports Linear in half precision", + ) def test_ema_fp16(self): - # CPU no longer supports Linear in half precision - if not torch.cuda.is_available(): - return - - model = DummyModule().half() + model = DummyModule().cuda().half() optimizer = torch.optim.SGD(model.parameters(), lr=0.01) state = deepcopy(model.state_dict()) config = EMAConfig(ema_fp32=False) @@ -227,7 +227,7 @@ class TestEMA(unittest.TestCase): # Since fp32 params is not used, it should be of size 0 self.assertEqual(len(ema.fp32_params), 0) - x = torch.randn(32) + x = torch.randn(32).cuda() y = model(x.half()) loss = y.sum() loss.backward() diff --git a/tests/test_multihead_attention.py b/tests/test_multihead_attention.py index 5301318c2..4a0b430b6 100644 --- a/tests/test_multihead_attention.py +++ b/tests/test_multihead_attention.py @@ -101,6 +101,7 @@ def test_mask_for_xformers(): @pytest.mark.skipif(not torch.cuda.is_available(), reason="blocksparse requires gpu") +@pytest.mark.skip(reason="not part of latest xformers") @pytest.mark.parametrize("device", ["cuda"]) @pytest.mark.parametrize("add_zero_attn", [False]) @pytest.mark.parametrize("batch_size", [20]) diff --git a/tests/test_valid_subset_checks.py b/tests/test_valid_subset_checks.py index 3e9191bda..c39fb8982 100644 --- a/tests/test_valid_subset_checks.py +++ b/tests/test_valid_subset_checks.py @@ -126,13 +126,18 @@ class TestCombineValidSubsets(unittest.TestCase): return [x.message for x in logs.records] def test_combined(self): - flags = ["--combine-valid-subsets"] + flags = ["--combine-valid-subsets", "--required-batch-size-multiple", "1"] logs = self._train(flags) assert any(["valid1" in x for x in logs]) # loaded 100 examples from valid1 assert not any(["valid1_ppl" in x for x in logs]) # metrics are combined def test_subsets(self): - flags = ["--valid-subset", "valid,valid1"] + flags = [ + "--valid-subset", + "valid,valid1", + "--required-batch-size-multiple", + "1", + ] logs = self._train(flags) assert any(["valid_ppl" in x for x in logs]) # loaded 100 examples from valid1 assert any(["valid1_ppl" in x for x in logs]) # metrics are combined