diff --git a/examples/data2vec/README.md b/examples/data2vec/README.md
index 9fd05d804..5b680ef8c 100644
--- a/examples/data2vec/README.md
+++ b/examples/data2vec/README.md
@@ -1,3 +1,125 @@
+# data2vec 2.0
+
+data2vec 2.0 improves the training efficiency of the original data2vec algorithm. We make the following improvements for efficiency considerations - we forward only the unmasked timesteps through the encoder, we use convolutional decoder and we use multimasking to amortize the compute overhead of the teacher model. You can find details in [Efficient Self-supervised Learning with Contextualized Target Representations for Vision, Speech and Language](https://ai.facebook.com/research/xyz)
+
+## Pretrained and finetuned models
+### Vision
+| Model | Finetuning split | Link
+|---|---|---
+data2vec ViT-B | No fine-tuning | [download](https://dl.fbaipublicfiles.com/fairseq/data2vec2/base_imagenet.pt)
+data2vec ViT-B | Imagenet-1K | [download](https://dl.fbaipublicfiles.com/fairseq/data2vec2/base_imagenet_ft.pt)
+data2vec ViT-L | No fine-tuning | [download](https://dl.fbaipublicfiles.com/fairseq/data2vec2/large_imagenet.pt)
+data2vec ViT-L | Imagenet-1K | [download](https://dl.fbaipublicfiles.com/fairseq/data2vec2/large_imagenet_ft.pt)
+data2vec ViT-H | No fine-tuning | [download](https://dl.fbaipublicfiles.com/fairseq/data2vec2/huge_imagenet.pt)
+data2vec ViT-H | Imagenet-1K | [download](https://dl.fbaipublicfiles.com/fairseq/data2vec2/huge_imagenet_ft.pt)
+
+Vision models only are license under CC-BY-NC.
+### Speech
+
+| Model | Finetuning split | Dataset | Link
+|---|---|---|---
+data2vec Base | No fine-tuning | [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/data2vec2/base_libri.pt)
+data2vec Base | 960 hours | [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/data2vec2/base_libri_960h.pt)
+data2vec Large | No fine-tuning | [Libri-light](https://github.com/facebookresearch/libri-light) | [download](https://dl.fbaipublicfiles.com/fairseq/data2vec2/large_vox.pt)
+data2vec Large | 960 hours | [Libri-light](https://github.com/facebookresearch/libri-light) | [download](https://dl.fbaipublicfiles.com/fairseq/data2vec2/large_vox_960h.pt)
+
+### NLP
+
+Model | Fine-tuning data | Dataset | Link
+|---|---|---|---|
+data2vec Base | No fine-tuning | Books + Wiki | [download](https://dl.fbaipublicfiles.com/fairseq/data2vec2/nlp_base.pt)
+
+[//]: # (## Data Preparation)
+
+[//]: # ()
+[//]: # (### Vision)
+
+[//]: # (add details)
+
+[//]: # (### Speech)
+
+[//]: # (add details)
+
+[//]: # ()
+[//]: # (### NLP)
+
+[//]: # (add details)
+
+
+## Commands to train different models using data2vec 2.0
+
+### Vision
+
+Commands to pretrain different model configurations
+```shell script
+$ python fairseq_cli/hydra_train.py -m --config-dir examples/data2vec/config/v2 \
+--config-name base_images_only_task task.data=/path/to/dir
+```
+
+```shell script
+$ python fairseq_cli/hydra_train.py -m --config-dir examples/data2vec/config/v2 \
+--config-name large_images_only_task task.data=/path/to/dir
+```
+
+```shell script
+$ python fairseq_cli/hydra_train.py -m --config-dir examples/data2vec/config/v2 \
+--config-name huge_images14_only_task task.data=/path/to/dir
+```
+
+Commands to finetune different model configurations
+
+```shell script
+$ python fairseq_cli/hydra_train.py -m --config-dir examples/data2vec/config/vision/finetuning \
+--config-name mae_imagenet_clean task.data=/path/to/dir model.model_path=/path/to/pretrained/model
+```
+
+```shell script
+$ python fairseq_cli/hydra_train.py -m --config-dir examples/data2vec/config/vision/finetuning \
+--config-name mae_imagenet_large_clean task.data=/path/to/dir model.model_path=/path/to/pretrained/model
+```
+
+```shell script
+$ python fairseq_cli/hydra_train.py -m --config-dir examples/data2vec/config/vision/finetuning \
+--config-name mae_imagenet_huge_clean task.data=/path/to/dir model.model_path=/path/to/pretrained/model
+```
+
+### Speech
+
+```shell script
+$ python fairseq_cli/hydra_train.py -m --config-dir examples/data2vec/config/v2 \
+--config-name base_audio_only_task task.data=/path/to/manifests
+```
+
+```shell script
+$ python fairseq_cli/hydra_train.py -m --config-dir examples/data2vec/config/v2 \
+--config-name large_audio_only_task task.data=/path/to/manifests
+```
+
+Finetuning:
+
+```shell script
+$ python fairseq_cli/hydra_train.py -m --config-dir examples/wav2vec/config/finetuning --config-name vox_10h \
+task.data=/path/to/manifests model.w2v_path=/path/to/pretrained/model common.user_dir=examples/data2vec
+```
+
+Replace vox_10h with the right config depending on your model and fine-tuning split.
+See examples/wav2vec/config/finetuning for all available configs.
+
+### NLP
+
+Commands to pretrain
+```shell script
+$ python fairseq_cli/hydra_train.py -m --config-dir examples/data2vec/config/v2 \
+--config-name base_text_only_task task.data=/path/to/file
+```
+
+Commands to fine-tune all GLUE tasks
+```shell script
+$ task=cola # choose from [cola|qnli|mrpc|rte|sst_2|mnli|qqp|sts_b]
+$ lr=1e-5 # sweep [1e-5|2e-5|4e-5|6e-5] for each task
+$ python fairseq_cli/hydra_train.py -m --config-dir examples/data2vec/config/v2/text_finetuning \
+--config-name $task task.data=/path/to/file model.model_path=/path/to/pretrained/model "optimization.lr=[${lr}]"
+```
# data2vec
diff --git a/examples/data2vec/__init__.py b/examples/data2vec/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/examples/data2vec/config/audio/classification/base_classification.yaml b/examples/data2vec/config/audio/classification/base_classification.yaml
new file mode 100644
index 000000000..fdb9c8d3d
--- /dev/null
+++ b/examples/data2vec/config/audio/classification/base_classification.yaml
@@ -0,0 +1,70 @@
+# @package _group_
+
+common:
+ fp16: true
+ log_format: json
+ log_interval: 200
+ all_gather_list_size: 70000
+ tensorboard_logdir: tb
+ min_loss_scale: 1e-6
+
+checkpoint:
+ save_interval: 1
+ no_epoch_checkpoints: true
+ best_checkpoint_metric: mAP
+ maximize_best_checkpoint_metric: true
+
+task:
+ _name: audio_classification
+ data: ???
+ normalize: true
+ labels: lbl
+
+dataset:
+ num_workers: 6
+ max_tokens: 2560000
+ skip_invalid_size_inputs_valid_test: true
+ valid_subset: eval
+ validate_interval: 5
+
+distributed_training:
+ ddp_backend: legacy_ddp
+ distributed_world_size: 8
+
+criterion:
+ _name: model
+ can_sum: false
+ log_keys:
+ - _predictions
+ - _targets
+
+optimization:
+ max_update: 30000
+ lr: [0.00006] # scratch 53-5
+
+optimizer:
+ _name: adam
+ adam_betas: (0.9,0.98)
+ adam_eps: 1e-08
+
+lr_scheduler:
+ _name: cosine
+ warmup_updates: 5000
+
+model:
+ _name: audio_classification
+ model_path: ???
+ apply_mask: true
+ mask_prob: 0.6
+ mask_length: 5 # scratch 1
+ mask_channel_prob: 0
+ mask_channel_length: 64
+ layerdrop: 0.1
+ dropout: 0.1
+ activation_dropout: 0.1
+ attention_dropout: 0.2
+ feature_grad_mult: 0 # scratch 1
+ label_mixup: true
+ source_mixup: 0.5
+ prediction_mode: lin_softmax # scratch average_sigmoid
+
diff --git a/examples/data2vec/config/audio/classification/run_config/slurm_1.yaml b/examples/data2vec/config/audio/classification/run_config/slurm_1.yaml
new file mode 100644
index 000000000..881a1583f
--- /dev/null
+++ b/examples/data2vec/config/audio/classification/run_config/slurm_1.yaml
@@ -0,0 +1,35 @@
+# @package _global_
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: ':'
+ item_sep: '/'
+ exclude_keys:
+ - run_config
+ - distributed_training.distributed_port
+ - distributed_training.distributed_world_size
+ - model.pretrained_model_path
+ - model.target_network_path
+ - next_script
+ - task.cache_in_scratch
+ - task.data
+ - checkpoint.save_interval_updates
+ - checkpoint.keep_interval_updates
+ - checkpoint.save_on_overflow
+ sweep:
+ dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
+ subdir: ''
+ launcher:
+ submitit_folder: ${hydra.sweep.dir}
+ timeout_min: 4320
+ cpus_per_task: 10
+ gpus_per_node: 8
+ tasks_per_node: 8
+ mem_gb: 450
+ nodes: 1
+ name: ${env:PREFIX}_${hydra.job.config_name}
+ partition: devlab,learnlab,learnfair,scavenge
+ constraint: volta32gb,ib4
+ max_num_timeout: 30
diff --git a/examples/data2vec/config/audio/classification/run_config/slurm_1g.yaml b/examples/data2vec/config/audio/classification/run_config/slurm_1g.yaml
new file mode 100644
index 000000000..de7894d9c
--- /dev/null
+++ b/examples/data2vec/config/audio/classification/run_config/slurm_1g.yaml
@@ -0,0 +1,35 @@
+# @package _global_
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: ':'
+ item_sep: '/'
+ exclude_keys:
+ - run_config
+ - distributed_training.distributed_port
+ - distributed_training.distributed_world_size
+ - model.pretrained_model_path
+ - model.target_network_path
+ - next_script
+ - task.cache_in_scratch
+ - task.data
+ - checkpoint.save_interval_updates
+ - checkpoint.keep_interval_updates
+ - checkpoint.save_on_overflow
+ sweep:
+ dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
+ subdir: ''
+ launcher:
+ submitit_folder: ${hydra.sweep.dir}
+ timeout_min: 4320
+ cpus_per_task: 10
+ gpus_per_node: 1
+ tasks_per_node: 1
+ mem_gb: 100
+ nodes: 1
+ name: ${env:PREFIX}_${hydra.job.config_name}
+ partition: devlab,learnlab,learnfair,scavenge
+ constraint: volta32gb
+ max_num_timeout: 30
diff --git a/examples/data2vec/config/audio/classification/run_config/slurm_2.yaml b/examples/data2vec/config/audio/classification/run_config/slurm_2.yaml
new file mode 100644
index 000000000..b016cac9b
--- /dev/null
+++ b/examples/data2vec/config/audio/classification/run_config/slurm_2.yaml
@@ -0,0 +1,35 @@
+# @package _global_
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: ':'
+ item_sep: '/'
+ exclude_keys:
+ - run_config
+ - distributed_training.distributed_port
+ - distributed_training.distributed_world_size
+ - model.pretrained_model_path
+ - model.target_network_path
+ - next_script
+ - task.cache_in_scratch
+ - task.data
+ - checkpoint.save_interval_updates
+ - checkpoint.keep_interval_updates
+ - checkpoint.save_on_overflow
+ sweep:
+ dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
+ subdir: ''
+ launcher:
+ submitit_folder: ${hydra.sweep.dir}
+ timeout_min: 4320
+ cpus_per_task: 10
+ gpus_per_node: 8
+ tasks_per_node: 8
+ mem_gb: 450
+ nodes: 2
+ name: ${env:PREFIX}_${hydra.job.config_name}
+ partition: devlab,learnlab,learnfair,scavenge
+ constraint: volta32gb,ib4
+ max_num_timeout: 30
diff --git a/examples/data2vec/config/audio/pretraining/audioset.yaml b/examples/data2vec/config/audio/pretraining/audioset.yaml
new file mode 100644
index 000000000..dd30fbedd
--- /dev/null
+++ b/examples/data2vec/config/audio/pretraining/audioset.yaml
@@ -0,0 +1,91 @@
+# @package _group_
+
+common:
+ fp16: true
+ log_format: json
+ log_interval: 200
+ tensorboard_logdir: tb
+ min_loss_scale: 1e-6
+ user_dir: /private/home/abaevski/fairseq-py/examples/data2vec
+
+checkpoint:
+ save_interval: 1
+ save_interval_updates: 25000
+ keep_interval_updates: 1
+ no_epoch_checkpoints: true
+
+task:
+ _name: audio_pretraining
+ data: /private/home/abaevski/data/audioset
+ max_sample_size: 320000
+ min_sample_size: 32000
+ normalize: true
+
+dataset:
+ num_workers: 6
+ max_tokens: 3400000
+ skip_invalid_size_inputs_valid_test: true
+ validate_interval: 5
+ required_batch_size_multiple: 1
+ disable_validation: true
+
+distributed_training:
+ distributed_world_size: 24
+ ddp_backend: legacy_ddp
+
+criterion:
+ _name: model
+ log_keys:
+ - ema_decay
+ - target_var
+ - pred_var
+# - avg_self_attn
+# - weights
+
+optimization:
+ max_update: 200000
+ lr: [0.0005]
+
+optimizer:
+ _name: adam
+ adam_betas: (0.9,0.98)
+ adam_eps: 1e-06
+ weight_decay: 0.01
+
+lr_scheduler:
+ _name: cosine
+ warmup_updates: 10000
+
+model:
+ _name: data2vec_audio
+ extractor_mode: layer_norm
+ encoder_layerdrop: 0.05
+ dropout_input: 0.0
+ dropout_features: 0.0
+ feature_grad_mult: 1.0
+ encoder_embed_dim: 768
+
+ mask_prob: 0.65
+ mask_length: 10
+
+ loss_beta: 0
+ loss_scale: null
+
+ instance_norm_target_layer: true
+ layer_norm_targets: true
+ average_top_k_layers: 12
+
+ self_attn_norm_type: deepnorm
+ final_norm_type: deepnorm
+
+ pos_conv_depth: 5
+ conv_pos: 95
+
+ ema_decay: 0.999
+ ema_end_decay: 0.9999
+ ema_anneal_end_step: 30000
+ ema_transformer_only: true
+ ema_layers_only: false
+
+ require_same_masks: true
+ mask_dropout: 0
diff --git a/examples/data2vec/config/audio/pretraining/run_config/local.yaml b/examples/data2vec/config/audio/pretraining/run_config/local.yaml
new file mode 100644
index 000000000..45595f9ee
--- /dev/null
+++ b/examples/data2vec/config/audio/pretraining/run_config/local.yaml
@@ -0,0 +1,15 @@
+# @package _global_
+hydra:
+ sweep:
+ dir: ${env:PWD}/tmp_dbg/${now:%H-%M-%S}
+
+distributed_training:
+ distributed_world_size: 1
+ nprocs_per_node: 1
+ distributed_port: -1
+
+common:
+ log_interval: 1
+
+dataset:
+ num_workers: 0
diff --git a/examples/data2vec/config/audio/pretraining/run_config/slurm_1.yaml b/examples/data2vec/config/audio/pretraining/run_config/slurm_1.yaml
new file mode 100644
index 000000000..732f01889
--- /dev/null
+++ b/examples/data2vec/config/audio/pretraining/run_config/slurm_1.yaml
@@ -0,0 +1,37 @@
+# @package _global_
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: ':'
+ item_sep: '/'
+ exclude_keys:
+ - run_config
+ - distributed_training.distributed_port
+ - distributed_training.distributed_world_size
+ - model.pretrained_model_path
+ - model.target_network_path
+ - next_script
+ - task.cache_in_scratch
+ - task.data
+ - checkpoint.save_interval_updates
+ - checkpoint.keep_interval_updates
+ - checkpoint.save_on_overflow
+ - common.log_interval
+ - common.user_dir
+ sweep:
+ dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
+ subdir: ''
+ launcher:
+ submitit_folder: ${hydra.sweep.dir}
+ timeout_min: 4320
+ cpus_per_task: 80
+ gpus_per_node: 8
+ tasks_per_node: 1
+ mem_gb: 450
+ nodes: 1
+ name: ${env:PREFIX}_${hydra.job.config_name}
+ partition: devlab,learnlab,learnfair,scavenge
+ constraint: volta32gb,ib4
+ max_num_timeout: 30
diff --git a/examples/data2vec/config/audio/pretraining/run_config/slurm_1_aws.yaml b/examples/data2vec/config/audio/pretraining/run_config/slurm_1_aws.yaml
new file mode 100644
index 000000000..e2bab5675
--- /dev/null
+++ b/examples/data2vec/config/audio/pretraining/run_config/slurm_1_aws.yaml
@@ -0,0 +1,36 @@
+# @package _global_
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: ':'
+ item_sep: '/'
+ exclude_keys:
+ - run_config
+ - distributed_training.distributed_port
+ - distributed_training.distributed_world_size
+ - model.pretrained_model_path
+ - model.target_network_path
+ - next_script
+ - task.cache_in_scratch
+ - task.data
+ - checkpoint.save_interval_updates
+ - checkpoint.keep_interval_updates
+ - checkpoint.save_on_overflow
+ - common.log_interval
+ - common.user_dir
+ sweep:
+ dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
+ subdir: ''
+ launcher:
+ submitit_folder: ${hydra.sweep.dir}
+ timeout_min: 4320
+ cpus_per_task: 80
+ gpus_per_node: 8
+ tasks_per_node: 1
+ mem_gb: 0
+ nodes: 1
+ name: ${env:PREFIX}_${hydra.job.config_name}
+ partition: wav2vec,learnlab,learnfair
+ max_num_timeout: 30
diff --git a/examples/data2vec/config/audio/pretraining/run_config/slurm_2.yaml b/examples/data2vec/config/audio/pretraining/run_config/slurm_2.yaml
new file mode 100644
index 000000000..ec53dc2a9
--- /dev/null
+++ b/examples/data2vec/config/audio/pretraining/run_config/slurm_2.yaml
@@ -0,0 +1,37 @@
+# @package _global_
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: ':'
+ item_sep: '/'
+ exclude_keys:
+ - run_config
+ - distributed_training.distributed_port
+ - distributed_training.distributed_world_size
+ - model.pretrained_model_path
+ - model.target_network_path
+ - next_script
+ - task.cache_in_scratch
+ - task.data
+ - checkpoint.save_interval_updates
+ - checkpoint.keep_interval_updates
+ - checkpoint.save_on_overflow
+ - common.log_interval
+ - common.user_dir
+ sweep:
+ dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
+ subdir: ''
+ launcher:
+ submitit_folder: ${hydra.sweep.dir}
+ timeout_min: 4320
+ cpus_per_task: 10
+ gpus_per_node: 8
+ tasks_per_node: 8
+ mem_gb: 450
+ nodes: 2
+ name: ${env:PREFIX}_${hydra.job.config_name}
+ partition: devlab,learnlab,learnfair,scavenge
+ constraint: volta32gb,ib4
+ max_num_timeout: 30
diff --git a/examples/data2vec/config/audio/pretraining/run_config/slurm_2_aws.yaml b/examples/data2vec/config/audio/pretraining/run_config/slurm_2_aws.yaml
new file mode 100644
index 000000000..70cc8cbb5
--- /dev/null
+++ b/examples/data2vec/config/audio/pretraining/run_config/slurm_2_aws.yaml
@@ -0,0 +1,37 @@
+# @package _global_
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: ':'
+ item_sep: '/'
+ exclude_keys:
+ - run_config
+ - distributed_training.distributed_port
+ - distributed_training.distributed_world_size
+ - model.pretrained_model_path
+ - model.target_network_path
+ - next_script
+ - task.cache_in_scratch
+ - task.data
+ - task.post_save_script
+ - checkpoint.save_interval_updates
+ - checkpoint.keep_interval_updates
+ - checkpoint.save_on_overflow
+ - common.log_interval
+ - common.user_dir
+ sweep:
+ dir: /fsx-wav2vec/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
+ subdir: ''
+ launcher:
+ submitit_folder: ${hydra.sweep.dir}
+ timeout_min: 4320
+ cpus_per_task: 10
+ gpus_per_node: 8
+ tasks_per_node: 8
+ mem_gb: 0
+ nodes: 2
+ name: ${env:PREFIX}_${hydra.job.config_name}
+ partition: wav2vec,learnlab,learnfair
+ max_num_timeout: 30
diff --git a/examples/data2vec/config/audio/pretraining/run_config/slurm_3.yaml b/examples/data2vec/config/audio/pretraining/run_config/slurm_3.yaml
new file mode 100644
index 000000000..14b47d14e
--- /dev/null
+++ b/examples/data2vec/config/audio/pretraining/run_config/slurm_3.yaml
@@ -0,0 +1,36 @@
+# @package _global_
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: ':'
+ item_sep: '/'
+ exclude_keys:
+ - run_config
+ - distributed_training.distributed_port
+ - distributed_training.distributed_world_size
+ - model.pretrained_model_path
+ - model.target_network_path
+ - next_script
+ - task.cache_in_scratch
+ - task.data
+ - checkpoint.save_interval_updates
+ - checkpoint.keep_interval_updates
+ - checkpoint.save_on_overflow
+ - common.log_interval
+ sweep:
+ dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
+ subdir: ''
+ launcher:
+ submitit_folder: ${hydra.sweep.dir}
+ timeout_min: 4320
+ cpus_per_task: 80
+ gpus_per_node: 8
+ tasks_per_node: 1
+ mem_gb: 450
+ nodes: 3
+ name: ${env:PREFIX}_${hydra.job.config_name}
+ partition: devlab,learnlab,learnfair,scavenge
+ constraint: volta32gb,ib4
+ max_num_timeout: 30
diff --git a/examples/data2vec/config/audio/pretraining/run_config/slurm_4.yaml b/examples/data2vec/config/audio/pretraining/run_config/slurm_4.yaml
new file mode 100644
index 000000000..c54d735fb
--- /dev/null
+++ b/examples/data2vec/config/audio/pretraining/run_config/slurm_4.yaml
@@ -0,0 +1,36 @@
+# @package _global_
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: ':'
+ item_sep: '/'
+ exclude_keys:
+ - run_config
+ - distributed_training.distributed_port
+ - distributed_training.distributed_world_size
+ - model.pretrained_model_path
+ - model.target_network_path
+ - next_script
+ - task.cache_in_scratch
+ - task.data
+ - checkpoint.save_interval_updates
+ - checkpoint.keep_interval_updates
+ - checkpoint.save_on_overflow
+ - common.log_interval
+ sweep:
+ dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
+ subdir: ''
+ launcher:
+ submitit_folder: ${hydra.sweep.dir}
+ timeout_min: 4320
+ cpus_per_task: 10
+ gpus_per_node: 8
+ tasks_per_node: 8
+ mem_gb: 450
+ nodes: 4
+ name: ${env:PREFIX}_${hydra.job.config_name}
+ partition: devlab,learnlab,learnfair,scavenge
+ constraint: volta32gb,ib4
+ max_num_timeout: 30
diff --git a/examples/data2vec/config/audio/pretraining/run_config/slurm_4_aws.yaml b/examples/data2vec/config/audio/pretraining/run_config/slurm_4_aws.yaml
new file mode 100644
index 000000000..0231b2690
--- /dev/null
+++ b/examples/data2vec/config/audio/pretraining/run_config/slurm_4_aws.yaml
@@ -0,0 +1,37 @@
+# @package _global_
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: ':'
+ item_sep: '/'
+ exclude_keys:
+ - run_config
+ - distributed_training.distributed_port
+ - distributed_training.distributed_world_size
+ - model.pretrained_model_path
+ - model.target_network_path
+ - next_script
+ - task.cache_in_scratch
+ - task.data
+ - task.post_save_script
+ - checkpoint.save_interval_updates
+ - checkpoint.keep_interval_updates
+ - checkpoint.save_on_overflow
+ - common.log_interval
+ - common.user_dir
+ sweep:
+ dir: /fsx-wav2vec/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
+ subdir: ''
+ launcher:
+ submitit_folder: ${hydra.sweep.dir}
+ timeout_min: 4320
+ cpus_per_task: 10
+ gpus_per_node: 8
+ tasks_per_node: 8
+ mem_gb: 0
+ nodes: 4
+ name: ${env:PREFIX}_${hydra.job.config_name}
+ partition: wav2vec,learnlab,learnfair
+ max_num_timeout: 30
diff --git a/examples/data2vec/config/audio/pretraining/run_config/slurm_6_aws.yaml b/examples/data2vec/config/audio/pretraining/run_config/slurm_6_aws.yaml
new file mode 100644
index 000000000..9a4e43a98
--- /dev/null
+++ b/examples/data2vec/config/audio/pretraining/run_config/slurm_6_aws.yaml
@@ -0,0 +1,36 @@
+# @package _global_
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: ':'
+ item_sep: '/'
+ exclude_keys:
+ - run_config
+ - distributed_training.distributed_port
+ - distributed_training.distributed_world_size
+ - model.pretrained_model_path
+ - model.target_network_path
+ - next_script
+ - task.cache_in_scratch
+ - task.data
+ - checkpoint.save_interval_updates
+ - checkpoint.keep_interval_updates
+ - checkpoint.save_on_overflow
+ - common.log_interval
+ - common.user_dir
+ sweep:
+ dir: /fsx-wav2vec/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
+ subdir: ''
+ launcher:
+ submitit_folder: ${hydra.sweep.dir}
+ timeout_min: 4320
+ cpus_per_task: 10
+ gpus_per_node: 8
+ tasks_per_node: 8
+ mem_gb: 0
+ nodes: 6
+ name: ${env:PREFIX}_${hydra.job.config_name}
+ partition: wav2vec,learnlab,learnfair
+ max_num_timeout: 30
diff --git a/examples/data2vec/config/audio/pretraining/run_config/slurm_8_aws.yaml b/examples/data2vec/config/audio/pretraining/run_config/slurm_8_aws.yaml
new file mode 100644
index 000000000..78c9f57ae
--- /dev/null
+++ b/examples/data2vec/config/audio/pretraining/run_config/slurm_8_aws.yaml
@@ -0,0 +1,36 @@
+# @package _global_
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: ':'
+ item_sep: '/'
+ exclude_keys:
+ - run_config
+ - distributed_training.distributed_port
+ - distributed_training.distributed_world_size
+ - model.pretrained_model_path
+ - model.target_network_path
+ - next_script
+ - task.cache_in_scratch
+ - task.data
+ - checkpoint.save_interval_updates
+ - checkpoint.keep_interval_updates
+ - checkpoint.save_on_overflow
+ - common.log_interval
+ - common.user_dir
+ sweep:
+ dir: /fsx-wav2vec/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
+ subdir: ''
+ launcher:
+ submitit_folder: ${hydra.sweep.dir}
+ timeout_min: 4320
+ cpus_per_task: 10
+ gpus_per_node: 8
+ tasks_per_node: 8
+ mem_gb: 0
+ nodes: 8
+ name: ${env:PREFIX}_${hydra.job.config_name}
+ partition: wav2vec,learnlab,learnfair
+ max_num_timeout: 30
diff --git a/examples/data2vec/config/text/pretraining/run_config/local.yaml b/examples/data2vec/config/text/pretraining/run_config/local.yaml
new file mode 100644
index 000000000..45595f9ee
--- /dev/null
+++ b/examples/data2vec/config/text/pretraining/run_config/local.yaml
@@ -0,0 +1,15 @@
+# @package _global_
+hydra:
+ sweep:
+ dir: ${env:PWD}/tmp_dbg/${now:%H-%M-%S}
+
+distributed_training:
+ distributed_world_size: 1
+ nprocs_per_node: 1
+ distributed_port: -1
+
+common:
+ log_interval: 1
+
+dataset:
+ num_workers: 0
diff --git a/examples/data2vec/config/text/pretraining/run_config/slurm_1_aws.yaml b/examples/data2vec/config/text/pretraining/run_config/slurm_1_aws.yaml
new file mode 100644
index 000000000..4bac45a58
--- /dev/null
+++ b/examples/data2vec/config/text/pretraining/run_config/slurm_1_aws.yaml
@@ -0,0 +1,37 @@
+# @package _global_
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: '_'
+ item_sep: '/'
+ exclude_keys:
+ - run_config
+ - distributed_training.distributed_port
+ - distributed_training.distributed_world_size
+ - model.pretrained_model_path
+ - model.target_network_path
+ - next_script
+ - task.cache_in_scratch
+ - task.data
+ - checkpoint.save_interval_updates
+ - checkpoint.keep_interval_updates
+ - checkpoint.save_on_overflow
+ - common.log_interval
+ - common.user_dir
+ sweep:
+ dir: /fsx-wav2vec/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
+ subdir: ''
+ launcher:
+ submitit_folder: ${hydra.sweep.dir}/submitit
+ timeout_min: 4320
+ cpus_per_task: 80
+ gpus_per_node: 8
+ tasks_per_node: 1
+ mem_gb: 0
+ nodes: 1
+ name: ${env:PREFIX}_${hydra.job.config_name}
+ partition: wav2vec
+ max_num_timeout: 30
+ exclude: a100-st-p4d24xlarge-471
diff --git a/examples/data2vec/config/text/pretraining/run_config/slurm_2.yaml b/examples/data2vec/config/text/pretraining/run_config/slurm_2.yaml
new file mode 100644
index 000000000..006a0f211
--- /dev/null
+++ b/examples/data2vec/config/text/pretraining/run_config/slurm_2.yaml
@@ -0,0 +1,37 @@
+# @package _global_
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: ':'
+ item_sep: '/'
+ exclude_keys:
+ - run_config
+ - distributed_training.distributed_port
+ - distributed_training.distributed_world_size
+ - model.pretrained_model_path
+ - model.target_network_path
+ - next_script
+ - task.cache_in_scratch
+ - task.data
+ - checkpoint.save_interval_updates
+ - checkpoint.keep_interval_updates
+ - checkpoint.save_on_overflow
+ - common.log_interval
+ - common.user_dir
+ sweep:
+ dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
+ subdir: ''
+ launcher:
+ submitit_folder: ${hydra.sweep.dir}
+ timeout_min: 4320
+ cpus_per_task: 80
+ gpus_per_node: 8
+ tasks_per_node: 1
+ mem_gb: 450
+ nodes: 2
+ name: ${env:PREFIX}_${hydra.job.config_name}
+ partition: devlab,learnlab,learnfair,scavenge
+ constraint: volta32gb,ib4
+ max_num_timeout: 30
diff --git a/examples/data2vec/config/text/pretraining/run_config/slurm_2_aws.yaml b/examples/data2vec/config/text/pretraining/run_config/slurm_2_aws.yaml
new file mode 100644
index 000000000..4292198b4
--- /dev/null
+++ b/examples/data2vec/config/text/pretraining/run_config/slurm_2_aws.yaml
@@ -0,0 +1,37 @@
+# @package _global_
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: '_'
+ item_sep: '/'
+ exclude_keys:
+ - run_config
+ - distributed_training.distributed_port
+ - distributed_training.distributed_world_size
+ - model.pretrained_model_path
+ - model.target_network_path
+ - next_script
+ - task.cache_in_scratch
+ - task.data
+ - checkpoint.save_interval_updates
+ - checkpoint.keep_interval_updates
+ - checkpoint.save_on_overflow
+ - common.log_interval
+ - common.user_dir
+ sweep:
+ dir: /fsx-wav2vec/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
+ subdir: ''
+ launcher:
+ submitit_folder: ${hydra.sweep.dir}/submitit
+ timeout_min: 4320
+ cpus_per_task: 10
+ gpus_per_node: 8
+ tasks_per_node: 8
+ mem_gb: 0
+ nodes: 2
+ name: ${env:PREFIX}_${hydra.job.config_name}
+ partition: wav2vec
+ max_num_timeout: 30
+ exclude: a100-st-p4d24xlarge-471
diff --git a/examples/data2vec/config/text/pretraining/run_config/slurm_3.yaml b/examples/data2vec/config/text/pretraining/run_config/slurm_3.yaml
new file mode 100644
index 000000000..0e1555d20
--- /dev/null
+++ b/examples/data2vec/config/text/pretraining/run_config/slurm_3.yaml
@@ -0,0 +1,36 @@
+# @package _global_
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: ':'
+ item_sep: '/'
+ exclude_keys:
+ - run_config
+ - distributed_training.distributed_port
+ - distributed_training.distributed_world_size
+ - model.pretrained_model_path
+ - model.target_network_path
+ - next_script
+ - task.cache_in_scratch
+ - task.data
+ - checkpoint.save_interval_updates
+ - checkpoint.keep_interval_updates
+ - checkpoint.save_on_overflow
+ - common.log_interval
+ sweep:
+ dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
+ subdir: ''
+ launcher:
+ submitit_folder: ${hydra.sweep.dir}
+ timeout_min: 4320
+ cpus_per_task: 10
+ gpus_per_node: 8
+ tasks_per_node: 8
+ mem_gb: 450
+ nodes: 3
+ name: ${env:PREFIX}_${hydra.job.config_name}
+ partition: devlab,learnlab,learnfair,scavenge
+ constraint: volta32gb,ib4
+ max_num_timeout: 30
diff --git a/examples/data2vec/config/text/pretraining/run_config/slurm_4.yaml b/examples/data2vec/config/text/pretraining/run_config/slurm_4.yaml
new file mode 100644
index 000000000..c54d735fb
--- /dev/null
+++ b/examples/data2vec/config/text/pretraining/run_config/slurm_4.yaml
@@ -0,0 +1,36 @@
+# @package _global_
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: ':'
+ item_sep: '/'
+ exclude_keys:
+ - run_config
+ - distributed_training.distributed_port
+ - distributed_training.distributed_world_size
+ - model.pretrained_model_path
+ - model.target_network_path
+ - next_script
+ - task.cache_in_scratch
+ - task.data
+ - checkpoint.save_interval_updates
+ - checkpoint.keep_interval_updates
+ - checkpoint.save_on_overflow
+ - common.log_interval
+ sweep:
+ dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
+ subdir: ''
+ launcher:
+ submitit_folder: ${hydra.sweep.dir}
+ timeout_min: 4320
+ cpus_per_task: 10
+ gpus_per_node: 8
+ tasks_per_node: 8
+ mem_gb: 450
+ nodes: 4
+ name: ${env:PREFIX}_${hydra.job.config_name}
+ partition: devlab,learnlab,learnfair,scavenge
+ constraint: volta32gb,ib4
+ max_num_timeout: 30
diff --git a/examples/data2vec/config/text/pretraining/run_config/slurm_4_aws.yaml b/examples/data2vec/config/text/pretraining/run_config/slurm_4_aws.yaml
new file mode 100644
index 000000000..5df84cd6d
--- /dev/null
+++ b/examples/data2vec/config/text/pretraining/run_config/slurm_4_aws.yaml
@@ -0,0 +1,41 @@
+# @package _global_
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: '_'
+ item_sep: '/'
+ exclude_keys:
+ - run_config
+ - distributed_training.distributed_port
+ - distributed_training.distributed_world_size
+ - model.pretrained_model_path
+ - model.target_network_path
+ - next_script
+ - task.cache_in_scratch
+ - task.data
+ - checkpoint.save_interval_updates
+ - checkpoint.keep_interval_updates
+ - checkpoint.save_on_overflow
+ - common.log_interval
+ - common.user_dir
+ sweep:
+ dir: /fsx-wav2vec/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
+ subdir: ''
+ launcher:
+ submitit_folder: ${hydra.sweep.dir}/submitit
+ timeout_min: 4320
+ cpus_per_task: 10
+ gpus_per_node: 8
+ tasks_per_node: 8
+ mem_gb: 0
+ nodes: 4
+ name: ${env:PREFIX}_${hydra.job.config_name}
+ partition: wav2vec
+ max_num_timeout: 30
+ exclude: a100-st-p4d24xlarge-471
+
+distributed_training:
+ distributed_world_size: 32
+ ddp_backend: legacy_ddp
diff --git a/examples/data2vec/config/text/pretraining/run_config/slurm_8_aws.yaml b/examples/data2vec/config/text/pretraining/run_config/slurm_8_aws.yaml
new file mode 100644
index 000000000..5b32c23a6
--- /dev/null
+++ b/examples/data2vec/config/text/pretraining/run_config/slurm_8_aws.yaml
@@ -0,0 +1,41 @@
+# @package _global_
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: '_'
+ item_sep: '/'
+ exclude_keys:
+ - run_config
+ - distributed_training.distributed_port
+ - distributed_training.distributed_world_size
+ - model.pretrained_model_path
+ - model.target_network_path
+ - next_script
+ - task.cache_in_scratch
+ - task.data
+ - checkpoint.save_interval_updates
+ - checkpoint.keep_interval_updates
+ - checkpoint.save_on_overflow
+ - common.log_interval
+ - common.user_dir
+ sweep:
+ dir: /fsx-wav2vec/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
+ subdir: ''
+ launcher:
+ submitit_folder: ${hydra.sweep.dir}/submitit
+ timeout_min: 4320
+ cpus_per_task: 10
+ gpus_per_node: 8
+ tasks_per_node: 8
+ mem_gb: 0
+ nodes: 8
+ name: pt
+ partition: wav2vec
+ max_num_timeout: 30
+ exclude: a100-st-p4d24xlarge-471
+
+distributed_training:
+ distributed_world_size: 64
+ ddp_backend: legacy_ddp
diff --git a/examples/data2vec/config/v2/base_audio_only_task.yaml b/examples/data2vec/config/v2/base_audio_only_task.yaml
new file mode 100644
index 000000000..65a9ab3e7
--- /dev/null
+++ b/examples/data2vec/config/v2/base_audio_only_task.yaml
@@ -0,0 +1,113 @@
+# @package _group_
+
+common:
+ fp16: true
+ log_format: json
+ log_interval: 200
+ tensorboard_logdir: tb
+ min_loss_scale: 1e-6
+ fp16_no_flatten_grads: false
+ user_dir: ${env:PWD}/examples/data2vec
+
+checkpoint:
+ save_interval: 1
+ save_interval_updates: 25000
+ keep_interval_updates: 1
+ no_epoch_checkpoints: true
+
+task:
+ _name: audio_pretraining
+ data: /private/home/abaevski/data/librispeech/full
+ max_sample_size: 320000
+ min_sample_size: 32000
+ normalize: true
+ precompute_mask_config: {}
+
+dataset:
+ num_workers: 6
+ max_tokens: 1000000
+ skip_invalid_size_inputs_valid_test: true
+ validate_interval: 5
+ required_batch_size_multiple: 1
+ disable_validation: true
+
+distributed_training:
+ distributed_world_size: 8
+ ddp_backend: legacy_ddp
+
+criterion:
+ _name: model
+ log_keys:
+ - ema_decay
+ - target_var
+ - pred_var
+ - model_norm
+ - ema_norm
+ - masked_pct
+
+optimization:
+ max_update: 400000
+ lr: [0.00075]
+ debug_param_names: true
+
+optimizer:
+ _name: adam
+ adam_betas: [ 0.9,0.98 ]
+ adam_eps: 1e-06
+ weight_decay: 0.01
+
+lr_scheduler:
+ _name: cosine
+ warmup_updates: 8000
+
+model:
+ _name: data2vec_multi
+
+ loss_beta: 0
+ loss_scale: null
+
+ depth: 12
+ embed_dim: 768
+ clone_batch: 8
+
+ ema_decay: 0.999
+ ema_end_decay: 0.99999
+ ema_anneal_end_step: 75000
+ ema_encoder_only: false
+
+ average_top_k_layers: 8
+ instance_norm_target_layer: true
+ layer_norm_target_layer: false
+ layer_norm_targets: false
+
+ layerdrop: 0.05
+ norm_eps: 1e-5
+
+ supported_modality: AUDIO
+
+ modalities:
+ audio:
+ feature_encoder_spec: '[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512,2,2)] + [(512,2,2)]'
+ conv_pos_depth: 5
+ conv_pos_width: 95
+ conv_pos_groups: 16
+ prenet_depth: 0
+ mask_prob: 0.5
+ mask_prob_adjust: 0.05
+ inverse_mask: false
+ mask_length: 5
+ mask_noise_std: 0.01
+ mask_dropout: 0
+ add_masks: false
+ ema_local_encoder: false
+ use_alibi_encoder: true
+ prenet_layerdrop: 0.05
+ prenet_dropout: 0.1
+ learned_alibi_scale: true
+ learned_alibi_scale_per_head: true
+ decoder:
+ input_dropout: 0.1
+ decoder_dim: 384
+ decoder_groups: 16
+ decoder_kernel: 7
+ decoder_layers: 4
diff --git a/examples/data2vec/config/v2/base_images_only_task.yaml b/examples/data2vec/config/v2/base_images_only_task.yaml
new file mode 100644
index 000000000..ff0c247b1
--- /dev/null
+++ b/examples/data2vec/config/v2/base_images_only_task.yaml
@@ -0,0 +1,116 @@
+# @package _group_
+
+common:
+ fp16: true
+ log_format: json
+ log_interval: 200
+ tensorboard_logdir: tb
+ min_loss_scale: 1e-6
+ fp16_no_flatten_grads: true
+ user_dir: ${env:PWD}/examples/data2vec
+
+checkpoint:
+ save_interval: 5
+ save_interval_updates: 25000
+ keep_interval_updates: 1
+ no_epoch_checkpoints: true
+
+task:
+ _name: mae_image_pretraining
+ data: /datasets01/imagenet_full_size/061417/
+ rebuild_batches: true
+ local_cache_path: /scratch/cache_abaevski/imagenet
+ key: source
+ precompute_mask_config: {}
+
+dataset:
+ num_workers: 10
+ batch_size: 16
+ skip_invalid_size_inputs_valid_test: true
+ required_batch_size_multiple: 1
+ disable_validation: true
+
+distributed_training:
+ distributed_world_size: 16
+ ddp_backend: c10d
+
+criterion:
+ _name: model
+ log_keys:
+ - ema_decay
+ - target_var
+ - pred_var
+ - model_norm
+ - ema_norm
+ - masked_pct
+
+optimization:
+ max_update: 375300
+ lr: [ 0.001 ]
+ debug_param_names: true
+ clip_norm: 4
+
+optimizer:
+ _name: composite
+ dynamic_groups: true
+ groups:
+ default:
+ lr_float: 1e-3
+ optimizer:
+ _name: adam
+ adam_betas: [0.9,0.95]
+ weight_decay: 0.05
+ lr_scheduler:
+ _name: cosine
+ warmup_updates: 50040
+
+lr_scheduler: pass_through
+
+model:
+ _name: data2vec_multi
+
+ ema_decay: 0.9998
+ ema_end_decay: 0.99999
+ ema_anneal_end_step: 100000
+ instance_norm_target_layer: true
+ layer_norm_target_layer: false
+ layer_norm_targets: true
+ end_of_block_targets: false
+
+ depth: 10
+ average_top_k_layers: 10
+ clone_batch: 16
+
+ norm_eps: 1e-6
+
+ min_target_var: 0
+ min_pred_var: 0
+
+ encoder_dropout: 0
+ post_mlp_drop: 0
+ attention_dropout: 0
+ activation_dropout: 0
+
+ supported_modality: IMAGE
+ cls_loss: 0.01
+
+ ema_encoder_only: false
+
+ modalities:
+ image:
+ inverse_mask: true
+ mask_prob: 0.8
+ mask_prob_adjust: 0.07
+ mask_length: 3
+ mask_noise_std: 0.01
+ prenet_depth: 2
+ ema_local_encoder: true
+ num_extra_tokens: 1
+ init_extra_token_zero: false
+ use_alibi_encoder: false
+ decoder:
+ decoder_dim: 768
+ decoder_groups: 16
+ decoder_kernel: 3
+ decoder_layers: 6
+ input_dropout: 0
\ No newline at end of file
diff --git a/examples/data2vec/config/v2/base_text_only_task.yaml b/examples/data2vec/config/v2/base_text_only_task.yaml
new file mode 100644
index 000000000..62f22eb0f
--- /dev/null
+++ b/examples/data2vec/config/v2/base_text_only_task.yaml
@@ -0,0 +1,112 @@
+# @package _group_
+
+common:
+ fp16: true
+ log_format: json
+ log_interval: 200
+ tensorboard_logdir: tb
+ fp16_no_flatten_grads: true
+ user_dir: ${env:PWD}/examples/data2vec
+
+checkpoint:
+ no_epoch_checkpoints: true
+ save_interval_updates: 50000
+ keep_interval_updates: 1
+
+distributed_training:
+ distributed_world_size: 16
+ ddp_backend: legacy_ddp
+
+task:
+ _name: masked_lm
+ data: /fsx-wav2vec/abaevski/data/nlp/bookwiki_aml-full-mmap2-bin
+ sample_break_mode: none
+ tokens_per_sample: 512
+ include_target_tokens: true
+ random_token_prob: 0
+ leave_unmasked_prob: 0
+ include_index: True
+ skip_masking: True
+ d2v2_multi: True
+
+criterion:
+ _name: model
+ log_keys:
+ - ema_decay
+ - target_var
+ - pred_var
+ - model_norm
+ - ema_norm
+ - masked_pct
+
+dataset:
+ batch_size: 4
+ ignore_unused_valid_subsets: true
+ skip_invalid_size_inputs_valid_test: true
+ disable_validation: true
+
+optimization:
+ clip_norm: 1
+ lr: [0.0002]
+ max_update: 1000000
+ update_freq: [1]
+
+optimizer:
+ _name: composite
+ dynamic_groups: true
+ groups:
+ default:
+ lr_float: 0.0002
+ optimizer:
+ _name: adam
+ adam_betas: [0.9,0.98]
+ adam_eps: 1e-06
+ weight_decay: 0.01
+ lr_scheduler:
+ _name: cosine
+ warmup_updates: 4000
+
+lr_scheduler: pass_through
+
+model:
+ _name: data2vec_multi
+
+ loss_beta: 0
+ loss_scale: 1
+
+ depth: 12
+ embed_dim: 768
+ clone_batch: 8
+
+ ema_decay: 0.9999
+ ema_end_decay: 0.99999
+ ema_anneal_end_step: 100000
+ ema_encoder_only: true
+
+ average_top_k_layers: 12
+ layer_norm_target_layer: false
+ instance_norm_target_layer: true
+ batch_norm_target_layer: false
+ instance_norm_targets: false
+ layer_norm_targets: false
+
+ layerdrop: 0
+ norm_eps: 1e-5
+
+ supported_modality: TEXT
+
+ modalities:
+ text:
+ mask_prob: 0.48
+ mask_length: 1
+ mask_noise_std: 0.01
+ prenet_depth: 0
+ decoder:
+ input_dropout: 0.1
+ decoder_dim: 768
+ decoder_groups: 1
+ decoder_kernel: 9
+ decoder_layers: 5
+ decoder_residual: false
+ projection_layers: 2
+ projection_ratio: 2.0
diff --git a/examples/data2vec/config/v2/huge_images14_only_task.yaml b/examples/data2vec/config/v2/huge_images14_only_task.yaml
new file mode 100644
index 000000000..a8a15253f
--- /dev/null
+++ b/examples/data2vec/config/v2/huge_images14_only_task.yaml
@@ -0,0 +1,122 @@
+# @package _group_
+
+common:
+ fp16: true
+ log_format: json
+ log_interval: 200
+ tensorboard_logdir: tb
+ min_loss_scale: 1e-6
+ fp16_no_flatten_grads: true
+ user_dir: ${env:PWD}/examples/data2vec
+
+checkpoint:
+ save_interval: 5
+ save_interval_updates: 25000
+ keep_interval_updates: 1
+ no_epoch_checkpoints: true
+
+task:
+ _name: mae_image_pretraining
+ data: /datasets01/imagenet_full_size/061417/
+ rebuild_batches: true
+ local_cache_path: /scratch/cache_abaevski/imagenet
+ key: source
+ precompute_mask_config: {}
+
+dataset:
+ num_workers: 10
+ batch_size: 8
+ skip_invalid_size_inputs_valid_test: true
+ required_batch_size_multiple: 1
+ disable_validation: true
+
+distributed_training:
+ distributed_world_size: 32
+ ddp_backend: c10d
+
+criterion:
+ _name: model
+ log_keys:
+ - ema_decay
+ - target_var
+ - pred_var
+ - model_norm
+ - ema_norm
+ - masked_pct
+
+optimization:
+ max_update: 500000
+ lr: [ 0.0004 ]
+ debug_param_names: true
+ clip_norm: 4
+
+optimizer:
+ _name: composite
+ dynamic_groups: true
+ groups:
+ default:
+ lr_float: 4e-4
+ optimizer:
+ _name: adam
+ adam_betas: [0.9,0.95]
+ weight_decay: 0.05
+ lr_scheduler:
+ _name: cosine
+ warmup_updates: 50040
+
+lr_scheduler: pass_through
+
+model:
+ _name: data2vec_multi
+
+ ema_decay: 0.9998
+ ema_end_decay: 1
+ ema_anneal_end_step: 300000
+ instance_norm_target_layer: true
+ layer_norm_target_layer: false
+ layer_norm_targets: true
+ end_of_block_targets: false
+
+ depth: 32
+ embed_dim: 1280
+ num_heads: 16
+
+ average_top_k_layers: 24
+ clone_batch: 16
+
+ norm_eps: 1e-6
+
+ min_target_var: 0
+ min_pred_var: 0
+
+ encoder_dropout: 0
+ post_mlp_drop: 0
+ attention_dropout: 0
+ activation_dropout: 0
+
+ supported_modality: IMAGE
+ cls_loss: 0.01
+
+ ema_encoder_only: false
+
+ modalities:
+ image:
+ patch_size: 14
+ inverse_mask: true
+ mask_prob: 0.75
+ mask_prob_adjust: 0.1
+ mask_length: 3
+ mask_noise_std: 0.01
+ prenet_depth: 0
+ ema_local_encoder: true
+ num_extra_tokens: 1
+ init_extra_token_zero: false
+ use_alibi_encoder: false
+ embed_dim: 1280
+ decoder:
+ decoder_dim: 1024
+ decoder_groups: 16
+ decoder_kernel: 5
+ decoder_layers: 3
+ final_layer_norm: false
+ input_dropout: 0
\ No newline at end of file
diff --git a/examples/data2vec/config/v2/huge_images_only_task.yaml b/examples/data2vec/config/v2/huge_images_only_task.yaml
new file mode 100644
index 000000000..7a352ac3c
--- /dev/null
+++ b/examples/data2vec/config/v2/huge_images_only_task.yaml
@@ -0,0 +1,120 @@
+# @package _group_
+
+common:
+ fp16: true
+ log_format: json
+ log_interval: 200
+ tensorboard_logdir: tb
+ min_loss_scale: 1e-6
+ fp16_no_flatten_grads: true
+ user_dir: ${env:PWD}/examples/data2vec
+
+checkpoint:
+ save_interval: 5
+ save_interval_updates: 25000
+ keep_interval_updates: 1
+ no_epoch_checkpoints: true
+
+task:
+ _name: mae_image_pretraining
+ data: /datasets01/imagenet_full_size/061417/
+ rebuild_batches: true
+ local_cache_path: /scratch/cache_abaevski/imagenet
+ key: source
+ precompute_mask_config: {}
+
+dataset:
+ num_workers: 10
+ batch_size: 8
+ skip_invalid_size_inputs_valid_test: true
+ required_batch_size_multiple: 1
+ disable_validation: true
+
+distributed_training:
+ distributed_world_size: 16
+ ddp_backend: c10d
+
+criterion:
+ _name: model
+ log_keys:
+ - ema_decay
+ - target_var
+ - pred_var
+ - model_norm
+ - ema_norm
+ - masked_pct
+
+optimization:
+ max_update: 375300
+ lr: [ 0.0004 ]
+ debug_param_names: true
+ clip_norm: 4
+
+optimizer:
+ _name: composite
+ dynamic_groups: true
+ groups:
+ default:
+ lr_float: 4e-4
+ optimizer:
+ _name: adam
+ adam_betas: [0.9,0.95]
+ weight_decay: 0.05
+ lr_scheduler:
+ _name: cosine
+ warmup_updates: 50040
+
+lr_scheduler: pass_through
+
+model:
+ _name: data2vec_multi
+
+ ema_decay: 0.9998
+ ema_end_decay: 0.99995
+ ema_anneal_end_step: 150000
+ instance_norm_target_layer: true
+ layer_norm_target_layer: false
+ layer_norm_targets: true
+ end_of_block_targets: false
+
+ depth: 32
+ embed_dim: 1280
+ num_heads: 16
+
+ average_top_k_layers: 24
+ clone_batch: 16
+
+ norm_eps: 1e-6
+
+ min_target_var: 0
+ min_pred_var: 0
+
+ encoder_dropout: 0
+ post_mlp_drop: 0
+ attention_dropout: 0
+ activation_dropout: 0
+
+ supported_modality: IMAGE
+ cls_loss: 0.01
+
+ ema_encoder_only: false
+
+ modalities:
+ image:
+ inverse_mask: true
+ mask_prob: 0.75
+ mask_prob_adjust: 0.1
+ mask_length: 3
+ mask_noise_std: 0.01
+ prenet_depth: 0
+ ema_local_encoder: true
+ num_extra_tokens: 1
+ init_extra_token_zero: false
+ use_alibi_encoder: false
+ embed_dim: 1280
+ decoder:
+ decoder_dim: 1024
+ decoder_groups: 16
+ decoder_kernel: 5
+ decoder_layers: 3
+ input_dropout: 0
\ No newline at end of file
diff --git a/examples/data2vec/config/v2/large_audio_only_task.yaml b/examples/data2vec/config/v2/large_audio_only_task.yaml
new file mode 100644
index 000000000..3f6158972
--- /dev/null
+++ b/examples/data2vec/config/v2/large_audio_only_task.yaml
@@ -0,0 +1,122 @@
+# @package _group_
+
+common:
+ fp16: true
+ log_format: json
+ log_interval: 200
+ tensorboard_logdir: tb
+ min_loss_scale: 1e-6
+ fp16_no_flatten_grads: true
+ user_dir: ${env:PWD}/examples/data2vec
+
+checkpoint:
+ save_interval: 1
+ save_interval_updates: 25000
+ keep_interval_updates: 1
+ no_epoch_checkpoints: true
+
+task:
+ _name: audio_pretraining
+ data: /fsx-wav2vec/abaevski/data/librivox/no_silence
+ max_sample_size: 320000
+ min_sample_size: 32000
+ normalize: true
+ precompute_mask_config: {}
+
+dataset:
+ num_workers: 8
+ max_tokens: 320000
+ skip_invalid_size_inputs_valid_test: true
+ validate_interval: 5
+ required_batch_size_multiple: 1
+ disable_validation: true
+
+distributed_training:
+ distributed_world_size: 48
+ ddp_backend: c10d
+
+criterion:
+ _name: model
+ log_keys:
+ - ema_decay
+ - target_var
+ - pred_var
+ - model_norm
+ - ema_norm
+ - masked_pct
+
+optimization:
+ max_update: 600000
+ debug_param_names: true
+ clip_norm: 1
+
+optimizer:
+ _name: composite
+ dynamic_groups: true
+ groups:
+ default:
+ lr_float: 0.0004
+ optimizer:
+ _name: adam
+ adam_betas: [0.9,0.98]
+ adam_eps: 1e-06
+ weight_decay: 0.01
+ lr_scheduler:
+ _name: cosine
+ warmup_updates: 10000
+
+lr_scheduler: pass_through
+
+model:
+ _name: data2vec_multi
+
+ loss_beta: 0
+ loss_scale: null
+
+ depth: 16
+ embed_dim: 1024
+ num_heads: 16
+
+ clone_batch: 12
+
+ ema_decay: 0.9997
+ ema_end_decay: 1
+ ema_anneal_end_step: 300000
+ ema_encoder_only: false
+
+ average_top_k_layers: 16
+ instance_norm_target_layer: true
+ layer_norm_target_layer: false
+ layer_norm_targets: false
+
+ layerdrop: 0
+ norm_eps: 1e-5
+
+ supported_modality: AUDIO
+
+ modalities:
+ audio:
+ feature_encoder_spec: '[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512,2,2)] + [(512,2,2)]'
+ conv_pos_depth: 5
+ conv_pos_width: 95
+ conv_pos_groups: 16
+ prenet_depth: 8
+ mask_prob: 0.55
+ mask_prob_adjust: 0.1
+ inverse_mask: false
+ mask_length: 5
+ mask_noise_std: 0.01
+ mask_dropout: 0
+ add_masks: false
+ ema_local_encoder: false
+ use_alibi_encoder: true
+ prenet_layerdrop: 0
+ prenet_dropout: 0.1
+ learned_alibi_scale: true
+ learned_alibi_scale_per_head: true
+ decoder:
+ input_dropout: 0.1
+ decoder_dim: 768
+ decoder_groups: 16
+ decoder_kernel: 7
+ decoder_layers: 4
diff --git a/examples/data2vec/config/v2/large_images_only_task.yaml b/examples/data2vec/config/v2/large_images_only_task.yaml
new file mode 100644
index 000000000..6b957fc12
--- /dev/null
+++ b/examples/data2vec/config/v2/large_images_only_task.yaml
@@ -0,0 +1,120 @@
+# @package _group_
+
+common:
+ fp16: true
+ log_format: json
+ log_interval: 200
+ tensorboard_logdir: tb
+ min_loss_scale: 1e-6
+ fp16_no_flatten_grads: true
+ user_dir: ${env:PWD}/examples/data2vec
+
+checkpoint:
+ save_interval: 5
+ save_interval_updates: 25000
+ keep_interval_updates: 1
+ no_epoch_checkpoints: true
+
+task:
+ _name: mae_image_pretraining
+ data: /datasets01/imagenet_full_size/061417/
+ rebuild_batches: true
+ local_cache_path: /scratch/cache_abaevski/imagenet
+ key: source
+ precompute_mask_config: {}
+
+dataset:
+ num_workers: 10
+ batch_size: 8
+ skip_invalid_size_inputs_valid_test: true
+ required_batch_size_multiple: 1
+ disable_validation: true
+
+distributed_training:
+ distributed_world_size: 16
+ ddp_backend: c10d
+
+criterion:
+ _name: model
+ log_keys:
+ - ema_decay
+ - target_var
+ - pred_var
+ - model_norm
+ - ema_norm
+ - masked_pct
+
+optimization:
+ max_update: 375300
+ lr: [ 0.0004 ]
+ debug_param_names: true
+ clip_norm: 4
+
+optimizer:
+ _name: composite
+ dynamic_groups: true
+ groups:
+ default:
+ lr_float: 4e-4
+ optimizer:
+ _name: adam
+ adam_betas: [0.9,0.95]
+ weight_decay: 0.05
+ lr_scheduler:
+ _name: cosine
+ warmup_updates: 50040
+
+lr_scheduler: pass_through
+
+model:
+ _name: data2vec_multi
+
+ ema_decay: 0.9998
+ ema_end_decay: 0.99999
+ ema_anneal_end_step: 150000
+ instance_norm_target_layer: true
+ layer_norm_target_layer: false
+ layer_norm_targets: true
+ end_of_block_targets: false
+
+ depth: 24
+ embed_dim: 1024
+ num_heads: 16
+
+ average_top_k_layers: 18
+ clone_batch: 16
+
+ norm_eps: 1e-6
+
+ min_target_var: 0
+ min_pred_var: 0
+
+ encoder_dropout: 0
+ post_mlp_drop: 0
+ attention_dropout: 0
+ activation_dropout: 0
+
+ supported_modality: IMAGE
+ cls_loss: 0.01
+
+ ema_encoder_only: false
+
+ modalities:
+ image:
+ inverse_mask: true
+ mask_prob: 0.75
+ mask_prob_adjust: 0.1
+ mask_length: 3
+ mask_noise_std: 0.01
+ prenet_depth: 0
+ ema_local_encoder: true
+ num_extra_tokens: 1
+ init_extra_token_zero: false
+ use_alibi_encoder: false
+ embed_dim: 1024
+ decoder:
+ decoder_dim: 1024
+ decoder_groups: 16
+ decoder_kernel: 5
+ decoder_layers: 3
+ input_dropout: 0
\ No newline at end of file
diff --git a/examples/data2vec/config/v2/large_text_only_task.yaml b/examples/data2vec/config/v2/large_text_only_task.yaml
new file mode 100644
index 000000000..fd69048e7
--- /dev/null
+++ b/examples/data2vec/config/v2/large_text_only_task.yaml
@@ -0,0 +1,112 @@
+# @package _group_
+
+common:
+ fp16: true
+ log_format: json
+ log_interval: 200
+ tensorboard_logdir: tb
+ min_loss_scale: 1e-6
+ fp16_no_flatten_grads: true
+ user_dir: ${env:PWD}/examples/data2vec
+
+checkpoint:
+ save_interval_updates: 50000
+ keep_interval_updates: 1
+ no_epoch_checkpoints: true
+
+task:
+ _name: masked_lm
+ data: /fsx-wav2vec/abaevski/data/nlp/bookwiki_aml-full-mmap2-bin
+ sample_break_mode: none
+ tokens_per_sample: 512
+ include_target_tokens: true
+ random_token_prob: 0
+ leave_unmasked_prob: 0
+ include_index: True
+ skip_masking: True
+ d2v2_multi: True
+
+dataset:
+ batch_size: 2
+ ignore_unused_valid_subsets: true
+ skip_invalid_size_inputs_valid_test: true
+ disable_validation: true
+
+distributed_training:
+ distributed_world_size: 32
+ ddp_backend: c10d
+
+criterion:
+ _name: model
+ log_keys:
+ - ema_decay
+ - target_var
+ - pred_var
+ - model_norm
+ - ema_norm
+ - masked_pct
+
+optimization:
+ max_update: 600000
+ clip_norm: 1
+
+optimizer:
+ _name: composite
+ dynamic_groups: true
+ groups:
+ default:
+ lr_float: 0.0001
+ optimizer:
+ _name: adam
+ adam_betas: [0.9,0.98]
+ adam_eps: 1e-06
+ weight_decay: 0.01
+ lr_scheduler:
+ _name: cosine
+ warmup_updates: 4000
+
+lr_scheduler: pass_through
+
+model:
+ _name: data2vec_multi
+
+ loss_beta: 0
+ loss_scale: 1
+
+ depth: 24
+ num_heads: 16
+ embed_dim: 1024
+ clone_batch: 8
+
+ ema_decay: 0.9999
+ ema_end_decay: 0.99999
+ ema_anneal_end_step: 100000
+ ema_encoder_only: true
+
+ average_top_k_layers: 24
+ layer_norm_target_layer: true
+ instance_norm_target_layer: false
+ batch_norm_target_layer: false
+ instance_norm_targets: true
+ layer_norm_targets: false
+
+ layerdrop: 0
+ norm_eps: 1e-5
+
+ supported_modality: TEXT
+
+ modalities:
+ text:
+ mask_prob: 0.5
+ mask_length: 1
+ mask_noise_std: 0.01
+ prenet_depth: 0
+ decoder:
+ input_dropout: 0.1
+ decoder_dim: 768
+ decoder_groups: 1
+ decoder_kernel: 9
+ decoder_layers: 5
+ decoder_residual: false
+ projection_layers: 2
+ projection_ratio: 2.0
diff --git a/examples/data2vec/config/v2/large_text_only_task_pgrp_1M.yaml b/examples/data2vec/config/v2/large_text_only_task_pgrp_1M.yaml
new file mode 100644
index 000000000..739e6f672
--- /dev/null
+++ b/examples/data2vec/config/v2/large_text_only_task_pgrp_1M.yaml
@@ -0,0 +1,123 @@
+# @package _group_
+
+common:
+ fp16: true
+ log_format: json
+ log_interval: 200
+ tensorboard_logdir: tb
+ fp16_no_flatten_grads: true
+ user_dir: ${env:PWD}/examples/data2vec
+
+checkpoint:
+ no_epoch_checkpoints: true
+ save_interval_updates: 50000
+ keep_interval_updates: 1
+
+distributed_training:
+ distributed_world_size: 32
+ ddp_backend: legacy_ddp
+
+task:
+ _name: masked_lm
+ data: /fsx-wav2vec/abaevski/data/nlp/bookwiki_aml-full-mmap2-bin
+ sample_break_mode: none
+ tokens_per_sample: 512
+ include_target_tokens: true
+ random_token_prob: 0
+ leave_unmasked_prob: 0
+ include_index: True
+ skip_masking: True
+ d2v2_multi: True
+
+criterion:
+ _name: model
+ log_keys:
+ - ema_decay
+ - target_var
+ - pred_var
+ - model_norm
+ - ema_norm
+ - masked_pct
+
+dataset:
+ batch_size: 2
+ ignore_unused_valid_subsets: true
+ skip_invalid_size_inputs_valid_test: true
+ disable_validation: true
+
+optimization:
+ clip_norm: 1
+ lr: [3e-4]
+ max_update: 1000000
+ update_freq: [1]
+
+optimizer:
+ _name: composite
+ groups:
+ default:
+ lr_float: 1e-4
+ optimizer:
+ _name: adam
+ adam_betas: [0.9,0.98]
+ adam_eps: 1e-06
+ weight_decay: 0.01
+ lr_scheduler:
+ _name: cosine
+ warmup_updates: 4000
+ decoder:
+ lr_float: 1e-4
+ optimizer:
+ _name: adam
+ adam_betas: [0.9,0.98]
+ adam_eps: 1e-06
+ weight_decay: 0.01
+ lr_scheduler:
+ _name: cosine
+ warmup_updates: 4000
+
+lr_scheduler: pass_through
+
+model:
+ _name: data2vec_multi
+
+ loss_beta: 4
+ loss_scale: 1
+
+ depth: 24
+ num_heads: 16
+ embed_dim: 1024
+ clone_batch: 8
+
+ ema_decay: 0.9999
+ ema_end_decay: 0.99999
+ ema_anneal_end_step: 100000
+ ema_encoder_only: true
+
+ average_top_k_layers: 24
+ layer_norm_target_layer: true
+ instance_norm_target_layer: false
+ batch_norm_target_layer: false
+ instance_norm_targets: true
+ layer_norm_targets: false
+
+ layerdrop: 0
+ norm_eps: 1e-5
+
+ supported_modality: TEXT
+ decoder_group: true
+
+ modalities:
+ text:
+ mask_prob: 0.5
+ mask_length: 1
+ mask_noise_std: 0.01
+ prenet_depth: 0
+ decoder:
+ input_dropout: 0.1
+ decoder_dim: 768
+ decoder_groups: 1
+ decoder_kernel: 9
+ decoder_layers: 5
+ decoder_residual: false
+ projection_layers: 2
+ projection_ratio: 2.0
diff --git a/examples/data2vec/config/v2/run_config/local.yaml b/examples/data2vec/config/v2/run_config/local.yaml
new file mode 100644
index 000000000..45595f9ee
--- /dev/null
+++ b/examples/data2vec/config/v2/run_config/local.yaml
@@ -0,0 +1,15 @@
+# @package _global_
+hydra:
+ sweep:
+ dir: ${env:PWD}/tmp_dbg/${now:%H-%M-%S}
+
+distributed_training:
+ distributed_world_size: 1
+ nprocs_per_node: 1
+ distributed_port: -1
+
+common:
+ log_interval: 1
+
+dataset:
+ num_workers: 0
diff --git a/examples/data2vec/config/v2/run_config/slurm_1.yaml b/examples/data2vec/config/v2/run_config/slurm_1.yaml
new file mode 100644
index 000000000..732f01889
--- /dev/null
+++ b/examples/data2vec/config/v2/run_config/slurm_1.yaml
@@ -0,0 +1,37 @@
+# @package _global_
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: ':'
+ item_sep: '/'
+ exclude_keys:
+ - run_config
+ - distributed_training.distributed_port
+ - distributed_training.distributed_world_size
+ - model.pretrained_model_path
+ - model.target_network_path
+ - next_script
+ - task.cache_in_scratch
+ - task.data
+ - checkpoint.save_interval_updates
+ - checkpoint.keep_interval_updates
+ - checkpoint.save_on_overflow
+ - common.log_interval
+ - common.user_dir
+ sweep:
+ dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
+ subdir: ''
+ launcher:
+ submitit_folder: ${hydra.sweep.dir}
+ timeout_min: 4320
+ cpus_per_task: 80
+ gpus_per_node: 8
+ tasks_per_node: 1
+ mem_gb: 450
+ nodes: 1
+ name: ${env:PREFIX}_${hydra.job.config_name}
+ partition: devlab,learnlab,learnfair,scavenge
+ constraint: volta32gb,ib4
+ max_num_timeout: 30
diff --git a/examples/data2vec/config/v2/run_config/slurm_1_aws.yaml b/examples/data2vec/config/v2/run_config/slurm_1_aws.yaml
new file mode 100644
index 000000000..b2184f8cf
--- /dev/null
+++ b/examples/data2vec/config/v2/run_config/slurm_1_aws.yaml
@@ -0,0 +1,37 @@
+# @package _global_
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: ':'
+ item_sep: '/'
+ exclude_keys:
+ - run_config
+ - distributed_training.distributed_port
+ - distributed_training.distributed_world_size
+ - model.pretrained_model_path
+ - model.target_network_path
+ - next_script
+ - task.cache_in_scratch
+ - task.local_cache_path
+ - task.data
+ - checkpoint.save_interval_updates
+ - checkpoint.keep_interval_updates
+ - checkpoint.save_on_overflow
+ - common.log_interval
+ - common.user_dir
+ sweep:
+ dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
+ subdir: ''
+ launcher:
+ submitit_folder: ${hydra.sweep.dir}
+ timeout_min: 4320
+ cpus_per_task: 80
+ gpus_per_node: 8
+ tasks_per_node: 1
+ mem_gb: 0
+ nodes: 1
+ name: ${env:PREFIX}_${hydra.job.config_name}
+ partition: wav2vec,learnlab,learnfair
+ max_num_timeout: 30
diff --git a/examples/data2vec/config/v2/run_config/slurm_2.yaml b/examples/data2vec/config/v2/run_config/slurm_2.yaml
new file mode 100644
index 000000000..ec53dc2a9
--- /dev/null
+++ b/examples/data2vec/config/v2/run_config/slurm_2.yaml
@@ -0,0 +1,37 @@
+# @package _global_
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: ':'
+ item_sep: '/'
+ exclude_keys:
+ - run_config
+ - distributed_training.distributed_port
+ - distributed_training.distributed_world_size
+ - model.pretrained_model_path
+ - model.target_network_path
+ - next_script
+ - task.cache_in_scratch
+ - task.data
+ - checkpoint.save_interval_updates
+ - checkpoint.keep_interval_updates
+ - checkpoint.save_on_overflow
+ - common.log_interval
+ - common.user_dir
+ sweep:
+ dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
+ subdir: ''
+ launcher:
+ submitit_folder: ${hydra.sweep.dir}
+ timeout_min: 4320
+ cpus_per_task: 10
+ gpus_per_node: 8
+ tasks_per_node: 8
+ mem_gb: 450
+ nodes: 2
+ name: ${env:PREFIX}_${hydra.job.config_name}
+ partition: devlab,learnlab,learnfair,scavenge
+ constraint: volta32gb,ib4
+ max_num_timeout: 30
diff --git a/examples/data2vec/config/v2/run_config/slurm_2_aws.yaml b/examples/data2vec/config/v2/run_config/slurm_2_aws.yaml
new file mode 100644
index 000000000..553765597
--- /dev/null
+++ b/examples/data2vec/config/v2/run_config/slurm_2_aws.yaml
@@ -0,0 +1,39 @@
+# @package _global_
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: ':'
+ item_sep: '/'
+ exclude_keys:
+ - run_config
+ - distributed_training.distributed_port
+ - distributed_training.distributed_world_size
+ - model.pretrained_model_path
+ - model.target_network_path
+ - next_script
+ - task.cache_in_scratch
+ - task.local_cache_path
+ - task.data
+ - task.post_save_script
+ - checkpoint.save_interval_updates
+ - checkpoint.keep_interval_updates
+ - checkpoint.save_on_overflow
+ - common.log_interval
+ - common.user_dir
+ - model.model_path
+ sweep:
+ dir: /fsx-wav2vec/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
+ subdir: ''
+ launcher:
+ submitit_folder: ${hydra.sweep.dir}
+ timeout_min: 4320
+ cpus_per_task: 12
+ gpus_per_node: 8
+ tasks_per_node: 8
+ mem_gb: 0
+ nodes: 2
+ name: ${env:PREFIX}_${hydra.job.config_name}
+ partition: wav2vec
+ max_num_timeout: 30
diff --git a/examples/data2vec/config/v2/run_config/slurm_3.yaml b/examples/data2vec/config/v2/run_config/slurm_3.yaml
new file mode 100644
index 000000000..14b47d14e
--- /dev/null
+++ b/examples/data2vec/config/v2/run_config/slurm_3.yaml
@@ -0,0 +1,36 @@
+# @package _global_
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: ':'
+ item_sep: '/'
+ exclude_keys:
+ - run_config
+ - distributed_training.distributed_port
+ - distributed_training.distributed_world_size
+ - model.pretrained_model_path
+ - model.target_network_path
+ - next_script
+ - task.cache_in_scratch
+ - task.data
+ - checkpoint.save_interval_updates
+ - checkpoint.keep_interval_updates
+ - checkpoint.save_on_overflow
+ - common.log_interval
+ sweep:
+ dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
+ subdir: ''
+ launcher:
+ submitit_folder: ${hydra.sweep.dir}
+ timeout_min: 4320
+ cpus_per_task: 80
+ gpus_per_node: 8
+ tasks_per_node: 1
+ mem_gb: 450
+ nodes: 3
+ name: ${env:PREFIX}_${hydra.job.config_name}
+ partition: devlab,learnlab,learnfair,scavenge
+ constraint: volta32gb,ib4
+ max_num_timeout: 30
diff --git a/examples/data2vec/config/v2/run_config/slurm_4.yaml b/examples/data2vec/config/v2/run_config/slurm_4.yaml
new file mode 100644
index 000000000..c54d735fb
--- /dev/null
+++ b/examples/data2vec/config/v2/run_config/slurm_4.yaml
@@ -0,0 +1,36 @@
+# @package _global_
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: ':'
+ item_sep: '/'
+ exclude_keys:
+ - run_config
+ - distributed_training.distributed_port
+ - distributed_training.distributed_world_size
+ - model.pretrained_model_path
+ - model.target_network_path
+ - next_script
+ - task.cache_in_scratch
+ - task.data
+ - checkpoint.save_interval_updates
+ - checkpoint.keep_interval_updates
+ - checkpoint.save_on_overflow
+ - common.log_interval
+ sweep:
+ dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
+ subdir: ''
+ launcher:
+ submitit_folder: ${hydra.sweep.dir}
+ timeout_min: 4320
+ cpus_per_task: 10
+ gpus_per_node: 8
+ tasks_per_node: 8
+ mem_gb: 450
+ nodes: 4
+ name: ${env:PREFIX}_${hydra.job.config_name}
+ partition: devlab,learnlab,learnfair,scavenge
+ constraint: volta32gb,ib4
+ max_num_timeout: 30
diff --git a/examples/data2vec/config/v2/run_config/slurm_4_aws.yaml b/examples/data2vec/config/v2/run_config/slurm_4_aws.yaml
new file mode 100644
index 000000000..a77f62aec
--- /dev/null
+++ b/examples/data2vec/config/v2/run_config/slurm_4_aws.yaml
@@ -0,0 +1,37 @@
+# @package _global_
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: ':'
+ item_sep: '/'
+ exclude_keys:
+ - run_config
+ - distributed_training.distributed_port
+ - distributed_training.distributed_world_size
+ - model.pretrained_model_path
+ - model.target_network_path
+ - next_script
+ - task.cache_in_scratch
+ - task.data
+ - task.post_save_script
+ - checkpoint.save_interval_updates
+ - checkpoint.keep_interval_updates
+ - checkpoint.save_on_overflow
+ - common.log_interval
+ - common.user_dir
+ sweep:
+ dir: /fsx-wav2vec/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
+ subdir: ''
+ launcher:
+ submitit_folder: ${hydra.sweep.dir}
+ timeout_min: 4320
+ cpus_per_task: 12
+ gpus_per_node: 8
+ tasks_per_node: 8
+ mem_gb: 0
+ nodes: 4
+ name: ${env:PREFIX}_${hydra.job.config_name}
+ partition: wav2vec
+ max_num_timeout: 30
diff --git a/examples/data2vec/config/v2/run_config/slurm_6_aws.yaml b/examples/data2vec/config/v2/run_config/slurm_6_aws.yaml
new file mode 100644
index 000000000..20e06582b
--- /dev/null
+++ b/examples/data2vec/config/v2/run_config/slurm_6_aws.yaml
@@ -0,0 +1,36 @@
+# @package _global_
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: ':'
+ item_sep: '/'
+ exclude_keys:
+ - run_config
+ - distributed_training.distributed_port
+ - distributed_training.distributed_world_size
+ - model.pretrained_model_path
+ - model.target_network_path
+ - next_script
+ - task.cache_in_scratch
+ - task.data
+ - checkpoint.save_interval_updates
+ - checkpoint.keep_interval_updates
+ - checkpoint.save_on_overflow
+ - common.log_interval
+ - common.user_dir
+ sweep:
+ dir: /fsx-wav2vec/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
+ subdir: ''
+ launcher:
+ submitit_folder: ${hydra.sweep.dir}
+ timeout_min: 4320
+ cpus_per_task: 12
+ gpus_per_node: 8
+ tasks_per_node: 8
+ mem_gb: 0
+ nodes: 6
+ name: ${env:PREFIX}_${hydra.job.config_name}
+ partition: wav2vec,learnlab,learnfair
+ max_num_timeout: 30
diff --git a/examples/data2vec/config/v2/run_config/slurm_8.yaml b/examples/data2vec/config/v2/run_config/slurm_8.yaml
new file mode 100644
index 000000000..e3ec2c284
--- /dev/null
+++ b/examples/data2vec/config/v2/run_config/slurm_8.yaml
@@ -0,0 +1,37 @@
+# @package _global_
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: ':'
+ item_sep: '/'
+ exclude_keys:
+ - run_config
+ - distributed_training.distributed_port
+ - distributed_training.distributed_world_size
+ - model.pretrained_model_path
+ - model.target_network_path
+ - next_script
+ - task.cache_in_scratch
+ - task.data
+ - checkpoint.save_interval_updates
+ - checkpoint.keep_interval_updates
+ - checkpoint.save_on_overflow
+ - common.log_interval
+ - common.user_dir
+ sweep:
+ dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
+ subdir: ''
+ launcher:
+ submitit_folder: ${hydra.sweep.dir}
+ timeout_min: 4320
+ cpus_per_task: 10
+ gpus_per_node: 8
+ tasks_per_node: 8
+ mem_gb: 450
+ nodes: 8
+ name: ${env:PREFIX}_${hydra.job.config_name}
+ partition: devlab,learnlab,learnfair,scavenge
+ constraint: volta32gb,ib4
+ max_num_timeout: 30
diff --git a/examples/data2vec/config/v2/run_config/slurm_8_aws.yaml b/examples/data2vec/config/v2/run_config/slurm_8_aws.yaml
new file mode 100644
index 000000000..a9dce876c
--- /dev/null
+++ b/examples/data2vec/config/v2/run_config/slurm_8_aws.yaml
@@ -0,0 +1,36 @@
+# @package _global_
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: ':'
+ item_sep: '/'
+ exclude_keys:
+ - run_config
+ - distributed_training.distributed_port
+ - distributed_training.distributed_world_size
+ - model.pretrained_model_path
+ - model.target_network_path
+ - next_script
+ - task.cache_in_scratch
+ - task.data
+ - checkpoint.save_interval_updates
+ - checkpoint.keep_interval_updates
+ - checkpoint.save_on_overflow
+ - common.log_interval
+ - common.user_dir
+ sweep:
+ dir: /fsx-wav2vec/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
+ subdir: ''
+ launcher:
+ submitit_folder: ${hydra.sweep.dir}
+ timeout_min: 4320
+ cpus_per_task: 12
+ gpus_per_node: 8
+ tasks_per_node: 8
+ mem_gb: 0
+ nodes: 8
+ name: ${env:PREFIX}_${hydra.job.config_name}
+ partition: wav2vec,learnlab,learnfair
+ max_num_timeout: 30
diff --git a/examples/data2vec/config/v2/text_finetuning/cola.yaml b/examples/data2vec/config/v2/text_finetuning/cola.yaml
new file mode 100644
index 000000000..d4ac4ec8b
--- /dev/null
+++ b/examples/data2vec/config/v2/text_finetuning/cola.yaml
@@ -0,0 +1,60 @@
+# @package _group_
+
+common:
+ fp16: true
+ fp16_init_scale: 4
+ threshold_loss_scale: 1
+ fp16_scale_window: 128
+ log_format: json
+ log_interval: 200
+ user_dir: ${env:PWD}/examples/data2vec
+
+task:
+ _name: sentence_prediction
+ data: ???
+ init_token: 0
+ separator_token: 2
+ num_classes: 2
+ max_positions: 512
+ d2v2_multi: True
+
+checkpoint:
+ best_checkpoint_metric: mcc
+ maximize_best_checkpoint_metric: true
+ no_epoch_checkpoints: true
+
+distributed_training:
+ find_unused_parameters: true
+ distributed_world_size: 1
+ nprocs_per_node: 1
+ distributed_port: -1
+
+criterion:
+ _name: sentence_prediction
+ report_mcc: True
+
+dataset:
+ batch_size: 16
+ required_batch_size_multiple: 1
+ max_tokens: 4400
+ num_workers: 1
+
+optimizer:
+ _name: adam
+ weight_decay: 0.1
+ adam_betas: (0.9,0.98)
+ adam_eps: 1e-06
+
+lr_scheduler:
+ _name: polynomial_decay
+ warmup_updates: 320
+
+optimization:
+ clip_norm: 0.0
+ lr: [2e-05]
+ max_update: 5336
+ max_epoch: 10
+
+model:
+ _name: data2vec_text_classification
+ model_path: ???
diff --git a/examples/data2vec/config/v2/text_finetuning/mnli.yaml b/examples/data2vec/config/v2/text_finetuning/mnli.yaml
new file mode 100644
index 000000000..1a9d6e52f
--- /dev/null
+++ b/examples/data2vec/config/v2/text_finetuning/mnli.yaml
@@ -0,0 +1,60 @@
+# @package _group_
+
+common:
+ fp16: true
+ fp16_init_scale: 4
+ threshold_loss_scale: 1
+ fp16_scale_window: 128
+ log_format: json
+ log_interval: 200
+ user_dir: ${env:PWD}/examples/data2vec
+
+task:
+ _name: sentence_prediction
+ data: ???
+ init_token: 0
+ separator_token: 2
+ num_classes: 3
+ max_positions: 512
+ d2v2_multi: True
+
+checkpoint:
+ best_checkpoint_metric: accuracy
+ maximize_best_checkpoint_metric: true
+ no_epoch_checkpoints: true
+
+distributed_training:
+ find_unused_parameters: true
+ distributed_world_size: 1
+ nprocs_per_node: 1
+ distributed_port: -1
+
+criterion:
+ _name: sentence_prediction
+
+dataset:
+ batch_size: 32
+ required_batch_size_multiple: 1
+ max_tokens: 4400
+ valid_subset: valid,valid1
+ num_workers: 1
+
+optimizer:
+ _name: adam
+ weight_decay: 0.1
+ adam_betas: (0.9,0.98)
+ adam_eps: 1e-06
+
+lr_scheduler:
+ _name: polynomial_decay
+ warmup_updates: 7432
+
+optimization:
+ clip_norm: 0.0
+ lr: [2e-05]
+ max_update: 123873
+ max_epoch: 10
+
+model:
+ _name: data2vec_text_classification
+ model_path: ???
diff --git a/examples/data2vec/config/v2/text_finetuning/mrpc.yaml b/examples/data2vec/config/v2/text_finetuning/mrpc.yaml
new file mode 100644
index 000000000..8f93d9d9e
--- /dev/null
+++ b/examples/data2vec/config/v2/text_finetuning/mrpc.yaml
@@ -0,0 +1,60 @@
+# @package _group_
+
+common:
+ fp16: true
+ fp16_init_scale: 4
+ threshold_loss_scale: 1
+ fp16_scale_window: 128
+ log_format: json
+ log_interval: 200
+ user_dir: ${env:PWD}/examples/data2vec
+
+task:
+ _name: sentence_prediction
+ data: ???
+ init_token: 0
+ separator_token: 2
+ num_classes: 2
+ max_positions: 512
+ d2v2_multi: True
+
+checkpoint:
+ best_checkpoint_metric: acc_and_f1
+ maximize_best_checkpoint_metric: true
+ no_epoch_checkpoints: true
+
+distributed_training:
+ find_unused_parameters: true
+ distributed_world_size: 1
+ nprocs_per_node: 1
+ distributed_port: -1
+
+criterion:
+ _name: sentence_prediction
+ report_acc_and_f1: True
+
+dataset:
+ batch_size: 16
+ required_batch_size_multiple: 1
+ max_tokens: 4400
+ num_workers: 1
+
+optimizer:
+ _name: adam
+ weight_decay: 0.1
+ adam_betas: (0.9,0.98)
+ adam_eps: 1e-06
+
+lr_scheduler:
+ _name: polynomial_decay
+ warmup_updates: 137
+
+optimization:
+ clip_norm: 0.0
+ lr: [2e-05]
+ max_update: 2296
+ max_epoch: 10
+
+model:
+ _name: data2vec_text_classification
+ model_path: ???
diff --git a/examples/data2vec/config/v2/text_finetuning/qnli.yaml b/examples/data2vec/config/v2/text_finetuning/qnli.yaml
new file mode 100644
index 000000000..739fb53b6
--- /dev/null
+++ b/examples/data2vec/config/v2/text_finetuning/qnli.yaml
@@ -0,0 +1,59 @@
+# @package _group_
+
+common:
+ fp16: true
+ fp16_init_scale: 4
+ threshold_loss_scale: 1
+ fp16_scale_window: 128
+ log_format: json
+ log_interval: 200
+ user_dir: ${env:PWD}/examples/data2vec
+
+task:
+ _name: sentence_prediction
+ data: ???
+ init_token: 0
+ separator_token: 2
+ num_classes: 2
+ max_positions: 512
+ d2v2_multi: True
+
+checkpoint:
+ best_checkpoint_metric: accuracy
+ maximize_best_checkpoint_metric: true
+ no_epoch_checkpoints: true
+
+distributed_training:
+ find_unused_parameters: true
+ distributed_world_size: 1
+ nprocs_per_node: 1
+ distributed_port: -1
+
+criterion:
+ _name: sentence_prediction
+
+dataset:
+ batch_size: 32
+ required_batch_size_multiple: 1
+ max_tokens: 4400
+ num_workers: 1
+
+optimizer:
+ _name: adam
+ weight_decay: 0.1
+ adam_betas: (0.9,0.98)
+ adam_eps: 1e-06
+
+lr_scheduler:
+ _name: polynomial_decay
+ warmup_updates: 1986
+
+optimization:
+ clip_norm: 0.0
+ lr: [2e-05]
+ max_update: 33112
+ max_epoch: 10
+
+model:
+ _name: data2vec_text_classification
+ model_path: ???
diff --git a/examples/data2vec/config/v2/text_finetuning/qqp.yaml b/examples/data2vec/config/v2/text_finetuning/qqp.yaml
new file mode 100644
index 000000000..9accbaa52
--- /dev/null
+++ b/examples/data2vec/config/v2/text_finetuning/qqp.yaml
@@ -0,0 +1,60 @@
+# @package _group_
+
+common:
+ fp16: true
+ fp16_init_scale: 4
+ threshold_loss_scale: 1
+ fp16_scale_window: 128
+ log_format: json
+ log_interval: 200
+ user_dir: ${env:PWD}/examples/data2vec
+
+task:
+ _name: sentence_prediction
+ data: ???
+ init_token: 0
+ separator_token: 2
+ num_classes: 2
+ max_positions: 512
+ d2v2_multi: True
+
+checkpoint:
+ best_checkpoint_metric: acc_and_f1
+ maximize_best_checkpoint_metric: true
+ no_epoch_checkpoints: true
+
+distributed_training:
+ find_unused_parameters: true
+ distributed_world_size: 1
+ nprocs_per_node: 1
+ distributed_port: -1
+
+criterion:
+ _name: sentence_prediction
+ report_acc_and_f1: True
+
+dataset:
+ batch_size: 32
+ required_batch_size_multiple: 1
+ max_tokens: 4400
+ num_workers: 1
+
+optimizer:
+ _name: adam
+ weight_decay: 0.1
+ adam_betas: (0.9,0.98)
+ adam_eps: 1e-06
+
+lr_scheduler:
+ _name: polynomial_decay
+ warmup_updates: 28318
+
+optimization:
+ clip_norm: 0.0
+ lr: [2e-05]
+ max_update: 113272
+ max_epoch: 10
+
+model:
+ _name: data2vec_text_classification
+ model_path: ???
diff --git a/examples/data2vec/config/v2/text_finetuning/rte.yaml b/examples/data2vec/config/v2/text_finetuning/rte.yaml
new file mode 100644
index 000000000..ea07764d9
--- /dev/null
+++ b/examples/data2vec/config/v2/text_finetuning/rte.yaml
@@ -0,0 +1,59 @@
+# @package _group_
+
+common:
+ fp16: true
+ fp16_init_scale: 4
+ threshold_loss_scale: 1
+ fp16_scale_window: 128
+ log_format: json
+ log_interval: 200
+ user_dir: ${env:PWD}/examples/data2vec
+
+task:
+ _name: sentence_prediction
+ data: ???
+ init_token: 0
+ separator_token: 2
+ num_classes: 2
+ max_positions: 512
+ d2v2_multi: True
+
+checkpoint:
+ best_checkpoint_metric: accuracy
+ maximize_best_checkpoint_metric: true
+ no_epoch_checkpoints: true
+
+distributed_training:
+ find_unused_parameters: true
+ distributed_world_size: 1
+ nprocs_per_node: 1
+ distributed_port: -1
+
+criterion:
+ _name: sentence_prediction
+
+dataset:
+ batch_size: 16
+ required_batch_size_multiple: 1
+ max_tokens: 4400
+ num_workers: 1
+
+optimizer:
+ _name: adam
+ weight_decay: 0.1
+ adam_betas: (0.9,0.98)
+ adam_eps: 1e-06
+
+lr_scheduler:
+ _name: polynomial_decay
+ warmup_updates: 122
+
+optimization:
+ clip_norm: 0.0
+ lr: [2e-05]
+ max_update: 2036
+ max_epoch: 10
+
+model:
+ _name: data2vec_text_classification
+ model_path: ???
diff --git a/examples/data2vec/config/v2/text_finetuning/run_config/local.yaml b/examples/data2vec/config/v2/text_finetuning/run_config/local.yaml
new file mode 100644
index 000000000..45595f9ee
--- /dev/null
+++ b/examples/data2vec/config/v2/text_finetuning/run_config/local.yaml
@@ -0,0 +1,15 @@
+# @package _global_
+hydra:
+ sweep:
+ dir: ${env:PWD}/tmp_dbg/${now:%H-%M-%S}
+
+distributed_training:
+ distributed_world_size: 1
+ nprocs_per_node: 1
+ distributed_port: -1
+
+common:
+ log_interval: 1
+
+dataset:
+ num_workers: 0
diff --git a/examples/data2vec/config/v2/text_finetuning/sst_2.yaml b/examples/data2vec/config/v2/text_finetuning/sst_2.yaml
new file mode 100644
index 000000000..a273e5b94
--- /dev/null
+++ b/examples/data2vec/config/v2/text_finetuning/sst_2.yaml
@@ -0,0 +1,59 @@
+# @package _group_
+
+common:
+ fp16: true
+ fp16_init_scale: 4
+ threshold_loss_scale: 1
+ fp16_scale_window: 128
+ log_format: json
+ log_interval: 200
+ user_dir: ${env:PWD}/examples/data2vec
+
+task:
+ _name: sentence_prediction
+ data: ???
+ init_token: 0
+ separator_token: 2
+ num_classes: 2
+ max_positions: 512
+ d2v2_multi: True
+
+checkpoint:
+ best_checkpoint_metric: accuracy
+ maximize_best_checkpoint_metric: true
+ no_epoch_checkpoints: true
+
+distributed_training:
+ find_unused_parameters: true
+ distributed_world_size: 1
+ nprocs_per_node: 1
+ distributed_port: -1
+
+criterion:
+ _name: sentence_prediction
+
+dataset:
+ batch_size: 32
+ required_batch_size_multiple: 1
+ max_tokens: 4400
+ num_workers: 1
+
+optimizer:
+ _name: adam
+ weight_decay: 0.1
+ adam_betas: (0.9,0.98)
+ adam_eps: 1e-06
+
+lr_scheduler:
+ _name: polynomial_decay
+ warmup_updates: 1256
+
+optimization:
+ clip_norm: 0.0
+ lr: [2e-05]
+ max_update: 20935
+ max_epoch: 10
+
+model:
+ _name: data2vec_text_classification
+ model_path: ???
diff --git a/examples/data2vec/config/v2/text_finetuning/sts_b.yaml b/examples/data2vec/config/v2/text_finetuning/sts_b.yaml
new file mode 100644
index 000000000..fb009ab95
--- /dev/null
+++ b/examples/data2vec/config/v2/text_finetuning/sts_b.yaml
@@ -0,0 +1,61 @@
+# @package _group_
+
+common:
+ fp16: true
+ fp16_init_scale: 4
+ threshold_loss_scale: 1
+ fp16_scale_window: 128
+ log_format: json
+ log_interval: 200
+ user_dir: ${env:PWD}/examples/data2vec
+
+task:
+ _name: sentence_prediction
+ data: ???
+ init_token: 0
+ separator_token: 2
+ num_classes: 1
+ max_positions: 512
+ d2v2_multi: True
+
+checkpoint:
+ best_checkpoint_metric: pearson_and_spearman
+ maximize_best_checkpoint_metric: true
+ no_epoch_checkpoints: true
+
+distributed_training:
+ find_unused_parameters: true
+ distributed_world_size: 1
+ nprocs_per_node: 1
+ distributed_port: -1
+
+criterion:
+ _name: sentence_prediction
+ regression_target: true
+ report_pearson_and_spearman: True
+
+dataset:
+ batch_size: 16
+ required_batch_size_multiple: 1
+ max_tokens: 4400
+ num_workers: 1
+
+optimizer:
+ _name: adam
+ weight_decay: 0.1
+ adam_betas: (0.9,0.98)
+ adam_eps: 1e-06
+
+lr_scheduler:
+ _name: polynomial_decay
+ warmup_updates: 214
+
+optimization:
+ clip_norm: 0.0
+ lr: [4e-05]
+ max_update: 3598
+ max_epoch: 10
+
+model:
+ _name: data2vec_text_classification
+ model_path: ???
diff --git a/examples/data2vec/config/vision/finetuning/imagenet.yaml b/examples/data2vec/config/vision/finetuning/imagenet.yaml
new file mode 100644
index 000000000..d6d4864cc
--- /dev/null
+++ b/examples/data2vec/config/vision/finetuning/imagenet.yaml
@@ -0,0 +1,52 @@
+# @package _group_
+
+common:
+ fp16: true
+ log_format: json
+ log_interval: 200
+ tensorboard_logdir: tb
+
+checkpoint:
+ save_interval: 1
+ save_interval_updates: 25000
+ keep_interval_updates: 1
+ no_epoch_checkpoints: true
+ best_checkpoint_metric: accuracy
+
+task:
+ _name: image_classification
+ data: /datasets01/imagenet_full_size/061417
+
+dataset:
+ num_workers: 6
+ batch_size: 64
+ skip_invalid_size_inputs_valid_test: true
+ required_batch_size_multiple: 1
+ valid_subset: val
+
+distributed_training:
+ distributed_world_size: 8
+ ddp_backend: c10d
+
+criterion:
+ _name: model
+ log_keys:
+ - correct
+
+optimization:
+ max_update: 100000
+ lr: [0.0005]
+
+optimizer:
+ _name: adam
+ adam_betas: (0.9,0.98)
+ adam_eps: 1e-06
+ weight_decay: 0.01
+
+lr_scheduler:
+ _name: cosine
+ warmup_updates: 10000
+
+model:
+ _name: data2vec_image_classification
+ model_path: ???
diff --git a/examples/data2vec/config/vision/finetuning/mae_imagenet_clean.yaml b/examples/data2vec/config/vision/finetuning/mae_imagenet_clean.yaml
new file mode 100644
index 000000000..17d4c0a8f
--- /dev/null
+++ b/examples/data2vec/config/vision/finetuning/mae_imagenet_clean.yaml
@@ -0,0 +1,65 @@
+# @package _group_
+
+common:
+ fp16: true
+ log_format: json
+ log_interval: 200
+ tensorboard_logdir: tb
+ fp16_no_flatten_grads: true
+
+checkpoint:
+ save_interval: 1
+ save_interval_updates: 25000
+ keep_interval_updates: 1
+ no_epoch_checkpoints: true
+ best_checkpoint_metric: accuracy
+ maximize_best_checkpoint_metric: true
+
+task:
+ _name: mae_image_classification
+ data: /datasets01/imagenet_full_size/061417
+
+dataset:
+ num_workers: 6
+ batch_size: 32
+ skip_invalid_size_inputs_valid_test: true
+ required_batch_size_multiple: 2
+ valid_subset: val
+
+distributed_training:
+ distributed_world_size: 16
+ ddp_backend: c10d
+
+criterion:
+ _name: model
+ log_keys:
+ - correct
+
+optimization:
+ max_update: 250200
+ lr: [0.001]
+
+optimizer:
+ _name: composite
+ dynamic_groups: true
+ groups:
+ default:
+ lr_float: 0.001
+ optimizer:
+ _name: adam
+ adam_betas: [0.9,0.95]
+ weight_decay: 0.05
+ lr_scheduler:
+ _name: cosine
+ warmup_updates: 16000
+ min_lr: 1e-6
+
+
+lr_scheduler: pass_through
+
+model:
+ _name: mae_image_classification
+ mixup: 0.7
+ mixup_prob: 0.9
+
+ model_path: ???
diff --git a/examples/data2vec/config/vision/finetuning/mae_imagenet_huge_clean.yaml b/examples/data2vec/config/vision/finetuning/mae_imagenet_huge_clean.yaml
new file mode 100644
index 000000000..2d2eb57ba
--- /dev/null
+++ b/examples/data2vec/config/vision/finetuning/mae_imagenet_huge_clean.yaml
@@ -0,0 +1,68 @@
+# @package _group_
+
+common:
+ fp16: true
+ log_format: json
+ log_interval: 200
+ tensorboard_logdir: tb
+ fp16_no_flatten_grads: true
+
+checkpoint:
+ save_interval: 1
+ save_interval_updates: 25000
+ keep_interval_updates: 1
+ no_epoch_checkpoints: true
+ best_checkpoint_metric: accuracy
+ maximize_best_checkpoint_metric: true
+
+task:
+ _name: mae_image_classification
+ data: /datasets01/imagenet_full_size/061417
+
+dataset:
+ num_workers: 6
+ batch_size: 32
+ skip_invalid_size_inputs_valid_test: true
+ required_batch_size_multiple: 2
+ valid_subset: val
+
+distributed_training:
+ distributed_world_size: 16
+ ddp_backend: c10d
+
+criterion:
+ _name: model
+ log_keys:
+ - correct
+
+optimization:
+ max_update: 125200
+ lr: [0.0005]
+ clip_norm: 4
+
+optimizer:
+ _name: composite
+ dynamic_groups: true
+ groups:
+ default:
+ lr_float: 0.0005
+ optimizer:
+ _name: adam
+ adam_betas: [0.9,0.95]
+ weight_decay: 0.05
+ lr_scheduler:
+ _name: cosine
+ warmup_updates: 16000
+ min_lr: 1e-20
+
+
+lr_scheduler: pass_through
+
+model:
+ _name: mae_image_classification
+ mixup: 0.7
+ mixup_prob: 0.9
+ layer_decay: 0.75
+ drop_path_rate: 0.2
+
+ model_path: ???
diff --git a/examples/data2vec/config/vision/finetuning/mae_imagenet_large_clean.yaml b/examples/data2vec/config/vision/finetuning/mae_imagenet_large_clean.yaml
new file mode 100644
index 000000000..3a9413cef
--- /dev/null
+++ b/examples/data2vec/config/vision/finetuning/mae_imagenet_large_clean.yaml
@@ -0,0 +1,68 @@
+# @package _group_
+
+common:
+ fp16: true
+ log_format: json
+ log_interval: 200
+ tensorboard_logdir: tb
+ fp16_no_flatten_grads: true
+
+checkpoint:
+ save_interval: 1
+ save_interval_updates: 25000
+ keep_interval_updates: 1
+ no_epoch_checkpoints: true
+ best_checkpoint_metric: accuracy
+ maximize_best_checkpoint_metric: true
+
+task:
+ _name: mae_image_classification
+ data: /datasets01/imagenet_full_size/061417
+
+dataset:
+ num_workers: 6
+ batch_size: 32
+ skip_invalid_size_inputs_valid_test: true
+ required_batch_size_multiple: 2
+ valid_subset: val
+
+distributed_training:
+ distributed_world_size: 16
+ ddp_backend: c10d
+
+criterion:
+ _name: model
+ log_keys:
+ - correct
+
+optimization:
+ max_update: 125200
+ lr: [0.0005]
+ clip_norm: 4
+
+optimizer:
+ _name: composite
+ dynamic_groups: true
+ groups:
+ default:
+ lr_float: 0.0005
+ optimizer:
+ _name: adam
+ adam_betas: [0.9,0.95]
+ weight_decay: 0.05
+ lr_scheduler:
+ _name: cosine
+ warmup_updates: 16000
+ min_lr: 1e-7
+
+
+lr_scheduler: pass_through
+
+model:
+ _name: mae_image_classification
+ mixup: 0.7
+ mixup_prob: 0.9
+ layer_decay: 0.75
+ drop_path_rate: 0.2
+
+ model_path: ???
diff --git a/examples/data2vec/config/vision/finetuning/run_config/local.yaml b/examples/data2vec/config/vision/finetuning/run_config/local.yaml
new file mode 100644
index 000000000..45595f9ee
--- /dev/null
+++ b/examples/data2vec/config/vision/finetuning/run_config/local.yaml
@@ -0,0 +1,15 @@
+# @package _global_
+hydra:
+ sweep:
+ dir: ${env:PWD}/tmp_dbg/${now:%H-%M-%S}
+
+distributed_training:
+ distributed_world_size: 1
+ nprocs_per_node: 1
+ distributed_port: -1
+
+common:
+ log_interval: 1
+
+dataset:
+ num_workers: 0
diff --git a/examples/data2vec/config/vision/finetuning/run_config/slurm_1.yaml b/examples/data2vec/config/vision/finetuning/run_config/slurm_1.yaml
new file mode 100644
index 000000000..732f01889
--- /dev/null
+++ b/examples/data2vec/config/vision/finetuning/run_config/slurm_1.yaml
@@ -0,0 +1,37 @@
+# @package _global_
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: ':'
+ item_sep: '/'
+ exclude_keys:
+ - run_config
+ - distributed_training.distributed_port
+ - distributed_training.distributed_world_size
+ - model.pretrained_model_path
+ - model.target_network_path
+ - next_script
+ - task.cache_in_scratch
+ - task.data
+ - checkpoint.save_interval_updates
+ - checkpoint.keep_interval_updates
+ - checkpoint.save_on_overflow
+ - common.log_interval
+ - common.user_dir
+ sweep:
+ dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
+ subdir: ''
+ launcher:
+ submitit_folder: ${hydra.sweep.dir}
+ timeout_min: 4320
+ cpus_per_task: 80
+ gpus_per_node: 8
+ tasks_per_node: 1
+ mem_gb: 450
+ nodes: 1
+ name: ${env:PREFIX}_${hydra.job.config_name}
+ partition: devlab,learnlab,learnfair,scavenge
+ constraint: volta32gb,ib4
+ max_num_timeout: 30
diff --git a/examples/data2vec/config/vision/finetuning/run_config/slurm_1_aws.yaml b/examples/data2vec/config/vision/finetuning/run_config/slurm_1_aws.yaml
new file mode 100644
index 000000000..e2bab5675
--- /dev/null
+++ b/examples/data2vec/config/vision/finetuning/run_config/slurm_1_aws.yaml
@@ -0,0 +1,36 @@
+# @package _global_
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: ':'
+ item_sep: '/'
+ exclude_keys:
+ - run_config
+ - distributed_training.distributed_port
+ - distributed_training.distributed_world_size
+ - model.pretrained_model_path
+ - model.target_network_path
+ - next_script
+ - task.cache_in_scratch
+ - task.data
+ - checkpoint.save_interval_updates
+ - checkpoint.keep_interval_updates
+ - checkpoint.save_on_overflow
+ - common.log_interval
+ - common.user_dir
+ sweep:
+ dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
+ subdir: ''
+ launcher:
+ submitit_folder: ${hydra.sweep.dir}
+ timeout_min: 4320
+ cpus_per_task: 80
+ gpus_per_node: 8
+ tasks_per_node: 1
+ mem_gb: 0
+ nodes: 1
+ name: ${env:PREFIX}_${hydra.job.config_name}
+ partition: wav2vec,learnlab,learnfair
+ max_num_timeout: 30
diff --git a/examples/data2vec/config/vision/finetuning/run_config/slurm_2.yaml b/examples/data2vec/config/vision/finetuning/run_config/slurm_2.yaml
new file mode 100644
index 000000000..c8b0f02a9
--- /dev/null
+++ b/examples/data2vec/config/vision/finetuning/run_config/slurm_2.yaml
@@ -0,0 +1,38 @@
+# @package _global_
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: ':'
+ item_sep: '/'
+ exclude_keys:
+ - run_config
+ - distributed_training.distributed_port
+ - distributed_training.distributed_world_size
+ - model.pretrained_model_path
+ - model.target_network_path
+ - next_script
+ - task.cache_in_scratch
+ - task.data
+ - checkpoint.save_interval_updates
+ - checkpoint.keep_interval_updates
+ - checkpoint.save_on_overflow
+ - common.log_interval
+ - common.user_dir
+ - task.local_cache_path
+ sweep:
+ dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
+ subdir: ''
+ launcher:
+ submitit_folder: ${hydra.sweep.dir}
+ timeout_min: 4320
+ cpus_per_task: 10
+ gpus_per_node: 8
+ tasks_per_node: 8
+ mem_gb: 450
+ nodes: 2
+ name: ${env:PREFIX}_${hydra.job.config_name}
+ partition: devlab,learnlab,learnfair,scavenge
+ constraint: volta32gb,ib4
+ max_num_timeout: 30
diff --git a/examples/data2vec/config/vision/finetuning/run_config/slurm_2_aws.yaml b/examples/data2vec/config/vision/finetuning/run_config/slurm_2_aws.yaml
new file mode 100644
index 000000000..93d0d9c20
--- /dev/null
+++ b/examples/data2vec/config/vision/finetuning/run_config/slurm_2_aws.yaml
@@ -0,0 +1,38 @@
+# @package _global_
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: ':'
+ item_sep: '/'
+ exclude_keys:
+ - run_config
+ - distributed_training.distributed_port
+ - distributed_training.distributed_world_size
+ - model.pretrained_model_path
+ - model.target_network_path
+ - next_script
+ - task.cache_in_scratch
+ - task.data
+ - checkpoint.save_interval_updates
+ - checkpoint.keep_interval_updates
+ - checkpoint.save_on_overflow
+ - common.log_interval
+ - common.user_dir
+ - task.local_cache_path
+ - model.model_path
+ sweep:
+ dir: /fsx-wav2vec/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
+ subdir: ''
+ launcher:
+ submitit_folder: ${hydra.sweep.dir}
+ timeout_min: 4320
+ cpus_per_task: 10
+ gpus_per_node: 8
+ tasks_per_node: 8
+ mem_gb: 0
+ nodes: 2
+ name: ${env:PREFIX}_${hydra.job.config_name}
+ partition: wav2vec,learnlab,learnfair
+ max_num_timeout: 30
diff --git a/examples/data2vec/config/vision/finetuning/run_config/slurm_3.yaml b/examples/data2vec/config/vision/finetuning/run_config/slurm_3.yaml
new file mode 100644
index 000000000..14b47d14e
--- /dev/null
+++ b/examples/data2vec/config/vision/finetuning/run_config/slurm_3.yaml
@@ -0,0 +1,36 @@
+# @package _global_
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: ':'
+ item_sep: '/'
+ exclude_keys:
+ - run_config
+ - distributed_training.distributed_port
+ - distributed_training.distributed_world_size
+ - model.pretrained_model_path
+ - model.target_network_path
+ - next_script
+ - task.cache_in_scratch
+ - task.data
+ - checkpoint.save_interval_updates
+ - checkpoint.keep_interval_updates
+ - checkpoint.save_on_overflow
+ - common.log_interval
+ sweep:
+ dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
+ subdir: ''
+ launcher:
+ submitit_folder: ${hydra.sweep.dir}
+ timeout_min: 4320
+ cpus_per_task: 80
+ gpus_per_node: 8
+ tasks_per_node: 1
+ mem_gb: 450
+ nodes: 3
+ name: ${env:PREFIX}_${hydra.job.config_name}
+ partition: devlab,learnlab,learnfair,scavenge
+ constraint: volta32gb,ib4
+ max_num_timeout: 30
diff --git a/examples/data2vec/config/vision/finetuning/run_config/slurm_4.yaml b/examples/data2vec/config/vision/finetuning/run_config/slurm_4.yaml
new file mode 100644
index 000000000..c54d735fb
--- /dev/null
+++ b/examples/data2vec/config/vision/finetuning/run_config/slurm_4.yaml
@@ -0,0 +1,36 @@
+# @package _global_
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: ':'
+ item_sep: '/'
+ exclude_keys:
+ - run_config
+ - distributed_training.distributed_port
+ - distributed_training.distributed_world_size
+ - model.pretrained_model_path
+ - model.target_network_path
+ - next_script
+ - task.cache_in_scratch
+ - task.data
+ - checkpoint.save_interval_updates
+ - checkpoint.keep_interval_updates
+ - checkpoint.save_on_overflow
+ - common.log_interval
+ sweep:
+ dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
+ subdir: ''
+ launcher:
+ submitit_folder: ${hydra.sweep.dir}
+ timeout_min: 4320
+ cpus_per_task: 10
+ gpus_per_node: 8
+ tasks_per_node: 8
+ mem_gb: 450
+ nodes: 4
+ name: ${env:PREFIX}_${hydra.job.config_name}
+ partition: devlab,learnlab,learnfair,scavenge
+ constraint: volta32gb,ib4
+ max_num_timeout: 30
diff --git a/examples/data2vec/config/vision/finetuning/run_config/slurm_4_aws.yaml b/examples/data2vec/config/vision/finetuning/run_config/slurm_4_aws.yaml
new file mode 100644
index 000000000..d5d11cb75
--- /dev/null
+++ b/examples/data2vec/config/vision/finetuning/run_config/slurm_4_aws.yaml
@@ -0,0 +1,36 @@
+# @package _global_
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: ':'
+ item_sep: '/'
+ exclude_keys:
+ - run_config
+ - distributed_training.distributed_port
+ - distributed_training.distributed_world_size
+ - model.pretrained_model_path
+ - model.target_network_path
+ - next_script
+ - task.cache_in_scratch
+ - task.data
+ - checkpoint.save_interval_updates
+ - checkpoint.keep_interval_updates
+ - checkpoint.save_on_overflow
+ - common.log_interval
+ - common.user_dir
+ sweep:
+ dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
+ subdir: ''
+ launcher:
+ submitit_folder: ${hydra.sweep.dir}
+ timeout_min: 4320
+ cpus_per_task: 10
+ gpus_per_node: 8
+ tasks_per_node: 8
+ mem_gb: 0
+ nodes: 4
+ name: ${env:PREFIX}_${hydra.job.config_name}
+ partition: wav2vec,learnlab,learnfair
+ max_num_timeout: 30
diff --git a/examples/data2vec/config/vision/finetuning/run_config/slurm_6_aws.yaml b/examples/data2vec/config/vision/finetuning/run_config/slurm_6_aws.yaml
new file mode 100644
index 000000000..906f08a60
--- /dev/null
+++ b/examples/data2vec/config/vision/finetuning/run_config/slurm_6_aws.yaml
@@ -0,0 +1,36 @@
+# @package _global_
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: ':'
+ item_sep: '/'
+ exclude_keys:
+ - run_config
+ - distributed_training.distributed_port
+ - distributed_training.distributed_world_size
+ - model.pretrained_model_path
+ - model.target_network_path
+ - next_script
+ - task.cache_in_scratch
+ - task.data
+ - checkpoint.save_interval_updates
+ - checkpoint.keep_interval_updates
+ - checkpoint.save_on_overflow
+ - common.log_interval
+ - common.user_dir
+ sweep:
+ dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
+ subdir: ''
+ launcher:
+ submitit_folder: ${hydra.sweep.dir}
+ timeout_min: 4320
+ cpus_per_task: 10
+ gpus_per_node: 8
+ tasks_per_node: 8
+ mem_gb: 0
+ nodes: 6
+ name: ${env:PREFIX}_${hydra.job.config_name}
+ partition: wav2vec,learnlab,learnfair
+ max_num_timeout: 30
diff --git a/examples/data2vec/config/vision/finetuning/run_config/slurm_8_aws.yaml b/examples/data2vec/config/vision/finetuning/run_config/slurm_8_aws.yaml
new file mode 100644
index 000000000..d60e13f8b
--- /dev/null
+++ b/examples/data2vec/config/vision/finetuning/run_config/slurm_8_aws.yaml
@@ -0,0 +1,36 @@
+# @package _global_
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: ':'
+ item_sep: '/'
+ exclude_keys:
+ - run_config
+ - distributed_training.distributed_port
+ - distributed_training.distributed_world_size
+ - model.pretrained_model_path
+ - model.target_network_path
+ - next_script
+ - task.cache_in_scratch
+ - task.data
+ - checkpoint.save_interval_updates
+ - checkpoint.keep_interval_updates
+ - checkpoint.save_on_overflow
+ - common.log_interval
+ - common.user_dir
+ sweep:
+ dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
+ subdir: ''
+ launcher:
+ submitit_folder: ${hydra.sweep.dir}
+ timeout_min: 4320
+ cpus_per_task: 10
+ gpus_per_node: 8
+ tasks_per_node: 8
+ mem_gb: 0
+ nodes: 8
+ name: ${env:PREFIX}_${hydra.job.config_name}
+ partition: wav2vec,learnlab,learnfair
+ max_num_timeout: 30
diff --git a/examples/data2vec/config/vision/pretraining/base_imagenet.yaml b/examples/data2vec/config/vision/pretraining/base_imagenet.yaml
new file mode 100644
index 000000000..9bfc0f32b
--- /dev/null
+++ b/examples/data2vec/config/vision/pretraining/base_imagenet.yaml
@@ -0,0 +1,52 @@
+# @package _group_
+
+common:
+ fp16: true
+ log_format: json
+ log_interval: 200
+ tensorboard_logdir: tb
+
+checkpoint:
+ save_interval: 5
+ save_interval_updates: 25000
+ keep_interval_updates: 1
+ no_epoch_checkpoints: true
+
+task:
+ _name: image_pretraining
+ data: /datasets01/imagenet_full_size/061417/
+
+dataset:
+ num_workers: 6
+ batch_size: 64
+ skip_invalid_size_inputs_valid_test: true
+ required_batch_size_multiple: 1
+ disable_validation: true
+
+distributed_training:
+ distributed_world_size: 16
+ ddp_backend: c10d
+
+criterion:
+ _name: model
+ log_keys:
+ - ema_decay
+ - target_var
+ - pred_var
+
+optimization:
+ max_update: 400000
+ lr: [0.0005]
+
+optimizer:
+ _name: adam
+ adam_betas: (0.9,0.98)
+ adam_eps: 1e-06
+ weight_decay: 0.01
+
+lr_scheduler:
+ _name: cosine
+ warmup_updates: 10000
+
+model:
+ _name: data2vec_vision
diff --git a/examples/data2vec/config/vision/pretraining/base_imagenet_d2v1.yaml b/examples/data2vec/config/vision/pretraining/base_imagenet_d2v1.yaml
new file mode 100644
index 000000000..5fd399b11
--- /dev/null
+++ b/examples/data2vec/config/vision/pretraining/base_imagenet_d2v1.yaml
@@ -0,0 +1,64 @@
+# @package _group_
+
+common:
+ fp16: true
+ log_format: json
+ log_interval: 200
+ tensorboard_logdir: tb
+
+checkpoint:
+ save_interval: 5
+ save_interval_updates: 25000
+ keep_interval_updates: 1
+ no_epoch_checkpoints: true
+
+task:
+ _name: image_pretraining
+ data: /datasets01/imagenet_full_size/061417
+
+dataset:
+ num_workers: 6
+ batch_size: 128
+ skip_invalid_size_inputs_valid_test: true
+ required_batch_size_multiple: 2
+ disable_validation: true
+
+distributed_training:
+ distributed_world_size: 16
+ ddp_backend: legacy_ddp
+
+criterion:
+ _name: model
+ log_keys:
+ - ema_decay
+ - target_var
+ - pred_var
+
+optimization:
+ max_update: 375300 #300*1251
+ lr: [0.0005]
+ clip_norm: 3.0
+
+optimizer:
+ _name: adam
+ adam_betas: (0.9,0.999)
+ adam_eps: 1e-08
+ weight_decay: 0.05
+
+lr_scheduler:
+ _name: cosine
+ warmup_updates: 12510 # it should be 10 epochs
+
+model:
+ _name: data2vec_vision
+
+ attention_dropout: 0.05
+
+ ema_decay: 0.999
+ ema_end_decay: 0.9998
+ layer_norm_targets: True
+ average_top_k_layers: 6
+
+ loss_beta: 2.0
+
+ drop_path: 0.25
diff --git a/examples/data2vec/config/vision/pretraining/base_mae_imagenet.yaml b/examples/data2vec/config/vision/pretraining/base_mae_imagenet.yaml
new file mode 100644
index 000000000..d7872b5e0
--- /dev/null
+++ b/examples/data2vec/config/vision/pretraining/base_mae_imagenet.yaml
@@ -0,0 +1,64 @@
+# @package _group_
+
+common:
+ fp16: true
+ log_format: json
+ log_interval: 200
+ tensorboard_logdir: tb
+ fp16_no_flatten_grads: true
+
+checkpoint:
+ save_interval: 5
+ save_interval_updates: 25000
+ keep_interval_updates: 1
+ no_epoch_checkpoints: true
+
+task:
+ _name: mae_image_pretraining
+ data: /datasets01/imagenet_full_size/061417/
+ rebuild_batches: true
+
+dataset:
+ num_workers: 6
+ batch_size: 64
+ skip_invalid_size_inputs_valid_test: true
+ required_batch_size_multiple: 1
+ disable_validation: true
+
+distributed_training:
+ distributed_world_size: 16
+ ddp_backend: c10d
+
+criterion:
+ _name: model
+
+optimization:
+ max_update: 375300
+ lr: [0.0006]
+
+optimizer:
+ _name: composite
+ groups:
+ with_decay:
+ lr_float: 6e-4
+ optimizer:
+ _name: adam
+ adam_betas: [0.9,0.95]
+ weight_decay: 0.05
+ lr_scheduler:
+ _name: cosine
+ warmup_updates: 50040
+ no_decay:
+ lr_float: 6e-4
+ optimizer:
+ _name: adam
+ adam_betas: [0.9,0.95]
+ weight_decay: 0
+ lr_scheduler:
+ _name: cosine
+ warmup_updates: 50040
+
+lr_scheduler: pass_through
+
+model:
+ _name: mae
diff --git a/examples/data2vec/config/vision/pretraining/run_config/local.yaml b/examples/data2vec/config/vision/pretraining/run_config/local.yaml
new file mode 100644
index 000000000..45595f9ee
--- /dev/null
+++ b/examples/data2vec/config/vision/pretraining/run_config/local.yaml
@@ -0,0 +1,15 @@
+# @package _global_
+hydra:
+ sweep:
+ dir: ${env:PWD}/tmp_dbg/${now:%H-%M-%S}
+
+distributed_training:
+ distributed_world_size: 1
+ nprocs_per_node: 1
+ distributed_port: -1
+
+common:
+ log_interval: 1
+
+dataset:
+ num_workers: 0
diff --git a/examples/data2vec/config/vision/pretraining/run_config/slurm_1.yaml b/examples/data2vec/config/vision/pretraining/run_config/slurm_1.yaml
new file mode 100644
index 000000000..732f01889
--- /dev/null
+++ b/examples/data2vec/config/vision/pretraining/run_config/slurm_1.yaml
@@ -0,0 +1,37 @@
+# @package _global_
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: ':'
+ item_sep: '/'
+ exclude_keys:
+ - run_config
+ - distributed_training.distributed_port
+ - distributed_training.distributed_world_size
+ - model.pretrained_model_path
+ - model.target_network_path
+ - next_script
+ - task.cache_in_scratch
+ - task.data
+ - checkpoint.save_interval_updates
+ - checkpoint.keep_interval_updates
+ - checkpoint.save_on_overflow
+ - common.log_interval
+ - common.user_dir
+ sweep:
+ dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
+ subdir: ''
+ launcher:
+ submitit_folder: ${hydra.sweep.dir}
+ timeout_min: 4320
+ cpus_per_task: 80
+ gpus_per_node: 8
+ tasks_per_node: 1
+ mem_gb: 450
+ nodes: 1
+ name: ${env:PREFIX}_${hydra.job.config_name}
+ partition: devlab,learnlab,learnfair,scavenge
+ constraint: volta32gb,ib4
+ max_num_timeout: 30
diff --git a/examples/data2vec/config/vision/pretraining/run_config/slurm_1_aws.yaml b/examples/data2vec/config/vision/pretraining/run_config/slurm_1_aws.yaml
new file mode 100644
index 000000000..e2bab5675
--- /dev/null
+++ b/examples/data2vec/config/vision/pretraining/run_config/slurm_1_aws.yaml
@@ -0,0 +1,36 @@
+# @package _global_
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: ':'
+ item_sep: '/'
+ exclude_keys:
+ - run_config
+ - distributed_training.distributed_port
+ - distributed_training.distributed_world_size
+ - model.pretrained_model_path
+ - model.target_network_path
+ - next_script
+ - task.cache_in_scratch
+ - task.data
+ - checkpoint.save_interval_updates
+ - checkpoint.keep_interval_updates
+ - checkpoint.save_on_overflow
+ - common.log_interval
+ - common.user_dir
+ sweep:
+ dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
+ subdir: ''
+ launcher:
+ submitit_folder: ${hydra.sweep.dir}
+ timeout_min: 4320
+ cpus_per_task: 80
+ gpus_per_node: 8
+ tasks_per_node: 1
+ mem_gb: 0
+ nodes: 1
+ name: ${env:PREFIX}_${hydra.job.config_name}
+ partition: wav2vec,learnlab,learnfair
+ max_num_timeout: 30
diff --git a/examples/data2vec/config/vision/pretraining/run_config/slurm_2.yaml b/examples/data2vec/config/vision/pretraining/run_config/slurm_2.yaml
new file mode 100644
index 000000000..c8b0f02a9
--- /dev/null
+++ b/examples/data2vec/config/vision/pretraining/run_config/slurm_2.yaml
@@ -0,0 +1,38 @@
+# @package _global_
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: ':'
+ item_sep: '/'
+ exclude_keys:
+ - run_config
+ - distributed_training.distributed_port
+ - distributed_training.distributed_world_size
+ - model.pretrained_model_path
+ - model.target_network_path
+ - next_script
+ - task.cache_in_scratch
+ - task.data
+ - checkpoint.save_interval_updates
+ - checkpoint.keep_interval_updates
+ - checkpoint.save_on_overflow
+ - common.log_interval
+ - common.user_dir
+ - task.local_cache_path
+ sweep:
+ dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
+ subdir: ''
+ launcher:
+ submitit_folder: ${hydra.sweep.dir}
+ timeout_min: 4320
+ cpus_per_task: 10
+ gpus_per_node: 8
+ tasks_per_node: 8
+ mem_gb: 450
+ nodes: 2
+ name: ${env:PREFIX}_${hydra.job.config_name}
+ partition: devlab,learnlab,learnfair,scavenge
+ constraint: volta32gb,ib4
+ max_num_timeout: 30
diff --git a/examples/data2vec/config/vision/pretraining/run_config/slurm_2_aws.yaml b/examples/data2vec/config/vision/pretraining/run_config/slurm_2_aws.yaml
new file mode 100644
index 000000000..032e53a30
--- /dev/null
+++ b/examples/data2vec/config/vision/pretraining/run_config/slurm_2_aws.yaml
@@ -0,0 +1,37 @@
+# @package _global_
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: ':'
+ item_sep: '/'
+ exclude_keys:
+ - run_config
+ - distributed_training.distributed_port
+ - distributed_training.distributed_world_size
+ - model.pretrained_model_path
+ - model.target_network_path
+ - next_script
+ - task.cache_in_scratch
+ - task.data
+ - checkpoint.save_interval_updates
+ - checkpoint.keep_interval_updates
+ - checkpoint.save_on_overflow
+ - common.log_interval
+ - common.user_dir
+ - task.local_cache_path
+ sweep:
+ dir: /fsx-wav2vec/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
+ subdir: ''
+ launcher:
+ submitit_folder: ${hydra.sweep.dir}
+ timeout_min: 4320
+ cpus_per_task: 10
+ gpus_per_node: 8
+ tasks_per_node: 8
+ mem_gb: 0
+ nodes: 2
+ name: ${env:PREFIX}_${hydra.job.config_name}
+ partition: wav2vec,learnlab,learnfair
+ max_num_timeout: 30
diff --git a/examples/data2vec/config/vision/pretraining/run_config/slurm_3.yaml b/examples/data2vec/config/vision/pretraining/run_config/slurm_3.yaml
new file mode 100644
index 000000000..14b47d14e
--- /dev/null
+++ b/examples/data2vec/config/vision/pretraining/run_config/slurm_3.yaml
@@ -0,0 +1,36 @@
+# @package _global_
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: ':'
+ item_sep: '/'
+ exclude_keys:
+ - run_config
+ - distributed_training.distributed_port
+ - distributed_training.distributed_world_size
+ - model.pretrained_model_path
+ - model.target_network_path
+ - next_script
+ - task.cache_in_scratch
+ - task.data
+ - checkpoint.save_interval_updates
+ - checkpoint.keep_interval_updates
+ - checkpoint.save_on_overflow
+ - common.log_interval
+ sweep:
+ dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
+ subdir: ''
+ launcher:
+ submitit_folder: ${hydra.sweep.dir}
+ timeout_min: 4320
+ cpus_per_task: 80
+ gpus_per_node: 8
+ tasks_per_node: 1
+ mem_gb: 450
+ nodes: 3
+ name: ${env:PREFIX}_${hydra.job.config_name}
+ partition: devlab,learnlab,learnfair,scavenge
+ constraint: volta32gb,ib4
+ max_num_timeout: 30
diff --git a/examples/data2vec/config/vision/pretraining/run_config/slurm_4.yaml b/examples/data2vec/config/vision/pretraining/run_config/slurm_4.yaml
new file mode 100644
index 000000000..c54d735fb
--- /dev/null
+++ b/examples/data2vec/config/vision/pretraining/run_config/slurm_4.yaml
@@ -0,0 +1,36 @@
+# @package _global_
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: ':'
+ item_sep: '/'
+ exclude_keys:
+ - run_config
+ - distributed_training.distributed_port
+ - distributed_training.distributed_world_size
+ - model.pretrained_model_path
+ - model.target_network_path
+ - next_script
+ - task.cache_in_scratch
+ - task.data
+ - checkpoint.save_interval_updates
+ - checkpoint.keep_interval_updates
+ - checkpoint.save_on_overflow
+ - common.log_interval
+ sweep:
+ dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
+ subdir: ''
+ launcher:
+ submitit_folder: ${hydra.sweep.dir}
+ timeout_min: 4320
+ cpus_per_task: 10
+ gpus_per_node: 8
+ tasks_per_node: 8
+ mem_gb: 450
+ nodes: 4
+ name: ${env:PREFIX}_${hydra.job.config_name}
+ partition: devlab,learnlab,learnfair,scavenge
+ constraint: volta32gb,ib4
+ max_num_timeout: 30
diff --git a/examples/data2vec/config/vision/pretraining/run_config/slurm_4_aws.yaml b/examples/data2vec/config/vision/pretraining/run_config/slurm_4_aws.yaml
new file mode 100644
index 000000000..d5d11cb75
--- /dev/null
+++ b/examples/data2vec/config/vision/pretraining/run_config/slurm_4_aws.yaml
@@ -0,0 +1,36 @@
+# @package _global_
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: ':'
+ item_sep: '/'
+ exclude_keys:
+ - run_config
+ - distributed_training.distributed_port
+ - distributed_training.distributed_world_size
+ - model.pretrained_model_path
+ - model.target_network_path
+ - next_script
+ - task.cache_in_scratch
+ - task.data
+ - checkpoint.save_interval_updates
+ - checkpoint.keep_interval_updates
+ - checkpoint.save_on_overflow
+ - common.log_interval
+ - common.user_dir
+ sweep:
+ dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
+ subdir: ''
+ launcher:
+ submitit_folder: ${hydra.sweep.dir}
+ timeout_min: 4320
+ cpus_per_task: 10
+ gpus_per_node: 8
+ tasks_per_node: 8
+ mem_gb: 0
+ nodes: 4
+ name: ${env:PREFIX}_${hydra.job.config_name}
+ partition: wav2vec,learnlab,learnfair
+ max_num_timeout: 30
diff --git a/examples/data2vec/config/vision/pretraining/run_config/slurm_6_aws.yaml b/examples/data2vec/config/vision/pretraining/run_config/slurm_6_aws.yaml
new file mode 100644
index 000000000..906f08a60
--- /dev/null
+++ b/examples/data2vec/config/vision/pretraining/run_config/slurm_6_aws.yaml
@@ -0,0 +1,36 @@
+# @package _global_
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: ':'
+ item_sep: '/'
+ exclude_keys:
+ - run_config
+ - distributed_training.distributed_port
+ - distributed_training.distributed_world_size
+ - model.pretrained_model_path
+ - model.target_network_path
+ - next_script
+ - task.cache_in_scratch
+ - task.data
+ - checkpoint.save_interval_updates
+ - checkpoint.keep_interval_updates
+ - checkpoint.save_on_overflow
+ - common.log_interval
+ - common.user_dir
+ sweep:
+ dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
+ subdir: ''
+ launcher:
+ submitit_folder: ${hydra.sweep.dir}
+ timeout_min: 4320
+ cpus_per_task: 10
+ gpus_per_node: 8
+ tasks_per_node: 8
+ mem_gb: 0
+ nodes: 6
+ name: ${env:PREFIX}_${hydra.job.config_name}
+ partition: wav2vec,learnlab,learnfair
+ max_num_timeout: 30
diff --git a/examples/data2vec/config/vision/pretraining/run_config/slurm_8_aws.yaml b/examples/data2vec/config/vision/pretraining/run_config/slurm_8_aws.yaml
new file mode 100644
index 000000000..d60e13f8b
--- /dev/null
+++ b/examples/data2vec/config/vision/pretraining/run_config/slurm_8_aws.yaml
@@ -0,0 +1,36 @@
+# @package _global_
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: ':'
+ item_sep: '/'
+ exclude_keys:
+ - run_config
+ - distributed_training.distributed_port
+ - distributed_training.distributed_world_size
+ - model.pretrained_model_path
+ - model.target_network_path
+ - next_script
+ - task.cache_in_scratch
+ - task.data
+ - checkpoint.save_interval_updates
+ - checkpoint.keep_interval_updates
+ - checkpoint.save_on_overflow
+ - common.log_interval
+ - common.user_dir
+ sweep:
+ dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
+ subdir: ''
+ launcher:
+ submitit_folder: ${hydra.sweep.dir}
+ timeout_min: 4320
+ cpus_per_task: 10
+ gpus_per_node: 8
+ tasks_per_node: 8
+ mem_gb: 0
+ nodes: 8
+ name: ${env:PREFIX}_${hydra.job.config_name}
+ partition: wav2vec,learnlab,learnfair
+ max_num_timeout: 30
diff --git a/examples/data2vec/data/__init__.py b/examples/data2vec/data/__init__.py
new file mode 100644
index 000000000..d76112bfc
--- /dev/null
+++ b/examples/data2vec/data/__init__.py
@@ -0,0 +1,17 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+from .image_dataset import ImageDataset
+from .path_dataset import PathDataset
+from .mae_image_dataset import MaeImageDataset
+from .mae_finetuning_image_dataset import MaeFinetuningImageDataset
+
+
+__all__ = [
+ "ImageDataset",
+ "MaeImageDataset",
+ "MaeFinetuningImageDataset",
+ "PathDataset",
+]
\ No newline at end of file
diff --git a/examples/data2vec/data/add_class_target_dataset.py b/examples/data2vec/data/add_class_target_dataset.py
new file mode 100644
index 000000000..c346c83e5
--- /dev/null
+++ b/examples/data2vec/data/add_class_target_dataset.py
@@ -0,0 +1,63 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+
+from fairseq.data import BaseWrapperDataset, data_utils
+
+
+class AddClassTargetDataset(BaseWrapperDataset):
+ def __init__(
+ self,
+ dataset,
+ labels,
+ multi_class,
+ num_classes=None,
+ label_indices=None,
+ add_to_input=True,
+ ):
+ super().__init__(dataset)
+
+ self.label_indices = label_indices
+ self.labels = labels
+ self.multi_class = multi_class
+ self.add_to_input = add_to_input
+ if num_classes is None and multi_class:
+ assert self.label_indices is not None
+ num_classes = len(self.label_indices)
+
+ self.num_classes = num_classes
+
+ def __getitem__(self, index):
+ item = self.dataset[index]
+ item_labels = self.labels[index]
+ if self.multi_class:
+ item["label"] = torch.zeros(self.num_classes)
+ for il in item_labels:
+ if self.label_indices is not None:
+ il = self.label_indices[il]
+ item["label"][il] = 1.0
+ else:
+ item["label"] = torch.tensor(
+ self.labels[index]
+ if self.label_indices is None
+ else self.label_indices[self.labels[index]]
+ )
+
+ return item
+
+ def collater(self, samples):
+ collated = self.dataset.collater(samples)
+ if len(collated) == 0:
+ return collated
+
+ indices = set(collated["id"].tolist())
+ target = [s["label"] for s in samples if s["id"] in indices]
+ collated["label"] = torch.stack(target, dim=0)
+
+ if self.add_to_input:
+ collated["net_input"]["label"] = collated["label"]
+
+ return collated
diff --git a/examples/data2vec/data/image_dataset.py b/examples/data2vec/data/image_dataset.py
new file mode 100644
index 000000000..7f551057e
--- /dev/null
+++ b/examples/data2vec/data/image_dataset.py
@@ -0,0 +1,127 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+
+import logging
+
+import numpy as np
+import os
+from typing import Optional, Callable, Set
+
+import torch
+
+from torchvision.datasets.vision import VisionDataset
+from torchvision.transforms import ToTensor
+
+from fairseq.data import FairseqDataset
+
+
+logger = logging.getLogger(__name__)
+
+
+class ImageDataset(FairseqDataset, VisionDataset):
+ def __init__(
+ self,
+ root: str,
+ extensions: Set[str],
+ load_classes: bool,
+ transform: Optional[Callable] = None,
+ shuffle=True,
+ ):
+ FairseqDataset.__init__(self)
+ VisionDataset.__init__(self, root=root, transform=transform)
+
+ self.shuffle = shuffle
+ self.tensor_transform = ToTensor()
+
+ self.classes = None
+ self.labels = None
+ if load_classes:
+ classes = [d.name for d in os.scandir(root) if d.is_dir()]
+ classes.sort()
+ self.classes = {cls_name: i for i, cls_name in enumerate(classes)}
+ logger.info(f"loaded {len(self.classes)} classes")
+ self.labels = []
+
+ def walk_path(root_path):
+ for root, _, fnames in sorted(os.walk(root_path, followlinks=True)):
+ for fname in sorted(fnames):
+ fname_ext = os.path.splitext(fname)
+ if fname_ext[-1].lower() not in extensions:
+ continue
+
+ path = os.path.join(root, fname)
+ yield path
+
+ logger.info(f"finding images in {root}")
+ if self.classes is not None:
+ self.files = []
+ self.labels = []
+ for c, i in self.classes.items():
+ for f in walk_path(os.path.join(root, c)):
+ self.files.append(f)
+ self.labels.append(i)
+ else:
+ self.files = [f for f in walk_path(root)]
+
+ logger.info(f"loaded {len(self.files)} examples")
+
+ def __getitem__(self, index):
+ from PIL import Image
+
+ fpath = self.files[index]
+
+ with open(fpath, "rb") as f:
+ img = Image.open(f).convert("RGB")
+
+ if self.transform is None:
+ img = self.tensor_transform(img)
+ else:
+ img = self.transform(img)
+ assert torch.is_tensor(img)
+
+ res = {"id": index, "img": img}
+
+ if self.labels is not None:
+ res["label"] = self.labels[index]
+
+ return res
+
+ def __len__(self):
+ return len(self.files)
+
+ def collater(self, samples):
+ if len(samples) == 0:
+ return {}
+
+ collated_img = torch.stack([s["img"] for s in samples], dim=0)
+
+ res = {
+ "id": torch.LongTensor([s["id"] for s in samples]),
+ "net_input": {
+ "img": collated_img,
+ },
+ }
+
+ if "label" in samples[0]:
+ res["net_input"]["label"] = torch.LongTensor([s["label"] for s in samples])
+
+ return res
+
+ def num_tokens(self, index):
+ return 1
+
+ def size(self, index):
+ return 1
+
+ def ordered_indices(self):
+ """Return an ordered list of indices. Batches will be constructed based
+ on this order."""
+ if self.shuffle:
+ order = [np.random.permutation(len(self))]
+ else:
+ order = [np.arange(len(self))]
+
+ return order[0]
diff --git a/examples/data2vec/data/mae_finetuning_image_dataset.py b/examples/data2vec/data/mae_finetuning_image_dataset.py
new file mode 100644
index 000000000..28cbcb38a
--- /dev/null
+++ b/examples/data2vec/data/mae_finetuning_image_dataset.py
@@ -0,0 +1,135 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+
+import logging
+
+import numpy as np
+import os
+
+import torch
+
+from torchvision import datasets, transforms
+
+from timm.data import create_transform
+from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
+import PIL
+
+from fairseq.data import FairseqDataset
+from .mae_image_dataset import caching_loader
+
+
+logger = logging.getLogger(__name__)
+
+
+def build_transform(is_train, input_size, color_jitter, aa, reprob, remode, recount):
+ mean = IMAGENET_DEFAULT_MEAN
+ std = IMAGENET_DEFAULT_STD
+ # train transform
+ if is_train:
+ # this should always dispatch to transforms_imagenet_train
+ transform = create_transform(
+ input_size=input_size,
+ is_training=True,
+ color_jitter=color_jitter,
+ auto_augment=aa,
+ interpolation="bicubic",
+ re_prob=reprob,
+ re_mode=remode,
+ re_count=recount,
+ mean=mean,
+ std=std,
+ )
+ return transform
+
+ # eval transform
+ t = []
+ if input_size <= 224:
+ crop_pct = 224 / 256
+ else:
+ crop_pct = 1.0
+ size = int(input_size / crop_pct)
+ t.append(
+ transforms.Resize(
+ size, interpolation=PIL.Image.BICUBIC
+ ), # to maintain same ratio w.r.t. 224 images
+ )
+ t.append(transforms.CenterCrop(input_size))
+
+ t.append(transforms.ToTensor())
+ t.append(transforms.Normalize(mean, std))
+ return transforms.Compose(t)
+
+
+class MaeFinetuningImageDataset(FairseqDataset):
+ def __init__(
+ self,
+ root: str,
+ split: str,
+ is_train: bool,
+ input_size,
+ color_jitter=None,
+ aa="rand-m9-mstd0.5-inc1",
+ reprob=0.25,
+ remode="pixel",
+ recount=1,
+ local_cache_path=None,
+ shuffle=True,
+ ):
+ FairseqDataset.__init__(self)
+
+ self.shuffle = shuffle
+
+ transform = build_transform(
+ is_train, input_size, color_jitter, aa, reprob, remode, recount
+ )
+
+ path = os.path.join(root, split)
+ loader = caching_loader(local_cache_path, datasets.folder.default_loader)
+
+ self.dataset = datasets.ImageFolder(path, loader=loader, transform=transform)
+
+ logger.info(f"loaded {len(self.dataset)} examples")
+
+ def __getitem__(self, index):
+ img, label = self.dataset[index]
+ return {"id": index, "img": img, "label": label}
+
+ def __len__(self):
+ return len(self.dataset)
+
+ def collater(self, samples):
+ if len(samples) == 0:
+ return {}
+
+ collated_img = torch.stack([s["img"] for s in samples], dim=0)
+
+ res = {
+ "id": torch.LongTensor([s["id"] for s in samples]),
+ "net_input": {
+ "imgs": collated_img,
+ },
+ }
+
+ if "label" in samples[0]:
+ res["net_input"]["labels"] = torch.LongTensor([s["label"] for s in samples])
+
+ return res
+
+ def num_tokens(self, index):
+ return 1
+
+ def size(self, index):
+ return 1
+
+ def ordered_indices(self):
+ """Return an ordered list of indices. Batches will be constructed based
+ on this order."""
+ if self.shuffle:
+ order = [np.random.permutation(len(self))]
+ else:
+ order = [np.arange(len(self))]
+
+ return order[0]
diff --git a/examples/data2vec/data/mae_image_dataset.py b/examples/data2vec/data/mae_image_dataset.py
new file mode 100644
index 000000000..4aacb9489
--- /dev/null
+++ b/examples/data2vec/data/mae_image_dataset.py
@@ -0,0 +1,418 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+
+from functools import partial
+import logging
+import math
+import random
+import time
+
+import numpy as np
+import os
+
+import torch
+
+from torchvision import datasets, transforms
+from .path_dataset import PathDataset
+
+from fairseq.data import FairseqDataset
+from fairseq.data.data_utils import compute_block_mask_1d, compute_block_mask_2d
+
+from shutil import copyfile
+
+logger = logging.getLogger(__name__)
+
+
+def load(path, loader, cache):
+ if hasattr(caching_loader, "cache_root"):
+ cache = caching_loader.cache_root
+
+ cached_path = cache + path
+
+ num_tries = 3
+ for curr_try in range(num_tries):
+ try:
+ if curr_try == 2:
+ return loader(path)
+ if not os.path.exists(cached_path) or curr_try > 0:
+ os.makedirs(os.path.dirname(cached_path), exist_ok=True)
+ copyfile(path, cached_path)
+ os.chmod(cached_path, 0o777)
+ return loader(cached_path)
+ except Exception as e:
+ logger.warning(str(e))
+ if "Errno 13" in str(e):
+ caching_loader.cache_root = f"/scratch/{random.randint(0, 69420)}"
+ logger.warning(f"setting cache root to {caching_loader.cache_root}")
+ cached_path = caching_loader.cache_root + path
+ if curr_try == (num_tries - 1):
+ raise
+ time.sleep(2)
+
+
+def caching_loader(cache_root: str, loader):
+ if cache_root is None:
+ return loader
+
+ if cache_root == "slurm_tmpdir":
+ cache_root = os.environ["SLURM_TMPDIR"]
+ assert len(cache_root) > 0
+
+ if not cache_root.endswith("/"):
+ cache_root += "/"
+
+ return partial(load, loader=loader, cache=cache_root)
+
+
+class RandomResizedCropAndInterpolationWithTwoPic:
+ """Crop the given PIL Image to random size and aspect ratio with random interpolation.
+
+ A crop of random size (default: of 0.08 to 1.0) of the original size and a random
+ aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This crop
+ is finally resized to given size.
+ This is popularly used to train the Inception networks.
+
+ Args:
+ size: expected output size of each edge
+ scale: range of size of the origin size cropped
+ ratio: range of aspect ratio of the origin aspect ratio cropped
+ interpolation: Default: PIL.Image.BILINEAR
+ """
+
+ def __init__(
+ self,
+ size,
+ second_size=None,
+ scale=(0.08, 1.0),
+ ratio=(3.0 / 4.0, 4.0 / 3.0),
+ interpolation="bilinear",
+ second_interpolation="lanczos",
+ ):
+ if isinstance(size, tuple):
+ self.size = size
+ else:
+ self.size = (size, size)
+ if second_size is not None:
+ if isinstance(second_size, tuple):
+ self.second_size = second_size
+ else:
+ self.second_size = (second_size, second_size)
+ else:
+ self.second_size = None
+ if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
+ logger.warning("range should be of kind (min, max)")
+
+ if interpolation == "random":
+ from PIL import Image
+
+ self.interpolation = (Image.BILINEAR, Image.BICUBIC)
+ else:
+ self.interpolation = self._pil_interp(interpolation)
+
+ self.second_interpolation = (
+ self._pil_interp(second_interpolation)
+ if second_interpolation is not None
+ else None
+ )
+ self.scale = scale
+ self.ratio = ratio
+
+ def _pil_interp(self, method):
+ from PIL import Image
+
+ if method == "bicubic":
+ return Image.BICUBIC
+ elif method == "lanczos":
+ return Image.LANCZOS
+ elif method == "hamming":
+ return Image.HAMMING
+ else:
+ # default bilinear, do we want to allow nearest?
+ return Image.BILINEAR
+
+ @staticmethod
+ def get_params(img, scale, ratio):
+ """Get parameters for ``crop`` for a random sized crop.
+
+ Args:
+ img (PIL Image): Image to be cropped.
+ scale (tuple): range of size of the origin size cropped
+ ratio (tuple): range of aspect ratio of the origin aspect ratio cropped
+
+ Returns:
+ tuple: params (i, j, h, w) to be passed to ``crop`` for a random
+ sized crop.
+ """
+ area = img.size[0] * img.size[1]
+
+ for attempt in range(10):
+ target_area = random.uniform(*scale) * area
+ log_ratio = (math.log(ratio[0]), math.log(ratio[1]))
+ aspect_ratio = math.exp(random.uniform(*log_ratio))
+
+ w = int(round(math.sqrt(target_area * aspect_ratio)))
+ h = int(round(math.sqrt(target_area / aspect_ratio)))
+
+ if w <= img.size[0] and h <= img.size[1]:
+ i = random.randint(0, img.size[1] - h)
+ j = random.randint(0, img.size[0] - w)
+ return i, j, h, w
+
+ # Fallback to central crop
+ in_ratio = img.size[0] / img.size[1]
+ if in_ratio < min(ratio):
+ w = img.size[0]
+ h = int(round(w / min(ratio)))
+ elif in_ratio > max(ratio):
+ h = img.size[1]
+ w = int(round(h * max(ratio)))
+ else: # whole image
+ w = img.size[0]
+ h = img.size[1]
+ i = (img.size[1] - h) // 2
+ j = (img.size[0] - w) // 2
+ return i, j, h, w
+
+ def __call__(self, img):
+ import torchvision.transforms.functional as F
+
+ """
+ Args:
+ img (PIL Image): Image to be cropped and resized.
+
+ Returns:
+ PIL Image: Randomly cropped and resized image.
+ """
+ i, j, h, w = self.get_params(img, self.scale, self.ratio)
+ if isinstance(self.interpolation, (tuple, list)):
+ interpolation = random.choice(self.interpolation)
+ else:
+ interpolation = self.interpolation
+ if self.second_size is None:
+ return F.resized_crop(img, i, j, h, w, self.size, interpolation)
+ else:
+ return F.resized_crop(
+ img, i, j, h, w, self.size, interpolation
+ ), F.resized_crop(
+ img, i, j, h, w, self.second_size, self.second_interpolation
+ )
+
+
+class MaeImageDataset(FairseqDataset):
+ def __init__(
+ self,
+ root: str,
+ split: str,
+ input_size,
+ local_cache_path=None,
+ shuffle=True,
+ key="imgs",
+ beit_transforms=False,
+ target_transform=False,
+ no_transform=False,
+ compute_mask=False,
+ patch_size: int = 16,
+ mask_prob: float = 0.75,
+ mask_prob_adjust: float = 0,
+ mask_length: int = 1,
+ inverse_mask: bool = False,
+ expand_adjacent: bool = False,
+ mask_dropout: float = 0,
+ non_overlapping: bool = False,
+ require_same_masks: bool = True,
+ clone_batch: int = 1,
+ dataset_type: str = "imagefolder",
+ ):
+ FairseqDataset.__init__(self)
+
+ self.shuffle = shuffle
+ self.key = key
+
+ loader = caching_loader(local_cache_path, datasets.folder.default_loader)
+
+ self.transform_source = None
+ self.transform_target = None
+
+ if target_transform:
+ self.transform_source = transforms.ColorJitter(0.4, 0.4, 0.4)
+ self.transform_target = transforms.ColorJitter(0.4, 0.4, 0.4)
+
+ if no_transform:
+ if input_size <= 224:
+ crop_pct = 224 / 256
+ else:
+ crop_pct = 1.0
+ size = int(input_size / crop_pct)
+
+ self.transform_train = transforms.Compose(
+ [
+ transforms.Resize(size, interpolation=3),
+ transforms.CenterCrop(input_size),
+ ]
+ )
+
+ self.transform_train = transforms.Resize((input_size, input_size))
+ elif beit_transforms:
+ beit_transform_list = []
+ if not target_transform:
+ beit_transform_list.append(transforms.ColorJitter(0.4, 0.4, 0.4))
+ beit_transform_list.extend(
+ [
+ transforms.RandomHorizontalFlip(p=0.5),
+ RandomResizedCropAndInterpolationWithTwoPic(
+ size=input_size,
+ second_size=None,
+ interpolation="bicubic",
+ second_interpolation=None,
+ ),
+ ]
+ )
+ self.transform_train = transforms.Compose(beit_transform_list)
+ else:
+ self.transform_train = transforms.Compose(
+ [
+ transforms.RandomResizedCrop(
+ input_size, scale=(0.2, 1.0), interpolation=3
+ ), # 3 is bicubic
+ transforms.RandomHorizontalFlip(),
+ ]
+ )
+ self.final_transform = transforms.Compose(
+ [
+ transforms.ToTensor(),
+ transforms.Normalize(
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
+ ),
+ ]
+ )
+
+ if dataset_type == "imagefolder":
+ self.dataset = datasets.ImageFolder(
+ os.path.join(root, split), loader=loader
+ )
+ elif dataset_type == "path":
+ self.dataset = PathDataset(
+ root,
+ loader,
+ None,
+ None,
+ mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225],
+ )
+ else:
+ raise Exception(f"invalid dataset type {dataset_type}")
+
+ logger.info(
+ f"initial transform: {self.transform_train}, "
+ f"source transform: {self.transform_source}, "
+ f"target transform: {self.transform_target}, "
+ f"final transform: {self.final_transform}"
+ )
+ logger.info(f"loaded {len(self.dataset)} examples")
+
+ self.is_compute_mask = compute_mask
+ self.patches = (input_size // patch_size) ** 2
+ self.mask_prob = mask_prob
+ self.mask_prob_adjust = mask_prob_adjust
+ self.mask_length = mask_length
+ self.inverse_mask = inverse_mask
+ self.expand_adjacent = expand_adjacent
+ self.mask_dropout = mask_dropout
+ self.non_overlapping = non_overlapping
+ self.require_same_masks = require_same_masks
+ self.clone_batch = clone_batch
+
+ def __getitem__(self, index):
+ img, _ = self.dataset[index]
+
+ img = self.transform_train(img)
+
+ source = None
+ target = None
+ if self.transform_source is not None:
+ source = self.final_transform(self.transform_source(img))
+ if self.transform_target is not None:
+ target = self.final_transform(self.transform_target(img))
+
+ if source is None:
+ img = self.final_transform(img)
+
+ v = {"id": index, self.key: source if source is not None else img}
+ if target is not None:
+ v["target"] = target
+
+ if self.is_compute_mask:
+ if self.mask_length == 1:
+ mask = compute_block_mask_1d(
+ shape=(self.clone_batch, self.patches),
+ mask_prob=self.mask_prob,
+ mask_length=self.mask_length,
+ mask_prob_adjust=self.mask_prob_adjust,
+ inverse_mask=self.inverse_mask,
+ require_same_masks=True,
+ )
+ else:
+ mask = compute_block_mask_2d(
+ shape=(self.clone_batch, self.patches),
+ mask_prob=self.mask_prob,
+ mask_length=self.mask_length,
+ mask_prob_adjust=self.mask_prob_adjust,
+ inverse_mask=self.inverse_mask,
+ require_same_masks=True,
+ expand_adjcent=self.expand_adjacent,
+ mask_dropout=self.mask_dropout,
+ non_overlapping=self.non_overlapping,
+ )
+
+ v["precomputed_mask"] = mask
+
+ return v
+
+ def __len__(self):
+ return len(self.dataset)
+
+ def collater(self, samples):
+ if len(samples) == 0:
+ return {}
+
+ collated_img = torch.stack([s[self.key] for s in samples], dim=0)
+
+ res = {
+ "id": torch.LongTensor([s["id"] for s in samples]),
+ "net_input": {
+ self.key: collated_img,
+ },
+ }
+
+ if "target" in samples[0]:
+ collated_target = torch.stack([s["target"] for s in samples], dim=0)
+ res["net_input"]["target"] = collated_target
+
+ if "precomputed_mask" in samples[0]:
+ collated_mask = torch.cat([s["precomputed_mask"] for s in samples], dim=0)
+ res["net_input"]["precomputed_mask"] = collated_mask
+
+ return res
+
+ def num_tokens(self, index):
+ return 1
+
+ def size(self, index):
+ return 1
+
+ @property
+ def sizes(self):
+ return np.full((len(self),), 1)
+
+ def ordered_indices(self):
+ """Return an ordered list of indices. Batches will be constructed based
+ on this order."""
+ if self.shuffle:
+ order = [np.random.permutation(len(self))]
+ else:
+ order = [np.arange(len(self))]
+
+ return order[0]
diff --git a/examples/data2vec/data/modality.py b/examples/data2vec/data/modality.py
new file mode 100644
index 000000000..aa23ac94f
--- /dev/null
+++ b/examples/data2vec/data/modality.py
@@ -0,0 +1,14 @@
+# Copyright (c) 2017-present, Facebook, Inc.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the LICENSE file in
+# the root directory of this source tree. An additional grant of patent rights
+# can be found in the PATENTS file in the same directory.
+
+from enum import Enum, auto
+
+
+class Modality(Enum):
+ AUDIO = auto()
+ IMAGE = auto()
+ TEXT = auto()
diff --git a/examples/data2vec/data/path_dataset.py b/examples/data2vec/data/path_dataset.py
new file mode 100644
index 000000000..02010058e
--- /dev/null
+++ b/examples/data2vec/data/path_dataset.py
@@ -0,0 +1,64 @@
+import glob
+import os
+from typing import List, Optional, Tuple
+
+import logging
+import numpy as np
+import torchvision.transforms.functional as TF
+import PIL
+from PIL import Image
+from torchvision.datasets import VisionDataset
+
+logger = logging.getLogger(__name__)
+
+
+class PathDataset(VisionDataset):
+ def __init__(
+ self,
+ root: List[str],
+ loader: None = None,
+ transform: Optional[str] = None,
+ extra_transform: Optional[str] = None,
+ mean: Optional[List[float]] = None,
+ std: Optional[List[float]] = None,
+ ):
+ super().__init__(root=root)
+
+ PIL.Image.MAX_IMAGE_PIXELS = 256000001
+
+ self.files = []
+ for folder in self.root:
+ self.files.extend(
+ sorted(glob.glob(os.path.join(folder, "**", "*.jpg"), recursive=True))
+ )
+ self.files.extend(
+ sorted(glob.glob(os.path.join(folder, "**", "*.png"), recursive=True))
+ )
+
+ self.transform = transform
+ self.extra_transform = extra_transform
+ self.mean = mean
+ self.std = std
+
+ self.loader = loader
+
+ logger.info(f"loaded {len(self.files)} samples from {root}")
+
+ assert (mean is None) == (std is None)
+
+ def __len__(self) -> int:
+ return len(self.files)
+
+ def __getitem__(self, idx) -> Tuple[np.ndarray, np.ndarray]:
+ path = self.files[idx]
+
+ if self.loader is not None:
+ return self.loader(path), None
+
+ img = Image.open(path).convert("RGB")
+ if self.transform is not None:
+ img = self.transform(img)
+ img = TF.to_tensor(img)
+ if self.mean is not None and self.std is not None:
+ img = TF.normalize(img, self.mean, self.std)
+ return img, None
diff --git a/examples/data2vec/fb_convert_beit_cp.py b/examples/data2vec/fb_convert_beit_cp.py
new file mode 100644
index 000000000..cf42ace76
--- /dev/null
+++ b/examples/data2vec/fb_convert_beit_cp.py
@@ -0,0 +1,165 @@
+#!/usr/bin/env python3
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import argparse
+import torch
+
+from omegaconf import OmegaConf
+
+from fairseq.criterions.model_criterion import ModelCriterionConfig
+from fairseq.dataclass.configs import FairseqConfig
+
+from tasks import ImageClassificationConfig, ImagePretrainingConfig
+from models.data2vec_image_classification import (
+ Data2VecImageClassificationConfig,
+ Data2VecImageClassificationModel,
+)
+from models.data2vec_vision import Data2VecVisionConfig, Data2VecVisionModel
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ description="convert beit checkpoint into data2vec - vision checkpoint"
+ )
+ # fmt: off
+ parser.add_argument('checkpoint', help='checkpoint to convert')
+ parser.add_argument('--output', required=True, metavar='PATH', help='where to output converted checkpoint')
+ parser.add_argument('--type', type=str, choices=['vision', 'image_classification'], default='image_classification', help='type of model to upgrade')
+ parser.add_argument('--inception_norms', action='store_true', default=False)
+ # fmt: on
+
+ return parser
+
+
+def update_checkpoint(model_dict, prefix, is_nested):
+
+ replace_paths = {
+ "cls_token": "model.cls_emb" if is_nested else "cls_emb",
+ "patch_embed": "model.patch_embed" if is_nested else "patch_embed",
+ "mask_token": "mask_emb",
+ }
+
+ starts_with = {
+ "patch_embed.proj": "model.patch_embed.conv"
+ if is_nested
+ else "patch_embed.conv",
+ "lm_head": "final_proj",
+ "fc_norm": "fc_norm",
+ "head": "head",
+ }
+
+ partial = {
+ "mlp.fc1": "mlp.0",
+ "mlp.fc2": "mlp.2",
+ }
+
+ for k in list(model_dict.keys()):
+ for sw, r in starts_with.items():
+ if k.startswith(sw):
+ replace_paths[k] = k.replace(sw, r)
+ for p, r in partial.items():
+ if p in k:
+ replace_paths[k] = prefix + k.replace(p, r)
+
+ if prefix != "":
+ for k in list(model_dict.keys()):
+ if k not in replace_paths:
+ replace_paths[k] = prefix + k
+
+ for k in list(model_dict.keys()):
+ if k in replace_paths:
+ model_dict[replace_paths[k]] = model_dict[k]
+ if k != replace_paths[k]:
+ del model_dict[k]
+
+ return model_dict
+
+
+def main():
+ parser = get_parser()
+ args = parser.parse_args()
+
+ cp = torch.load(args.checkpoint, map_location="cpu")
+
+ cfg = FairseqConfig(
+ criterion=ModelCriterionConfig(_name="model", log_keys=["correct"]),
+ )
+
+ if args.type == "image_classification":
+
+ cfg.task = ImageClassificationConfig(
+ _name="image_classification",
+ data=".",
+ )
+
+ if args.inception_norms:
+ cfg.task.normalization_mean = [0.5, 0.5, 0.5]
+ cfg.task.normalization_std = [0.5, 0.5, 0.5]
+
+ cfg.model = Data2VecImageClassificationConfig(
+ _name="data2vec_image_classification",
+ )
+ cfg.model.pretrained_model_args = FairseqConfig(
+ model=Data2VecVisionConfig(
+ _name="data2vec_vision", shared_rel_pos_bias=False
+ ),
+ task=ImagePretrainingConfig(
+ _name="image_pretraining",
+ ),
+ )
+
+ cfg = OmegaConf.create(cfg)
+
+ state = {
+ "cfg": OmegaConf.to_container(cfg, resolve=True, enum_to_str=True),
+ "model": cp["module"],
+ "best_loss": None,
+ "optimizer": None,
+ "extra_state": {},
+ }
+
+ model = Data2VecImageClassificationModel(cfg.model)
+ model.load_state_dict(
+ update_checkpoint(state["model"], prefix="model.encoder.", is_nested=True),
+ strict=True,
+ )
+ elif args.type == "vision":
+ cfg.task = ImagePretrainingConfig(
+ _name="image_pretraining",
+ data=".",
+ )
+
+ if args.inception_norms:
+ cfg.task.normalization_mean = [0.5, 0.5, 0.5]
+ cfg.task.normalization_std = [0.5, 0.5, 0.5]
+
+ cfg.model = Data2VecVisionConfig(
+ _name="data2vec_vision",
+ )
+ cfg = OmegaConf.create(cfg)
+
+ state = {
+ "cfg": OmegaConf.to_container(cfg, resolve=True, enum_to_str=True),
+ "model": cp["model"],
+ "best_loss": None,
+ "optimizer": None,
+ "extra_state": {},
+ }
+
+ model = Data2VecVisionModel(cfg.model)
+ model.load_state_dict(
+ update_checkpoint(state["model"], prefix="encoder.", is_nested=False),
+ strict=True,
+ )
+ else:
+ raise Exception("unsupported type " + args.type)
+
+ print(state["cfg"], state.keys())
+ torch.save(state, args.output)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/data2vec/models/__init__.py b/examples/data2vec/models/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/examples/data2vec/models/audio_classification.py b/examples/data2vec/models/audio_classification.py
new file mode 100644
index 000000000..06d215826
--- /dev/null
+++ b/examples/data2vec/models/audio_classification.py
@@ -0,0 +1,614 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import contextlib
+import logging
+import re
+from dataclasses import dataclass, field
+from typing import Any, Optional
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import numpy as np
+from omegaconf import II, MISSING, open_dict
+
+from fairseq import checkpoint_utils, tasks
+from fairseq.dataclass import FairseqDataclass
+from fairseq.dataclass.utils import convert_namespace_to_omegaconf
+from fairseq.models import (
+ BaseFairseqModel,
+ register_model,
+)
+from fairseq.models.wav2vec.wav2vec2 import MASKING_DISTRIBUTION_CHOICES
+from fairseq.modules import TransposeLast
+from fairseq.tasks import FairseqTask
+
+logger = logging.getLogger(__name__)
+
+
+@dataclass
+class AudioClassificationConfig(FairseqDataclass):
+ model_path: str = field(
+ default=MISSING, metadata={"help": "path to wav2vec 2.0 model"}
+ )
+ no_pretrained_weights: bool = field(
+ default=False, metadata={"help": "if true, does not load pretrained weights"}
+ )
+ dropout_input: float = field(
+ default=0.0,
+ metadata={"help": "dropout to apply to the input (after feat extr)"},
+ )
+ final_dropout: float = field(
+ default=0.0,
+ metadata={"help": "dropout after transformer and before final projection"},
+ )
+ dropout: float = field(
+ default=0.0, metadata={"help": "dropout probability inside wav2vec 2.0 model"}
+ )
+ attention_dropout: float = field(
+ default=0.0,
+ metadata={
+ "help": "dropout probability for attention weights inside wav2vec 2.0 model"
+ },
+ )
+ activation_dropout: float = field(
+ default=0.0,
+ metadata={
+ "help": "dropout probability after activation in FFN inside wav2vec 2.0 model"
+ },
+ )
+
+ # masking
+ apply_mask: bool = field(
+ default=False, metadata={"help": "apply masking during fine-tuning"}
+ )
+ mask_length: int = field(
+ default=10, metadata={"help": "repeat the mask indices multiple times"}
+ )
+ mask_prob: float = field(
+ default=0.5,
+ metadata={
+ "help": "probability of replacing a token with mask (normalized by length)"
+ },
+ )
+ mask_selection: MASKING_DISTRIBUTION_CHOICES = field(
+ default="static", metadata={"help": "how to choose masks"}
+ )
+ mask_other: float = field(
+ default=0,
+ metadata={
+ "help": "secondary mask argument (used for more complex distributions), "
+ "see help in compute_mask_indices"
+ },
+ )
+ no_mask_overlap: bool = field(
+ default=False, metadata={"help": "whether to allow masks to overlap"}
+ )
+ mask_min_space: Optional[int] = field(
+ default=1,
+ metadata={"help": "min space between spans (if no overlap is enabled)"},
+ )
+ require_same_masks: bool = field(
+ default=True,
+ metadata={
+ "help": "whether to number of masked timesteps must be the same across all "
+ "examples in a batch"
+ },
+ )
+ mask_dropout: float = field(
+ default=0.0,
+ metadata={"help": "percent of masks to unmask for each sample"},
+ )
+
+ # channel masking
+ mask_channel_length: int = field(
+ default=10, metadata={"help": "length of the mask for features (channels)"}
+ )
+ mask_channel_prob: float = field(
+ default=0.0, metadata={"help": "probability of replacing a feature with 0"}
+ )
+ mask_channel_selection: MASKING_DISTRIBUTION_CHOICES = field(
+ default="static",
+ metadata={"help": "how to choose mask length for channel masking"},
+ )
+ mask_channel_other: float = field(
+ default=0,
+ metadata={
+ "help": "secondary mask argument (used for more complex distributions), "
+ "see help in compute_mask_indicesh"
+ },
+ )
+ no_mask_channel_overlap: bool = field(
+ default=False, metadata={"help": "whether to allow channel masks to overlap"}
+ )
+ freeze_finetune_updates: int = field(
+ default=0, metadata={"help": "dont finetune wav2vec for this many updates"}
+ )
+ feature_grad_mult: float = field(
+ default=0.0, metadata={"help": "reset feature grad mult in wav2vec 2.0 to this"}
+ )
+ layerdrop: float = field(
+ default=0.0, metadata={"help": "probability of dropping a layer in wav2vec 2.0"}
+ )
+ mask_channel_min_space: Optional[int] = field(
+ default=1,
+ metadata={"help": "min space between spans (if no overlap is enabled)"},
+ )
+ mask_channel_before: bool = False
+ normalize: bool = II("task.normalize")
+ data: str = II("task.data")
+ # this holds the loaded wav2vec args
+ d2v_args: Any = None
+ offload_activations: bool = field(
+ default=False, metadata={"help": "offload_activations"}
+ )
+ min_params_to_wrap: int = field(
+ default=int(1e8),
+ metadata={
+ "help": "minimum number of params for a layer to be wrapped with FSDP() when "
+ "training with --ddp-backend=fully_sharded. Smaller values will "
+ "improve memory efficiency, but may make torch.distributed "
+ "communication less efficient due to smaller input sizes. This option "
+ "is set to 0 (i.e., always wrap) when --checkpoint-activations or "
+ "--offload-activations are passed."
+ },
+ )
+
+ checkpoint_activations: bool = field(
+ default=False,
+ metadata={"help": "recompute activations and save memory for extra compute"},
+ )
+ ddp_backend: str = II("distributed_training.ddp_backend")
+
+ prediction_mode: str = "lin_softmax"
+ eval_prediction_mode: Optional[str] = None
+ conv_kernel: int = -1
+ conv_stride: int = 1
+ two_convs: bool = False
+ extreme_factor: float = 1.0
+
+ conv_feature_layers: Optional[str] = field(
+ default=None,
+ metadata={
+ "help": "string describing convolutional feature extraction layers in form of a python list that contains "
+ "[(dim, kernel_size, stride), ...]"
+ },
+ )
+
+ mixup_prob: float = 1.0
+ source_mixup: float = -1
+ same_mixup: bool = True
+ label_mixup: bool = False
+
+ gain_mode: str = "none"
+
+
+@register_model("audio_classification", dataclass=AudioClassificationConfig)
+class AudioClassificationModel(BaseFairseqModel):
+ def __init__(self, cfg: AudioClassificationConfig, num_classes):
+ super().__init__()
+
+ self.apply_mask = cfg.apply_mask
+ self.cfg = cfg
+
+ arg_overrides = {
+ "dropout": cfg.dropout,
+ "activation_dropout": cfg.activation_dropout,
+ "dropout_input": cfg.dropout_input,
+ "attention_dropout": cfg.attention_dropout,
+ "mask_length": cfg.mask_length,
+ "mask_prob": cfg.mask_prob,
+ "require_same_masks": getattr(cfg, "require_same_masks", True),
+ "mask_dropout": getattr(cfg, "mask_dropout", 0),
+ "mask_selection": cfg.mask_selection,
+ "mask_other": cfg.mask_other,
+ "no_mask_overlap": cfg.no_mask_overlap,
+ "mask_channel_length": cfg.mask_channel_length,
+ "mask_channel_prob": cfg.mask_channel_prob,
+ "mask_channel_before": cfg.mask_channel_before,
+ "mask_channel_selection": cfg.mask_channel_selection,
+ "mask_channel_other": cfg.mask_channel_other,
+ "no_mask_channel_overlap": cfg.no_mask_channel_overlap,
+ "encoder_layerdrop": cfg.layerdrop,
+ "feature_grad_mult": cfg.feature_grad_mult,
+ "checkpoint_activations": cfg.checkpoint_activations,
+ "offload_activations": cfg.offload_activations,
+ "min_params_to_wrap": cfg.min_params_to_wrap,
+ "mixup": -1,
+ }
+
+ if cfg.conv_feature_layers is not None:
+ arg_overrides["conv_feature_layers"] = cfg.conv_feature_layers
+
+ if cfg.d2v_args is None:
+ state = checkpoint_utils.load_checkpoint_to_cpu(
+ cfg.model_path, arg_overrides
+ )
+ d2v_args = state.get("cfg", None)
+ if d2v_args is None:
+ d2v_args = convert_namespace_to_omegaconf(state["args"])
+ d2v_args.criterion = None
+ d2v_args.lr_scheduler = None
+ cfg.d2v_args = d2v_args
+
+ logger.info(d2v_args)
+
+ else:
+ state = None
+ d2v_args = cfg.d2v_args
+
+ model_normalized = d2v_args.task.get(
+ "normalize", d2v_args.model.get("normalize", False)
+ )
+ assert cfg.normalize == model_normalized, (
+ "Fine-tuning works best when data normalization is the same. "
+ "Please check that --normalize is set or unset for both pre-training and here"
+ )
+
+ if hasattr(cfg, "checkpoint_activations") and cfg.checkpoint_activations:
+ with open_dict(d2v_args):
+ d2v_args.model.checkpoint_activations = cfg.checkpoint_activations
+
+ d2v_args.task.data = cfg.data
+ task = tasks.setup_task(d2v_args.task)
+ model = task.build_model(d2v_args.model, from_checkpoint=True)
+
+ model.remove_pretraining_modules()
+
+ if state is not None and not cfg.no_pretrained_weights:
+ self.load_model_weights(state, model, cfg)
+
+ d = d2v_args.model.encoder_embed_dim
+
+ self.d2v_model = model
+
+ self.final_dropout = nn.Dropout(cfg.final_dropout)
+ self.freeze_finetune_updates = cfg.freeze_finetune_updates
+ self.num_updates = 0
+
+ for p in self.parameters():
+ p.param_group = "pretrained"
+
+ if cfg.prediction_mode == "proj_avg_proj":
+ self.proj = nn.Linear(d, d * 2)
+ self.proj2 = nn.Linear(d * 2, num_classes)
+
+ for p in self.proj.parameters():
+ p.param_group = "projection"
+ for p in self.proj2.parameters():
+ p.param_group = "projection"
+ elif self.cfg.prediction_mode == "summary_proj":
+ self.proj = nn.Linear(d // 3, num_classes)
+ for p in self.proj.parameters():
+ p.param_group = "projection"
+ elif self.cfg.conv_kernel > 1 and not self.cfg.two_convs:
+ self.proj = nn.Sequential(
+ TransposeLast(),
+ nn.Conv1d(d, num_classes, kernel_size=self.cfg.conv_kernel, stride=self.cfg.conv_stride),
+ TransposeLast(),
+ )
+ for p in self.proj.parameters():
+ p.param_group = "projection"
+ elif self.cfg.conv_kernel > 0 and self.cfg.two_convs:
+ self.proj = nn.Sequential(
+ TransposeLast(),
+ nn.Conv1d(d, d, kernel_size=self.cfg.conv_kernel, stride=self.cfg.conv_stride),
+ TransposeLast(),
+ nn.GELU(),
+ nn.Linear(d, num_classes),
+ )
+ for p in self.proj.parameters():
+ p.param_group = "projection"
+ else:
+ self.proj = nn.Linear(d, num_classes)
+ for p in self.proj.parameters():
+ p.param_group = "projection"
+
+ def upgrade_state_dict_named(self, state_dict, name):
+ super().upgrade_state_dict_named(state_dict, name)
+ return state_dict
+
+ @classmethod
+ def build_model(cls, cfg: AudioClassificationConfig, task: FairseqTask):
+ """Build a new model instance."""
+
+ assert hasattr(task, "labels"), f"Task {task} must have an attribute 'labels'"
+
+ return cls(cfg, len(task.labels))
+
+ def load_model_weights(self, state, model, cfg):
+ if cfg.ddp_backend == "fully_sharded":
+ from fairseq.distributed import FullyShardedDataParallel
+
+ for name, module in model.named_modules():
+ if "encoder.layers" in name and len(name.split(".")) == 3:
+ # Only for layers, we do a special handling and load the weights one by one
+ # We dont load all weights together as that wont be memory efficient and may
+ # cause oom
+ new_dict = {
+ k.replace(name + ".", ""): v
+ for (k, v) in state["model"].items()
+ if name + "." in k
+ }
+ assert isinstance(module, FullyShardedDataParallel)
+ with module.summon_full_params():
+ module.load_state_dict(new_dict, strict=True)
+ module._reset_lazy_init()
+
+ # Once layers are loaded, filter them out and load everything else.
+ r = re.compile("encoder.layers.\d.")
+ filtered_list = list(filter(r.match, state["model"].keys()))
+
+ new_big_dict = {
+ k: v for (k, v) in state["model"].items() if k not in filtered_list
+ }
+
+ model.load_state_dict(new_big_dict, strict=False)
+ else:
+ if "_ema" in state["model"]:
+ del state["model"]["_ema"]
+ model.load_state_dict(state["model"], strict=False)
+
+ def set_num_updates(self, num_updates):
+ """Set the number of parameters updates."""
+ super().set_num_updates(num_updates)
+ self.num_updates = num_updates
+
+ def compute_gain(self, sound, fs=16_000, min_db=-80.0, mode="A_weighting"):
+ if fs == 16000:
+ n_fft = 2048
+ elif fs == 44100:
+ n_fft = 4096
+ else:
+ raise Exception("Invalid fs {}".format(fs))
+ stride = n_fft // 2
+
+ def a_weight(fs, n_fft, min_db=-80.0):
+ freq = np.linspace(0, fs // 2, n_fft // 2 + 1)
+ freq_sq = np.power(freq, 2)
+ freq_sq[0] = 1.0
+ weight = 2.0 + 20.0 * (
+ 2 * np.log10(12194)
+ + 2 * np.log10(freq_sq)
+ - np.log10(freq_sq + 12194 ** 2)
+ - np.log10(freq_sq + 20.6 ** 2)
+ - 0.5 * np.log10(freq_sq + 107.7 ** 2)
+ - 0.5 * np.log10(freq_sq + 737.9 ** 2)
+ )
+ weight = np.maximum(weight, min_db)
+
+ return weight
+
+ gain = []
+ for i in range(0, len(sound) - n_fft + 1, stride):
+ if mode == "RMSE":
+ g = np.mean(sound[i : i + n_fft] ** 2)
+ elif mode == "A_weighting":
+ spec = np.fft.rfft(np.hanning(n_fft + 1)[:-1] * sound[i : i + n_fft])
+ power_spec = np.abs(spec) ** 2
+ a_weighted_spec = power_spec * np.power(10, a_weight(fs, n_fft) / 10)
+ g = np.sum(a_weighted_spec)
+ else:
+ raise Exception("Invalid mode {}".format(mode))
+ gain.append(g)
+
+ gain = np.array(gain)
+ gain = np.maximum(gain, np.power(10, min_db / 10))
+ gain_db = 10 * np.log10(gain)
+
+ return gain_db
+
+ # adapted from https://github.com/mil-tokyo/bc_learning_sound/blob/master/utils.py
+ def compute_gain_torch(self, sound, fs=16_000, min_db=-80.0, mode="A_weighting"):
+ if fs == 16000:
+ n_fft = 2048
+ elif fs == 44100:
+ n_fft = 4096
+ else:
+ raise Exception("Invalid fs {}".format(fs))
+
+ if mode == "A_weighting":
+ if not hasattr(self, f"a_weight"):
+ self.a_weight = {}
+
+ if fs not in self.a_weight:
+
+ def a_weight(fs, n_fft, min_db=-80.0):
+ freq = np.linspace(0, fs // 2, n_fft // 2 + 1)
+ freq_sq = freq ** 2
+ freq_sq[0] = 1.0
+ weight = 2.0 + 20.0 * (
+ 2 * np.log10(12194)
+ + 2 * np.log10(freq_sq)
+ - np.log10(freq_sq + 12194 ** 2)
+ - np.log10(freq_sq + 20.6 ** 2)
+ - 0.5 * np.log10(freq_sq + 107.7 ** 2)
+ - 0.5 * np.log10(freq_sq + 737.9 ** 2)
+ )
+ weight = np.maximum(weight, min_db)
+
+ return weight
+
+ self.a_weight[fs] = torch.from_numpy(
+ np.power(10, a_weight(fs, n_fft, min_db) / 10)
+ ).to(device=sound.device)
+
+ sound = sound.unfold(-1, n_fft, n_fft // 2)
+
+ if mode == "RMSE":
+ sound = sound ** 2
+ g = sound.mean(-1)
+ elif mode == "A_weighting":
+ w = torch.hann_window(n_fft, device=sound.device) * sound
+ spec = torch.fft.rfft(w)
+ power_spec = spec.abs() ** 2
+ a_weighted_spec = power_spec * self.a_weight[fs]
+ g = a_weighted_spec.sum(-1)
+ else:
+ raise Exception("Invalid mode {}".format(mode))
+
+ gain = torch.maximum(g, torch.tensor(10 ** (min_db / 10), device=g.device))
+ gain_db = 10 * torch.log10(gain)
+
+ return gain_db
+
+ def forward(self, source, padding_mask, label=None, **kwargs):
+
+ if self.cfg.source_mixup >= 0 and self.training and self.cfg.mixup_prob > 0:
+ with torch.no_grad():
+ mixed_source = source
+ mix_mask = None
+ if self.cfg.mixup_prob < 1:
+ mix_mask = (
+ torch.empty((source.size(0),), device=source.device)
+ .bernoulli_(self.cfg.mixup_prob)
+ .bool()
+ )
+ mixed_source = source[mix_mask]
+
+ r = (
+ torch.FloatTensor(
+ 1 if self.cfg.same_mixup else mixed_source.size(0)
+ )
+ .uniform_(max(1e-6, self.cfg.source_mixup), 1)
+ .to(dtype=source.dtype, device=source.device)
+ )
+
+ mixup_perm = torch.randperm(source.size(0))
+ s2 = source[mixup_perm]
+
+ if self.cfg.gain_mode == "none":
+ p = r.unsqueeze(-1)
+ if mix_mask is not None:
+ s2 = s2[mix_mask]
+ else:
+ if self.cfg.gain_mode == "naive_rms":
+ G1 = source.pow(2).mean(dim=-1).sqrt()
+ else:
+ G1, _ = self.compute_gain_torch(
+ source, mode=self.cfg.gain_mode
+ ).max(-1)
+ G1 = G1.to(dtype=source.dtype)
+
+ G2 = G1[mixup_perm]
+
+ if mix_mask is not None:
+ G1 = G1[mix_mask]
+ G2 = G2[mix_mask]
+ s2 = s2[mix_mask]
+
+ p = 1 / (1 + 10 ** ((G1 - G2) / 20) * (1 - r) / r)
+ p = p.unsqueeze(-1)
+
+ mixed = (p * mixed_source) + (1 - p) * s2
+
+ if mix_mask is None:
+ source = mixed / torch.sqrt(p ** 2 + (1 - p) ** 2)
+ else:
+ source[mix_mask] = mixed / torch.sqrt(p ** 2 + (1 - p) ** 2)
+
+ if label is not None and self.cfg.label_mixup:
+ r = r.unsqueeze(-1)
+ if mix_mask is None:
+ label = label * r + (1 - r) * label[mixup_perm]
+ else:
+ label[mix_mask] = (
+ label[mix_mask] * r + (1 - r) * label[mixup_perm][mix_mask]
+ )
+
+ d2v_args = {
+ "source": source,
+ "padding_mask": padding_mask,
+ "mask": self.apply_mask and self.training,
+ }
+
+ ft = self.freeze_finetune_updates <= self.num_updates
+
+ with torch.no_grad() if not ft else contextlib.ExitStack():
+ res = self.d2v_model.extract_features(**d2v_args)
+
+ x = res["x"]
+ padding_mask = res["padding_mask"]
+ if padding_mask is not None:
+ x[padding_mask] = 0
+
+ x = self.final_dropout(x)
+
+ if self.training or (
+ self.cfg.eval_prediction_mode is None or self.cfg.eval_prediction_mode == ""
+ ):
+ prediction_mode = self.cfg.prediction_mode
+ else:
+ prediction_mode = self.cfg.eval_prediction_mode
+
+ if prediction_mode == "average_before":
+ x = x.mean(dim=1)
+
+ if prediction_mode != "summary_mha" and prediction_mode != "summary_proj" and prediction_mode != "cls":
+ x = self.proj(x)
+
+ logits = True
+ if prediction_mode == "lin_softmax":
+ x = F.logsigmoid(x.float())
+ x = torch.logsumexp(x + x, dim=1) - torch.logsumexp(x, dim=1)
+ x = x.clamp(max=0)
+ x = x - torch.log(-(torch.expm1(x)))
+ elif prediction_mode == "extremized_odds":
+ x = x.float().sum(dim=1)
+ x = x * self.cfg.extreme_factor
+ elif prediction_mode == "average_before":
+ x = x.float()
+ elif prediction_mode == "average":
+ x = x.float().mean(dim=1)
+ elif prediction_mode == "average_sigmoid":
+ x = torch.sigmoid(x.float())
+ x = x.mean(dim=1)
+ logits = False
+ elif prediction_mode == "max":
+ x, _ = x.float().max(dim=1)
+ elif prediction_mode == "max_sigmoid":
+ x = torch.sigmoid(x.float())
+ x, _ = x.float().max(dim=1)
+ logits = False
+ elif prediction_mode == "proj_avg_proj":
+ x = x.mean(dim=1)
+ x = self.proj2(x)
+ elif prediction_mode == "summary_mha" or prediction_mode == "summary_proj":
+ x = self.d2v_model.summary(
+ x, padding_mask, proj=prediction_mode == "summary_proj"
+ )
+ x = x.type_as(source)
+ x = self.proj(x)
+ elif prediction_mode == "cls":
+ x = x[:,0]
+ x = self.proj(x)
+ else:
+ raise Exception(f"unknown prediction mode {prediction_mode}")
+
+ if label is None:
+ return torch.sigmoid(x) if logits else x
+
+ x = torch.nan_to_num(x)
+
+ if logits:
+ loss = F.binary_cross_entropy_with_logits(
+ x, label.float(), reduction="none"
+ )
+ else:
+ loss = F.binary_cross_entropy(x, label.float(), reduction="none")
+
+ result = {
+ "losses": {
+ "main": loss,
+ },
+ "sample_size": label.sum(),
+ }
+
+ if not self.training:
+ result["_predictions"] = torch.sigmoid(x) if logits else x
+ result["_targets"] = label
+
+ return result
diff --git a/examples/data2vec/models/data2vec2.py b/examples/data2vec/models/data2vec2.py
new file mode 100644
index 000000000..0c61b3708
--- /dev/null
+++ b/examples/data2vec/models/data2vec2.py
@@ -0,0 +1,813 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import logging
+import math
+from dataclasses import dataclass, field
+from typing import Optional, Callable
+from functools import partial
+import numpy as np
+
+from omegaconf import II
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.distributed as dist
+
+from fairseq.modules import EMAModule, EMAModuleConfig
+
+from fairseq.dataclass import FairseqDataclass
+from fairseq.models import BaseFairseqModel, register_model
+
+from examples.data2vec.data.modality import Modality
+
+from examples.data2vec.models.modalities.base import (
+ MaskSeed,
+ D2vModalityConfig,
+ ModalitySpecificEncoder,
+ get_annealed_rate,
+)
+from examples.data2vec.models.modalities.modules import (
+ D2vDecoderConfig,
+ AltBlock,
+ Decoder1d,
+)
+
+from examples.data2vec.models.modalities.audio import (
+ D2vAudioConfig,
+ AudioEncoder,
+)
+from examples.data2vec.models.modalities.images import (
+ D2vImageConfig,
+ ImageEncoder,
+)
+from examples.data2vec.models.modalities.text import (
+ D2vTextConfig,
+ TextEncoder,
+)
+
+logger = logging.getLogger(__name__)
+
+
+@dataclass
+class D2vModalitiesConfig(FairseqDataclass):
+ audio: D2vAudioConfig = D2vAudioConfig()
+ image: D2vImageConfig = D2vImageConfig()
+ text: D2vTextConfig = D2vTextConfig()
+
+
+@dataclass
+class Data2VecMultiConfig(FairseqDataclass):
+
+ loss_beta: float = field(
+ default=0, metadata={"help": "beta for smooth l1 loss. 0 means use l2 loss"}
+ )
+ loss_scale: Optional[float] = field(
+ default=None,
+ metadata={
+ "help": "scale the reconstruction loss by this constant. if None then scales by 1/sqrt(dim)"
+ },
+ )
+
+ depth: int = 8
+ start_drop_path_rate: float = 0
+ end_drop_path_rate: float = 0
+ num_heads: int = 12
+ norm_eps: float = 1e-6
+ norm_affine: bool = True
+ encoder_dropout: float = 0.1
+ post_mlp_drop: float = 0.1
+ attention_dropout: float = 0.1
+ activation_dropout: float = 0.0
+ dropout_input: float = 0.0
+ layerdrop: float = 0.0
+ embed_dim: int = 768
+ mlp_ratio: float = 4
+ layer_norm_first: bool = False
+
+ average_top_k_layers: int = field(
+ default=8, metadata={"help": "how many layers to average"}
+ )
+
+ end_of_block_targets: bool = False
+
+ clone_batch: int = 1
+
+ layer_norm_target_layer: bool = False
+ batch_norm_target_layer: bool = False
+ instance_norm_target_layer: bool = False
+ instance_norm_targets: bool = False
+ layer_norm_targets: bool = False
+
+ ema_decay: float = field(default=0.999, metadata={"help": "initial ema decay rate"})
+ ema_same_dtype: bool = True
+ log_norms: bool = True
+ ema_end_decay: float = field(
+ default=0.9999, metadata={"help": "final ema decay rate"}
+ )
+
+ # when to finish annealing ema decay rate
+ ema_anneal_end_step: int = II("optimization.max_update")
+
+ ema_encoder_only: bool = field(
+ default=True,
+ metadata={
+ "help": "whether to momentum update only the shared transformer encoder"
+ },
+ )
+
+ max_update: int = II("optimization.max_update")
+
+ modalities: D2vModalitiesConfig = D2vModalitiesConfig()
+
+ shared_decoder: Optional[D2vDecoderConfig] = None
+
+ min_target_var: float = field(
+ default=0.1, metadata={"help": "stop training if target var falls below this"}
+ )
+ min_pred_var: float = field(
+ default=0.01,
+ metadata={"help": "stop training if prediction var falls below this"},
+ )
+
+ supported_modality: Optional[Modality] = None
+ mae_init: bool = False
+
+ seed: int = II("common.seed")
+
+ skip_ema: bool = False
+
+ cls_loss: float = 0
+ recon_loss: float = 0
+ d2v_loss: float = 1
+
+ decoder_group: bool = False
+
+
+@register_model("data2vec_multi", dataclass=Data2VecMultiConfig)
+class Data2VecMultiModel(BaseFairseqModel):
+ def make_modality_encoder(
+ self,
+ cfg: D2vModalityConfig,
+ embed_dim: int,
+ make_block: Callable[[float], nn.ModuleList],
+ norm_layer: Callable[[int], nn.LayerNorm],
+ layer_norm_first: bool,
+ alibi_biases,
+ task,
+ ) -> ModalitySpecificEncoder:
+ if cfg.type == Modality.AUDIO:
+ enc_cls = AudioEncoder
+ elif cfg.type == Modality.IMAGE:
+ enc_cls = ImageEncoder
+ elif cfg.type == Modality.TEXT:
+ enc_cls = TextEncoder
+ if hasattr(task, "text_task"):
+ task = task.text_task
+ else:
+ raise Exception(f"unsupported modality {cfg.type}")
+
+ return enc_cls(
+ cfg,
+ embed_dim,
+ make_block,
+ norm_layer,
+ layer_norm_first,
+ alibi_biases,
+ task,
+ )
+
+ def __init__(self, cfg: Data2VecMultiConfig, modalities, skip_ema=False, task=None):
+ super().__init__()
+ self.cfg = cfg
+ self.modalities = modalities
+ self.task = task
+
+ make_layer_norm = partial(
+ nn.LayerNorm, eps=cfg.norm_eps, elementwise_affine=cfg.norm_affine
+ )
+
+ def make_block(drop_path, dim=None, heads=None):
+ return AltBlock(
+ cfg.embed_dim if dim is None else dim,
+ cfg.num_heads if heads is None else heads,
+ cfg.mlp_ratio,
+ qkv_bias=True,
+ drop=cfg.encoder_dropout,
+ attn_drop=cfg.attention_dropout,
+ mlp_drop=cfg.activation_dropout,
+ post_mlp_drop=cfg.post_mlp_drop,
+ drop_path=drop_path,
+ norm_layer=make_layer_norm,
+ layer_norm_first=cfg.layer_norm_first,
+ ffn_targets=not cfg.end_of_block_targets,
+ )
+
+ self.alibi_biases = {}
+ self.modality_encoders = nn.ModuleDict()
+ for mod in self.modalities:
+ mod_cfg = getattr(cfg.modalities, mod.name.lower())
+ enc = self.make_modality_encoder(
+ mod_cfg,
+ cfg.embed_dim,
+ make_block,
+ make_layer_norm,
+ cfg.layer_norm_first,
+ self.alibi_biases,
+ task,
+ )
+ self.modality_encoders[mod.name] = enc
+
+ self.ema = None
+
+ self.average_top_k_layers = cfg.average_top_k_layers
+ self.loss_beta = cfg.loss_beta
+ self.loss_scale = cfg.loss_scale
+
+ self.dropout_input = nn.Dropout(cfg.dropout_input)
+
+ dpr = np.linspace(cfg.start_drop_path_rate, cfg.end_drop_path_rate, cfg.depth)
+
+ self.blocks = nn.ModuleList([make_block(dpr[i]) for i in range(cfg.depth)])
+
+ self.norm = None
+ if cfg.layer_norm_first:
+ self.norm = make_layer_norm(cfg.embed_dim)
+
+ if self.cfg.mae_init:
+ self.apply(self._init_weights)
+ else:
+ from fairseq.modules.transformer_sentence_encoder import init_bert_params
+
+ self.apply(init_bert_params)
+
+ for mod_enc in self.modality_encoders.values():
+ mod_enc.reset_parameters()
+
+ if not skip_ema:
+ self.ema = self.make_ema_teacher(cfg.ema_decay)
+ self.shared_decoder = (
+ Decoder1d(cfg.shared_decoder, cfg.embed_dim)
+ if self.cfg.shared_decoder is not None
+ else None
+ )
+ if self.shared_decoder is not None:
+ self.shared_decoder.apply(self._init_weights)
+
+ self.recon_proj = None
+ if cfg.recon_loss > 0:
+ self.recon_proj = nn.Linear(cfg.embed_dim, cfg.embed_dim)
+
+ for pn, p in self.named_parameters():
+ if len(p.shape) == 1 or pn.endswith(".bias") or "alibi_scale" in pn:
+ p.optim_overrides = {"optimizer": {"weight_decay_scale": 0}}
+ if cfg.decoder_group and "decoder" in pn:
+ p.param_group = "decoder"
+
+ self.num_updates = 0
+
+ def _init_weights(self, m):
+
+ try:
+ from apex.normalization import FusedLayerNorm
+
+ fn = FusedLayerNorm
+ except:
+ fn = nn.LayerNorm
+
+ if isinstance(m, nn.Linear):
+ torch.nn.init.xavier_uniform_(m.weight)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm) or isinstance(m, fn):
+ if m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ if m.weight is not None:
+ nn.init.constant_(m.weight, 1.0)
+
+ @torch.no_grad()
+ def make_ema_teacher(self, ema_decay):
+ ema_config = EMAModuleConfig(
+ ema_decay=ema_decay,
+ ema_fp32=True,
+ log_norms=self.cfg.log_norms,
+ add_missing_params=False,
+ )
+
+ model_copy = self.make_target_model()
+
+ return EMAModule(
+ model_copy,
+ ema_config,
+ copy_model=False,
+ )
+
+ def make_target_model(self):
+ logger.info("making target model")
+
+ model_copy = Data2VecMultiModel(
+ self.cfg, self.modalities, skip_ema=True, task=self.task
+ )
+
+ if self.cfg.ema_encoder_only:
+ model_copy = model_copy.blocks
+ for p_s, p_t in zip(self.blocks.parameters(), model_copy.parameters()):
+ p_t.data.copy_(p_s.data)
+ else:
+ for p_s, p_t in zip(self.parameters(), model_copy.parameters()):
+ p_t.data.copy_(p_s.data)
+
+ for mod_enc in model_copy.modality_encoders.values():
+ mod_enc.decoder = None
+ if not mod_enc.modality_cfg.ema_local_encoder:
+ mod_enc.local_encoder = None
+ mod_enc.project_features = None
+
+ model_copy.requires_grad_(False)
+ return model_copy
+
+ def set_num_updates(self, num_updates):
+ super().set_num_updates(num_updates)
+
+ if self.ema is not None and (
+ (self.num_updates == 0 and num_updates > 1)
+ or self.num_updates >= num_updates
+ ):
+ pass
+ elif self.training and self.ema is not None:
+ ema_weight_decay = None
+ if self.cfg.ema_decay != self.cfg.ema_end_decay:
+ if num_updates >= self.cfg.ema_anneal_end_step:
+ decay = self.cfg.ema_end_decay
+ else:
+ decay = get_annealed_rate(
+ self.cfg.ema_decay,
+ self.cfg.ema_end_decay,
+ num_updates,
+ self.cfg.ema_anneal_end_step,
+ )
+ self.ema.set_decay(decay, weight_decay=ema_weight_decay)
+ if self.ema.get_decay() < 1:
+ self.ema.step(self.blocks if self.cfg.ema_encoder_only else self)
+
+ self.num_updates = num_updates
+
+ def state_dict(self, destination=None, prefix="", keep_vars=False):
+ state = super().state_dict(destination, prefix, keep_vars)
+
+ if self.ema is not None:
+ state[prefix + "_ema"] = self.ema.fp32_params
+
+ return state
+
+ def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
+ k = prefix + "_ema"
+ if self.ema is not None:
+ assert k in state_dict
+ self.ema.restore(state_dict[k], True)
+ del state_dict[k]
+ elif k in state_dict:
+ del state_dict[k]
+
+ return super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
+
+ @classmethod
+ def build_model(cls, cfg: Data2VecMultiConfig, task=None):
+ """Build a new model instance."""
+ if task is None or not hasattr(task, "supported_modalities"):
+ modalities = (
+ [cfg.supported_modality]
+ if cfg.supported_modality is not None
+ else [
+ Modality.AUDIO,
+ Modality.IMAGE,
+ Modality.TEXT,
+ ]
+ )
+ else:
+ modalities = task.supported_modalities
+ return cls(cfg, modalities, task=task, skip_ema=cfg.skip_ema)
+
+ def forward(
+ self,
+ source,
+ target=None,
+ id=None,
+ mode=None,
+ padding_mask=None,
+ mask=True,
+ features_only=False,
+ force_remove_masked=False,
+ remove_extra_tokens=True,
+ precomputed_mask=None,
+ ):
+ if mode is None:
+ assert self.cfg.supported_modality is not None
+ mode = self.cfg.supported_modality
+
+ if isinstance(mode, Modality):
+ mode = mode.name
+
+ feature_extractor = self.modality_encoders[mode]
+
+ mask_seeds = None
+ if id is not None:
+ mask_seeds = MaskSeed(seed=self.cfg.seed, update=self.num_updates, ids=id)
+
+ extractor_out = feature_extractor(
+ source,
+ padding_mask,
+ mask,
+ remove_masked=not features_only or force_remove_masked,
+ clone_batch=self.cfg.clone_batch if not features_only else 1,
+ mask_seeds=mask_seeds,
+ precomputed_mask=precomputed_mask,
+ )
+
+ x = extractor_out["x"]
+ encoder_mask = extractor_out["encoder_mask"]
+ masked_padding_mask = extractor_out["padding_mask"]
+ masked_alibi_bias = extractor_out.get("alibi_bias", None)
+ alibi_scale = extractor_out.get("alibi_scale", None)
+
+ if self.dropout_input is not None:
+ x = self.dropout_input(x)
+
+ layer_results = []
+ for i, blk in enumerate(self.blocks):
+ if (
+ not self.training
+ or self.cfg.layerdrop == 0
+ or (np.random.random() > self.cfg.layerdrop)
+ ):
+ ab = masked_alibi_bias
+ if ab is not None and alibi_scale is not None:
+ scale = (
+ alibi_scale[i]
+ if alibi_scale.size(0) > 1
+ else alibi_scale.squeeze(0)
+ )
+ ab = ab * scale.type_as(ab)
+
+ x, lr = blk(
+ x,
+ padding_mask=masked_padding_mask,
+ alibi_bias=ab,
+ )
+ if features_only:
+ layer_results.append(lr)
+
+ if self.norm is not None:
+ x = self.norm(x)
+
+ if features_only:
+ if remove_extra_tokens:
+ x = x[:, feature_extractor.modality_cfg.num_extra_tokens :]
+ if masked_padding_mask is not None:
+ masked_padding_mask = masked_padding_mask[
+ :, feature_extractor.modality_cfg.num_extra_tokens :
+ ]
+
+ return {
+ "x": x,
+ "padding_mask": masked_padding_mask,
+ "layer_results": layer_results,
+ "mask": encoder_mask,
+ }
+
+ xs = []
+
+ if self.shared_decoder is not None:
+ dx = self.forward_decoder(
+ x,
+ feature_extractor,
+ self.shared_decoder,
+ encoder_mask,
+ )
+ xs.append(dx)
+ if feature_extractor.decoder is not None:
+ dx = self.forward_decoder(
+ x,
+ feature_extractor,
+ feature_extractor.decoder,
+ encoder_mask,
+ )
+ xs.append(dx)
+ orig_x = x
+
+ assert len(xs) > 0
+
+ p = next(self.ema.model.parameters())
+ device = x.device
+ dtype = x.dtype
+ ema_device = p.device
+ ema_dtype = p.dtype
+
+ if not self.cfg.ema_same_dtype:
+ dtype = ema_dtype
+
+ if ema_device != device or ema_dtype != dtype:
+ logger.info(f"adjusting ema dtype to {dtype} and device to {device}")
+ self.ema.model = self.ema.model.to(dtype=dtype, device=device)
+ ema_dtype = dtype
+
+ def to_device(d):
+ for k, p in d.items():
+ if isinstance(d[k], dict):
+ to_device(d[k])
+ else:
+ d[k] = p.to(device=device)
+
+ to_device(self.ema.fp32_params)
+ tm = self.ema.model
+
+ with torch.no_grad():
+ tm.eval()
+
+ if self.cfg.ema_encoder_only:
+ assert target is None
+ ema_input = extractor_out["local_features"]
+ ema_input = feature_extractor.contextualized_features(
+ ema_input.to(dtype=ema_dtype),
+ padding_mask,
+ mask=False,
+ remove_masked=False,
+ )
+ ema_blocks = tm
+ else:
+ ema_blocks = tm.blocks
+ if feature_extractor.modality_cfg.ema_local_encoder:
+ inp = (
+ target.to(dtype=ema_dtype)
+ if target is not None
+ else source.to(dtype=ema_dtype)
+ )
+ ema_input = tm.modality_encoders[mode](
+ inp,
+ padding_mask,
+ mask=False,
+ remove_masked=False,
+ )
+ else:
+ assert target is None
+ ema_input = extractor_out["local_features"]
+ ema_feature_enc = tm.modality_encoders[mode]
+ ema_input = ema_feature_enc.contextualized_features(
+ ema_input.to(dtype=ema_dtype),
+ padding_mask,
+ mask=False,
+ remove_masked=False,
+ )
+
+ ema_padding_mask = ema_input["padding_mask"]
+ ema_alibi_bias = ema_input.get("alibi_bias", None)
+ ema_alibi_scale = ema_input.get("alibi_scale", None)
+ ema_input = ema_input["x"]
+
+ y = []
+ ema_x = []
+ extra_tokens = feature_extractor.modality_cfg.num_extra_tokens
+ for i, blk in enumerate(ema_blocks):
+ ab = ema_alibi_bias
+ if ab is not None and alibi_scale is not None:
+ scale = (
+ ema_alibi_scale[i]
+ if ema_alibi_scale.size(0) > 1
+ else ema_alibi_scale.squeeze(0)
+ )
+ ab = ab * scale.type_as(ab)
+
+ ema_input, lr = blk(
+ ema_input,
+ padding_mask=ema_padding_mask,
+ alibi_bias=ab,
+ )
+ y.append(lr[:, extra_tokens:])
+ ema_x.append(ema_input[:, extra_tokens:])
+
+ y = self.make_targets(y, self.average_top_k_layers)
+ orig_targets = y
+
+ if self.cfg.clone_batch > 1:
+ y = y.repeat_interleave(self.cfg.clone_batch, 0)
+
+ masked = encoder_mask.mask.unsqueeze(-1)
+ masked_b = encoder_mask.mask.bool()
+ y = y[masked_b]
+
+ if xs[0].size(1) == masked_b.size(1):
+ xs = [x[masked_b] for x in xs]
+ else:
+ xs = [x.reshape(-1, x.size(-1)) for x in xs]
+
+ sample_size = masked.sum().long()
+
+ result = {
+ "losses": {},
+ "sample_size": sample_size,
+ }
+
+ sample_size = result["sample_size"]
+
+ if self.cfg.cls_loss > 0:
+ assert extra_tokens > 0
+ cls_target = orig_targets.mean(dim=1)
+ if self.cfg.clone_batch > 1:
+ cls_target = cls_target.repeat_interleave(self.cfg.clone_batch, 0)
+ cls_pred = x[:, extra_tokens - 1]
+ result["losses"]["cls"] = self.d2v_loss(cls_pred, cls_target) * (
+ self.cfg.cls_loss * sample_size
+ )
+
+ if self.cfg.recon_loss > 0:
+
+ with torch.no_grad():
+ target = feature_extractor.patchify(source)
+ mean = target.mean(dim=-1, keepdim=True)
+ var = target.var(dim=-1, keepdim=True)
+ target = (target - mean) / (var + 1.0e-6) ** 0.5
+
+ if self.cfg.clone_batch > 1:
+ target = target.repeat_interleave(self.cfg.clone_batch, 0)
+
+ if masked_b is not None:
+ target = target[masked_b]
+
+ recon = xs[0]
+ if self.recon_proj is not None:
+ recon = self.recon_proj(recon)
+
+ result["losses"]["recon"] = (
+ self.d2v_loss(recon, target.float()) * self.cfg.recon_loss
+ )
+
+ if self.cfg.d2v_loss > 0:
+ for i, x in enumerate(xs):
+ reg_loss = self.d2v_loss(x, y)
+ n = f"{mode}_regression_{i}" if len(xs) > 1 else f"{mode}_regression"
+ result["losses"][n] = reg_loss * self.cfg.d2v_loss
+
+ suffix = "" if len(self.modalities) == 1 else f"_{mode}"
+ with torch.no_grad():
+ if encoder_mask is not None:
+ result["masked_pct"] = 1 - (
+ encoder_mask.ids_keep.size(1) / encoder_mask.ids_restore.size(1)
+ )
+ for i, x in enumerate(xs):
+ n = f"pred_var{suffix}_{i}" if len(xs) > 1 else f"pred_var{suffix}"
+ result[n] = self.compute_var(x.float())
+ if self.ema is not None:
+ for k, v in self.ema.logs.items():
+ result[k] = v
+
+ y = y.float()
+ result[f"target_var{suffix}"] = self.compute_var(y)
+
+ if self.num_updates > 5000:
+ if result[f"target_var{suffix}"] < self.cfg.min_target_var:
+ logger.error(
+ f"target var is {result[f'target_var{suffix}'].item()} < {self.cfg.min_target_var}, exiting ({mode})"
+ )
+ raise Exception(
+ f"target var is {result[f'target_var{suffix}'].item()} < {self.cfg.min_target_var}, exiting ({mode})"
+ )
+
+ for k in result.keys():
+ if k.startswith("pred_var") and result[k] < self.cfg.min_pred_var:
+ logger.error(
+ f"{k} is {result[k].item()} < {self.cfg.min_pred_var}, exiting ({mode})"
+ )
+ raise Exception(
+ f"{k} is {result[k].item()} < {self.cfg.min_pred_var}, exiting ({mode})"
+ )
+
+ result["ema_decay"] = self.ema.get_decay() * 1000
+
+ return result
+
+ def forward_decoder(
+ self,
+ x,
+ feature_extractor,
+ decoder,
+ mask_info,
+ ):
+ x = feature_extractor.decoder_input(x, mask_info)
+ x = decoder(*x)
+
+ return x
+
+ def d2v_loss(self, x, y):
+ x = x.view(-1, x.size(-1)).float()
+ y = y.view(-1, x.size(-1))
+
+ if self.loss_beta == 0:
+ loss = F.mse_loss(x, y, reduction="none")
+ else:
+ loss = F.smooth_l1_loss(x, y, reduction="none", beta=self.loss_beta)
+
+ if self.loss_scale is not None:
+ scale = self.loss_scale
+ else:
+ scale = 1 / math.sqrt(x.size(-1))
+
+ reg_loss = loss * scale
+
+ return reg_loss
+
+ def make_targets(self, y, num_layers):
+
+ with torch.no_grad():
+ target_layer_results = y[-num_layers:]
+
+ permuted = False
+ if self.cfg.instance_norm_target_layer or self.cfg.batch_norm_target_layer:
+ target_layer_results = [
+ tl.transpose(1, 2) for tl in target_layer_results # BTC -> BCT
+ ]
+ permuted = True
+ if self.cfg.batch_norm_target_layer:
+ target_layer_results = [
+ F.batch_norm(
+ tl.float(), running_mean=None, running_var=None, training=True
+ )
+ for tl in target_layer_results
+ ]
+ if self.cfg.instance_norm_target_layer:
+ target_layer_results = [
+ F.instance_norm(tl.float()) for tl in target_layer_results
+ ]
+ if permuted:
+ target_layer_results = [
+ tl.transpose(1, 2) for tl in target_layer_results # BCT -> BTC
+ ]
+ if self.cfg.layer_norm_target_layer:
+ target_layer_results = [
+ F.layer_norm(tl.float(), tl.shape[-1:])
+ for tl in target_layer_results
+ ]
+
+ y = target_layer_results[0].float()
+ for tl in target_layer_results[1:]:
+ y.add_(tl.float())
+ y = y.div_(len(target_layer_results))
+
+ if self.cfg.layer_norm_targets:
+ y = F.layer_norm(y, y.shape[-1:])
+
+ if self.cfg.instance_norm_targets:
+ y = F.instance_norm(y.transpose(1, 2)).transpose(1, 2)
+
+ return y
+
+ @staticmethod
+ def compute_var(y):
+ y = y.view(-1, y.size(-1))
+ if dist.is_initialized():
+ zc = torch.tensor(y.size(0)).cuda()
+ zs = y.sum(dim=0)
+ zss = (y**2).sum(dim=0)
+
+ dist.all_reduce(zc)
+ dist.all_reduce(zs)
+ dist.all_reduce(zss)
+
+ var = zss / (zc - 1) - (zs**2) / (zc * (zc - 1))
+ return torch.sqrt(var + 1e-6).mean()
+ else:
+ return torch.sqrt(y.var(dim=0) + 1e-6).mean()
+
+ def extract_features(
+ self, source, mode=None, padding_mask=None, mask=False, remove_extra_tokens=True
+ ):
+ res = self.forward(
+ source,
+ mode=mode,
+ padding_mask=padding_mask,
+ mask=mask,
+ features_only=True,
+ remove_extra_tokens=remove_extra_tokens,
+ )
+ return res
+
+ def remove_pretraining_modules(self, modality=None, keep_decoder=False):
+ self.ema = None
+ self.cfg.clone_batch = 1
+ self.recon_proj = None
+
+ if not keep_decoder:
+ self.shared_decoder = None
+
+ modality = modality.lower() if modality is not None else None
+ for k in list(self.modality_encoders.keys()):
+ if modality is not None and k.lower() != modality:
+ del self.modality_encoders[k]
+ else:
+ self.modality_encoders[k].remove_pretraining_modules(
+ keep_decoder=keep_decoder
+ )
+ if not keep_decoder:
+ self.modality_encoders[k].decoder = None
diff --git a/examples/data2vec/models/data2vec_image_classification.py b/examples/data2vec/models/data2vec_image_classification.py
new file mode 100644
index 000000000..851c9ce45
--- /dev/null
+++ b/examples/data2vec/models/data2vec_image_classification.py
@@ -0,0 +1,143 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+# The code in this file is adapted from the BeiT implementation which can be found here:
+# https://github.com/microsoft/unilm/tree/master/beit
+
+import logging
+
+from dataclasses import dataclass
+from typing import Any
+
+from omegaconf import II, MISSING
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from fairseq import checkpoint_utils, tasks
+
+from fairseq.dataclass import FairseqDataclass
+from fairseq.models import BaseFairseqModel, register_model
+
+
+logger = logging.getLogger(__name__)
+
+
+@dataclass
+class Data2VecImageClassificationConfig(FairseqDataclass):
+ model_path: str = MISSING
+ no_pretrained_weights: bool = False
+ num_classes: int = 1000
+ mixup: float = 0.8
+ cutmix: float = 1.0
+ label_smoothing: float = 0.1
+
+ pretrained_model_args: Any = None
+ data: str = II("task.data")
+
+
+@register_model(
+ "data2vec_image_classification", dataclass=Data2VecImageClassificationConfig
+)
+class Data2VecImageClassificationModel(BaseFairseqModel):
+ def __init__(self, cfg: Data2VecImageClassificationConfig):
+ super().__init__()
+ self.cfg = cfg
+
+ if cfg.pretrained_model_args is None:
+ state = checkpoint_utils.load_checkpoint_to_cpu(cfg.model_path, {})
+ pretrained_args = state.get("cfg", None)
+ pretrained_args.criterion = None
+ pretrained_args.lr_scheduler = None
+ cfg.pretrained_model_args = pretrained_args
+
+ logger.info(pretrained_args)
+ else:
+ state = None
+ pretrained_args = cfg.pretrained_model_args
+
+ pretrained_args.task.data = cfg.data
+ task = tasks.setup_task(pretrained_args.task)
+ model = task.build_model(pretrained_args.model, from_checkpoint=True)
+
+ model.remove_pretraining_modules()
+
+ self.model = model
+
+ if state is not None and not cfg.no_pretrained_weights:
+ self.load_model_weights(state, model, cfg)
+
+ self.fc_norm = nn.LayerNorm(pretrained_args.model.embed_dim)
+ self.head = nn.Linear(pretrained_args.model.embed_dim, cfg.num_classes)
+
+ self.head.weight.data.mul_(1e-3)
+ self.head.bias.data.mul_(1e-3)
+
+ self.mixup_fn = None
+
+ if cfg.mixup > 0 or cfg.cutmix > 0:
+ from timm.data import Mixup
+
+ self.mixup_fn = Mixup(
+ mixup_alpha=cfg.mixup,
+ cutmix_alpha=cfg.cutmix,
+ cutmix_minmax=None,
+ prob=1.0,
+ switch_prob=0.5,
+ mode="batch",
+ label_smoothing=cfg.label_smoothing,
+ num_classes=cfg.num_classes,
+ )
+
+ def load_model_weights(self, state, model, cfg):
+ if "_ema" in state["model"]:
+ del state["model"]["_ema"]
+ model.load_state_dict(state["model"], strict=True)
+
+ @classmethod
+ def build_model(cls, cfg: Data2VecImageClassificationConfig, task=None):
+ """Build a new model instance."""
+
+ return cls(cfg)
+
+ def forward(
+ self,
+ img,
+ label=None,
+ ):
+ if self.training and self.mixup_fn is not None and label is not None:
+ img, label = self.mixup_fn(img, label)
+
+ x = self.model(img, mask=False)
+ x = x[:, 1:]
+ x = self.fc_norm(x.mean(1))
+ x = self.head(x)
+
+ if label is None:
+ return x
+
+ if self.training and self.mixup_fn is not None:
+ loss = -label * F.log_softmax(x.float(), dim=-1)
+ else:
+ loss = F.cross_entropy(
+ x.float(),
+ label,
+ label_smoothing=self.cfg.label_smoothing if self.training else 0,
+ reduction="none",
+ )
+
+ result = {
+ "losses": {"regression": loss},
+ "sample_size": img.size(0),
+ }
+
+ if not self.training:
+ with torch.no_grad():
+ pred = x.argmax(-1)
+ correct = (pred == label).sum()
+ result["correct"] = correct
+
+ return result
diff --git a/examples/data2vec/models/data2vec_text_classification.py b/examples/data2vec/models/data2vec_text_classification.py
new file mode 100644
index 000000000..e787b916d
--- /dev/null
+++ b/examples/data2vec/models/data2vec_text_classification.py
@@ -0,0 +1,141 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+# The code in this file is adapted from the BeiT implementation which can be found here:
+# https://github.com/microsoft/unilm/tree/master/beit
+
+import logging
+
+from dataclasses import dataclass
+from typing import Any
+
+from omegaconf import II, MISSING
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from fairseq import checkpoint_utils, tasks
+
+from fairseq.dataclass import FairseqDataclass
+from fairseq.models import BaseFairseqModel, register_model
+from fairseq.models.roberta.model import RobertaClassificationHead
+
+from examples.data2vec.data.modality import Modality
+
+
+logger = logging.getLogger(__name__)
+
+
+@dataclass
+class Data2VecTextClassificationConfig(FairseqDataclass):
+ pooler_dropout: float = 0.0
+ pooler_activation_fn: str = "tanh"
+ quant_noise_pq: int = 0
+ quant_noise_pq_block_size: int = 8
+ spectral_norm_classification_head: bool = False
+
+ model_path: str = MISSING
+ no_pretrained_weights: bool = False
+
+ pretrained_model_args: Any = None
+
+
+@register_model(
+ "data2vec_text_classification", dataclass=Data2VecTextClassificationConfig
+)
+class Data2VecTextClassificationModel(BaseFairseqModel):
+ def __init__(self, cfg: Data2VecTextClassificationConfig):
+ super().__init__()
+ self.cfg = cfg
+
+ if cfg.pretrained_model_args is None:
+ state = checkpoint_utils.load_checkpoint_to_cpu(cfg.model_path, {})
+ pretrained_args = state.get("cfg", None)
+ pretrained_args.criterion = None
+ pretrained_args.lr_scheduler = None
+ cfg.pretrained_model_args = pretrained_args
+
+ logger.info(pretrained_args)
+ else:
+ state = None
+ pretrained_args = cfg.pretrained_model_args
+
+ task = tasks.setup_task(pretrained_args.task)
+ model = task.build_model(pretrained_args.model, from_checkpoint=True)
+
+ model.remove_pretraining_modules()
+
+ self.model = model
+
+ if state is not None and not cfg.no_pretrained_weights:
+ self.load_model_weights(state, model, cfg)
+
+ self.classification_heads = nn.ModuleDict()
+
+
+ def load_model_weights(self, state, model, cfg):
+ for k in list(state["model"].keys()):
+ if (
+ k.startswith("shared_decoder") or
+ k.startswith("_ema") or
+ "decoder" in k
+ ):
+ logger.info(f"Deleting {k} from checkpoint")
+ del state["model"][k]
+ model.load_state_dict(state["model"], strict=True)
+
+ @classmethod
+ def build_model(cls, cfg: Data2VecTextClassificationConfig, task=None):
+ """Build a new model instance."""
+
+ return cls(cfg)
+
+ def register_classification_head(
+ self, name, num_classes=None, inner_dim=None, **kwargs
+ ):
+ """Register a classification head."""
+ if name in self.classification_heads:
+ prev_num_classes = self.classification_heads[name].out_proj.out_features
+ prev_inner_dim = self.classification_heads[name].dense.out_features
+ if num_classes != prev_num_classes or inner_dim != prev_inner_dim:
+ logger.warning(
+ 're-registering head "{}" with num_classes {} (prev: {}) '
+ "and inner_dim {} (prev: {})".format(
+ name, num_classes, prev_num_classes, inner_dim, prev_inner_dim
+ )
+ )
+ embed_dim = self.cfg.pretrained_model_args.model.embed_dim
+ self.classification_heads[name] = RobertaClassificationHead(
+ input_dim=embed_dim,
+ inner_dim=inner_dim or embed_dim,
+ num_classes=num_classes,
+ activation_fn=self.cfg.pooler_activation_fn,
+ pooler_dropout=self.cfg.pooler_dropout,
+ q_noise=self.cfg.quant_noise_pq,
+ qn_block_size=self.cfg.quant_noise_pq_block_size,
+ do_spectral_norm=self.cfg.spectral_norm_classification_head,
+ )
+
+ def forward(
+ self,
+ source,
+ id,
+ padding_mask,
+ features_only=True,
+ remove_extra_tokens=True,
+ classification_head_name=None,
+ ):
+ encoder_out = self.model(
+ source,
+ id=id,
+ mode=Modality.TEXT,
+ padding_mask=padding_mask,
+ mask=False,
+ features_only=features_only,
+ remove_extra_tokens=remove_extra_tokens
+ )
+ logits = self.classification_heads[classification_head_name](encoder_out["x"])
+ return logits, encoder_out
diff --git a/examples/data2vec/models/data2vec_vision.py b/examples/data2vec/models/data2vec_vision.py
new file mode 100644
index 000000000..2f8989442
--- /dev/null
+++ b/examples/data2vec/models/data2vec_vision.py
@@ -0,0 +1,727 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+# The code in this file is adapted from the BeiT implementation which can be found here:
+# https://github.com/microsoft/unilm/tree/master/beit
+
+import logging
+import math
+import numpy as np
+import random
+
+from dataclasses import dataclass, field
+from typing import Optional
+
+from omegaconf import II
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.distributed as dist
+
+from fairseq.modules import EMAModule, EMAModuleConfig
+from fairseq.dataclass import FairseqDataclass
+from fairseq.models import BaseFairseqModel, register_model
+
+
+logger = logging.getLogger(__name__)
+
+
+@dataclass
+class Data2VecVisionConfig(FairseqDataclass):
+ layer_scale_init_value: float = field(
+ default=1e-4, metadata={"help": "rescale layer outputs, 0 to disable"}
+ )
+ num_mask_patches: int = field(
+ default=75,
+ metadata={"help": "number of the visual tokens/patches need be masked"},
+ )
+ min_mask_patches_per_block: int = 16
+ max_mask_patches_per_block: int = 196
+ image_size: int = 224
+ patch_size: int = 16
+ in_channels: int = 3
+
+ shared_rel_pos_bias: bool = True
+
+ drop_path: float = 0.1
+ attention_dropout: float = 0.0
+
+ depth: int = 12
+ embed_dim: int = 768
+ num_heads: int = 12
+ mlp_ratio: int = 4
+
+ loss_beta: float = field(
+ default=0, metadata={"help": "beta for smooth l1 loss. 0 means use l2 loss"}
+ )
+ loss_scale: Optional[float] = field(
+ default=None,
+ metadata={
+ "help": "scale the reconstruction loss by this constant. if None then scales by 1/sqrt(dim)"
+ },
+ )
+ average_top_k_layers: int = field(
+ default=8, metadata={"help": "how many layers to average"}
+ )
+
+ end_of_block_targets: bool = True
+ layer_norm_target_layer: bool = False
+ instance_norm_target_layer: bool = False
+ batch_norm_target_layer: bool = False
+ instance_norm_targets: bool = False
+ layer_norm_targets: bool = False
+
+ ema_decay: float = field(default=0.999, metadata={"help": "initial ema decay rate"})
+ ema_end_decay: float = field(
+ default=0.9999, metadata={"help": "final ema decay rate"}
+ )
+
+ # when to finish annealing ema decay rate
+ ema_anneal_end_step: int = II("optimization.max_update")
+
+ ema_transformer_only: bool = field(
+ default=True,
+ metadata={"help": "whether to momentum update only the transformer layers"},
+ )
+
+
+def get_annealed_rate(start, end, curr_step, total_steps):
+ r = end - start
+ pct_remaining = 1 - curr_step / total_steps
+ return end - r * pct_remaining
+
+
+@register_model("data2vec_vision", dataclass=Data2VecVisionConfig)
+class Data2VecVisionModel(BaseFairseqModel):
+ def __init__(self, cfg: Data2VecVisionConfig):
+ super().__init__()
+ self.cfg = cfg
+
+ self.ema = None
+
+ self.average_top_k_layers = cfg.average_top_k_layers
+ self.loss_beta = cfg.loss_beta
+ self.loss_scale = (
+ cfg.loss_scale
+ if cfg.loss_scale is not None
+ else 1 / math.sqrt(cfg.embed_dim)
+ )
+
+ self.patch_embed = PatchEmbed(
+ img_size=cfg.image_size,
+ patch_size=cfg.patch_size,
+ in_chans=cfg.in_channels,
+ embed_dim=cfg.embed_dim,
+ )
+
+ patch_size = self.patch_embed.patch_size
+ self.window_size = (
+ cfg.image_size // patch_size[0],
+ cfg.image_size // patch_size[1],
+ )
+
+ self.cls_emb = nn.Parameter(torch.FloatTensor(1, 1, cfg.embed_dim))
+ self.mask_emb = nn.Parameter(torch.FloatTensor(1, 1, cfg.embed_dim))
+
+ nn.init.trunc_normal_(self.cls_emb, 0.02)
+ nn.init.trunc_normal_(self.mask_emb, 0.02)
+
+ self.encoder = TransformerEncoder(cfg, self.patch_embed.patch_shape)
+
+ self.final_proj = nn.Linear(cfg.embed_dim, cfg.embed_dim)
+ self.num_updates = 0
+
+ def make_ema_teacher(self):
+ ema_config = EMAModuleConfig(
+ ema_decay=self.cfg.ema_decay,
+ ema_fp32=True,
+ )
+ self.ema = EMAModule(
+ self.encoder if self.cfg.ema_transformer_only else self,
+ ema_config,
+ )
+
+ def set_num_updates(self, num_updates):
+ super().set_num_updates(num_updates)
+
+ if self.ema is None and self.final_proj is not None:
+ logger.info(f"making ema teacher")
+ self.make_ema_teacher()
+ elif self.training and self.ema is not None:
+ if self.cfg.ema_decay != self.cfg.ema_end_decay:
+ if num_updates >= self.cfg.ema_anneal_end_step:
+ decay = self.cfg.ema_end_decay
+ else:
+ decay = get_annealed_rate(
+ self.cfg.ema_decay,
+ self.cfg.ema_end_decay,
+ num_updates,
+ self.cfg.ema_anneal_end_step,
+ )
+ self.ema.set_decay(decay)
+ if self.ema.get_decay() < 1:
+ self.ema.step(self.encoder if self.cfg.ema_transformer_only else self)
+
+ self.num_updates = num_updates
+
+ def state_dict(self, destination=None, prefix="", keep_vars=False):
+ state = super().state_dict(destination, prefix, keep_vars)
+
+ if self.ema is not None:
+ state[prefix + "_ema"] = self.ema.fp32_params
+
+ return state
+
+ def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
+ if self.ema is not None:
+ k = prefix + "_ema"
+ assert k in state_dict
+ self.ema.restore(state_dict[k], True)
+ del state_dict[k]
+ return super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
+
+ @classmethod
+ def build_model(cls, cfg: Data2VecVisionConfig, task=None):
+ """Build a new model instance."""
+
+ return cls(cfg)
+
+ def make_mask(self, bsz, num_masks, min_masks, max_masks):
+ height, width = self.window_size
+
+ masks = np.zeros(shape=(bsz, height, width), dtype=np.int)
+
+ for i in range(bsz):
+ mask = masks[i]
+ mask_count = 0
+
+ min_aspect = 0.3
+ max_aspect = 1 / min_aspect
+ log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect))
+
+ def _mask(mask, max_mask_patches):
+ delta = 0
+ for attempt in range(10):
+ target_area = random.uniform(min_masks, max_mask_patches)
+ aspect_ratio = math.exp(random.uniform(*log_aspect_ratio))
+ h = int(round(math.sqrt(target_area * aspect_ratio)))
+ w = int(round(math.sqrt(target_area / aspect_ratio)))
+ if w < width and h < height:
+ top = random.randint(0, height - h)
+ left = random.randint(0, width - w)
+
+ num_masked = mask[top : top + h, left : left + w].sum()
+ # Overlap
+ if 0 < h * w - num_masked <= max_mask_patches:
+ for i in range(top, top + h):
+ for j in range(left, left + w):
+ if mask[i, j] == 0:
+ mask[i, j] = 1
+ delta += 1
+
+ if delta > 0:
+ break
+ return delta
+
+ while mask_count < num_masks:
+ max_mask_patches = min(num_masks - mask_count, max_masks)
+
+ delta = _mask(mask, max_mask_patches)
+ if delta == 0:
+ break
+ else:
+ mask_count += delta
+
+ return torch.from_numpy(masks)
+
+ def forward(
+ self,
+ img,
+ mask: bool = True,
+ layer_results: bool = False,
+ ):
+ x = self.patch_embed(img)
+ batch_size, seq_len, _ = x.size()
+
+ if mask:
+ mask_indices = self.make_mask(
+ img.size(0),
+ self.cfg.num_mask_patches,
+ self.cfg.min_mask_patches_per_block,
+ self.cfg.max_mask_patches_per_block,
+ )
+ bool_mask = mask_indices.view(mask_indices.size(0), -1).bool()
+ else:
+ mask_indices = bool_mask = None
+
+ cls_tokens = self.cls_emb.expand(batch_size, -1, -1)
+ x = torch.cat((cls_tokens, x), dim=1)
+
+ if self.ema is not None:
+ with torch.no_grad():
+ self.ema.model.eval()
+
+ if self.cfg.ema_transformer_only:
+ y = self.ema.model(
+ x,
+ layer_results="end" if self.cfg.end_of_block_targets else "fc",
+ )
+ else:
+ y = self.ema.model(
+ img,
+ mask=False,
+ layer_results=True,
+ )
+
+ y = y[-self.cfg.average_top_k_layers :]
+
+ permuted = False
+ if self.cfg.instance_norm_target_layer or self.cfg.batch_norm_target_layer:
+ y = [tl.transpose(1, 2) for tl in y] # BTC -> BCT
+ permuted = True
+
+ if self.cfg.batch_norm_target_layer:
+ y = [
+ F.batch_norm(
+ tl.float(), running_mean=None, running_var=None, training=True
+ )
+ for tl in y
+ ]
+
+ if self.cfg.instance_norm_target_layer:
+ y = [F.instance_norm(tl.float()) for tl in y]
+
+ if permuted:
+ y = [tl.transpose(1, 2) for tl in y] # BCT -> BTC
+
+ if self.cfg.layer_norm_target_layer:
+ y = [F.layer_norm(tl.float(), tl.shape[-1:]) for tl in y]
+
+ y = sum(y) / len(y)
+
+ if self.cfg.layer_norm_targets:
+ y = F.layer_norm(y.float(), y.shape[-1:])
+
+ if self.cfg.instance_norm_targets:
+ y = F.instance_norm(y.float().transpose(1, 2)).transpose(1, 2)
+
+ y = y[bool_mask].float()
+
+ if mask_indices is not None:
+ mask_token = self.mask_emb.expand(batch_size, seq_len, -1)
+ w = mask_indices.view(mask_indices.size(0), -1, 1).type_as(mask_token)
+ x[:, 1:] = x[:, 1:] * (1 - w) + mask_token * w
+
+ if layer_results:
+ enc_layer_results = "end" if self.cfg.end_of_block_targets else "fc"
+ else:
+ enc_layer_results = None
+
+ x = self.encoder(x, layer_results=enc_layer_results)
+ if layer_results or mask_indices is None:
+ return x
+
+ x = x[bool_mask].float()
+
+ if self.loss_beta == 0:
+ loss = F.mse_loss(x, y, reduction="none").sum(dim=-1)
+ else:
+ loss = F.smooth_l1_loss(x, y, reduction="none", beta=self.loss_beta).sum(
+ dim=-1
+ )
+
+ if self.loss_scale > 0:
+ loss = loss * self.loss_scale
+
+ result = {
+ "losses": {"regression": loss.sum()},
+ "sample_size": loss.numel(),
+ "target_var": self.compute_var(y),
+ "pred_var": self.compute_var(x),
+ "ema_decay": self.ema.get_decay() * 1000,
+ }
+ return result
+
+ @staticmethod
+ def compute_var(y):
+ y = y.view(-1, y.size(-1))
+ if dist.is_initialized():
+ zc = torch.tensor(y.size(0)).cuda()
+ zs = y.sum(dim=0)
+ zss = (y ** 2).sum(dim=0)
+
+ dist.all_reduce(zc)
+ dist.all_reduce(zs)
+ dist.all_reduce(zss)
+
+ var = zss / (zc - 1) - (zs ** 2) / (zc * (zc - 1))
+ return torch.sqrt(var + 1e-6).mean()
+ else:
+ return torch.sqrt(y.var(dim=0) + 1e-6).mean()
+
+ def remove_pretraining_modules(self, last_layer=None):
+ self.final_proj = None
+ self.ema = None
+ self.encoder.norm = nn.Identity()
+ self.mask_emb = None
+ if last_layer is not None:
+ self.encoder.layers = nn.ModuleList(
+ l for i, l in enumerate(self.encoder.layers) if i <= last_layer
+ )
+
+
+class PatchEmbed(nn.Module):
+ """Image to Patch Embedding"""
+
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
+ super().__init__()
+ if isinstance(img_size, int):
+ img_size = img_size, img_size
+ if isinstance(patch_size, int):
+ patch_size = patch_size, patch_size
+ num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
+ self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
+ self.img_size = img_size
+ self.patch_size = patch_size
+ self.num_patches = num_patches
+
+ self.conv = nn.Conv2d(
+ in_chans, embed_dim, kernel_size=patch_size, stride=patch_size
+ )
+
+ def forward(self, x):
+ # BCHW -> BTC
+ x = self.conv(x).flatten(2).transpose(1, 2)
+ return x
+
+
+class Attention(nn.Module):
+ def __init__(
+ self,
+ dim,
+ num_heads=8,
+ qkv_bias=True,
+ attn_drop=0.0,
+ proj_drop=0.0,
+ window_size=None,
+ attn_head_dim=None,
+ ):
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ if attn_head_dim is not None:
+ head_dim = attn_head_dim
+ all_head_dim = head_dim * self.num_heads
+ self.scale = head_dim ** -0.5
+
+ self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
+ if qkv_bias:
+ self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
+ self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
+ else:
+ self.q_bias = None
+ self.v_bias = None
+
+ if window_size:
+ self.window_size = window_size
+ self.num_relative_distance = (2 * window_size[0] - 1) * (
+ 2 * window_size[1] - 1
+ ) + 3
+ self.relative_position_bias_table = nn.Parameter(
+ torch.zeros(self.num_relative_distance, num_heads)
+ ) # 2*Wh-1 * 2*Ww-1, nH
+ # cls to token & token 2 cls & cls to cls
+
+ # get pair-wise relative position index for each token inside the window
+ coords_h = torch.arange(window_size[0])
+ coords_w = torch.arange(window_size[1])
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
+ relative_coords = (
+ coords_flatten[:, :, None] - coords_flatten[:, None, :]
+ ) # 2, Wh*Ww, Wh*Ww
+ relative_coords = relative_coords.permute(
+ 1, 2, 0
+ ).contiguous() # Wh*Ww, Wh*Ww, 2
+ relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
+ relative_coords[:, :, 1] += window_size[1] - 1
+ relative_coords[:, :, 0] *= 2 * window_size[1] - 1
+ relative_position_index = torch.zeros(
+ size=(window_size[0] * window_size[1] + 1,) * 2,
+ dtype=relative_coords.dtype,
+ )
+ relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
+ relative_position_index[0, 0:] = self.num_relative_distance - 3
+ relative_position_index[0:, 0] = self.num_relative_distance - 2
+ relative_position_index[0, 0] = self.num_relative_distance - 1
+
+ self.register_buffer("relative_position_index", relative_position_index)
+ else:
+ self.window_size = None
+ self.relative_position_bias_table = None
+ self.relative_position_index = None
+
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(all_head_dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ def forward(self, x, rel_pos_bias=None):
+ B, N, C = x.shape
+ qkv_bias = None
+ if self.q_bias is not None:
+ qkv_bias = torch.cat(
+ (
+ self.q_bias,
+ torch.zeros_like(self.v_bias, requires_grad=False),
+ self.v_bias,
+ )
+ )
+ # qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
+ qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
+ qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
+ q, k, v = (
+ qkv[0],
+ qkv[1],
+ qkv[2],
+ ) # make torchscript happy (cannot use tensor as tuple)
+
+ q = q * self.scale
+ attn = q @ k.transpose(-2, -1)
+
+ if self.relative_position_bias_table is not None:
+ assert 1==2
+ relative_position_bias = self.relative_position_bias_table[
+ self.relative_position_index.view(-1)
+ ].view(
+ self.window_size[0] * self.window_size[1] + 1,
+ self.window_size[0] * self.window_size[1] + 1,
+ -1,
+ ) # Wh*Ww,Wh*Ww,nH
+ relative_position_bias = relative_position_bias.permute(
+ 2, 0, 1
+ ).contiguous() # nH, Wh*Ww, Wh*Ww
+ attn = attn + relative_position_bias.unsqueeze(0)
+ print("attn.size() :", attn.size())
+ print("rel_pos_bias.size() :", rel_pos_bias.size())
+ if rel_pos_bias is not None:
+ attn = attn + rel_pos_bias
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class RelativePositionBias(nn.Module):
+ def __init__(self, window_size, num_heads):
+ super().__init__()
+ self.window_size = window_size
+ self.num_relative_distance = (2 * window_size[0] - 1) * (
+ 2 * window_size[1] - 1
+ ) + 3
+ self.relative_position_bias_table = nn.Parameter(
+ torch.zeros(self.num_relative_distance, num_heads)
+ )
+
+ # get pair-wise relative position index for each token inside the window
+ coords_h = torch.arange(window_size[0])
+ coords_w = torch.arange(window_size[1])
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
+ relative_coords = (
+ coords_flatten[:, :, None] - coords_flatten[:, None, :]
+ ) # 2, Wh*Ww, Wh*Ww
+ relative_coords = relative_coords.permute(
+ 1, 2, 0
+ ).contiguous() # Wh*Ww, Wh*Ww, 2
+ relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
+ relative_coords[:, :, 1] += window_size[1] - 1
+ relative_coords[:, :, 0] *= 2 * window_size[1] - 1
+ relative_position_index = torch.zeros(
+ size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype
+ )
+ relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
+ relative_position_index[0, 0:] = self.num_relative_distance - 3
+ relative_position_index[0:, 0] = self.num_relative_distance - 2
+ relative_position_index[0, 0] = self.num_relative_distance - 1
+
+ self.register_buffer("relative_position_index", relative_position_index)
+
+ def forward(self):
+ relative_position_bias = self.relative_position_bias_table[
+ self.relative_position_index.view(-1)
+ ].view(
+ self.window_size[0] * self.window_size[1] + 1,
+ self.window_size[0] * self.window_size[1] + 1,
+ -1,
+ ) # Wh*Ww,Wh*Ww,nH
+ print("self.window_size :", self.window_size)
+ print("self.num_relative_distance :", self.num_relative_distance)
+ print("self.relative_position_index :", self.relative_position_index.size(), self.relative_position_index)
+ print("relative_position_bias.size(), relative_position_bias :",relative_position_bias.size(), relative_position_bias)
+ print("self.relative_position_bias_table.size(), self.relative_position_bias_table :",self.relative_position_bias_table.size(), self.relative_position_bias_table)
+ return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
+
+
+class DropPath(nn.Module):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
+
+ def __init__(self, drop_prob=None):
+ super(DropPath, self).__init__()
+ self.drop_prob = drop_prob
+
+ def forward(self, x):
+ if self.drop_prob == 0.0 or not self.training:
+ return x
+ keep_prob = 1 - self.drop_prob
+ shape = (x.shape[0],) + (1,) * (
+ x.ndim - 1
+ ) # work with diff dim tensors, not just 2D ConvNets
+ random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
+ random_tensor.floor_()
+ output = x.div(keep_prob) * random_tensor
+ return output
+
+ def extra_repr(self) -> str:
+ return "p={}".format(self.drop_prob)
+
+
+class Block(nn.Module):
+ def __init__(
+ self,
+ dim,
+ num_heads,
+ mlp_ratio=4.0,
+ drop=0.0,
+ attn_drop=0.0,
+ drop_path=0.0,
+ init_values=None,
+ window_size=None,
+ ):
+ super().__init__()
+
+ self.norm1 = nn.LayerNorm(dim)
+ self.attn = Attention(
+ dim,
+ num_heads=num_heads,
+ attn_drop=attn_drop,
+ proj_drop=drop,
+ window_size=window_size,
+ )
+
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+ self.norm2 = nn.LayerNorm(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+
+ self.mlp = nn.Sequential(
+ nn.Linear(dim, mlp_hidden_dim),
+ nn.GELU(),
+ nn.Linear(mlp_hidden_dim, dim),
+ nn.Dropout(drop),
+ )
+
+ if init_values > 0:
+ self.gamma_1 = nn.Parameter(
+ init_values * torch.ones((dim)), requires_grad=True
+ )
+ self.gamma_2 = nn.Parameter(
+ init_values * torch.ones((dim)), requires_grad=True
+ )
+ else:
+ self.gamma_1, self.gamma_2 = None, None
+
+ def forward(self, x, rel_pos_bias=None):
+ print("inside block :", x.size())
+ if self.gamma_1 is None:
+ x = x + self.drop_path(self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias))
+ fc_feature = self.drop_path(self.mlp(self.norm2(x)))
+ x = x + fc_feature
+ else:
+ x = x + self.drop_path(
+ self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias)
+ )
+ fc_feature = self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
+ x = x + fc_feature
+ return x, fc_feature
+
+
+class TransformerEncoder(nn.Module):
+ def __init__(self, cfg: Data2VecVisionConfig, patch_shape):
+ super().__init__()
+
+ self.rel_pos_bias = None
+ if cfg.shared_rel_pos_bias:
+ self.rel_pos_bias = RelativePositionBias(
+ window_size=patch_shape, num_heads=cfg.num_heads
+ )
+
+ dpr = [
+ x.item() for x in torch.linspace(0, cfg.drop_path, cfg.depth)
+ ] # stochastic depth decay rule
+
+ print("TransformerEncoder > patch_shape :", patch_shape)
+ self.blocks = nn.ModuleList(
+ Block(
+ dim=cfg.embed_dim,
+ num_heads=cfg.num_heads,
+ attn_drop=cfg.attention_dropout,
+ drop_path=dpr[i],
+ init_values=cfg.layer_scale_init_value,
+ window_size=patch_shape if not cfg.shared_rel_pos_bias else None,
+ )
+ for i in range(cfg.depth)
+ )
+
+ self.norm = nn.LayerNorm(cfg.embed_dim)
+
+ self.apply(self.init_weights)
+ self.fix_init_weight()
+
+ def init_weights(self, m):
+ std = 0.02
+ if isinstance(m, nn.Linear):
+ nn.init.trunc_normal_(m.weight, std=std)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+ elif isinstance(m, nn.Conv2d):
+ nn.init.trunc_normal_(m.weight, std=std)
+ if m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+
+ def fix_init_weight(self):
+ def rescale(param, layer_id):
+ param.div_(math.sqrt(2.0 * layer_id))
+
+ for layer_id, layer in enumerate(self.blocks):
+ rescale(layer.attn.proj.weight.data, layer_id + 1)
+ rescale(layer.mlp[2].weight.data, layer_id + 1)
+
+ def extract_features(self, x, layer_results):
+
+ rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
+
+ z = []
+ for i, blk in enumerate(self.blocks):
+ x, fc_feature = blk(x, rel_pos_bias=rel_pos_bias)
+ if layer_results == "end":
+ z.append(x)
+ elif layer_results == "fc":
+ z.append(fc_feature)
+
+ return z if layer_results else self.norm(x)
+
+ def forward(self, x, layer_results=None):
+ x = self.extract_features(x, layer_results=layer_results)
+ if layer_results:
+ return [z[:, 1:] for z in x]
+
+ x = x[:, 1:]
+ return x
diff --git a/examples/data2vec/models/mae.py b/examples/data2vec/models/mae.py
new file mode 100644
index 000000000..5101e070e
--- /dev/null
+++ b/examples/data2vec/models/mae.py
@@ -0,0 +1,825 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+# The code in this file is adapted from the BeiT implementation which can be found here:
+# https://github.com/microsoft/unilm/tree/master/beit
+
+import logging
+from dataclasses import dataclass
+from functools import partial
+
+from timm.models.vision_transformer import PatchEmbed, Block
+
+import torch
+import torch.nn as nn
+
+import numpy as np
+
+from fairseq.dataclass import FairseqDataclass
+from fairseq.models import BaseFairseqModel, register_model
+from fairseq.models.wav2vec.wav2vec2 import TransformerSentenceEncoderLayer
+
+from apex.normalization import FusedLayerNorm
+import torch.nn.functional as F
+
+
+logger = logging.getLogger(__name__)
+
+
+@dataclass
+class MaeConfig(FairseqDataclass):
+ input_size: int = 224
+ in_chans: int = 3
+ patch_size: int = 16
+ embed_dim: int = 768
+ depth: int = 12
+ num_heads: int = 12
+ decoder_embed_dim: int = 512
+ decoder_depth: int = 8
+ decoder_num_heads: int = 16
+ mlp_ratio: int = 4
+ norm_eps: float = 1e-6
+
+ drop_path_rate: float = 0.0
+
+ mask_ratio: float = 0.75
+ norm_pix_loss: bool = True
+
+ w2v_block: bool = False
+ alt_block: bool = False
+ alt_block2: bool = False
+ alt_attention: bool = False
+ block_dropout: float = 0
+ attention_dropout: float = 0
+ activation_dropout: float = 0
+ layer_norm_first: bool = False
+
+ fused_ln: bool = True
+ end_of_block_targets: bool = True
+
+ no_decoder_embed: bool = False
+ no_decoder_pos_embed: bool = False
+ mask_noise_std: float = 0
+
+ single_qkv: bool = False
+ use_rel_pos_bias: bool = False
+ no_cls: bool = False
+
+
+def modify_relative_position_bias(orig_bias, bsz, mask):
+ if mask is None:
+ return orig_bias.unsqueeze(0).repeat(
+ bsz, 1, 1, 1
+ ) # heads x seq_len x seq_len => bsz x heads x seq_len x seq_len
+ heads, max_seq_len, max_seq_len = orig_bias.shape # includes CLS token
+ mask_for_rel_pos_bias = torch.cat(
+ (torch.zeros(bsz, 1, dtype=mask.dtype, device=mask.device), mask), dim=1
+ ).bool() # bsz x seqlen (add CLS token)
+ unmasked_for_rel_pos_bias = ~mask_for_rel_pos_bias
+ unmasked_for_rel_pos_bias = unmasked_for_rel_pos_bias.unsqueeze(1).repeat(
+ 1, heads, 1
+ ) # bsz x seq_len => bsz x heads x seq_len
+ b_t_t_rel_pos_bias = orig_bias.unsqueeze(0).repeat(
+ bsz, 1, 1, 1
+ ) # heads x seq_len x seq_len => bsz x heads x seq_len x seq_len
+ b_t_t_rel_pos_bias = b_t_t_rel_pos_bias.masked_select(
+ unmasked_for_rel_pos_bias.unsqueeze(-1)
+ )
+ b_t_t_rel_pos_bias = b_t_t_rel_pos_bias.view(bsz, heads, -1, max_seq_len)
+ new_len = b_t_t_rel_pos_bias.size(-2)
+ b_t_t_rel_pos_bias = b_t_t_rel_pos_bias.masked_select(
+ unmasked_for_rel_pos_bias.unsqueeze(-2)
+ )
+ b_t_t_rel_pos_bias = b_t_t_rel_pos_bias.view(bsz, heads, new_len, new_len)
+ return b_t_t_rel_pos_bias
+
+
+class AltBlock(nn.Module):
+ def __init__(
+ self,
+ dim,
+ num_heads,
+ mlp_ratio=4.0,
+ qkv_bias=False,
+ qk_scale=None,
+ drop=0.0,
+ attn_drop=0.0,
+ drop_path=0.0,
+ act_layer=nn.GELU,
+ norm_layer=nn.LayerNorm,
+ layer_norm_first=True,
+ ffn_targets=False,
+ use_rel_pos_bias=False,
+ window_size=None,
+ alt_attention=False,
+ ):
+ super().__init__()
+
+ self.layer_norm_first = layer_norm_first
+ self.ffn_targets = ffn_targets
+
+ from timm.models.vision_transformer import Attention, DropPath, Mlp
+
+ self.norm1 = norm_layer(dim)
+ self.use_rel_pos_bias = use_rel_pos_bias
+ if use_rel_pos_bias:
+ self.attn = AltAttention(
+ dim,
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ attn_drop=attn_drop,
+ proj_drop=drop,
+ window_size=window_size,
+ )
+ else:
+ if alt_attention:
+ from .multi.modules import AltAttention as AltAttention2
+ self.attn = AltAttention2(
+ dim,
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ attn_drop=attn_drop,
+ proj_drop=drop,
+ )
+ else:
+ self.attn = Attention(
+ dim,
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ attn_drop=attn_drop,
+ proj_drop=drop,
+ )
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = Mlp(
+ in_features=dim,
+ hidden_features=mlp_hidden_dim,
+ act_layer=act_layer,
+ drop=drop,
+ )
+
+ def forward(self, x, rel_pos_bias=None, pos_mask=None):
+ if self.layer_norm_first:
+ if self.use_rel_pos_bias:
+ x = x + self.drop_path(
+ self.attn(
+ self.norm1(x), rel_pos_bias=rel_pos_bias, pos_mask=pos_mask
+ )
+ )
+ else:
+ x = x + self.drop_path(self.attn(self.norm1(x)))
+ t = self.mlp(self.norm2(x))
+ x = x + self.drop_path(t)
+ if not self.ffn_targets:
+ t = x
+ return x, t
+ else:
+ if self.use_rel_pos_bias:
+ x = x + self.drop_path(
+ self.attn(x, rel_pos_bias=rel_pos_bias, pos_mask=pos_mask)
+ )
+ else:
+ x = x + self.drop_path(self.attn(x))
+ r = x = self.norm1(x)
+ x = self.mlp(x)
+ t = x
+ x = self.norm2(r + self.drop_path(x))
+ if not self.ffn_targets:
+ t = x
+ return x, t
+
+
+class AltAttention(nn.Module):
+ def __init__(
+ self,
+ dim,
+ num_heads=8,
+ qkv_bias=True,
+ qk_scale=None,
+ attn_drop=0.0,
+ proj_drop=0.0,
+ window_size=None,
+ attn_head_dim=None,
+ ):
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ if attn_head_dim is not None:
+ head_dim = attn_head_dim
+ all_head_dim = head_dim * self.num_heads
+ self.scale = qk_scale or head_dim ** -0.5
+
+ self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
+ if qkv_bias:
+ self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
+ self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
+ else:
+ self.q_bias = None
+ self.v_bias = None
+
+ if window_size:
+ self.window_size = window_size
+ self.num_relative_distance = (2 * window_size[0] - 1) * (
+ 2 * window_size[1] - 1
+ ) + 3
+ self.relative_position_bias_table = nn.Parameter(
+ torch.zeros(self.num_relative_distance, num_heads)
+ ) # 2*Wh-1 * 2*Ww-1, nH
+ # cls to token & token 2 cls & cls to cls
+
+ # get pair-wise relative position index for each token inside the window
+ coords_h = torch.arange(window_size[0])
+ coords_w = torch.arange(window_size[1])
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
+ relative_coords = (
+ coords_flatten[:, :, None] - coords_flatten[:, None, :]
+ ) # 2, Wh*Ww, Wh*Ww
+ relative_coords = relative_coords.permute(
+ 1, 2, 0
+ ).contiguous() # Wh*Ww, Wh*Ww, 2
+ relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
+ relative_coords[:, :, 1] += window_size[1] - 1
+ relative_coords[:, :, 0] *= 2 * window_size[1] - 1
+ relative_position_index = torch.zeros(
+ size=(window_size[0] * window_size[1] + 1,) * 2,
+ dtype=relative_coords.dtype,
+ )
+ relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
+ relative_position_index[0, 0:] = self.num_relative_distance - 3
+ relative_position_index[0:, 0] = self.num_relative_distance - 2
+ relative_position_index[0, 0] = self.num_relative_distance - 1
+
+ self.register_buffer("relative_position_index", relative_position_index)
+ else:
+ self.window_size = None
+ self.relative_position_bias_table = None
+ self.relative_position_index = None
+
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(all_head_dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ def forward(self, x, rel_pos_bias=None, pos_mask=None):
+ B, N, C = x.shape
+ qkv_bias = None
+ if self.q_bias is not None:
+ qkv_bias = torch.cat(
+ (
+ self.q_bias,
+ torch.zeros_like(self.v_bias, requires_grad=False),
+ self.v_bias,
+ )
+ )
+ # qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
+ qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
+ qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
+ q, k, v = (
+ qkv[0],
+ qkv[1],
+ qkv[2],
+ ) # make torchscript happy (cannot use tensor as tuple)
+
+ q = q * self.scale
+ attn = q @ k.transpose(-2, -1)
+
+ if self.relative_position_bias_table is not None:
+ relative_position_bias = self.relative_position_bias_table[
+ self.relative_position_index.view(-1)
+ ].view(
+ self.window_size[0] * self.window_size[1] + 1,
+ self.window_size[0] * self.window_size[1] + 1,
+ -1,
+ ) # Wh*Ww,Wh*Ww,nH
+ relative_position_bias = relative_position_bias.permute(
+ 2, 0, 1
+ ).contiguous() # nH, Wh*Ww, Wh*Ww
+ attn = attn + modify_relative_position_bias(
+ relative_position_bias, x.size(0), pos_mask
+ )
+
+ if rel_pos_bias is not None:
+ attn = attn + rel_pos_bias
+
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class RelativePositionBias(nn.Module):
+ def __init__(self, window_size, num_heads):
+ super().__init__()
+ self.window_size = window_size
+ self.num_relative_distance = (2 * window_size[0] - 1) * (
+ 2 * window_size[1] - 1
+ ) + 3
+ self.relative_position_bias_table = nn.Parameter(
+ torch.zeros(self.num_relative_distance, num_heads)
+ )
+
+ # get pair-wise relative position index for each token inside the window
+ coords_h = torch.arange(window_size[0])
+ coords_w = torch.arange(window_size[1])
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
+ relative_coords = (
+ coords_flatten[:, :, None] - coords_flatten[:, None, :]
+ ) # 2, Wh*Ww, Wh*Ww
+ relative_coords = relative_coords.permute(
+ 1, 2, 0
+ ).contiguous() # Wh*Ww, Wh*Ww, 2
+ relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
+ relative_coords[:, :, 1] += window_size[1] - 1
+ relative_coords[:, :, 0] *= 2 * window_size[1] - 1
+ relative_position_index = torch.zeros(
+ size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype
+ )
+ relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
+ relative_position_index[0, 0:] = self.num_relative_distance - 3
+ relative_position_index[0:, 0] = self.num_relative_distance - 2
+ relative_position_index[0, 0] = self.num_relative_distance - 1
+
+ self.register_buffer("relative_position_index", relative_position_index)
+
+ def forward(self):
+ relative_position_bias = self.relative_position_bias_table[
+ self.relative_position_index.view(-1)
+ ].view(
+ self.window_size[0] * self.window_size[1] + 1,
+ self.window_size[0] * self.window_size[1] + 1,
+ -1,
+ ) # Wh*Ww,Wh*Ww,nH
+ return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
+
+
+def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
+ """
+ grid_size: int of the grid height and width
+ return:
+ pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
+ """
+ grid_h = np.arange(grid_size, dtype=np.float32)
+ grid_w = np.arange(grid_size, dtype=np.float32)
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
+ grid = np.stack(grid, axis=0)
+
+ grid = grid.reshape([2, 1, grid_size, grid_size])
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
+ if cls_token:
+ pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
+ return pos_embed
+
+
+def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
+ assert embed_dim % 2 == 0
+
+ # use half of dimensions to encode grid_h
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
+
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
+ return emb
+
+
+def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
+ """
+ embed_dim: output dimension for each position
+ pos: a list of positions to be encoded: size (M,)
+ out: (M, D)
+ """
+ assert embed_dim % 2 == 0
+ omega = np.arange(embed_dim // 2, dtype=np.float)
+ omega /= embed_dim / 2.0
+ omega = 1.0 / 10000 ** omega # (D/2,)
+
+ pos = pos.reshape(-1) # (M,)
+ out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
+
+ emb_sin = np.sin(out) # (M, D/2)
+ emb_cos = np.cos(out) # (M, D/2)
+
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
+ return emb
+
+
+def interpolate_pos_embed(model, checkpoint_model):
+ if "pos_embed" in checkpoint_model:
+ pos_embed_checkpoint = checkpoint_model["pos_embed"]
+ embedding_size = pos_embed_checkpoint.shape[-1]
+ num_patches = model.patch_embed.num_patches
+ num_extra_tokens = model.pos_embed.shape[-2] - num_patches
+ # height (== width) for the checkpoint position embedding
+ orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
+ # height (== width) for the new position embedding
+ new_size = int(num_patches ** 0.5)
+ # class_token and dist_token are kept unchanged
+ if orig_size != new_size:
+ print(
+ "Position interpolate from %dx%d to %dx%d"
+ % (orig_size, orig_size, new_size, new_size)
+ )
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
+ # only the position tokens are interpolated
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
+ pos_tokens = pos_tokens.reshape(
+ -1, orig_size, orig_size, embedding_size
+ ).permute(0, 3, 1, 2)
+ pos_tokens = torch.nn.functional.interpolate(
+ pos_tokens,
+ size=(new_size, new_size),
+ mode="bicubic",
+ align_corners=False,
+ )
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
+ checkpoint_model["pos_embed"] = new_pos_embed
+
+
+@register_model("mae", dataclass=MaeConfig)
+class MaeModel(BaseFairseqModel):
+ def __init__(self, cfg: MaeConfig):
+ super().__init__()
+ self.cfg = cfg
+
+ self.mask_ratio = cfg.mask_ratio
+
+ # --------------------------------------------------------------------------
+ # MAE encoder specifics
+ self.patch_embed = PatchEmbed(
+ cfg.input_size, cfg.patch_size, cfg.in_chans, cfg.embed_dim
+ )
+ num_patches = self.patch_embed.num_patches
+
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, cfg.embed_dim)) if not cfg.no_cls else None
+ self.pos_embed = nn.Parameter(
+ torch.zeros(1, num_patches + int(not cfg.no_cls), cfg.embed_dim), requires_grad=False
+ ) # fixed sin-cos embedding
+
+ norm_layer = partial(nn.LayerNorm, eps=cfg.norm_eps)
+
+ dpr = [
+ x.item() for x in torch.linspace(0, cfg.drop_path_rate, cfg.depth)
+ ] # stochastic depth decay rule
+
+ def make_block(drop_path):
+ if cfg.w2v_block:
+ return TransformerSentenceEncoderLayer(
+ embedding_dim=cfg.embed_dim,
+ ffn_embedding_dim=cfg.embed_dim * cfg.mlp_ratio,
+ num_attention_heads=cfg.num_heads,
+ dropout=cfg.block_dropout,
+ attention_dropout=cfg.attention_dropout,
+ activation_dropout=cfg.activation_dropout,
+ activation_fn="gelu",
+ layer_norm_first=cfg.layer_norm_first,
+ drop_path=drop_path,
+ norm_eps=1e-6,
+ single_qkv=cfg.single_qkv,
+ fused_ln=cfg.fused_ln,
+ )
+ elif cfg.alt_block:
+ window_size = (
+ cfg.input_size // self.patch_embed.patch_size[0],
+ cfg.input_size // self.patch_embed.patch_size[1],
+ )
+ return AltBlock(
+ cfg.embed_dim,
+ cfg.num_heads,
+ cfg.mlp_ratio,
+ qkv_bias=True,
+ qk_scale=None,
+ norm_layer=norm_layer,
+ drop_path=drop_path,
+ layer_norm_first=cfg.layer_norm_first,
+ ffn_targets=not cfg.end_of_block_targets,
+ use_rel_pos_bias=cfg.use_rel_pos_bias,
+ window_size=window_size
+ if (self.cfg.use_rel_pos_bias and not self.cfg.shared_rel_pos_bias)
+ else None,
+ alt_attention=cfg.alt_attention,
+ )
+ elif cfg.alt_block2:
+ from .multi.modules import AltBlock as AltBlock2
+ return AltBlock2(
+ cfg.embed_dim,
+ cfg.num_heads,
+ cfg.mlp_ratio,
+ qkv_bias=True,
+ qk_scale=None,
+ norm_layer=norm_layer,
+ drop_path=drop_path,
+ layer_norm_first=cfg.layer_norm_first,
+ ffn_targets=not cfg.end_of_block_targets,
+ )
+ else:
+ return Block(
+ cfg.embed_dim,
+ cfg.num_heads,
+ cfg.mlp_ratio,
+ qkv_bias=True,
+ qk_scale=None,
+ norm_layer=norm_layer,
+ drop_path=drop_path,
+ )
+
+ self.blocks = nn.ModuleList([make_block(dpr[i]) for i in range(cfg.depth)])
+ self.norm = norm_layer(cfg.embed_dim)
+ # --------------------------------------------------------------------------
+
+ # --------------------------------------------------------------------------
+ # MAE decoder specifics
+ self.decoder_embed = (
+ nn.Linear(cfg.embed_dim, cfg.decoder_embed_dim, bias=True)
+ if not cfg.no_decoder_embed
+ else None
+ )
+
+ self.mask_token = (
+ nn.Parameter(
+ torch.zeros(
+ 1,
+ 1,
+ cfg.decoder_embed_dim
+ if not cfg.no_decoder_embed
+ else cfg.embed_dim,
+ )
+ )
+ if cfg.mask_noise_std <= 0
+ else None
+ )
+
+ self.decoder_pos_embed = (
+ nn.Parameter(
+ torch.zeros(
+ 1,
+ num_patches + 1,
+ cfg.decoder_embed_dim
+ if not cfg.no_decoder_embed
+ else cfg.embed_dim,
+ ),
+ requires_grad=False,
+ )
+ if not cfg.no_decoder_pos_embed
+ else None
+ )
+
+ self.decoder_blocks = nn.ModuleList(
+ [
+ Block(
+ cfg.decoder_embed_dim,
+ cfg.decoder_num_heads,
+ cfg.mlp_ratio,
+ qkv_bias=True,
+ qk_scale=None,
+ norm_layer=norm_layer,
+ )
+ for _ in range(cfg.decoder_depth)
+ ]
+ )
+
+ self.decoder_norm = norm_layer(cfg.decoder_embed_dim)
+ self.decoder_pred = nn.Linear(
+ cfg.decoder_embed_dim, cfg.patch_size ** 2 * cfg.in_chans, bias=True
+ ) # decoder to patch
+ # --------------------------------------------------------------------------
+
+ self.norm_pix_loss = cfg.norm_pix_loss
+
+ self.initialize_weights()
+
+ for pn, p in self.named_parameters():
+ if len(p.shape) == 1 or pn.endswith(".bias"):
+ p.param_group = "no_decay"
+ else:
+ p.param_group = "with_decay"
+
+ def initialize_weights(self):
+ # initialization
+ # initialize (and freeze) pos_embed by sin-cos embedding
+ pos_embed = get_2d_sincos_pos_embed(
+ self.pos_embed.shape[-1],
+ int(self.patch_embed.num_patches ** 0.5),
+ cls_token=not self.cfg.no_cls,
+ )
+ self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
+
+ if self.decoder_pos_embed is not None:
+ decoder_pos_embed = get_2d_sincos_pos_embed(
+ self.decoder_pos_embed.shape[-1],
+ int(self.patch_embed.num_patches ** 0.5),
+ cls_token=not self.cfg.no_cls,
+ )
+ self.decoder_pos_embed.data.copy_(
+ torch.from_numpy(decoder_pos_embed).float().unsqueeze(0)
+ )
+
+ # initialize patch_embed like nn.Linear (instead of nn.Conv2d)
+ w = self.patch_embed.proj.weight.data
+ torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
+
+ # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
+ if self.cls_token is not None:
+ torch.nn.init.normal_(self.cls_token, std=0.02)
+
+ if self.mask_token is not None:
+ torch.nn.init.normal_(self.mask_token, std=0.02)
+
+ # initialize nn.Linear and nn.LayerNorm
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ # we use xavier_uniform following official JAX ViT:
+ torch.nn.init.xavier_uniform_(m.weight)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm) or isinstance(m, FusedLayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+
+ def patchify(self, imgs):
+ """
+ imgs: (N, 3, H, W)
+ x: (N, L, patch_size**2 *3)
+ """
+ p = self.patch_embed.patch_size[0]
+ assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0
+
+ h = w = imgs.shape[2] // p
+ x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p))
+ x = torch.einsum("nchpwq->nhwpqc", x)
+ x = x.reshape(shape=(imgs.shape[0], h * w, p ** 2 * 3))
+ return x
+
+ def unpatchify(self, x):
+ """
+ x: (N, L, patch_size**2 *3)
+ imgs: (N, 3, H, W)
+ """
+ p = self.patch_embed.patch_size[0]
+ h = w = int(x.shape[1] ** 0.5)
+ assert h * w == x.shape[1]
+
+ x = x.reshape(shape=(x.shape[0], h, w, p, p, 3))
+ x = torch.einsum("nhwpqc->nchpwq", x)
+ imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p))
+ return imgs
+
+ def random_masking(self, x, mask_ratio):
+ """
+ Perform per-sample random masking by per-sample shuffling.
+ Per-sample shuffling is done by argsort random noise.
+ x: [N, L, D], sequence
+ """
+ N, L, D = x.shape # batch, length, dim
+ len_keep = int(L * (1 - mask_ratio))
+
+ noise = torch.rand(N, L, device=x.device) # noise in [0, 1]
+
+ # sort noise for each sample
+ ids_shuffle = torch.argsort(
+ noise, dim=1
+ ) # ascend: small is keep, large is remove
+ ids_restore = torch.argsort(ids_shuffle, dim=1)
+
+ # keep the first subset
+ ids_keep = ids_shuffle[:, :len_keep]
+ x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
+
+ # generate the binary mask: 0 is keep, 1 is remove
+ mask = torch.ones([N, L], device=x.device)
+ mask[:, :len_keep] = 0
+ # unshuffle to get the binary mask
+ mask = torch.gather(mask, dim=1, index=ids_restore)
+
+ return x_masked, mask, ids_restore # x_masked is actually unmasked x
+
+ @classmethod
+ def build_model(cls, cfg: MaeConfig, task=None):
+ """Build a new model instance."""
+
+ return cls(cfg)
+
+ def forward_encoder(self, x, mask_ratio):
+ # embed patches
+ x = self.patch_embed(x)
+
+ # add pos embed w/o cls token
+ # if self.cls_token is not None:
+ # x = x + self.pos_embed
+ # else:
+ x = x + self.pos_embed[:, 1:, :]
+
+ # masking: length -> length * mask_ratio
+ if mask_ratio > 0:
+ x, mask, ids_restore = self.random_masking(x, mask_ratio)
+ else:
+ mask = ids_restore = None
+
+ # append cls token
+ if self.cls_token is not None:
+ cls_token = self.cls_token + self.pos_embed[:, :1, :]
+ cls_tokens = cls_token.expand(x.shape[0], -1, -1)
+ x = torch.cat((cls_tokens, x), dim=1)
+
+ # apply Transformer blocks
+ for blk in self.blocks:
+ x = blk(x)
+
+ if self.norm is not None:
+ x = self.norm(x)
+
+ return x, mask, ids_restore
+
+ def forward_decoder(self, x, ids_restore):
+ # embed tokens
+ x = self.decoder_embed(x)
+
+ # append mask tokens to sequence
+ mask_tokens = self.mask_token.repeat(
+ x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1
+ )
+ if self.cls_token is not None:
+ x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token
+ else:
+ x_ = torch.cat([x, mask_tokens], dim=1) # no cls token
+
+ x_ = torch.gather(
+ x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])
+ ) # unshuffle
+
+ if self.cls_token is not None:
+ x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token
+
+ # add pos embed
+ x = x + self.decoder_pos_embed
+
+ # apply Transformer blocks
+ for blk in self.decoder_blocks:
+ x = blk(x)
+ x = self.decoder_norm(x)
+
+ # predictor projection
+ x = self.decoder_pred(x)
+
+ if self.cls_token is not None:
+ # remove cls token
+ x = x[:, 1:, :]
+
+ return x
+
+ def forward_loss(self, imgs, pred, mask):
+ """
+ imgs: [N, 3, H, W]
+ pred: [N, L, p*p*3]
+ mask: [N, L], 0 is keep, 1 is remove,
+ """
+ target = self.patchify(imgs)
+ if self.norm_pix_loss:
+ mean = target.mean(dim=-1, keepdim=True)
+ var = target.var(dim=-1, keepdim=True)
+ target = (target - mean) / (var + 1.0e-6) ** 0.5
+
+ loss = (pred - target) ** 2
+ loss = loss.mean(dim=-1) # [N, L], mean loss per patch
+
+ loss = (loss * mask).sum()
+ return loss, mask.sum()
+
+ def forward(self, imgs, predictions_only=False):
+ latent, mask, ids_restore = self.forward_encoder(
+ imgs, self.mask_ratio if not predictions_only else 0
+ )
+
+ if predictions_only:
+ return latent
+
+ pred = self.forward_decoder(latent, ids_restore) # [N, L, p*p*3]
+ loss, sample_size = self.forward_loss(imgs, pred, mask)
+
+ result = {
+ "losses": {"regression": loss},
+ "sample_size": sample_size,
+ }
+ return result
+
+ def remove_pretraining_modules(self):
+ self.decoder_embed = None
+ self.decoder_blocks = None
+ self.decoder_norm = None
+ self.decoder_pos_embed = None
+ self.decoder_pred = None
+ self.mask_token = None
+ if self.cfg.layer_norm_first:
+ self.norm = None
diff --git a/examples/data2vec/models/mae_image_classification.py b/examples/data2vec/models/mae_image_classification.py
new file mode 100644
index 000000000..e304618dc
--- /dev/null
+++ b/examples/data2vec/models/mae_image_classification.py
@@ -0,0 +1,386 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+# The code in this file is adapted from the BeiT implementation which can be found here:
+# https://github.com/microsoft/unilm/tree/master/beit
+
+import logging
+
+from dataclasses import dataclass
+from enum import Enum, auto
+from typing import Any, Optional
+
+import numpy as np
+from omegaconf import II, MISSING
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from fairseq import checkpoint_utils, tasks
+from omegaconf import open_dict
+
+from fairseq.dataclass import FairseqDataclass
+from fairseq.models import BaseFairseqModel, register_model
+from .mae import interpolate_pos_embed
+
+
+logger = logging.getLogger(__name__)
+
+
+class PredictionMode(Enum):
+ MEAN_POOLING = auto()
+ CLS_TOKEN = auto()
+ LIN_SOFTMAX = auto()
+
+
+@dataclass
+class MaeImageClassificationConfig(FairseqDataclass):
+ model_path: str = MISSING
+ no_pretrained_weights: bool = False
+ linear_classifier: bool = False
+ num_classes: int = 1000
+ mixup: float = 0.8
+ cutmix: float = 1.0
+ label_smoothing: float = 0.1
+
+ drop_path_rate: float = 0.1
+ layer_decay: float = 0.65
+
+ mixup_prob: float = 1.0
+ mixup_switch_prob: float = 0.5
+ mixup_mode: str = "batch"
+
+ pretrained_model_args: Any = None
+ data: str = II("task.data")
+
+ norm_eps: Optional[float] = None
+
+ remove_alibi: bool = False
+
+ # regularization overwrites
+ encoder_dropout: float = 0
+ post_mlp_drop: float = 0
+ attention_dropout: float = 0
+ activation_dropout: float = 0.0
+ dropout_input: float = 0.0
+ layerdrop: float = 0.0
+
+ prenet_layerdrop: float = 0
+ prenet_dropout: float = 0
+
+ use_fc_norm: bool = True
+ prediction_mode: PredictionMode = PredictionMode.MEAN_POOLING
+
+ no_decay_blocks: bool = True
+
+
+def get_layer_id_for_vit(name, num_layers):
+ """
+ Assign a parameter with its layer id
+ Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33
+ """
+ if name in ["cls_token", "pos_embed"]:
+ return 0
+ elif name.startswith("patch_embed"):
+ return 0
+ elif name.startswith("rel_pos_bias"):
+ return num_layers - 1
+ elif name.startswith("blocks"):
+ return int(name.split(".")[1]) + 1
+ else:
+ return num_layers
+
+
+@register_model("mae_image_classification", dataclass=MaeImageClassificationConfig)
+class MaeImageClassificationModel(BaseFairseqModel):
+ def __init__(self, cfg: MaeImageClassificationConfig):
+ super().__init__()
+ self.cfg = cfg
+
+ if cfg.pretrained_model_args is None:
+ state = checkpoint_utils.load_checkpoint_to_cpu(cfg.model_path, {})
+ pretrained_args = state.get("cfg", None)
+
+ pretrained_args.criterion = None
+ pretrained_args.lr_scheduler = None
+
+ logger.info(pretrained_args.model)
+
+ with open_dict(pretrained_args.model):
+ pretrained_args.model.drop_path_rate = cfg.drop_path_rate
+ if cfg.norm_eps is not None:
+ pretrained_args.model.norm_eps = cfg.norm_eps
+
+ cfg.pretrained_model_args = pretrained_args
+
+ logger.info(pretrained_args)
+ else:
+ state = None
+ pretrained_args = cfg.pretrained_model_args
+
+ if "data" in pretrained_args.task:
+ pretrained_args.task.data = cfg.data
+ elif "image" in pretrained_args.task:
+ pretrained_args.task.image.data = cfg.data
+
+ if "modalities" in pretrained_args.model:
+ prenet_blocks = pretrained_args.model["modalities"]["image"]["prenet_depth"]
+ model_blocks = pretrained_args.model["depth"]
+ with open_dict(pretrained_args):
+ dpr = np.linspace(0, cfg.drop_path_rate, model_blocks).tolist()
+ pretrained_args.model["modalities"]["image"][
+ "start_drop_path_rate"
+ ] = dpr[0]
+ pretrained_args.model["modalities"]["image"][
+ "end_drop_path_rate"
+ ] = max(0, dpr[prenet_blocks - 1])
+ pretrained_args.model["start_drop_path_rate"] = dpr[prenet_blocks]
+ pretrained_args.model["end_drop_path_rate"] = dpr[-1]
+
+ if "mae_masking" in pretrained_args.model["modalities"]["image"]:
+ del pretrained_args.model["modalities"]["image"]["mae_masking"]
+
+ if cfg.remove_alibi:
+ pretrained_args.model["modalities"]["image"][
+ "use_alibi_encoder"
+ ] = False
+ if (
+ state is not None
+ and "modality_encoders.IMAGE.alibi_bias" in state["model"]
+ ):
+ del state["model"]["modality_encoders.IMAGE.alibi_bias"]
+
+ pretrained_args.model["encoder_dropout"] = cfg.encoder_dropout
+ pretrained_args.model["post_mlp_drop"] = cfg.post_mlp_drop
+ pretrained_args.model["attention_dropout"] = cfg.attention_dropout
+ pretrained_args.model["activation_dropout"] = cfg.activation_dropout
+ pretrained_args.model["dropout_input"] = cfg.dropout_input
+ pretrained_args.model["layerdrop"] = cfg.layerdrop
+
+ pretrained_args.model["modalities"]["image"][
+ "prenet_layerdrop"
+ ] = cfg.prenet_layerdrop
+ pretrained_args.model["modalities"]["image"][
+ "prenet_dropout"
+ ] = cfg.prenet_dropout
+ else:
+ # not d2v multi
+ with open_dict(pretrained_args):
+ pretrained_args.model["drop_path_rate"] = cfg.drop_path_rate
+ pretrained_args.model["block_dropout"] = cfg.encoder_dropout
+ pretrained_args.model["attention_dropout"] = cfg.attention_dropout
+ pretrained_args.model["activation_dropout"] = cfg.activation_dropout
+
+ task = tasks.setup_task(pretrained_args.task)
+ model = task.build_model(pretrained_args.model, from_checkpoint=True)
+
+ self.d2v_multi = "data2vec_multi" in pretrained_args.model._name
+ self.linear_classifier = cfg.linear_classifier
+
+ self.model = model
+
+ if state is not None and not cfg.no_pretrained_weights:
+ interpolate_pos_embed(model, state)
+
+ if "modality_encoders.IMAGE.positional_encoder.pos_embed" in state["model"]:
+ state["model"][
+ "modality_encoders.IMAGE.positional_encoder.positions"
+ ] = state["model"][
+ "modality_encoders.IMAGE.positional_encoder.pos_embed"
+ ]
+ del state["model"][
+ "modality_encoders.IMAGE.positional_encoder.pos_embed"
+ ]
+ if "modality_encoders.IMAGE.encoder_mask" in state["model"]:
+ del state["model"]["modality_encoders.IMAGE.encoder_mask"]
+
+ model.load_state_dict(state["model"], strict=True)
+
+ if self.d2v_multi:
+ model.remove_pretraining_modules(modality="image")
+ else:
+ model.remove_pretraining_modules()
+
+ if self.linear_classifier:
+ model.requires_grad_(False)
+
+ self.fc_norm = None
+ if self.cfg.use_fc_norm:
+ self.fc_norm = nn.LayerNorm(pretrained_args.model.embed_dim, eps=1e-6)
+ nn.init.constant_(self.fc_norm.bias, 0)
+ nn.init.constant_(self.fc_norm.weight, 1.0)
+
+ self.head = nn.Linear(pretrained_args.model.embed_dim, cfg.num_classes)
+
+ nn.init.trunc_normal_(self.head.weight, std=0.02)
+ nn.init.constant_(self.head.bias, 0)
+
+ self.mixup_fn = None
+
+ if cfg.mixup > 0 or cfg.cutmix > 0:
+ from timm.data import Mixup
+
+ self.mixup_fn = Mixup(
+ mixup_alpha=cfg.mixup,
+ cutmix_alpha=cfg.cutmix,
+ cutmix_minmax=None,
+ prob=cfg.mixup_prob,
+ switch_prob=cfg.mixup_switch_prob,
+ mode=cfg.mixup_mode,
+ label_smoothing=cfg.label_smoothing,
+ num_classes=cfg.num_classes,
+ )
+
+ if self.model.norm is not None:
+ for pn, p in self.model.norm.named_parameters():
+ if len(p.shape) == 1 or pn.endswith(".bias"):
+ p.optim_overrides = {"optimizer": {"weight_decay_scale": 0}}
+
+ if self.fc_norm is not None:
+ for pn, p in self.fc_norm.named_parameters():
+ if len(p.shape) == 1 or pn.endswith(".bias"):
+ p.optim_overrides = {"optimizer": {"weight_decay_scale": 0}}
+
+ for pn, p in self.head.named_parameters():
+ if len(p.shape) == 1 or pn.endswith(".bias"):
+ p.optim_overrides = {"optimizer": {"weight_decay_scale": 0}}
+
+ if self.d2v_multi:
+ mod_encs = list(model.modality_encoders.values())
+ assert len(mod_encs) == 1, len(mod_encs)
+ blocks = list(mod_encs[0].context_encoder.blocks) + list(model.blocks)
+ else:
+ blocks = model.blocks
+
+ num_layers = len(blocks) + 1
+ layer_scales = list(
+ cfg.layer_decay ** (num_layers - i) for i in range(num_layers + 1)
+ )
+
+ if self.d2v_multi:
+ for n, p in self.model.named_parameters():
+ optimizer_override_dict = {}
+
+ if len(p.shape) == 1 or n.endswith(".bias"):
+ optimizer_override_dict["weight_decay_scale"] = 0
+
+ p.optim_overrides = {"optimizer": optimizer_override_dict}
+
+ if cfg.layer_decay > 0:
+ for i, b in enumerate(blocks):
+ lid = i + 1
+ if layer_scales[lid] == 1.0:
+ continue
+
+ for n, p in b.named_parameters():
+ optim_override = getattr(p, "optim_overrides", {})
+ if "optimizer" not in optim_override:
+ optim_override["optimizer"] = {}
+
+ if cfg.no_decay_blocks:
+ optim_override["optimizer"]["lr_scale"] = layer_scales[lid]
+ p.optim_overrides = optim_override
+ else:
+ optim_override["optimizer"] = {
+ "lr_scale": layer_scales[lid]
+ }
+ p.optim_overrides = optim_override
+
+ else:
+ for n, p in self.model.named_parameters():
+ optimizer_override_dict = {}
+ layer_id = get_layer_id_for_vit(n, num_layers)
+
+ if len(p.shape) == 1 or n.endswith(".bias"):
+ optimizer_override_dict["weight_decay_scale"] = 0
+
+ if cfg.layer_decay > 0:
+ optimizer_override_dict["lr_scale"] = layer_scales[layer_id]
+ p.optim_overrides = {"optimizer": optimizer_override_dict}
+
+ @classmethod
+ def build_model(cls, cfg: MaeImageClassificationConfig, task=None):
+ """Build a new model instance."""
+
+ return cls(cfg)
+
+ def forward(
+ self,
+ imgs,
+ labels=None,
+ ):
+ if self.training and self.mixup_fn is not None and labels is not None:
+ imgs, labels = self.mixup_fn(imgs, labels)
+
+ if self.linear_classifier:
+ with torch.no_grad():
+ x = self.model_forward(imgs)
+ else:
+ x = self.model_forward(imgs)
+
+ if self.cfg.prediction_mode == PredictionMode.MEAN_POOLING:
+ x = x.mean(dim=1)
+ elif self.cfg.prediction_mode == PredictionMode.CLS_TOKEN:
+ x = x[:, 0]
+ elif self.cfg.prediction_mode == PredictionMode.LIN_SOFTMAX:
+ dtype = x.dtype
+ x = F.logsigmoid(x.float())
+ x = torch.logsumexp(x + x, dim=1) - torch.logsumexp(x + 1e-6, dim=1)
+ x = x.clamp(max=0)
+ x = x - torch.log(-(torch.expm1(x)))
+ x = torch.nan_to_num(x, nan=0, posinf=0, neginf=0)
+ x = x.to(dtype=dtype)
+ else:
+ raise Exception(f"unknown prediction mode {self.cfg.prediction_mode.name}")
+
+ if self.fc_norm is not None:
+ x = self.fc_norm(x)
+
+ x = self.head(x)
+
+ if labels is None:
+ return x
+
+ if self.training and self.mixup_fn is not None:
+ loss = -labels * F.log_softmax(x.float(), dim=-1)
+ else:
+ loss = F.cross_entropy(
+ x.float(),
+ labels,
+ label_smoothing=self.cfg.label_smoothing if self.training else 0,
+ reduction="none",
+ )
+
+ result = {
+ "losses": {"regression": loss},
+ "sample_size": imgs.size(0),
+ }
+
+ if not self.training:
+ with torch.no_grad():
+ pred = x.argmax(-1)
+ correct = (pred == labels).sum()
+ result["correct"] = correct
+
+ return result
+
+ def model_forward(self, imgs):
+ if self.d2v_multi:
+ x = self.model.extract_features(
+ imgs,
+ mode="IMAGE",
+ mask=False,
+ remove_extra_tokens=(
+ self.cfg.prediction_mode != PredictionMode.CLS_TOKEN
+ ),
+ )["x"]
+ else:
+ x = self.model(imgs, predictions_only=True)
+ if (
+ "no_cls" not in self.model.cfg or not self.model.cfg.no_cls
+ ) and not self.cfg.prediction_mode == PredictionMode.CLS_TOKEN:
+ x = x[:, 1:]
+ return x
diff --git a/examples/data2vec/models/modalities/__init__.py b/examples/data2vec/models/modalities/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/examples/data2vec/models/modalities/audio.py b/examples/data2vec/models/modalities/audio.py
new file mode 100644
index 000000000..80d2857b2
--- /dev/null
+++ b/examples/data2vec/models/modalities/audio.py
@@ -0,0 +1,192 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+from functools import partial
+import torch
+import torch.nn as nn
+import numpy as np
+from dataclasses import dataclass, field
+from typing import Callable, Dict, Optional
+from fairseq.models.wav2vec import ConvFeatureExtractionModel
+from fairseq.modules import (
+ LayerNorm,
+ SamePad,
+ TransposeLast,
+)
+from fairseq.tasks import FairseqTask
+from .base import D2vModalityConfig, ModalitySpecificEncoder, get_alibi_bias
+from .modules import BlockEncoder, Decoder1d
+from examples.data2vec.data.modality import Modality
+
+
+@dataclass
+class D2vAudioConfig(D2vModalityConfig):
+ type: Modality = Modality.AUDIO
+ extractor_mode: str = "layer_norm"
+ feature_encoder_spec: str = field(
+ default="[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512,2,2)] + [(512,2,2)]",
+ metadata={
+ "help": "string describing convolutional feature extraction layers in form of a python list that contains "
+ "[(dim, kernel_size, stride), ...]"
+ },
+ )
+ conv_pos_width: int = field(
+ default=95,
+ metadata={"help": "number of filters for convolutional positional embeddings"},
+ )
+ conv_pos_groups: int = field(
+ default=16,
+ metadata={"help": "number of groups for convolutional positional embedding"},
+ )
+ conv_pos_depth: int = field(
+ default=5,
+ metadata={"help": "depth of positional encoder network"},
+ )
+ conv_pos_pre_ln: bool = False
+
+
+class AudioEncoder(ModalitySpecificEncoder):
+
+ modality_cfg: D2vAudioConfig
+
+ def __init__(
+ self,
+ modality_cfg: D2vAudioConfig,
+ embed_dim: int,
+ make_block: Callable[[float], nn.ModuleList],
+ norm_layer: Callable[[int], nn.LayerNorm],
+ layer_norm_first: bool,
+ alibi_biases: Dict,
+ task: Optional[FairseqTask],
+ ):
+
+ self.feature_enc_layers = eval(modality_cfg.feature_encoder_spec)
+ feature_embed_dim = self.feature_enc_layers[-1][0]
+
+ local_encoder = ConvFeatureExtractionModel(
+ conv_layers=self.feature_enc_layers,
+ dropout=0.0,
+ mode=modality_cfg.extractor_mode,
+ conv_bias=False,
+ )
+
+ project_features = nn.Sequential(
+ TransposeLast(),
+ nn.LayerNorm(feature_embed_dim),
+ nn.Linear(feature_embed_dim, embed_dim),
+ )
+
+ num_pos_layers = modality_cfg.conv_pos_depth
+ k = max(3, modality_cfg.conv_pos_width // num_pos_layers)
+
+ positional_encoder = nn.Sequential(
+ TransposeLast(),
+ *[
+ nn.Sequential(
+ nn.Conv1d(
+ embed_dim,
+ embed_dim,
+ kernel_size=k,
+ padding=k // 2,
+ groups=modality_cfg.conv_pos_groups,
+ ),
+ SamePad(k),
+ TransposeLast(),
+ LayerNorm(embed_dim, elementwise_affine=False),
+ TransposeLast(),
+ nn.GELU(),
+ )
+ for _ in range(num_pos_layers)
+ ],
+ TransposeLast(),
+ )
+
+ if modality_cfg.conv_pos_pre_ln:
+ positional_encoder = nn.Sequential(LayerNorm(embed_dim), positional_encoder)
+
+ dpr = np.linspace(
+ modality_cfg.start_drop_path_rate,
+ modality_cfg.end_drop_path_rate,
+ modality_cfg.prenet_depth,
+ )
+ context_encoder = BlockEncoder(
+ nn.ModuleList(make_block(dpr[i]) for i in range(modality_cfg.prenet_depth)),
+ norm_layer(embed_dim) if not layer_norm_first else None,
+ layer_norm_first,
+ modality_cfg.prenet_layerdrop,
+ modality_cfg.prenet_dropout,
+ )
+
+ decoder = (
+ Decoder1d(modality_cfg.decoder, embed_dim)
+ if modality_cfg.decoder is not None
+ else None
+ )
+
+ alibi_bias_fn = partial(get_alibi_bias, alibi_biases=alibi_biases)
+
+ super().__init__(
+ modality_cfg=modality_cfg,
+ embed_dim=embed_dim,
+ local_encoder=local_encoder,
+ project_features=project_features,
+ fixed_positional_encoder=None,
+ relative_positional_encoder=positional_encoder,
+ context_encoder=context_encoder,
+ decoder=decoder,
+ get_alibi_bias=alibi_bias_fn,
+ )
+
+ def convert_padding_mask(self, x, padding_mask):
+ def get_feat_extract_output_lengths(input_lengths: torch.LongTensor):
+ """
+ Computes the output length of the convolutional layers
+ """
+
+ def _conv_out_length(input_length, kernel_size, stride):
+ return torch.floor((input_length - kernel_size) / stride + 1)
+
+ for i in range(len(self.feature_enc_layers)):
+ input_lengths = _conv_out_length(
+ input_lengths,
+ self.feature_enc_layers[i][1],
+ self.feature_enc_layers[i][2],
+ )
+
+ return input_lengths.to(torch.long)
+
+ if padding_mask is not None:
+ input_lengths = (1 - padding_mask.long()).sum(-1)
+ # apply conv formula to get real output_lengths
+ output_lengths = get_feat_extract_output_lengths(input_lengths)
+
+ if padding_mask.any():
+ padding_mask = torch.zeros(x.shape[:2], dtype=x.dtype, device=x.device)
+
+ # these two operations makes sure that all values
+ # before the output lengths indices are attended to
+ padding_mask[
+ (
+ torch.arange(padding_mask.shape[0], device=padding_mask.device),
+ output_lengths - 1,
+ )
+ ] = 1
+ padding_mask = (
+ 1 - padding_mask.flip([-1]).cumsum(-1).flip([-1])
+ ).bool()
+ else:
+ padding_mask = torch.zeros(
+ x.shape[:2], dtype=torch.bool, device=x.device
+ )
+
+ return padding_mask
+
+ def reset_parameters(self):
+ super().reset_parameters()
+ for mod in self.project_features.children():
+ if isinstance(mod, nn.Linear):
+ mod.reset_parameters()
+ if self.decoder is not None:
+ self.decoder.reset_parameters()
diff --git a/examples/data2vec/models/modalities/base.py b/examples/data2vec/models/modalities/base.py
new file mode 100644
index 000000000..642cc8466
--- /dev/null
+++ b/examples/data2vec/models/modalities/base.py
@@ -0,0 +1,684 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import logging
+import math
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from collections import namedtuple
+from dataclasses import dataclass
+from functools import partial
+from omegaconf import MISSING, II
+from typing import Optional, Callable
+from fairseq.data.data_utils import compute_mask_indices
+from fairseq.modules import GradMultiply
+from fairseq.utils import index_put
+from examples.data2vec.data.modality import Modality
+from .modules import D2vDecoderConfig
+
+logger = logging.getLogger(__name__)
+
+
+@dataclass
+class D2vModalityConfig:
+ type: Modality = MISSING
+ prenet_depth: int = 4
+ prenet_layerdrop: float = 0
+ prenet_dropout: float = 0
+ start_drop_path_rate: float = 0
+ end_drop_path_rate: float = 0
+
+ num_extra_tokens: int = 0
+ init_extra_token_zero: bool = True
+
+ mask_noise_std: float = 0.01
+ mask_prob_min: Optional[float] = None
+ mask_prob: float = 0.7
+ inverse_mask: bool = False
+ mask_prob_adjust: float = 0
+ keep_masked_pct: float = 0
+
+ mask_length: int = 5
+ add_masks: bool = False
+ remove_masks: bool = False
+ mask_dropout: float = 0.0
+ encoder_zero_mask: bool = True
+
+ mask_channel_prob: float = 0.0
+ mask_channel_length: int = 64
+
+ ema_local_encoder: bool = False # used in data2vec_multi
+ local_grad_mult: float = 1.0
+
+ use_alibi_encoder: bool = False
+ alibi_scale: float = 1.0
+ learned_alibi: bool = False
+ alibi_max_pos: Optional[int] = None
+ learned_alibi_scale: bool = False
+ learned_alibi_scale_per_head: bool = False
+ learned_alibi_scale_per_layer: bool = False
+
+ num_alibi_heads: int = II("model.num_heads")
+ model_depth: int = II("model.depth")
+
+ decoder: Optional[D2vDecoderConfig] = D2vDecoderConfig()
+
+
+MaskSeed = namedtuple("MaskSeed", ["seed", "update", "ids"])
+MaskInfo = namedtuple("MaskInfo", ["x_unmasked", "mask", "ids_restore", "ids_keep"])
+
+
+class ModalitySpecificEncoder(nn.Module):
+ def __init__(
+ self,
+ modality_cfg: D2vModalityConfig,
+ embed_dim: int,
+ local_encoder: nn.Module,
+ project_features: nn.Module,
+ fixed_positional_encoder: Optional[nn.Module],
+ relative_positional_encoder: Optional[nn.Module],
+ context_encoder: nn.Module,
+ decoder: nn.Module,
+ get_alibi_bias: Optional[Callable[[int, int, str, str], torch.Tensor]],
+ ):
+ super().__init__()
+
+ self.modality_cfg = modality_cfg
+ self.local_encoder = local_encoder
+ self.project_features = project_features
+ self.fixed_positional_encoder = fixed_positional_encoder
+ self.relative_positional_encoder = relative_positional_encoder
+ self.context_encoder = context_encoder
+
+ self.decoder = decoder
+ self.get_alibi_bias = get_alibi_bias if modality_cfg.use_alibi_encoder else None
+
+ self.local_grad_mult = self.modality_cfg.local_grad_mult
+
+ self.extra_tokens = None
+ if modality_cfg.num_extra_tokens > 0:
+ self.extra_tokens = nn.Parameter(
+ torch.zeros(1, modality_cfg.num_extra_tokens, embed_dim)
+ )
+ if not modality_cfg.init_extra_token_zero:
+ nn.init.normal_(self.extra_tokens)
+ elif self.extra_tokens.size(1) > 1:
+ nn.init.normal_(self.extra_tokens[:, 1:])
+
+ self.alibi_scale = None
+ if self.get_alibi_bias is not None:
+ self.alibi_scale = nn.Parameter(
+ torch.full(
+ (
+ (modality_cfg.prenet_depth + modality_cfg.model_depth)
+ if modality_cfg.learned_alibi_scale_per_layer
+ else 1,
+ 1,
+ self.modality_cfg.num_alibi_heads
+ if modality_cfg.learned_alibi_scale_per_head
+ else 1,
+ 1,
+ 1,
+ ),
+ modality_cfg.alibi_scale,
+ dtype=torch.float,
+ ),
+ requires_grad=modality_cfg.learned_alibi_scale,
+ )
+
+ if modality_cfg.learned_alibi and self.get_alibi_bias is not None:
+ assert modality_cfg.alibi_max_pos is not None
+ alibi_bias = self.get_alibi_bias(
+ batch_size=1,
+ time_steps=modality_cfg.alibi_max_pos,
+ heads=modality_cfg.num_alibi_heads,
+ scale=1.0,
+ dtype=torch.float,
+ device="cpu",
+ )
+ self.alibi_bias = nn.Parameter(alibi_bias)
+ self.get_alibi_bias = partial(
+ _learned_alibi_bias, alibi_bias=self.alibi_bias
+ )
+
+ def upgrade_state_dict_named(self, state_dict, name):
+ k = f"{name}.alibi_scale"
+ if k in state_dict and state_dict[k].dim() == 4:
+ state_dict[k] = state_dict[k].unsqueeze(0)
+
+ return state_dict
+
+ def convert_padding_mask(self, x, padding_mask):
+ return padding_mask
+
+ def decoder_input(self, x, mask_info: MaskInfo):
+ inp_drop = self.modality_cfg.decoder.input_dropout
+ if inp_drop > 0:
+ x = F.dropout(x, inp_drop, training=self.training, inplace=True)
+
+ num_extra = self.modality_cfg.num_extra_tokens
+
+ if mask_info is not None:
+ num_masked = mask_info.ids_restore.shape[1] - x.shape[1] + num_extra
+
+ mask_tokens = x.new_empty(
+ x.size(0),
+ num_masked,
+ x.size(-1),
+ ).normal_(0, self.modality_cfg.mask_noise_std)
+
+ x_ = torch.cat([x[:, num_extra:], mask_tokens], dim=1)
+ x = torch.gather(x_, dim=1, index=mask_info.ids_restore)
+
+ if self.modality_cfg.decoder.add_positions_masked:
+ assert self.fixed_positional_encoder is not None
+ pos = self.fixed_positional_encoder(x, None)
+ x = x + (pos * mask_info.mask.unsqueeze(-1))
+ else:
+ x = x[:, num_extra:]
+
+ if self.modality_cfg.decoder.add_positions_all:
+ assert self.fixed_positional_encoder is not None
+ x = x + self.fixed_positional_encoder(x, None)
+
+ return x, mask_info
+
+ def local_features(self, features):
+ if self.local_grad_mult > 0:
+ if self.local_grad_mult == 1.0:
+ x = self.local_encoder(features)
+ else:
+ x = GradMultiply.apply(
+ self.local_encoder(features), self.local_grad_mult
+ )
+ else:
+ with torch.no_grad():
+ x = self.local_encoder(features)
+
+ x = self.project_features(x)
+ return x
+
+ def contextualized_features(
+ self,
+ x,
+ padding_mask,
+ mask,
+ remove_masked,
+ clone_batch: int = 1,
+ mask_seeds: Optional[torch.Tensor] = None,
+ precomputed_mask=None,
+ ):
+
+ if padding_mask is not None:
+ padding_mask = self.convert_padding_mask(x, padding_mask)
+
+ local_features = x
+ if mask and clone_batch == 1:
+ local_features = local_features.clone()
+
+ orig_B, orig_T, _ = x.shape
+ pre_mask_B = orig_B
+ mask_info = None
+
+ x_pos = None
+ if self.fixed_positional_encoder is not None:
+ x = x + self.fixed_positional_encoder(x, padding_mask)
+
+ if mask:
+ if clone_batch > 1:
+ x = x.repeat_interleave(clone_batch, 0)
+ if mask_seeds is not None:
+ clone_hash = [
+ int(hash((mask_seeds.seed, ind)) % 1e10)
+ for ind in range(clone_batch - 1)
+ ]
+ clone_hash = torch.tensor([0] + clone_hash).long().view(1, -1)
+
+ id = mask_seeds.ids
+ id = id.repeat_interleave(clone_batch, 0)
+ id = id.view(-1, clone_batch) + clone_hash.to(id)
+ id = id.view(-1)
+ mask_seeds = MaskSeed(
+ seed=mask_seeds.seed, update=mask_seeds.update, ids=id
+ )
+ if padding_mask is not None:
+ padding_mask = padding_mask.repeat_interleave(clone_batch, 0)
+
+ x, mask_info = self.compute_mask(
+ x,
+ padding_mask,
+ mask_seed=mask_seeds,
+ apply=self.relative_positional_encoder is not None or not remove_masked,
+ precomputed_mask=precomputed_mask,
+ )
+
+ if self.relative_positional_encoder is not None:
+ x_pos = self.relative_positional_encoder(x)
+
+ masked_padding_mask = padding_mask
+ if mask and remove_masked:
+ x = mask_info.x_unmasked
+ if x_pos is not None:
+ x = x + gather_unmasked(x_pos, mask_info)
+
+ if padding_mask is not None and padding_mask.any():
+ masked_padding_mask = gather_unmasked_mask(padding_mask, mask_info)
+ if not masked_padding_mask.any():
+ masked_padding_mask = None
+ else:
+ masked_padding_mask = None
+
+ elif x_pos is not None:
+ x = x + x_pos
+
+ alibi_bias = None
+ alibi_scale = self.alibi_scale
+
+ if self.get_alibi_bias is not None:
+ alibi_bias = self.get_alibi_bias(
+ batch_size=pre_mask_B,
+ time_steps=orig_T,
+ heads=self.modality_cfg.num_alibi_heads,
+ dtype=torch.float32,
+ device=x.device,
+ )
+
+ if alibi_scale is not None:
+ alibi_scale = alibi_scale.clamp_min(0)
+ if alibi_scale.size(0) == 1:
+ alibi_bias = alibi_bias * alibi_scale.squeeze(0).type_as(alibi_bias)
+ alibi_scale = None
+
+ if clone_batch > 1:
+ alibi_bias = alibi_bias.repeat_interleave(clone_batch, 0)
+
+ if mask_info is not None and remove_masked:
+ alibi_bias = masked_alibi(alibi_bias, mask_info)
+
+ if self.extra_tokens is not None:
+ num = self.extra_tokens.size(1)
+ x = torch.cat([self.extra_tokens.expand(x.size(0), -1, -1), x], dim=1)
+ if masked_padding_mask is not None:
+ # B x T
+ masked_padding_mask = F.pad(masked_padding_mask, (num, 0))
+ if alibi_bias is not None:
+ # B x H x T x T
+ alibi_bias = F.pad(alibi_bias, (num, 0, num, 0))
+
+ x = self.context_encoder(
+ x,
+ masked_padding_mask,
+ alibi_bias,
+ alibi_scale[: self.modality_cfg.prenet_depth]
+ if alibi_scale is not None
+ else None,
+ )
+
+ return {
+ "x": x,
+ "local_features": local_features,
+ "padding_mask": masked_padding_mask,
+ "alibi_bias": alibi_bias,
+ "alibi_scale": alibi_scale[self.modality_cfg.prenet_depth :]
+ if alibi_scale is not None and alibi_scale.size(0) > 1
+ else alibi_scale,
+ "encoder_mask": mask_info,
+ }
+
+ def forward(
+ self,
+ features,
+ padding_mask,
+ mask: bool,
+ remove_masked: bool,
+ clone_batch: int = 1,
+ mask_seeds: Optional[torch.Tensor] = None,
+ precomputed_mask=None,
+ ):
+ x = self.local_features(features)
+ return self.contextualized_features(
+ x,
+ padding_mask,
+ mask,
+ remove_masked,
+ clone_batch,
+ mask_seeds,
+ precomputed_mask,
+ )
+
+ def reset_parameters(self):
+ pass
+
+ def compute_mask(
+ self,
+ x,
+ padding_mask,
+ mask_seed: Optional[MaskSeed],
+ apply,
+ precomputed_mask,
+ ):
+ if precomputed_mask is not None:
+ mask = precomputed_mask
+ mask_info = self.make_maskinfo(x, mask)
+ else:
+ B, T, C = x.shape
+ cfg = self.modality_cfg
+
+ mask_prob = cfg.mask_prob
+
+ if (
+ cfg.mask_prob_min is not None
+ and cfg.mask_prob_min >= 0
+ and cfg.mask_prob_min < mask_prob
+ ):
+ mask_prob = np.random.uniform(cfg.mask_prob_min, mask_prob)
+
+ if mask_prob > 0:
+ if cfg.mask_length == 1:
+ mask_info = random_masking(x, mask_prob, mask_seed)
+ else:
+ if self.modality_cfg.inverse_mask:
+ mask_prob = 1 - mask_prob
+
+ mask = compute_mask_indices(
+ (B, T),
+ padding_mask,
+ mask_prob,
+ cfg.mask_length,
+ min_masks=1,
+ require_same_masks=True,
+ mask_dropout=cfg.mask_dropout,
+ add_masks=cfg.add_masks,
+ seed=mask_seed.seed if mask_seed is not None else None,
+ epoch=mask_seed.update if mask_seed is not None else None,
+ indices=mask_seed.ids if mask_seed is not None else None,
+ )
+
+ mask = torch.from_numpy(mask).to(device=x.device)
+ if self.modality_cfg.inverse_mask:
+ mask = 1 - mask
+ mask_info = self.make_maskinfo(x, mask)
+ else:
+ mask_info = None
+
+ if apply:
+ x = self.apply_mask(x, mask_info)
+
+ return x, mask_info
+
+ def make_maskinfo(self, x, mask, shape=None):
+ if shape is None:
+ B, T, D = x.shape
+ else:
+ B, T, D = shape
+
+ mask = mask.to(torch.uint8)
+ ids_shuffle = mask.argsort(dim=1)
+ ids_restore = ids_shuffle.argsort(dim=1).unsqueeze(-1).expand(-1, -1, D)
+
+ len_keep = T - mask[0].sum()
+ if self.modality_cfg.keep_masked_pct > 0:
+ len_keep += round((T - int(len_keep)) * self.modality_cfg.keep_masked_pct)
+
+ ids_keep = ids_shuffle[:, :len_keep]
+
+ if shape is not None:
+ x_unmasked = None
+ else:
+ ids_keep = ids_keep.unsqueeze(-1).expand(-1, -1, D)
+ x_unmasked = torch.gather(x, dim=1, index=ids_keep)
+
+ mask_info = MaskInfo(
+ x_unmasked=x_unmasked,
+ mask=mask,
+ ids_restore=ids_restore,
+ ids_keep=ids_keep,
+ )
+ return mask_info
+
+ def apply_mask(self, x, mask_info):
+ cfg = self.modality_cfg
+ B, T, C = x.shape
+
+ if mask_info is not None:
+ mask = mask_info.mask
+ if cfg.encoder_zero_mask:
+ x = x * (1 - mask.type_as(x).unsqueeze(-1))
+ else:
+ num_masks = mask.sum().item()
+ masks = x.new_empty(num_masks, x.size(-1)).normal_(
+ 0, cfg.mask_noise_std
+ )
+ x = index_put(x, mask, masks)
+ if cfg.mask_channel_prob > 0:
+ mask_channel = compute_mask_indices(
+ (B, C),
+ None,
+ cfg.mask_channel_prob,
+ cfg.mask_channel_length,
+ )
+ mask_channel = (
+ torch.from_numpy(mask_channel)
+ .to(x.device)
+ .unsqueeze(1)
+ .expand(-1, T, -1)
+ )
+ x = index_put(x, mask_channel, 0)
+ return x
+
+ def remove_pretraining_modules(self, keep_decoder=False):
+ if not keep_decoder:
+ self.decoder = None
+
+
+def get_annealed_rate(start, end, curr_step, total_steps):
+ if curr_step >= total_steps:
+ return end
+ r = end - start
+ pct_remaining = 1 - curr_step / total_steps
+ return end - r * pct_remaining
+
+
+# adapted from MAE
+def random_masking(x, mask_ratio, mask_seed: Optional[MaskSeed]):
+ N, L, D = x.shape # batch, length, dim
+ len_keep = int(L * (1 - mask_ratio))
+
+ generator = None
+ if mask_seed is not None:
+ seed = int(
+ hash((mask_seed.seed, mask_seed.update, mask_seed.ids.sum().item())) % 1e6
+ )
+ generator = torch.Generator(device=x.device)
+ generator.manual_seed(seed)
+
+ noise = torch.rand(N, L, generator=generator, device=x.device) # noise in [0, 1]
+
+ # sort noise for each sample
+ ids_shuffle = noise.argsort(dim=1) # ascend: small is keep, large is remove
+ ids_restore = ids_shuffle.argsort(dim=1)
+
+ # keep the first subset
+ ids_keep = ids_shuffle[:, :len_keep]
+ ids_keep = ids_keep.unsqueeze(-1).expand(-1, -1, D)
+ x_unmasked = torch.gather(x, dim=1, index=ids_keep)
+
+ # generate the binary mask: 0 is keep, 1 is remove
+ mask = torch.ones([N, L], dtype=x.dtype, device=x.device)
+ mask[:, :len_keep] = 0
+ # unshuffle to get the binary mask
+ mask = torch.gather(mask, dim=1, index=ids_restore)
+
+ ids_restore = ids_restore.unsqueeze(-1).expand(-1, -1, D)
+
+ return MaskInfo(
+ x_unmasked=x_unmasked, mask=mask, ids_restore=ids_restore, ids_keep=ids_keep
+ )
+
+
+def gather_unmasked(x: torch.Tensor, mask_info: MaskInfo) -> torch.Tensor:
+ return torch.gather(
+ x,
+ dim=1,
+ index=mask_info.ids_keep,
+ )
+
+
+def gather_unmasked_mask(x: torch.Tensor, mask_info: MaskInfo) -> torch.Tensor:
+ return torch.gather(
+ x,
+ dim=1,
+ index=mask_info.ids_keep[..., 0], # ignore the feature dimension
+ )
+
+
+def get_alibi(
+ max_positions: int,
+ attention_heads: int,
+ dims: int = 1,
+ distance: str = "manhattan",
+):
+ def get_slopes(n):
+ def get_slopes_power_of_2(n):
+ start = 2 ** (-(2 ** -(math.log2(n) - 3)))
+ ratio = start
+ return [start * ratio**i for i in range(n)]
+
+ # In the paper, we only train models that have 2^a heads for some
+ # a. This function has some good properties that only occur when
+ # the input is a power of 2. To maintain that even when the number
+ # of heads is not a power of 2, we use this workaround.
+ if math.log2(n).is_integer():
+ return get_slopes_power_of_2(n)
+ else:
+ closest_power_of_2 = 2 ** math.floor(math.log2(n))
+ return (
+ get_slopes_power_of_2(closest_power_of_2)
+ + get_slopes(2 * closest_power_of_2)[0::2][: n - closest_power_of_2]
+ )
+
+ maxpos = max_positions
+ attn_heads = attention_heads
+ slopes = torch.Tensor(get_slopes(attn_heads))
+
+ if dims == 1:
+ # prepare alibi position linear bias. Note that wav2vec2 is non
+ # autoregressive model so we want a symmetric mask with 0 on the
+ # diagonal and other wise linear decreasing valuees
+ pos_bias = (
+ torch.abs(
+ torch.arange(maxpos).unsqueeze(0) - torch.arange(maxpos).unsqueeze(1)
+ )
+ * -1
+ )
+ elif dims == 2:
+ if distance == "manhattan":
+ df = lambda x1, y1, x2, y2: abs(x1 - x2) + abs(y1 - y2)
+ elif distance == "euclidean":
+ df = lambda x1, y1, x2, y2: math.sqrt((x1 - x2) ** 2 + (y1 - y2) ** 2)
+
+ n = math.sqrt(max_positions)
+ assert n.is_integer(), n
+ n = int(n)
+
+ pos_bias = torch.zeros((max_positions, max_positions))
+
+ for i in range(n):
+ for j in range(n):
+ for k in range(n):
+ for l in range(n):
+ new_x = i * n + j
+ new_y = k * n + l
+ pos_bias[new_x, new_y] = -df(i, j, k, l)
+
+ else:
+ raise Exception(f"unsupported number of alibi dims: {dims}")
+
+ alibi_bias = slopes.unsqueeze(1).unsqueeze(1) * pos_bias.unsqueeze(0).expand(
+ attn_heads, -1, -1
+ )
+
+ return alibi_bias
+
+
+def get_alibi_bias(
+ alibi_biases,
+ batch_size,
+ time_steps,
+ heads,
+ dtype,
+ device,
+ dims=1,
+ distance="manhattan",
+):
+ cache_key = f"{dims}_{heads}_{distance}"
+
+ buffered = alibi_biases.get(cache_key, None)
+
+ target_size = heads * batch_size
+ if (
+ buffered is None
+ or buffered.size(0) < target_size
+ or buffered.size(1) < time_steps
+ or buffered.dtype != dtype
+ or buffered.device != device
+ ):
+ bt = max(time_steps, buffered.size(1) if buffered is not None else 0)
+ bn = max(target_size, buffered.size(0) if buffered is not None else 0) // heads
+
+ buffered = (
+ get_alibi(bt, heads, dims=dims, distance=distance)
+ .to(dtype=dtype, device=device)
+ .repeat(bn, 1, 1)
+ )
+
+ alibi_biases[cache_key] = buffered
+
+ b = buffered[:target_size, :time_steps, :time_steps]
+ b = b.view(batch_size, heads, time_steps, time_steps)
+ return b
+
+
+def _learned_alibi_bias(
+ alibi_bias,
+ batch_size,
+ time_steps,
+ heads,
+ scale,
+ dtype,
+ device,
+):
+ assert alibi_bias.size(1) == heads, alibi_bias.shape
+ assert alibi_bias.dtype == dtype, alibi_bias.dtype
+ assert alibi_bias.device == device, alibi_bias.device
+
+ if alibi_bias.size(-1) < time_steps:
+ psz = math.ceil((time_steps - alibi_bias.size(-1)) / 2)
+ alibi_bias = F.pad(alibi_bias, (psz, psz, psz, psz), mode="replicate")
+
+ alibi_bias = alibi_bias.expand(batch_size, -1, -1, -1) * scale
+ return alibi_bias[..., :time_steps, :time_steps]
+
+
+def masked_alibi(alibi_bias, mask_info):
+ H = alibi_bias.size(1)
+
+ orig_bias = alibi_bias
+
+ index = mask_info.ids_keep.unsqueeze(1)[..., 0].unsqueeze(-1)
+ alibi_bias = torch.gather(
+ orig_bias,
+ dim=-2,
+ index=index.expand(-1, H, -1, mask_info.ids_restore.size(1)),
+ )
+ alibi_bias = torch.gather(
+ alibi_bias,
+ dim=-1,
+ index=index.transpose(-1, -2).expand(-1, H, alibi_bias.size(-2), -1),
+ )
+
+ return alibi_bias
diff --git a/examples/data2vec/models/modalities/images.py b/examples/data2vec/models/modalities/images.py
new file mode 100644
index 000000000..a6b738cb0
--- /dev/null
+++ b/examples/data2vec/models/modalities/images.py
@@ -0,0 +1,256 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import numpy as np
+from functools import partial
+from dataclasses import dataclass
+from typing import Callable, Dict, Optional
+from timm.models.layers import to_2tuple
+from fairseq.tasks import FairseqTask
+from examples.data2vec.models.mae import get_2d_sincos_pos_embed, PatchEmbed
+from .base import (
+ D2vModalityConfig,
+ ModalitySpecificEncoder,
+ get_alibi_bias,
+ MaskSeed,
+)
+from .modules import (
+ BlockEncoder,
+ Decoder2d,
+ FixedPositionalEncoder,
+ TransformerDecoder,
+ EncDecTransformerDecoder,
+)
+from examples.data2vec.data.modality import Modality
+
+
+@dataclass
+class D2vImageConfig(D2vModalityConfig):
+ type: Modality = Modality.IMAGE
+
+ input_size: int = 224
+ in_chans: int = 3
+ patch_size: int = 16
+ embed_dim: int = 768
+
+ alibi_dims: int = 2
+ alibi_distance: str = "manhattan"
+
+ fixed_positions: bool = True
+
+ transformer_decoder: bool = False
+ enc_dec_transformer: bool = False
+
+
+class ImageEncoder(ModalitySpecificEncoder):
+
+ modality_cfg: D2vImageConfig
+
+ def __init__(
+ self,
+ modality_cfg: D2vImageConfig,
+ embed_dim: int,
+ make_block: Callable[[float, Optional[int], Optional[int]], nn.ModuleList],
+ norm_layer: Callable[[int], nn.LayerNorm],
+ layer_norm_first: bool,
+ alibi_biases: Dict,
+ task: Optional[FairseqTask],
+ ):
+
+ img_size = to_2tuple(modality_cfg.input_size)
+ patch_size = to_2tuple(modality_cfg.patch_size)
+ num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
+
+ local_encoder = PatchEmbed(
+ modality_cfg.input_size,
+ modality_cfg.patch_size,
+ modality_cfg.in_chans,
+ modality_cfg.embed_dim,
+ )
+
+ w = local_encoder.proj.weight.data
+ torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
+
+ if modality_cfg.embed_dim != embed_dim:
+ local_encoder = nn.Sequential(
+ local_encoder,
+ nn.Linear(modality_cfg.embed_dim, embed_dim),
+ )
+
+ project_features = nn.Identity()
+
+ pos_embed = nn.Parameter(
+ torch.zeros(1, num_patches, embed_dim), requires_grad=False
+ )
+
+ side_n = int(num_patches ** 0.5)
+
+ emb = get_2d_sincos_pos_embed(
+ pos_embed.shape[-1],
+ side_n,
+ cls_token=False,
+ )
+ pos_embed.data.copy_(torch.from_numpy(emb).float().unsqueeze(0))
+ fixed_positional_encoder = (
+ FixedPositionalEncoder(pos_embed) if modality_cfg.fixed_positions else None
+ )
+
+ dpr = np.linspace(
+ modality_cfg.start_drop_path_rate,
+ modality_cfg.end_drop_path_rate,
+ modality_cfg.prenet_depth,
+ )
+
+ context_encoder = BlockEncoder(
+ nn.ModuleList(make_block(dpr[i]) for i in range(modality_cfg.prenet_depth)),
+ norm_layer(embed_dim) if not layer_norm_first else None,
+ layer_norm_first,
+ modality_cfg.prenet_layerdrop,
+ modality_cfg.prenet_dropout,
+ )
+
+ if modality_cfg.transformer_decoder:
+ if modality_cfg.enc_dec_transformer:
+ decoder = EncDecTransformerDecoder(modality_cfg.decoder, embed_dim)
+ else:
+ dec_enc = BlockEncoder(
+ nn.ModuleList(
+ make_block(0, modality_cfg.decoder.decoder_dim, 8)
+ for _ in range(modality_cfg.decoder.decoder_layers)
+ ),
+ None,
+ layer_norm_first,
+ 0,
+ 0,
+ )
+ decoder = TransformerDecoder(modality_cfg.decoder, embed_dim, dec_enc)
+ else:
+ decoder = (
+ Decoder2d(modality_cfg.decoder, embed_dim, side_n, side_n)
+ if modality_cfg.decoder is not None
+ else None
+ )
+
+ alibi_bias_fn = partial(
+ get_alibi_bias,
+ alibi_biases=alibi_biases,
+ heads=modality_cfg.num_alibi_heads,
+ dims=modality_cfg.alibi_dims,
+ distance=modality_cfg.alibi_distance,
+ )
+
+ super().__init__(
+ modality_cfg=modality_cfg,
+ embed_dim=embed_dim,
+ local_encoder=local_encoder,
+ project_features=project_features,
+ fixed_positional_encoder=fixed_positional_encoder,
+ relative_positional_encoder=None,
+ context_encoder=context_encoder,
+ decoder=decoder,
+ get_alibi_bias=alibi_bias_fn,
+ )
+
+ def reset_parameters(self):
+ super().reset_parameters()
+ if self.decoder is not None:
+ self.decoder.reset_parameters()
+
+ @torch.no_grad()
+ def patchify(self, imgs):
+ """
+ imgs: (N, 3, H, W)
+ x: (N, L, patch_size**2 *3)
+ """
+ p = self.modality_cfg.patch_size
+ h = w = imgs.shape[2] // p
+ x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p))
+ x = torch.einsum("nchpwq->nhwpqc", x)
+ x = x.reshape(shape=(imgs.shape[0], h * w, p ** 2 * 3))
+
+ return x
+
+ @torch.no_grad()
+ def unpatchify(self, x):
+ """
+ x: (N, L, patch_size**2 *3)
+ imgs: (N, 3, H, W)
+ """
+ p = self.modality_cfg.patch_size
+ h = w = int(x.shape[1] ** 0.5)
+ assert h * w == x.shape[1]
+
+ x = x.reshape(shape=(x.shape[0], h, w, p, p, 3))
+ x = torch.einsum("nhwpqc->nchpwq", x)
+ imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p))
+ return imgs
+
+ def compute_mask(
+ self,
+ x,
+ padding_mask,
+ mask_seed: Optional[MaskSeed],
+ apply,
+ shape=None,
+ precomputed_mask=None,
+ ):
+ mlen = self.modality_cfg.mask_length
+ if mlen <= 1:
+ return super().compute_mask(
+ x, padding_mask, mask_seed, apply, precomputed_mask
+ )
+
+ if precomputed_mask is not None:
+ mask = precomputed_mask
+ else:
+ from fairseq.data.data_utils import compute_block_mask_2d
+
+ if shape is not None:
+ B, L, D = shape
+ else:
+ B, L, D = x.shape
+
+ mask = compute_block_mask_2d(
+ shape=(B, L),
+ mask_prob=self.modality_cfg.mask_prob,
+ mask_length=self.modality_cfg.mask_length,
+ mask_prob_adjust=self.modality_cfg.mask_prob_adjust,
+ inverse_mask=self.modality_cfg.inverse_mask,
+ require_same_masks=True,
+ mask_dropout=self.modality_cfg.mask_dropout,
+ )
+
+ mask_info = self.make_maskinfo(x, mask, shape)
+ if apply:
+ x = self.apply_mask(x, mask_info)
+
+ return x, mask_info
+
+ def decoder_input(self, x, mask_info):
+ if (
+ not self.modality_cfg.transformer_decoder
+ or not self.modality_cfg.enc_dec_transformer
+ ):
+ return super().decoder_input(x, mask_info)
+
+ inp_drop = self.modality_cfg.decoder.input_dropout
+ if inp_drop > 0:
+ x = F.dropout(x, inp_drop, training=self.training, inplace=True)
+
+ kv = x[:, self.modality_cfg.num_extra_tokens :]
+
+ assert self.fixed_positional_encoder is not None
+ pos = self.fixed_positional_encoder(x, None).expand(x.size(0), -1, -1)
+
+ mask = mask_info.mask.bool()
+ if self.modality_cfg.decoder.add_positions_all:
+ kv = kv + pos[~mask].view(kv.shape)
+
+ q = pos[mask].view(x.size(0), -1, x.size(-1))
+
+ return q, kv
diff --git a/examples/data2vec/models/modalities/modules.py b/examples/data2vec/models/modalities/modules.py
new file mode 100644
index 000000000..a4e1a4ea0
--- /dev/null
+++ b/examples/data2vec/models/modalities/modules.py
@@ -0,0 +1,589 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import numpy as np
+from dataclasses import dataclass
+from fairseq.modules import (
+ LayerNorm,
+ SamePad,
+ SamePad2d,
+ TransposeLast,
+)
+
+
+@dataclass
+class D2vDecoderConfig:
+ decoder_dim: int = 384
+ decoder_groups: int = 16
+ decoder_kernel: int = 5
+ decoder_layers: int = 5
+ input_dropout: float = 0.1
+
+ add_positions_masked: bool = False
+ add_positions_all: bool = False
+
+ decoder_residual: bool = True
+ projection_layers: int = 1
+ projection_ratio: float = 2.0
+
+
+class FixedPositionalEncoder(nn.Module):
+ def __init__(self, pos_embed):
+ super().__init__()
+ self.positions = pos_embed
+
+ def forward(self, x, padding_mask):
+ return self.positions
+
+
+class TextFeatPositionalEncoder(nn.Module):
+ """
+ Original encoder expects (B, T) long input. This module wraps it to take
+ local_encoder output which are (B, T, D) float tensors
+ """
+
+ def __init__(self, pos_encoder):
+ super().__init__()
+ self.pos_encoder = pos_encoder
+
+ def forward(self, x, padding_mask):
+ # assume padded token embeddings are 0s
+ # TODO: consider using padding_mask as input
+ return self.pos_encoder(x[..., 0])
+
+
+class BlockEncoder(nn.Module):
+ def __init__(self, blocks, norm_layer, layer_norm_first, layerdrop, dropout):
+ super().__init__()
+ self.blocks = blocks
+ self.norm = norm_layer
+ self.layer_norm_first = layer_norm_first
+ self.layerdrop = layerdrop
+ self.dropout = nn.Dropout(dropout, inplace=True)
+
+ def forward(self, x, padding_mask, alibi_bias, alibi_scale):
+ if self.norm is not None and not self.layer_norm_first:
+ x = self.norm(x)
+
+ x = self.dropout(x)
+
+ for i, blk in enumerate(self.blocks):
+ if (
+ not self.training
+ or self.layerdrop == 0
+ or (np.random.random() > self.layerdrop)
+ ):
+ ab = alibi_bias
+ if ab is not None and alibi_scale is not None:
+ scale = (
+ alibi_scale[i]
+ if alibi_scale.size(0) > 1
+ else alibi_scale.squeeze(0)
+ )
+ ab = ab * scale.type_as(ab)
+ x, _ = blk(x, padding_mask, ab)
+
+ if self.norm is not None and self.layer_norm_first:
+ x = self.norm(x)
+
+ return x
+
+
+class DecoderBase(nn.Module):
+ decoder_cfg: D2vDecoderConfig
+
+ def __init__(self, cfg: D2vDecoderConfig):
+ super().__init__()
+
+ self.decoder_cfg = cfg
+
+ def reset_parameters(self):
+ for mod in self.proj.modules():
+ if isinstance(mod, nn.Linear):
+ mod.reset_parameters()
+
+ def add_residual(self, x, residual, i, mask_info):
+ if (
+ residual is None
+ or not self.decoder_cfg.decoder_residual
+ or residual.size(1) != x.size(1)
+ ):
+ return x
+
+ ret = x + residual
+
+ return ret
+
+
+class Decoder1d(DecoderBase):
+ def __init__(self, cfg: D2vDecoderConfig, input_dim):
+ super().__init__(cfg)
+
+ def make_block(in_dim):
+ block = [
+ nn.Conv1d(
+ in_dim,
+ cfg.decoder_dim,
+ kernel_size=cfg.decoder_kernel,
+ padding=cfg.decoder_kernel // 2,
+ groups=cfg.decoder_groups,
+ ),
+ SamePad(cfg.decoder_kernel),
+ TransposeLast(),
+ LayerNorm(cfg.decoder_dim, elementwise_affine=False),
+ TransposeLast(),
+ nn.GELU(),
+ ]
+
+ return nn.Sequential(*block)
+
+ self.blocks = nn.Sequential(
+ *[
+ make_block(input_dim if i == 0 else cfg.decoder_dim)
+ for i in range(cfg.decoder_layers)
+ ]
+ )
+
+ projs = []
+ curr_dim = cfg.decoder_dim
+ for i in range(cfg.projection_layers - 1):
+ next_dim = int(curr_dim * cfg.projection_ratio) if i == 0 else curr_dim
+ projs.append(nn.Linear(curr_dim, next_dim))
+ projs.append(nn.GELU())
+ curr_dim = next_dim
+ projs.append(nn.Linear(curr_dim, input_dim))
+ if len(projs) == 1:
+ self.proj = projs[0]
+ else:
+ self.proj = nn.Sequential(*projs)
+
+ def forward(self, x, mask_info):
+
+ x = x.transpose(1, 2)
+
+ residual = x
+
+ for i, layer in enumerate(self.blocks):
+ x = layer(x)
+ x = self.add_residual(x, residual, i, mask_info)
+ residual = x
+
+ x = x.transpose(1, 2)
+ x = self.proj(x)
+ return x
+
+
+class Decoder2d(DecoderBase):
+ def __init__(self, cfg: D2vDecoderConfig, input_dim, h_size, w_size):
+ super().__init__(cfg)
+
+ self.h_size = h_size
+ self.w_size = w_size
+
+ def make_block(in_dim):
+ block = [
+ nn.Conv2d(
+ in_dim,
+ cfg.decoder_dim,
+ kernel_size=cfg.decoder_kernel,
+ padding=cfg.decoder_kernel // 2,
+ groups=cfg.decoder_groups,
+ ),
+ SamePad2d(cfg.decoder_kernel),
+ TransposeLast(tranpose_dim=-3),
+ LayerNorm(cfg.decoder_dim, elementwise_affine=False),
+ TransposeLast(tranpose_dim=-3),
+ nn.GELU(),
+ ]
+
+ return nn.Sequential(*block)
+
+ self.blocks = nn.Sequential(
+ *[
+ make_block(input_dim if i == 0 else cfg.decoder_dim)
+ for i in range(cfg.decoder_layers)
+ ]
+ )
+
+ self.proj = nn.Linear(cfg.decoder_dim, input_dim)
+
+ def forward(self, x, mask_info):
+ B, T, C = x.shape
+
+ x = x.transpose(1, 2).reshape(B, C, self.h_size, self.w_size)
+
+ residual = x
+
+ for i, layer in enumerate(self.blocks):
+ x = layer(x)
+ x = self.add_residual(x, residual, i, mask_info)
+ residual = x
+
+ x = x.reshape(B, -1, T).transpose(1, 2)
+ x = self.proj(x)
+ return x
+
+
+class TransformerDecoder(nn.Module):
+ decoder_cfg: D2vDecoderConfig
+
+ def __init__(self, cfg: D2vDecoderConfig, input_dim, encoder):
+ super().__init__()
+
+ self.decoder_cfg = cfg
+
+ self.input_proj = nn.Linear(input_dim, cfg.decoder_dim)
+
+ self.encoder = encoder
+
+ self.proj = nn.Linear(cfg.decoder_dim, input_dim)
+
+ def reset_parameters(self):
+ from fairseq.modules.transformer_sentence_encoder import init_bert_params
+
+ self.apply(init_bert_params)
+
+ def forward(self, x, mask_info):
+ x = self.input_proj(x)
+ x = self.encoder(x, None, None, 1)
+ x = self.proj(x)
+ return x
+
+
+class AltBlock(nn.Module):
+ def __init__(
+ self,
+ dim,
+ num_heads,
+ mlp_ratio=4.0,
+ qkv_bias=False,
+ qk_scale=None,
+ drop=0.0,
+ attn_drop=0.0,
+ mlp_drop=0.0,
+ post_mlp_drop=0.0,
+ drop_path=0.0,
+ act_layer=nn.GELU,
+ norm_layer=nn.LayerNorm,
+ layer_norm_first=True,
+ ffn_targets=False,
+ cosine_attention=False,
+ ):
+ super().__init__()
+
+ self.layer_norm_first = layer_norm_first
+ self.ffn_targets = ffn_targets
+
+ from timm.models.vision_transformer import DropPath, Mlp
+
+ self.norm1 = norm_layer(dim)
+ self.attn = AltAttention(
+ dim,
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ attn_drop=attn_drop,
+ proj_drop=drop,
+ cosine_attention=cosine_attention,
+ )
+
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = Mlp(
+ in_features=dim,
+ hidden_features=mlp_hidden_dim,
+ act_layer=act_layer,
+ drop=mlp_drop,
+ )
+ self.post_mlp_dropout = nn.Dropout(post_mlp_drop, inplace=False)
+
+ def forward(self, x, padding_mask=None, alibi_bias=None):
+ if self.layer_norm_first:
+ x = x + self.drop_path(self.attn(self.norm1(x), padding_mask, alibi_bias))
+ r = x = self.mlp(self.norm2(x))
+ t = x
+ x = r + self.drop_path(self.post_mlp_dropout(x))
+ if not self.ffn_targets:
+ t = x
+ else:
+ x = x + self.drop_path(self.attn(x, padding_mask, alibi_bias))
+ r = x = self.norm1(x)
+ x = self.mlp(x)
+ t = x
+ x = self.norm2(r + self.drop_path(self.post_mlp_dropout(x)))
+ if not self.ffn_targets:
+ t = x
+
+ return x, t
+
+
+class AltAttention(nn.Module):
+ def __init__(
+ self,
+ dim,
+ num_heads=8,
+ qkv_bias=False,
+ qk_scale=None,
+ attn_drop=0.0,
+ proj_drop=0.0,
+ cosine_attention=False,
+ ):
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = qk_scale or head_dim ** -0.5
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ self.cosine_attention = cosine_attention
+
+ if cosine_attention:
+ self.logit_scale = nn.Parameter(
+ torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True
+ )
+
+ def forward(self, x, padding_mask=None, alibi_bias=None):
+ B, N, C = x.shape
+ qkv = (
+ self.qkv(x)
+ .reshape(B, N, 3, self.num_heads, C // self.num_heads)
+ .permute(2, 0, 3, 1, 4) # qkv x B x H x L x D
+ )
+ q, k, v = (
+ qkv[0],
+ qkv[1],
+ qkv[2],
+ ) # make torchscript happy (cannot use tensor as tuple)
+
+ dtype = q.dtype
+
+ if self.cosine_attention:
+ # cosine attention
+ attn = F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1)
+ logit_scale = torch.clamp(
+ self.logit_scale, max=torch.log(torch.tensor(1.0 / 0.01))
+ ).exp()
+ attn = attn * logit_scale
+ else:
+ q = q * self.scale
+ attn = q @ k.transpose(-2, -1)
+
+ if alibi_bias is not None:
+ attn = attn.type_as(alibi_bias)
+ attn[:, : alibi_bias.size(1)] += alibi_bias
+
+ if padding_mask is not None and padding_mask.any():
+ attn = attn.masked_fill(
+ padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
+ float("-inf"),
+ )
+
+ attn = attn.softmax(dim=-1, dtype=torch.float32).to(dtype=dtype)
+ attn = self.attn_drop(attn)
+ x = (attn @ v).transpose(1, 2) #
+ x = x.reshape(B, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class EncDecAttention(nn.Module):
+ def __init__(
+ self,
+ q_dim,
+ kv_dim,
+ num_heads=8,
+ qkv_bias=False,
+ qk_scale=None,
+ attn_drop=0.0,
+ proj_drop=0.0,
+ cosine_attention=False,
+ ):
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = q_dim // num_heads
+ self.scale = qk_scale or head_dim ** -0.5
+
+ self.q_proj = nn.Linear(q_dim, q_dim, bias=qkv_bias)
+ self.kv_proj = nn.Linear(kv_dim, 2 * q_dim, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(q_dim, q_dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ self.cosine_attention = cosine_attention
+
+ if cosine_attention:
+ self.logit_scale = nn.Parameter(
+ torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True
+ )
+
+ def forward(self, q, kv, padding_mask=None, alibi_bias=None):
+ B, N, C = q.shape
+
+ q = (
+ self.q_proj(q)
+ .reshape(B, N, self.num_heads, C // self.num_heads)
+ .permute(0, 2, 1, 3)
+ ) # B x H x L x D
+ kv = (
+ self.kv_proj(kv)
+ .reshape(B, -1, 2, self.num_heads, C // self.num_heads)
+ .permute(2, 0, 3, 1, 4)
+ ) # kv x B x H x L x D
+ k, v = (
+ kv[0],
+ kv[1],
+ ) # make torchscript happy (cannot use tensor as tuple)
+
+ dtype = q.dtype
+
+ if self.cosine_attention:
+ # cosine attention
+ attn = F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1)
+ logit_scale = torch.clamp(
+ self.logit_scale, max=torch.log(torch.tensor(1.0 / 0.01))
+ ).exp()
+ attn = attn * logit_scale
+ else:
+ q = q * self.scale
+ attn = q @ k.transpose(-2, -1)
+
+ if alibi_bias is not None:
+ attn = attn.type_as(alibi_bias)
+ attn[:, : alibi_bias.size(1)] += alibi_bias
+
+ if padding_mask is not None and padding_mask.any():
+ attn = attn.masked_fill(
+ padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
+ float("-inf"),
+ )
+
+ attn = attn.softmax(dim=-1, dtype=torch.float32).to(dtype=dtype)
+ attn = self.attn_drop(attn)
+ x = (attn @ v).transpose(1, 2) #
+ x = x.reshape(B, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class EncDecBlock(nn.Module):
+ def __init__(
+ self,
+ q_dim,
+ kv_dim,
+ num_heads,
+ mlp_ratio=4.0,
+ qkv_bias=False,
+ qk_scale=None,
+ drop=0.0,
+ attn_drop=0.0,
+ mlp_drop=0.0,
+ post_mlp_drop=0.0,
+ drop_path=0.0,
+ act_layer=nn.GELU,
+ norm_layer=nn.LayerNorm,
+ layer_norm_first=True,
+ cosine_attention=False,
+ first_residual=True,
+ ):
+ super().__init__()
+
+ self.layer_norm_first = layer_norm_first
+
+ from timm.models.vision_transformer import DropPath, Mlp
+
+ self.norm1 = norm_layer(q_dim)
+ self.attn = EncDecAttention(
+ q_dim,
+ kv_dim,
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ attn_drop=attn_drop,
+ proj_drop=drop,
+ cosine_attention=cosine_attention,
+ )
+
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+ self.norm2 = norm_layer(q_dim)
+ mlp_hidden_dim = int(q_dim * mlp_ratio)
+ self.mlp = Mlp(
+ in_features=q_dim,
+ hidden_features=mlp_hidden_dim,
+ act_layer=act_layer,
+ drop=mlp_drop,
+ )
+ self.post_mlp_dropout = nn.Dropout(post_mlp_drop, inplace=False)
+ self.first_residual = first_residual
+
+ def forward(self, q, kv, padding_mask=None, alibi_bias=None):
+ r = q if self.first_residual else 0
+ if self.layer_norm_first:
+ x = r + self.drop_path(
+ self.attn(self.norm1(q), kv, padding_mask, alibi_bias)
+ )
+ r = x = self.mlp(self.norm2(x))
+ x = r + self.drop_path(self.post_mlp_dropout(x))
+ else:
+ x = r + self.drop_path(self.attn(q, kv, padding_mask, alibi_bias))
+ r = x = self.norm1(x)
+ x = self.mlp(x)
+ x = self.norm2(r + self.drop_path(self.post_mlp_dropout(x)))
+
+ return x
+
+
+class EncDecTransformerDecoder(nn.Module):
+ def __init__(self, cfg: D2vDecoderConfig, input_dim):
+ super().__init__()
+
+ self.input_proj = nn.Linear(input_dim, cfg.decoder_dim)
+
+ self.blocks = nn.Sequential(
+ *[
+ EncDecBlock(
+ q_dim=cfg.decoder_dim,
+ kv_dim=input_dim,
+ num_heads=8,
+ mlp_ratio=4.0,
+ qkv_bias=True,
+ qk_scale=None,
+ drop=0.0,
+ attn_drop=0.0,
+ mlp_drop=0.0,
+ post_mlp_drop=0.0,
+ drop_path=0.0,
+ act_layer=nn.GELU,
+ norm_layer=nn.LayerNorm,
+ layer_norm_first=False,
+ cosine_attention=False,
+ first_residual=i > 0,
+ )
+ for i in range(cfg.decoder_layers)
+ ]
+ )
+
+ self.proj = nn.Linear(cfg.decoder_dim, input_dim)
+
+ def reset_parameters(self):
+ from fairseq.modules.transformer_sentence_encoder import init_bert_params
+
+ self.apply(init_bert_params)
+
+ def forward(self, x, kv):
+ x = self.input_proj(x)
+ for i, layer in enumerate(self.blocks):
+ x = layer(x, kv)
+
+ x = self.proj(x)
+ return x
diff --git a/examples/data2vec/models/modalities/text.py b/examples/data2vec/models/modalities/text.py
new file mode 100644
index 000000000..adfac1ca4
--- /dev/null
+++ b/examples/data2vec/models/modalities/text.py
@@ -0,0 +1,161 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import math
+from dataclasses import dataclass
+from functools import partial
+from typing import Callable, Dict, Optional
+
+import torch.nn as nn
+import torch.nn.functional as F
+import numpy as np
+from fairseq.modules import PositionalEmbedding, FairseqDropout, LayerNorm
+from fairseq.tasks import FairseqTask
+from .base import D2vModalityConfig, ModalitySpecificEncoder, get_alibi_bias
+from .modules import BlockEncoder, Decoder1d
+from examples.data2vec.data.modality import Modality
+
+
+@dataclass
+class D2vTextConfig(D2vModalityConfig):
+ type: Modality = Modality.TEXT
+ max_source_positions: int = 512
+ learned_pos: bool = True
+ dropout: float = 0.1 # used for both local_encoder and contextualized encoder. tied with global transformer in data2vec_text
+
+ no_scale_embedding: bool = True
+ layernorm_embedding: bool = True
+ no_token_positional_embeddings: bool = False
+
+
+class TextEncoder(ModalitySpecificEncoder):
+
+ modality_cfg: D2vTextConfig
+
+ def __init__(
+ self,
+ modality_cfg: D2vTextConfig,
+ embed_dim: int,
+ make_block: Callable[[float], nn.ModuleList],
+ norm_layer: Callable[[int], nn.LayerNorm],
+ layer_norm_first: bool,
+ alibi_biases: Dict,
+ task: Optional[FairseqTask],
+ ):
+ self.pad_idx = task.source_dictionary.pad()
+ self.vocab_size = len(task.source_dictionary)
+
+ local_encoder = TextLocalEncoder(
+ vocab_size=self.vocab_size,
+ embed_dim=embed_dim,
+ max_source_positions=modality_cfg.max_source_positions,
+ pad_idx=self.pad_idx,
+ no_scale_embedding=modality_cfg.no_scale_embedding,
+ layernorm_embedding=modality_cfg.layernorm_embedding,
+ dropout=modality_cfg.dropout,
+ no_token_positional_embeddings=modality_cfg.no_token_positional_embeddings,
+ learned_pos=modality_cfg.learned_pos,
+ )
+ dpr = np.linspace(
+ modality_cfg.start_drop_path_rate,
+ modality_cfg.end_drop_path_rate,
+ modality_cfg.prenet_depth,
+ )
+ context_encoder = BlockEncoder(
+ nn.ModuleList(make_block(dpr[i]) for i in range(modality_cfg.prenet_depth)),
+ norm_layer(embed_dim)
+ if not layer_norm_first and modality_cfg.prenet_depth > 0
+ else None,
+ layer_norm_first,
+ modality_cfg.prenet_layerdrop,
+ modality_cfg.prenet_dropout if modality_cfg.prenet_depth > 0 else 0.0,
+ )
+ decoder = (
+ Decoder1d(modality_cfg.decoder, embed_dim)
+ if modality_cfg.decoder is not None
+ else None
+ )
+
+ alibi_bias_fn = partial(get_alibi_bias, alibi_biases=alibi_biases)
+
+ super().__init__(
+ modality_cfg=modality_cfg,
+ embed_dim=embed_dim,
+ local_encoder=local_encoder,
+ project_features=nn.Identity(),
+ fixed_positional_encoder=None,
+ relative_positional_encoder=None,
+ context_encoder=context_encoder,
+ decoder=decoder,
+ get_alibi_bias=alibi_bias_fn,
+ )
+
+ def reset_parameters(self):
+ super().reset_parameters()
+
+ def convert_padding_mask(self, x, padding_mask):
+ if padding_mask is None or padding_mask.size(1) == x.size(1):
+ return padding_mask
+
+ diff = self.downsample - padding_mask.size(1) % self.downsample
+ if 0 < diff < self.downsample:
+ padding_mask = F.pad(padding_mask, (0, diff), value=True)
+
+ padding_mask = padding_mask.view(padding_mask.size(0), -1, self.downsample)
+ padding_mask = padding_mask.all(-1)
+ if padding_mask.size(1) > x.size(1):
+ padding_mask = padding_mask[:, : x.size(1)]
+
+ assert x.size(1) == padding_mask.size(
+ 1
+ ), f"{x.size(1), padding_mask.size(1), diff, self.downsample}"
+
+ return padding_mask
+
+
+class TextLocalEncoder(nn.Module):
+ def __init__(
+ self,
+ vocab_size,
+ embed_dim,
+ max_source_positions,
+ pad_idx,
+ no_scale_embedding,
+ layernorm_embedding,
+ dropout,
+ no_token_positional_embeddings,
+ learned_pos,
+ ):
+ super().__init__()
+ self.pad_idx = pad_idx
+ self.dropout_module = FairseqDropout(dropout)
+
+ self.embed_tokens = nn.Embedding(vocab_size, embed_dim, pad_idx)
+ self.embed_scale = 1.0 if no_scale_embedding else math.sqrt(embed_dim)
+ self.embed_positions = (
+ PositionalEmbedding(
+ max_source_positions,
+ embed_dim,
+ pad_idx,
+ learned=learned_pos,
+ )
+ if not no_token_positional_embeddings
+ else None
+ )
+ self.embed_scale = 1.0 if no_scale_embedding else math.sqrt(embed_dim)
+
+ self.layernorm_embedding = None
+ if layernorm_embedding:
+ self.layernorm_embedding = LayerNorm(embed_dim)
+
+ def forward(self, src_tokens):
+ x = self.embed_scale * self.embed_tokens(src_tokens)
+ if self.embed_positions is not None:
+ x = x + self.embed_positions(src_tokens)
+
+ if self.layernorm_embedding is not None:
+ x = self.layernorm_embedding(x)
+ x = self.dropout_module(x)
+ return x
diff --git a/examples/data2vec/models/utils.py b/examples/data2vec/models/utils.py
new file mode 100644
index 000000000..0e2f240d4
--- /dev/null
+++ b/examples/data2vec/models/utils.py
@@ -0,0 +1,55 @@
+import math
+import torch
+
+def get_alibi(
+ max_positions: int,
+ attention_heads: int,
+):
+ def get_slopes(n):
+ def get_slopes_power_of_2(n):
+ start = 2 ** (-(2 ** -(math.log2(n) - 3)))
+ ratio = start
+ return [start * ratio ** i for i in range(n)]
+
+ # In the paper, we only train models that have 2^a heads for some
+ # a. This function has some good properties that only occur when
+ # the input is a power of 2. To maintain that even when the number
+ # of heads is not a power of 2, we use this workaround.
+ if math.log2(n).is_integer():
+ return get_slopes_power_of_2(n)
+ else:
+ closest_power_of_2 = 2 ** math.floor(math.log2(n))
+ return (
+ get_slopes_power_of_2(closest_power_of_2)
+ + get_slopes(2 * closest_power_of_2)[0::2][: n - closest_power_of_2]
+ )
+
+ maxpos = max_positions
+ attn_heads = attention_heads
+ slopes = torch.Tensor(get_slopes(attn_heads))
+ # prepare alibi position linear bias. Note that wav2vec2 is non
+ # autoregressive model so we want a symmetric mask with 0 on the
+ # diagonal and other wise linear decreasing valuees
+ pos_bias = (
+ torch.abs(
+ torch.arange(maxpos).unsqueeze(0) - torch.arange(maxpos).unsqueeze(1)
+ )
+ * -1
+ )
+ alibi_bias = slopes.unsqueeze(1).unsqueeze(1) * pos_bias.unsqueeze(0).expand(
+ attn_heads, -1, -1
+ )
+ return alibi_bias
+
+def masked_alibi(alibi_bias, mask_indices, orig_B, orig_T):
+ alibi_bias = alibi_bias.view(orig_B, -1, orig_T, orig_T)
+ H = alibi_bias.size(1)
+ alibi_mask = mask_indices.unsqueeze(1)
+ alibi_bias = alibi_bias.masked_select(alibi_mask.unsqueeze(-1))
+ alibi_bias = alibi_bias.view(orig_B, H, -1, orig_T)
+ M = alibi_bias.size(-2)
+ alibi_bias = alibi_bias.masked_select(alibi_mask.unsqueeze(-2))
+ alibi_bias = alibi_bias.view(-1, M, M)
+ return alibi_bias
+
+
diff --git a/examples/data2vec/scripts/convert_audioset_labels.py b/examples/data2vec/scripts/convert_audioset_labels.py
new file mode 100644
index 000000000..7d720e606
--- /dev/null
+++ b/examples/data2vec/scripts/convert_audioset_labels.py
@@ -0,0 +1,63 @@
+#!/usr/bin/env python3
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import argparse
+import os
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(description="convert audioset labels")
+ # fmt: off
+ parser.add_argument('in_file', help='audioset csv file to convert')
+ parser.add_argument('--manifest', required=True, metavar='PATH', help='wav2vec-like manifest')
+ parser.add_argument('--descriptors', required=True, metavar='PATH', help='path to label descriptor file')
+ parser.add_argument('--output', required=True, metavar='PATH', help='where to output converted labels')
+ # fmt: on
+
+ return parser
+
+
+def main():
+ parser = get_parser()
+ args = parser.parse_args()
+
+ label_descriptors = {}
+ with open(args.descriptors, "r") as ldf:
+ next(ldf)
+ for line in ldf:
+ if line.strip() == "":
+ continue
+
+ items = line.split(",")
+ assert len(items) > 2, line
+ idx = items[0]
+ lbl = items[1]
+ assert lbl not in label_descriptors, lbl
+ label_descriptors[lbl] = idx
+
+ labels = {}
+ with open(args.in_file, "r") as ifd:
+ for line in ifd:
+ if line.lstrip().startswith("#"):
+ continue
+ items = line.rstrip().split(",")
+ id = items[0].strip()
+ start = items[1].strip()
+ end = items[2].strip()
+ lbls = [label_descriptors[it.strip(' "')] for it in items[3:]]
+ labels[id] = [start, end, ",".join(lbls)]
+
+ with open(args.manifest, "r") as mf, open(args.output, "w") as of:
+ next(mf)
+ for line in mf:
+ path, _ = line.split("\t")
+ id = os.path.splitext(os.path.basename(path))[0]
+ lbl = labels[id]
+ print("\t".join(lbl), file=of)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/data2vec/scripts/multi/finetune_all_fair_aws_local_lr.sh b/examples/data2vec/scripts/multi/finetune_all_fair_aws_local_lr.sh
new file mode 100755
index 000000000..41bcd31fc
--- /dev/null
+++ b/examples/data2vec/scripts/multi/finetune_all_fair_aws_local_lr.sh
@@ -0,0 +1,18 @@
+#!/bin/bash
+
+set -eu
+
+job_id="$1"
+task_id="$2"
+dir="$3"
+
+echo "job_id: $job_id, task_id: $task_id, dir: $dir"
+
+mkdir -p "$dir/log"
+sbatch_args="-p wav2vec --nodes=1 --ntasks-per-node=1"
+sbatch_args="$sbatch_args --gpus-per-node=1 --cpus-per-task=8 --mem=0 --time=24:00:00"
+sbatch_args="$sbatch_args -d afterok:$job_id -o $dir/log/decode_sweep_%A.out"
+sbatch_args="$sbatch_args -e $dir/log/decode_sweep_%A.err"
+
+sbatch $sbatch_args examples/data2vec/scripts/multi/finetune_all_fair_local_lr.sh $dir
+
diff --git a/examples/data2vec/scripts/multi/finetune_all_fair_aws_local_lr_nodep.sh b/examples/data2vec/scripts/multi/finetune_all_fair_aws_local_lr_nodep.sh
new file mode 100644
index 000000000..fc85908b7
--- /dev/null
+++ b/examples/data2vec/scripts/multi/finetune_all_fair_aws_local_lr_nodep.sh
@@ -0,0 +1,16 @@
+#!/bin/bash
+
+set -eu
+
+dir="$1"
+
+echo "dir: $dir"
+
+mkdir -p "$dir/log"
+sbatch_args="-p wav2vec --nodes=1 --ntasks-per-node=1"
+sbatch_args="$sbatch_args --gpus-per-node=1 --cpus-per-task=8 --mem=0 --time=24:00:00"
+sbatch_args="$sbatch_args -o $dir/log/decode_sweep_%A.out"
+sbatch_args="$sbatch_args -e $dir/log/decode_sweep_%A.err"
+
+sbatch $sbatch_args examples/data2vec/scripts/multi/finetune_all_fair_local_lr.sh $dir
+
diff --git a/examples/data2vec/scripts/multi/finetune_all_fair_local_lr.sh b/examples/data2vec/scripts/multi/finetune_all_fair_local_lr.sh
new file mode 100755
index 000000000..121226972
--- /dev/null
+++ b/examples/data2vec/scripts/multi/finetune_all_fair_local_lr.sh
@@ -0,0 +1,28 @@
+#!/usr/bin/env zsh
+
+dir="$1"
+cp="$dir/checkpoints/checkpoint_last.pt"
+
+echo "dir: $dir"
+
+declare -A tasks
+tasks[cola]="/fsx-wav2vec/abaevski/data/nlp/GLUE/CoLA-bin"
+tasks[qnli]="/fsx-wav2vec/abaevski/data/nlp/GLUE/QNLI-bin"
+tasks[mrpc]="/fsx-wav2vec/abaevski/data/nlp/GLUE/MRPC-bin"
+tasks[rte]="/fsx-wav2vec/abaevski/data/nlp/GLUE/RTE-bin"
+tasks[sst_2]="/fsx-wav2vec/abaevski/data/nlp/GLUE/SST-2-bin"
+tasks[mnli]="/fsx-wav2vec/abaevski/data/nlp/GLUE/MNLI-bin"
+tasks[qqp]="/fsx-wav2vec/abaevski/data/nlp/GLUE/QQP-bin"
+tasks[sts_b]="/fsx-wav2vec/abaevski/data/nlp/GLUE/STS-B-bin"
+
+lrs=(5e-6 8e-6 1e-5 2e-5)
+
+for task data_path in ${(kv)tasks}; do
+ for lr in $lrs; do
+ echo $lr $task
+ PYTHONPATH=. PREFIX="${PREFIX}" SUFFIX="" \
+ python fairseq_cli/hydra_train.py -m --config-dir examples/data2vec/config/multi/text_finetuning \
+ --config-name $task +run_config=local task.data="$data_path" common.log_interval=200 dataset.num_workers=1 \
+ model.model_path="$cp" hydra.sweep.dir="$dir/finetune_lr/$task/$lr" "optimization.lr=[${lr}]" +model=text_wrap
+ done
+done
diff --git a/examples/data2vec/scripts/text/finetune_all_char_fair_aws_local_lr.sh b/examples/data2vec/scripts/text/finetune_all_char_fair_aws_local_lr.sh
new file mode 100755
index 000000000..18b862c24
--- /dev/null
+++ b/examples/data2vec/scripts/text/finetune_all_char_fair_aws_local_lr.sh
@@ -0,0 +1,17 @@
+#!/bin/bash
+
+set -eu
+
+job_id="$1"
+task_id="$2"
+dir="$3"
+
+echo "job_id: $job_id, task_id: $task_id, dir: $dir"
+
+mkdir -p "$dir/log"
+sbatch_args="-p wav2vec --nodes=1 --ntasks-per-node=1"
+sbatch_args="$sbatch_args --gpus-per-node=1 --cpus-per-task=8 --mem=0 --time=24:00:00"
+sbatch_args="$sbatch_args -d afterok:$job_id -o $dir/log/ft_%A.out"
+sbatch_args="$sbatch_args -e $dir/log/ft_%A.err"
+
+sbatch $sbatch_args examples/data2vec/scripts/text/finetune_all_char_fair_local_lr.sh $dir
diff --git a/examples/data2vec/scripts/text/finetune_all_fair.sh b/examples/data2vec/scripts/text/finetune_all_fair.sh
new file mode 100755
index 000000000..34a2df399
--- /dev/null
+++ b/examples/data2vec/scripts/text/finetune_all_fair.sh
@@ -0,0 +1,21 @@
+#!/usr/bin/env zsh
+
+job_id=$1
+task_id=$2
+dir="$3"
+cp="$dir/$task_id/checkpoints/checkpoint_last.pt"
+
+echo "job_id: $job_id, task_id: $task_id, dir: $dir"
+
+declare -A tasks
+tasks[cola]="/private/home/jgu/data/GLUE/CoLA-bin"
+tasks[qnli]="/private/home/jgu/data/GLUE/QNLI-bin"
+tasks[mrpc]="/private/home/jgu/data/GLUE/MRPC-bin"
+tasks[rte]="/private/home/jgu/data/GLUE/RTE-bin"
+tasks[sst_2]="/private/home/jgu/data/GLUE/SST-2-bin"
+
+for task data_path in ${(kv)tasks}; do
+ PYTHONPATH=. PREFIX="${PREFIX}" SUFFIX="" nohup python fairseq_cli/hydra_train.py -m --config-dir examples/roberta/config/finetuning \
+ --config-name $task hydra/launcher=submitit_slurm +run_config=slurm_1g task.data="$data_path" hydra.launcher.name=finetune_${task}_${PREFIX} \
+ checkpoint.restore_file="$cp" +hydra.launcher.additional_parameters.dependency="afterok:$job_id" hydra.sweep.dir="$dir/finetune/$task" &
+done
diff --git a/examples/data2vec/scripts/text/finetune_all_fair_aws.sh b/examples/data2vec/scripts/text/finetune_all_fair_aws.sh
new file mode 100755
index 000000000..b417c2002
--- /dev/null
+++ b/examples/data2vec/scripts/text/finetune_all_fair_aws.sh
@@ -0,0 +1,21 @@
+#!/usr/bin/env zsh
+
+job_id=$1
+task_id=$2
+dir="$3"
+cp="$dir/checkpoints/checkpoint_last.pt"
+
+echo "job_id: $job_id, task_id: $task_id, dir: $dir"
+
+declare -A tasks
+tasks[cola]="/fsx-wav2vec/abaevski/data/nlp/GLUE/CoLA-bin"
+tasks[qnli]="/fsx-wav2vec/abaevski/data/nlp/GLUE/QNLI-bin"
+tasks[mrpc]="/fsx-wav2vec/abaevski/data/nlp/GLUE/MRPC-bin"
+tasks[rte]="/fsx-wav2vec/abaevski/data/nlp/GLUE/RTE-bin"
+tasks[sst_2]="/fsx-wav2vec/abaevski/data/nlp/GLUE/SST-2-bin"
+
+for task data_path in ${(kv)tasks}; do
+ PYTHONPATH=. PREFIX="${PREFIX}" SUFFIX="" nohup python fairseq_cli/hydra_train.py -m --config-dir examples/roberta/config/finetuning \
+ --config-name $task hydra/launcher=submitit_slurm +run_config=slurm_1g_aws task.data="$data_path" hydra.launcher.name=finetune_${task}_${PREFIX} \
+ checkpoint.restore_file="$cp" +hydra.launcher.additional_parameters.dependency="afterok:$job_id" hydra.sweep.dir="$dir/finetune/$task" &
+done
diff --git a/examples/data2vec/scripts/text/finetune_all_fair_aws_local_lr.sh b/examples/data2vec/scripts/text/finetune_all_fair_aws_local_lr.sh
new file mode 100755
index 000000000..64dbcb111
--- /dev/null
+++ b/examples/data2vec/scripts/text/finetune_all_fair_aws_local_lr.sh
@@ -0,0 +1,17 @@
+#!/bin/bash
+
+set -eu
+
+job_id="$1"
+task_id="$2"
+dir="$3"
+
+echo "job_id: $job_id, task_id: $task_id, dir: $dir"
+
+mkdir -p "$dir/log"
+sbatch_args="-p wav2vec --nodes=1 --ntasks-per-node=1"
+sbatch_args="$sbatch_args --gpus-per-node=1 --cpus-per-task=8 --mem=0 --time=24:00:00"
+sbatch_args="$sbatch_args -d afterok:$job_id -o $dir/log/decode_sweep_%A.out"
+sbatch_args="$sbatch_args -e $dir/log/decode_sweep_%A.err"
+
+sbatch $sbatch_args examples/data2vec/scripts/text/finetune_all_fair_local_lr.sh $dir
diff --git a/examples/data2vec/scripts/text/finetune_all_fair_aws_lr.sh b/examples/data2vec/scripts/text/finetune_all_fair_aws_lr.sh
new file mode 100755
index 000000000..d75c54957
--- /dev/null
+++ b/examples/data2vec/scripts/text/finetune_all_fair_aws_lr.sh
@@ -0,0 +1,23 @@
+#!/usr/bin/env zsh
+
+job_id=$1
+task_id=$2
+dir="$3"
+cp="$dir/checkpoints/checkpoint_last.pt"
+
+echo "job_id: $job_id, task_id: $task_id, dir: $dir"
+
+declare -A tasks
+tasks[cola]="/fsx-wav2vec/abaevski/data/nlp/GLUE/CoLA-bin"
+tasks[qnli]="/fsx-wav2vec/abaevski/data/nlp/GLUE/QNLI-bin"
+tasks[mrpc]="/fsx-wav2vec/abaevski/data/nlp/GLUE/MRPC-bin"
+tasks[rte]="/fsx-wav2vec/abaevski/data/nlp/GLUE/RTE-bin"
+tasks[sst_2]="/fsx-wav2vec/abaevski/data/nlp/GLUE/SST-2-bin"
+
+for task data_path in ${(kv)tasks}; do
+ for lr in 5e-6 8e-6 1e-5 2e-5 5e-5 8e-5 1e-4 2e-4; do
+ PYTHONPATH=. PREFIX="${PREFIX}" SUFFIX="" nohup python fairseq_cli/hydra_train.py -m --config-dir examples/roberta/config/finetuning \
+ --config-name $task hydra/launcher=submitit_slurm +run_config=slurm_1g_aws task.data="$data_path" hydra.launcher.name=finetune_${task}_${PREFIX} \
+ checkpoint.restore_file="$cp" +hydra.launcher.additional_parameters.dependency="afterok:$job_id" hydra.sweep.dir="$dir/finetune_lr/$task/$lr" "optimization.lr=[${lr}]" &
+ done
+done
diff --git a/examples/data2vec/scripts/text/finetune_all_fair_local_lr.sh b/examples/data2vec/scripts/text/finetune_all_fair_local_lr.sh
new file mode 100755
index 000000000..8be98c084
--- /dev/null
+++ b/examples/data2vec/scripts/text/finetune_all_fair_local_lr.sh
@@ -0,0 +1,25 @@
+#!/usr/bin/env zsh
+
+dir="$1"
+cp="$dir/checkpoints/checkpoint_last.pt"
+
+echo "dir: $dir"
+
+declare -A tasks
+tasks[cola]="/fsx-wav2vec/abaevski/data/nlp/GLUE/CoLA-bin"
+tasks[qnli]="/fsx-wav2vec/abaevski/data/nlp/GLUE/QNLI-bin"
+tasks[mrpc]="/fsx-wav2vec/abaevski/data/nlp/GLUE/MRPC-bin"
+tasks[rte]="/fsx-wav2vec/abaevski/data/nlp/GLUE/RTE-bin"
+tasks[sst_2]="/fsx-wav2vec/abaevski/data/nlp/GLUE/SST-2-bin"
+
+lrs=(5e-6 8e-6 1e-5 2e-5)
+
+for task data_path in ${(kv)tasks}; do
+ for lr in $lrs; do
+ echo $lr $task
+ PYTHONPATH=. PREFIX="${PREFIX}" SUFFIX="" \
+ python fairseq_cli/hydra_train.py -m --config-dir examples/roberta/config/finetuning \
+ --config-name $task +run_config=local task.data="$data_path" common.log_interval=200 dataset.num_workers=1 \
+ checkpoint.restore_file="$cp" hydra.sweep.dir="$dir/finetune_lr/$task/$lr" "optimization.lr=[${lr}]"
+ done
+done
diff --git a/examples/data2vec/scripts/text/finetune_all_fair_nodep.sh b/examples/data2vec/scripts/text/finetune_all_fair_nodep.sh
new file mode 100755
index 000000000..d02bcc0f7
--- /dev/null
+++ b/examples/data2vec/scripts/text/finetune_all_fair_nodep.sh
@@ -0,0 +1,19 @@
+#!/usr/bin/env zsh
+
+dir="$1"
+cp="$dir/checkpoints/checkpoint_last.pt"
+
+echo "dir: $dir"
+
+declare -A tasks
+tasks[cola]="/private/home/jgu/data/GLUE/CoLA-bin"
+tasks[qnli]="/private/home/jgu/data/GLUE/QNLI-bin"
+tasks[mrpc]="/private/home/jgu/data/GLUE/MRPC-bin"
+tasks[rte]="/private/home/jgu/data/GLUE/RTE-bin"
+tasks[sst_2]="/private/home/jgu/data/GLUE/SST-2-bin"
+
+for task data_path in ${(kv)tasks}; do
+ PYTHONPATH=. PREFIX="${PREFIX}" SUFFIX="" nohup python fairseq_cli/hydra_train.py -m --config-dir examples/roberta/config/finetuning \
+ --config-name $task hydra/launcher=submitit_slurm +run_config=slurm_1g task.data="$data_path" hydra.launcher.name=finetune_${task}_${PREFIX} \
+ checkpoint.restore_file="$cp" hydra.sweep.dir="$dir/finetune/$task" &
+done
diff --git a/examples/data2vec/scripts/text/finetune_all_fair_nodep_aws.sh b/examples/data2vec/scripts/text/finetune_all_fair_nodep_aws.sh
new file mode 100755
index 000000000..75538354e
--- /dev/null
+++ b/examples/data2vec/scripts/text/finetune_all_fair_nodep_aws.sh
@@ -0,0 +1,19 @@
+#!/usr/bin/env zsh
+
+dir="$1"
+cp="$dir/checkpoints/checkpoint_last.pt"
+
+echo "dir: $dir"
+
+declare -A tasks
+tasks[cola]="/fsx-wav2vec/abaevski/data/nlp/GLUE/CoLA-bin"
+tasks[qnli]="/fsx-wav2vec/abaevski/data/nlp/GLUE/QNLI-bin"
+tasks[mrpc]="/fsx-wav2vec/abaevski/data/nlp/GLUE/MRPC-bin"
+tasks[rte]="/fsx-wav2vec/abaevski/data/nlp/GLUE/RTE-bin"
+tasks[sst_2]="/fsx-wav2vec/abaevski/data/nlp/GLUE/SST-2-bin"
+
+for task data_path in ${(kv)tasks}; do
+ PYTHONPATH=. PREFIX="${PREFIX}" SUFFIX="" nohup python fairseq_cli/hydra_train.py -m --config-dir examples/roberta/config/finetuning \
+ --config-name $task hydra/launcher=submitit_slurm +run_config=slurm_1g_aws task.data="$data_path" hydra.launcher.name=finetune_${task}_${PREFIX} \
+ checkpoint.restore_file="$cp" hydra.sweep.dir="$dir/finetune/$task" &
+done
diff --git a/examples/data2vec/scripts/text/finetune_all_fair_nodep_aws_local_lr.sh b/examples/data2vec/scripts/text/finetune_all_fair_nodep_aws_local_lr.sh
new file mode 100755
index 000000000..16c1358b2
--- /dev/null
+++ b/examples/data2vec/scripts/text/finetune_all_fair_nodep_aws_local_lr.sh
@@ -0,0 +1,15 @@
+#!/bin/bash
+
+set -eu
+
+dir="$1"
+
+echo "dir: $dir"
+
+mkdir -p "$dir/log"
+sbatch_args="-p wav2vec --nodes=1 --ntasks-per-node=1"
+sbatch_args="$sbatch_args --gpus-per-node=1 --cpus-per-task=8 --mem=0 --time=24:00:00"
+sbatch_args="$sbatch_args -o $dir/log/decode_sweep_%A.out"
+sbatch_args="$sbatch_args -e $dir/log/decode_sweep_%A.err"
+
+sbatch $sbatch_args examples/data2vec/scripts/text/finetune_all_fair_local_lr.sh $dir
diff --git a/examples/data2vec/scripts/text/finetune_all_fair_nodep_aws_lr.sh b/examples/data2vec/scripts/text/finetune_all_fair_nodep_aws_lr.sh
new file mode 100755
index 000000000..fb5ddbe22
--- /dev/null
+++ b/examples/data2vec/scripts/text/finetune_all_fair_nodep_aws_lr.sh
@@ -0,0 +1,21 @@
+#!/usr/bin/env zsh
+
+dir="$1"
+cp="$dir/checkpoints/checkpoint_last.pt"
+
+echo "dir: $dir"
+
+declare -A tasks
+tasks[cola]="/fsx-wav2vec/abaevski/data/nlp/GLUE/CoLA-bin"
+tasks[qnli]="/fsx-wav2vec/abaevski/data/nlp/GLUE/QNLI-bin"
+tasks[mrpc]="/fsx-wav2vec/abaevski/data/nlp/GLUE/MRPC-bin"
+tasks[rte]="/fsx-wav2vec/abaevski/data/nlp/GLUE/RTE-bin"
+tasks[sst_2]="/fsx-wav2vec/abaevski/data/nlp/GLUE/SST-2-bin"
+
+for task data_path in ${(kv)tasks}; do
+ for lr in 5e-6 8e-6 1e-5 2e-5 5e-5 8e-5 1e-4 2e-4; do
+ PYTHONPATH=. PREFIX="${PREFIX}" SUFFIX="" nohup python fairseq_cli/hydra_train.py -m --config-dir examples/roberta/config/finetuning \
+ --config-name $task hydra/launcher=submitit_slurm +run_config=slurm_1g_aws task.data="$data_path" hydra.launcher.name=finetune_${task}_${PREFIX} \
+ checkpoint.restore_file="$cp" hydra.sweep.dir="$dir/finetune_lr/$task/$lr" "optimization.lr=[${lr}]" &
+ done
+done
diff --git a/examples/data2vec/scripts/text/finetune_all_fair_nodep_aws_lr_nopos.sh b/examples/data2vec/scripts/text/finetune_all_fair_nodep_aws_lr_nopos.sh
new file mode 100755
index 000000000..1ffab1c85
--- /dev/null
+++ b/examples/data2vec/scripts/text/finetune_all_fair_nodep_aws_lr_nopos.sh
@@ -0,0 +1,21 @@
+#!/usr/bin/env zsh
+
+dir="$1"
+cp="$dir/checkpoints/checkpoint_last.pt"
+
+echo "dir: $dir"
+
+declare -A tasks
+tasks[cola]="/fsx-wav2vec/abaevski/data/nlp/GLUE/CoLA-bin"
+tasks[qnli]="/fsx-wav2vec/abaevski/data/nlp/GLUE/QNLI-bin"
+tasks[mrpc]="/fsx-wav2vec/abaevski/data/nlp/GLUE/MRPC-bin"
+tasks[rte]="/fsx-wav2vec/abaevski/data/nlp/GLUE/RTE-bin"
+tasks[sst_2]="/fsx-wav2vec/abaevski/data/nlp/GLUE/SST-2-bin"
+
+for task data_path in ${(kv)tasks}; do
+ for lr in 5e-6 8e-6 1e-5 2e-5 5e-5 8e-5 1e-4 2e-4; do
+ PYTHONPATH=. PREFIX="${PREFIX}" SUFFIX="" nohup python fairseq_cli/hydra_train.py -m --config-dir examples/roberta/config/finetuning \
+ --config-name $task hydra/launcher=submitit_slurm +run_config=slurm_1g_aws task.data="$data_path" hydra.launcher.name=finetune_${task}_${PREFIX} \
+ checkpoint.restore_file="$cp" hydra.sweep.dir="$dir/finetune_lr/$task/$lr" "optimization.lr=[${lr}]" +model.encoder_learned_pos=False &
+ done
+done
diff --git a/examples/data2vec/scripts/text/finetune_all_large_fair_aws_local_lr.sh b/examples/data2vec/scripts/text/finetune_all_large_fair_aws_local_lr.sh
new file mode 100755
index 000000000..c3c58adcb
--- /dev/null
+++ b/examples/data2vec/scripts/text/finetune_all_large_fair_aws_local_lr.sh
@@ -0,0 +1,17 @@
+#!/bin/bash
+
+set -eu
+
+job_id="$1"
+task_id="$2"
+dir="$3"
+
+echo "job_id: $job_id, task_id: $task_id, dir: $dir"
+
+mkdir -p "$dir/log"
+sbatch_args="-p wav2vec --nodes=1 --ntasks-per-node=1"
+sbatch_args="$sbatch_args --gpus-per-node=1 --cpus-per-task=8 --mem=0 --time=24:00:00"
+sbatch_args="$sbatch_args -d afterok:$job_id -o $dir/log/decode_sweep_%A.out"
+sbatch_args="$sbatch_args -e $dir/log/decode_sweep_%A.err"
+
+sbatch $sbatch_args examples/data2vec/scripts/text/finetune_all_large_fair_local_lr.sh $dir
diff --git a/examples/data2vec/scripts/text/finetune_all_large_fair_local_lr.sh b/examples/data2vec/scripts/text/finetune_all_large_fair_local_lr.sh
new file mode 100644
index 000000000..5efb00e0d
--- /dev/null
+++ b/examples/data2vec/scripts/text/finetune_all_large_fair_local_lr.sh
@@ -0,0 +1,26 @@
+#!/usr/bin/env zsh
+
+dir="$1"
+cp="$dir/checkpoints/checkpoint_last.pt"
+
+echo "dir: $dir"
+
+declare -A tasks
+tasks[cola]="/fsx-wav2vec/abaevski/data/nlp/GLUE/CoLA-bin"
+tasks[qnli]="/fsx-wav2vec/abaevski/data/nlp/GLUE/QNLI-bin"
+tasks[mrpc]="/fsx-wav2vec/abaevski/data/nlp/GLUE/MRPC-bin"
+tasks[rte]="/fsx-wav2vec/abaevski/data/nlp/GLUE/RTE-bin"
+tasks[sst_2]="/fsx-wav2vec/abaevski/data/nlp/GLUE/SST-2-bin"
+
+lrs=(5e-6 8e-6 1e-5 2e-5)
+
+for task data_path in ${(kv)tasks}; do
+ for lr in $lrs; do
+ echo $lr $task
+ PYTHONPATH=. PREFIX="${PREFIX}" SUFFIX="" \
+ python fairseq_cli/hydra_train.py -m --config-dir examples/roberta/config/finetuning \
+ --config-name $task +run_config=local task.data="$data_path" common.log_interval=200 dataset.num_workers=1 \
+ checkpoint.restore_file="$cp" hydra.sweep.dir="$dir/finetune_lr/$task/$lr" "optimization.lr=[${lr}]" \
+ model._name=roberta_large
+ done
+done
diff --git a/examples/data2vec/scripts/text/finetune_all_large_fair_nodep_aws_local_lr.sh b/examples/data2vec/scripts/text/finetune_all_large_fair_nodep_aws_local_lr.sh
new file mode 100755
index 000000000..4fb21bce7
--- /dev/null
+++ b/examples/data2vec/scripts/text/finetune_all_large_fair_nodep_aws_local_lr.sh
@@ -0,0 +1,15 @@
+#!/bin/bash
+
+set -eu
+
+dir="$1"
+
+echo "dir: $dir"
+
+mkdir -p "$dir/log"
+sbatch_args="-p wav2vec --nodes=1 --ntasks-per-node=1"
+sbatch_args="$sbatch_args --gpus-per-node=1 --cpus-per-task=8 --mem=0 --time=24:00:00"
+sbatch_args="$sbatch_args -o $dir/log/decode_sweep_%A.out"
+sbatch_args="$sbatch_args -e $dir/log/decode_sweep_%A.err"
+
+sbatch $sbatch_args examples/data2vec/scripts/text/finetune_all_large_fair_local_lr.sh $dir
diff --git a/examples/data2vec/scripts/text/finetune_sst2_qnli_sweep_fair_nodep.sh b/examples/data2vec/scripts/text/finetune_sst2_qnli_sweep_fair_nodep.sh
new file mode 100755
index 000000000..d7b43bee8
--- /dev/null
+++ b/examples/data2vec/scripts/text/finetune_sst2_qnli_sweep_fair_nodep.sh
@@ -0,0 +1,20 @@
+#!/usr/bin/env zsh
+
+dir="$1"
+cp="$dir/checkpoints/checkpoint_last.pt"
+
+echo "dir: $dir"
+
+declare -A tasks
+tasks[qnli]="/private/home/jgu/data/GLUE/QNLI-bin"
+tasks[sst_2]="/private/home/jgu/data/GLUE/SST-2-bin"
+
+lrs="5e-6 1e-5 2e-5 5e-5 1e-4 2e-4 5e-4 1e-3"
+
+for task data_path in ${(kv)tasks}; do
+ for lr in $(echo "$lrs"); do
+ PYTHONPATH=. PREFIX="${PREFIX}" SUFFIX="" nohup python fairseq_cli/hydra_train.py -m --config-dir examples/roberta/config/finetuning \
+ --config-name $task hydra/launcher=submitit_slurm +run_config=slurm_1g task.data="$data_path" hydra.launcher.name=finetune_${task}_${PREFIX} \
+ checkpoint.restore_file="$cp" hydra.sweep.dir="$dir/finetune_sweep/$task/lr_$lr" "optimization.lr=[${lr}]" &
+ done
+done
diff --git a/examples/data2vec/scripts/text/glue.py b/examples/data2vec/scripts/text/glue.py
new file mode 100644
index 000000000..5382d3183
--- /dev/null
+++ b/examples/data2vec/scripts/text/glue.py
@@ -0,0 +1,34 @@
+from valids import parser, main as valids_main
+import os.path as osp
+
+
+args = parser.parse_args()
+args.target = "valid_accuracy"
+args.best_biggest = True
+args.best = True
+args.last = 0
+args.path_contains = None
+
+res = valids_main(args, print_output=False)
+
+grouped = {}
+for k, v in res.items():
+ k = osp.dirname(k)
+ run = osp.dirname(k)
+ task = osp.basename(k)
+ val = v["valid_accuracy"]
+
+ if run not in grouped:
+ grouped[run] = {}
+
+ grouped[run][task] = val
+
+for run, tasks in grouped.items():
+ print(run)
+ avg = sum(float(v) for v in tasks.values()) / len(tasks)
+ avg_norte = sum(float(v) for k,v in tasks.items() if k != 'rte') / (len(tasks) -1)
+ try:
+ print(f"{tasks['cola']}\t{tasks['qnli']}\t{tasks['mrpc']}\t{tasks['rte']}\t{tasks['sst_2']}\t{avg:.2f}\t{avg_norte:.2f}")
+ except:
+ print(tasks)
+ print()
diff --git a/examples/data2vec/scripts/text/glue_lr.py b/examples/data2vec/scripts/text/glue_lr.py
new file mode 100644
index 000000000..75bdfe036
--- /dev/null
+++ b/examples/data2vec/scripts/text/glue_lr.py
@@ -0,0 +1,143 @@
+import os.path as osp
+import re
+from collections import defaultdict
+
+from valids import parser, main as valids_main
+
+
+TASK_TO_METRIC = {
+ "cola": "mcc",
+ "qnli": "accuracy",
+ "mrpc": "acc_and_f1",
+ "rte": "accuracy",
+ "sst_2": "accuracy",
+ "mnli": "accuracy",
+ "qqp": "acc_and_f1",
+ "sts_b": "pearson_and_spearman",
+}
+TASKS = ["cola", "qnli", "mrpc", "rte", "sst_2", "mnli", "qqp", "sts_b"]
+
+
+def get_best_stat_str(task_vals, show_subdir):
+ task_to_best_val = {}
+ task_to_best_dir = {}
+ for task, subdir_to_val in task_vals.items():
+ task_to_best_val[task] = max(subdir_to_val.values())
+ task_to_best_dir[task] = max(subdir_to_val.keys(), key=lambda x: subdir_to_val[x])
+
+ # import pdb; pdb.set_trace()
+ N1 = len(task_to_best_val)
+ N2 = len([k for k in task_to_best_val if k != "rte"])
+ avg1 = sum(task_to_best_val.values()) / N1
+ avg2 = sum(v for task, v in task_to_best_val.items() if task != "rte") / N2
+
+ try:
+ msg = ""
+ for task in TASKS:
+ dir = task_to_best_dir.get(task, 'null')
+ val = task_to_best_val.get(task, -100)
+ msg += f"({dir}, {val})\t" if show_subdir else f"{val}\t"
+ msg += f"{avg1:.2f}\t{avg2:.2f}"
+ except Exception as e:
+ msg = str(e)
+ msg += str(sorted(task_vals.items()))
+ return msg
+
+def get_all_stat_str(task_vals):
+ msg = ""
+ for task in [task for task in TASKS if task in task_vals]:
+ msg += f"=== {task}\n"
+ for subdir in sorted(task_vals[task].keys()):
+ msg += f"\t{subdir}\t{task_vals[task][subdir]}\n"
+ return msg
+
+def get_tabular_stat_str(task_vals):
+ """assume subdir is /run_*/0"""
+ msg = ""
+ for task in [task for task in TASKS if task in task_vals]:
+ msg += f"=== {task}\n"
+ param_to_runs = defaultdict(dict)
+ for subdir in task_vals[task]:
+ match = re.match("(.*)/(run_.*)/0", subdir)
+ assert match, "subdir"
+ param, run = match.groups()
+ param_to_runs[param][run] = task_vals[task][subdir]
+ params = sorted(param_to_runs, key=lambda x: float(x))
+ runs = sorted(set(run for runs in param_to_runs.values() for run in runs))
+ msg += ("runs:" + "\t".join(runs) + "\n")
+ msg += ("params:" + "\t".join(params) + "\n")
+ for param in params:
+ msg += "\t".join([str(param_to_runs[param].get(run, None)) for run in runs])
+ msg += "\n"
+ # for subdir in sorted(task_vals[task].keys()):
+ # msg += f"\t{subdir}\t{task_vals[task][subdir]}\n"
+ return msg
+
+
+
+def main():
+ parser.add_argument("--show_glue", action="store_true", help="show glue metric for each task instead of accuracy")
+ parser.add_argument("--print_mode", default="best", help="best|all|tabular")
+ parser.add_argument("--show_subdir", action="store_true", help="print the subdir that has the best results for each run")
+ parser.add_argument("--override_target", default="valid_accuracy", help="override target")
+
+ args = parser.parse_args()
+ args.target = args.override_target
+ args.best_biggest = True
+ args.best = True
+ args.last = 0
+ args.path_contains = None
+
+ res = valids_main(args, print_output=False)
+ grouped_acc = {}
+ grouped_met = {} # use official metric for each task
+ for path, v in res.items():
+ path = "/".join([args.base, path])
+ path = re.sub("//*", "/", path)
+ match = re.match("(.*)finetune[^/]*/([^/]*)/(.*)", path)
+ if not match:
+ continue
+ run, task, subdir = match.groups()
+
+ if run not in grouped_acc:
+ grouped_acc[run] = {}
+ grouped_met[run] = {}
+ if task not in grouped_acc[run]:
+ grouped_acc[run][task] = {}
+ grouped_met[run][task] = {}
+
+ if v is not None:
+ grouped_acc[run][task][subdir] = float(v.get("valid_accuracy", -100))
+ grouped_met[run][task][subdir] = float(v.get(f"valid_{TASK_TO_METRIC[task]}", -100))
+ else:
+ print(f"{path} has None return")
+
+ header = "\t".join(TASKS)
+ for run in sorted(grouped_acc):
+ print(run)
+ if args.print_mode == "all":
+ if args.show_glue:
+ print("===== GLUE =====")
+ print(get_all_stat_str(grouped_met[run]))
+ else:
+ print("===== ACC =====")
+ print(get_all_stat_str(grouped_acc[run]))
+ elif args.print_mode == "best":
+ print(f" {header}")
+ if args.show_glue:
+ print(f"GLEU: {get_best_stat_str(grouped_met[run], args.show_subdir)}")
+ else:
+ print(f"ACC: {get_best_stat_str(grouped_acc[run], args.show_subdir)}")
+ elif args.print_mode == "tabular":
+ if args.show_glue:
+ print("===== GLUE =====")
+ print(get_tabular_stat_str(grouped_met[run]))
+ else:
+ print("===== ACC =====")
+ print(get_tabular_stat_str(grouped_acc[run]))
+ else:
+ raise ValueError(args.print_mode)
+ print()
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/data2vec/scripts/text/unprocess_data.py b/examples/data2vec/scripts/text/unprocess_data.py
new file mode 100644
index 000000000..f1acb624b
--- /dev/null
+++ b/examples/data2vec/scripts/text/unprocess_data.py
@@ -0,0 +1,188 @@
+import json
+import os
+import tqdm
+from fairseq.data import Dictionary, data_utils
+
+
+def load_dictionary(dict_path):
+ return Dictionary.load(dict_path)
+
+def load_dataset(split_path, src_dict):
+ dataset = data_utils.load_indexed_dataset(
+ split_path,
+ src_dict,
+ combine=False, # set to true for loading `train*`
+ )
+ if dataset is None:
+ raise FileNotFoundError(f"Dataset not found: {split_path}")
+ return dataset
+
+def load_bpe(enc_path):
+ with open(enc_path) as f:
+ bpe2idx = json.load(f)
+ idx2bpe = {v: k for k, v in bpe2idx.items()}
+ return bpe2idx, idx2bpe
+
+def detokenize(tokens, src_dict, idx2bpe):
+ raw_inds = map(int, src_dict.string(tokens).split())
+ raw_chrs = "".join([idx2bpe[raw_ind] for raw_ind in raw_inds])
+ raw_chrs = raw_chrs.replace("\u0120", " ")
+ return raw_chrs
+
+def _main(src_root, src_dict_path, src_bpe_path, src_splits, tgt_root, tgt_splits):
+ src_dict = load_dictionary(src_dict_path)
+ bpe2idx, idx2bpe = load_bpe(src_bpe_path)
+
+ assert len(src_splits) == len(tgt_splits)
+ for src_split, tgt_split in zip(src_splits, tgt_splits):
+ src_dataset = load_dataset(f"{src_root}/{src_split}", src_dict)
+ tgt_path = f"{tgt_root}/{tgt_split}.txt"
+ print(f"processing {src_split} (dump to {tgt_path})...")
+ os.makedirs(os.path.dirname(tgt_path), exist_ok=True)
+ with open(tgt_path, "w") as f:
+ for tokens in tqdm.tqdm(src_dataset):
+ raw_str = detokenize(tokens, src_dict, idx2bpe)
+ f.write(raw_str + "\n")
+
+def main_pt():
+ src_root = "/datasets01/bookwiki_CC-NEWS_openwebtext_stories-mmap2-bin/121219/bookwiki_CC-NEWS_openwebtext_stories-mmap2-bin"
+ src_dict_path = f"{src_root}/dict.txt"
+ src_bpe_path = f"{src_root}/encoder.json"
+ src_splits = [
+ "bookwiki_aml-mmap2-bin/shard0/train",
+ "bookwiki_aml-mmap2-bin/shard1/train",
+ "bookwiki_aml-mmap2-bin/shard2/train",
+ "bookwiki_aml-mmap2-bin/shard3/train",
+ "bookwiki_aml-mmap2-bin/shard4/train",
+ "bookwiki_aml-mmap2-bin/valid/valid",
+ ]
+
+ tgt_root = "/checkpoint/wnhsu/data/data2vec2/data/text/bookwiki_aml-full-mmap2-txt"
+ tgt_splits = [
+ "train0",
+ "train1",
+ "train2",
+ "train3",
+ "train4",
+ "valid",
+ ]
+ _main(src_root, src_dict_path, src_bpe_path, src_splits, tgt_root, tgt_splits)
+
+def main_ft():
+ src_root = "/fsx-wav2vec/wnhsu/data/data2vec2/data/text/GLUE"
+ src_dict_path = f"{src_root}/dict.txt"
+ src_bpe_path = f"{src_root}/encoder.json"
+ src_splits = [
+ "CoLA-bin/input0/train",
+ "CoLA-bin/input0/valid",
+ "CoLA-bin/input0/test",
+
+ "MNLI-bin/input0/train",
+ "MNLI-bin/input0/valid",
+ "MNLI-bin/input0/test",
+ "MNLI-bin/input0/test1",
+ "MNLI-bin/input1/train",
+ "MNLI-bin/input1/valid",
+ "MNLI-bin/input1/test",
+ "MNLI-bin/input1/test1",
+
+ "MRPC-bin/input0/train",
+ "MRPC-bin/input0/valid",
+ "MRPC-bin/input0/test",
+ "MRPC-bin/input1/train",
+ "MRPC-bin/input1/valid",
+ "MRPC-bin/input1/test",
+
+ "QNLI-bin/input0/train",
+ "QNLI-bin/input0/valid",
+ "QNLI-bin/input0/test",
+ "QNLI-bin/input1/train",
+ "QNLI-bin/input1/valid",
+ "QNLI-bin/input1/test",
+
+ "QQP-bin/input0/train",
+ "QQP-bin/input0/valid",
+ "QQP-bin/input0/test",
+ "QQP-bin/input1/train",
+ "QQP-bin/input1/valid",
+ "QQP-bin/input1/test",
+
+ "RTE-bin/input0/train",
+ "RTE-bin/input0/valid",
+ "RTE-bin/input0/test",
+ "RTE-bin/input1/train",
+ "RTE-bin/input1/valid",
+ "RTE-bin/input1/test",
+
+ "SST-2-bin/input0/train",
+ "SST-2-bin/input0/valid",
+ "SST-2-bin/input0/test",
+
+ "STS-B-bin/input0/train",
+ "STS-B-bin/input0/valid",
+ "STS-B-bin/input0/test",
+ "STS-B-bin/input1/train",
+ "STS-B-bin/input1/valid",
+ "STS-B-bin/input1/test",
+ ]
+
+ tgt_root = "/fsx-wav2vec/wnhsu/data/data2vec2/data/text/GLUE_chr"
+ tgt_splits = [
+ "CoLA-bin/input0/train",
+ "CoLA-bin/input0/valid",
+ "CoLA-bin/input0/test",
+
+ "MNLI-bin/input0/train",
+ "MNLI-bin/input0/valid",
+ "MNLI-bin/input0/test",
+ "MNLI-bin/input0/test1",
+ "MNLI-bin/input1/train",
+ "MNLI-bin/input1/valid",
+ "MNLI-bin/input1/test",
+ "MNLI-bin/input1/test1",
+
+ "MRPC-bin/input0/train",
+ "MRPC-bin/input0/valid",
+ "MRPC-bin/input0/test",
+ "MRPC-bin/input1/train",
+ "MRPC-bin/input1/valid",
+ "MRPC-bin/input1/test",
+
+ "QNLI-bin/input0/train",
+ "QNLI-bin/input0/valid",
+ "QNLI-bin/input0/test",
+ "QNLI-bin/input1/train",
+ "QNLI-bin/input1/valid",
+ "QNLI-bin/input1/test",
+
+ "QQP-bin/input0/train",
+ "QQP-bin/input0/valid",
+ "QQP-bin/input0/test",
+ "QQP-bin/input1/train",
+ "QQP-bin/input1/valid",
+ "QQP-bin/input1/test",
+
+ "RTE-bin/input0/train",
+ "RTE-bin/input0/valid",
+ "RTE-bin/input0/test",
+ "RTE-bin/input1/train",
+ "RTE-bin/input1/valid",
+ "RTE-bin/input1/test",
+
+ "SST-2-bin/input0/train",
+ "SST-2-bin/input0/valid",
+ "SST-2-bin/input0/test",
+
+ "STS-B-bin/input0/train",
+ "STS-B-bin/input0/valid",
+ "STS-B-bin/input0/test",
+ "STS-B-bin/input1/train",
+ "STS-B-bin/input1/valid",
+ "STS-B-bin/input1/test",
+ ]
+ _main(src_root, src_dict_path, src_bpe_path, src_splits, tgt_root, tgt_splits)
+
+
+if __name__ == "__main__":
+ main_pt()
+ main_ft()
diff --git a/examples/data2vec/scripts/text/valids.py b/examples/data2vec/scripts/text/valids.py
new file mode 100644
index 000000000..b2e5cfb25
--- /dev/null
+++ b/examples/data2vec/scripts/text/valids.py
@@ -0,0 +1,301 @@
+import os, argparse, re, json, copy, math
+from collections import OrderedDict
+import numpy as np
+
+parser = argparse.ArgumentParser(description='Process some integers.')
+parser.add_argument('base', help='base log path')
+parser.add_argument('--file_name', default='train.log', help='the log file name')
+parser.add_argument('--target', default='valid_loss', help='target metric')
+parser.add_argument('--last', type=int, default=999999999, help='print last n matches')
+parser.add_argument('--last_files', type=int, default=None, help='print last x files')
+parser.add_argument('--everything', action='store_true', help='print everything instead of only last match')
+parser.add_argument('--path_contains', help='only consider matching file pattern')
+parser.add_argument('--group_on', help='if set, groups by this metric and shows table of differences')
+parser.add_argument('--epoch', help='epoch for comparison', type=int)
+parser.add_argument('--skip_empty', action='store_true', help='skip empty results')
+parser.add_argument('--skip_containing', help='skips entries containing this attribute')
+parser.add_argument('--unique_epochs', action='store_true', help='only consider the last line fore each epoch')
+parser.add_argument('--best', action='store_true', help='print the last best result')
+parser.add_argument('--avg_params', help='average these params through entire log')
+parser.add_argument('--extract_prev', help='extracts this metric from previous line')
+
+parser.add_argument('--remove_metric', help='extracts this metric from previous line')
+
+parser.add_argument('--compact', action='store_true', help='if true, just prints checkpoint best val')
+parser.add_argument('--hydra', action='store_true', help='if true, uses hydra param conventions')
+
+parser.add_argument('--best_biggest', action='store_true', help='if true, best is the biggest number, not smallest')
+parser.add_argument('--key_len', type=int, default=10, help='max length of key')
+
+parser.add_argument('--best_only', action='store_true', help='if set, only prints the best value')
+parser.add_argument('--flat', action='store_true', help='just print the best results')
+
+
+def main(args, print_output):
+ ret = {}
+
+ entries = []
+
+ def extract_metric(s, metric):
+ try:
+ j = json.loads(s)
+ except:
+ return None
+ if args.epoch is not None and ('epoch' not in j or j['epoch'] != args.epoch):
+ return None
+ return j[metric] if metric in j else None
+
+
+ def extract_params(s):
+ s = s.replace(args.base, '', 1)
+ if args.path_contains is not None:
+ s = s.replace(args.path_contains, '', 1)
+
+ if args.hydra:
+ num_matches = re.findall(r'(?:/|__)([^/:]+):(\d+\.?\d*)', s)
+ # str_matches = re.findall(r'(?:/|__)([^/:]+):([^\.]*[^\d\.]+)(?:/|__)', s)
+ str_matches = re.findall(r'(?:/|__)?((?:(?!(?:\:|__)).)+):([^\.]*[^\d\.]+\d*)(?:/|__)', s)
+ lr_matches = re.findall(r'optimization.(lr):\[([\d\.,]+)\]', s)
+ task_matches = re.findall(r'.*/(\d+)$', s)
+ else:
+ num_matches = re.findall(r'\.?([^\.]+?)(\d+(e\-\d+)?(?:\.\d+)?)(\.|$)', s)
+ str_matches = re.findall(r'[/\.]([^\.]*[^\d\.]+\d*)(?=\.)', s)
+ lr_matches = []
+ task_matches = []
+
+ cp_matches = re.findall(r'checkpoint(?:_\d+)?_(\d+).pt', s)
+
+ items = OrderedDict()
+ for m in str_matches:
+ if isinstance(m, tuple):
+ if 'checkpoint' not in m[0]:
+ items[m[0]] = m[1]
+ else:
+ items[m] = ''
+
+ for m in num_matches:
+ items[m[0]] = m[1]
+
+ for m in lr_matches:
+ items[m[0]] = m[1]
+
+ for m in task_matches:
+ items["hydra_task"] = m
+
+ for m in cp_matches:
+ items['checkpoint'] = m
+
+ return items
+
+ abs_best = None
+
+ sources = []
+ for root, _, files in os.walk(args.base):
+ if args.path_contains is not None and not args.path_contains in root:
+ continue
+ for f in files:
+ if f.endswith(args.file_name):
+ sources.append((root, f))
+
+ if args.last_files is not None:
+ sources = sources[-args.last_files:]
+
+ for root, file in sources:
+ with open(os.path.join(root, file), 'r') as fin:
+ found = []
+ avg = {}
+ prev = None
+ for line in fin:
+ line = line.rstrip()
+ if line.find(args.target) != -1 and (
+ args.skip_containing is None or line.find(args.skip_containing) == -1):
+ try:
+ idx = line.index("{")
+ line = line[idx:]
+ line_json = json.loads(line)
+ except:
+ continue
+ if prev is not None:
+ try:
+ prev.update(line_json)
+ line_json = prev
+ except:
+ pass
+ if args.target in line_json:
+ found.append(line_json)
+ if args.avg_params:
+ avg_params = args.avg_params.split(',')
+ for p in avg_params:
+ m = extract_metric(line, p)
+ if m is not None:
+ prev_v, prev_c = avg.get(p, (0, 0))
+ avg[p] = prev_v + float(m), prev_c + 1
+ if args.extract_prev:
+ try:
+ prev = json.loads(line)
+ except:
+ pass
+ best = None
+ if args.best:
+ curr_best = None
+ for i in range(len(found)):
+ cand_best = found[i][args.target] if args.target in found[i] else None
+
+ def cmp(a, b):
+ a = float(a)
+ b = float(b)
+ if args.best_biggest:
+ return a > b
+ return a < b
+
+ if cand_best is not None and not math.isnan(float(cand_best)) and (
+ curr_best is None or cmp(cand_best, curr_best)):
+ curr_best = cand_best
+ if abs_best is None or cmp(curr_best, abs_best):
+ abs_best = curr_best
+ best = found[i]
+ if args.unique_epochs or args.epoch:
+ last_found = []
+ last_epoch = None
+ for i in reversed(range(len(found))):
+ epoch = found[i]['epoch']
+ if args.epoch and args.epoch != epoch:
+ continue
+ if epoch != last_epoch:
+ last_epoch = epoch
+ last_found.append(found[i])
+ found = list(reversed(last_found))
+
+ if len(found) == 0:
+ if print_output and (args.last_files is not None or not args.skip_empty):
+ # print(root.split('/')[-1])
+ print(root[len(args.base):])
+ print('Nothing')
+ else:
+ if not print_output:
+ ret[root[len(args.base):]] = best
+ continue
+
+ if args.compact:
+ # print('{}\t{}'.format(root.split('/')[-1], curr_best))
+ print('{}\t{}'.format(root[len(args.base)+1:], curr_best))
+ continue
+
+ if args.group_on is None and not args.best_only:
+ # print(root.split('/')[-1])
+ print(root[len(args.base):])
+ if not args.everything:
+ if best is not None and args.group_on is None and not args.best_only and not args.flat:
+ print(best, '(best)')
+ if args.group_on is None and args.last and not args.best_only and not args.flat:
+ for f in found[-args.last:]:
+ if args.extract_prev is not None:
+ try:
+ print('{}\t{}'.format(f[args.extract_prev], f[args.target]))
+ except Exception as e:
+ print('Exception!', e)
+ else:
+ print(f)
+ try:
+ metric = found[-1][args.target] if not args.best or best is None else best[args.target]
+ except:
+ print(found[-1])
+ raise
+ if metric is not None:
+ entries.append((extract_params(root), metric))
+ else:
+ for f in found:
+ print(f)
+ if not args.group_on and print_output:
+ print()
+
+ if len(avg) > 0:
+ for k, (v, c) in avg.items():
+ print(f'{k}: {v/c}')
+
+ if args.best_only:
+ print(abs_best)
+
+ if args.flat:
+ print("\t".join(m for _, m in entries))
+
+ if args.group_on is not None:
+ by_val = OrderedDict()
+ for e, m in entries:
+ k = args.group_on
+ if k not in e:
+ m_keys = [x for x in e.keys() if x.startswith(k)]
+ if len(m_keys) == 0:
+ val = "False"
+ else:
+ assert len(m_keys) == 1
+ k = m_keys[0]
+ val = m_keys[0]
+ else:
+ val = e[args.group_on]
+ if val == "":
+ val = "True"
+ scrubbed_entry = copy.deepcopy(e)
+ if k in scrubbed_entry:
+ del scrubbed_entry[k]
+ if args.remove_metric and args.remove_metric in scrubbed_entry:
+ val += '_' + scrubbed_entry[args.remove_metric]
+ del scrubbed_entry[args.remove_metric]
+ by_val.setdefault(tuple(scrubbed_entry.items()), dict())[val] = m
+ distinct_vals = set()
+ for v in by_val.values():
+ distinct_vals.update(v.keys())
+ try:
+ distinct_vals = {int(d) for d in distinct_vals}
+ except:
+ print(distinct_vals)
+ print()
+ print("by_val", len(by_val))
+ for k,v in by_val.items():
+ print(k, '=>', v)
+ print()
+
+ # , by_val, entries)
+ raise
+ from natsort import natsorted
+ svals = list(map(str, natsorted(distinct_vals)))
+ print('{}\t{}'.format(args.group_on, '\t'.join(svals)))
+ sums = OrderedDict({n:[] for n in svals})
+ for k, v in by_val.items():
+ kstr = '.'.join(':'.join(x) for x in k)
+ vstr = ''
+ for mv in svals:
+ x = v[mv] if mv in v else ''
+ vstr += '\t{}'.format(round(x, 5) if isinstance(x, float) else x)
+ try:
+ sums[mv].append(float(x))
+ except:
+ pass
+ print('{}{}'.format(kstr[:args.key_len], vstr))
+ if any(len(x) > 0 for x in sums.values()):
+ print('min:', end='')
+ for v in sums.values():
+ min = np.min(v)
+ print(f'\t{round(min, 5)}', end='')
+ print()
+ print('max:', end='')
+ for v in sums.values():
+ max = np.max(v)
+ print(f'\t{round(max, 5)}', end='')
+ print()
+ print('avg:', end='')
+ for v in sums.values():
+ mean = np.mean(v)
+ print(f'\t{round(mean, 5)}', end='')
+ print()
+ print('median:', end='')
+ for v in sums.values():
+ median = np.median(v)
+ print(f'\t{round(median, 5)}', end='')
+ print()
+
+ return ret
+
+if __name__ == "__main__":
+ args = parser.parse_args()
+ main(args, print_output=True)
\ No newline at end of file
diff --git a/examples/data2vec/tasks/__init__.py b/examples/data2vec/tasks/__init__.py
new file mode 100644
index 000000000..a7422e4b3
--- /dev/null
+++ b/examples/data2vec/tasks/__init__.py
@@ -0,0 +1,18 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+from .image_pretraining import ImagePretrainingTask, ImagePretrainingConfig
+from .image_classification import ImageClassificationTask, ImageClassificationConfig
+from .mae_image_pretraining import MaeImagePretrainingTask, MaeImagePretrainingConfig
+
+
+__all__ = [
+ "ImageClassificationTask",
+ "ImageClassificationConfig",
+ "ImagePretrainingTask",
+ "ImagePretrainingConfig",
+ "MaeImagePretrainingTask",
+ "MaeImagePretrainingConfig",
+]
\ No newline at end of file
diff --git a/examples/data2vec/tasks/audio_classification.py b/examples/data2vec/tasks/audio_classification.py
new file mode 100644
index 000000000..2925a04cf
--- /dev/null
+++ b/examples/data2vec/tasks/audio_classification.py
@@ -0,0 +1,167 @@
+# Copyright (c) 2017-present, Facebook, Inc.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the LICENSE file in
+# the root directory of this source tree. An additional grant of patent rights
+# can be found in the PATENTS file in the same directory.
+
+import logging
+import os
+import numpy as np
+import math
+import torch
+
+from sklearn import metrics as sklearn_metrics
+from dataclasses import dataclass
+
+from fairseq.tasks.audio_pretraining import AudioPretrainingTask, AudioPretrainingConfig
+from fairseq.tasks import register_task
+from fairseq.logging import metrics
+
+from ..data.add_class_target_dataset import AddClassTargetDataset
+
+
+logger = logging.getLogger(__name__)
+
+
+@dataclass
+class AudioClassificationConfig(AudioPretrainingConfig):
+ label_descriptors: str = "label_descriptors.csv"
+ labels: str = "lbl"
+
+
+@register_task("audio_classification", dataclass=AudioClassificationConfig)
+class AudioClassificationTask(AudioPretrainingTask):
+ """ """
+
+ cfg: AudioClassificationConfig
+
+ def __init__(
+ self,
+ cfg: AudioClassificationConfig,
+ ):
+ super().__init__(cfg)
+
+ self.state.add_factory("labels", self.load_labels)
+
+ def load_labels(self):
+ labels = {}
+ path = os.path.join(self.cfg.data, self.cfg.label_descriptors)
+ with open(path, "r") as ldf:
+ for line in ldf:
+ if line.strip() == "":
+ continue
+ items = line.split(",")
+ idx = items[0]
+ lbl = items[1]
+ assert lbl not in labels, lbl
+ labels[lbl] = idx
+ return labels
+
+ @property
+ def labels(self):
+ return self.state.labels
+
+ def load_dataset(
+ self, split: str, task_cfg: AudioClassificationConfig = None, **kwargs
+ ):
+ super().load_dataset(split, task_cfg, **kwargs)
+
+ task_cfg = task_cfg or self.cfg
+
+ data_path = self.cfg.data
+ label_path = os.path.join(data_path, f"{split}.{task_cfg.labels}")
+ skipped_indices = getattr(self.datasets[split], "skipped_indices", set())
+ labels = []
+ with open(label_path, "r") as f:
+ for i, line in enumerate(f):
+ if i not in skipped_indices:
+ lbl_items = line.rstrip().split("\t")
+ labels.append([int(x) for x in lbl_items[2].split(",")])
+
+ assert len(labels) == len(self.datasets[split]), (
+ f"labels length ({len(labels)}) and dataset length "
+ f"({len(self.datasets[split])}) do not match"
+ )
+
+ self.datasets[split] = AddClassTargetDataset(
+ self.datasets[split],
+ labels,
+ multi_class=True,
+ add_to_input=True,
+ num_classes=len(self.labels),
+ )
+
+ def calculate_stats(self, output, target):
+
+ classes_num = target.shape[-1]
+ stats = []
+
+ # Accuracy, only used for single-label classification such as esc-50, not for multiple label one such as AudioSet
+ # acc = sklearn_metrics.accuracy_score(np.argmax(target, 1), np.argmax(output, 1))
+
+ # Class-wise statistics
+ for k in range(classes_num):
+ # Average precision
+ avg_precision = sklearn_metrics.average_precision_score(
+ target[:, k], output[:, k], average=None
+ )
+
+ dict = {
+ "AP": avg_precision,
+ }
+
+ # # AUC
+ # try:
+ # auc = sklearn_metrics.roc_auc_score(target[:, k], output[:, k], average=None)
+ # except:
+ # auc = 0
+ #
+ # # Precisions, recalls
+ # (precisions, recalls, thresholds) = sklearn_metrics.precision_recall_curve(
+ # target[:, k], output[:, k]
+ # )
+ #
+ # # FPR, TPR
+ # (fpr, tpr, thresholds) = sklearn_metrics.roc_curve(target[:, k], output[:, k])
+ #
+ # save_every_steps = 1000 # Sample statistics to reduce size
+ # dict = {
+ # "precisions": precisions[0::save_every_steps],
+ # "recalls": recalls[0::save_every_steps],
+ # "AP": avg_precision,
+ # "fpr": fpr[0::save_every_steps],
+ # "fnr": 1.0 - tpr[0::save_every_steps],
+ # "auc": auc,
+ # # note acc is not class-wise, this is just to keep consistent with other metrics
+ # "acc": acc,
+ # }
+ stats.append(dict)
+
+ return stats
+
+ def valid_step(self, sample, model, criterion):
+ loss, sample_size, logging_output = super().valid_step(sample, model, criterion)
+ return loss, sample_size, logging_output
+
+ def reduce_metrics(self, logging_outputs, criterion):
+ super().reduce_metrics(logging_outputs, criterion)
+ if "_predictions" in logging_outputs[0]:
+ metrics.log_concat_tensor(
+ "_predictions",
+ torch.cat([l["_predictions"].cpu() for l in logging_outputs], dim=0),
+ )
+ metrics.log_concat_tensor(
+ "_targets",
+ torch.cat([l["_targets"].cpu() for l in logging_outputs], dim=0),
+ )
+
+ def compute_stats(meters):
+ if meters["_predictions"].tensor.shape[0] < 100:
+ return 0
+ stats = self.calculate_stats(
+ meters["_predictions"].tensor, meters["_targets"].tensor
+ )
+ return np.nanmean([stat["AP"] for stat in stats])
+
+ metrics.log_derived("mAP", compute_stats)
diff --git a/examples/data2vec/tasks/image_classification.py b/examples/data2vec/tasks/image_classification.py
new file mode 100644
index 000000000..1ea4c2afe
--- /dev/null
+++ b/examples/data2vec/tasks/image_classification.py
@@ -0,0 +1,129 @@
+# Copyright (c) 2017-present, Facebook, Inc.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the LICENSE file in
+# the root directory of this source tree. An additional grant of patent rights
+# can be found in the PATENTS file in the same directory.
+
+import os.path as osp
+import logging
+
+from dataclasses import dataclass
+import torch
+from torchvision import transforms
+
+from fairseq.dataclass import FairseqDataclass
+from fairseq.tasks import register_task
+from fairseq.logging import metrics
+
+try:
+ from ..data import ImageDataset
+except:
+ import sys
+
+ sys.path.append("..")
+ from data import ImageDataset
+
+from .image_pretraining import (
+ ImagePretrainingConfig,
+ ImagePretrainingTask,
+ IMG_EXTENSIONS,
+)
+
+logger = logging.getLogger(__name__)
+
+
+@dataclass
+class ImageClassificationConfig(ImagePretrainingConfig):
+ pass
+
+
+@register_task("image_classification", dataclass=ImageClassificationConfig)
+class ImageClassificationTask(ImagePretrainingTask):
+
+ cfg: ImageClassificationConfig
+
+ @classmethod
+ def setup_task(cls, cfg: ImageClassificationConfig, **kwargs):
+ return cls(cfg)
+
+ def load_dataset(self, split: str, task_cfg: FairseqDataclass = None, **kwargs):
+ data_path = self.cfg.data
+ cfg = task_cfg or self.cfg
+
+ path_with_split = osp.join(data_path, split)
+ if osp.exists(path_with_split):
+ data_path = path_with_split
+
+ from timm.data import create_transform
+
+ if split == "train":
+ # this should always dispatch to transforms_imagenet_train
+ transform = create_transform(
+ input_size=cfg.input_size,
+ is_training=True,
+ auto_augment="rand-m9-mstd0.5-inc1",
+ interpolation="bicubic",
+ re_prob=0.25,
+ re_mode="pixel",
+ re_count=1,
+ mean=cfg.normalization_mean,
+ std=cfg.normalization_std,
+ )
+ if not cfg.input_size > 32:
+ transform.transforms[0] = transforms.RandomCrop(
+ cfg.input_size, padding=4
+ )
+ else:
+ t = []
+ if cfg.input_size > 32:
+ crop_pct = 1
+ if cfg.input_size < 384:
+ crop_pct = 224 / 256
+ size = int(cfg.input_size / crop_pct)
+ t.append(
+ transforms.Resize(
+ size, interpolation=3
+ ), # to maintain same ratio w.r.t. 224 images
+ )
+ t.append(transforms.CenterCrop(cfg.input_size))
+
+ t.append(transforms.ToTensor())
+ t.append(
+ transforms.Normalize(cfg.normalization_mean, cfg.normalization_std)
+ )
+ transform = transforms.Compose(t)
+ logger.info(transform)
+
+ self.datasets[split] = ImageDataset(
+ root=data_path,
+ extensions=IMG_EXTENSIONS,
+ load_classes=True,
+ transform=transform,
+ )
+ for k in self.datasets.keys():
+ if k != split:
+ assert self.datasets[k].classes == self.datasets[split].classes
+
+ def build_model(self, model_cfg: FairseqDataclass, from_checkpoint=False):
+ model = super().build_model(model_cfg, from_checkpoint)
+
+ actualized_cfg = getattr(model, "cfg", None)
+ if actualized_cfg is not None:
+ if hasattr(actualized_cfg, "pretrained_model_args"):
+ model_cfg.pretrained_model_args = actualized_cfg.pretrained_model_args
+
+ return model
+
+ def reduce_metrics(self, logging_outputs, criterion):
+ super().reduce_metrics(logging_outputs, criterion)
+
+ if "correct" in logging_outputs[0]:
+ zero = torch.scalar_tensor(0.0)
+ correct = sum(log.get("correct", zero) for log in logging_outputs)
+ metrics.log_scalar_sum("_correct", correct)
+
+ metrics.log_derived(
+ "accuracy",
+ lambda meters: 100 * meters["_correct"].sum / meters["sample_size"].sum,
+ )
diff --git a/examples/data2vec/tasks/image_pretraining.py b/examples/data2vec/tasks/image_pretraining.py
new file mode 100644
index 000000000..cd688fd13
--- /dev/null
+++ b/examples/data2vec/tasks/image_pretraining.py
@@ -0,0 +1,110 @@
+# Copyright (c) 2017-present, Facebook, Inc.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the LICENSE file in
+# the root directory of this source tree. An additional grant of patent rights
+# can be found in the PATENTS file in the same directory.
+
+import logging
+import sys
+import os.path as osp
+
+from dataclasses import dataclass, field
+from typing import List
+from omegaconf import MISSING
+
+import torch
+from torchvision import transforms
+
+from fairseq.dataclass import FairseqDataclass
+from fairseq.tasks import FairseqTask, register_task
+
+try:
+ from ..data import ImageDataset
+except:
+ sys.path.append("..")
+ from data import ImageDataset
+
+logger = logging.getLogger(__name__)
+
+IMG_EXTENSIONS = {
+ ".jpg",
+ ".jpeg",
+ ".png",
+ ".ppm",
+ ".bmp",
+ ".pgm",
+ ".tif",
+ ".tiff",
+ ".webp",
+}
+
+
+@dataclass
+class ImagePretrainingConfig(FairseqDataclass):
+ data: str = field(default=MISSING, metadata={"help": "path to data directory"})
+ input_size: int = 224
+ normalization_mean: List[float] = (0.485, 0.456, 0.406)
+ normalization_std: List[float] = (0.229, 0.224, 0.225)
+
+
+@register_task("image_pretraining", dataclass=ImagePretrainingConfig)
+class ImagePretrainingTask(FairseqTask):
+ """ """
+
+ cfg: ImagePretrainingConfig
+
+ @classmethod
+ def setup_task(cls, cfg: ImagePretrainingConfig, **kwargs):
+ """Setup the task (e.g., load dictionaries).
+
+ Args:
+ cfg (AudioPretrainingConfig): configuration of this task
+ """
+
+ return cls(cfg)
+
+ def load_dataset(self, split: str, task_cfg: FairseqDataclass = None, **kwargs):
+ data_path = self.cfg.data
+ cfg = task_cfg or self.cfg
+
+ path_with_split = osp.join(data_path, split)
+ if osp.exists(path_with_split):
+ data_path = path_with_split
+
+ transform = transforms.Compose(
+ [
+ transforms.ColorJitter(0.4, 0.4, 0.4),
+ transforms.RandomHorizontalFlip(p=0.5),
+ transforms.RandomResizedCrop(
+ size=cfg.input_size,
+ interpolation=transforms.InterpolationMode.BICUBIC,
+ ),
+ transforms.ToTensor(),
+ transforms.Normalize(
+ mean=torch.tensor(cfg.normalization_mean),
+ std=torch.tensor(cfg.normalization_std),
+ ),
+ ]
+ )
+
+ logger.info(transform)
+
+ self.datasets[split] = ImageDataset(
+ root=data_path,
+ extensions=IMG_EXTENSIONS,
+ load_classes=False,
+ transform=transform,
+ )
+
+ @property
+ def source_dictionary(self):
+ return None
+
+ @property
+ def target_dictionary(self):
+ return None
+
+ def max_positions(self):
+ """Maximum input length supported by the encoder."""
+ return sys.maxsize, sys.maxsize
diff --git a/examples/data2vec/tasks/mae_image_classification.py b/examples/data2vec/tasks/mae_image_classification.py
new file mode 100644
index 000000000..1bf935879
--- /dev/null
+++ b/examples/data2vec/tasks/mae_image_classification.py
@@ -0,0 +1,100 @@
+# Copyright (c) 2017-present, Facebook, Inc.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the LICENSE file in
+# the root directory of this source tree. An additional grant of patent rights
+# can be found in the PATENTS file in the same directory.
+
+import logging
+import sys
+import torch
+
+from typing import Optional
+from dataclasses import dataclass, field
+from omegaconf import MISSING
+
+from fairseq.dataclass import FairseqDataclass
+from fairseq.tasks import FairseqTask, register_task
+from fairseq.logging import metrics
+
+try:
+ from ..data import MaeFinetuningImageDataset
+except:
+ sys.path.append("..")
+ from data import MaeFinetuningImageDataset
+
+logger = logging.getLogger(__name__)
+
+
+@dataclass
+class MaeImageClassificationConfig(FairseqDataclass):
+ data: str = field(default=MISSING, metadata={"help": "path to data directory"})
+ input_size: int = 224
+ local_cache_path: Optional[str] = None
+
+ rebuild_batches: bool = True
+
+
+@register_task("mae_image_classification", dataclass=MaeImageClassificationConfig)
+class MaeImageClassificationTask(FairseqTask):
+ """ """
+
+ cfg: MaeImageClassificationConfig
+
+ @classmethod
+ def setup_task(cls, cfg: MaeImageClassificationConfig, **kwargs):
+ """Setup the task (e.g., load dictionaries).
+
+ Args:
+ cfg (AudioPretrainingConfig): configuration of this task
+ """
+
+ return cls(cfg)
+
+ def load_dataset(self, split: str, task_cfg: FairseqDataclass = None, **kwargs):
+ data_path = self.cfg.data
+ cfg = task_cfg or self.cfg
+
+ self.datasets[split] = MaeFinetuningImageDataset(
+ root=data_path,
+ split=split,
+ is_train=split == "train",
+ input_size=cfg.input_size,
+ local_cache_path=cfg.local_cache_path,
+ shuffle=split == "train",
+ )
+
+ def build_model(self, model_cfg: FairseqDataclass, from_checkpoint=False):
+ model = super().build_model(model_cfg, from_checkpoint)
+
+ actualized_cfg = getattr(model, "cfg", None)
+ if actualized_cfg is not None:
+ if hasattr(actualized_cfg, "pretrained_model_args"):
+ model_cfg.pretrained_model_args = actualized_cfg.pretrained_model_args
+
+ return model
+
+ def reduce_metrics(self, logging_outputs, criterion):
+ super().reduce_metrics(logging_outputs, criterion)
+
+ if "correct" in logging_outputs[0]:
+ zero = torch.scalar_tensor(0.0)
+ correct = sum(log.get("correct", zero) for log in logging_outputs)
+ metrics.log_scalar_sum("_correct", correct)
+
+ metrics.log_derived(
+ "accuracy",
+ lambda meters: 100 * meters["_correct"].sum / meters["sample_size"].sum,
+ )
+
+ @property
+ def source_dictionary(self):
+ return None
+
+ @property
+ def target_dictionary(self):
+ return None
+
+ def max_positions(self):
+ """Maximum input length supported by the encoder."""
+ return sys.maxsize, sys.maxsize
diff --git a/examples/data2vec/tasks/mae_image_pretraining.py b/examples/data2vec/tasks/mae_image_pretraining.py
new file mode 100644
index 000000000..35a14891c
--- /dev/null
+++ b/examples/data2vec/tasks/mae_image_pretraining.py
@@ -0,0 +1,119 @@
+# Copyright (c) 2017-present, Facebook, Inc.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the LICENSE file in
+# the root directory of this source tree. An additional grant of patent rights
+# can be found in the PATENTS file in the same directory.
+
+import logging
+import sys
+
+from typing import Optional, List
+from dataclasses import dataclass, field
+from omegaconf import MISSING, II
+
+from fairseq.data import SubsampleDataset
+from fairseq.dataclass import FairseqDataclass
+from fairseq.tasks import FairseqTask, register_task
+
+try:
+ from ..data import MaeImageDataset
+except:
+ sys.path.append("..")
+ from data import MaeImageDataset
+
+logger = logging.getLogger(__name__)
+
+
+@dataclass
+class ImageMaskingConfig:
+ patch_size: int = II("model.modalities.image.patch_size")
+ mask_prob: float = II("model.modalities.image.mask_prob")
+ mask_prob_adjust: float = II("model.modalities.image.mask_prob_adjust")
+ mask_length: int = II("model.modalities.image.mask_length")
+ inverse_mask: bool = II("model.modalities.image.inverse_mask")
+ mask_dropout: float = II("model.modalities.image.mask_dropout")
+ clone_batch: int = II("model.clone_batch")
+ expand_adjacent: bool = False
+ non_overlapping: bool = False
+
+
+@dataclass
+class MaeImagePretrainingConfig(FairseqDataclass):
+ data: str = field(default=MISSING, metadata={"help": "path to data directory"})
+ multi_data: Optional[List[str]] = None
+ input_size: int = 224
+ local_cache_path: Optional[str] = None
+ key: str = "imgs"
+
+ beit_transforms: bool = False
+ target_transform: bool = False
+ no_transform: bool = False
+
+ rebuild_batches: bool = True
+
+ precompute_mask_config: Optional[ImageMaskingConfig] = None
+
+ subsample: float = 1
+ seed: int = II("common.seed")
+ dataset_type: str = "imagefolder"
+
+
+@register_task("mae_image_pretraining", dataclass=MaeImagePretrainingConfig)
+class MaeImagePretrainingTask(FairseqTask):
+ """ """
+
+ cfg: MaeImagePretrainingConfig
+
+ @classmethod
+ def setup_task(cls, cfg: MaeImagePretrainingConfig, **kwargs):
+ """Setup the task (e.g., load dictionaries).
+
+ Args:
+ cfg (AudioPretrainingConfig): configuration of this task
+ """
+
+ return cls(cfg)
+
+ def load_dataset(self, split: str, task_cfg: FairseqDataclass = None, **kwargs):
+ data_path = self.cfg.data
+ cfg = task_cfg or self.cfg
+
+ compute_mask = cfg.precompute_mask_config is not None
+ mask_args = {}
+ if compute_mask:
+ mask_args = cfg.precompute_mask_config
+
+ self.datasets[split] = MaeImageDataset(
+ root=data_path if cfg.multi_data is None else cfg.multi_data,
+ split=split,
+ input_size=cfg.input_size,
+ local_cache_path=cfg.local_cache_path,
+ key=cfg.key,
+ beit_transforms=cfg.beit_transforms,
+ target_transform=cfg.target_transform,
+ no_transform=cfg.no_transform,
+ compute_mask=compute_mask,
+ dataset_type=cfg.dataset_type,
+ **mask_args,
+ )
+
+ if cfg.subsample < 1:
+ self.datasets[split] = SubsampleDataset(
+ self.datasets[split],
+ cfg.subsample,
+ shuffle=True,
+ seed=cfg.seed,
+ )
+
+ @property
+ def source_dictionary(self):
+ return None
+
+ @property
+ def target_dictionary(self):
+ return None
+
+ def max_positions(self):
+ """Maximum input length supported by the encoder."""
+ return sys.maxsize, sys.maxsize
diff --git a/examples/data2vec/tasks/multimodal.py b/examples/data2vec/tasks/multimodal.py
new file mode 100644
index 000000000..74648e918
--- /dev/null
+++ b/examples/data2vec/tasks/multimodal.py
@@ -0,0 +1,165 @@
+# Copyright (c) 2017-present, Facebook, Inc.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the LICENSE file in
+# the root directory of this source tree. An additional grant of patent rights
+# can be found in the PATENTS file in the same directory.
+
+import sys
+
+from dataclasses import dataclass
+from typing import Optional, List
+from omegaconf import II
+
+from fairseq.data.iterators import GroupedEpochBatchIterator
+
+from fairseq.dataclass import FairseqDataclass
+from fairseq.tasks import FairseqTask, register_task
+from fairseq.tasks.audio_pretraining import AudioPretrainingConfig, AudioPretrainingTask
+from fairseq.tasks.masked_lm import MaskedLMConfig, MaskedLMTask
+from .mae_image_pretraining import MaeImagePretrainingConfig, MaeImagePretrainingTask
+from examples.data2vec.data.modality import Modality
+
+from fairseq.data.audio.multi_modality_dataset import (
+ MultiModalityDataset,
+ ModalityDatasetItem,
+)
+
+
+@dataclass
+class MultimodalPretrainingConfig(FairseqDataclass):
+ audio: Optional[AudioPretrainingConfig] = None
+ image: Optional[MaeImagePretrainingConfig] = None
+ text: Optional[MaskedLMConfig] = None
+
+ audio_ratio: float = 1
+ image_ratio: float = 1
+ text_ratio: float = 1
+
+ max_tokens: Optional[int] = II("dataset.max_tokens")
+ batch_size: Optional[int] = II("dataset.batch_size")
+ update_freq: List[int] = II("optimization.update_freq")
+
+ rebuild_batches: bool = True
+
+
+@register_task("multimodal_pretraining", dataclass=MultimodalPretrainingConfig)
+class MultimodalPretrainingTask(FairseqTask):
+ """ """
+
+ cfg: MultimodalPretrainingConfig
+
+ def __init__(self, cfg: MultimodalPretrainingConfig):
+ super().__init__(cfg)
+ self.audio_task = (
+ AudioPretrainingTask(cfg.audio) if cfg.audio is not None else None
+ )
+ self.image_task = (
+ MaeImagePretrainingTask(cfg.image) if cfg.image is not None else None
+ )
+ self.text_task = MaskedLMTask(cfg.text) if cfg.text is not None else None
+
+ self.mult_ratios = []
+
+ @classmethod
+ def setup_task(cls, cfg: MultimodalPretrainingConfig, **kwargs):
+ """Setup the task (e.g., load dictionaries).
+
+ Args:
+ cfg (AudioPretrainingConfig): configuration of this task
+ """
+
+ return cls(cfg)
+
+ def load_dataset(self, split: str, task_cfg: FairseqDataclass = None, **kwargs):
+ datasets = []
+ self.mult_ratios = []
+
+ def load_ds(task, name, ratio):
+ if task is not None:
+ task.load_dataset(split)
+ ds = ModalityDatasetItem(
+ datasetname=name,
+ dataset=task.dataset(split),
+ max_positions=task.max_positions(),
+ max_tokens=self.cfg.max_tokens,
+ max_sentences=self.cfg.batch_size,
+ )
+ datasets.append(ds)
+ self.mult_ratios.append(ratio)
+
+ load_ds(self.audio_task, Modality.AUDIO, self.cfg.audio_ratio)
+ load_ds(self.image_task, Modality.IMAGE, self.cfg.image_ratio)
+ load_ds(self.text_task, Modality.TEXT, self.cfg.text_ratio)
+
+ assert len(datasets) > 0
+
+ self.datasets[split] = MultiModalityDataset(datasets)
+
+ @property
+ def supported_modalities(self):
+ modalities = []
+ if self.cfg.text is not None:
+ modalities.append(Modality.TEXT)
+ if self.cfg.audio is not None:
+ modalities.append(Modality.AUDIO)
+ if self.cfg.image is not None:
+ modalities.append(Modality.IMAGE)
+
+ return modalities
+
+ def get_batch_iterator(
+ self,
+ dataset,
+ max_tokens=None,
+ max_sentences=None,
+ max_positions=None,
+ ignore_invalid_inputs=False,
+ required_batch_size_multiple=1,
+ seed=1,
+ num_shards=1,
+ shard_id=0,
+ num_workers=0,
+ epoch=0,
+ data_buffer_size=0,
+ disable_iterator_cache=False,
+ skip_remainder_batch=False,
+ grouped_shuffling=False,
+ update_epoch_batch_itr=False,
+ ):
+
+ # initialize the dataset with the correct starting epoch
+ dataset.set_epoch(epoch)
+
+ batch_samplers = dataset.get_batch_samplers(
+ self.mult_ratios, required_batch_size_multiple, seed
+ )
+
+ # return a reusable, sharded iterator
+ epoch_iter = GroupedEpochBatchIterator(
+ dataset=dataset,
+ collate_fn=dataset.collater,
+ batch_samplers=batch_samplers,
+ seed=seed,
+ num_shards=num_shards,
+ shard_id=shard_id,
+ num_workers=num_workers,
+ epoch=epoch,
+ mult_rate=max(self.cfg.update_freq),
+ buffer_size=data_buffer_size,
+ skip_remainder_batch=skip_remainder_batch,
+ )
+ self.dataset_to_epoch_iter[dataset] = {} # refresh it every epoch
+ return epoch_iter
+
+ @property
+ def source_dictionary(self):
+ return None
+
+ @property
+ def target_dictionary(self):
+ return None
+
+ def max_positions(self):
+ """Maximum input length supported by the encoder."""
+ return sys.maxsize, sys.maxsize
diff --git a/examples/roberta/config/finetuning/run_config/local.yaml b/examples/roberta/config/finetuning/run_config/local.yaml
new file mode 100644
index 000000000..45595f9ee
--- /dev/null
+++ b/examples/roberta/config/finetuning/run_config/local.yaml
@@ -0,0 +1,15 @@
+# @package _global_
+hydra:
+ sweep:
+ dir: ${env:PWD}/tmp_dbg/${now:%H-%M-%S}
+
+distributed_training:
+ distributed_world_size: 1
+ nprocs_per_node: 1
+ distributed_port: -1
+
+common:
+ log_interval: 1
+
+dataset:
+ num_workers: 0
diff --git a/examples/roberta/config/finetuning/run_config/slurm_1g.yaml b/examples/roberta/config/finetuning/run_config/slurm_1g.yaml
new file mode 100644
index 000000000..8bc21854d
--- /dev/null
+++ b/examples/roberta/config/finetuning/run_config/slurm_1g.yaml
@@ -0,0 +1,28 @@
+
+# @package _global_
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: '_'
+ item_sep: '/'
+ exclude_keys:
+ - run_config
+ - distributed_training.distributed_port
+ sweep:
+ dir: /checkpoint/${env:USER}/roberta_ft/${env:PREFIX}/${hydra.job.config_name}/${env:SUFFIX}
+ subdir: ${hydra.job.num}
+ launcher:
+ submitit_folder: ${hydra.sweep.dir}/submitit
+ timeout_min: 1000
+ cpus_per_task: 8
+ gpus_per_node: 1
+ tasks_per_node: 1
+ mem_gb: 60
+ nodes: 1
+ name: ${env:PREFIX}_${hydra.job.config_name}
+ partition: devlab,learnlab,learnfair,scavenge
+ constraint: volta32gb
+ max_num_timeout: 30
+ exclude: learnfair1381,learnfair5192,learnfair2304
diff --git a/examples/roberta/config/finetuning/run_config/slurm_1g_aws.yaml b/examples/roberta/config/finetuning/run_config/slurm_1g_aws.yaml
new file mode 100644
index 000000000..085391cff
--- /dev/null
+++ b/examples/roberta/config/finetuning/run_config/slurm_1g_aws.yaml
@@ -0,0 +1,25 @@
+# @package _global_
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: '_'
+ item_sep: '/'
+ exclude_keys:
+ - run_config
+ - distributed_training.distributed_port
+ sweep:
+ dir: /fsx-wav2vec/${env:USER}/roberta_ft/${env:PREFIX}/${hydra.job.config_name}/${env:SUFFIX}
+ subdir: ${hydra.job.num}
+ launcher:
+ submitit_folder: ${hydra.sweep.dir}/submitit
+ timeout_min: 1000
+ cpus_per_task: 8
+ gpus_per_node: 1
+ tasks_per_node: 1
+ mem_gb: 0
+ nodes: 1
+ name: ${env:PREFIX}_${hydra.job.config_name}
+ partition: learnfair,wav2vec
+ max_num_timeout: 30
diff --git a/examples/roberta/config/pretraining/run_config/local.yaml b/examples/roberta/config/pretraining/run_config/local.yaml
new file mode 100644
index 000000000..45595f9ee
--- /dev/null
+++ b/examples/roberta/config/pretraining/run_config/local.yaml
@@ -0,0 +1,15 @@
+# @package _global_
+hydra:
+ sweep:
+ dir: ${env:PWD}/tmp_dbg/${now:%H-%M-%S}
+
+distributed_training:
+ distributed_world_size: 1
+ nprocs_per_node: 1
+ distributed_port: -1
+
+common:
+ log_interval: 1
+
+dataset:
+ num_workers: 0
diff --git a/examples/roberta/config/pretraining/run_config/slurm_2.yaml b/examples/roberta/config/pretraining/run_config/slurm_2.yaml
new file mode 100644
index 000000000..006a0f211
--- /dev/null
+++ b/examples/roberta/config/pretraining/run_config/slurm_2.yaml
@@ -0,0 +1,37 @@
+# @package _global_
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: ':'
+ item_sep: '/'
+ exclude_keys:
+ - run_config
+ - distributed_training.distributed_port
+ - distributed_training.distributed_world_size
+ - model.pretrained_model_path
+ - model.target_network_path
+ - next_script
+ - task.cache_in_scratch
+ - task.data
+ - checkpoint.save_interval_updates
+ - checkpoint.keep_interval_updates
+ - checkpoint.save_on_overflow
+ - common.log_interval
+ - common.user_dir
+ sweep:
+ dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
+ subdir: ''
+ launcher:
+ submitit_folder: ${hydra.sweep.dir}
+ timeout_min: 4320
+ cpus_per_task: 80
+ gpus_per_node: 8
+ tasks_per_node: 1
+ mem_gb: 450
+ nodes: 2
+ name: ${env:PREFIX}_${hydra.job.config_name}
+ partition: devlab,learnlab,learnfair,scavenge
+ constraint: volta32gb,ib4
+ max_num_timeout: 30
diff --git a/examples/roberta/config/pretraining/run_config/slurm_2_aws.yaml b/examples/roberta/config/pretraining/run_config/slurm_2_aws.yaml
new file mode 100644
index 000000000..a5937ea5a
--- /dev/null
+++ b/examples/roberta/config/pretraining/run_config/slurm_2_aws.yaml
@@ -0,0 +1,39 @@
+# @package _global_
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: ':'
+ item_sep: '/'
+ exclude_keys:
+ - run_config
+ - distributed_training.distributed_port
+ - distributed_training.distributed_world_size
+ - model.pretrained_model_path
+ - model.target_network_path
+ - next_script
+ - task.cache_in_scratch
+ - task.local_cache_path
+ - task.data
+ - task.post_save_script
+ - checkpoint.save_interval_updates
+ - checkpoint.keep_interval_updates
+ - checkpoint.save_on_overflow
+ - common.log_interval
+ - common.user_dir
+ - model.model_path
+ sweep:
+ dir: /fsx-wav2vec/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
+ subdir: ''
+ launcher:
+ submitit_folder: ${hydra.sweep.dir}
+ timeout_min: 4320
+ cpus_per_task: 10
+ gpus_per_node: 8
+ tasks_per_node: 8
+ mem_gb: 0
+ nodes: 2
+ name: ${env:PREFIX}_${hydra.job.config_name}
+ partition: wav2vec
+ max_num_timeout: 30
diff --git a/examples/roberta/config/pretraining/run_config/slurm_3.yaml b/examples/roberta/config/pretraining/run_config/slurm_3.yaml
new file mode 100644
index 000000000..0e1555d20
--- /dev/null
+++ b/examples/roberta/config/pretraining/run_config/slurm_3.yaml
@@ -0,0 +1,36 @@
+# @package _global_
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: ':'
+ item_sep: '/'
+ exclude_keys:
+ - run_config
+ - distributed_training.distributed_port
+ - distributed_training.distributed_world_size
+ - model.pretrained_model_path
+ - model.target_network_path
+ - next_script
+ - task.cache_in_scratch
+ - task.data
+ - checkpoint.save_interval_updates
+ - checkpoint.keep_interval_updates
+ - checkpoint.save_on_overflow
+ - common.log_interval
+ sweep:
+ dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
+ subdir: ''
+ launcher:
+ submitit_folder: ${hydra.sweep.dir}
+ timeout_min: 4320
+ cpus_per_task: 10
+ gpus_per_node: 8
+ tasks_per_node: 8
+ mem_gb: 450
+ nodes: 3
+ name: ${env:PREFIX}_${hydra.job.config_name}
+ partition: devlab,learnlab,learnfair,scavenge
+ constraint: volta32gb,ib4
+ max_num_timeout: 30
diff --git a/examples/roberta/config/pretraining/run_config/slurm_4.yaml b/examples/roberta/config/pretraining/run_config/slurm_4.yaml
new file mode 100644
index 000000000..c54d735fb
--- /dev/null
+++ b/examples/roberta/config/pretraining/run_config/slurm_4.yaml
@@ -0,0 +1,36 @@
+# @package _global_
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: ':'
+ item_sep: '/'
+ exclude_keys:
+ - run_config
+ - distributed_training.distributed_port
+ - distributed_training.distributed_world_size
+ - model.pretrained_model_path
+ - model.target_network_path
+ - next_script
+ - task.cache_in_scratch
+ - task.data
+ - checkpoint.save_interval_updates
+ - checkpoint.keep_interval_updates
+ - checkpoint.save_on_overflow
+ - common.log_interval
+ sweep:
+ dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
+ subdir: ''
+ launcher:
+ submitit_folder: ${hydra.sweep.dir}
+ timeout_min: 4320
+ cpus_per_task: 10
+ gpus_per_node: 8
+ tasks_per_node: 8
+ mem_gb: 450
+ nodes: 4
+ name: ${env:PREFIX}_${hydra.job.config_name}
+ partition: devlab,learnlab,learnfair,scavenge
+ constraint: volta32gb,ib4
+ max_num_timeout: 30
diff --git a/examples/roberta/fb_multilingual/README.multilingual.pretraining.md b/examples/roberta/fb_multilingual/README.multilingual.pretraining.md
new file mode 100644
index 000000000..234fd7470
--- /dev/null
+++ b/examples/roberta/fb_multilingual/README.multilingual.pretraining.md
@@ -0,0 +1,26 @@
+# Multilingual pretraining RoBERTa
+
+This tutorial will walk you through pretraining multilingual RoBERTa.
+
+### 1) Preprocess the data
+
+```bash
+DICTIONARY="/private/home/namangoyal/dataset/XLM/wiki/17/175k/vocab"
+DATA_LOCATION="/private/home/namangoyal/dataset/XLM/wiki/17/175k"
+
+for LANG in en es it
+do
+ fairseq-preprocess \
+ --only-source \
+ --srcdict $DICTIONARY \
+ --trainpref "$DATA_LOCATION/train.$LANG" \
+ --validpref "$DATA_LOCATION/valid.$LANG" \
+ --testpref "$DATA_LOCATION/test.$LANG" \
+ --destdir "wiki_17-bin/$LANG" \
+ --workers 60;
+done
+```
+
+### 2) Train RoBERTa base
+
+[COMING UP...]
diff --git a/examples/simultaneous_translation/modules/monotonic_multihead_attention.py b/examples/simultaneous_translation/modules/monotonic_multihead_attention.py
index 11ef60c94..06d20d8d4 100644
--- a/examples/simultaneous_translation/modules/monotonic_multihead_attention.py
+++ b/examples/simultaneous_translation/modules/monotonic_multihead_attention.py
@@ -396,6 +396,7 @@ class MonotonicAttention(MultiheadAttention):
"p_choose": p_choose,
"alpha": alpha,
"beta": beta,
+ "soft_energy": soft_energy,
}
def _get_monotonic_buffer(self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]):
diff --git a/examples/simultaneous_translation/tests/test_text_models.py b/examples/simultaneous_translation/tests/test_text_models.py
index 127adfa63..19d635630 100644
--- a/examples/simultaneous_translation/tests/test_text_models.py
+++ b/examples/simultaneous_translation/tests/test_text_models.py
@@ -334,7 +334,7 @@ class InfiniteLookbackTestCase(
self.model.decoder.layers[0].encoder_attn,
"chunk_size",
int(1e10)
- )
+ ) or int(1e10)
)
self.assertTrue(
diff --git a/examples/speech_recognition/new/conf/hydra/sweeper/ax_sil.yaml b/examples/speech_recognition/new/conf/hydra/sweeper/ax_sil.yaml
new file mode 100644
index 000000000..eaaebcf5f
--- /dev/null
+++ b/examples/speech_recognition/new/conf/hydra/sweeper/ax_sil.yaml
@@ -0,0 +1,29 @@
+# @package hydra.sweeper
+_target_: hydra_plugins.hydra_ax_sweeper.ax_sweeper.AxSweeper
+max_batch_size: null
+ax_config:
+ max_trials: 64
+ early_stop:
+ minimize: true
+ max_epochs_without_improvement: 10
+ epsilon: 0.025
+ experiment:
+ name: ${dataset.gen_subset}
+ objective_name: wer
+ minimize: true
+ parameter_constraints: null
+ outcome_constraints: null
+ status_quo: null
+ client:
+ verbose_logging: false
+ random_seed: null
+ params:
+ decoding.lmweight:
+ type: range
+ bounds: [0.0, 10.0]
+ decoding.wordscore:
+ type: range
+ bounds: [-10.0, 10.0]
+ decoding.silweight:
+ type: range
+ bounds: [ -10.0, 0.0 ]
diff --git a/examples/speech_recognition/new/conf/infer.yaml b/examples/speech_recognition/new/conf/infer.yaml
index 21dd19fad..2d168d06a 100644
--- a/examples/speech_recognition/new/conf/infer.yaml
+++ b/examples/speech_recognition/new/conf/infer.yaml
@@ -10,6 +10,8 @@ hydra:
sweep:
dir: /checkpoint/${env:USER}/${env:PREFIX}/${common_eval.results_path}
subdir: ${dataset.gen_subset}
+common:
+ user_dir: /private/home/abaevski/fairseq-py/examples/data2vec
common_eval:
results_path: null
path: null
diff --git a/examples/speech_recognition/new/conf/run_config/fb_slurm_1.yaml b/examples/speech_recognition/new/conf/run_config/fb_slurm_1.yaml
new file mode 100644
index 000000000..d0a9b0e58
--- /dev/null
+++ b/examples/speech_recognition/new/conf/run_config/fb_slurm_1.yaml
@@ -0,0 +1,28 @@
+# @package _global_
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: ':'
+ item_sep: '/'
+ exclude_keys:
+ - run_config
+ - distributed_training.distributed_port
+ - common_eval.path
+ sweep:
+ dir: /checkpoint/abaevski/asr/d2v2/decoding/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
+# subdir: ${hydra.job.override_dirname}
+ launcher:
+ cpus_per_task: 16
+ gpus_per_node: 1
+ tasks_per_node: 1
+ nodes: 1
+ partition: devlab,learnlab
+ mem_gb: 100
+ timeout_min: 2000
+ max_num_timeout: 10
+ name: ${env:PREFIX}_${hydra.job.config_name}
+ submitit_folder: ${hydra.sweep.dir}/%j
+ constraint: volta32gb
+ exclude: learnfair7598
\ No newline at end of file
diff --git a/examples/speech_recognition/new/conf/run_config/fb_slurm_2g.yaml b/examples/speech_recognition/new/conf/run_config/fb_slurm_2g.yaml
new file mode 100644
index 000000000..c0c442f76
--- /dev/null
+++ b/examples/speech_recognition/new/conf/run_config/fb_slurm_2g.yaml
@@ -0,0 +1,27 @@
+# @package _global_
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: ':'
+ item_sep: '/'
+ exclude_keys:
+ - run_config
+ - distributed_training.distributed_port
+ - common_eval.path
+ sweep:
+ dir: /checkpoint/abaevski/asr/d2v2/decoding/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
+# subdir: ${hydra.job.override_dirname}
+ launcher:
+ cpus_per_task: 16
+ gpus_per_node: 2
+ tasks_per_node: 2
+ nodes: 1
+ partition: devlab,learnlab
+ mem_gb: 100
+ timeout_min: 2000
+ max_num_timeout: 10
+ name: ${env:PREFIX}_${hydra.job.config_name}
+ submitit_folder: ${hydra.sweep.dir}/%j
+ constraint: volta32gb
\ No newline at end of file
diff --git a/examples/wav2vec/config/finetuning/run_config/slurm_1.yaml b/examples/wav2vec/config/finetuning/run_config/slurm_1.yaml
new file mode 100644
index 000000000..4a848435c
--- /dev/null
+++ b/examples/wav2vec/config/finetuning/run_config/slurm_1.yaml
@@ -0,0 +1,26 @@
+# @package _global_
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: ':'
+ item_sep: '__'
+ exclude_keys:
+ - run_config
+ - distributed_training.distributed_port
+ sweep:
+ dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
+ subdir: ${hydra.job.num}
+ launcher:
+ submitit_folder: ${hydra.sweep.dir}
+ timeout_min: 4320
+ cpus_per_task: 10
+ gpus_per_node: 8
+ tasks_per_node: 8
+ mem_gb: 450
+ nodes: 1
+ name: ${env:PREFIX}_${hydra.job.config_name}
+ partition: devlab,learnlab,learnfair,scavenge
+ constraint: volta32gb
+ max_num_timeout: 30
\ No newline at end of file
diff --git a/examples/wav2vec/config/finetuning/run_config/slurm_16.yaml b/examples/wav2vec/config/finetuning/run_config/slurm_16.yaml
new file mode 100644
index 000000000..041843a9b
--- /dev/null
+++ b/examples/wav2vec/config/finetuning/run_config/slurm_16.yaml
@@ -0,0 +1,27 @@
+# @package _global_
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: ':'
+ item_sep: '__'
+ exclude_keys:
+ - run_config
+ - distributed_training.distributed_port
+ sweep:
+ dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
+ subdir: ${hydra.job.num}
+ launcher:
+ submitit_folder: ${hydra.sweep.dir}
+ timeout_min: 4320
+ cpus_per_task: 80
+ gpus_per_node: 8
+ tasks_per_node: 1
+ mem_gb: 450
+ nodes: 16
+ name: ${env:PREFIX}_${hydra.job.config_name}
+ partition: learnlab,learnfair,scavenge
+ constraint: volta32gb
+ max_num_timeout: 30
+ exclude: learnfair1381,learnfair5192,learnfair2304
\ No newline at end of file
diff --git a/examples/wav2vec/config/finetuning/run_config/slurm_1_aws.yaml b/examples/wav2vec/config/finetuning/run_config/slurm_1_aws.yaml
new file mode 100644
index 000000000..b9335df78
--- /dev/null
+++ b/examples/wav2vec/config/finetuning/run_config/slurm_1_aws.yaml
@@ -0,0 +1,37 @@
+# @package _global_
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: ':'
+ item_sep: '/'
+ exclude_keys:
+ - run_config
+ - distributed_training.distributed_port
+ - distributed_training.distributed_world_size
+ - model.pretrained_model_path
+ - model.target_network_path
+ - next_script
+ - task.cache_in_scratch
+ - task.local_cache_path
+ - task.data
+ - checkpoint.save_interval_updates
+ - checkpoint.keep_interval_updates
+ - checkpoint.save_on_overflow
+ - common.log_interval
+ - common.user_dir
+ sweep:
+ dir: /fsx-wav2vec/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
+ subdir: ''
+ launcher:
+ submitit_folder: ${hydra.sweep.dir}
+ timeout_min: 4320
+ cpus_per_task: 80
+ gpus_per_node: 8
+ tasks_per_node: 1
+ mem_gb: 0
+ nodes: 1
+ name: ${env:PREFIX}_${hydra.job.config_name}
+ partition: wav2vec,learnlab,learnfair
+ max_num_timeout: 30
diff --git a/examples/wav2vec/config/finetuning/run_config/slurm_1_old.yaml b/examples/wav2vec/config/finetuning/run_config/slurm_1_old.yaml
new file mode 100644
index 000000000..a8d2363dc
--- /dev/null
+++ b/examples/wav2vec/config/finetuning/run_config/slurm_1_old.yaml
@@ -0,0 +1,27 @@
+# @package _global_
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: ':'
+ item_sep: '__'
+ exclude_keys:
+ - run_config
+ - distributed_training.distributed_port
+ sweep:
+ dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
+ subdir: ${hydra.job.num}
+ launcher:
+ submitit_folder: ${hydra.sweep.dir}
+ timeout_min: 4320
+ cpus_per_task: 80
+ gpus_per_node: 8
+ tasks_per_node: 1
+ mem_gb: 450
+ nodes: 1
+ name: ${env:PREFIX}_wav2vec3_small_librispeech
+ partition: devlab,learnlab,learnfair,scavenge
+ constraint: volta32gb
+ max_num_timeout: 30
+ exclude: learnfair1381
\ No newline at end of file
diff --git a/examples/wav2vec/config/finetuning/run_config/slurm_2.yaml b/examples/wav2vec/config/finetuning/run_config/slurm_2.yaml
new file mode 100644
index 000000000..65ec48920
--- /dev/null
+++ b/examples/wav2vec/config/finetuning/run_config/slurm_2.yaml
@@ -0,0 +1,27 @@
+# @package _global_
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: ':'
+ item_sep: '__'
+ exclude_keys:
+ - run_config
+ - distributed_training.distributed_port
+ sweep:
+ dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
+ subdir: ${hydra.job.num}
+ launcher:
+ submitit_folder: ${hydra.sweep.dir}
+ timeout_min: 4320
+ cpus_per_task: 10
+ gpus_per_node: 8
+ tasks_per_node: 8
+ mem_gb: 450
+ nodes: 2
+ name: ${env:PREFIX}_${hydra.job.config_name}
+ partition: devlab,learnlab,learnfair,scavenge
+ constraint: volta32gb
+ max_num_timeout: 30
+ exclude: learnfair7491,learnfair7477,learnfair7487
\ No newline at end of file
diff --git a/examples/wav2vec/config/finetuning/run_config/slurm_2_aws.yaml b/examples/wav2vec/config/finetuning/run_config/slurm_2_aws.yaml
new file mode 100644
index 000000000..e7590efc0
--- /dev/null
+++ b/examples/wav2vec/config/finetuning/run_config/slurm_2_aws.yaml
@@ -0,0 +1,37 @@
+# @package _global_
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: ':'
+ item_sep: '/'
+ exclude_keys:
+ - run_config
+ - distributed_training.distributed_port
+ - distributed_training.distributed_world_size
+ - model.pretrained_model_path
+ - model.target_network_path
+ - next_script
+ - task.cache_in_scratch
+ - task.local_cache_path
+ - task.data
+ - checkpoint.save_interval_updates
+ - checkpoint.keep_interval_updates
+ - checkpoint.save_on_overflow
+ - common.log_interval
+ - common.user_dir
+ sweep:
+ dir: /fsx-wav2vec/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
+ subdir: ''
+ launcher:
+ submitit_folder: ${hydra.sweep.dir}
+ timeout_min: 4320
+ cpus_per_task: 80
+ gpus_per_node: 8
+ tasks_per_node: 1
+ mem_gb: 0
+ nodes: 2
+ name: ${env:PREFIX}_${hydra.job.config_name}
+ partition: wav2vec,learnlab,learnfair
+ max_num_timeout: 30
diff --git a/examples/wav2vec/config/finetuning/run_config/slurm_2g.yaml b/examples/wav2vec/config/finetuning/run_config/slurm_2g.yaml
new file mode 100644
index 000000000..aaa20ebd0
--- /dev/null
+++ b/examples/wav2vec/config/finetuning/run_config/slurm_2g.yaml
@@ -0,0 +1,26 @@
+# @package _global_
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: ':'
+ item_sep: '__'
+ exclude_keys:
+ - run_config
+ - distributed_training.distributed_port
+ sweep:
+ dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
+ subdir: ${hydra.job.num}
+ launcher:
+ submitit_folder: ${hydra.sweep.dir}
+ timeout_min: 4320
+ cpus_per_task: 10
+ gpus_per_node: 2
+ tasks_per_node: 2
+ mem_gb: 200
+ nodes: 1
+ name: ${env:PREFIX}_${hydra.job.config_name}
+ partition: devlab,learnlab,learnfair,scavenge
+ constraint: volta32gb
+ max_num_timeout: 30
diff --git a/examples/wav2vec/config/finetuning/run_config/slurm_3.yaml b/examples/wav2vec/config/finetuning/run_config/slurm_3.yaml
new file mode 100644
index 000000000..9614ececa
--- /dev/null
+++ b/examples/wav2vec/config/finetuning/run_config/slurm_3.yaml
@@ -0,0 +1,27 @@
+# @package _global_
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: ':'
+ item_sep: '__'
+ exclude_keys:
+ - run_config
+ - distributed_training.distributed_port
+ sweep:
+ dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
+ subdir: ${hydra.job.num}
+ launcher:
+ submitit_folder: ${hydra.sweep.dir}
+ timeout_min: 4320
+ cpus_per_task: 10
+ gpus_per_node: 8
+ tasks_per_node: 8
+ mem_gb: 450
+ nodes: 3
+ name: ${env:PREFIX}_${hydra.job.config_name}
+ partition: devlab,learnlab,learnfair,scavenge
+ constraint: volta32gb
+ max_num_timeout: 30
+ exclude: learnfair7491,learnfair7477,learnfair7487
\ No newline at end of file
diff --git a/examples/wav2vec/config/finetuning/run_config/slurm_4g.yaml b/examples/wav2vec/config/finetuning/run_config/slurm_4g.yaml
new file mode 100644
index 000000000..c0c9f6043
--- /dev/null
+++ b/examples/wav2vec/config/finetuning/run_config/slurm_4g.yaml
@@ -0,0 +1,26 @@
+# @package _global_
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: ':'
+ item_sep: '__'
+ exclude_keys:
+ - run_config
+ - distributed_training.distributed_port
+ sweep:
+ dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
+ subdir: ${hydra.job.num}
+ launcher:
+ submitit_folder: ${hydra.sweep.dir}
+ timeout_min: 4320
+ cpus_per_task: 10
+ gpus_per_node: 4
+ tasks_per_node: 4
+ mem_gb: 200
+ nodes: 1
+ name: ${env:PREFIX}_${hydra.job.config_name}
+ partition: devlab,learnlab,learnfair,scavenge
+ constraint: volta32gb
+ max_num_timeout: 30
diff --git a/examples/wav2vec/config/finetuning/run_config/slurm_4g_aws.yaml b/examples/wav2vec/config/finetuning/run_config/slurm_4g_aws.yaml
new file mode 100644
index 000000000..6bbbf3b64
--- /dev/null
+++ b/examples/wav2vec/config/finetuning/run_config/slurm_4g_aws.yaml
@@ -0,0 +1,37 @@
+# @package _global_
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: ':'
+ item_sep: '/'
+ exclude_keys:
+ - run_config
+ - distributed_training.distributed_port
+ - distributed_training.distributed_world_size
+ - model.pretrained_model_path
+ - model.target_network_path
+ - next_script
+ - task.cache_in_scratch
+ - task.local_cache_path
+ - task.data
+ - checkpoint.save_interval_updates
+ - checkpoint.keep_interval_updates
+ - checkpoint.save_on_overflow
+ - common.log_interval
+ - common.user_dir
+ sweep:
+ dir: /fsx-wav2vec/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
+ subdir: ''
+ launcher:
+ submitit_folder: ${hydra.sweep.dir}
+ timeout_min: 4320
+ cpus_per_task: 80
+ gpus_per_node: 4
+ tasks_per_node: 1
+ mem_gb: 0
+ nodes: 1
+ name: ${env:PREFIX}_${hydra.job.config_name}
+ partition: wav2vec,learnlab,learnfair
+ max_num_timeout: 30
diff --git a/examples/wav2vec/config/finetuning/run_config/slurm_8.yaml b/examples/wav2vec/config/finetuning/run_config/slurm_8.yaml
new file mode 100644
index 000000000..984f21888
--- /dev/null
+++ b/examples/wav2vec/config/finetuning/run_config/slurm_8.yaml
@@ -0,0 +1,26 @@
+# @package _global_
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: ':'
+ item_sep: '__'
+ exclude_keys:
+ - run_config
+ - distributed_training.distributed_port
+ sweep:
+ dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
+ subdir: ${hydra.job.num}
+ launcher:
+ submitit_folder: ${hydra.sweep.dir}
+ timeout_min: 4320
+ cpus_per_task: 10
+ gpus_per_node: 8
+ tasks_per_node: 8
+ mem_gb: 400
+ nodes: 8
+ name: ${env:PREFIX}_${hydra.job.config_name}
+ partition: devlab,learnlab,learnfair,scavenge
+ constraint: volta32gb
+ max_num_timeout: 30
diff --git a/examples/wav2vec/config/finetuning/vox_100h_2.yaml b/examples/wav2vec/config/finetuning/vox_100h_2.yaml
new file mode 100644
index 000000000..9bf588f58
--- /dev/null
+++ b/examples/wav2vec/config/finetuning/vox_100h_2.yaml
@@ -0,0 +1,106 @@
+# @package _group_
+
+common:
+ fp16: true
+ log_format: json
+ log_interval: 200
+ user_dir: /private/home/abaevski/fairseq-py/examples/data2vec
+# tensorboard_logdir: tb
+
+checkpoint:
+ save_interval: 1
+ no_epoch_checkpoints: true
+ best_checkpoint_metric: wer
+
+task:
+ _name: audio_finetuning
+ data: /checkpoint/abaevski/data/speech/libri/1h/wav2vec/raw
+ labels: ltr
+ normalize: true
+
+dataset:
+ num_workers: 6
+ max_tokens: 1280000
+ skip_invalid_size_inputs_valid_test: true
+ validate_after_updates: 100
+ validate_interval: 1
+ valid_subset: dev_other
+ required_batch_size_multiple: 1
+
+distributed_training:
+ ddp_backend: legacy_ddp
+ distributed_world_size: 8
+
+criterion:
+ _name: ctc
+ zero_infinity: true
+ post_process: letter
+ wer_kenlm_model: /checkpoint/abaevski/data/speech/libri/4-gram.bin
+ wer_lexicon: /checkpoint/abaevski/data/speech/libri/10h/wav2vec/raw/lexicon_ltr2.lst
+ wer_lm_weight: 2.0
+ wer_word_score: 0
+ wer_sil_weight: -2
+
+optimization:
+ max_update: 100000
+ lr: [1e-5]
+# lr: [1e-5] # base 10h wer
+ sentence_avg: true
+ update_freq: [1] # base 10h we -> 2/4
+
+optimizer:
+ _name: adam
+ adam_betas: (0.9,0.98)
+ adam_eps: 1e-08
+
+lr_scheduler:
+ _name: tri_stage
+ phase_ratio: null
+ warmup_steps: 8000
+ hold_steps: 0
+ decay_steps: 72000
+ final_lr_scale: 0.05
+
+model:
+ _name: wav2vec_ctc
+ w2v_path: ???
+ apply_mask: true
+ mask_prob: 0.4
+ mask_length: 5
+# mask_prob: 0.65 # base 10h wer
+ mask_channel_prob: 0.1
+# mask_channel_prob: 0.6 # base 10h wer
+ mask_channel_length: 64
+ layerdrop: 0.1
+# layerdrop: 0.05 # base 10h wer
+ activation_dropout: 0.1
+ feature_grad_mult: 0.0
+ freeze_finetune_updates: 100
+ dropout: 0
+ final_dropout: 0
+ attention_dropout: 0
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: ':'
+ item_sep: '__'
+ exclude_keys:
+ - run_config
+ - distributed_training.distributed_port
+ sweep:
+ dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}/${hydra.job.override_dirname}
+ subdir: ${hydra.job.num}
+ launcher:
+ submitit_folder: ${hydra.sweep.dir}
+ timeout_min: 3000
+ cpus_per_task: 10
+ gpus_per_node: 4
+ tasks_per_node: 4
+ mem_gb: 250
+ nodes: 1
+ name: ${env:PREFIX}_${hydra.job.config_name}
+ partition: devlab,learnlab,learnfair,scavenge
+ constraint: volta32gb
+ max_num_timeout: 30
diff --git a/examples/wav2vec/config/finetuning/vox_100h_2_aws.yaml b/examples/wav2vec/config/finetuning/vox_100h_2_aws.yaml
new file mode 100644
index 000000000..3a0d517eb
--- /dev/null
+++ b/examples/wav2vec/config/finetuning/vox_100h_2_aws.yaml
@@ -0,0 +1,82 @@
+# @package _group_
+
+common:
+ fp16: true
+ log_format: json
+ log_interval: 200
+ user_dir: /data/home/abaevski/fairseq-py/examples/data2vec
+# tensorboard_logdir: tb
+
+checkpoint:
+ save_interval: 1
+ no_epoch_checkpoints: true
+ best_checkpoint_metric: wer
+
+task:
+ _name: audio_finetuning
+ data: /fsx-wav2vec/abaevski/data/libri/100h/raw
+ labels: ltr
+ normalize: true
+
+dataset:
+ num_workers: 6
+ max_tokens: 1280000
+ skip_invalid_size_inputs_valid_test: true
+ validate_after_updates: 100
+ validate_interval: 1
+ valid_subset: dev_other
+ required_batch_size_multiple: 1
+
+distributed_training:
+ ddp_backend: legacy_ddp
+ distributed_world_size: 8
+
+criterion:
+ _name: ctc
+ zero_infinity: true
+ post_process: letter
+ wer_kenlm_model: /fsx-wav2vec/abaevski/data/libri/4-gram.bin
+ wer_lexicon: /fsx-wav2vec/abaevski/data/libri/10h/wav2vec/raw/lexicon_ltr2.lst
+ wer_lm_weight: 2.0
+ wer_word_score: 0
+ wer_sil_weight: -2
+
+optimization:
+ max_update: 100000
+ lr: [1e-5]
+# lr: [1e-5] # base 10h wer
+ sentence_avg: true
+ update_freq: [1] # base 10h we -> 2/4
+
+optimizer:
+ _name: adam
+ adam_betas: (0.9,0.98)
+ adam_eps: 1e-08
+
+lr_scheduler:
+ _name: tri_stage
+ phase_ratio: null
+ warmup_steps: 8000
+ hold_steps: 0
+ decay_steps: 82000
+ final_lr_scale: 0.05
+
+model:
+ _name: wav2vec_ctc
+ w2v_path: ???
+ apply_mask: true
+ mask_prob: 0.4
+ mask_length: 7
+# mask_prob: 0.65 # base 10h wer
+ mask_channel_prob: 0.1
+# mask_channel_prob: 0.6 # base 10h wer
+ mask_channel_length: 64
+ layerdrop: 0
+# layerdrop: 0.05 # base 10h wer
+ activation_dropout: 0.1
+ feature_grad_mult: 0.0
+ freeze_finetune_updates: 100
+ dropout: 0
+ final_dropout: 0
+ attention_dropout: 0
+
diff --git a/examples/wav2vec/config/finetuning/vox_100h_3.yaml b/examples/wav2vec/config/finetuning/vox_100h_3.yaml
new file mode 100644
index 000000000..46778666f
--- /dev/null
+++ b/examples/wav2vec/config/finetuning/vox_100h_3.yaml
@@ -0,0 +1,101 @@
+# @package _group_
+
+common:
+ fp16: true
+ log_format: json
+ log_interval: 200
+ user_dir: /private/home/abaevski/fairseq-py/examples/data2vec
+# tensorboard_logdir: tb
+
+checkpoint:
+ save_interval: 1
+ no_epoch_checkpoints: true
+ best_checkpoint_metric: wer
+
+task:
+ _name: audio_finetuning
+ data: /checkpoint/abaevski/data/speech/libri/1h/wav2vec/raw
+ labels: ltr
+ normalize: true
+
+dataset:
+ num_workers: 6
+ max_tokens: 1000000
+ skip_invalid_size_inputs_valid_test: true
+ validate_after_updates: 100
+ validate_interval: 1
+ valid_subset: dev_other
+ required_batch_size_multiple: 1
+
+distributed_training:
+ ddp_backend: legacy_ddp
+ distributed_world_size: 8
+
+criterion:
+ _name: ctc
+ zero_infinity: true
+ post_process: letter
+ wer_kenlm_model: /checkpoint/abaevski/data/speech/libri/4-gram.bin
+ wer_lexicon: /checkpoint/abaevski/data/speech/libri/10h/wav2vec/raw/lexicon_ltr2.lst
+ wer_lm_weight: 2.0
+ wer_word_score: -1.0
+
+optimization:
+ max_update: 100000
+ lr: [1e-5]
+# lr: [1e-5] # base 10h wer
+ sentence_avg: true
+ update_freq: [1] # base 10h we -> 2/4
+
+optimizer:
+ _name: adam
+ adam_betas: (0.9,0.98)
+ adam_eps: 1e-08
+
+lr_scheduler:
+ _name: cosine
+ warmup_updates: 8000
+
+model:
+ _name: wav2vec_ctc
+ w2v_path: ???
+ apply_mask: true
+ mask_prob: 0.4
+ mask_length: 5
+# mask_prob: 0.65 # base 10h wer
+ mask_channel_prob: 0.1
+# mask_channel_prob: 0.6 # base 10h wer
+ mask_channel_length: 64
+ layerdrop: 0.1
+# layerdrop: 0.05 # base 10h wer
+ activation_dropout: 0.1
+ feature_grad_mult: 0.0
+ freeze_finetune_updates: 100
+ dropout: 0
+ final_dropout: 0
+ attention_dropout: 0
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: ':'
+ item_sep: '__'
+ exclude_keys:
+ - run_config
+ - distributed_training.distributed_port
+ sweep:
+ dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}/${hydra.job.override_dirname}
+ subdir: ${hydra.job.num}
+ launcher:
+ submitit_folder: ${hydra.sweep.dir}
+ timeout_min: 3000
+ cpus_per_task: 10
+ gpus_per_node: 4
+ tasks_per_node: 4
+ mem_gb: 250
+ nodes: 1
+ name: ${env:PREFIX}_${hydra.job.config_name}
+ partition: devlab,learnlab,learnfair,scavenge
+ constraint: volta32gb
+ max_num_timeout: 30
diff --git a/examples/wav2vec/config/finetuning/vox_10h_2.yaml b/examples/wav2vec/config/finetuning/vox_10h_2.yaml
new file mode 100644
index 000000000..05ee76f14
--- /dev/null
+++ b/examples/wav2vec/config/finetuning/vox_10h_2.yaml
@@ -0,0 +1,102 @@
+# @package _group_
+
+common:
+ fp16: true
+ log_format: json
+ log_interval: 200
+ user_dir: /private/home/abaevski/fairseq-py/examples/data2vec
+# tensorboard_logdir: tb
+
+checkpoint:
+ save_interval: 10
+ no_epoch_checkpoints: true
+ best_checkpoint_metric: wer
+ keep_interval_updates: 1
+
+task:
+ _name: audio_finetuning
+ data: /checkpoint/abaevski/data/speech/libri/10h/wav2vec/raw
+ labels: ltr
+ normalize: true
+
+dataset:
+ num_workers: 6
+ max_tokens: 1280000
+ skip_invalid_size_inputs_valid_test: true
+ validate_after_updates: 100
+ validate_interval: 10
+ valid_subset: dev_other
+ required_batch_size_multiple: 1
+
+distributed_training:
+ ddp_backend: legacy_ddp
+ distributed_world_size: 4
+
+criterion:
+ _name: ctc
+ zero_infinity: true
+ post_process: letter
+ wer_kenlm_model: /checkpoint/abaevski/data/speech/libri/4-gram.bin
+ wer_lexicon: /checkpoint/abaevski/data/speech/libri/10h/wav2vec/raw/lexicon_ltr2.lst
+ wer_lm_weight: 2.0
+ wer_word_score: -1.0
+
+optimization:
+ max_update: 60000
+ lr: [2e-5]
+# lr: [1e-5] # base 10h wer
+ sentence_avg: true
+ update_freq: [1] # base 10h we -> 2/4
+
+optimizer:
+ _name: adam
+ adam_betas: (0.9,0.98)
+ adam_eps: 1e-08
+
+lr_scheduler:
+ _name: cosine
+ warmup_updates: 8000
+
+model:
+ _name: wav2vec_ctc
+ w2v_path: ???
+ apply_mask: true
+ mask_prob: 0.5
+ mask_length: 5
+# mask_prob: 0.65 # base 10h wer
+ mask_channel_prob: 0.1
+# mask_channel_prob: 0.6 # base 10h wer
+ mask_channel_length: 64
+ layerdrop: 0.1
+# layerdrop: 0.05 # base 10h wer
+ activation_dropout: 0.1
+ feature_grad_mult: 0.0
+ freeze_finetune_updates: 100
+ dropout: 0
+ final_dropout: 0
+ attention_dropout: 0
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: ':'
+ item_sep: '__'
+ exclude_keys:
+ - run_config
+ - distributed_training.distributed_port
+ sweep:
+ dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}/${hydra.job.override_dirname}
+ subdir: ${hydra.job.num}
+ launcher:
+ submitit_folder: ${hydra.sweep.dir}
+ timeout_min: 3000
+ cpus_per_task: 10
+ gpus_per_node: 4
+ tasks_per_node: 4
+ mem_gb: 250
+ nodes: 1
+ name: ${env:PREFIX}_${hydra.job.config_name}
+ partition: devlab,learnlab,learnfair,scavenge
+ constraint: volta32gb
+ max_num_timeout: 30
diff --git a/examples/wav2vec/config/finetuning/vox_10h_2_aws.yaml b/examples/wav2vec/config/finetuning/vox_10h_2_aws.yaml
new file mode 100644
index 000000000..a0afc9c5d
--- /dev/null
+++ b/examples/wav2vec/config/finetuning/vox_10h_2_aws.yaml
@@ -0,0 +1,81 @@
+# @package _group_
+
+common:
+ fp16: true
+ log_format: json
+ log_interval: 200
+ user_dir: /data/home/abaevski/fairseq-py/examples/data2vec
+# tensorboard_logdir: tb
+
+checkpoint:
+ save_interval: 10
+ no_epoch_checkpoints: true
+ best_checkpoint_metric: wer
+
+task:
+ _name: audio_finetuning
+ data: /fsx-wav2vec/abaevski/data/libri/10h/wav2vec/raw
+ labels: ltr
+ normalize: true
+
+dataset:
+ num_workers: 6
+ max_tokens: 1280000
+ skip_invalid_size_inputs_valid_test: true
+ validate_after_updates: 100
+ validate_interval: 10
+ valid_subset: dev_other
+ required_batch_size_multiple: 1
+
+distributed_training:
+ ddp_backend: legacy_ddp
+ distributed_world_size: 4
+
+criterion:
+ _name: ctc
+ zero_infinity: true
+ post_process: letter
+ wer_kenlm_model: /fsx-wav2vec/abaevski/data/libri/4-gram.bin
+ wer_lexicon: /fsx-wav2vec/abaevski/data/libri/10h/wav2vec/raw/lexicon_ltr2.lst
+ wer_lm_weight: 2.0
+ wer_word_score: 4
+ wer_sil_weight: -5
+
+optimization:
+ max_update: 60000
+ lr: [1e-5]
+# lr: [1e-5] # base 10h wer
+ sentence_avg: true
+ update_freq: [1] # base 10h we -> 2/4
+
+optimizer:
+ _name: adam
+ adam_betas: (0.9,0.98)
+ adam_eps: 1e-08
+
+lr_scheduler:
+ _name: tri_stage
+ phase_ratio: null
+ warmup_steps: 8000
+ hold_steps: 0
+ decay_steps: 72000
+ final_lr_scale: 0.05
+
+model:
+ _name: wav2vec_ctc
+ w2v_path: ???
+ apply_mask: true
+ mask_prob: 0.75
+ mask_length: 5
+# mask_prob: 0.65 # base 10h wer
+ mask_channel_prob: 0.1
+# mask_channel_prob: 0.6 # base 10h wer
+ mask_channel_length: 64
+ layerdrop: 0
+# layerdrop: 0.05 # base 10h wer
+ activation_dropout: 0.1
+ feature_grad_mult: 0.0
+ freeze_finetune_updates: 100
+ dropout: 0
+ final_dropout: 0
+ attention_dropout: 0
diff --git a/examples/wav2vec/config/finetuning/vox_10h_aws.yaml b/examples/wav2vec/config/finetuning/vox_10h_aws.yaml
new file mode 100644
index 000000000..c75437365
--- /dev/null
+++ b/examples/wav2vec/config/finetuning/vox_10h_aws.yaml
@@ -0,0 +1,104 @@
+# @package _group_
+
+common:
+ fp16: true
+ log_format: json
+ log_interval: 200
+ user_dir: /data/home/abaevski/fairseq-py/examples/data2vec
+# tensorboard_logdir: tb
+
+checkpoint:
+ save_interval: 10
+ no_epoch_checkpoints: true
+ best_checkpoint_metric: wer
+
+task:
+ _name: audio_finetuning
+ data: /fsx-wav2vec/abaevski/data/libri/10h/wav2vec/raw
+ labels: ltr
+ normalize: true
+
+dataset:
+ num_workers: 6
+ max_tokens: 1280000
+ skip_invalid_size_inputs_valid_test: true
+ validate_after_updates: 100
+ validate_interval: 10
+ valid_subset: dev_other
+ required_batch_size_multiple: 1
+
+distributed_training:
+ ddp_backend: legacy_ddp
+ distributed_world_size: 4
+
+criterion:
+ _name: ctc
+ zero_infinity: true
+ post_process: letter
+# wer_kenlm_model: /fsx-wav2vec/abaevski/data/libri/4-gram.bin
+# wer_lexicon: /fsx-wav2vec/abaevski/data/libri/10h/wav2vec/raw/lexicon_ltr2.lst
+# wer_lm_weight: 2.0
+# wer_word_score: -1.0
+
+optimization:
+ max_update: 60000
+ lr: [2e-5]
+# lr: [1e-5] # base 10h wer
+ sentence_avg: true
+ update_freq: [1] # base 10h we -> 2/4
+
+optimizer:
+ _name: adam
+ adam_betas: (0.9,0.98)
+ adam_eps: 1e-08
+
+lr_scheduler:
+ _name: tri_stage
+ phase_ratio: null
+ warmup_steps: 8000
+ hold_steps: 0
+ decay_steps: 72000
+ final_lr_scale: 0.05
+
+model:
+ _name: wav2vec_ctc
+ w2v_path: ???
+ apply_mask: true
+ mask_prob: 0.4
+ mask_length: 5
+# mask_prob: 0.65 # base 10h wer
+ mask_channel_prob: 0.1
+# mask_channel_prob: 0.6 # base 10h wer
+ mask_channel_length: 64
+ layerdrop: 0.1
+# layerdrop: 0.05 # base 10h wer
+ activation_dropout: 0.1
+ feature_grad_mult: 0.0
+ freeze_finetune_updates: 100
+ dropout: 0
+ final_dropout: 0
+ attention_dropout: 0
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: ':'
+ item_sep: '__'
+ exclude_keys:
+ - run_config
+ - distributed_training.distributed_port
+ sweep:
+ dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}/${hydra.job.override_dirname}
+ subdir: ${hydra.job.num}
+ launcher:
+ submitit_folder: ${hydra.sweep.dir}
+ timeout_min: 3000
+ cpus_per_task: 10
+ gpus_per_node: 4
+ tasks_per_node: 4
+ mem_gb: 0
+ nodes: 1
+ name: ${env:PREFIX}_${hydra.job.config_name}
+ partition: wav2vec,learnlab
+ max_num_timeout: 30
diff --git a/examples/wav2vec/config/finetuning/vox_10h_aws_v100.yaml b/examples/wav2vec/config/finetuning/vox_10h_aws_v100.yaml
new file mode 100644
index 000000000..58ad2acf7
--- /dev/null
+++ b/examples/wav2vec/config/finetuning/vox_10h_aws_v100.yaml
@@ -0,0 +1,102 @@
+# @package _group_
+
+common:
+ fp16: true
+ log_format: json
+ log_interval: 200
+# tensorboard_logdir: tb
+
+checkpoint:
+ save_interval: 10
+ no_epoch_checkpoints: true
+ best_checkpoint_metric: wer
+
+task:
+ _name: audio_finetuning
+ data: /fsx/abaevski/data/libri/10h/wav2vec/raw
+ labels: ltr
+ cache_in_scratch: true
+
+
+dataset:
+ num_workers: 10
+ max_tokens: 1280000
+ skip_invalid_size_inputs_valid_test: true
+ validate_after_updates: 100
+ validate_interval: 10
+ valid_subset: dev_other
+ required_batch_size_multiple: 1
+
+distributed_training:
+ ddp_backend: legacy_ddp
+ distributed_world_size: 4
+
+criterion:
+ _name: ctc
+ zero_infinity: true
+ post_process: letter
+ wer_lexicon: /fsx/abaevski/data/libri/10h/wav2vec/raw/lexicon_ltr2.lst
+ wer_lm_weight: 2.0
+ wer_word_score: -1.0
+
+optimization:
+ max_update: 60000
+ lr: [2e-5]
+# lr: [1e-5] # base 10h wer
+ sentence_avg: true
+ update_freq: [1] # base 10h we -> 2/4
+
+optimizer:
+ _name: adam
+ adam_betas: (0.9,0.98)
+ adam_eps: 1e-08
+
+lr_scheduler:
+ _name: tri_stage
+ phase_ratio: null
+ warmup_steps: 8000
+ hold_steps: 0
+ decay_steps: 72000
+ final_lr_scale: 0.05
+
+model:
+ _name: wav2vec_ctc
+ w2v_path: ???
+ apply_mask: true
+ mask_prob: 0.6
+# mask_prob: 0.65 # base 10h wer
+ mask_channel_prob: 0.1
+# mask_channel_prob: 0.6 # base 10h wer
+ mask_channel_length: 64
+ layerdrop: 0.1
+# layerdrop: 0.05 # base 10h wer
+ activation_dropout: 0.1
+ feature_grad_mult: 0.0
+ freeze_finetune_updates: 100
+ dropout: 0
+ final_dropout: 0
+ attention_dropout: 0
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: ':'
+ item_sep: '__'
+ exclude_keys:
+ - run_config
+ - distributed_training.distributed_port
+ sweep:
+ dir: /fsx/${env:USER}/w2v_ft/${env:PREFIX}/${hydra.job.config_name}/${hydra.job.override_dirname}
+ subdir: ${hydra.job.num}
+ launcher:
+ submitit_folder: ${hydra.sweep.dir}
+ timeout_min: 3000
+ cpus_per_task: 10
+ gpus_per_node: 4
+ tasks_per_node: 4
+ mem_gb: 0
+ nodes: 1
+ name: ${env:PREFIX}_${hydra.job.config_name}
+ partition: learnfair
+ max_num_timeout: 30
diff --git a/examples/wav2vec/config/finetuning/vox_10m_2.yaml b/examples/wav2vec/config/finetuning/vox_10m_2.yaml
new file mode 100644
index 000000000..1ac7c1217
--- /dev/null
+++ b/examples/wav2vec/config/finetuning/vox_10m_2.yaml
@@ -0,0 +1,114 @@
+# @package _group_
+
+common:
+ fp16: true
+ fp16_no_flatten_grads: true
+ log_format: json
+ log_interval: 200
+ user_dir: /private/home/abaevski/fairseq-py/examples/data2vec
+# tensorboard_logdir: tb
+
+checkpoint:
+ save_interval: 500
+ save_interval_updates: 500
+ keep_interval_updates: 1
+ no_epoch_checkpoints: true
+ best_checkpoint_metric: wer
+
+task:
+ _name: audio_finetuning
+ data: /checkpoint/abaevski/data/speech/libri/10m/wav2vec/raw
+ labels: ltr
+ normalize: true
+
+dataset:
+ num_workers: 6
+ max_tokens: 1000000
+ skip_invalid_size_inputs_valid_test: true
+ validate_after_updates: 100
+ validate_interval: 500
+ valid_subset: dev_other
+ required_batch_size_multiple: 1
+
+distributed_training:
+ ddp_backend: legacy_ddp
+ distributed_world_size: 4
+
+criterion:
+ _name: ctc
+ zero_infinity: true
+ post_process: letter
+ wer_kenlm_model: /checkpoint/abaevski/data/speech/libri/4-gram.bin
+ wer_lexicon: /checkpoint/abaevski/data/speech/libri/10h/wav2vec/raw/lexicon_ltr2.lst
+ wer_lm_weight: 5
+ wer_word_score: 2
+ wer_sil_weight: -2
+
+optimization:
+ max_update: 10000
+ lr: [2e-6]
+# lr: [1e-5] # base 10h wer
+ sentence_avg: true
+ update_freq: [4] # base 10h we -> 2/4
+
+optimizer:
+ _name: composite
+ dynamic_groups: true
+ groups:
+ default:
+ lr_float: 2e-6
+ optimizer:
+ _name: adam
+ adam_betas: [0.9,0.95]
+ lr_scheduler:
+ _name: cosine
+ warmup_updates: 1000
+
+lr_scheduler: pass_through
+
+model:
+ _name: wav2vec_ctc
+ w2v_path: ???
+ apply_mask: true
+ mask_prob: 0.4
+ mask_length: 3
+# mask_prob: 0.65 # base 10h wer
+ mask_channel_prob: 0.25
+# mask_channel_prob: 0.6 # base 10h wer
+ mask_channel_length: 64
+ layerdrop: 0.1
+# layerdrop: 0.05 # base 10h wer
+ freeze_finetune_updates: 100
+
+ zero_mask: true
+ feature_grad_mult: 0.0
+ activation_dropout: 0.1
+ dropout: 0
+ final_dropout: 0
+ attention_dropout: 0
+ update_alibi: false
+
+#hydra:
+# job:
+# config:
+# override_dirname:
+# kv_sep: ':'
+# item_sep: '__'
+# exclude_keys:
+# - run_config
+# - distributed_training.distributed_port
+# sweep:
+# dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}/${hydra.job.override_dirname}
+# subdir: ${hydra.job.num}
+# launcher:
+# submitit_folder: ${hydra.sweep.dir}
+# timeout_min: 3000
+# cpus_per_task: 10
+# gpus_per_node: 4
+# tasks_per_node: 4
+# mem_gb: 250
+# nodes: 1
+# name: ${env:PREFIX}_${hydra.job.config_name}
+# partition: devlab,learnlab,learnfair,scavenge
+# constraint: volta32gb
+# max_num_timeout: 30
diff --git a/examples/wav2vec/config/finetuning/vox_10m_2_aws.yaml b/examples/wav2vec/config/finetuning/vox_10m_2_aws.yaml
new file mode 100644
index 000000000..a9c270855
--- /dev/null
+++ b/examples/wav2vec/config/finetuning/vox_10m_2_aws.yaml
@@ -0,0 +1,114 @@
+# @package _group_
+
+common:
+ fp16: true
+ fp16_no_flatten_grads: true
+ log_format: json
+ log_interval: 200
+ user_dir: /data/home/abaevski/fairseq-py/examples/data2vec
+# tensorboard_logdir: tb
+
+checkpoint:
+ save_interval: 500
+ save_interval_updates: 500
+ keep_interval_updates: 1
+ no_epoch_checkpoints: true
+ best_checkpoint_metric: wer
+
+task:
+ _name: audio_finetuning
+ data: /fsx-wav2vec/abaevski/data/libri/10m/wav2vec/raw
+ labels: ltr
+ normalize: true
+
+dataset:
+ num_workers: 6
+ max_tokens: 1000000
+ skip_invalid_size_inputs_valid_test: true
+ validate_after_updates: 100
+ validate_interval: 500
+ valid_subset: dev_other
+ required_batch_size_multiple: 1
+
+distributed_training:
+ ddp_backend: legacy_ddp
+ distributed_world_size: 4
+
+criterion:
+ _name: ctc
+ zero_infinity: true
+ post_process: letter
+ wer_kenlm_model: /fsx-wav2vec/abaevski/data/libri/4-gram.bin
+ wer_lexicon: /fsx-wav2vec/abaevski/data/libri/10h/wav2vec/raw/lexicon_ltr2.lst
+ wer_lm_weight: 5
+ wer_word_score: 2
+ wer_sil_weight: -2
+
+optimization:
+ max_update: 10000
+ lr: [2e-6]
+# lr: [1e-5] # base 10h wer
+ sentence_avg: true
+ update_freq: [4] # base 10h we -> 2/4
+
+optimizer:
+ _name: composite
+ dynamic_groups: true
+ groups:
+ default:
+ lr_float: 2e-6
+ optimizer:
+ _name: adam
+ adam_betas: [0.9,0.95]
+ lr_scheduler:
+ _name: cosine
+ warmup_updates: 1000
+
+lr_scheduler: pass_through
+
+model:
+ _name: wav2vec_ctc
+ w2v_path: ???
+ apply_mask: true
+ mask_prob: 0.4
+ mask_length: 3
+# mask_prob: 0.65 # base 10h wer
+ mask_channel_prob: 0.25
+# mask_channel_prob: 0.6 # base 10h wer
+ mask_channel_length: 64
+ layerdrop: 0.1
+# layerdrop: 0.05 # base 10h wer
+ freeze_finetune_updates: 100
+
+ zero_mask: true
+ feature_grad_mult: 0.0
+ activation_dropout: 0.1
+ dropout: 0
+ final_dropout: 0
+ attention_dropout: 0
+ update_alibi: false
+
+#hydra:
+# job:
+# config:
+# override_dirname:
+# kv_sep: ':'
+# item_sep: '__'
+# exclude_keys:
+# - run_config
+# - distributed_training.distributed_port
+# sweep:
+# dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}/${hydra.job.override_dirname}
+# subdir: ${hydra.job.num}
+# launcher:
+# submitit_folder: ${hydra.sweep.dir}
+# timeout_min: 3000
+# cpus_per_task: 10
+# gpus_per_node: 4
+# tasks_per_node: 4
+# mem_gb: 250
+# nodes: 1
+# name: ${env:PREFIX}_${hydra.job.config_name}
+# partition: devlab,learnlab,learnfair,scavenge
+# constraint: volta32gb
+# max_num_timeout: 30
diff --git a/examples/wav2vec/config/finetuning/vox_10m_3.yaml b/examples/wav2vec/config/finetuning/vox_10m_3.yaml
new file mode 100644
index 000000000..b6804126c
--- /dev/null
+++ b/examples/wav2vec/config/finetuning/vox_10m_3.yaml
@@ -0,0 +1,105 @@
+# @package _group_
+
+common:
+ fp16: true
+ log_format: json
+ log_interval: 200
+ user_dir: /private/home/abaevski/fairseq-py/examples/data2vec
+# tensorboard_logdir: tb
+
+checkpoint:
+ save_interval: 1000
+ save_interval_updates: 100
+ keep_interval_updates: 1
+ no_epoch_checkpoints: true
+ best_checkpoint_metric: wer
+
+task:
+ _name: audio_finetuning
+ data: /checkpoint/abaevski/data/speech/libri/10m/wav2vec/raw
+ labels: ltr
+ normalize: true
+
+dataset:
+ num_workers: 6
+ max_tokens: 1280000
+ skip_invalid_size_inputs_valid_test: true
+ validate_after_updates: 10000
+ validate_interval: 500
+ valid_subset: dev_other
+ required_batch_size_multiple: 8
+
+distributed_training:
+ ddp_backend: legacy_ddp
+ distributed_world_size: 4
+
+criterion:
+ _name: ctc
+ zero_infinity: true
+ post_process: letter
+ wer_kenlm_model: /checkpoint/abaevski/data/speech/libri/4-gram.bin
+ wer_lexicon: /checkpoint/abaevski/data/speech/libri/10h/wav2vec/raw/lexicon_ltr2.lst
+ wer_lm_weight: 8
+ wer_word_score: 5.8
+ wer_sil_weight: -8
+
+optimization:
+ max_update: 13000
+ lr: [2e-5]
+# lr: [1e-5] # base 10h wer
+ sentence_avg: true
+ update_freq: [5] # base 10h we -> 2/4
+
+optimizer:
+ _name: adam
+ adam_betas: (0.9,0.98)
+ adam_eps: 1e-08
+
+lr_scheduler:
+ _name: tri_stage
+ phase_ratio: [0.1, 0.4, 0.5]
+ final_lr_scale: 0.05
+
+model:
+ _name: wav2vec_ctc
+ w2v_path: ???
+ apply_mask: true
+ mask_prob: 0.65
+ mask_length: 10
+# mask_prob: 0.65 # base 10h wer
+ mask_channel_prob: 0.25
+# mask_channel_prob: 0.6 # base 10h wer
+ mask_channel_length: 64
+ layerdrop: 0.1
+# layerdrop: 0.05 # base 10h wer
+ activation_dropout: 0.1
+ feature_grad_mult: 0.0
+ freeze_finetune_updates: 10000
+ dropout: 0
+ final_dropout: 0
+ attention_dropout: 0
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: ':'
+ item_sep: '__'
+ exclude_keys:
+ - run_config
+ - distributed_training.distributed_port
+ sweep:
+ dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}/${hydra.job.override_dirname}
+ subdir: ${hydra.job.num}
+ launcher:
+ submitit_folder: ${hydra.sweep.dir}
+ timeout_min: 3000
+ cpus_per_task: 10
+ gpus_per_node: 4
+ tasks_per_node: 4
+ mem_gb: 250
+ nodes: 1
+ name: ${env:PREFIX}_${hydra.job.config_name}
+ partition: devlab,learnlab,learnfair,scavenge
+ constraint: volta32gb
+ max_num_timeout: 30
diff --git a/examples/wav2vec/config/finetuning/vox_1h_2.yaml b/examples/wav2vec/config/finetuning/vox_1h_2.yaml
new file mode 100644
index 000000000..75f4aafd7
--- /dev/null
+++ b/examples/wav2vec/config/finetuning/vox_1h_2.yaml
@@ -0,0 +1,104 @@
+# @package _group_
+
+common:
+ fp16: true
+ log_format: json
+ log_interval: 200
+ user_dir: /private/home/abaevski/fairseq-py/examples/data2vec
+# tensorboard_logdir: tb
+
+checkpoint:
+ save_interval: 100
+ save_interval_updates: 500
+ keep_interval_updates: 1
+ no_epoch_checkpoints: true
+ best_checkpoint_metric: wer
+
+task:
+ _name: audio_finetuning
+ data: /checkpoint/abaevski/data/speech/libri/1h/wav2vec/raw
+ labels: ltr
+ normalize: true
+
+dataset:
+ num_workers: 6
+ max_tokens: 1000000
+ skip_invalid_size_inputs_valid_test: true
+ validate_after_updates: 100
+ validate_interval: 100
+ valid_subset: dev_other
+ required_batch_size_multiple: 1
+
+distributed_training:
+ ddp_backend: legacy_ddp
+ distributed_world_size: 8
+
+criterion:
+ _name: ctc
+ zero_infinity: true
+ post_process: letter
+ wer_kenlm_model: /checkpoint/abaevski/data/speech/libri/4-gram.bin
+ wer_lexicon: /checkpoint/abaevski/data/speech/libri/10h/wav2vec/raw/lexicon_ltr2.lst
+ wer_lm_weight: 6
+ wer_word_score: -0.1
+ wer_sil_weight: -4.7
+
+optimization:
+ max_update: 60000
+ lr: [1e-5]
+# lr: [1e-5] # base 10h wer
+ sentence_avg: true
+ update_freq: [1] # base 10h we -> 2/4
+
+optimizer:
+ _name: adam
+ adam_betas: (0.9,0.98)
+ adam_eps: 1e-08
+
+lr_scheduler:
+ _name: cosine
+ warmup_updates: 4000
+
+model:
+ _name: wav2vec_ctc
+ w2v_path: ???
+ apply_mask: true
+ mask_prob: 0.65
+ mask_length: 5
+# mask_prob: 0.65 # base 10h wer
+ mask_channel_prob: 0.25
+# mask_channel_prob: 0.6 # base 10h wer
+ mask_channel_length: 64
+ layerdrop: 0.1
+# layerdrop: 0.05 # base 10h wer
+ activation_dropout: 0.1
+ feature_grad_mult: 0.0
+ freeze_finetune_updates: 100
+ dropout: 0
+ final_dropout: 0
+ attention_dropout: 0
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: ':'
+ item_sep: '__'
+ exclude_keys:
+ - run_config
+ - distributed_training.distributed_port
+ sweep:
+ dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}/${hydra.job.override_dirname}
+ subdir: ${hydra.job.num}
+ launcher:
+ submitit_folder: ${hydra.sweep.dir}
+ timeout_min: 3000
+ cpus_per_task: 10
+ gpus_per_node: 4
+ tasks_per_node: 4
+ mem_gb: 250
+ nodes: 1
+ name: ${env:PREFIX}_${hydra.job.config_name}
+ partition: devlab,learnlab,learnfair,scavenge
+ constraint: volta32gb
+ max_num_timeout: 30
diff --git a/examples/wav2vec/config/finetuning/vox_1h_2_aws.yaml b/examples/wav2vec/config/finetuning/vox_1h_2_aws.yaml
new file mode 100644
index 000000000..cc4d511d1
--- /dev/null
+++ b/examples/wav2vec/config/finetuning/vox_1h_2_aws.yaml
@@ -0,0 +1,114 @@
+# @package _group_
+
+common:
+ fp16: true
+ fp16_no_flatten_grads: true
+ log_format: json
+ log_interval: 200
+ user_dir: /data/home/abaevski/fairseq-py/examples/data2vec
+# tensorboard_logdir: tb
+
+checkpoint:
+ save_interval: 100
+ save_interval_updates: 500
+ keep_interval_updates: 1
+ no_epoch_checkpoints: true
+ best_checkpoint_metric: wer
+
+task:
+ _name: audio_finetuning
+ data: /fsx-wav2vec/abaevski/data/libri/1h/wav2vec/raw
+ labels: ltr
+ normalize: true
+
+dataset:
+ num_workers: 6
+ max_tokens: 1000000
+ skip_invalid_size_inputs_valid_test: true
+ validate_after_updates: 100
+ validate_interval: 500
+ valid_subset: dev_other
+ required_batch_size_multiple: 1
+
+distributed_training:
+ ddp_backend: legacy_ddp
+ distributed_world_size: 4
+
+criterion:
+ _name: ctc
+ zero_infinity: true
+ post_process: letter
+ wer_kenlm_model: /fsx-wav2vec/abaevski/data/libri/4-gram.bin
+ wer_lexicon: /fsx-wav2vec/abaevski/data/libri/10h/wav2vec/raw/lexicon_ltr2.lst
+ wer_lm_weight: 5
+ wer_word_score: 0
+ wer_sil_weight: -4
+
+optimization:
+ max_update: 10000
+ lr: [2e-6]
+# lr: [1e-5] # base 10h wer
+ sentence_avg: true
+ update_freq: [4] # base 10h we -> 2/4
+
+optimizer:
+ _name: composite
+ dynamic_groups: true
+ groups:
+ default:
+ lr_float: 2e-6
+ optimizer:
+ _name: adam
+ adam_betas: [0.9,0.95]
+ lr_scheduler:
+ _name: cosine
+ warmup_updates: 1000
+
+lr_scheduler: pass_through
+
+model:
+ _name: wav2vec_ctc
+ w2v_path: ???
+ apply_mask: true
+ mask_prob: 0.4
+ mask_length: 3
+# mask_prob: 0.65 # base 10h wer
+ mask_channel_prob: 0.25
+# mask_channel_prob: 0.6 # base 10h wer
+ mask_channel_length: 64
+ layerdrop: 0.1
+# layerdrop: 0.05 # base 10h wer
+ freeze_finetune_updates: 100
+
+ zero_mask: true
+ feature_grad_mult: 0.0
+ activation_dropout: 0.1
+ dropout: 0
+ final_dropout: 0
+ attention_dropout: 0
+ update_alibi: false
+
+#hydra:
+# job:
+# config:
+# override_dirname:
+# kv_sep: ':'
+# item_sep: '__'
+# exclude_keys:
+# - run_config
+# - distributed_training.distributed_port
+# sweep:
+# dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}/${hydra.job.override_dirname}
+# subdir: ${hydra.job.num}
+# launcher:
+# submitit_folder: ${hydra.sweep.dir}
+# timeout_min: 3000
+# cpus_per_task: 10
+# gpus_per_node: 4
+# tasks_per_node: 4
+# mem_gb: 250
+# nodes: 1
+# name: ${env:PREFIX}_${hydra.job.config_name}
+# partition: devlab,learnlab,learnfair,scavenge
+# constraint: volta32gb
+# max_num_timeout: 30
diff --git a/examples/wav2vec/config/finetuning/vox_1h_3.yaml b/examples/wav2vec/config/finetuning/vox_1h_3.yaml
new file mode 100644
index 000000000..842c89717
--- /dev/null
+++ b/examples/wav2vec/config/finetuning/vox_1h_3.yaml
@@ -0,0 +1,104 @@
+# @package _group_
+
+common:
+ fp16: true
+ log_format: json
+ log_interval: 200
+ user_dir: /private/home/abaevski/fairseq-py/examples/data2vec
+# tensorboard_logdir: tb
+
+checkpoint:
+ save_interval: 100
+ save_interval_updates: 500
+ keep_interval_updates: 1
+ no_epoch_checkpoints: true
+ best_checkpoint_metric: wer
+
+task:
+ _name: audio_finetuning
+ data: /checkpoint/abaevski/data/speech/libri/1h/wav2vec/raw
+ labels: ltr
+ normalize: true
+
+dataset:
+ num_workers: 6
+ max_tokens: 640000
+ skip_invalid_size_inputs_valid_test: true
+ validate_after_updates: 10000
+ validate_interval: 100
+ valid_subset: dev_other
+ required_batch_size_multiple: 8
+
+distributed_training:
+ ddp_backend: legacy_ddp
+ distributed_world_size: 8
+
+criterion:
+ _name: ctc
+ zero_infinity: true
+ post_process: letter
+ wer_kenlm_model: /checkpoint/abaevski/data/speech/libri/4-gram.bin
+ wer_lexicon: /checkpoint/abaevski/data/speech/libri/10h/wav2vec/raw/lexicon_ltr2.lst
+ wer_lm_weight: 6
+ wer_word_score: -0.1
+ wer_sil_weight: -4.7
+
+optimization:
+ max_update: 13000
+ lr: [6e-5]
+# lr: [1e-5] # base 10h wer
+ sentence_avg: true
+ update_freq: [5] # base 10h we -> 2/4
+
+optimizer:
+ _name: adam
+ adam_betas: (0.9,0.98)
+ adam_eps: 1e-08
+
+lr_scheduler:
+ _name: cosine
+ warmup_updates: 4000
+
+model:
+ _name: wav2vec_ctc
+ w2v_path: ???
+ apply_mask: true
+ mask_prob: 0.3
+ mask_length: 3
+# mask_prob: 0.65 # base 10h wer
+ mask_channel_prob: 0.25
+# mask_channel_prob: 0.6 # base 10h wer
+ mask_channel_length: 64
+ layerdrop: 0.1
+# layerdrop: 0.05 # base 10h wer
+ activation_dropout: 0.1
+ feature_grad_mult: 0.0
+ freeze_finetune_updates: 10000
+ dropout: 0
+ final_dropout: 0
+ attention_dropout: 0
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: ':'
+ item_sep: '__'
+ exclude_keys:
+ - run_config
+ - distributed_training.distributed_port
+ sweep:
+ dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}/${hydra.job.override_dirname}
+ subdir: ${hydra.job.num}
+ launcher:
+ submitit_folder: ${hydra.sweep.dir}
+ timeout_min: 3000
+ cpus_per_task: 10
+ gpus_per_node: 4
+ tasks_per_node: 4
+ mem_gb: 250
+ nodes: 1
+ name: ${env:PREFIX}_${hydra.job.config_name}
+ partition: devlab,learnlab,learnfair,scavenge
+ constraint: volta32gb
+ max_num_timeout: 30
diff --git a/examples/wav2vec/config/finetuning/vox_1h_4.yaml b/examples/wav2vec/config/finetuning/vox_1h_4.yaml
new file mode 100644
index 000000000..698ed8c4d
--- /dev/null
+++ b/examples/wav2vec/config/finetuning/vox_1h_4.yaml
@@ -0,0 +1,104 @@
+# @package _group_
+
+common:
+ fp16: true
+ log_format: json
+ log_interval: 200
+ user_dir: /private/home/abaevski/fairseq-py/examples/data2vec
+# tensorboard_logdir: tb
+
+checkpoint:
+ save_interval: 100
+ save_interval_updates: 1000
+ keep_interval_updates: 1
+ no_epoch_checkpoints: true
+ best_checkpoint_metric: wer
+
+task:
+ _name: audio_finetuning
+ data: /checkpoint/abaevski/data/speech/libri/1h/wav2vec/raw
+ labels: ltr
+ normalize: true
+
+dataset:
+ num_workers: 6
+ max_tokens: 640000
+ skip_invalid_size_inputs_valid_test: true
+ validate_after_updates: 10000
+ validate_interval: 100
+ valid_subset: dev_other
+ required_batch_size_multiple: 8
+
+distributed_training:
+ ddp_backend: legacy_ddp
+ distributed_world_size: 8
+
+criterion:
+ _name: ctc
+ zero_infinity: true
+ post_process: letter
+ wer_kenlm_model: /checkpoint/abaevski/data/speech/libri/4-gram.bin
+ wer_lexicon: /checkpoint/abaevski/data/speech/libri/10h/wav2vec/raw/lexicon_ltr2.lst
+ wer_lm_weight: 2.0
+ wer_word_score: -1.0
+
+optimization:
+ max_update: 13000
+ lr: [6e-5]
+# lr: [1e-5] # base 10h wer
+ sentence_avg: true
+ update_freq: [5] # base 10h we -> 2/4
+
+optimizer:
+ _name: adam
+ adam_betas: (0.9,0.98)
+ adam_eps: 1e-08
+
+lr_scheduler:
+ _name: tri_stage
+ phase_ratio: [0.1, 0.4, 0.5]
+ final_lr_scale: 0.05
+
+model:
+ _name: wav2vec_ctc
+ w2v_path: ???
+ apply_mask: true
+ mask_prob: 0.65
+ mask_length: 10
+# mask_prob: 0.65 # base 10h wer
+ mask_channel_prob: 0.25
+# mask_channel_prob: 0.6 # base 10h wer
+ mask_channel_length: 64
+ layerdrop: 0.1
+# layerdrop: 0.05 # base 10h wer
+ activation_dropout: 0.1
+ feature_grad_mult: 0.0
+ freeze_finetune_updates: 10000
+ dropout: 0
+ final_dropout: 0
+ attention_dropout: 0
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: ':'
+ item_sep: '__'
+ exclude_keys:
+ - run_config
+ - distributed_training.distributed_port
+ sweep:
+ dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}/${hydra.job.override_dirname}
+ subdir: ${hydra.job.num}
+ launcher:
+ submitit_folder: ${hydra.sweep.dir}
+ timeout_min: 3000
+ cpus_per_task: 10
+ gpus_per_node: 4
+ tasks_per_node: 4
+ mem_gb: 250
+ nodes: 1
+ name: ${env:PREFIX}_${hydra.job.config_name}
+ partition: devlab,learnlab,learnfair,scavenge
+ constraint: volta32gb
+ max_num_timeout: 30
diff --git a/examples/wav2vec/config/finetuning/vox_1h_aws.yaml b/examples/wav2vec/config/finetuning/vox_1h_aws.yaml
new file mode 100644
index 000000000..aa6700415
--- /dev/null
+++ b/examples/wav2vec/config/finetuning/vox_1h_aws.yaml
@@ -0,0 +1,80 @@
+# @package _group_
+
+common:
+ fp16: true
+ log_format: json
+ log_interval: 200
+ user_dir: /data/home/abaevski/fairseq-py/examples/data2vec
+# tensorboard_logdir: tb
+
+checkpoint:
+ save_interval: 100
+ save_interval_updates: 500
+ keep_interval_updates: 1
+ no_epoch_checkpoints: true
+ best_checkpoint_metric: wer
+
+task:
+ _name: audio_finetuning
+ data: /fsx-wav2vec/abaevski/data/libri/10m/wav2vec/raw
+ labels: ltr
+ normalize: true
+
+dataset:
+ num_workers: 6
+ max_tokens: 1000000
+ skip_invalid_size_inputs_valid_test: true
+ validate_after_updates: 10000
+ validate_interval: 100
+ valid_subset: dev_other
+ required_batch_size_multiple: 8
+
+distributed_training:
+ ddp_backend: legacy_ddp
+ distributed_world_size: 8
+
+criterion:
+ _name: ctc
+ zero_infinity: true
+ post_process: letter
+ wer_kenlm_model: /fsx-wav2vec/abaevski/data/libri/4-gram.bin
+ wer_lexicon: /fsx-wav2vec/abaevski/data/libri/10h/wav2vec/raw/lexicon_ltr2.lst
+ wer_lm_weight: 5
+ wer_word_score: -0.1
+ wer_sil_weight: -4.7
+
+optimization:
+ max_update: 13000
+ lr: [6e-5]
+# lr: [1e-5] # base 10h wer
+ sentence_avg: true
+ update_freq: [5] # base 10h we -> 2/4
+
+optimizer:
+ _name: adam
+ adam_betas: (0.9,0.98)
+ adam_eps: 1e-08
+
+lr_scheduler:
+ _name: cosine
+ warmup_updates: 4000
+
+model:
+ _name: wav2vec_ctc
+ w2v_path: ???
+ apply_mask: true
+ mask_prob: 0.3
+ mask_length: 3
+# mask_prob: 0.65 # base 10h wer
+ mask_channel_prob: 0.25
+# mask_channel_prob: 0.6 # base 10h wer
+ mask_channel_length: 64
+ layerdrop: 0.1
+# layerdrop: 0.05 # base 10h wer
+ activation_dropout: 0.1
+ feature_grad_mult: 0.0
+ freeze_finetune_updates: 10000
+ dropout: 0
+ final_dropout: 0
+ attention_dropout: 0
+ update_alibi: false
diff --git a/examples/wav2vec/config/finetuning/vox_960h_2.yaml b/examples/wav2vec/config/finetuning/vox_960h_2.yaml
new file mode 100644
index 000000000..d96e2325b
--- /dev/null
+++ b/examples/wav2vec/config/finetuning/vox_960h_2.yaml
@@ -0,0 +1,105 @@
+# @package _group_
+
+common:
+ fp16: true
+ log_format: json
+ log_interval: 200
+ user_dir: /private/home/abaevski/fairseq-py/examples/data2vec
+# tensorboard_logdir: tb
+
+checkpoint:
+ save_interval: 1
+ no_epoch_checkpoints: true
+ best_checkpoint_metric: wer
+
+task:
+ _name: audio_finetuning
+ data: /checkpoint/abaevski/data/speech/libri/960h/wav2vec/raw
+ labels: ltr
+ normalize: true
+
+dataset:
+ num_workers: 6
+ max_tokens: 1000000
+ skip_invalid_size_inputs_valid_test: true
+ validate_after_updates: 100
+ validate_interval: 1
+ valid_subset: dev_other
+ required_batch_size_multiple: 1
+
+distributed_training:
+ ddp_backend: legacy_ddp
+ distributed_world_size: 16
+
+criterion:
+ _name: ctc
+ zero_infinity: true
+ post_process: letter
+ wer_kenlm_model: /checkpoint/abaevski/data/speech/libri/4-gram.bin
+ wer_lexicon: /checkpoint/abaevski/data/speech/libri/10h/wav2vec/raw/lexicon_ltr2.lst
+ wer_lm_weight: 2.0
+ wer_word_score: -1.0
+
+optimization:
+ max_update: 200000
+ lr: [1e-5]
+# lr: [1e-5] # base 10h wer
+ sentence_avg: true
+ update_freq: [1] # base 10h we -> 2/4
+
+optimizer:
+ _name: adam
+ adam_betas: (0.9,0.98)
+ adam_eps: 1e-08
+
+lr_scheduler:
+ _name: tri_stage
+ phase_ratio: null
+ warmup_steps: 8000
+ hold_steps: 0
+ decay_steps: 200000
+ final_lr_scale: 0.05
+
+model:
+ _name: wav2vec_ctc
+ w2v_path: ???
+ apply_mask: true
+ mask_prob: 0.4
+ mask_length: 5
+# mask_prob: 0.65 # base 10h wer
+ mask_channel_prob: 0.1
+# mask_channel_prob: 0.6 # base 10h wer
+ mask_channel_length: 64
+ layerdrop: 0.1
+# layerdrop: 0.05 # base 10h wer
+ activation_dropout: 0.1
+ feature_grad_mult: 0.0
+ freeze_finetune_updates: 100
+ dropout: 0
+ final_dropout: 0
+ attention_dropout: 0
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: ':'
+ item_sep: '__'
+ exclude_keys:
+ - run_config
+ - distributed_training.distributed_port
+ sweep:
+ dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}/${hydra.job.override_dirname}
+ subdir: ${hydra.job.num}
+ launcher:
+ submitit_folder: ${hydra.sweep.dir}
+ timeout_min: 3000
+ cpus_per_task: 10
+ gpus_per_node: 4
+ tasks_per_node: 4
+ mem_gb: 250
+ nodes: 1
+ name: ${env:PREFIX}_${hydra.job.config_name}
+ partition: devlab,learnlab,learnfair,scavenge
+ constraint: volta32gb
+ max_num_timeout: 30
diff --git a/examples/wav2vec/config/finetuning/vox_960h_2_aws.yaml b/examples/wav2vec/config/finetuning/vox_960h_2_aws.yaml
new file mode 100644
index 000000000..41d2b38f8
--- /dev/null
+++ b/examples/wav2vec/config/finetuning/vox_960h_2_aws.yaml
@@ -0,0 +1,82 @@
+# @package _group_
+
+common:
+ fp16: true
+ log_format: json
+ log_interval: 200
+ user_dir: /data/home/abaevski/fairseq-py/examples/data2vec
+# tensorboard_logdir: tb
+
+checkpoint:
+ save_interval: 1
+ no_epoch_checkpoints: true
+ best_checkpoint_metric: wer
+
+task:
+ _name: audio_finetuning
+ data: /fsx-wav2vec/abaevski/data/librispeech
+ labels: ltr
+ normalize: true
+
+dataset:
+ num_workers: 6
+ max_tokens: 1280000
+ skip_invalid_size_inputs_valid_test: true
+ validate_after_updates: 100
+ validate_interval: 1
+ valid_subset: dev_other
+ required_batch_size_multiple: 1
+
+distributed_training:
+ ddp_backend: legacy_ddp
+ distributed_world_size: 16
+
+criterion:
+ _name: ctc
+ zero_infinity: true
+ post_process: letter
+ wer_kenlm_model: /fsx-wav2vec/abaevski/data/libri/4-gram.bin
+ wer_lexicon: /fsx-wav2vec/abaevski/data/libri/10h/wav2vec/raw/lexicon_ltr2.lst
+ wer_lm_weight: 1.5
+ wer_word_score: 0
+ wer_sil_weight: -1
+
+optimization:
+ max_update: 200000
+ lr: [2e-5]
+# lr: [1e-5] # base 10h wer
+ sentence_avg: true
+ update_freq: [1] # base 10h we -> 2/4
+
+optimizer:
+ _name: adam
+ adam_betas: (0.9,0.98)
+ adam_eps: 1e-08
+
+lr_scheduler:
+ _name: tri_stage
+ phase_ratio: null
+ warmup_steps: 8000
+ hold_steps: 0
+ decay_steps: 192000
+ final_lr_scale: 0.05
+
+model:
+ _name: wav2vec_ctc
+ w2v_path: ???
+ apply_mask: true
+ mask_prob: 0.3
+ mask_length: 5
+# mask_prob: 0.65 # base 10h wer
+ mask_channel_prob: 0.1
+# mask_channel_prob: 0.6 # base 10h wer
+ mask_channel_length: 64
+ layerdrop: 0
+# layerdrop: 0.05 # base 10h wer
+ activation_dropout: 0.1
+ feature_grad_mult: 0.0
+ freeze_finetune_updates: 100
+ dropout: 0
+ final_dropout: 0
+ attention_dropout: 0
+
diff --git a/examples/wav2vec/config/finetuning/vox_960h_3.yaml b/examples/wav2vec/config/finetuning/vox_960h_3.yaml
new file mode 100644
index 000000000..ef6597aa6
--- /dev/null
+++ b/examples/wav2vec/config/finetuning/vox_960h_3.yaml
@@ -0,0 +1,101 @@
+# @package _group_
+
+common:
+ fp16: true
+ log_format: json
+ log_interval: 200
+ user_dir: /private/home/abaevski/fairseq-py/examples/data2vec
+# tensorboard_logdir: tb
+
+checkpoint:
+ save_interval: 1
+ no_epoch_checkpoints: true
+ best_checkpoint_metric: wer
+
+task:
+ _name: audio_finetuning
+ data: /checkpoint/abaevski/data/speech/libri/1h/wav2vec/raw
+ labels: ltr
+ normalize: true
+
+dataset:
+ num_workers: 6
+ max_tokens: 1000000
+ skip_invalid_size_inputs_valid_test: true
+ validate_after_updates: 100
+ validate_interval: 1
+ valid_subset: dev_other
+ required_batch_size_multiple: 1
+
+distributed_training:
+ ddp_backend: legacy_ddp
+ distributed_world_size: 16
+
+criterion:
+ _name: ctc
+ zero_infinity: true
+ post_process: letter
+ wer_kenlm_model: /checkpoint/abaevski/data/speech/libri/4-gram.bin
+ wer_lexicon: /checkpoint/abaevski/data/speech/libri/10h/wav2vec/raw/lexicon_ltr2.lst
+ wer_lm_weight: 2.0
+ wer_word_score: -1.0
+
+optimization:
+ max_update: 200000
+ lr: [1e-5]
+# lr: [1e-5] # base 10h wer
+ sentence_avg: true
+ update_freq: [1] # base 10h we -> 2/4
+
+optimizer:
+ _name: adam
+ adam_betas: (0.9,0.98)
+ adam_eps: 1e-08
+
+lr_scheduler:
+ _name: cosine
+ warmup_updates: 8000
+
+model:
+ _name: wav2vec_ctc
+ w2v_path: ???
+ apply_mask: true
+ mask_prob: 0.4
+ mask_length: 5
+# mask_prob: 0.65 # base 10h wer
+ mask_channel_prob: 0.1
+# mask_channel_prob: 0.6 # base 10h wer
+ mask_channel_length: 64
+ layerdrop: 0.1
+# layerdrop: 0.05 # base 10h wer
+ activation_dropout: 0.1
+ feature_grad_mult: 0.0
+ freeze_finetune_updates: 100
+ dropout: 0
+ final_dropout: 0
+ attention_dropout: 0
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: ':'
+ item_sep: '__'
+ exclude_keys:
+ - run_config
+ - distributed_training.distributed_port
+ sweep:
+ dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}/${hydra.job.override_dirname}
+ subdir: ${hydra.job.num}
+ launcher:
+ submitit_folder: ${hydra.sweep.dir}
+ timeout_min: 3000
+ cpus_per_task: 10
+ gpus_per_node: 4
+ tasks_per_node: 4
+ mem_gb: 250
+ nodes: 1
+ name: ${env:PREFIX}_${hydra.job.config_name}
+ partition: devlab,learnlab,learnfair,scavenge
+ constraint: volta32gb
+ max_num_timeout: 30
diff --git a/fairseq/checkpoint_utils.py b/fairseq/checkpoint_utils.py
index 138b4d1eb..ff1da2553 100644
--- a/fairseq/checkpoint_utils.py
+++ b/fairseq/checkpoint_utils.py
@@ -45,14 +45,14 @@ def save_checkpoint(cfg: CheckpointConfig, trainer, epoch_itr, val_loss):
save_checkpoint.best = best_function(val_loss, prev_best)
if cfg.no_save:
- return
+ return None
trainer.consolidate_optimizer() # TODO(SS): do we need this if no_save_optimizer_state
if not trainer.should_save_checkpoint_on_current_rank:
if trainer.always_call_state_dict_during_save_checkpoint:
trainer.state_dict()
- return
+ return None
write_timer = meters.StopwatchMeter()
write_timer.start()
@@ -111,8 +111,9 @@ def save_checkpoint(cfg: CheckpointConfig, trainer, epoch_itr, val_loss):
checkpoints = [
os.path.join(cfg.save_dir, fn) for fn, cond in checkpoint_conds.items() if cond
]
+ saved_cp = None
if len(checkpoints) > 0 and trainer.should_save_checkpoint_on_current_rank:
- trainer.save_checkpoint(checkpoints[0], extra_state)
+ saved_cp = trainer.save_checkpoint(checkpoints[0], extra_state)
for cp in checkpoints[1:]:
if cfg.write_checkpoints_asynchronously:
# TODO[ioPath]: Need to implement a delayed asynchronous
@@ -133,7 +134,11 @@ def save_checkpoint(cfg: CheckpointConfig, trainer, epoch_itr, val_loss):
)
)
- if not end_of_epoch and cfg.keep_interval_updates > 0:
+ if (
+ not end_of_epoch
+ and cfg.keep_interval_updates > 0
+ and trainer.should_save_checkpoint_on_current_rank
+ ):
# remove old checkpoints; checkpoints are sorted in descending order
if cfg.keep_interval_updates_pattern == -1:
checkpoints = checkpoint_paths(
@@ -157,7 +162,7 @@ def save_checkpoint(cfg: CheckpointConfig, trainer, epoch_itr, val_loss):
elif PathManager.exists(old_chk):
PathManager.rm(old_chk)
- if cfg.keep_last_epochs > 0:
+ if cfg.keep_last_epochs > 0 and trainer.should_save_checkpoint_on_current_rank:
# remove old epoch checkpoints; checkpoints are sorted in descending order
checkpoints = checkpoint_paths(
cfg.save_dir, pattern=r"checkpoint(\d+){}\.pt".format(suffix)
@@ -168,7 +173,7 @@ def save_checkpoint(cfg: CheckpointConfig, trainer, epoch_itr, val_loss):
elif PathManager.exists(old_chk):
PathManager.rm(old_chk)
- if cfg.keep_best_checkpoints > 0:
+ if cfg.keep_best_checkpoints > 0 and trainer.should_save_checkpoint_on_current_rank:
# only keep the best N checkpoints according to validation metric
checkpoints = checkpoint_paths(
cfg.save_dir,
@@ -184,6 +189,8 @@ def save_checkpoint(cfg: CheckpointConfig, trainer, epoch_itr, val_loss):
elif PathManager.exists(old_chk):
PathManager.rm(old_chk)
+ return saved_cp
+
def load_checkpoint(cfg: CheckpointConfig, trainer, **passthrough_args):
"""
@@ -574,6 +581,8 @@ def _torch_persistent_save(obj, f):
if i == 2:
logger.error(traceback.format_exc())
raise
+ else:
+ time.sleep(2.5)
def _upgrade_state_dict(state):
diff --git a/fairseq/config/fb_run_config/slurm.yaml b/fairseq/config/fb_run_config/slurm.yaml
new file mode 100644
index 000000000..20cf8f520
--- /dev/null
+++ b/fairseq/config/fb_run_config/slurm.yaml
@@ -0,0 +1,29 @@
+# @package _global_
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: ':'
+ item_sep: '__'
+ exclude_keys:
+ - fb_run_config
+ - distributed_training.distributed_port
+ sweep:
+ dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
+ launcher:
+ cpus_per_task: 60
+ gpus_per_node: ???
+ tasks_per_node: 1
+ nodes: 1
+ partition: learnfair
+ mem_gb: 400
+ timeout_min: 4320
+ max_num_timeout: 10
+ name: ${env:PREFIX}_${hydra.job.config_name}
+ submitit_folder: ${hydra.sweep.dir}
+
+distributed_training:
+ ddp_backend: c10d
+ distributed_world_size: ???
+ distributed_port: ???
diff --git a/fairseq/criterions/__init__.py b/fairseq/criterions/__init__.py
index 4dbf46a1c..ecd65d34a 100644
--- a/fairseq/criterions/__init__.py
+++ b/fairseq/criterions/__init__.py
@@ -25,8 +25,8 @@ from omegaconf import DictConfig
)
-def build_criterion(cfg: DictConfig, task):
- return build_criterion_(cfg, task)
+def build_criterion(cfg: DictConfig, task, from_checkpoint=False):
+ return build_criterion_(cfg, task, from_checkpoint=from_checkpoint)
# automatically import any Python files in the criterions/ directory
diff --git a/fairseq/criterions/ctc.py b/fairseq/criterions/ctc.py
index 6d53198b0..e55e928b4 100644
--- a/fairseq/criterions/ctc.py
+++ b/fairseq/criterions/ctc.py
@@ -7,18 +7,17 @@
import math
from argparse import Namespace
from dataclasses import dataclass, field
+from omegaconf import II
from typing import Optional
import torch
import torch.nn.functional as F
-from omegaconf import II
-
from fairseq import metrics, utils
from fairseq.criterions import FairseqCriterion, register_criterion
-from fairseq.data.data_utils import post_process
from fairseq.dataclass import FairseqDataclass
-from fairseq.logging.meters import safe_round
+from fairseq.data.data_utils import post_process
from fairseq.tasks import FairseqTask
+from fairseq.logging.meters import safe_round
@dataclass
@@ -54,6 +53,10 @@ class CtcCriterionConfig(FairseqDataclass):
default=-1.0,
metadata={"help": "lm word score to use with wer_kenlm_model"},
)
+ wer_sil_weight: float = field(
+ default=0,
+ metadata={"help": "lm word score to use with wer_kenlm_model"},
+ )
wer_args: Optional[str] = field(
default=None,
@@ -101,6 +104,7 @@ class CtcCriterion(FairseqCriterion):
dec_args.beam_threshold = min(50, len(task.target_dictionary))
dec_args.lm_weight = cfg.wer_lm_weight
dec_args.word_score = cfg.wer_word_score
+ dec_args.sil_weight = cfg.wer_sil_weight
dec_args.unk_weight = -math.inf
dec_args.sil_weight = 0
diff --git a/fairseq/criterions/label_smoothed_cross_entropy.py b/fairseq/criterions/label_smoothed_cross_entropy.py
index 257466903..cb43be0ca 100644
--- a/fairseq/criterions/label_smoothed_cross_entropy.py
+++ b/fairseq/criterions/label_smoothed_cross_entropy.py
@@ -7,11 +7,10 @@ import math
from dataclasses import dataclass, field
import torch
-from omegaconf import II
-
from fairseq import metrics, utils
from fairseq.criterions import FairseqCriterion, register_criterion
from fairseq.dataclass import FairseqDataclass
+from omegaconf import II
@dataclass
diff --git a/fairseq/criterions/model_criterion.py b/fairseq/criterions/model_criterion.py
index f9a810d83..2c9fbb255 100644
--- a/fairseq/criterions/model_criterion.py
+++ b/fairseq/criterions/model_criterion.py
@@ -12,6 +12,7 @@ import torch
from fairseq import metrics, utils
from fairseq.criterions import FairseqCriterion, register_criterion
from fairseq.dataclass import FairseqDataclass
+from fairseq.logging.meters import safe_round
logger = logging.getLogger(__name__)
@@ -27,6 +28,7 @@ class ModelCriterionConfig(FairseqDataclass):
default_factory=list,
metadata={"help": "additional output keys to log"},
)
+ can_sum: bool = True
@register_criterion("model", dataclass=ModelCriterionConfig)
@@ -43,10 +45,11 @@ class ModelCriterion(FairseqCriterion):
net_output dict can be logged via the log_keys parameter.
"""
- def __init__(self, task, loss_weights=None, log_keys=None):
+ def __init__(self, task, loss_weights=None, log_keys=None, can_sum=True):
super().__init__(task)
self.loss_weights = loss_weights
self.log_keys = log_keys
+ self.can_sum = can_sum
def forward(self, model, sample, reduce=True):
net_output = model(**sample["net_input"])
@@ -69,7 +72,7 @@ class ModelCriterion(FairseqCriterion):
)
raise
if coef != 0 and p is not None:
- scaled_losses[lk] = coef * p.float()
+ scaled_losses[lk] = coef * p.float().sum()
loss = sum(scaled_losses.values())
@@ -93,6 +96,8 @@ class ModelCriterion(FairseqCriterion):
if lk in net_output and net_output[lk] is not None:
if not torch.is_tensor(net_output[lk]) or net_output[lk].numel() == 1:
logging_output[lk] = float(net_output[lk])
+ elif lk.startswith("_"):
+ logging_output[lk] = net_output[lk]
else:
for i, v in enumerate(net_output[lk]):
logging_output[f"{lk}_{i}"] = float(v)
@@ -124,6 +129,7 @@ class ModelCriterion(FairseqCriterion):
metrics.log_scalar("loss", loss_sum / sample_size, sample_size, round=3)
metrics.log_scalar("ntokens", ntokens)
metrics.log_scalar("nsentences", nsentences)
+ metrics.log_scalar("sample_size", sample_size)
builtin_keys = {
"loss",
@@ -138,18 +144,33 @@ class ModelCriterion(FairseqCriterion):
)
for k in logging_outputs[0]:
- if k not in builtin_keys:
+ if k not in builtin_keys and not k.startswith("_"):
val = sum(log.get(k, 0) for log in logging_outputs)
if k.startswith("loss_"):
metrics.log_scalar(k, val / sample_size, sample_size, round=3)
else:
metrics.log_scalar(k, val / world_size, round=3)
- @staticmethod
- def logging_outputs_can_be_summed() -> bool:
+ correct = sum(log.get("correct", 0) for log in logging_outputs)
+ total = sum(log.get("count", 0) for log in logging_outputs)
+
+ if total > 0:
+ metrics.log_scalar("_correct", correct)
+ metrics.log_scalar("_total", total)
+
+ metrics.log_derived(
+ "accuracy",
+ lambda meters: safe_round(
+ meters["_correct"].sum / meters["_total"].sum, 5
+ )
+ if meters["_total"].sum > 0
+ else float("nan"),
+ )
+
+ def logging_outputs_can_be_summed(self) -> bool:
"""
Whether the logging outputs returned by `forward` can be summed
across workers prior to calling `reduce_metrics`. Setting this
to True will improves distributed training speed.
"""
- return True
+ return self.can_sum
diff --git a/fairseq/criterions/sentence_prediction.py b/fairseq/criterions/sentence_prediction.py
index b402d7603..01c2a2ba6 100644
--- a/fairseq/criterions/sentence_prediction.py
+++ b/fairseq/criterions/sentence_prediction.py
@@ -5,13 +5,49 @@
import math
from dataclasses import dataclass, field
+from itertools import chain
+import numpy as np
import torch
import torch.nn.functional as F
+from sklearn.metrics import f1_score
+from sklearn.metrics import matthews_corrcoef as _matthews_corrcoef
+from scipy.stats import pearsonr, spearmanr
from fairseq import metrics
from fairseq.criterions import FairseqCriterion, register_criterion
from fairseq.dataclass import FairseqDataclass
+from fairseq.logging.meters import safe_round
+
+
+def simple_accuracy(preds, labels):
+ return (preds == labels).mean()
+
+
+def acc_and_f1(preds, labels):
+ acc = simple_accuracy(preds, labels)
+ f1 = f1_score(y_true=labels, y_pred=preds)
+ return {
+ "acc": acc,
+ "f1": f1,
+ "acc_and_f1": (acc + f1) / 2,
+ }
+
+
+def pearson_and_spearman(preds, labels):
+ pearson_corr = pearsonr(preds, labels)[0]
+ spearman_corr = spearmanr(preds, labels)[0]
+ return {
+ "pearson": pearson_corr,
+ "spearmanr": spearman_corr,
+ "corr": (pearson_corr + spearman_corr) / 2,
+ }
+
+
+def matthews_corrcoef(preds, labels):
+ # make it consistent with other metrics taking (preds, labels) as input
+ mcc = _matthews_corrcoef(labels, preds)
+ return mcc
@dataclass
@@ -23,6 +59,9 @@ class SentencePredictionConfig(FairseqDataclass):
regression_target: bool = field(
default=False,
)
+ report_mcc: bool = False
+ report_acc_and_f1: bool = False
+ report_pearson_and_spearman: bool = False
@register_criterion("sentence_prediction", dataclass=SentencePredictionConfig)
@@ -31,6 +70,13 @@ class SentencePredictionCriterion(FairseqCriterion):
super().__init__(task)
self.classification_head_name = cfg.classification_head_name
self.regression_target = cfg.regression_target
+ self.keep_pred_and_targ = (
+ cfg.report_mcc or cfg.report_acc_and_f1 or cfg.report_pearson_and_spearman
+ )
+ self.report_mcc = cfg.report_mcc
+ self.report_acc_and_f1 = cfg.report_acc_and_f1
+ self.report_pearson_and_spearman = cfg.report_pearson_and_spearman
+ self.label_dict = task.label_dictionary
def forward(self, model, sample, reduce=True):
"""Compute the loss for the given sample.
@@ -65,14 +111,16 @@ class SentencePredictionCriterion(FairseqCriterion):
loss = task_loss
# mha & ffn regularization update
if (
- hasattr(model.args, "mha_reg_scale_factor")
+ hasattr(model, "args")
+ and hasattr(model.args, "mha_reg_scale_factor")
and model.args.mha_reg_scale_factor != 0.0
):
mha_reg_loss = model._get_adaptive_head_loss()
loss += mha_reg_loss
logging_output.update({"mha_reg_loss": mha_reg_loss})
if (
- hasattr(model.args, "ffn_reg_scale_factor")
+ hasattr(model, "args")
+ and hasattr(model.args, "ffn_reg_scale_factor")
and model.args.ffn_reg_scale_factor != 0.0
):
ffn_reg_loss = model._get_adaptive_ffn_loss()
@@ -90,6 +138,25 @@ class SentencePredictionCriterion(FairseqCriterion):
if not self.regression_target:
preds = logits.argmax(dim=1)
logging_output["ncorrect"] = (preds == targets).sum()
+ if self.keep_pred_and_targ and not model.training:
+ if self.regression_target:
+ logging_output["pred"] = logits.detach().cpu().tolist()
+ logging_output["targ"] = targets.detach().cpu().tolist()
+ else:
+ # remove offset `self.label_dict.nspecial` from OffsetTokensDataset
+ preds = self.label_dict.string(preds + self.label_dict.nspecial).split()
+ targets = self.label_dict.string(
+ targets + self.label_dict.nspecial
+ ).split()
+ logging_output["pred"] = list(map(int, preds))
+ logging_output["targ"] = list(map(int, targets))
+
+ if self.report_mcc:
+ logging_output["report_mcc"] = True
+ if self.report_acc_and_f1:
+ logging_output["report_acc_and_f1"] = True
+ if self.report_pearson_and_spearman:
+ logging_output["report_pearson_and_spearman"] = True
return loss, sample_size, logging_output
@@ -131,6 +198,86 @@ class SentencePredictionCriterion(FairseqCriterion):
"accuracy", 100.0 * ncorrect / nsentences, nsentences, round=1
)
+ # Metrics used by GLUE
+ pred = np.array(
+ list(chain.from_iterable(log.get("pred", []) for log in logging_outputs))
+ )
+ targ = np.array(
+ list(chain.from_iterable(log.get("targ", []) for log in logging_outputs))
+ )
+ if len(pred):
+ metrics.log_concat_tensor("pred", torch.from_numpy(pred), dim=0)
+ metrics.log_concat_tensor("targ", torch.from_numpy(targ), dim=0)
+ if any("report_mcc" in log for log in logging_outputs):
+ metrics.log_derived(
+ "mcc",
+ lambda meters: safe_round(
+ matthews_corrcoef(
+ meters["pred"].tensor.numpy(),
+ meters["targ"].tensor.numpy(),
+ )
+ * 100,
+ 1,
+ ),
+ )
+ if any("report_acc_and_f1" in log for log in logging_outputs):
+ metrics.log_derived(
+ "acc_and_f1",
+ lambda meters: safe_round(
+ acc_and_f1(
+ meters["pred"].tensor.numpy(),
+ meters["targ"].tensor.numpy(),
+ )["acc_and_f1"]
+ * 100,
+ 1,
+ ),
+ )
+ metrics.log_derived(
+ "f1",
+ lambda meters: safe_round(
+ acc_and_f1(
+ meters["pred"].tensor.numpy(),
+ meters["targ"].tensor.numpy(),
+ )["f1"]
+ * 100,
+ 1,
+ ),
+ )
+ if any("report_pearson_and_spearman" in log for log in logging_outputs):
+ metrics.log_derived(
+ "pearson_and_spearman",
+ lambda meters: safe_round(
+ pearson_and_spearman(
+ meters["pred"].tensor.numpy(),
+ meters["targ"].tensor.numpy(),
+ )["corr"]
+ * 100,
+ 1,
+ ),
+ )
+ metrics.log_derived(
+ "pearson",
+ lambda meters: safe_round(
+ pearson_and_spearman(
+ meters["pred"].tensor.numpy(),
+ meters["targ"].tensor.numpy(),
+ )["pearson"]
+ * 100,
+ 1,
+ ),
+ )
+ metrics.log_derived(
+ "spearman",
+ lambda meters: safe_round(
+ pearson_and_spearman(
+ meters["pred"].tensor.numpy(),
+ meters["targ"].tensor.numpy(),
+ )["spearmanr"]
+ * 100,
+ 1,
+ ),
+ )
+
@staticmethod
def logging_outputs_can_be_summed() -> bool:
"""
diff --git a/fairseq/data/__init__.py b/fairseq/data/__init__.py
index 8acf2ca17..a27e3184a 100644
--- a/fairseq/data/__init__.py
+++ b/fairseq/data/__init__.py
@@ -39,6 +39,11 @@ from .noising import NoisingDataset
from .numel_dataset import NumelDataset
from .num_samples_dataset import NumSamplesDataset
from .offset_tokens_dataset import OffsetTokensDataset
+from .padding_mask_dataset import (
+ LeftPaddingMaskDataset,
+ PaddingMaskDataset,
+ RightPaddingMaskDataset,
+)
from .pad_dataset import LeftPadDataset, PadDataset, RightPadDataset
from .prepend_dataset import PrependDataset
from .prepend_token_dataset import PrependTokenDataset
diff --git a/fairseq/data/add_class_target_dataset.py b/fairseq/data/add_class_target_dataset.py
new file mode 100644
index 000000000..bf89f2565
--- /dev/null
+++ b/fairseq/data/add_class_target_dataset.py
@@ -0,0 +1,79 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+
+from . import BaseWrapperDataset, data_utils
+from fairseq.data.text_compressor import TextCompressor, TextCompressionLevel
+
+
+class AddTargetDataset(BaseWrapperDataset):
+ def __init__(
+ self,
+ dataset,
+ labels,
+ pad,
+ eos,
+ batch_targets,
+ process_label=None,
+ label_len_fn=None,
+ add_to_input=False,
+ text_compression_level=TextCompressionLevel.none,
+ ):
+ super().__init__(dataset)
+ self.labels = labels
+ self.batch_targets = batch_targets
+ self.pad = pad
+ self.eos = eos
+ self.process_label = process_label
+ self.label_len_fn = label_len_fn
+ self.add_to_input = add_to_input
+ self.text_compressor = TextCompressor(level=text_compression_level)
+
+ def get_label(self, index, process_fn=None):
+ lbl = self.labels[index]
+ lbl = self.text_compressor.decompress(lbl)
+ return lbl if process_fn is None else process_fn(lbl)
+
+ def __getitem__(self, index):
+ item = self.dataset[index]
+ item["label"] = self.get_label(index, process_fn=self.process_label)
+ return item
+
+ def size(self, index):
+ sz = self.dataset.size(index)
+ own_sz = self.label_len_fn(self.get_label(index))
+ return sz, own_sz
+
+ def collater(self, samples):
+ collated = self.dataset.collater(samples)
+ if len(collated) == 0:
+ return collated
+ indices = set(collated["id"].tolist())
+ target = [s["label"] for s in samples if s["id"] in indices]
+
+ if self.batch_targets:
+ collated["target_lengths"] = torch.LongTensor([len(t) for t in target])
+ target = data_utils.collate_tokens(target, pad_idx=self.pad, left_pad=False)
+ collated["ntokens"] = collated["target_lengths"].sum().item()
+ else:
+ collated["ntokens"] = sum([len(t) for t in target])
+
+ collated["target"] = target
+
+ if self.add_to_input:
+ eos = target.new_full((target.size(0), 1), self.eos)
+ collated["target"] = torch.cat([target, eos], dim=-1).long()
+ collated["net_input"]["prev_output_tokens"] = torch.cat(
+ [eos, target], dim=-1
+ ).long()
+ collated["ntokens"] += target.size(0)
+ return collated
+
+ def filter_indices_by_size(self, indices, max_sizes):
+ indices, ignored = data_utils._filter_by_size_dynamic(
+ indices, self.size, max_sizes
+ )
+ return indices, ignored
diff --git a/fairseq/data/audio/multi_modality_dataset.py b/fairseq/data/audio/multi_modality_dataset.py
index 2db163734..0a42c1061 100644
--- a/fairseq/data/audio/multi_modality_dataset.py
+++ b/fairseq/data/audio/multi_modality_dataset.py
@@ -10,7 +10,6 @@ import math
from typing import List, Optional, NamedTuple
import numpy as np
-from fairseq.data.resampling_dataset import ResamplingDataset
import torch
from fairseq.data import (
ConcatDataset,
@@ -31,16 +30,6 @@ class ModalityDatasetItem(NamedTuple):
max_sentences: Optional[int] = None
-def resampling_dataset_present(ds):
- if isinstance(ds, ResamplingDataset):
- return True
- if isinstance(ds, ConcatDataset):
- return any(resampling_dataset_present(d) for d in ds.datasets)
- if hasattr(ds, "dataset"):
- return resampling_dataset_present(ds.dataset)
- return False
-
-
# MultiModalityDataset: it concate multiple datasets with different modalities.
# Compared with ConcatDataset it can 1) sample data given the ratios for different datasets
# 2) it adds mode to indicate what type of the data samples come from.
@@ -106,7 +95,7 @@ class MultiModalityDataset(ConcatDataset):
Returns indices sorted by length. So less padding is needed.
"""
if len(self.datasets) == 1:
- return self.datasets[0].ordered_indices()
+ return [self.datasets[0].ordered_indices()]
indices_group = []
for d_idx, ds in enumerate(self.datasets):
sample_num = self.cumulative_sizes[d_idx]
@@ -117,16 +106,13 @@ class MultiModalityDataset(ConcatDataset):
return indices_group
def get_raw_batch_samplers(self, required_batch_size_multiple, seed):
+ if len(self.raw_sub_batch_samplers) > 0:
+ logger.info(" raw_sub_batch_samplers exists. No action is taken")
+ return
with data_utils.numpy_seed(seed):
indices = self.ordered_indices()
+
for i, ds in enumerate(self.datasets):
- # If we have ResamplingDataset, the same id can correpond to a different
- # sample in the next epoch, so we need to rebuild this at every epoch
- if i < len(self.raw_sub_batch_samplers) and not resampling_dataset_present(
- ds
- ):
- logger.info(f"dataset {i} is valid and it is not re-sampled")
- continue
indices[i] = ds.filter_indices_by_size(
indices[i],
self.max_positions[i],
@@ -137,10 +123,7 @@ class MultiModalityDataset(ConcatDataset):
max_sentences=self.max_sentences[i],
required_batch_size_multiple=required_batch_size_multiple,
)
- if i < len(self.raw_sub_batch_samplers):
- self.raw_sub_batch_samplers[i] = sub_batch_sampler
- else:
- self.raw_sub_batch_samplers.append(sub_batch_sampler)
+ self.raw_sub_batch_samplers.append(sub_batch_sampler)
def get_batch_samplers(self, mult_ratios, required_batch_size_multiple, seed):
self.get_raw_batch_samplers(required_batch_size_multiple, seed)
diff --git a/fairseq/data/audio/raw_audio_dataset.py b/fairseq/data/audio/raw_audio_dataset.py
index 181e2bbc9..edb307e68 100644
--- a/fairseq/data/audio/raw_audio_dataset.py
+++ b/fairseq/data/audio/raw_audio_dataset.py
@@ -7,6 +7,7 @@
import logging
import os
import sys
+import time
import io
import numpy as np
@@ -14,7 +15,7 @@ import torch
import torch.nn.functional as F
from .. import FairseqDataset
-from ..data_utils import compute_mask_indices, get_buckets, get_bucketed_sizes
+from ..data_utils import compute_block_mask_1d, get_buckets, get_bucketed_sizes
from fairseq.data.audio.audio_utils import (
parse_path,
read_from_stored_zip,
@@ -35,8 +36,17 @@ class RawAudioDataset(FairseqDataset):
shuffle=True,
pad=False,
normalize=False,
- compute_mask_indices=False,
- **mask_compute_kwargs,
+ compute_mask=False,
+ feature_encoder_spec: str = "None",
+ mask_prob: float = 0.75,
+ mask_prob_adjust: float = 0,
+ mask_length: int = 1,
+ inverse_mask: bool = False,
+ require_same_masks: bool = True,
+ clone_batch: int = 1,
+ expand_adjacent: bool = False,
+ mask_dropout: float = 0,
+ non_overlapping: bool = False,
):
super().__init__()
@@ -49,12 +59,19 @@ class RawAudioDataset(FairseqDataset):
self.pad = pad
self.shuffle = shuffle
self.normalize = normalize
- self.compute_mask_indices = compute_mask_indices
- if self.compute_mask_indices:
- self.mask_compute_kwargs = mask_compute_kwargs
- self._features_size_map = {}
- self._C = mask_compute_kwargs["encoder_embed_dim"]
- self._conv_feature_layers = eval(mask_compute_kwargs["conv_feature_layers"])
+
+ self.is_compute_mask = compute_mask
+ self.feature_encoder_spec = eval(feature_encoder_spec)
+ self._features_size_map = {}
+ self.mask_prob = mask_prob
+ self.mask_prob_adjust = mask_prob_adjust
+ self.mask_length = mask_length
+ self.inverse_mask = inverse_mask
+ self.require_same_masks = require_same_masks
+ self.clone_batch = clone_batch
+ self.expand_adjacent = expand_adjacent
+ self.mask_dropout = mask_dropout
+ self.non_overlapping = non_overlapping
def __getitem__(self, index):
raise NotImplementedError()
@@ -76,48 +93,21 @@ class RawAudioDataset(FairseqDataset):
feats = F.layer_norm(feats, feats.shape)
return feats
- def crop_to_max_size(self, wav, target_size):
- size = len(wav)
+ def crop_to_max_size(self, t, target_size, dim=0):
+ size = t.size(dim)
diff = size - target_size
if diff <= 0:
- return wav
+ return t
start = np.random.randint(0, diff + 1)
end = size - diff + start
- return wav[start:end]
- def _compute_mask_indices(self, dims, padding_mask):
- B, T, C = dims
- mask_indices, mask_channel_indices = None, None
- if self.mask_compute_kwargs["mask_prob"] > 0:
- mask_indices = compute_mask_indices(
- (B, T),
- padding_mask,
- self.mask_compute_kwargs["mask_prob"],
- self.mask_compute_kwargs["mask_length"],
- self.mask_compute_kwargs["mask_selection"],
- self.mask_compute_kwargs["mask_other"],
- min_masks=2,
- no_overlap=self.mask_compute_kwargs["no_mask_overlap"],
- min_space=self.mask_compute_kwargs["mask_min_space"],
- )
- mask_indices = torch.from_numpy(mask_indices)
- if self.mask_compute_kwargs["mask_channel_prob"] > 0:
- mask_channel_indices = compute_mask_indices(
- (B, C),
- None,
- self.mask_compute_kwargs["mask_channel_prob"],
- self.mask_compute_kwargs["mask_channel_length"],
- self.mask_compute_kwargs["mask_channel_selection"],
- self.mask_compute_kwargs["mask_channel_other"],
- no_overlap=self.mask_compute_kwargs["no_mask_channel_overlap"],
- min_space=self.mask_compute_kwargs["mask_channel_min_space"],
- )
- mask_channel_indices = (
- torch.from_numpy(mask_channel_indices).unsqueeze(1).expand(-1, T, -1)
- )
+ slices = []
+ for d in range(dim):
+ slices.append(slice(None))
+ slices.append(slice(start, end))
- return mask_indices, mask_channel_indices
+ return t[slices]
@staticmethod
def _bucket_tensor(tensor, num_pad, value):
@@ -166,33 +156,24 @@ class RawAudioDataset(FairseqDataset):
input["source"] = self._bucket_tensor(collated_sources, num_pad, 0)
input["padding_mask"] = self._bucket_tensor(padding_mask, num_pad, True)
- if self.compute_mask_indices:
- B = input["source"].size(0)
- T = self._get_mask_indices_dims(input["source"].size(-1))
- padding_mask_reshaped = input["padding_mask"].clone()
- extra = padding_mask_reshaped.size(1) % T
- if extra > 0:
- padding_mask_reshaped = padding_mask_reshaped[:, :-extra]
- padding_mask_reshaped = padding_mask_reshaped.view(
- padding_mask_reshaped.size(0), T, -1
+ if "precomputed_mask" in samples[0]:
+ target_size = self._get_mask_indices_dims(target_size)
+ collated_mask = torch.cat(
+ [
+ self.crop_to_max_size(s["precomputed_mask"], target_size, dim=1)
+ for s in samples
+ ],
+ dim=0,
)
- padding_mask_reshaped = padding_mask_reshaped.all(-1)
- input["padding_count"] = padding_mask_reshaped.sum(-1).max().item()
- mask_indices, mask_channel_indices = self._compute_mask_indices(
- (B, T, self._C),
- padding_mask_reshaped,
- )
- input["mask_indices"] = mask_indices
- input["mask_channel_indices"] = mask_channel_indices
- out["sample_size"] = mask_indices.sum().item()
+ input["precomputed_mask"] = collated_mask
out["net_input"] = input
return out
def _get_mask_indices_dims(self, size, padding=0, dilation=1):
- if size not in self._features_size_map:
+ if size not in self.feature_encoder_spec:
L_in = size
- for (_, kernel_size, stride) in self._conv_feature_layers:
+ for (_, kernel_size, stride) in self.feature_encoder_spec:
L_out = L_in + 2 * padding - dilation * (kernel_size - 1) - 1
L_out = 1 + L_out // stride
L_in = L_out
@@ -244,6 +225,9 @@ class RawAudioDataset(FairseqDataset):
f"{self.buckets}"
)
+ def filter_indices_by_size(self, indices, max_sizes):
+ return indices, []
+
class FileAudioDataset(RawAudioDataset):
def __init__(
@@ -256,7 +240,7 @@ class FileAudioDataset(RawAudioDataset):
pad=False,
normalize=False,
num_buckets=0,
- compute_mask_indices=False,
+ compute_mask=False,
text_compression_level=TextCompressionLevel.none,
**mask_compute_kwargs,
):
@@ -267,7 +251,7 @@ class FileAudioDataset(RawAudioDataset):
shuffle=shuffle,
pad=pad,
normalize=normalize,
- compute_mask_indices=compute_mask_indices,
+ compute_mask=compute_mask,
**mask_compute_kwargs,
)
@@ -319,11 +303,43 @@ class FileAudioDataset(RawAudioDataset):
assert is_sf_audio_data(byte_data)
path_or_fp = io.BytesIO(byte_data)
- wav, curr_sample_rate = sf.read(path_or_fp, dtype="float32")
+ retry = 3
+ wav = None
+ for i in range(retry):
+ try:
+ wav, curr_sample_rate = sf.read(path_or_fp, dtype="float32")
+ break
+ except Exception as e:
+ logger.warning(
+ f"Failed to read {path_or_fp}: {e}. Sleeping for {1 * i}"
+ )
+ time.sleep(1 * i)
+
+ if wav is None:
+ raise Exception(f"Failed to load {path_or_fp}")
feats = torch.from_numpy(wav).float()
feats = self.postprocess(feats, curr_sample_rate)
- return {"id": index, "source": feats}
+
+ v = {"id": index, "source": feats}
+
+ if self.is_compute_mask:
+ T = self._get_mask_indices_dims(feats.size(-1))
+ mask = compute_block_mask_1d(
+ shape=(self.clone_batch, T),
+ mask_prob=self.mask_prob,
+ mask_length=self.mask_length,
+ mask_prob_adjust=self.mask_prob_adjust,
+ inverse_mask=self.inverse_mask,
+ require_same_masks=True,
+ expand_adjcent=self.expand_adjacent,
+ mask_dropout=self.mask_dropout,
+ non_overlapping=self.non_overlapping,
+ )
+
+ v["precomputed_mask"] = mask
+
+ return v
class BinarizedAudioDataset(RawAudioDataset):
@@ -338,7 +354,7 @@ class BinarizedAudioDataset(RawAudioDataset):
pad=False,
normalize=False,
num_buckets=0,
- compute_mask_indices=False,
+ compute_mask=False,
**mask_compute_kwargs,
):
super().__init__(
@@ -348,7 +364,7 @@ class BinarizedAudioDataset(RawAudioDataset):
shuffle=shuffle,
pad=pad,
normalize=normalize,
- compute_mask_indices=compute_mask_indices,
+ compute_mask=compute_mask,
**mask_compute_kwargs,
)
@@ -390,4 +406,22 @@ class BinarizedAudioDataset(RawAudioDataset):
wav, curr_sample_rate = sf.read(fname)
feats = torch.from_numpy(wav).float()
feats = self.postprocess(feats, curr_sample_rate)
- return {"id": index, "source": feats}
+ v = {"id": index, "source": feats}
+
+ if self.is_compute_mask:
+ T = self._get_mask_indices_dims(feats.size(-1))
+ mask = compute_block_mask_1d(
+ shape=(self.clone_batch, T),
+ mask_prob=self.mask_prob,
+ mask_length=self.mask_length,
+ mask_prob_adjust=self.mask_prob_adjust,
+ inverse_mask=self.inverse_mask,
+ require_same_masks=True,
+ expand_adjcent=self.expand_adjacent,
+ mask_dropout=self.mask_dropout,
+ non_overlapping=self.non_overlapping,
+ )
+
+ v["precomputed_mask"] = mask
+
+ return v
diff --git a/fairseq/data/data_utils.py b/fairseq/data/data_utils.py
index 0372d52b0..9a19cc3c1 100644
--- a/fairseq/data/data_utils.py
+++ b/fairseq/data/data_utils.py
@@ -14,6 +14,7 @@ import re
import warnings
from typing import Optional, Tuple
+import math
import numpy as np
import torch
@@ -337,7 +338,7 @@ def batch_by_size(
if fixed_shapes is None:
if num_tokens_vec is None:
- return batch_by_size_fn(
+ b = batch_by_size_fn(
indices,
num_tokens_fn,
max_tokens,
@@ -345,7 +346,7 @@ def batch_by_size(
bsz_mult,
)
else:
- return batch_by_size_vec(
+ b = batch_by_size_vec(
indices,
num_tokens_vec,
max_tokens,
@@ -353,6 +354,11 @@ def batch_by_size(
bsz_mult,
)
+ if bsz_mult > 1 and len(b[-1]) % bsz_mult != 0:
+ b = b[:-1]
+
+ return b
+
else:
fixed_shapes = np.array(fixed_shapes, dtype=np.int64)
sort_order = np.lexsort(
@@ -402,6 +408,12 @@ def compute_mask_indices(
min_space: int = 0,
require_same_masks: bool = True,
mask_dropout: float = 0.0,
+ add_masks: bool = False,
+ seed: Optional[int] = None,
+ epoch: Optional[int] = None,
+ indices: Optional[torch.Tensor] = None,
+ idc_select_ver: int = 1, # 2 to reproduce mask_tokens_dataset
+ num_mask_ver: int = 2, # 2 to reproduce mask_tokens_dataset
) -> np.ndarray:
"""
Computes random mask spans for a given shape
@@ -428,49 +440,73 @@ def compute_mask_indices(
bsz, all_sz = shape
mask = np.full((bsz, all_sz), False)
- all_num_mask = int(
- # add a random number for probabilistic rounding
- mask_prob * all_sz / float(mask_length)
- + np.random.rand()
- )
-
- all_num_mask = max(min_masks, all_num_mask)
+ if num_mask_ver == 1:
+ all_num_mask = int(
+ # add a random number for probabilistic rounding
+ mask_prob * all_sz / float(mask_length)
+ + np.random.rand()
+ )
+ all_num_mask = max(min_masks, all_num_mask)
mask_idcs = []
for i in range(bsz):
+ if seed is not None and epoch is not None and indices is not None:
+ seed_i = int(hash((seed, epoch, indices[i].item())) % 1e6)
+ else:
+ seed_i = None
+
+ rng = np.random.default_rng(seed_i)
+
if padding_mask is not None:
sz = all_sz - padding_mask[i].long().sum().item()
+ assert sz >= 0, sz
+ else:
+ sz = all_sz
+
+ if num_mask_ver == 1:
+ if padding_mask is not None:
+ num_mask = int(
+ # add a random number for probabilistic rounding
+ mask_prob * sz / float(mask_length)
+ + np.random.rand()
+ )
+ num_mask = max(min_masks, num_mask)
+ else:
+ num_mask = all_num_mask
+ elif num_mask_ver == 2:
num_mask = int(
# add a random number for probabilistic rounding
mask_prob * sz / float(mask_length)
- + np.random.rand()
+ + rng.random()
)
num_mask = max(min_masks, num_mask)
else:
- sz = all_sz
- num_mask = all_num_mask
+ raise ValueError()
if mask_type == "static":
lengths = np.full(num_mask, mask_length)
elif mask_type == "uniform":
- lengths = np.random.randint(mask_other, mask_length * 2 + 1, size=num_mask)
+ lengths = rng.randint(mask_other, mask_length * 2 + 1, size=num_mask)
elif mask_type == "normal":
- lengths = np.random.normal(mask_length, mask_other, size=num_mask)
+ lengths = rng.normal(mask_length, mask_other, size=num_mask)
lengths = [max(1, int(round(x))) for x in lengths]
elif mask_type == "poisson":
- lengths = np.random.poisson(mask_length, size=num_mask)
+ lengths = rng.poisson(mask_length, size=num_mask)
lengths = [int(round(x)) for x in lengths]
else:
raise Exception("unknown mask selection " + mask_type)
if sum(lengths) == 0:
- lengths[0] = min(mask_length, sz - 1)
+ if mask_type == "static":
+ raise ValueError(f"this should never happens")
+ else:
+ lengths = [min(mask_length, sz - 1)]
if no_overlap:
mask_idc = []
def arrange(s, e, length, keep_length):
- span_start = np.random.randint(s, e - length)
+ span_start = rng.randint(s, e - length)
mask_idc.extend(span_start + i for i in range(length))
new_parts = []
@@ -491,16 +527,20 @@ def compute_mask_indices(
if l_sum == 0:
break
probs = lens / np.sum(lens)
- c = np.random.choice(len(parts), p=probs)
+ c = rng.choice(len(parts), p=probs)
s, e = parts.pop(c)
parts.extend(arrange(s, e, length, min_length))
mask_idc = np.asarray(mask_idc)
else:
- min_len = min(lengths)
- if sz - min_len <= num_mask:
- min_len = sz - num_mask - 1
-
- mask_idc = np.random.choice(sz - min_len, num_mask, replace=False)
+ if idc_select_ver == 1:
+ min_len = min(lengths)
+ if sz - min_len <= num_mask:
+ min_len = sz - num_mask - 1
+ mask_idc = rng.choice(sz - min_len, num_mask, replace=False)
+ elif idc_select_ver == 2:
+ mask_idc = rng.choice(sz, num_mask, replace=False)
+ else:
+ raise ValueError()
mask_idc = np.asarray(
[
@@ -510,20 +550,300 @@ def compute_mask_indices(
]
)
- mask_idcs.append(np.unique(mask_idc[mask_idc < sz]))
-
- min_len = min([len(m) for m in mask_idcs])
- for i, mask_idc in enumerate(mask_idcs):
- if len(mask_idc) > min_len and require_same_masks:
- mask_idc = np.random.choice(mask_idc, min_len, replace=False)
- if mask_dropout > 0:
- num_holes = np.rint(len(mask_idc) * mask_dropout).astype(int)
- mask_idc = np.random.choice(
- mask_idc, len(mask_idc) - num_holes, replace=False
+ mask_idc = np.unique(mask_idc[mask_idc < sz])
+ if len(mask_idc) >= sz:
+ raise ValueError(
+ (
+ f"the entire sequence is masked. "
+ f"sz={sz}; mask_idc[mask_idc]; "
+ f"index={indices[i] if indices is not None else None}"
+ )
)
+ mask_idcs.append(mask_idc)
+
+ target_len = None
+ if require_same_masks:
+ if add_masks:
+ target_len = max([len(m) for m in mask_idcs])
+ else:
+ target_len = min([len(m) for m in mask_idcs])
+
+ for i, mask_idc in enumerate(mask_idcs):
+ if target_len is not None and len(mask_idc) > target_len:
+ mask_idc = rng.choice(mask_idc, target_len, replace=False)
mask[i, mask_idc] = True
+ if target_len is not None and len(mask_idc) < target_len:
+ unmasked = np.flatnonzero(~mask[i])
+ to_mask = rng.choice(unmasked, target_len - len(mask_idc), replace=False)
+ mask[i, to_mask] = True
+
+ if mask_dropout > 0:
+ masked = np.flatnonzero(mask[i])
+ num_holes = np.rint(len(masked) * mask_dropout).astype(int)
+ to_drop = rng.choice(masked, num_holes, replace=False)
+ mask[i, to_drop] = False
+
+ return mask
+
+
+def compute_block_mask_2d(
+ shape: Tuple[int, int],
+ mask_prob: float,
+ mask_length: int,
+ mask_prob_adjust: float = 0,
+ inverse_mask: bool = False,
+ require_same_masks: bool = True,
+ expand_adjcent: bool = False,
+ mask_dropout: float = 0,
+ non_overlapping: bool = False,
+) -> torch.Tensor:
+
+ assert mask_length > 1
+
+ B, L = shape
+
+ d = int(L**0.5)
+
+ if inverse_mask:
+ mask_prob = 1 - mask_prob
+
+ if non_overlapping:
+ sz = math.ceil(d / mask_length)
+ inp_len = sz * sz
+
+ inp = torch.zeros((B, 1, sz, sz))
+ w = torch.ones((1, 1, mask_length, mask_length))
+
+ mask_inds = torch.multinomial(
+ 1 - inp.view(B, -1),
+ int(inp_len * (mask_prob + mask_prob_adjust) * (1 + mask_dropout)),
+ replacement=False,
+ )
+ inp.view(B, -1).scatter_(1, mask_inds, 1)
+
+ mask = torch.nn.functional.conv_transpose2d(inp, w, stride=mask_length).squeeze(
+ 1
+ )
+ if mask.size(-1) > d:
+ mask = mask[..., :d, :d]
+ else:
+ mask = torch.zeros((B, d, d))
+ mask_inds = torch.randint(
+ 0,
+ L,
+ size=(
+ B,
+ int(
+ L
+ * ((mask_prob + mask_prob_adjust) / mask_length**2)
+ * (1 + mask_dropout)
+ ),
+ ),
+ )
+ mask.view(B, -1).scatter_(1, mask_inds, 1)
+ centers = mask.nonzero(as_tuple=True)
+
+ inds = ([], [], [])
+
+ offset = mask_length // 2
+ for i in range(mask_length):
+ for j in range(mask_length):
+ k1 = i - offset
+ k2 = j - offset
+ inds[0].append(centers[0])
+ inds[1].append(centers[1] + k1)
+ inds[2].append(centers[2] + k2)
+
+ i0 = torch.cat(inds[0])
+ i1 = torch.cat(inds[1]).clamp_(min=0, max=d - 1)
+ i2 = torch.cat(inds[2]).clamp_(min=0, max=d - 1)
+
+ mask[(i0, i1, i2)] = 1
+
+ def get_nbs(b, m, w):
+ all_nbs = torch.nn.functional.conv2d(m.unsqueeze(1), w, padding="same")
+ all_nbs = all_nbs.clamp_max_(1).view(b, -1)
+ return all_nbs
+
+ if require_same_masks and expand_adjcent:
+ w = torch.zeros((1, 1, 3, 3))
+ w[..., 0, 1] = 1
+ w[..., 2, 1] = 1
+ w[..., 1, 0] = 1
+ w[..., 1, 2] = 1
+
+ all_nbs = get_nbs(B, mask, w)
+
+ mask = mask.reshape(B, -1)
+
+ if require_same_masks:
+ n_masks = mask.sum(dim=-1)
+ final_target_len = int(L * (mask_prob))
+ target_len = int(final_target_len * (1 + mask_dropout))
+
+ for i in range(len(mask)):
+ n = n_masks[i]
+ m = mask[i]
+ r = 0
+ while expand_adjcent and n < target_len:
+ if r == 0:
+ nbs = all_nbs[i]
+ else:
+ nbs = get_nbs(1, m.view(1, d, d), w).flatten()
+
+ cands = (1 - m + nbs) > 1
+ cand_sz = int(cands.sum().item())
+
+ assert cand_sz > 0, f"{nbs} {cand_sz}"
+
+ to_mask = torch.multinomial(
+ cands.float(), min(cand_sz, int(target_len - n)), replacement=False
+ )
+ m[to_mask] = 1
+ assert to_mask.numel() > 0
+ n += to_mask.numel()
+ r += 1
+
+ if n > final_target_len:
+ to_unmask = torch.multinomial(
+ m, int(n - final_target_len), replacement=False
+ )
+ m[to_unmask] = 0
+ elif n < final_target_len:
+ to_mask = torch.multinomial(
+ (1 - m), int(final_target_len - n), replacement=False
+ )
+ m[to_mask] = 1
+
+ if inverse_mask:
+ mask = 1 - mask
+
+ return mask
+
+
+def compute_block_mask_1d(
+ shape: Tuple[int, int],
+ mask_prob: float,
+ mask_length: int,
+ mask_prob_adjust: float = 0,
+ inverse_mask: bool = False,
+ require_same_masks: bool = True,
+ expand_adjcent: bool = False,
+ mask_dropout: float = 0,
+ non_overlapping: bool = False,
+) -> torch.Tensor:
+
+ B, L = shape
+
+ if inverse_mask:
+ mask_prob = 1 - mask_prob
+
+ if non_overlapping:
+ sz = math.ceil(L / mask_length)
+
+ inp = torch.zeros((B, 1, sz))
+ w = torch.ones((1, 1, mask_length))
+
+ mask_inds = torch.multinomial(
+ 1 - inp.view(B, -1),
+ int(sz * (mask_prob + mask_prob_adjust) * (1 + mask_dropout)),
+ replacement=False,
+ )
+ inp.view(B, -1).scatter_(1, mask_inds, 1)
+
+ mask = torch.nn.functional.conv_transpose1d(inp, w, stride=mask_length).squeeze(
+ 1
+ )
+ if mask.size(-1) > L:
+ mask = mask[..., :L]
+
+ else:
+ mask = torch.zeros((B, L))
+ mask_inds = torch.randint(
+ 0,
+ L,
+ size=(
+ B,
+ int(
+ L
+ * ((mask_prob + mask_prob_adjust) / mask_length)
+ * (1 + mask_dropout)
+ ),
+ ),
+ )
+
+ mask.view(B, -1).scatter_(1, mask_inds, 1)
+ centers = mask.nonzero(as_tuple=True)
+
+ inds = ([], [])
+
+ offset = mask_length // 2
+ for i in range(mask_length):
+ k1 = i - offset
+ inds[0].append(centers[0])
+ inds[1].append(centers[1] + k1)
+
+ i0 = torch.cat(inds[0])
+ i1 = torch.cat(inds[1]).clamp_(min=0, max=L - 1)
+
+ mask[(i0, i1)] = 1
+
+ def get_nbs(b, m, w):
+ all_nbs = torch.nn.functional.conv1d(m.unsqueeze(1), w, padding="same")
+ all_nbs = all_nbs.clamp_max_(1).view(b, -1)
+ return all_nbs
+
+ if require_same_masks and expand_adjcent:
+ w = torch.ones((1, 1, 3))
+ w[..., 1] = 0
+ all_nbs = get_nbs(B, mask, w)
+
+ mask = mask.view(B, -1)
+
+ if require_same_masks:
+ n_masks = mask.sum(dim=-1)
+ final_target_len = int(L * (mask_prob))
+ target_len = int(final_target_len * (1 + mask_dropout))
+
+ for i in range(len(mask)):
+ n = n_masks[i]
+ m = mask[i]
+ r = 0
+ while expand_adjcent and n < target_len:
+ if r == 0:
+ nbs = all_nbs[i]
+ else:
+ nbs = get_nbs(1, m.unsqueeze(0), w).squeeze(0)
+
+ cands = (1 - m + nbs) > 1
+ cand_sz = int(cands.sum().item())
+
+ assert cand_sz > 0, f"{nbs} {cand_sz}"
+
+ to_mask = torch.multinomial(
+ cands.float(), min(cand_sz, int(target_len - n)), replacement=False
+ )
+ m[to_mask] = 1
+ assert to_mask.numel() > 0
+ n += to_mask.numel()
+ r += 1
+
+ if n > final_target_len:
+ to_unmask = torch.multinomial(
+ m, int(n - final_target_len), replacement=False
+ )
+ m[to_unmask] = 0
+ elif n < final_target_len:
+ to_mask = torch.multinomial(
+ (1 - m), int(final_target_len - n), replacement=False
+ )
+ m[to_mask] = 1
+
+ if inverse_mask:
+ mask = 1 - mask
+
return mask
@@ -602,3 +922,223 @@ def raise_if_valid_subsets_unintentionally_ignored(train_cfg) -> None:
advice = "Set --combine-val to combine them or --ignore-unused-valid-subsets to ignore them."
msg = f"Valid paths {ignored_paths} will be ignored. {advice}"
raise ValueError(msg)
+
+
+def compute_mask_indices_for_one(
+ sz,
+ mask_prob: float,
+ mask_length: int,
+ seed=None,
+ epoch=None,
+ index=None,
+ min_masks=0,
+):
+ """
+ set seed, epoch, index for deterministic masking
+ """
+ seed = int(hash((seed, epoch, index)) % 1e6) if seed else None
+ rng = np.random.default_rng(seed)
+
+ # decide elements to mask
+ mask = np.full(sz, False)
+ num_mask = int(
+ # add a random number for probabilistic rounding
+ mask_prob * sz / float(mask_length)
+ + rng.random()
+ )
+ num_mask = max(min_masks, num_mask)
+
+ # multiple masking as described in the vq-wav2vec paper (https://arxiv.org/abs/1910.05453)
+ mask_idc = rng.choice(sz, num_mask, replace=False)
+ mask_idc = np.concatenate([mask_idc + i for i in range(mask_length)])
+ mask_idc = mask_idc[mask_idc < len(mask)]
+ try:
+ mask[mask_idc] = True
+ except: # something wrong
+ print(f"Assigning mask indexes {mask_idc} to mask {mask} failed!")
+ raise
+
+ return mask
+
+
+def compute_mask_indices_v2(
+ shape: Tuple[int, int],
+ padding_mask: Optional[torch.Tensor],
+ mask_prob: float,
+ mask_length: int,
+ min_masks: int = 0,
+ require_same_masks: bool = True,
+ seed: Optional[int] = None,
+ epoch: Optional[int] = None,
+ indices: Optional[torch.Tensor] = None,
+) -> np.ndarray:
+ bsz, all_sz = shape
+ mask = np.full((bsz, all_sz), False)
+ for i in range(bsz):
+ if padding_mask is not None:
+ sz = all_sz - padding_mask[i].long().sum().item()
+ else:
+ sz = all_sz
+ index = indices[i].item() if indices is not None else None
+ mask_for_one = compute_mask_indices_for_one(
+ sz, mask_prob, mask_length, seed, epoch, index, min_masks
+ )
+ mask[i, :sz] = mask_for_one
+
+ if require_same_masks:
+ index_sum = indices.sum().item() if indices is not None else None
+ seed = int(hash((seed, epoch, index_sum)) % 1e6) if seed else None
+ rng = np.random.default_rng(seed)
+
+ num_mask = mask.sum(-1).min()
+ for i in range(bsz):
+ extra = mask[i].sum() - num_mask
+ if extra > 0:
+ to_unmask = rng.choice(np.nonzero(mask[i])[0], extra, replace=False)
+ mask[i, to_unmask] = False
+
+ return mask
+
+
+# TODO: a copy of the original compute_mask_indices
+def compute_mask_indices_v3(
+ shape: Tuple[int, int],
+ padding_mask: Optional[torch.Tensor],
+ mask_prob: float,
+ mask_length: int,
+ mask_type: str = "static",
+ mask_other: float = 0.0,
+ min_masks: int = 0,
+ no_overlap: bool = False,
+ min_space: int = 0,
+ require_same_masks: bool = True,
+ mask_dropout: float = 0.0,
+ seed: Optional[int] = None,
+ epoch: Optional[int] = None,
+ indices: Optional[torch.Tensor] = None,
+) -> np.ndarray:
+ """
+ Computes random mask spans for a given shape
+
+ Args:
+ shape: the the shape for which to compute masks.
+ should be of size 2 where first element is batch size and 2nd is timesteps
+ padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements
+ mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by
+ number of timesteps divided by length of mask span to mask approximately this percentage of all elements.
+ however due to overlaps, the actual number will be smaller (unless no_overlap is True)
+ mask_type: how to compute mask lengths
+ static = fixed size
+ uniform = sample from uniform distribution [mask_other, mask_length*2]
+ normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element
+ poisson = sample from possion distribution with lambda = mask length
+ min_masks: minimum number of masked spans
+ no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping
+ min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans
+ require_same_masks: if true, will randomly drop out masks until same amount of masks remains in each sample
+ mask_dropout: randomly dropout this percentage of masks in each example
+ """
+ bsz, all_sz = shape
+ mask = np.full((bsz, all_sz), False)
+
+ all_num_mask = int(
+ # add a random number for probabilistic rounding
+ mask_prob * all_sz / float(mask_length)
+ + np.random.rand()
+ )
+
+ all_num_mask = max(min_masks, all_num_mask)
+
+ mask_idcs = []
+ for i in range(bsz):
+ if seed is not None and epoch is not None and indices is not None:
+ seed_i = int(hash((seed, epoch, indices[i].item())) % 1e6)
+ else:
+ seed_i = None
+ rng = np.random.default_rng(seed_i)
+
+ if padding_mask is not None:
+ sz = all_sz - padding_mask[i].long().sum().item()
+ num_mask = int(
+ # add a random number for probabilistic rounding
+ mask_prob * sz / float(mask_length)
+ + rng.random()
+ )
+ num_mask = max(min_masks, num_mask)
+ else:
+ sz = all_sz
+ num_mask = all_num_mask
+
+ if mask_type == "static":
+ lengths = np.full(num_mask, mask_length)
+ elif mask_type == "uniform":
+ lengths = rng.randint(mask_other, mask_length * 2 + 1, size=num_mask)
+ elif mask_type == "normal":
+ lengths = rng.normal(mask_length, mask_other, size=num_mask)
+ lengths = [max(1, int(round(x))) for x in lengths]
+ elif mask_type == "poisson":
+ lengths = rng.poisson(mask_length, size=num_mask)
+ lengths = [int(round(x)) for x in lengths]
+ else:
+ raise Exception("unknown mask selection " + mask_type)
+
+ if sum(lengths) == 0:
+ lengths[0] = min(mask_length, sz - 1)
+
+ if no_overlap:
+ mask_idc = []
+
+ def arrange(s, e, length, keep_length):
+ span_start = rng.randint(s, e - length)
+ mask_idc.extend(span_start + i for i in range(length))
+
+ new_parts = []
+ if span_start - s - min_space >= keep_length:
+ new_parts.append((s, span_start - min_space + 1))
+ if e - span_start - length - min_space > keep_length:
+ new_parts.append((span_start + length + min_space, e))
+ return new_parts
+
+ parts = [(0, sz)]
+ min_length = min(lengths)
+ for length in sorted(lengths, reverse=True):
+ lens = np.fromiter(
+ (e - s if e - s >= length + min_space else 0 for s, e in parts),
+ np.int,
+ )
+ l_sum = np.sum(lens)
+ if l_sum == 0:
+ break
+ probs = lens / np.sum(lens)
+ c = rng.choice(len(parts), p=probs)
+ s, e = parts.pop(c)
+ parts.extend(arrange(s, e, length, min_length))
+ mask_idc = np.asarray(mask_idc)
+ else:
+ min_len = min(lengths)
+ if sz - min_len <= num_mask:
+ min_len = sz - num_mask - 1
+
+ mask_idc = rng.choice(sz - min_len, num_mask, replace=False)
+
+ mask_idc = np.asarray(
+ [
+ mask_idc[j] + offset
+ for j in range(len(mask_idc))
+ for offset in range(lengths[j])
+ ]
+ )
+
+ mask_idcs.append(np.unique(mask_idc[mask_idc < sz]))
+
+ min_len = min([len(m) for m in mask_idcs])
+ for i, mask_idc in enumerate(mask_idcs):
+ if len(mask_idc) > min_len and require_same_masks:
+ mask_idc = rng.choice(mask_idc, min_len, replace=False)
+ if mask_dropout > 0:
+ num_holes = np.rint(len(mask_idc) * mask_dropout).astype(int)
+ mask_idc = rng.choice(mask_idc, len(mask_idc) - num_holes, replace=False)
+
+ mask[i, mask_idc] = True
+
+ return mask
diff --git a/fairseq/data/indexed_dataset.py b/fairseq/data/indexed_dataset.py
index 81cba4af6..1947d9940 100644
--- a/fairseq/data/indexed_dataset.py
+++ b/fairseq/data/indexed_dataset.py
@@ -542,6 +542,11 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
data_file_path(path)
)
+ @property
+ def can_reuse_epoch_itr_across_epochs(self):
+ # TODO: a quick fix. make it a child class of FairseqDataset instead?
+ return True
+
def get_indexed_dataset_to_local(path) -> str:
local_index_path = PathManager.get_local_path(index_file_path(path))
diff --git a/fairseq/data/iterators.py b/fairseq/data/iterators.py
index 45a8c65fa..a48826513 100644
--- a/fairseq/data/iterators.py
+++ b/fairseq/data/iterators.py
@@ -156,7 +156,7 @@ class StreamingEpochBatchIterator(EpochBatchIterating):
num_workers=0,
buffer_size=0,
timeout=0,
- persistent_workers=False,
+ persistent_workers=True,
):
assert isinstance(dataset, torch.utils.data.IterableDataset)
self.dataset = dataset
@@ -164,11 +164,11 @@ class StreamingEpochBatchIterator(EpochBatchIterating):
self.collate_fn = collate_fn
self.epoch = max(epoch, 1) # we use 1-based indexing for epochs
self.num_workers = num_workers
+ self.persistent_workers = persistent_workers and num_workers > 0
# This upper limit here is to prevent people from abusing this feature
# in a shared computing environment.
self.buffer_size = min(buffer_size, 20)
self.timeout = timeout
- self.persistent_workers = persistent_workers
self._current_epoch_iterator = None
@@ -321,7 +321,7 @@ class EpochBatchIterator(EpochBatchIterating):
skip_remainder_batch=False,
grouped_shuffling=False,
reuse_dataloader=False,
- persistent_workers=False,
+ persistent_workers=True,
):
assert isinstance(dataset, torch.utils.data.Dataset)
self.dataset = dataset
@@ -334,6 +334,7 @@ class EpochBatchIterator(EpochBatchIterating):
self.num_shards = num_shards
self.shard_id = shard_id
self.num_workers = num_workers
+ self.persistent_workers = persistent_workers and num_workers > 0
# This upper limit here is to prevent people from abusing this feature
# in a shared computing environment.
self.buffer_size = min(buffer_size, 20)
@@ -350,7 +351,6 @@ class EpochBatchIterator(EpochBatchIterating):
self.dataloader = None
self.reuse_dataloader = reuse_dataloader
- self.persistent_workers = persistent_workers
@property
def frozen_batches(self):
@@ -777,8 +777,6 @@ class GroupedEpochBatchIterator(EpochBatchIterator):
mult_rate=1,
buffer_size=0,
skip_remainder_batch=False,
- reuse_dataloader=False,
- persistent_workers=False,
):
super().__init__(
dataset,
@@ -791,8 +789,6 @@ class GroupedEpochBatchIterator(EpochBatchIterator):
epoch,
buffer_size,
skip_remainder_batch=skip_remainder_batch,
- reuse_dataloader=reuse_dataloader,
- persistent_workers=persistent_workers,
)
# level 0: sub-samplers 1: batch_idx 2: batches
self._frozen_batches = tuple([tuple(sub_batch) for sub_batch in batch_samplers])
diff --git a/fairseq/data/mask_tokens_dataset.py b/fairseq/data/mask_tokens_dataset.py
index 912323559..0ca9051c9 100644
--- a/fairseq/data/mask_tokens_dataset.py
+++ b/fairseq/data/mask_tokens_dataset.py
@@ -69,6 +69,7 @@ class MaskTokensDataset(BaseWrapperDataset):
mask_whole_words: torch.Tensor = None,
mask_multiple_length: int = 1,
mask_stdev: float = 0.0,
+ skip_masking: bool = False,
):
assert 0.0 < mask_prob < 1.0
assert 0.0 <= random_token_prob <= 1.0
@@ -89,6 +90,7 @@ class MaskTokensDataset(BaseWrapperDataset):
self.mask_whole_words = mask_whole_words
self.mask_multiple_length = mask_multiple_length
self.mask_stdev = mask_stdev
+ self.skip_masking = skip_masking
if random_token_prob > 0.0:
if freq_weighted_replacement:
@@ -113,108 +115,112 @@ class MaskTokensDataset(BaseWrapperDataset):
@lru_cache(maxsize=8)
def __getitem_cached__(self, seed: int, epoch: int, index: int):
- with data_utils.numpy_seed(self.seed, self.epoch, index):
- item = self.dataset[index]
- sz = len(item)
+ seed = int(hash((seed, epoch, index)) % 1e6)
+ rng = np.random.default_rng(seed)
+ item = self.dataset[index]
+ sz = len(item)
- assert (
- self.mask_idx not in item
- ), "Dataset contains mask_idx (={}), this is not expected!".format(
- self.mask_idx,
+ assert (
+ self.mask_idx not in item
+ ), "Dataset contains mask_idx (={}), this is not expected!".format(
+ self.mask_idx,
+ )
+ if self.skip_masking:
+ return torch.from_numpy(np.copy(item))
+
+ if self.mask_whole_words is not None:
+ word_begins_mask = self.mask_whole_words.gather(0, item)
+ word_begins_idx = word_begins_mask.nonzero().view(-1)
+ sz = len(word_begins_idx)
+ words = np.split(word_begins_mask, word_begins_idx)[1:]
+ assert len(words) == sz
+ word_lens = list(map(len, words))
+
+ # decide elements to mask
+ mask = np.full(sz, False)
+ num_mask = int(
+ # add a random number for probabilistic rounding
+ self.mask_prob * sz / float(self.mask_multiple_length)
+ + rng.random()
+ )
+
+ # multiple masking as described in the vq-wav2vec paper (https://arxiv.org/abs/1910.05453)
+ mask_idc = rng.choice(sz, num_mask, replace=False)
+ if self.mask_stdev > 0.0:
+ lengths = rng.normal(
+ self.mask_multiple_length, self.mask_stdev, size=num_mask
)
-
- if self.mask_whole_words is not None:
- word_begins_mask = self.mask_whole_words.gather(0, item)
- word_begins_idx = word_begins_mask.nonzero().view(-1)
- sz = len(word_begins_idx)
- words = np.split(word_begins_mask, word_begins_idx)[1:]
- assert len(words) == sz
- word_lens = list(map(len, words))
-
- # decide elements to mask
- mask = np.full(sz, False)
- num_mask = int(
- # add a random number for probabilistic rounding
- self.mask_prob * sz / float(self.mask_multiple_length)
- + np.random.rand()
+ lengths = [max(0, int(round(x))) for x in lengths]
+ mask_idc = np.asarray(
+ [
+ mask_idc[j] + offset
+ for j in range(len(mask_idc))
+ for offset in range(lengths[j])
+ ],
+ dtype=np.int64,
)
+ else:
+ mask_idc = np.concatenate(
+ [mask_idc + i for i in range(self.mask_multiple_length)]
+ )
+ mask_idc = mask_idc[mask_idc < len(mask)]
+ try:
+ mask[mask_idc] = True
+ except: # something wrong
+ print("Assigning mask indexes {} to mask {} failed!".format(mask_idc, mask))
+ raise
- # multiple masking as described in the vq-wav2vec paper (https://arxiv.org/abs/1910.05453)
- mask_idc = np.random.choice(sz, num_mask, replace=False)
- if self.mask_stdev > 0.0:
- lengths = np.random.normal(
- self.mask_multiple_length, self.mask_stdev, size=num_mask
- )
- lengths = [max(0, int(round(x))) for x in lengths]
- mask_idc = np.asarray(
- [
- mask_idc[j] + offset
- for j in range(len(mask_idc))
- for offset in range(lengths[j])
- ],
- dtype=np.int64,
- )
- else:
- mask_idc = np.concatenate(
- [mask_idc + i for i in range(self.mask_multiple_length)]
- )
- mask_idc = mask_idc[mask_idc < len(mask)]
- try:
- mask[mask_idc] = True
- except: # something wrong
- print(
- "Assigning mask indexes {} to mask {} failed!".format(
- mask_idc, mask
- )
- )
- raise
-
- if self.return_masked_tokens:
- # exit early if we're just returning the masked tokens
- # (i.e., the targets for masked LM training)
- if self.mask_whole_words is not None:
- mask = np.repeat(mask, word_lens)
- new_item = np.full(len(mask), self.pad_idx)
- new_item[mask] = item[torch.from_numpy(mask.astype(np.uint8)) == 1]
- return torch.from_numpy(new_item)
-
- # decide unmasking and random replacement
- rand_or_unmask_prob = self.random_token_prob + self.leave_unmasked_prob
- if rand_or_unmask_prob > 0.0:
- rand_or_unmask = mask & (np.random.rand(sz) < rand_or_unmask_prob)
- if self.random_token_prob == 0.0:
- unmask = rand_or_unmask
- rand_mask = None
- elif self.leave_unmasked_prob == 0.0:
- unmask = None
- rand_mask = rand_or_unmask
- else:
- unmask_prob = self.leave_unmasked_prob / rand_or_unmask_prob
- decision = np.random.rand(sz) < unmask_prob
- unmask = rand_or_unmask & decision
- rand_mask = rand_or_unmask & (~decision)
- else:
- unmask = rand_mask = None
-
- if unmask is not None:
- mask = mask ^ unmask
-
+ # if self.return_masked_tokens:
+ # print((
+ # f"IDX={index}; seed={seed}; epoch={epoch}; is_tgt={self.return_masked_tokens}: "
+ # f"{np.nonzero(mask)[0].sum()}"
+ # ))
+ if self.return_masked_tokens:
+ # exit early if we're just returning the masked tokens
+ # (i.e., the targets for masked LM training)
if self.mask_whole_words is not None:
mask = np.repeat(mask, word_lens)
-
- new_item = np.copy(item)
- new_item[mask] = self.mask_idx
- if rand_mask is not None:
- num_rand = rand_mask.sum()
- if num_rand > 0:
- if self.mask_whole_words is not None:
- rand_mask = np.repeat(rand_mask, word_lens)
- num_rand = rand_mask.sum()
-
- new_item[rand_mask] = np.random.choice(
- len(self.vocab),
- num_rand,
- p=self.weights,
- )
-
+ new_item = np.full(len(mask), self.pad_idx)
+ new_item[mask] = item[torch.from_numpy(mask.astype(np.uint8)) == 1]
return torch.from_numpy(new_item)
+
+ # decide unmasking and random replacement
+ rand_or_unmask_prob = self.random_token_prob + self.leave_unmasked_prob
+ if rand_or_unmask_prob > 0.0:
+ rand_or_unmask = mask & (rng.random(sz) < rand_or_unmask_prob)
+ if self.random_token_prob == 0.0:
+ unmask = rand_or_unmask
+ rand_mask = None
+ elif self.leave_unmasked_prob == 0.0:
+ unmask = None
+ rand_mask = rand_or_unmask
+ else:
+ unmask_prob = self.leave_unmasked_prob / rand_or_unmask_prob
+ decision = rng.random(sz) < unmask_prob
+ unmask = rand_or_unmask & decision
+ rand_mask = rand_or_unmask & (~decision)
+ else:
+ unmask = rand_mask = None
+
+ if unmask is not None:
+ mask = mask ^ unmask
+
+ if self.mask_whole_words is not None:
+ mask = np.repeat(mask, word_lens)
+
+ new_item = np.copy(item)
+ new_item[mask] = self.mask_idx
+ if rand_mask is not None:
+ num_rand = rand_mask.sum()
+ if num_rand > 0:
+ if self.mask_whole_words is not None:
+ rand_mask = np.repeat(rand_mask, word_lens)
+ num_rand = rand_mask.sum()
+
+ new_item[rand_mask] = rng.choice(
+ len(self.vocab),
+ num_rand,
+ p=self.weights,
+ )
+
+ return torch.from_numpy(new_item)
diff --git a/fairseq/data/padding_mask_dataset.py b/fairseq/data/padding_mask_dataset.py
new file mode 100644
index 000000000..d7f7b88db
--- /dev/null
+++ b/fairseq/data/padding_mask_dataset.py
@@ -0,0 +1,38 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+
+from fairseq.data import data_utils
+from . import BaseWrapperDataset
+
+
+class PaddingMaskDataset(BaseWrapperDataset):
+ def __init__(self, dataset, left_pad, pad_length=None):
+ super().__init__(dataset)
+ self.left_pad = left_pad
+ self.pad_length = pad_length
+
+ def __getitem__(self, index):
+ item = self.dataset[index]
+ return torch.zeros_like(item).bool()
+
+ def __len__(self):
+ return len(self.dataset)
+
+ def collater(self, samples):
+ return data_utils.collate_tokens(
+ samples, True, left_pad=self.left_pad, pad_to_length=self.pad_length
+ )
+
+
+class LeftPaddingMaskDataset(PaddingMaskDataset):
+ def __init__(self, dataset):
+ super().__init__(dataset, left_pad=True)
+
+
+class RightPaddingMaskDataset(PaddingMaskDataset):
+ def __init__(self, dataset):
+ super().__init__(dataset, left_pad=False)
diff --git a/fairseq/data/subsample_dataset.py b/fairseq/data/subsample_dataset.py
index 48feaf883..fe5c7e2ac 100644
--- a/fairseq/data/subsample_dataset.py
+++ b/fairseq/data/subsample_dataset.py
@@ -3,9 +3,11 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
+import contextlib
import logging
import numpy as np
+from fairseq.data.data_utils import numpy_seed
from . import BaseWrapperDataset
@@ -21,13 +23,14 @@ class SubsampleDataset(BaseWrapperDataset):
size_ratio(float): the ratio to subsample to. must be between 0 and 1 (exclusive)
"""
- def __init__(self, dataset, size_ratio, shuffle=False):
+ def __init__(self, dataset, size_ratio, shuffle=False, seed=None):
super().__init__(dataset)
assert size_ratio < 1
self.actual_size = np.ceil(len(dataset) * size_ratio).astype(int)
- self.indices = np.random.choice(
- list(range(len(self.dataset))), self.actual_size, replace=False
- )
+ with numpy_seed(seed) if seed is not None else contextlib.ExitStack():
+ self.indices = np.random.choice(
+ list(range(len(self.dataset))), self.actual_size, replace=False
+ )
self.shuffle = shuffle
logger.info(
"subsampled dataset from {} to {} (ratio={})".format(
diff --git a/fairseq/dataclass/configs.py b/fairseq/dataclass/configs.py
index 5fdfab38d..af957fec6 100644
--- a/fairseq/dataclass/configs.py
+++ b/fairseq/dataclass/configs.py
@@ -636,6 +636,7 @@ class OptimizationConfig(FairseqDataclass):
" (default is to skip it)."
},
)
+ debug_param_names: bool = False
@dataclass
diff --git a/fairseq/dataclass/utils.py b/fairseq/dataclass/utils.py
index 69b77962e..f6467d5f4 100644
--- a/fairseq/dataclass/utils.py
+++ b/fairseq/dataclass/utils.py
@@ -487,15 +487,22 @@ def overwrite_args_by_name(cfg: DictConfig, overrides: Dict[str, any]):
def merge_with_parent(dc: FairseqDataclass, cfg: DictConfig, remove_missing=False):
if remove_missing:
- if is_dataclass(dc):
- target_keys = set(dc.__dataclass_fields__.keys())
- else:
- target_keys = set(dc.keys())
+ def remove_missing_rec(src_keys, target_cfg):
+ if is_dataclass(target_cfg):
+ target_keys = set(target_cfg.__dataclass_fields__.keys())
+ else:
+ target_keys = set(target_cfg.keys())
+
+ for k in list(src_keys.keys()):
+ if k not in target_keys:
+ del src_keys[k]
+ elif OmegaConf.is_config(src_keys[k]):
+ tgt = getattr(target_cfg, k)
+ if tgt is not None and (is_dataclass(tgt) or hasattr(tgt, "keys")):
+ remove_missing_rec(src_keys[k], tgt)
with open_dict(cfg):
- for k in list(cfg.keys()):
- if k not in target_keys:
- del cfg[k]
+ remove_missing_rec(cfg, dc)
merged_cfg = OmegaConf.merge(dc, cfg)
merged_cfg.__dict__["_parent"] = cfg.__dict__["_parent"]
diff --git a/fairseq/distributed/utils.py b/fairseq/distributed/utils.py
index 2c52f76aa..968830d58 100644
--- a/fairseq/distributed/utils.py
+++ b/fairseq/distributed/utils.py
@@ -51,18 +51,21 @@ def infer_init_method(cfg: DistributedTrainingConfig, force_distributed=False):
if cfg.pipeline_model_parallel:
num_pipeline_devices, num_pipelines_per_node = _pipeline_parallel_pre_init(cfg)
+ if cfg.distributed_world_size == 1:
+ return
if all(
key in os.environ
for key in ["MASTER_ADDR", "MASTER_PORT", "WORLD_SIZE", "RANK"]
):
# support torch.distributed.launch
_infer_torch_distributed_launch_init(cfg)
- elif cfg.distributed_port > 0:
+ else:
# we can determine the init method automatically for Slurm
- _infer_slurm_init(cfg, num_pipelines_per_node)
- elif cfg.distributed_world_size > 1 or force_distributed:
- # fallback for single node with multiple GPUs
- _infer_single_node_init(cfg)
+ if not _infer_slurm_init(cfg, num_pipelines_per_node):
+ if cfg.distributed_port <= 0 or force_distributed:
+ _infer_single_node_init(cfg)
+ elif cfg.distributed_port <= 0:
+ _infer_single_node_init(cfg)
if cfg.pipeline_model_parallel:
_pipeline_parallel_post_init(cfg, num_pipeline_devices, num_pipelines_per_node)
@@ -71,12 +74,21 @@ def infer_init_method(cfg: DistributedTrainingConfig, force_distributed=False):
cfg.distributed_num_procs = min(
torch.cuda.device_count(), cfg.distributed_world_size
)
+ else:
+ if cfg.device_id > 0:
+ logger.info(
+ "setting CUDA device={} on rank {}".format(
+ cfg.device_id, cfg.distributed_rank
+ )
+ )
+ torch.cuda.set_device(cfg.device_id)
def _infer_torch_distributed_launch_init(cfg: DistributedTrainingConfig):
cfg.distributed_init_method = "env://"
cfg.distributed_world_size = int(os.environ["WORLD_SIZE"])
cfg.distributed_rank = int(os.environ["RANK"])
+ cfg.device_id = cfg.distributed_rank % torch.cuda.device_count()
# processes are created by torch.distributed.launch
cfg.distributed_no_spawn = True
@@ -127,22 +139,44 @@ def _infer_slurm_init(cfg: DistributedTrainingConfig, num_pipelines_per_node):
# number of pipelines across all nodes.
cfg.distributed_world_size = nnodes * num_pipelines_per_node
else:
- assert ntasks_per_node == cfg.distributed_world_size // nnodes
+ assert (
+ ntasks_per_node == cfg.distributed_world_size // nnodes
+ ), f"{ntasks_per_node}, {cfg.distributed_world_size}, {nnodes}"
cfg.distributed_no_spawn = True
cfg.distributed_rank = int(os.environ.get("SLURM_PROCID"))
cfg.device_id = int(os.environ.get("SLURM_LOCALID"))
+ logger.info(f"Rank {cfg.distributed_rank}, device_id: {cfg.device_id}")
+ return True
except subprocess.CalledProcessError as e: # scontrol failed
raise e
except FileNotFoundError: # Slurm is not installed
pass
+ return False
+
def _infer_single_node_init(cfg: DistributedTrainingConfig):
assert (
cfg.distributed_world_size <= torch.cuda.device_count()
), f"world size is {cfg.distributed_world_size} but have {torch.cuda.device_count()} available devices"
- port = random.randint(10000, 20000)
- cfg.distributed_init_method = "tcp://localhost:{port}".format(port=port)
+
+ if cfg.distributed_port <= 0:
+ jobid = os.environ.get("SLURM_JOB_ID")
+ task_id = os.environ.get("SLURM_ARRAY_TASK_ID")
+
+ if jobid is not None:
+ if task_id is not None:
+ jobid += str(task_id)
+ jobid = int(jobid)
+ rng = random.Random(jobid)
+ port = rng.randint(10000, 60000)
+ else:
+ port = random.randint(10000, 60000)
+
+ cfg.distributed_port = port
+ cfg.distributed_init_method = "tcp://localhost:{port}".format(
+ port=cfg.distributed_port
+ )
def _pipeline_parallel_pre_init(cfg: DistributedTrainingConfig):
@@ -341,6 +375,7 @@ def call_main(cfg: FairseqConfig, main, **kwargs):
start_rank = cfg.distributed_training.distributed_rank
cfg.distributed_training.distributed_rank = None # assign automatically
kwargs["start_rank"] = start_rank
+
torch.multiprocessing.spawn(
fn=distributed_main,
args=(main, cfg, kwargs),
diff --git a/fairseq/iterative_refinement_generator.py b/fairseq/iterative_refinement_generator.py
index 4fb0946f4..3d32c6bf4 100644
--- a/fairseq/iterative_refinement_generator.py
+++ b/fairseq/iterative_refinement_generator.py
@@ -235,7 +235,7 @@ class IterativeRefinementGenerator(object):
terminated.fill_(1)
# collect finalized sentences
- finalized_idxs = sent_idxs[terminated]
+ finalized_idxs = sent_idxs[terminated.to(sent_idxs.device)]
finalized_tokens = decoder_out.output_tokens[terminated]
finalized_scores = decoder_out.output_scores[terminated]
finalized_attn = (
@@ -285,7 +285,7 @@ class IterativeRefinementGenerator(object):
encoder_out = model.encoder.reorder_encoder_out(
encoder_out, not_terminated.nonzero(as_tuple=False).squeeze()
)
- sent_idxs = sent_idxs[not_terminated]
+ sent_idxs = sent_idxs[not_terminated.to(sent_idxs.device)]
prev_output_tokens = prev_decoder_out.output_tokens.clone()
if self.beam_size > 1:
diff --git a/fairseq/logging/meters.py b/fairseq/logging/meters.py
index d5f7c775d..495bd0830 100644
--- a/fairseq/logging/meters.py
+++ b/fairseq/logging/meters.py
@@ -139,6 +139,36 @@ class SumMeter(Meter):
return val
+class ConcatTensorMeter(Meter):
+ """Concatenates tensors"""
+
+ def __init__(self, dim=0):
+ super().__init__()
+ self.reset()
+ self.dim = dim
+
+ def reset(self):
+ self.tensor = None
+
+ def update(self, val):
+ if self.tensor is None:
+ self.tensor = val
+ else:
+ self.tensor = torch.cat([self.tensor, val], dim=self.dim)
+
+ def state_dict(self):
+ return {
+ "tensor": self.tensor,
+ }
+
+ def load_state_dict(self, state_dict):
+ self.tensor = state_dict["tensor"]
+
+ @property
+ def smoothed_value(self) -> float:
+ return [] # return a dummy value
+
+
class TimeMeter(Meter):
"""Computes the average occurrence of some event per second"""
diff --git a/fairseq/logging/metrics.py b/fairseq/logging/metrics.py
index 892b0ea4d..49301f27f 100644
--- a/fairseq/logging/metrics.py
+++ b/fairseq/logging/metrics.py
@@ -151,6 +151,26 @@ def log_scalar_sum(
agg[key].update(value)
+def log_concat_tensor(
+ key: str,
+ value: torch.Tensor,
+ priority: int = 10,
+ dim: int = 0,
+):
+ """Log a scalar value that is summed for reporting.
+
+ Args:
+ key (str): name of the field to log
+ value (float): value to log
+ priority (int): smaller values are logged earlier in the output
+ round (Optional[int]): number of digits to round to when displaying
+ """
+ for agg in get_active_aggregators():
+ if key not in agg:
+ agg.add_meter(key, ConcatTensorMeter(dim=dim), priority)
+ agg[key].update(value)
+
+
def log_derived(key: str, fn: Callable[[MetersDict], float], priority: int = 20):
"""Log a scalar value derived from other meters.
diff --git a/fairseq/models/__init__.py b/fairseq/models/__init__.py
index 616e3051e..11cf6ee53 100644
--- a/fairseq/models/__init__.py
+++ b/fairseq/models/__init__.py
@@ -128,7 +128,8 @@ def register_model(name, dataclass=None):
def register_model_cls(cls):
if name in MODEL_REGISTRY:
- raise ValueError("Cannot register duplicate model ({})".format(name))
+ return MODEL_REGISTRY[name]
+
if not issubclass(cls, BaseFairseqModel):
raise ValueError(
"Model ({}: {}) must extend BaseFairseqModel".format(name, cls.__name__)
diff --git a/fairseq/models/fairseq_model.py b/fairseq/models/fairseq_model.py
index 42f9134a3..65ead9dcf 100644
--- a/fairseq/models/fairseq_model.py
+++ b/fairseq/models/fairseq_model.py
@@ -160,6 +160,11 @@ class BaseFairseqModel(nn.Module):
if hasattr(m, "set_num_updates") and m != self:
m.set_num_updates(num_updates)
+ def set_epoch(self, epoch):
+ for m in self.modules():
+ if hasattr(m, "set_epoch") and m != self:
+ m.set_epoch(epoch)
+
def prepare_for_inference_(self, cfg: DictConfig):
"""Prepare model for inference."""
kwargs = {}
diff --git a/fairseq/models/wav2vec/wav2vec2_asr.py b/fairseq/models/wav2vec/wav2vec2_asr.py
index 9d3ffa154..bf261624e 100644
--- a/fairseq/models/wav2vec/wav2vec2_asr.py
+++ b/fairseq/models/wav2vec/wav2vec2_asr.py
@@ -47,6 +47,7 @@ class Wav2Vec2AsrConfig(FairseqDataclass):
default=0.0,
metadata={"help": "dropout to apply to the input (after feat extr)"},
)
+
final_dropout: float = field(
default=0.0,
metadata={"help": "dropout after transformer and before final projection"},
@@ -66,19 +67,6 @@ class Wav2Vec2AsrConfig(FairseqDataclass):
"help": "dropout probability after activation in FFN inside wav2vec 2.0 model"
},
)
- conv_feature_layers: Optional[str] = field(
- default="[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512,2,2)] + [(512,2,2)]",
- metadata={
- "help": (
- "string describing convolutional feature extraction "
- "layers in form of a python list that contains "
- "[(dim, kernel_size, stride), ...]"
- ),
- },
- )
- encoder_embed_dim: Optional[int] = field(
- default=768, metadata={"help": "encoder embedding dimension"}
- )
# masking
apply_mask: bool = field(
@@ -152,12 +140,14 @@ class Wav2Vec2AsrConfig(FairseqDataclass):
layerdrop: float = field(
default=0.0, metadata={"help": "probability of dropping a layer in wav2vec 2.0"}
)
+ drop_path: float = 0
mask_channel_min_space: Optional[int] = field(
default=1,
metadata={"help": "min space between spans (if no overlap is enabled)"},
)
mask_channel_before: bool = False
normalize: bool = II("task.normalize")
+ update_alibi: bool = True
data: str = II("task.data")
# this holds the loaded wav2vec args
w2v_args: Any = None
@@ -182,6 +172,11 @@ class Wav2Vec2AsrConfig(FairseqDataclass):
)
ddp_backend: str = II("distributed_training.ddp_backend")
+ zero_mask: bool = False
+ load_ema: bool = False
+
+ layer_decay: float = 1
+
@dataclass
class Wav2Vec2CtcConfig(Wav2Vec2AsrConfig):
@@ -224,6 +219,12 @@ class Wav2VecCtc(BaseFairseqModel):
number_of_classes, device=logits.device
) * float("-inf")
masking_tensor[0] = 0
+
+ if logits.size(0) > net_output["padding_mask"].size(1):
+ net_output["padding_mask"] = F.pad(
+ net_output["padding_mask"], (1, 0), value=False
+ )
+
logits[net_output["padding_mask"].T] = masking_tensor.type_as(logits)
if normalize:
@@ -371,6 +372,19 @@ class Wav2VecEncoder(FairseqEncoder):
"checkpoint_activations": cfg.checkpoint_activations,
"offload_activations": cfg.offload_activations,
"min_params_to_wrap": cfg.min_params_to_wrap,
+ # d2v multi args
+ "encoder_dropout": cfg.dropout,
+ "drop_path": getattr(cfg, "drop_path", 0),
+ "mask_dropout": getattr(cfg, "mask_dropout", 0),
+ "zero_mask": getattr(cfg, "zero_mask", False),
+ "local_grad_mult": cfg.feature_grad_mult,
+ "layerdrop": cfg.layerdrop,
+ "prenet_layerdrop": cfg.layerdrop,
+ "prenet_dropout": cfg.dropout,
+ "post_mlp_drop": cfg.dropout,
+ "encoder_zero_mask": getattr(cfg, "zero_mask", False),
+ "inverse_mask": False,
+ "learned_alibi_scale": getattr(cfg, "update_alibi", True),
}
if cfg.w2v_args is None:
@@ -380,6 +394,7 @@ class Wav2VecEncoder(FairseqEncoder):
w2v_args = convert_namespace_to_omegaconf(state["args"])
w2v_args.criterion = None
w2v_args.lr_scheduler = None
+
cfg.w2v_args = w2v_args
logger.info(w2v_args)
@@ -390,31 +405,52 @@ class Wav2VecEncoder(FairseqEncoder):
if isinstance(w2v_args, Namespace):
cfg.w2v_args = w2v_args = convert_namespace_to_omegaconf(w2v_args)
- model_normalized = w2v_args.task.get(
- "normalize", w2v_args.model.get("normalize", False)
- )
- assert cfg.normalize == model_normalized, (
- "Fine-tuning works best when data normalization is the same. "
- "Please check that --normalize is set or unset for both pre-training and here"
- )
+ self.is_d2v_multi = "data2vec_multi" in w2v_args.model.get("_name", None)
- if hasattr(cfg, "checkpoint_activations") and cfg.checkpoint_activations:
- with open_dict(w2v_args):
- w2v_args.model.checkpoint_activations = cfg.checkpoint_activations
+ if not self.is_d2v_multi:
+ model_normalized = w2v_args.task.get(
+ "normalize", w2v_args.model.get("normalize", False)
+ )
+ assert cfg.normalize == model_normalized, (
+ "Fine-tuning works best when data normalization is the same. "
+ "Please check that --normalize is set or unset for both pre-training and here"
+ )
- w2v_args.task.data = cfg.data
- task = tasks.setup_task(w2v_args.task)
- model = task.build_model(w2v_args.model, from_checkpoint=True)
+ if hasattr(cfg, "checkpoint_activations") and cfg.checkpoint_activations:
+ with open_dict(w2v_args):
+ w2v_args.model.checkpoint_activations = cfg.checkpoint_activations
- model.remove_pretraining_modules()
+ w2v_args.task.data = cfg.data
+ task = tasks.setup_task(w2v_args.task, from_checkpoint=True)
+ model = task.build_model(w2v_args.model, from_checkpoint=True)
+
+ model.remove_pretraining_modules()
+ d = w2v_args.model.encoder_embed_dim
+ else:
+ assert cfg.normalize
+
+ if hasattr(w2v_args.task, "audio"):
+ w2v_args.task.audio.data = cfg.data
+ else:
+ w2v_args.task.data = cfg.data
+ task = tasks.setup_task(w2v_args.task, from_checkpoint=True)
+
+ model = task.build_model(w2v_args.model, from_checkpoint=True)
+
+ model.remove_pretraining_modules(modality="audio")
+ d = w2v_args.model.embed_dim
if state is not None and not cfg.no_pretrained_weights:
+ if cfg.load_ema:
+ assert "_ema" in state["model"]
+ for k in state["model"]["_ema"]:
+ mk = "encoder." + k
+ assert mk in state["model"], mk
+ state["model"][mk] = state["model"]["_ema"][k]
self.load_model_weights(state, model, cfg)
super().__init__(task.source_dictionary)
- d = w2v_args.model.encoder_embed_dim
-
self.w2v_model = model
self.final_dropout = nn.Dropout(cfg.final_dropout)
@@ -432,6 +468,29 @@ class Wav2VecEncoder(FairseqEncoder):
if targ_d is not None:
self.proj = Linear(d, targ_d)
+ layer_decay = getattr(cfg, "layer_decay", 1)
+ if layer_decay < 1:
+ mod_encs = list(model.modality_encoders.values())
+ assert len(mod_encs) == 1, len(mod_encs)
+ blocks = list(mod_encs[0].context_encoder.blocks) + list(model.blocks)
+ num_layers = len(blocks) + 1
+ layer_scales = list(
+ layer_decay ** (num_layers - i) for i in range(num_layers + 1)
+ )
+
+ for i, b in enumerate(blocks):
+ lid = i + 1
+ if layer_scales[lid] == 1.0:
+ continue
+
+ for n, p in b.named_parameters():
+ optim_override = getattr(p, "optim_overrides", {})
+ if "optimizer" not in optim_override:
+ optim_override["optimizer"] = {}
+
+ optim_override["optimizer"]["lr_scale"] = layer_scales[lid]
+ p.optim_overrides = optim_override
+
def load_model_weights(self, state, model, cfg):
if cfg.ddp_backend == "fully_sharded":
from fairseq.distributed import FullyShardedDataParallel
@@ -461,8 +520,25 @@ class Wav2VecEncoder(FairseqEncoder):
model.load_state_dict(new_big_dict, strict=False)
else:
- if "_ema" in state["model"]:
- del state["model"]["_ema"]
+ to_delete = {"_ema", "target_proj", "decoder"}
+ for k in to_delete:
+ if k in state["model"]:
+ del state["model"][k]
+
+ if hasattr(model, "modality_encoders"):
+ if "modality_encoders.AUDIO.encoder_mask" not in state["model"]:
+ model.modality_encoders["AUDIO"].encoder_mask = None
+ elif not cfg.zero_mask:
+ model.modality_encoders["AUDIO"].encoder_mask = None
+ del state["model"]["modality_encoders.AUDIO.encoder_mask"]
+
+ for k in list(state["model"].keys()):
+ if k.startswith("modality_encoders.") and not k.startswith(
+ "modality_encoders.AUDIO"
+ ):
+ del state["model"][k]
+
+ print(model)
model.load_state_dict(state["model"], strict=True)
def set_num_updates(self, num_updates):
@@ -478,6 +554,9 @@ class Wav2VecEncoder(FairseqEncoder):
"mask": self.apply_mask and self.training,
}
+ if self.is_d2v_multi:
+ w2v_args["mode"] = "AUDIO"
+
ft = self.freeze_finetune_updates <= self.num_updates
with torch.no_grad() if not ft else contextlib.ExitStack():
@@ -626,6 +705,7 @@ class TransformerDecoder(FairseqIncrementalDecoder):
- the decoder's output of shape `(batch, tgt_len, vocab)`
- a dictionary with any model-specific outputs
"""
+
if type(prev_output_tokens) == list:
max_len = max((len(x) for x in prev_output_tokens))
tmp = torch.zeros(
@@ -634,6 +714,7 @@ class TransformerDecoder(FairseqIncrementalDecoder):
for (i, p) in enumerate(prev_output_tokens):
tmp[i, : len(p)] = p
prev_output_tokens = tmp
+
prev_output_tokens = prev_output_tokens.long()
x, extra = self.extract_features(
prev_output_tokens, encoder_out, incremental_state
diff --git a/fairseq/modules/__init__.py b/fairseq/modules/__init__.py
index 1df76ef24..dcfda9b82 100644
--- a/fairseq/modules/__init__.py
+++ b/fairseq/modules/__init__.py
@@ -32,7 +32,7 @@ from .location_attention import LocationAttention
from .lstm_cell_with_zoneout import LSTMCellWithZoneOut
from .multihead_attention import MultiheadAttention
from .positional_embedding import PositionalEmbedding
-from .same_pad import SamePad
+from .same_pad import SamePad, SamePad2d
from .scalar_bias import ScalarBias
from .sinusoidal_positional_embedding import SinusoidalPositionalEmbedding
from .transformer_sentence_encoder_layer import TransformerSentenceEncoderLayer
@@ -87,6 +87,7 @@ __all__ = [
"MultiheadAttention",
"PositionalEmbedding",
"SamePad",
+ "SamePad2d",
"ScalarBias",
"SinusoidalPositionalEmbedding",
"TransformerSentenceEncoderLayer",
diff --git a/fairseq/modules/ema_module.py b/fairseq/modules/ema_module.py
index a5b98861d..f0ece842d 100644
--- a/fairseq/modules/ema_module.py
+++ b/fairseq/modules/ema_module.py
@@ -11,8 +11,18 @@ import logging
import torch
+from omegaconf import II
from fairseq.dataclass import FairseqDataclass
+try:
+ from amp_C import multi_tensor_l2norm
+
+ multi_tensor_l2norm_available = True
+except ImportError:
+ multi_tensor_l2norm_available = False
+
+logger = logging.getLogger(__name__)
+
@dataclass
class EMAModuleConfig(FairseqDataclass):
@@ -23,12 +33,21 @@ class EMAModuleConfig(FairseqDataclass):
default=False,
metadata={"help": "If true, store EMA model in fp32 even if model is in fp16"},
)
+ add_missing_params: bool = True
+ log_norms: bool = False
class EMAModule:
"""Exponential Moving Average of Fairseq Models"""
- def __init__(self, model, config: EMAModuleConfig, device=None, skip_keys=None):
+ def __init__(
+ self,
+ model,
+ config: EMAModuleConfig,
+ copy_model=True,
+ device=None,
+ skip_keys=None,
+ ):
"""
@param model model to initialize the EMA with
@param config EMAConfig object with configuration like
@@ -37,11 +56,18 @@ class EMAModule:
Otherwise EMA is in the same device as the model.
"""
- self.decay = config.ema_decay
- self.model = copy.deepcopy(model)
- self.model.requires_grad_(False)
self.config = config
+
+ if copy_model:
+ self.model = copy.deepcopy(model)
+ self.model.requires_grad_(False)
+ else:
+ self.model = model
+
+ self.config = config
+ self.decay = config.ema_decay
self.skip_keys = skip_keys or set()
+ self.add_missing_params = config.add_missing_params
self.fp32_params = {}
if device is not None:
@@ -51,7 +77,8 @@ class EMAModule:
if self.config.ema_fp32:
self.build_fp32_params()
- self.update_freq_counter = 0
+ self.log_norms = config.log_norms and multi_tensor_l2norm_available
+ self.logs = {}
def build_fp32_params(self, state_dict=None):
"""
@@ -74,9 +101,16 @@ class EMAModule:
for param_key in state_dict:
if param_key in self.fp32_params:
- self.fp32_params[param_key].copy_(state_dict[param_key])
+ if param_key == "__sq_mom":
+ self.fp32_params[param_key] = state_dict[param_key]
+ else:
+ self.fp32_params[param_key].copy_(state_dict[param_key])
else:
self.fp32_params[param_key] = _to_float(state_dict[param_key])
+ if "__sq_mom" in self.fp32_params:
+ self.fp32_params["__sq_mom"][param_key] = torch.zeros_like(
+ self.fp32_params[param_key]
+ )
def restore(self, state_dict, build_fp32_params=False):
"""Load data from a model spec into EMA model"""
@@ -84,8 +118,10 @@ class EMAModule:
if build_fp32_params:
self.build_fp32_params(state_dict)
- def set_decay(self, decay):
+ def set_decay(self, decay, weight_decay=None):
self.decay = decay
+ if weight_decay is not None:
+ self.weight_decay = weight_decay
def get_decay(self):
return self.decay
@@ -98,9 +134,17 @@ class EMAModule:
ema_params = (
self.fp32_params if self.config.ema_fp32 else self.model.state_dict()
)
+
+ new_p = []
+ ema_p = []
+
for key, param in new_model.named_parameters():
if isinstance(param, dict):
continue
+
+ if not self.add_missing_params and key not in ema_params:
+ continue
+
try:
ema_param = ema_params[key]
except KeyError:
@@ -119,18 +163,39 @@ class EMAModule:
# Do not decay a model.version pytorch param
continue
+ lr = 1 - decay
+
if key in self.skip_keys or not param.requires_grad:
ema_params[key].copy_(param.to(dtype=ema_param.dtype).data)
ema_param = ema_params[key]
else:
- ema_param.mul_(decay)
- ema_param.add_(param.data.to(dtype=ema_param.dtype), alpha=1 - decay)
+ if self.log_norms:
+ new_p.append(param)
+ ema_p.append(ema_param)
+
+ ema_param.mul_(1 - lr)
+ ema_param.add_(param.data.to(dtype=ema_param.dtype), alpha=lr)
ema_state_dict[key] = ema_param
for key, param in new_model.named_buffers():
ema_state_dict[key] = param
+ if self.log_norms:
+ if "model_norm" in self.logs:
+ self.prev_model_norm = self.logs["model_norm"]
+
+ chunk_size = 2048 * 32
+ has_inf = torch.zeros(
+ (1, 1), dtype=torch.int, device=next(new_model.parameters()).device
+ )
+
+ new_norm = multi_tensor_l2norm(chunk_size, has_inf, [new_p], False)
+ old_norm = multi_tensor_l2norm(chunk_size, has_inf, [ema_p], False)
+
+ self.logs["model_norm"] = new_norm[0]
+ self.logs["ema_norm"] = old_norm[0]
+
self.restore(ema_state_dict, build_fp32_params=False)
@torch.no_grad()
diff --git a/fairseq/modules/gumbel_vector_quantizer.py b/fairseq/modules/gumbel_vector_quantizer.py
index 91655bc5e..867b019f6 100644
--- a/fairseq/modules/gumbel_vector_quantizer.py
+++ b/fairseq/modules/gumbel_vector_quantizer.py
@@ -21,6 +21,8 @@ class GumbelVectorQuantizer(nn.Module):
activation=nn.GELU(),
weight_proj_depth=1,
weight_proj_factor=1,
+ hard=True,
+ std=0,
):
"""Vector quantization using gumbel softmax
@@ -44,6 +46,7 @@ class GumbelVectorQuantizer(nn.Module):
self.input_dim = dim
self.num_vars = num_vars
self.time_first = time_first
+ self.hard = hard
assert (
vq_dim % groups == 0
@@ -53,7 +56,10 @@ class GumbelVectorQuantizer(nn.Module):
num_groups = groups if not combine_groups else 1
self.vars = nn.Parameter(torch.FloatTensor(1, num_groups * num_vars, var_dim))
- nn.init.uniform_(self.vars)
+ if std == 0:
+ nn.init.uniform_(self.vars)
+ else:
+ nn.init.normal_(self.vars, mean=0, std=std)
if weight_proj_depth > 1:
@@ -151,16 +157,17 @@ class GumbelVectorQuantizer(nn.Module):
x = self.weight_proj(x)
x = x.view(bsz * tsz * self.groups, -1)
- _, k = x.max(-1)
- hard_x = (
- x.new_zeros(*x.shape)
- .scatter_(-1, k.view(-1, 1), 1.0)
- .view(bsz * tsz, self.groups, -1)
- )
- hard_probs = torch.mean(hard_x.float(), dim=0)
- result["code_perplexity"] = torch.exp(
- -torch.sum(hard_probs * torch.log(hard_probs + 1e-7), dim=-1)
- ).sum()
+ with torch.no_grad():
+ _, k = x.max(-1)
+ hard_x = (
+ x.new_zeros(*x.shape)
+ .scatter_(-1, k.view(-1, 1), 1.0)
+ .view(bsz * tsz, self.groups, -1)
+ )
+ hard_probs = torch.mean(hard_x.float(), dim=0)
+ result["code_perplexity"] = torch.exp(
+ -torch.sum(hard_probs * torch.log(hard_probs + 1e-7), dim=-1)
+ ).sum()
avg_probs = torch.softmax(
x.view(bsz * tsz, self.groups, -1).float(), dim=-1
@@ -172,7 +179,9 @@ class GumbelVectorQuantizer(nn.Module):
result["temp"] = self.curr_temp
if self.training:
- x = F.gumbel_softmax(x.float(), tau=self.curr_temp, hard=True).type_as(x)
+ x = F.gumbel_softmax(x.float(), tau=self.curr_temp, hard=self.hard).type_as(
+ x
+ )
else:
x = hard_x
diff --git a/fairseq/modules/kmeans_vector_quantizer.py b/fairseq/modules/kmeans_vector_quantizer.py
index 040db1e83..1015c3899 100644
--- a/fairseq/modules/kmeans_vector_quantizer.py
+++ b/fairseq/modules/kmeans_vector_quantizer.py
@@ -100,15 +100,16 @@ class KmeansVectorQuantizer(nn.Module):
assert ze.shape == zq.shape, (ze.shape, zq.shape)
x = self._pass_grad(ze, zq)
- hard_x = (
- idx.new_zeros(bsz * tsz * self.groups, self.num_vars)
- .scatter_(-1, idx.view(-1, 1), 1.0)
- .view(bsz * tsz, self.groups, -1)
- )
- hard_probs = torch.mean(hard_x.float(), dim=0)
- result["code_perplexity"] = torch.exp(
- -torch.sum(hard_probs * torch.log(hard_probs + 1e-7), dim=-1)
- ).sum()
+ with torch.no_grad():
+ hard_x = (
+ idx.new_zeros(bsz * tsz * self.groups, self.num_vars)
+ .scatter_(-1, idx.view(-1, 1), 1.0)
+ .view(bsz * tsz, self.groups, -1)
+ )
+ hard_probs = torch.mean(hard_x.float(), dim=0)
+ result["code_perplexity"] = torch.exp(
+ -torch.sum(hard_probs * torch.log(hard_probs + 1e-7), dim=-1)
+ ).sum()
if produce_targets:
result["targets"] = idx
diff --git a/fairseq/modules/multihead_attention.py b/fairseq/modules/multihead_attention.py
index 1cf3abe94..262132dfe 100644
--- a/fairseq/modules/multihead_attention.py
+++ b/fairseq/modules/multihead_attention.py
@@ -551,7 +551,7 @@ class MultiheadAttention(FairseqIncrementalDecoder):
self.out_proj.weight,
self.out_proj.bias,
self.training or self.dropout_module.apply_during_inference,
- key_padding_mask,
+ key_padding_mask.bool() if key_padding_mask is not None else None,
need_weights,
attn_mask,
use_separate_proj_weight=True,
diff --git a/fairseq/modules/same_pad.py b/fairseq/modules/same_pad.py
index 4c04990ea..a3ce4131c 100644
--- a/fairseq/modules/same_pad.py
+++ b/fairseq/modules/same_pad.py
@@ -19,3 +19,15 @@ class SamePad(nn.Module):
if self.remove > 0:
x = x[:, :, : -self.remove]
return x
+
+
+class SamePad2d(nn.Module):
+ def __init__(self, kernel_size):
+ super().__init__()
+ self.remove = 1 if kernel_size % 2 == 0 else 0
+
+ def forward(self, x):
+ assert len(x.size()) == 4
+ if self.remove > 0:
+ x = x[:, :, : -self.remove, : -self.remove]
+ return x
diff --git a/fairseq/modules/transpose_last.py b/fairseq/modules/transpose_last.py
index e578b3ec5..d7cca9a4b 100644
--- a/fairseq/modules/transpose_last.py
+++ b/fairseq/modules/transpose_last.py
@@ -10,11 +10,12 @@ import torch.nn as nn
class TransposeLast(nn.Module):
- def __init__(self, deconstruct_idx=None):
+ def __init__(self, deconstruct_idx=None, tranpose_dim=-2):
super().__init__()
self.deconstruct_idx = deconstruct_idx
+ self.tranpose_dim = tranpose_dim
def forward(self, x):
if self.deconstruct_idx is not None:
x = x[self.deconstruct_idx]
- return x.transpose(-2, -1)
+ return x.transpose(self.tranpose_dim, -1)
diff --git a/fairseq/nan_detector.py b/fairseq/nan_detector.py
index 7d46d766d..bd0f91107 100644
--- a/fairseq/nan_detector.py
+++ b/fairseq/nan_detector.py
@@ -38,7 +38,7 @@ class NanDetector:
for name, param in self.named_parameters:
if param.grad is not None:
grad_norm = torch.norm(param.grad.data.float(), p=2)
- norm[name] = grad_norm.item()
+ norm[name] = param.norm().item()
if torch.isnan(grad_norm).any() or torch.isinf(grad_norm).any():
gradients[name] = param.grad.data
if len(gradients) > 0:
diff --git a/fairseq/optim/composite.py b/fairseq/optim/composite.py
index 63701ee8b..1ef0114ed 100644
--- a/fairseq/optim/composite.py
+++ b/fairseq/optim/composite.py
@@ -13,6 +13,7 @@ from fairseq.dataclass import FairseqDataclass
from fairseq.optim import FairseqOptimizer, register_optimizer, _build_optimizer
from fairseq.optim.lr_scheduler import FairseqLRScheduler, build_lr_scheduler
from omegaconf import II, open_dict
+import copy
logger = logging.getLogger(__name__)
@@ -37,6 +38,12 @@ class CompositeOptimizerConfig(FairseqDataclass):
"Configures a different optimizer and (optionally) lr scheduler for each parameter group"
},
)
+ dynamic_groups: bool = field(
+ default=False,
+ metadata={
+ "help": "create groups dynamically based on parameters, if set to False, all parameters needs to have group_names"
+ },
+ )
@register_optimizer("composite", dataclass=CompositeOptimizerConfig)
@@ -54,31 +61,107 @@ class FairseqCompositeOptimizer(FairseqOptimizer):
len(params) > 1
), "Composite optimizer only works when there are multiple parameter groups (try fp16_no_flatten_grads: true)"
+ def dict_hash(dictionary: Dict[str, Any]) -> str:
+ import hashlib
+ import json
+
+ dhash = hashlib.md5()
+ encoded = json.dumps(dictionary, sort_keys=True).encode()
+ dhash.update(encoded)
+ return dhash.hexdigest()
+
groupped_params = defaultdict(list)
- for p in params:
- group = getattr(p, "param_group", "default")
- groupped_params[group].append(p)
-
- assert groupped_params.keys() == cfg.groups.keys(), (
- f"Parameter groups {groupped_params.keys()} and optimizer groups {cfg.groups.keys()} are not the same! "
- "Try setting 'param_group' on your parameters in the model."
- )
-
- for group, group_params in groupped_params.items():
- group_cfg = cfg.groups[group]
- with open_dict(group_cfg):
- if group_cfg.lr_float is not None:
- group_cfg.optimizer.lr = [group_cfg.lr_float]
- group_cfg.lr_scheduler.lr = [group_cfg.lr_float]
+ overrides = defaultdict(dict)
+ if not cfg.dynamic_groups:
+ for p in params:
+ group = getattr(p, "param_group", "default")
+ override_config = getattr(p, "optim_overrides", None)
+ if override_config is not None and bool(override_config):
+ overrides[group] = override_config
else:
- group_cfg.optimizer.lr = group_cfg.lr
- group_cfg.lr_scheduler.lr = group_cfg.lr
- self.optimizers[group] = _build_optimizer(group_cfg.optimizer, group_params)
- if group_cfg.lr_scheduler is not None:
- self.lr_schedulers[group] = build_lr_scheduler(
- group_cfg.lr_scheduler, self.optimizers[group]
- )
+ assert (
+ override_config == None or override_config == overrides[group]
+ ), f"For group {group}, different overrides found {override_config} v/s {overrides[group]}"
+ groupped_params[group].append(p)
+ for p, params in groupped_params.items():
+ override_config = getattr(params[0], "optim_overrides", None)
+ if override_config is not None:
+ for pp in params[1:]:
+ assert override_config == getattr(
+ pp, "optim_overrides", None
+ ), f" {str(override_config)} != {str(getattr(pp, 'optim_overrides', None))}"
+ else:
+ for p in params:
+ group = getattr(p, "param_group", "default")
+ override_config = getattr(p, "optim_overrides", None)
+ if override_config is not None:
+ override_config["group_name"] = group
+ group_name = dict_hash(override_config)
+ overrides[group_name] = override_config
+ else:
+ group_name = group
+ groupped_params[group_name].append(p)
+
+ self.optimizers_config = {}
+ for group, group_params in groupped_params.items():
+ p_group = group
+ if group in overrides and "group_name" in overrides[group]:
+ p_group = overrides[group]["group_name"]
+ if group in cfg.groups:
+ group_cfg = cfg.groups[group]
+ optimizer_config = copy.deepcopy(group_cfg.optimizer)
+ scheduler_config = copy.deepcopy(group_cfg.lr_scheduler)
+ explicit_group_present = True
+ else:
+ group_cfg = cfg.groups[p_group]
+ optimizer_config = copy.deepcopy(group_cfg.optimizer)
+ scheduler_config = copy.deepcopy(group_cfg.lr_scheduler)
+ explicit_group_present = False
+
+ if getattr(group_cfg, "lr_float", None) is not None:
+ with open_dict(optimizer_config):
+ optimizer_config.lr = [group_cfg.lr_float]
+
+ if group in overrides and "optimizer" in overrides[group]:
+ with open_dict(optimizer_config):
+ if "lr_scale" in overrides[group]["optimizer"]:
+ lr_scale = overrides[group]["optimizer"]["lr_scale"]
+ optimizer_config.lr = [
+ lr * lr_scale for lr in optimizer_config.lr
+ ]
+
+ if explicit_group_present:
+ logger.info(
+ f"For group:{group}, config as well as override present for lr"
+ )
+
+ if (
+ "weight_decay_scale" in overrides[group]["optimizer"]
+ and "optimizer_config" in optimizer_config
+ ):
+ weight_decay_scale = overrides[group]["optimizer"][
+ "weight_decay_scale"
+ ]
+ optimizer_config.weight_decay = (
+ optimizer_config.weight_decay * weight_decay_scale
+ )
+ if explicit_group_present:
+ logger.info(
+ f"For group:{group}, config as well as override present for weight_decay"
+ )
+
+ with open_dict(scheduler_config):
+ scheduler_config.lr = optimizer_config.lr
+ self.optimizers[group] = _build_optimizer(optimizer_config, group_params)
+ self.optimizers_config[group] = optimizer_config
+ if scheduler_config is not None:
+ self.lr_schedulers[group] = build_lr_scheduler(
+ scheduler_config, self.optimizers[group]
+ )
+ logger.info("Optimizers for different groups are as below")
+ for group in self.optimizers_config.keys():
+ logger.info(f"Group : {group}:{self.optimizers_config[group]}")
if len(self.lr_schedulers) > 0:
assert len(self.lr_schedulers) == len(self.optimizers), (
f"Please provide an lr scheduler for each optimizer to use pass_through scheduler. "
diff --git a/fairseq/optim/fp16_optimizer.py b/fairseq/optim/fp16_optimizer.py
index 2c4ee326e..6a4da342c 100644
--- a/fairseq/optim/fp16_optimizer.py
+++ b/fairseq/optim/fp16_optimizer.py
@@ -72,6 +72,8 @@ class _FP16OptimizerMixin(object):
p32.grad = torch.zeros_like(p32.data)
if hasattr(p, "param_group"):
p32.param_group = p.param_group
+ if hasattr(p, "optim_overrides"):
+ p32.optim_overrides = p.optim_overrides
fp32_params.append(p32)
return fp32_params
@@ -194,6 +196,9 @@ class _FP16OptimizerMixin(object):
0, aggregate_norm_fn
)
+ if torch.is_tensor(self._multiply_factor):
+ self._multiply_factor = self._multiply_factor.to(grad_norm.device)
+
if self.scaler is not None:
if grad_norm > max_norm > 0.0:
self._multiply_factor *= max_norm / grad_norm
diff --git a/fairseq/optim/fused_adam.py b/fairseq/optim/fused_adam.py
index 1290ecfdb..39a2a8369 100644
--- a/fairseq/optim/fused_adam.py
+++ b/fairseq/optim/fused_adam.py
@@ -210,6 +210,9 @@ class FusedAdamV1(torch.optim.Optimizer):
exp_avg_sq = exp_avg_sq.float() * state["exp_avg_sq_scale"]
beta1, beta2 = group["betas"]
+ if "step" not in state:
+ state["step"] = group["step"]
+
state["step"] += 1
with torch.cuda.device(p_data_fp32.device):
diff --git a/fairseq/optim/lr_scheduler/cosine_lr_scheduler.py b/fairseq/optim/lr_scheduler/cosine_lr_scheduler.py
index d01143498..5fcaea25d 100644
--- a/fairseq/optim/lr_scheduler/cosine_lr_scheduler.py
+++ b/fairseq/optim/lr_scheduler/cosine_lr_scheduler.py
@@ -77,9 +77,8 @@ class CosineLRSchedule(FairseqLRScheduler):
)
self.max_lr = cfg.lr[0] if isinstance(cfg.lr, Collection) else cfg.lr
- assert (
- self.max_lr > cfg.min_lr
- ), f"max_lr (={cfg.lr}) must be more than min_lr (={cfg.min_lr})"
+ if self.max_lr < cfg.min_lr:
+ cfg.min_lr = self.max_lr
warmup_end_lr = self.max_lr
if cfg.warmup_init_lr < 0:
diff --git a/fairseq/registry.py b/fairseq/registry.py
index f3b940604..904ffcd60 100644
--- a/fairseq/registry.py
+++ b/fairseq/registry.py
@@ -36,8 +36,9 @@ def setup_registry(registry_name: str, base_class=None, default=None, required=F
choice = cfg._name
if choice and choice in DATACLASS_REGISTRY:
+ from_checkpoint = extra_kwargs.get("from_checkpoint", False)
dc = DATACLASS_REGISTRY[choice]
- cfg = merge_with_parent(dc(), cfg)
+ cfg = merge_with_parent(dc(), cfg, remove_missing=from_checkpoint)
elif isinstance(cfg, str):
choice = cfg
if choice in DATACLASS_REGISTRY:
@@ -58,6 +59,9 @@ def setup_registry(registry_name: str, base_class=None, default=None, required=F
else:
builder = cls
+ if "from_checkpoint" in extra_kwargs:
+ del extra_kwargs["from_checkpoint"]
+
return builder(cfg, *extra_args, **extra_kwargs)
def register_x(name, dataclass=None):
diff --git a/fairseq/tasks/__init__.py b/fairseq/tasks/__init__.py
index 9a46b012c..6da1f001f 100644
--- a/fairseq/tasks/__init__.py
+++ b/fairseq/tasks/__init__.py
@@ -35,8 +35,9 @@ def setup_task(cfg: FairseqDataclass, **kwargs):
task_name = getattr(cfg, "_name", None)
if task_name and task_name in TASK_DATACLASS_REGISTRY:
+ remove_missing = "from_checkpoint" in kwargs and kwargs["from_checkpoint"]
dc = TASK_DATACLASS_REGISTRY[task_name]
- cfg = merge_with_parent(dc(), cfg)
+ cfg = merge_with_parent(dc(), cfg, remove_missing=remove_missing)
task = TASK_REGISTRY[task_name]
assert (
@@ -68,7 +69,8 @@ def register_task(name, dataclass=None):
def register_task_cls(cls):
if name in TASK_REGISTRY:
- raise ValueError("Cannot register duplicate task ({})".format(name))
+ return TASK_REGISTRY[name]
+
if not issubclass(cls, FairseqTask):
raise ValueError(
"Task ({}: {}) must extend FairseqTask".format(name, cls.__name__)
diff --git a/fairseq/tasks/audio_finetuning.py b/fairseq/tasks/audio_finetuning.py
index 5e04a1b79..77634f1ba 100644
--- a/fairseq/tasks/audio_finetuning.py
+++ b/fairseq/tasks/audio_finetuning.py
@@ -100,6 +100,7 @@ class AudioFinetuningConfig(AudioPretrainingConfig):
"adds 'prev_output_tokens' to input and appends eos to target"
},
)
+ rebuild_batches: bool = True
@register_task("audio_finetuning", dataclass=AudioFinetuningConfig)
diff --git a/fairseq/tasks/audio_pretraining.py b/fairseq/tasks/audio_pretraining.py
index a55c70400..e6de0666a 100644
--- a/fairseq/tasks/audio_pretraining.py
+++ b/fairseq/tasks/audio_pretraining.py
@@ -12,9 +12,9 @@ import sys
from argparse import Namespace
from dataclasses import dataclass, field
from typing import Optional
-from omegaconf import MISSING, II, OmegaConf
+from omegaconf import MISSING, II
-from fairseq.data import BinarizedAudioDataset, FileAudioDataset
+from fairseq.data import BinarizedAudioDataset, FileAudioDataset, SubsampleDataset
from fairseq.dataclass import FairseqDataclass, ChoiceEnum
from fairseq.data.text_compressor import TextCompressionLevel
@@ -25,24 +25,16 @@ logger = logging.getLogger(__name__)
@dataclass
-class InferredW2vConfig:
- # The following are needed to precompute mask and mask channel indices
- # before model's forward.
- mask_length: Optional[int] = II("model.mask_length")
- mask_prob: Optional[float] = II("model.mask_prob")
- mask_selection: Optional[str] = II("model.mask_selection")
- mask_other: Optional[float] = II("model.mask_other")
- no_mask_overlap: Optional[bool] = II("model.no_mask_overlap")
- mask_min_space: Optional[int] = II("model.mask_min_space")
- mask_channel_length: Optional[int] = II("model.mask_channel_length")
- mask_channel_prob: Optional[float] = II("model.mask_channel_prob")
- mask_channel_selection: Optional[str] = II("model.mask_channel_selection")
- mask_channel_other: Optional[float] = II("model.mask_channel_other")
- no_mask_channel_overlap: Optional[bool] = II("model.no_mask_channel_overlap")
- mask_channel_min_space: Optional[int] = II("model.mask_channel_min_space")
-
- conv_feature_layers: Optional[str] = II("model.conv_feature_layers")
- encoder_embed_dim: Optional[int] = II("model.encoder_embed_dim")
+class AudioMaskingConfig:
+ feature_encoder_spec: str = II("model.modalities.audio.feature_encoder_spec")
+ mask_prob: float = II("model.modalities.audio.mask_prob")
+ mask_prob_adjust: float = II("model.modalities.audio.mask_prob_adjust")
+ mask_length: int = II("model.modalities.audio.mask_length")
+ inverse_mask: bool = II("model.modalities.audio.inverse_mask")
+ mask_dropout: float = II("model.modalities.audio.mask_dropout")
+ clone_batch: int = II("model.clone_batch")
+ expand_adjacent: bool = False
+ non_overlapping: bool = False
@dataclass
@@ -82,20 +74,6 @@ class AudioPretrainingConfig(FairseqDataclass):
default=0,
metadata={"help": "number of buckets"},
)
- precompute_mask_indices: bool = field(
- default=False,
- metadata={
- "help": "flag to compute mask indices in data preparation.",
- },
- )
-
- inferred_w2v_config: Optional[InferredW2vConfig] = field(
- default=None,
- metadata={
- "help": "wav2vec 2.0 masking arguments used to pre-compute masks (required for TPU)",
- },
- )
-
tpu: bool = II("common.tpu")
text_compression_level: ChoiceEnum([x.name for x in TextCompressionLevel]) = field(
default="none",
@@ -105,6 +83,14 @@ class AudioPretrainingConfig(FairseqDataclass):
},
)
+ rebuild_batches: bool = True
+ precompute_mask_config: Optional[AudioMaskingConfig] = None
+
+ post_save_script: Optional[str] = None
+
+ subsample: float = 1
+ seed: int = II("common.seed")
+
@register_task("audio_pretraining", dataclass=AudioPretrainingConfig)
class AudioPretrainingTask(FairseqTask):
@@ -122,17 +108,6 @@ class AudioPretrainingTask(FairseqTask):
return cls(cfg)
- def _get_mask_precompute_kwargs(self, cfg):
- if self.cfg.precompute_mask_indices or self.cfg.tpu:
- assert (
- cfg.inferred_w2v_config is not None
- ), "inferred_w2v_config must be set"
- return OmegaConf.to_container(
- cfg.inferred_w2v_config, resolve=True, enum_to_str=True
- )
- else:
- return {}
-
def load_dataset(self, split: str, task_cfg: FairseqDataclass = None, **kwargs):
data_path = self.cfg.data
task_cfg = task_cfg or self.cfg
@@ -145,6 +120,12 @@ class AudioPretrainingTask(FairseqTask):
text_compression_level = getattr(
TextCompressionLevel, str(self.cfg.text_compression_level)
)
+
+ compute_mask = task_cfg.precompute_mask_config is not None
+ mask_args = {}
+ if compute_mask:
+ mask_args = task_cfg.precompute_mask_config
+
if getattr(task_cfg, "binarized_dataset", False):
self.datasets[split] = BinarizedAudioDataset(
data_path,
@@ -155,8 +136,8 @@ class AudioPretrainingTask(FairseqTask):
pad=task_cfg.labels is not None or task_cfg.enable_padding,
normalize=task_cfg.normalize,
num_buckets=self.cfg.num_batch_buckets or int(self.cfg.tpu),
- compute_mask_indices=(self.cfg.precompute_mask_indices or self.cfg.tpu),
- **self._get_mask_precompute_kwargs(task_cfg),
+ compute_mask=compute_mask,
+ **mask_args,
)
else:
manifest_path = os.path.join(data_path, "{}.tsv".format(split))
@@ -169,9 +150,17 @@ class AudioPretrainingTask(FairseqTask):
pad=task_cfg.labels is not None or task_cfg.enable_padding,
normalize=task_cfg.normalize,
num_buckets=self.cfg.num_batch_buckets or int(self.cfg.tpu),
- compute_mask_indices=(self.cfg.precompute_mask_indices or self.cfg.tpu),
text_compression_level=text_compression_level,
- **self._get_mask_precompute_kwargs(task_cfg),
+ compute_mask=compute_mask,
+ **mask_args,
+ )
+
+ if getattr(task_cfg, "subsample", 1) < 1:
+ self.datasets[split] = SubsampleDataset(
+ self.datasets[split],
+ task_cfg.subsample,
+ shuffle=True,
+ seed=task_cfg.seed,
)
if self.cfg.tpu and task_cfg.inferred_w2v_config.mask_channel_prob == 0.0:
@@ -181,14 +170,6 @@ class AudioPretrainingTask(FairseqTask):
"0. You may want to set this to a low value close to 0."
)
- @property
- def source_dictionary(self):
- return None
-
- @property
- def target_dictionary(self):
- return None
-
def max_positions(self):
"""Maximum input length supported by the encoder."""
return sys.maxsize, sys.maxsize
@@ -203,3 +184,24 @@ class AudioPretrainingTask(FairseqTask):
model_cfg.w2v_args = actualized_cfg.w2v_args
return model
+
+ def post_save(self, cp_path, num_updates):
+ if self.cfg.post_save_script is not None:
+ logger.info(f"launching {self.cfg.post_save_script}")
+ import os.path as osp
+ from fairseq.file_io import PathManager
+
+ eval_cp_path = osp.join(
+ osp.dirname(cp_path), f"checkpoint_eval_{num_updates}.pt"
+ )
+
+ print(cp_path, eval_cp_path, osp.dirname(cp_path))
+
+ assert PathManager.copy(
+ cp_path, eval_cp_path, overwrite=True
+ ), f"Failed to copy {cp_path} to {eval_cp_path}"
+
+ import subprocess
+ import shlex
+
+ subprocess.call(shlex.split(f"{self.cfg.post_save_script} {eval_cp_path}"))
diff --git a/fairseq/tasks/fairseq_task.py b/fairseq/tasks/fairseq_task.py
index 3cba8f224..7481e0f4b 100644
--- a/fairseq/tasks/fairseq_task.py
+++ b/fairseq/tasks/fairseq_task.py
@@ -272,6 +272,7 @@ class FairseqTask(object):
and not update_epoch_batch_itr
and self.can_reuse_epoch_itr(dataset)
)
+ logger.info(f"can_reuse_epoch_itr = {can_reuse_epoch_itr}")
if can_reuse_epoch_itr and dataset in self.dataset_to_epoch_iter:
logger.debug("reusing EpochBatchIterator for epoch {}".format(epoch))
return self.dataset_to_epoch_iter[dataset]
@@ -281,26 +282,39 @@ class FairseqTask(object):
# initialize the dataset with the correct starting epoch
dataset.set_epoch(epoch)
- # get indices ordered by example size
- with data_utils.numpy_seed(seed):
- indices = dataset.ordered_indices()
+ def make_batches(dataset, epoch):
+ logger.info(f"creating new batches for epoch {epoch}")
- # filter examples that are too large
- if max_positions is not None:
- indices = self.filter_indices_by_size(
- indices, dataset, max_positions, ignore_invalid_inputs
+ # get indices ordered by example size
+ with data_utils.numpy_seed(seed + epoch):
+ indices = dataset.ordered_indices()
+
+ # filter examples that are too large
+ if max_positions is not None:
+ indices = self.filter_indices_by_size(
+ indices, dataset, max_positions, ignore_invalid_inputs
+ )
+
+ # create mini-batches with given size constraints
+ batches = dataset.batch_by_size(
+ indices,
+ max_tokens=max_tokens,
+ max_sentences=max_sentences,
+ required_batch_size_multiple=required_batch_size_multiple,
)
-
- # create mini-batches with given size constraints
- batch_sampler = dataset.batch_by_size(
- indices,
- max_tokens=max_tokens,
- max_sentences=max_sentences,
- required_batch_size_multiple=required_batch_size_multiple,
- )
+ return batches
reuse_dataloader = getattr(self.cfg, "reuse_dataloader", True)
- persistent_workers = getattr(self.cfg, "persistent_workers", False)
+ persistent_workers = getattr(self.cfg, "persistent_workers", True)
+ rebuild_batches = getattr(self.cfg, "rebuild_batches", False)
+ logger.info(f"reuse_dataloader = {reuse_dataloader}")
+ logger.info(f"rebuild_batches = {rebuild_batches}")
+
+ if rebuild_batches:
+ logger.info("batches will be rebuilt for each epoch")
+ batch_sampler = make_batches
+ else:
+ batch_sampler = make_batches(dataset, epoch)
# return a reusable, sharded iterator
epoch_iter = iterators.EpochBatchIterator(
@@ -341,7 +355,7 @@ class FairseqTask(object):
model = quantization_utils.quantize_model_scalar(model, cfg)
return model
- def build_criterion(self, cfg: DictConfig):
+ def build_criterion(self, cfg: DictConfig, from_checkpoint=False):
"""
Build the :class:`~fairseq.criterions.FairseqCriterion` instance for
this task.
@@ -354,7 +368,7 @@ class FairseqTask(object):
"""
from fairseq import criterions
- return criterions.build_criterion(cfg, self)
+ return criterions.build_criterion(cfg, self, from_checkpoint=from_checkpoint)
def build_generator(
self,
@@ -614,13 +628,13 @@ class FairseqTask(object):
def source_dictionary(self):
"""Return the source :class:`~fairseq.data.Dictionary` (if applicable
for this task)."""
- raise NotImplementedError
+ return None
@property
def target_dictionary(self):
"""Return the target :class:`~fairseq.data.Dictionary` (if applicable
for this task)."""
- raise NotImplementedError
+ return None
def build_tokenizer(self, args):
"""Build the pre-tokenizer for this task."""
diff --git a/fairseq/tasks/masked_lm.py b/fairseq/tasks/masked_lm.py
index 6393ee480..b064907a5 100644
--- a/fairseq/tasks/masked_lm.py
+++ b/fairseq/tasks/masked_lm.py
@@ -20,6 +20,7 @@ from fairseq.data import (
NumSamplesDataset,
PrependTokenDataset,
RightPadDataset,
+ RightPaddingMaskDataset,
SortDataset,
TokenBlockDataset,
data_utils,
@@ -106,6 +107,22 @@ class MaskedLMConfig(FairseqDataclass):
"help": "include target tokens in model input. this is used for data2vec"
},
)
+ include_index: bool = field(
+ default=True,
+ metadata={"help": "include index in model input. this is used for data2vec"},
+ )
+ skip_masking: bool = field(
+ default=False,
+ metadata={"help": "skip masking at dataset"},
+ )
+ # subsample_train: float = field(
+ # default=1,
+ # metadata={"help": "shorten training set for debugging"},
+ # )
+ d2v2_multi: bool = field(
+ default=False,
+ metadata={"help": "prepare dataset for data2vec_multi"},
+ )
@register_task("masked_lm", dataclass=MaskedLMConfig)
@@ -115,20 +132,25 @@ class MaskedLMTask(FairseqTask):
"""Task for training masked language models (e.g., BERT, RoBERTa)."""
- def __init__(self, cfg: MaskedLMConfig, dictionary):
+ def __init__(self, cfg: MaskedLMConfig, dictionary=None):
super().__init__(cfg)
- self.dictionary = dictionary
+ self.dictionary = dictionary or self.load_dict(cfg)
# add mask token
- self.mask_idx = dictionary.add_symbol("")
+ self.mask_idx = self.dictionary.add_symbol("")
@classmethod
def setup_task(cls, cfg: MaskedLMConfig, **kwargs):
+ dictionary = cls.load_dict(cfg)
+ return cls(cfg, dictionary)
+
+ @classmethod
+ def load_dict(cls, cfg):
paths = utils.split_paths(cfg.data)
assert len(paths) > 0
dictionary = Dictionary.load(os.path.join(paths[0], "dict.txt"))
logger.info("dictionary: {} types".format(len(dictionary)))
- return cls(cfg, dictionary)
+ return dictionary
def _load_dataset_split(self, split, epoch, combine):
paths = utils.split_paths(self.cfg.data)
@@ -197,6 +219,7 @@ class MaskedLMTask(FairseqTask):
mask_whole_words=mask_whole_words,
mask_multiple_length=self.cfg.mask_multiple_length,
mask_stdev=self.cfg.mask_stdev,
+ skip_masking=self.cfg.skip_masking,
)
with data_utils.numpy_seed(self.cfg.seed):
@@ -207,6 +230,16 @@ class MaskedLMTask(FairseqTask):
pad_idx=self.source_dictionary.pad(),
)
+ if self.cfg.d2v2_multi:
+ dataset = self._d2v2_multi_dataset(src_dataset)
+ else:
+ dataset = self._regular_dataset(src_dataset, target_dataset)
+
+ self.datasets[split] = SortDataset(
+ dataset, sort_order=[shuffle, src_dataset.sizes]
+ )
+
+ def _regular_dataset(self, src_dataset, target_dataset):
input_dict = {
"src_tokens": RightPadDataset(
src_dataset,
@@ -216,23 +249,41 @@ class MaskedLMTask(FairseqTask):
}
if self.cfg.include_target_tokens:
input_dict["target_tokens"] = target_dataset
+ if self.cfg.include_index:
+ input_dict["src_id"] = IdDataset()
- self.datasets[split] = SortDataset(
- NestedDictionaryDataset(
- {
- "id": IdDataset(),
- "net_input": input_dict,
- "target": target_dataset,
- "nsentences": NumSamplesDataset(),
- "ntokens": NumelDataset(src_dataset, reduce=True),
- },
- sizes=[src_dataset.sizes],
- ),
- sort_order=[
- shuffle,
- src_dataset.sizes,
- ],
+ dataset = NestedDictionaryDataset(
+ {
+ "id": IdDataset(),
+ "net_input": input_dict,
+ "target": target_dataset,
+ "nsentences": NumSamplesDataset(),
+ "ntokens": NumelDataset(src_dataset, reduce=True),
+ },
+ sizes=[src_dataset.sizes],
)
+ return dataset
+
+ def _d2v2_multi_dataset(self, src_dataset):
+ input_dict = {
+ "source": RightPadDataset(
+ src_dataset,
+ pad_idx=self.source_dictionary.pad(),
+ ),
+ "id": IdDataset(),
+ "padding_mask": RightPaddingMaskDataset(src_dataset),
+ }
+
+ dataset = NestedDictionaryDataset(
+ {
+ "id": IdDataset(),
+ "net_input": input_dict,
+ "nsentences": NumSamplesDataset(),
+ "ntokens": NumelDataset(src_dataset, reduce=True),
+ },
+ sizes=[src_dataset.sizes],
+ )
+ return dataset
def build_dataset_for_inference(self, src_tokens, src_lengths, sort=True):
src_dataset = RightPadDataset(
@@ -268,3 +319,9 @@ class MaskedLMTask(FairseqTask):
@property
def target_dictionary(self):
return self.dictionary
+
+ def begin_epoch(self, epoch, model):
+ model.set_epoch(epoch)
+
+ def max_positions(self):
+ return self.cfg.tokens_per_sample
diff --git a/fairseq/tasks/sentence_prediction.py b/fairseq/tasks/sentence_prediction.py
index 52532ff61..de80addaf 100644
--- a/fairseq/tasks/sentence_prediction.py
+++ b/fairseq/tasks/sentence_prediction.py
@@ -23,6 +23,7 @@ from fairseq.data import (
PrependTokenDataset,
RawLabelDataset,
RightPadDataset,
+ RightPaddingMaskDataset,
RollDataset,
SortDataset,
StripTokenDataset,
@@ -83,6 +84,11 @@ class SentencePredictionConfig(FairseqDataclass):
classification_head_name: str = II("criterion.classification_head_name")
seed: int = II("common.seed")
+ d2v2_multi: bool = field(
+ default=False,
+ metadata={"help": "prepare dataset for data2vec_multi"},
+ )
+
@register_task("sentence_prediction", dataclass=SentencePredictionConfig)
class SentencePredictionTask(FairseqTask):
@@ -181,28 +187,39 @@ class SentencePredictionTask(FairseqTask):
self.cfg.seed,
)
- dataset = {
- "id": IdDataset(),
- "net_input": {
+ if self.cfg.d2v2_multi:
+ net_input = {
+ "source": RightPadDataset(
+ src_tokens,
+ pad_idx=self.source_dictionary.pad(),
+ ),
+ "id": IdDataset(),
+ "padding_mask": RightPaddingMaskDataset(src_tokens),
+ }
+ else:
+ net_input = {
"src_tokens": RightPadDataset(
src_tokens,
pad_idx=self.source_dictionary.pad(),
),
"src_lengths": NumelDataset(src_tokens, reduce=False),
- },
+ }
+ if self.cfg.add_prev_output_tokens:
+ prev_tokens_dataset = RightPadDataset(
+ RollDataset(src_tokens, 1),
+ pad_idx=self.dictionary.pad(),
+ )
+ net_input.update(
+ prev_output_tokens=prev_tokens_dataset,
+ )
+
+ dataset = {
+ "id": IdDataset(),
+ "net_input": net_input,
"nsentences": NumSamplesDataset(),
"ntokens": NumelDataset(src_tokens, reduce=True),
}
- if self.cfg.add_prev_output_tokens:
- prev_tokens_dataset = RightPadDataset(
- RollDataset(src_tokens, 1),
- pad_idx=self.dictionary.pad(),
- )
- dataset["net_input"].update(
- prev_output_tokens=prev_tokens_dataset,
- )
-
if not self.cfg.regression_target:
label_dataset = make_dataset("label", self.label_dictionary)
if label_dataset is not None:
diff --git a/fairseq/trainer.py b/fairseq/trainer.py
index da1f94910..16b1b9169 100644
--- a/fairseq/trainer.py
+++ b/fairseq/trainer.py
@@ -30,7 +30,6 @@ from fairseq.nan_detector import NanDetector
from fairseq.optim import lr_scheduler
from fairseq.utils import safe_hasattr
-
logger = logging.getLogger(__name__)
@@ -288,12 +287,27 @@ class Trainer(object):
return self._lr_scheduler
def _build_optimizer(self):
- params = list(
- filter(
- lambda p: p.requires_grad,
- chain(self.model.parameters(), self.criterion.parameters()),
+
+ if (
+ self.cfg.optimization.debug_param_names
+ and self.cfg.common.fp16_no_flatten_grads
+ ):
+ params = []
+ self.param_names = []
+
+ for n, p in chain(
+ self.model.named_parameters(), self.criterion.named_parameters()
+ ):
+ if p.requires_grad:
+ params.append(p)
+ self.param_names.append(n)
+ else:
+ params = list(
+ filter(
+ lambda p: p.requires_grad,
+ chain(self.model.parameters(), self.criterion.parameters()),
+ )
)
- )
if self.is_fsdp and self.cfg.common.fp16:
# FullyShardedDataParallel always uses MemoryEfficientFP16 wrapper,
@@ -432,18 +446,21 @@ class Trainer(object):
def save_checkpoint(self, filename, extra_state):
"""Save all training state in a checkpoint file."""
-
- logger.info(f"Saving checkpoint to {os.path.abspath(filename)}")
- # call state_dict on all ranks in case it needs internal communication
- state_dict = utils.move_to_cpu(self.state_dict())
- state_dict["extra_state"].update(extra_state)
if self.should_save_checkpoint_on_current_rank:
+
+ logger.info(f"Saving checkpoint to {os.path.abspath(filename)}")
+ # call state_dict on all ranks in case it needs internal communication
+ state_dict = utils.move_to_cpu(self.state_dict())
+ state_dict["extra_state"].update(extra_state)
+
checkpoint_utils.torch_persistent_save(
state_dict,
filename,
async_write=self.cfg.checkpoint.write_checkpoints_asynchronously,
)
- logger.info(f"Finished saving checkpoint to {os.path.abspath(filename)}")
+ logger.info(f"Finished saving checkpoint to {os.path.abspath(filename)}")
+ return os.path.abspath(filename)
+ return None
def load_checkpoint(
self,
@@ -793,6 +810,8 @@ class Trainer(object):
if self.cfg.ema.store_ema and getattr(self.task, "uses_ema", False):
extra_kwargs["ema_model"] = self.ema.get_model()
+ has_oom = False
+
# forward and backward pass
logging_outputs, sample_size, ooms = [], 0, 0
for i, sample in enumerate(samples): # delayed update loop
@@ -842,17 +861,9 @@ class Trainer(object):
except RuntimeError as e:
if "out of memory" in str(e):
self._log_oom(e)
+ has_oom = True
if raise_oom:
raise e
- logger.warning(
- "attempting to recover from OOM in forward/backward pass"
- )
- ooms += 1
- self.zero_grad()
- if self.cuda:
- torch.cuda.empty_cache()
- if self.cfg.distributed_training.distributed_world_size == 1:
- return None
else:
raise e
except Exception:
@@ -862,6 +873,18 @@ class Trainer(object):
)
raise
+ if has_oom:
+ logger.warning(
+ "attempting to recover from OOM in forward/backward pass"
+ )
+ ooms += 1
+ self.zero_grad()
+ if self.cuda:
+ torch.cuda.empty_cache()
+
+ if self.cfg.distributed_training.distributed_world_size == 1:
+ return None
+
if self.tpu and i < len(samples) - 1:
# tpu-comment: every XLA operation before marking step is
# appended to the IR graph, and processing too many batches
@@ -989,6 +1012,14 @@ class Trainer(object):
logger.info(
f"NOTE: gradient overflow detected, ignoring gradient, {str(e)}"
)
+
+ if hasattr(self, "param_names") and hasattr(
+ self.optimizer, "fp32_optimizer"
+ ):
+ for p, n in zip(self.optimizer.fp32_optimizer.params, self.param_names):
+ if torch.isinf(p.grad).any() or torch.isnan(p.grad).any():
+ logger.info(f"overflow in param {n}")
+
grad_norm = torch.tensor(0.0).cuda()
self.zero_grad()
except RuntimeError as e:
diff --git a/fairseq/utils.py b/fairseq/utils.py
index e6196863f..e5c77e4b9 100644
--- a/fairseq/utils.py
+++ b/fairseq/utils.py
@@ -270,9 +270,9 @@ def strip_pad(tensor, pad):
return tensor[tensor.ne(pad)]
-def buffered_arange(max):
+def buffered_arange(max, device="cpu"):
if not hasattr(buffered_arange, "buf"):
- buffered_arange.buf = torch.LongTensor()
+ buffered_arange.buf = torch.LongTensor().to(device)
if max > buffered_arange.buf.numel():
buffered_arange.buf.resize_(max)
torch.arange(max, out=buffered_arange.buf)
diff --git a/fairseq_cli/hydra_validate.py b/fairseq_cli/hydra_validate.py
new file mode 100644
index 000000000..cb6f7612d
--- /dev/null
+++ b/fairseq_cli/hydra_validate.py
@@ -0,0 +1,188 @@
+#!/usr/bin/env python3 -u
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import logging
+import os
+import sys
+from itertools import chain
+
+import torch
+from hydra.core.hydra_config import HydraConfig
+from omegaconf import OmegaConf, open_dict
+import hydra
+
+from fairseq import checkpoint_utils, distributed_utils, utils
+from fairseq.dataclass.configs import FairseqConfig
+from fairseq.dataclass.initialize import add_defaults, hydra_init
+from fairseq.dataclass.utils import omegaconf_no_object_check
+from fairseq.distributed import utils as distributed_utils
+from fairseq.logging import metrics, progress_bar
+from fairseq.utils import reset_logging
+
+logging.basicConfig(
+ format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
+ datefmt="%Y-%m-%d %H:%M:%S",
+ level=os.environ.get("LOGLEVEL", "INFO").upper(),
+ stream=sys.stdout,
+)
+logger = logging.getLogger("fairseq_cli.validate")
+
+
+@hydra.main(config_path=os.path.join("..", "fairseq", "config"), config_name="config")
+def hydra_main(cfg: FairseqConfig) -> float:
+ return _hydra_main(cfg)
+
+
+def _hydra_main(cfg: FairseqConfig, **kwargs) -> float:
+ add_defaults(cfg)
+
+ if cfg.common.reset_logging:
+ reset_logging() # Hydra hijacks logging, fix that
+ else:
+ # check if directly called or called through hydra_main
+ if HydraConfig.initialized():
+ with open_dict(cfg):
+ # make hydra logging work with ddp (see # see https://github.com/facebookresearch/hydra/issues/1126)
+ cfg.job_logging_cfg = OmegaConf.to_container(
+ HydraConfig.get().job_logging, resolve=True
+ )
+
+ with omegaconf_no_object_check():
+ cfg = OmegaConf.create(
+ OmegaConf.to_container(cfg, resolve=True, enum_to_str=True)
+ )
+ OmegaConf.set_struct(cfg, True)
+
+ assert (
+ cfg.dataset.max_tokens is not None or cfg.dataset.batch_size is not None
+ ), "Must specify batch size either with --max-tokens or --batch-size"
+
+ distributed_utils.call_main(cfg, validate, **kwargs)
+
+
+def validate(cfg):
+ utils.import_user_module(cfg.common)
+
+ use_fp16 = cfg.common.fp16
+ use_cuda = torch.cuda.is_available() and not cfg.common.cpu
+
+ if use_cuda:
+ torch.cuda.set_device(cfg.distributed_training.device_id)
+
+ if cfg.distributed_training.distributed_world_size > 1:
+ data_parallel_world_size = distributed_utils.get_data_parallel_world_size()
+ data_parallel_rank = distributed_utils.get_data_parallel_rank()
+ else:
+ data_parallel_world_size = 1
+ data_parallel_rank = 0
+
+ overrides = {"task": {"data": cfg.task.data}}
+
+ # Load ensemble
+ logger.info("loading model(s) from {}".format(cfg.common_eval.path))
+ models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task(
+ [cfg.common_eval.path],
+ arg_overrides=overrides,
+ suffix=cfg.checkpoint.checkpoint_suffix,
+ )
+ model = models[0]
+
+ # Move models to GPU
+ for model in models:
+ model.eval()
+ if use_fp16:
+ model.half()
+ if use_cuda:
+ model.cuda()
+
+ # Print args
+ logger.info(saved_cfg)
+
+ # Build criterion
+ criterion = task.build_criterion(saved_cfg.criterion, from_checkpoint=True)
+ criterion.eval()
+
+ for subset in cfg.dataset.valid_subset.split(","):
+ try:
+ task.load_dataset(subset, combine=False, epoch=1, task_cfg=saved_cfg.task)
+ dataset = task.dataset(subset)
+ except KeyError:
+ raise Exception("Cannot find dataset: " + subset)
+
+ # Initialize data iterator
+ itr = task.get_batch_iterator(
+ dataset=dataset,
+ max_tokens=cfg.dataset.max_tokens,
+ max_sentences=cfg.dataset.batch_size,
+ max_positions=utils.resolve_max_positions(
+ task.max_positions(),
+ *[m.max_positions() for m in models],
+ ),
+ ignore_invalid_inputs=cfg.dataset.skip_invalid_size_inputs_valid_test,
+ required_batch_size_multiple=cfg.dataset.required_batch_size_multiple,
+ seed=cfg.common.seed,
+ num_shards=data_parallel_world_size,
+ shard_id=data_parallel_rank,
+ num_workers=cfg.dataset.num_workers,
+ data_buffer_size=cfg.dataset.data_buffer_size,
+ ).next_epoch_itr(shuffle=False)
+ progress = progress_bar.progress_bar(
+ itr,
+ log_format=cfg.common.log_format,
+ log_interval=cfg.common.log_interval,
+ prefix=f"valid on '{subset}' subset",
+ default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"),
+ )
+
+ def apply_half(t):
+ if t.dtype is torch.float32:
+ return t.to(dtype=torch.half)
+ return t
+
+ log_outputs = []
+ for i, sample in enumerate(progress):
+ sample = utils.move_to_cuda(sample) if use_cuda else sample
+
+ if use_fp16:
+ sample = utils.apply_to_sample(apply_half, sample)
+
+ _loss, _sample_size, log_output = task.valid_step(sample, model, criterion)
+ with metrics.aggregate() as agg:
+ task.reduce_metrics([log_output], criterion)
+ progress.log(agg.get_smoothed_values(), step=i)
+ # progress.log(log_output, step=i) from vision
+ log_outputs.append(log_output)
+
+ if data_parallel_world_size > 1:
+ log_outputs = distributed_utils.all_gather_list(
+ log_outputs,
+ max_size=cfg.common.all_gather_list_size,
+ group=distributed_utils.get_data_parallel_group(),
+ )
+ log_outputs = list(chain.from_iterable(log_outputs))
+
+ with metrics.aggregate() as agg:
+ task.reduce_metrics(log_outputs, criterion)
+ log_output = agg.get_smoothed_values()
+
+ progress.print(log_output, tag=subset, step=i)
+
+
+def cli_main():
+ try:
+ from hydra._internal.utils import get_args
+
+ cfg_name = get_args().config_name or "config"
+ except:
+ logger.warning("Failed to get config name from hydra args")
+ cfg_name = "config"
+
+ hydra_init(cfg_name)
+ hydra_main()
+
+
+if __name__ == "__main__":
+ cli_main()
diff --git a/fairseq_cli/train.py b/fairseq_cli/train.py
index 376bd1d03..f771bff65 100644
--- a/fairseq_cli/train.py
+++ b/fairseq_cli/train.py
@@ -125,12 +125,13 @@ def main(cfg: FairseqConfig) -> None:
# Load valid dataset (we load training data below, based on the latest checkpoint)
# We load the valid dataset AFTER building the model
- data_utils.raise_if_valid_subsets_unintentionally_ignored(cfg)
- if cfg.dataset.combine_valid_subsets:
- task.load_dataset("valid", combine=True, epoch=1)
- else:
- for valid_sub_split in cfg.dataset.valid_subset.split(","):
- task.load_dataset(valid_sub_split, combine=False, epoch=1)
+ if not cfg.dataset.disable_validation:
+ data_utils.raise_if_valid_subsets_unintentionally_ignored(cfg)
+ if cfg.dataset.combine_valid_subsets:
+ task.load_dataset("valid", combine=True, epoch=1)
+ else:
+ for valid_sub_split in cfg.dataset.valid_subset.split(","):
+ task.load_dataset(valid_sub_split, combine=False, epoch=1)
# (optionally) Configure quantization
if cfg.common.quantization_config_path is not None:
@@ -175,6 +176,20 @@ def main(cfg: FairseqConfig) -> None:
max_epoch = cfg.optimization.max_epoch or math.inf
lr = trainer.get_lr()
+ # TODO: a dry run on validation set to pin the memory
+ valid_subsets = cfg.dataset.valid_subset.split(",")
+ if not cfg.dataset.disable_validation:
+ for subset in valid_subsets:
+ logger.info('begin dry-run validation on "{}" subset'.format(subset))
+ itr = trainer.get_valid_iterator(subset).next_epoch_itr(
+ shuffle=False, set_dataset_epoch=False # use a fixed valid set
+ )
+ if cfg.common.tpu:
+ itr = utils.tpu_data_loader(itr)
+ for _ in itr:
+ pass
+ # TODO: end of dry run section
+
train_meter = meters.StopwatchMeter()
train_meter.start()
while epoch_itr.next_epoch_idx <= max_epoch:
@@ -424,9 +439,11 @@ def validate_and_save(
# Save checkpoint
if do_save or should_stop:
- checkpoint_utils.save_checkpoint(
+ cp_path = checkpoint_utils.save_checkpoint(
cfg.checkpoint, trainer, epoch_itr, valid_losses[0]
)
+ if cp_path is not None and hasattr(task, "post_save"):
+ task.post_save(cp_path, num_updates)
return valid_losses, should_stop
diff --git a/hydra_plugins/dependency_submitit_launcher/hydra_plugins/dependency_submitit_launcher/__init__.py b/hydra_plugins/dependency_submitit_launcher/hydra_plugins/dependency_submitit_launcher/__init__.py
new file mode 100644
index 000000000..4884f5bdc
--- /dev/null
+++ b/hydra_plugins/dependency_submitit_launcher/hydra_plugins/dependency_submitit_launcher/__init__.py
@@ -0,0 +1,3 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+
+__version__ = "0.1"
diff --git a/hydra_plugins/dependency_submitit_launcher/hydra_plugins/dependency_submitit_launcher/config.py b/hydra_plugins/dependency_submitit_launcher/hydra_plugins/dependency_submitit_launcher/config.py
new file mode 100644
index 000000000..91926c4ab
--- /dev/null
+++ b/hydra_plugins/dependency_submitit_launcher/hydra_plugins/dependency_submitit_launcher/config.py
@@ -0,0 +1,23 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+from dataclasses import dataclass, field
+
+from hydra.core.config_store import ConfigStore
+
+from hydra_plugins.hydra_submitit_launcher.config import SlurmQueueConf
+
+
+@dataclass
+class DependencySubmititConf(SlurmQueueConf):
+ """Slurm configuration overrides and specific parameters"""
+
+ _target_: str = (
+ "hydra_plugins.dependency_submitit_launcher.launcher.DependencySubmititLauncher"
+ )
+
+
+ConfigStore.instance().store(
+ group="hydra/launcher",
+ name="dependency_submitit_slurm",
+ node=DependencySubmititConf(),
+ provider="dependency_submitit_slurm",
+)
diff --git a/hydra_plugins/dependency_submitit_launcher/hydra_plugins/dependency_submitit_launcher/launcher.py b/hydra_plugins/dependency_submitit_launcher/hydra_plugins/dependency_submitit_launcher/launcher.py
new file mode 100644
index 000000000..b3fcf79e1
--- /dev/null
+++ b/hydra_plugins/dependency_submitit_launcher/hydra_plugins/dependency_submitit_launcher/launcher.py
@@ -0,0 +1,121 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+import logging
+import os
+import subprocess
+from pathlib import Path
+from typing import Any, List, Sequence
+
+from hydra.core.singleton import Singleton
+from hydra.core.utils import JobReturn, filter_overrides
+from omegaconf import OmegaConf
+
+log = logging.getLogger(__name__)
+
+from .config import DependencySubmititConf
+from hydra_plugins.hydra_submitit_launcher.submitit_launcher import BaseSubmititLauncher
+
+
+class DependencySubmititLauncher(BaseSubmititLauncher):
+ _EXECUTOR = "slurm"
+
+ def launch(
+ self, job_overrides: Sequence[Sequence[str]], initial_job_idx: int
+ ) -> Sequence[JobReturn]:
+
+ # lazy import to ensure plugin discovery remains fast
+ import submitit
+
+ assert self.config is not None
+
+ num_jobs = len(job_overrides)
+ assert num_jobs > 0
+
+ next_script = None
+
+ for jo in job_overrides:
+ if next_script is None:
+ for item in jo:
+ if "next_script=" in item:
+ next_script = item
+ break
+ assert (
+ next_script is not None
+ ), "job overrides must contain +next_script=path/to/next/script"
+ jo.remove(next_script)
+
+ idx = next_script.find("=")
+ next_script = next_script[idx + 1 :]
+
+ params = self.params
+ # build executor
+ init_params = {"folder": self.params["submitit_folder"]}
+ specific_init_keys = {"max_num_timeout"}
+
+ init_params.update(
+ **{
+ f"{self._EXECUTOR}_{x}": y
+ for x, y in params.items()
+ if x in specific_init_keys
+ }
+ )
+ init_keys = specific_init_keys | {"submitit_folder"}
+ executor = submitit.AutoExecutor(cluster=self._EXECUTOR, **init_params)
+
+ # specify resources/parameters
+ baseparams = set(OmegaConf.structured(DependencySubmititConf).keys())
+ params = {
+ x if x in baseparams else f"{self._EXECUTOR}_{x}": y
+ for x, y in params.items()
+ if x not in init_keys
+ }
+ executor.update_parameters(**params)
+
+ log.info(
+ f"Submitit '{self._EXECUTOR}' sweep output dir : "
+ f"{self.config.hydra.sweep.dir}"
+ )
+ sweep_dir = Path(str(self.config.hydra.sweep.dir))
+ sweep_dir.mkdir(parents=True, exist_ok=True)
+ if "mode" in self.config.hydra.sweep:
+ mode = int(str(self.config.hydra.sweep.mode), 8)
+ os.chmod(sweep_dir, mode=mode)
+
+ job_params: List[Any] = []
+ for idx, overrides in enumerate(job_overrides):
+ idx = initial_job_idx + idx
+ lst = " ".join(filter_overrides(overrides))
+ log.info(f"\t#{idx} : {lst}")
+ job_params.append(
+ (
+ list(overrides),
+ "hydra.sweep.dir",
+ idx,
+ f"job_id_for_{idx}",
+ Singleton.get_state(),
+ )
+ )
+
+ jobs = executor.map_array(self, *zip(*job_params))
+
+ for j, jp in zip(jobs, job_params):
+ job_id = str(j.job_id)
+ task_id = "0" if "_" not in job_id else job_id.split("_")[1]
+ sweep_config = self.config_loader.load_sweep_config(self.config, jp[0])
+ dir = sweep_config.hydra.sweep.dir
+
+ dir = (
+ dir.replace("[", "")
+ .replace("]", "")
+ .replace("{", "")
+ .replace("}", "")
+ .replace(",", "_")
+ .replace("'", "")
+ .replace('"', "")
+ )
+
+ subprocess.call(
+ [next_script, job_id, task_id, dir],
+ shell=False,
+ )
+
+ return [j.results()[0] for j in jobs]
diff --git a/hydra_plugins/dependency_submitit_launcher/setup.py b/hydra_plugins/dependency_submitit_launcher/setup.py
new file mode 100644
index 000000000..bf795462b
--- /dev/null
+++ b/hydra_plugins/dependency_submitit_launcher/setup.py
@@ -0,0 +1,29 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+# type: ignore
+from pathlib import Path
+
+from read_version import read_version
+from setuptools import find_namespace_packages, setup
+
+setup(
+ name="dependency-submitit-launcher",
+ version=read_version("hydra_plugins/dependency_submitit_launcher", "__init__.py"),
+ author="Alexei Baevski",
+ author_email="abaevski@fb.com",
+ description="Dependency-supporting Submitit Launcher for Hydra apps",
+ packages=find_namespace_packages(include=["hydra_plugins.*"]),
+ classifiers=[
+ "License :: OSI Approved :: MIT License",
+ "Programming Language :: Python :: 3.7",
+ "Programming Language :: Python :: 3.8",
+ "Programming Language :: Python :: 3.9",
+ "Operating System :: MacOS",
+ "Operating System :: POSIX :: Linux",
+ "Development Status :: 4 - Beta",
+ ],
+ install_requires=[
+ "hydra-core>=1.0.4",
+ "submitit>=1.0.0",
+ ],
+ include_package_data=True,
+)
diff --git a/setup.py b/setup.py
index 8a9b2f977..dae06080c 100644
--- a/setup.py
+++ b/setup.py
@@ -184,10 +184,12 @@ def do_setup(package_data):
"numpy>=1.21.3",
"regex",
"sacrebleu>=1.4.12",
- "torch>=1.10",
+ "torch>=1.13",
"tqdm",
"bitarray",
"torchaudio>=0.8.0",
+ "scikit-learn",
+ "packaging",
],
extras_require={
"dev": ["flake8", "pytest", "black==22.3.0"],
diff --git a/tests/test_binaries.py b/tests/test_binaries.py
index 1ab92f5f7..41d9210e7 100644
--- a/tests/test_binaries.py
+++ b/tests/test_binaries.py
@@ -11,6 +11,7 @@ import random
import sys
import tempfile
import unittest
+from packaging import version
from io import StringIO
from typing import Dict, List
@@ -625,6 +626,10 @@ class TestTranslation(unittest.TestCase):
)
generate_main(data_dir, extra_flags=[])
+ @unittest.skipIf(
+ version.parse(torch.__version__) > version.parse("1.8"),
+ "skip for latest torch versions",
+ )
def test_transformer_pointer_generator(self):
with contextlib.redirect_stdout(StringIO()):
with tempfile.TemporaryDirectory(
@@ -1827,6 +1832,8 @@ def train_masked_lm(data_dir, arch, extra_flags=None):
"masked_lm",
"--batch-size",
"500",
+ "--required-batch-size-multiple",
+ "1",
"--save-dir",
data_dir,
"--max-epoch",
diff --git a/tests/test_ema.py b/tests/test_ema.py
index 847316ff4..bd2cf2c78 100644
--- a/tests/test_ema.py
+++ b/tests/test_ema.py
@@ -6,6 +6,7 @@
import unittest
from copy import deepcopy
from dataclasses import dataclass
+import pytest
from typing import Optional
from unittest.mock import patch
@@ -160,8 +161,7 @@ class TestEMA(unittest.TestCase):
self._test_ema_start_update(updates=1)
def test_ema_fp32(self):
- # CPU no longer supports Linear in half precision
- dtype = torch.half if torch.cuda.is_available() else torch.float
+ dtype = torch.float
model = DummyModule().to(dtype)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
@@ -213,12 +213,12 @@ class TestEMA(unittest.TestCase):
).to(dtype),
)
+ @pytest.mark.skipif(
+ not torch.cuda.is_available(),
+ reason="CPU no longer supports Linear in half precision",
+ )
def test_ema_fp16(self):
- # CPU no longer supports Linear in half precision
- if not torch.cuda.is_available():
- return
-
- model = DummyModule().half()
+ model = DummyModule().cuda().half()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
state = deepcopy(model.state_dict())
config = EMAConfig(ema_fp32=False)
@@ -227,7 +227,7 @@ class TestEMA(unittest.TestCase):
# Since fp32 params is not used, it should be of size 0
self.assertEqual(len(ema.fp32_params), 0)
- x = torch.randn(32)
+ x = torch.randn(32).cuda()
y = model(x.half())
loss = y.sum()
loss.backward()
diff --git a/tests/test_multihead_attention.py b/tests/test_multihead_attention.py
index 5301318c2..4a0b430b6 100644
--- a/tests/test_multihead_attention.py
+++ b/tests/test_multihead_attention.py
@@ -101,6 +101,7 @@ def test_mask_for_xformers():
@pytest.mark.skipif(not torch.cuda.is_available(), reason="blocksparse requires gpu")
+@pytest.mark.skip(reason="not part of latest xformers")
@pytest.mark.parametrize("device", ["cuda"])
@pytest.mark.parametrize("add_zero_attn", [False])
@pytest.mark.parametrize("batch_size", [20])
diff --git a/tests/test_valid_subset_checks.py b/tests/test_valid_subset_checks.py
index 3e9191bda..c39fb8982 100644
--- a/tests/test_valid_subset_checks.py
+++ b/tests/test_valid_subset_checks.py
@@ -126,13 +126,18 @@ class TestCombineValidSubsets(unittest.TestCase):
return [x.message for x in logs.records]
def test_combined(self):
- flags = ["--combine-valid-subsets"]
+ flags = ["--combine-valid-subsets", "--required-batch-size-multiple", "1"]
logs = self._train(flags)
assert any(["valid1" in x for x in logs]) # loaded 100 examples from valid1
assert not any(["valid1_ppl" in x for x in logs]) # metrics are combined
def test_subsets(self):
- flags = ["--valid-subset", "valid,valid1"]
+ flags = [
+ "--valid-subset",
+ "valid,valid1",
+ "--required-batch-size-multiple",
+ "1",
+ ]
logs = self._train(flags)
assert any(["valid_ppl" in x for x in logs]) # loaded 100 examples from valid1
assert any(["valid1_ppl" in x for x in logs]) # metrics are combined