data2vec v2.0 (#4903)

data2v2c 2.0
Co-authored-by: Arun Babu <arbabu@fb.com>
Co-authored-by: Wei-Ning Hsu <wnhsu@csail.mit.edu>
This commit is contained in:
Alexei Baevski 2022-12-12 08:53:56 -08:00 committed by GitHub
parent 0f33ccf7cf
commit d871f6169f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
236 changed files with 17327 additions and 522 deletions

View File

@ -1,3 +1,125 @@
# data2vec 2.0
data2vec 2.0 improves the training efficiency of the original data2vec algorithm. We make the following improvements for efficiency considerations - we forward only the unmasked timesteps through the encoder, we use convolutional decoder and we use multimasking to amortize the compute overhead of the teacher model. You can find details in [Efficient Self-supervised Learning with Contextualized Target Representations for Vision, Speech and Language](https://ai.facebook.com/research/xyz)
## Pretrained and finetuned models
### Vision
| Model | Finetuning split | Link
|---|---|---
data2vec ViT-B | No fine-tuning | [download](https://dl.fbaipublicfiles.com/fairseq/data2vec2/base_imagenet.pt)
data2vec ViT-B | Imagenet-1K | [download](https://dl.fbaipublicfiles.com/fairseq/data2vec2/base_imagenet_ft.pt)
data2vec ViT-L | No fine-tuning | [download](https://dl.fbaipublicfiles.com/fairseq/data2vec2/large_imagenet.pt)
data2vec ViT-L | Imagenet-1K | [download](https://dl.fbaipublicfiles.com/fairseq/data2vec2/large_imagenet_ft.pt)
data2vec ViT-H | No fine-tuning | [download](https://dl.fbaipublicfiles.com/fairseq/data2vec2/huge_imagenet.pt)
data2vec ViT-H | Imagenet-1K | [download](https://dl.fbaipublicfiles.com/fairseq/data2vec2/huge_imagenet_ft.pt)
Vision models only are license under CC-BY-NC.
### Speech
| Model | Finetuning split | Dataset | Link
|---|---|---|---
data2vec Base | No fine-tuning | [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/data2vec2/base_libri.pt)
data2vec Base | 960 hours | [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/data2vec2/base_libri_960h.pt)
data2vec Large | No fine-tuning | [Libri-light](https://github.com/facebookresearch/libri-light) | [download](https://dl.fbaipublicfiles.com/fairseq/data2vec2/large_vox.pt)
data2vec Large | 960 hours | [Libri-light](https://github.com/facebookresearch/libri-light) | [download](https://dl.fbaipublicfiles.com/fairseq/data2vec2/large_vox_960h.pt)
### NLP
Model | Fine-tuning data | Dataset | Link
|---|---|---|---|
data2vec Base | No fine-tuning | Books + Wiki | [download](https://dl.fbaipublicfiles.com/fairseq/data2vec2/nlp_base.pt)
[//]: # (## Data Preparation)
[//]: # ()
[//]: # (### Vision)
[//]: # (add details)
[//]: # (### Speech)
[//]: # (add details)
[//]: # ()
[//]: # (### NLP)
[//]: # (add details)
## Commands to train different models using data2vec 2.0
### Vision
Commands to pretrain different model configurations
```shell script
$ python fairseq_cli/hydra_train.py -m --config-dir examples/data2vec/config/v2 \
--config-name base_images_only_task task.data=/path/to/dir
```
```shell script
$ python fairseq_cli/hydra_train.py -m --config-dir examples/data2vec/config/v2 \
--config-name large_images_only_task task.data=/path/to/dir
```
```shell script
$ python fairseq_cli/hydra_train.py -m --config-dir examples/data2vec/config/v2 \
--config-name huge_images14_only_task task.data=/path/to/dir
```
Commands to finetune different model configurations
```shell script
$ python fairseq_cli/hydra_train.py -m --config-dir examples/data2vec/config/vision/finetuning \
--config-name mae_imagenet_clean task.data=/path/to/dir model.model_path=/path/to/pretrained/model
```
```shell script
$ python fairseq_cli/hydra_train.py -m --config-dir examples/data2vec/config/vision/finetuning \
--config-name mae_imagenet_large_clean task.data=/path/to/dir model.model_path=/path/to/pretrained/model
```
```shell script
$ python fairseq_cli/hydra_train.py -m --config-dir examples/data2vec/config/vision/finetuning \
--config-name mae_imagenet_huge_clean task.data=/path/to/dir model.model_path=/path/to/pretrained/model
```
### Speech
```shell script
$ python fairseq_cli/hydra_train.py -m --config-dir examples/data2vec/config/v2 \
--config-name base_audio_only_task task.data=/path/to/manifests
```
```shell script
$ python fairseq_cli/hydra_train.py -m --config-dir examples/data2vec/config/v2 \
--config-name large_audio_only_task task.data=/path/to/manifests
```
Finetuning:
```shell script
$ python fairseq_cli/hydra_train.py -m --config-dir examples/wav2vec/config/finetuning --config-name vox_10h \
task.data=/path/to/manifests model.w2v_path=/path/to/pretrained/model common.user_dir=examples/data2vec
```
Replace vox_10h with the right config depending on your model and fine-tuning split.
See examples/wav2vec/config/finetuning for all available configs.
### NLP
Commands to pretrain
```shell script
$ python fairseq_cli/hydra_train.py -m --config-dir examples/data2vec/config/v2 \
--config-name base_text_only_task task.data=/path/to/file
```
Commands to fine-tune all GLUE tasks
```shell script
$ task=cola # choose from [cola|qnli|mrpc|rte|sst_2|mnli|qqp|sts_b]
$ lr=1e-5 # sweep [1e-5|2e-5|4e-5|6e-5] for each task
$ python fairseq_cli/hydra_train.py -m --config-dir examples/data2vec/config/v2/text_finetuning \
--config-name $task task.data=/path/to/file model.model_path=/path/to/pretrained/model "optimization.lr=[${lr}]"
```
# data2vec

View File

View File

@ -0,0 +1,70 @@
# @package _group_
common:
fp16: true
log_format: json
log_interval: 200
all_gather_list_size: 70000
tensorboard_logdir: tb
min_loss_scale: 1e-6
checkpoint:
save_interval: 1
no_epoch_checkpoints: true
best_checkpoint_metric: mAP
maximize_best_checkpoint_metric: true
task:
_name: audio_classification
data: ???
normalize: true
labels: lbl
dataset:
num_workers: 6
max_tokens: 2560000
skip_invalid_size_inputs_valid_test: true
valid_subset: eval
validate_interval: 5
distributed_training:
ddp_backend: legacy_ddp
distributed_world_size: 8
criterion:
_name: model
can_sum: false
log_keys:
- _predictions
- _targets
optimization:
max_update: 30000
lr: [0.00006] # scratch 53-5
optimizer:
_name: adam
adam_betas: (0.9,0.98)
adam_eps: 1e-08
lr_scheduler:
_name: cosine
warmup_updates: 5000
model:
_name: audio_classification
model_path: ???
apply_mask: true
mask_prob: 0.6
mask_length: 5 # scratch 1
mask_channel_prob: 0
mask_channel_length: 64
layerdrop: 0.1
dropout: 0.1
activation_dropout: 0.1
attention_dropout: 0.2
feature_grad_mult: 0 # scratch 1
label_mixup: true
source_mixup: 0.5
prediction_mode: lin_softmax # scratch average_sigmoid

View File

@ -0,0 +1,35 @@
# @package _global_
hydra:
job:
config:
override_dirname:
kv_sep: ':'
item_sep: '/'
exclude_keys:
- run_config
- distributed_training.distributed_port
- distributed_training.distributed_world_size
- model.pretrained_model_path
- model.target_network_path
- next_script
- task.cache_in_scratch
- task.data
- checkpoint.save_interval_updates
- checkpoint.keep_interval_updates
- checkpoint.save_on_overflow
sweep:
dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
subdir: ''
launcher:
submitit_folder: ${hydra.sweep.dir}
timeout_min: 4320
cpus_per_task: 10
gpus_per_node: 8
tasks_per_node: 8
mem_gb: 450
nodes: 1
name: ${env:PREFIX}_${hydra.job.config_name}
partition: devlab,learnlab,learnfair,scavenge
constraint: volta32gb,ib4
max_num_timeout: 30

View File

@ -0,0 +1,35 @@
# @package _global_
hydra:
job:
config:
override_dirname:
kv_sep: ':'
item_sep: '/'
exclude_keys:
- run_config
- distributed_training.distributed_port
- distributed_training.distributed_world_size
- model.pretrained_model_path
- model.target_network_path
- next_script
- task.cache_in_scratch
- task.data
- checkpoint.save_interval_updates
- checkpoint.keep_interval_updates
- checkpoint.save_on_overflow
sweep:
dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
subdir: ''
launcher:
submitit_folder: ${hydra.sweep.dir}
timeout_min: 4320
cpus_per_task: 10
gpus_per_node: 1
tasks_per_node: 1
mem_gb: 100
nodes: 1
name: ${env:PREFIX}_${hydra.job.config_name}
partition: devlab,learnlab,learnfair,scavenge
constraint: volta32gb
max_num_timeout: 30

View File

@ -0,0 +1,35 @@
# @package _global_
hydra:
job:
config:
override_dirname:
kv_sep: ':'
item_sep: '/'
exclude_keys:
- run_config
- distributed_training.distributed_port
- distributed_training.distributed_world_size
- model.pretrained_model_path
- model.target_network_path
- next_script
- task.cache_in_scratch
- task.data
- checkpoint.save_interval_updates
- checkpoint.keep_interval_updates
- checkpoint.save_on_overflow
sweep:
dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
subdir: ''
launcher:
submitit_folder: ${hydra.sweep.dir}
timeout_min: 4320
cpus_per_task: 10
gpus_per_node: 8
tasks_per_node: 8
mem_gb: 450
nodes: 2
name: ${env:PREFIX}_${hydra.job.config_name}
partition: devlab,learnlab,learnfair,scavenge
constraint: volta32gb,ib4
max_num_timeout: 30

View File

@ -0,0 +1,91 @@
# @package _group_
common:
fp16: true
log_format: json
log_interval: 200
tensorboard_logdir: tb
min_loss_scale: 1e-6
user_dir: /private/home/abaevski/fairseq-py/examples/data2vec
checkpoint:
save_interval: 1
save_interval_updates: 25000
keep_interval_updates: 1
no_epoch_checkpoints: true
task:
_name: audio_pretraining
data: /private/home/abaevski/data/audioset
max_sample_size: 320000
min_sample_size: 32000
normalize: true
dataset:
num_workers: 6
max_tokens: 3400000
skip_invalid_size_inputs_valid_test: true
validate_interval: 5
required_batch_size_multiple: 1
disable_validation: true
distributed_training:
distributed_world_size: 24
ddp_backend: legacy_ddp
criterion:
_name: model
log_keys:
- ema_decay
- target_var
- pred_var
# - avg_self_attn
# - weights
optimization:
max_update: 200000
lr: [0.0005]
optimizer:
_name: adam
adam_betas: (0.9,0.98)
adam_eps: 1e-06
weight_decay: 0.01
lr_scheduler:
_name: cosine
warmup_updates: 10000
model:
_name: data2vec_audio
extractor_mode: layer_norm
encoder_layerdrop: 0.05
dropout_input: 0.0
dropout_features: 0.0
feature_grad_mult: 1.0
encoder_embed_dim: 768
mask_prob: 0.65
mask_length: 10
loss_beta: 0
loss_scale: null
instance_norm_target_layer: true
layer_norm_targets: true
average_top_k_layers: 12
self_attn_norm_type: deepnorm
final_norm_type: deepnorm
pos_conv_depth: 5
conv_pos: 95
ema_decay: 0.999
ema_end_decay: 0.9999
ema_anneal_end_step: 30000
ema_transformer_only: true
ema_layers_only: false
require_same_masks: true
mask_dropout: 0

View File

@ -0,0 +1,15 @@
# @package _global_
hydra:
sweep:
dir: ${env:PWD}/tmp_dbg/${now:%H-%M-%S}
distributed_training:
distributed_world_size: 1
nprocs_per_node: 1
distributed_port: -1
common:
log_interval: 1
dataset:
num_workers: 0

View File

@ -0,0 +1,37 @@
# @package _global_
hydra:
job:
config:
override_dirname:
kv_sep: ':'
item_sep: '/'
exclude_keys:
- run_config
- distributed_training.distributed_port
- distributed_training.distributed_world_size
- model.pretrained_model_path
- model.target_network_path
- next_script
- task.cache_in_scratch
- task.data
- checkpoint.save_interval_updates
- checkpoint.keep_interval_updates
- checkpoint.save_on_overflow
- common.log_interval
- common.user_dir
sweep:
dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
subdir: ''
launcher:
submitit_folder: ${hydra.sweep.dir}
timeout_min: 4320
cpus_per_task: 80
gpus_per_node: 8
tasks_per_node: 1
mem_gb: 450
nodes: 1
name: ${env:PREFIX}_${hydra.job.config_name}
partition: devlab,learnlab,learnfair,scavenge
constraint: volta32gb,ib4
max_num_timeout: 30

View File

@ -0,0 +1,36 @@
# @package _global_
hydra:
job:
config:
override_dirname:
kv_sep: ':'
item_sep: '/'
exclude_keys:
- run_config
- distributed_training.distributed_port
- distributed_training.distributed_world_size
- model.pretrained_model_path
- model.target_network_path
- next_script
- task.cache_in_scratch
- task.data
- checkpoint.save_interval_updates
- checkpoint.keep_interval_updates
- checkpoint.save_on_overflow
- common.log_interval
- common.user_dir
sweep:
dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
subdir: ''
launcher:
submitit_folder: ${hydra.sweep.dir}
timeout_min: 4320
cpus_per_task: 80
gpus_per_node: 8
tasks_per_node: 1
mem_gb: 0
nodes: 1
name: ${env:PREFIX}_${hydra.job.config_name}
partition: wav2vec,learnlab,learnfair
max_num_timeout: 30

View File

@ -0,0 +1,37 @@
# @package _global_
hydra:
job:
config:
override_dirname:
kv_sep: ':'
item_sep: '/'
exclude_keys:
- run_config
- distributed_training.distributed_port
- distributed_training.distributed_world_size
- model.pretrained_model_path
- model.target_network_path
- next_script
- task.cache_in_scratch
- task.data
- checkpoint.save_interval_updates
- checkpoint.keep_interval_updates
- checkpoint.save_on_overflow
- common.log_interval
- common.user_dir
sweep:
dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
subdir: ''
launcher:
submitit_folder: ${hydra.sweep.dir}
timeout_min: 4320
cpus_per_task: 10
gpus_per_node: 8
tasks_per_node: 8
mem_gb: 450
nodes: 2
name: ${env:PREFIX}_${hydra.job.config_name}
partition: devlab,learnlab,learnfair,scavenge
constraint: volta32gb,ib4
max_num_timeout: 30

View File

@ -0,0 +1,37 @@
# @package _global_
hydra:
job:
config:
override_dirname:
kv_sep: ':'
item_sep: '/'
exclude_keys:
- run_config
- distributed_training.distributed_port
- distributed_training.distributed_world_size
- model.pretrained_model_path
- model.target_network_path
- next_script
- task.cache_in_scratch
- task.data
- task.post_save_script
- checkpoint.save_interval_updates
- checkpoint.keep_interval_updates
- checkpoint.save_on_overflow
- common.log_interval
- common.user_dir
sweep:
dir: /fsx-wav2vec/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
subdir: ''
launcher:
submitit_folder: ${hydra.sweep.dir}
timeout_min: 4320
cpus_per_task: 10
gpus_per_node: 8
tasks_per_node: 8
mem_gb: 0
nodes: 2
name: ${env:PREFIX}_${hydra.job.config_name}
partition: wav2vec,learnlab,learnfair
max_num_timeout: 30

View File

@ -0,0 +1,36 @@
# @package _global_
hydra:
job:
config:
override_dirname:
kv_sep: ':'
item_sep: '/'
exclude_keys:
- run_config
- distributed_training.distributed_port
- distributed_training.distributed_world_size
- model.pretrained_model_path
- model.target_network_path
- next_script
- task.cache_in_scratch
- task.data
- checkpoint.save_interval_updates
- checkpoint.keep_interval_updates
- checkpoint.save_on_overflow
- common.log_interval
sweep:
dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
subdir: ''
launcher:
submitit_folder: ${hydra.sweep.dir}
timeout_min: 4320
cpus_per_task: 80
gpus_per_node: 8
tasks_per_node: 1
mem_gb: 450
nodes: 3
name: ${env:PREFIX}_${hydra.job.config_name}
partition: devlab,learnlab,learnfair,scavenge
constraint: volta32gb,ib4
max_num_timeout: 30

View File

@ -0,0 +1,36 @@
# @package _global_
hydra:
job:
config:
override_dirname:
kv_sep: ':'
item_sep: '/'
exclude_keys:
- run_config
- distributed_training.distributed_port
- distributed_training.distributed_world_size
- model.pretrained_model_path
- model.target_network_path
- next_script
- task.cache_in_scratch
- task.data
- checkpoint.save_interval_updates
- checkpoint.keep_interval_updates
- checkpoint.save_on_overflow
- common.log_interval
sweep:
dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
subdir: ''
launcher:
submitit_folder: ${hydra.sweep.dir}
timeout_min: 4320
cpus_per_task: 10
gpus_per_node: 8
tasks_per_node: 8
mem_gb: 450
nodes: 4
name: ${env:PREFIX}_${hydra.job.config_name}
partition: devlab,learnlab,learnfair,scavenge
constraint: volta32gb,ib4
max_num_timeout: 30

View File

@ -0,0 +1,37 @@
# @package _global_
hydra:
job:
config:
override_dirname:
kv_sep: ':'
item_sep: '/'
exclude_keys:
- run_config
- distributed_training.distributed_port
- distributed_training.distributed_world_size
- model.pretrained_model_path
- model.target_network_path
- next_script
- task.cache_in_scratch
- task.data
- task.post_save_script
- checkpoint.save_interval_updates
- checkpoint.keep_interval_updates
- checkpoint.save_on_overflow
- common.log_interval
- common.user_dir
sweep:
dir: /fsx-wav2vec/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
subdir: ''
launcher:
submitit_folder: ${hydra.sweep.dir}
timeout_min: 4320
cpus_per_task: 10
gpus_per_node: 8
tasks_per_node: 8
mem_gb: 0
nodes: 4
name: ${env:PREFIX}_${hydra.job.config_name}
partition: wav2vec,learnlab,learnfair
max_num_timeout: 30

View File

@ -0,0 +1,36 @@
# @package _global_
hydra:
job:
config:
override_dirname:
kv_sep: ':'
item_sep: '/'
exclude_keys:
- run_config
- distributed_training.distributed_port
- distributed_training.distributed_world_size
- model.pretrained_model_path
- model.target_network_path
- next_script
- task.cache_in_scratch
- task.data
- checkpoint.save_interval_updates
- checkpoint.keep_interval_updates
- checkpoint.save_on_overflow
- common.log_interval
- common.user_dir
sweep:
dir: /fsx-wav2vec/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
subdir: ''
launcher:
submitit_folder: ${hydra.sweep.dir}
timeout_min: 4320
cpus_per_task: 10
gpus_per_node: 8
tasks_per_node: 8
mem_gb: 0
nodes: 6
name: ${env:PREFIX}_${hydra.job.config_name}
partition: wav2vec,learnlab,learnfair
max_num_timeout: 30

View File

@ -0,0 +1,36 @@
# @package _global_
hydra:
job:
config:
override_dirname:
kv_sep: ':'
item_sep: '/'
exclude_keys:
- run_config
- distributed_training.distributed_port
- distributed_training.distributed_world_size
- model.pretrained_model_path
- model.target_network_path
- next_script
- task.cache_in_scratch
- task.data
- checkpoint.save_interval_updates
- checkpoint.keep_interval_updates
- checkpoint.save_on_overflow
- common.log_interval
- common.user_dir
sweep:
dir: /fsx-wav2vec/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
subdir: ''
launcher:
submitit_folder: ${hydra.sweep.dir}
timeout_min: 4320
cpus_per_task: 10
gpus_per_node: 8
tasks_per_node: 8
mem_gb: 0
nodes: 8
name: ${env:PREFIX}_${hydra.job.config_name}
partition: wav2vec,learnlab,learnfair
max_num_timeout: 30

View File

@ -0,0 +1,15 @@
# @package _global_
hydra:
sweep:
dir: ${env:PWD}/tmp_dbg/${now:%H-%M-%S}
distributed_training:
distributed_world_size: 1
nprocs_per_node: 1
distributed_port: -1
common:
log_interval: 1
dataset:
num_workers: 0

View File

@ -0,0 +1,37 @@
# @package _global_
hydra:
job:
config:
override_dirname:
kv_sep: '_'
item_sep: '/'
exclude_keys:
- run_config
- distributed_training.distributed_port
- distributed_training.distributed_world_size
- model.pretrained_model_path
- model.target_network_path
- next_script
- task.cache_in_scratch
- task.data
- checkpoint.save_interval_updates
- checkpoint.keep_interval_updates
- checkpoint.save_on_overflow
- common.log_interval
- common.user_dir
sweep:
dir: /fsx-wav2vec/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
subdir: ''
launcher:
submitit_folder: ${hydra.sweep.dir}/submitit
timeout_min: 4320
cpus_per_task: 80
gpus_per_node: 8
tasks_per_node: 1
mem_gb: 0
nodes: 1
name: ${env:PREFIX}_${hydra.job.config_name}
partition: wav2vec
max_num_timeout: 30
exclude: a100-st-p4d24xlarge-471

View File

@ -0,0 +1,37 @@
# @package _global_
hydra:
job:
config:
override_dirname:
kv_sep: ':'
item_sep: '/'
exclude_keys:
- run_config
- distributed_training.distributed_port
- distributed_training.distributed_world_size
- model.pretrained_model_path
- model.target_network_path
- next_script
- task.cache_in_scratch
- task.data
- checkpoint.save_interval_updates
- checkpoint.keep_interval_updates
- checkpoint.save_on_overflow
- common.log_interval
- common.user_dir
sweep:
dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
subdir: ''
launcher:
submitit_folder: ${hydra.sweep.dir}
timeout_min: 4320
cpus_per_task: 80
gpus_per_node: 8
tasks_per_node: 1
mem_gb: 450
nodes: 2
name: ${env:PREFIX}_${hydra.job.config_name}
partition: devlab,learnlab,learnfair,scavenge
constraint: volta32gb,ib4
max_num_timeout: 30

View File

@ -0,0 +1,37 @@
# @package _global_
hydra:
job:
config:
override_dirname:
kv_sep: '_'
item_sep: '/'
exclude_keys:
- run_config
- distributed_training.distributed_port
- distributed_training.distributed_world_size
- model.pretrained_model_path
- model.target_network_path
- next_script
- task.cache_in_scratch
- task.data
- checkpoint.save_interval_updates
- checkpoint.keep_interval_updates
- checkpoint.save_on_overflow
- common.log_interval
- common.user_dir
sweep:
dir: /fsx-wav2vec/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
subdir: ''
launcher:
submitit_folder: ${hydra.sweep.dir}/submitit
timeout_min: 4320
cpus_per_task: 10
gpus_per_node: 8
tasks_per_node: 8
mem_gb: 0
nodes: 2
name: ${env:PREFIX}_${hydra.job.config_name}
partition: wav2vec
max_num_timeout: 30
exclude: a100-st-p4d24xlarge-471

View File

@ -0,0 +1,36 @@
# @package _global_
hydra:
job:
config:
override_dirname:
kv_sep: ':'
item_sep: '/'
exclude_keys:
- run_config
- distributed_training.distributed_port
- distributed_training.distributed_world_size
- model.pretrained_model_path
- model.target_network_path
- next_script
- task.cache_in_scratch
- task.data
- checkpoint.save_interval_updates
- checkpoint.keep_interval_updates
- checkpoint.save_on_overflow
- common.log_interval
sweep:
dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
subdir: ''
launcher:
submitit_folder: ${hydra.sweep.dir}
timeout_min: 4320
cpus_per_task: 10
gpus_per_node: 8
tasks_per_node: 8
mem_gb: 450
nodes: 3
name: ${env:PREFIX}_${hydra.job.config_name}
partition: devlab,learnlab,learnfair,scavenge
constraint: volta32gb,ib4
max_num_timeout: 30

View File

@ -0,0 +1,36 @@
# @package _global_
hydra:
job:
config:
override_dirname:
kv_sep: ':'
item_sep: '/'
exclude_keys:
- run_config
- distributed_training.distributed_port
- distributed_training.distributed_world_size
- model.pretrained_model_path
- model.target_network_path
- next_script
- task.cache_in_scratch
- task.data
- checkpoint.save_interval_updates
- checkpoint.keep_interval_updates
- checkpoint.save_on_overflow
- common.log_interval
sweep:
dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
subdir: ''
launcher:
submitit_folder: ${hydra.sweep.dir}
timeout_min: 4320
cpus_per_task: 10
gpus_per_node: 8
tasks_per_node: 8
mem_gb: 450
nodes: 4
name: ${env:PREFIX}_${hydra.job.config_name}
partition: devlab,learnlab,learnfair,scavenge
constraint: volta32gb,ib4
max_num_timeout: 30

View File

@ -0,0 +1,41 @@
# @package _global_
hydra:
job:
config:
override_dirname:
kv_sep: '_'
item_sep: '/'
exclude_keys:
- run_config
- distributed_training.distributed_port
- distributed_training.distributed_world_size
- model.pretrained_model_path
- model.target_network_path
- next_script
- task.cache_in_scratch
- task.data
- checkpoint.save_interval_updates
- checkpoint.keep_interval_updates
- checkpoint.save_on_overflow
- common.log_interval
- common.user_dir
sweep:
dir: /fsx-wav2vec/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
subdir: ''
launcher:
submitit_folder: ${hydra.sweep.dir}/submitit
timeout_min: 4320
cpus_per_task: 10
gpus_per_node: 8
tasks_per_node: 8
mem_gb: 0
nodes: 4
name: ${env:PREFIX}_${hydra.job.config_name}
partition: wav2vec
max_num_timeout: 30
exclude: a100-st-p4d24xlarge-471
distributed_training:
distributed_world_size: 32
ddp_backend: legacy_ddp

View File

@ -0,0 +1,41 @@
# @package _global_
hydra:
job:
config:
override_dirname:
kv_sep: '_'
item_sep: '/'
exclude_keys:
- run_config
- distributed_training.distributed_port
- distributed_training.distributed_world_size
- model.pretrained_model_path
- model.target_network_path
- next_script
- task.cache_in_scratch
- task.data
- checkpoint.save_interval_updates
- checkpoint.keep_interval_updates
- checkpoint.save_on_overflow
- common.log_interval
- common.user_dir
sweep:
dir: /fsx-wav2vec/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
subdir: ''
launcher:
submitit_folder: ${hydra.sweep.dir}/submitit
timeout_min: 4320
cpus_per_task: 10
gpus_per_node: 8
tasks_per_node: 8
mem_gb: 0
nodes: 8
name: pt
partition: wav2vec
max_num_timeout: 30
exclude: a100-st-p4d24xlarge-471
distributed_training:
distributed_world_size: 64
ddp_backend: legacy_ddp

View File

@ -0,0 +1,113 @@
# @package _group_
common:
fp16: true
log_format: json
log_interval: 200
tensorboard_logdir: tb
min_loss_scale: 1e-6
fp16_no_flatten_grads: false
user_dir: ${env:PWD}/examples/data2vec
checkpoint:
save_interval: 1
save_interval_updates: 25000
keep_interval_updates: 1
no_epoch_checkpoints: true
task:
_name: audio_pretraining
data: /private/home/abaevski/data/librispeech/full
max_sample_size: 320000
min_sample_size: 32000
normalize: true
precompute_mask_config: {}
dataset:
num_workers: 6
max_tokens: 1000000
skip_invalid_size_inputs_valid_test: true
validate_interval: 5
required_batch_size_multiple: 1
disable_validation: true
distributed_training:
distributed_world_size: 8
ddp_backend: legacy_ddp
criterion:
_name: model
log_keys:
- ema_decay
- target_var
- pred_var
- model_norm
- ema_norm
- masked_pct
optimization:
max_update: 400000
lr: [0.00075]
debug_param_names: true
optimizer:
_name: adam
adam_betas: [ 0.9,0.98 ]
adam_eps: 1e-06
weight_decay: 0.01
lr_scheduler:
_name: cosine
warmup_updates: 8000
model:
_name: data2vec_multi
loss_beta: 0
loss_scale: null
depth: 12
embed_dim: 768
clone_batch: 8
ema_decay: 0.999
ema_end_decay: 0.99999
ema_anneal_end_step: 75000
ema_encoder_only: false
average_top_k_layers: 8
instance_norm_target_layer: true
layer_norm_target_layer: false
layer_norm_targets: false
layerdrop: 0.05
norm_eps: 1e-5
supported_modality: AUDIO
modalities:
audio:
feature_encoder_spec: '[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512,2,2)] + [(512,2,2)]'
conv_pos_depth: 5
conv_pos_width: 95
conv_pos_groups: 16
prenet_depth: 0
mask_prob: 0.5
mask_prob_adjust: 0.05
inverse_mask: false
mask_length: 5
mask_noise_std: 0.01
mask_dropout: 0
add_masks: false
ema_local_encoder: false
use_alibi_encoder: true
prenet_layerdrop: 0.05
prenet_dropout: 0.1
learned_alibi_scale: true
learned_alibi_scale_per_head: true
decoder:
input_dropout: 0.1
decoder_dim: 384
decoder_groups: 16
decoder_kernel: 7
decoder_layers: 4

View File

@ -0,0 +1,116 @@
# @package _group_
common:
fp16: true
log_format: json
log_interval: 200
tensorboard_logdir: tb
min_loss_scale: 1e-6
fp16_no_flatten_grads: true
user_dir: ${env:PWD}/examples/data2vec
checkpoint:
save_interval: 5
save_interval_updates: 25000
keep_interval_updates: 1
no_epoch_checkpoints: true
task:
_name: mae_image_pretraining
data: /datasets01/imagenet_full_size/061417/
rebuild_batches: true
local_cache_path: /scratch/cache_abaevski/imagenet
key: source
precompute_mask_config: {}
dataset:
num_workers: 10
batch_size: 16
skip_invalid_size_inputs_valid_test: true
required_batch_size_multiple: 1
disable_validation: true
distributed_training:
distributed_world_size: 16
ddp_backend: c10d
criterion:
_name: model
log_keys:
- ema_decay
- target_var
- pred_var
- model_norm
- ema_norm
- masked_pct
optimization:
max_update: 375300
lr: [ 0.001 ]
debug_param_names: true
clip_norm: 4
optimizer:
_name: composite
dynamic_groups: true
groups:
default:
lr_float: 1e-3
optimizer:
_name: adam
adam_betas: [0.9,0.95]
weight_decay: 0.05
lr_scheduler:
_name: cosine
warmup_updates: 50040
lr_scheduler: pass_through
model:
_name: data2vec_multi
ema_decay: 0.9998
ema_end_decay: 0.99999
ema_anneal_end_step: 100000
instance_norm_target_layer: true
layer_norm_target_layer: false
layer_norm_targets: true
end_of_block_targets: false
depth: 10
average_top_k_layers: 10
clone_batch: 16
norm_eps: 1e-6
min_target_var: 0
min_pred_var: 0
encoder_dropout: 0
post_mlp_drop: 0
attention_dropout: 0
activation_dropout: 0
supported_modality: IMAGE
cls_loss: 0.01
ema_encoder_only: false
modalities:
image:
inverse_mask: true
mask_prob: 0.8
mask_prob_adjust: 0.07
mask_length: 3
mask_noise_std: 0.01
prenet_depth: 2
ema_local_encoder: true
num_extra_tokens: 1
init_extra_token_zero: false
use_alibi_encoder: false
decoder:
decoder_dim: 768
decoder_groups: 16
decoder_kernel: 3
decoder_layers: 6
input_dropout: 0

View File

@ -0,0 +1,112 @@
# @package _group_
common:
fp16: true
log_format: json
log_interval: 200
tensorboard_logdir: tb
fp16_no_flatten_grads: true
user_dir: ${env:PWD}/examples/data2vec
checkpoint:
no_epoch_checkpoints: true
save_interval_updates: 50000
keep_interval_updates: 1
distributed_training:
distributed_world_size: 16
ddp_backend: legacy_ddp
task:
_name: masked_lm
data: /fsx-wav2vec/abaevski/data/nlp/bookwiki_aml-full-mmap2-bin
sample_break_mode: none
tokens_per_sample: 512
include_target_tokens: true
random_token_prob: 0
leave_unmasked_prob: 0
include_index: True
skip_masking: True
d2v2_multi: True
criterion:
_name: model
log_keys:
- ema_decay
- target_var
- pred_var
- model_norm
- ema_norm
- masked_pct
dataset:
batch_size: 4
ignore_unused_valid_subsets: true
skip_invalid_size_inputs_valid_test: true
disable_validation: true
optimization:
clip_norm: 1
lr: [0.0002]
max_update: 1000000
update_freq: [1]
optimizer:
_name: composite
dynamic_groups: true
groups:
default:
lr_float: 0.0002
optimizer:
_name: adam
adam_betas: [0.9,0.98]
adam_eps: 1e-06
weight_decay: 0.01
lr_scheduler:
_name: cosine
warmup_updates: 4000
lr_scheduler: pass_through
model:
_name: data2vec_multi
loss_beta: 0
loss_scale: 1
depth: 12
embed_dim: 768
clone_batch: 8
ema_decay: 0.9999
ema_end_decay: 0.99999
ema_anneal_end_step: 100000
ema_encoder_only: true
average_top_k_layers: 12
layer_norm_target_layer: false
instance_norm_target_layer: true
batch_norm_target_layer: false
instance_norm_targets: false
layer_norm_targets: false
layerdrop: 0
norm_eps: 1e-5
supported_modality: TEXT
modalities:
text:
mask_prob: 0.48
mask_length: 1
mask_noise_std: 0.01
prenet_depth: 0
decoder:
input_dropout: 0.1
decoder_dim: 768
decoder_groups: 1
decoder_kernel: 9
decoder_layers: 5
decoder_residual: false
projection_layers: 2
projection_ratio: 2.0

View File

@ -0,0 +1,122 @@
# @package _group_
common:
fp16: true
log_format: json
log_interval: 200
tensorboard_logdir: tb
min_loss_scale: 1e-6
fp16_no_flatten_grads: true
user_dir: ${env:PWD}/examples/data2vec
checkpoint:
save_interval: 5
save_interval_updates: 25000
keep_interval_updates: 1
no_epoch_checkpoints: true
task:
_name: mae_image_pretraining
data: /datasets01/imagenet_full_size/061417/
rebuild_batches: true
local_cache_path: /scratch/cache_abaevski/imagenet
key: source
precompute_mask_config: {}
dataset:
num_workers: 10
batch_size: 8
skip_invalid_size_inputs_valid_test: true
required_batch_size_multiple: 1
disable_validation: true
distributed_training:
distributed_world_size: 32
ddp_backend: c10d
criterion:
_name: model
log_keys:
- ema_decay
- target_var
- pred_var
- model_norm
- ema_norm
- masked_pct
optimization:
max_update: 500000
lr: [ 0.0004 ]
debug_param_names: true
clip_norm: 4
optimizer:
_name: composite
dynamic_groups: true
groups:
default:
lr_float: 4e-4
optimizer:
_name: adam
adam_betas: [0.9,0.95]
weight_decay: 0.05
lr_scheduler:
_name: cosine
warmup_updates: 50040
lr_scheduler: pass_through
model:
_name: data2vec_multi
ema_decay: 0.9998
ema_end_decay: 1
ema_anneal_end_step: 300000
instance_norm_target_layer: true
layer_norm_target_layer: false
layer_norm_targets: true
end_of_block_targets: false
depth: 32
embed_dim: 1280
num_heads: 16
average_top_k_layers: 24
clone_batch: 16
norm_eps: 1e-6
min_target_var: 0
min_pred_var: 0
encoder_dropout: 0
post_mlp_drop: 0
attention_dropout: 0
activation_dropout: 0
supported_modality: IMAGE
cls_loss: 0.01
ema_encoder_only: false
modalities:
image:
patch_size: 14
inverse_mask: true
mask_prob: 0.75
mask_prob_adjust: 0.1
mask_length: 3
mask_noise_std: 0.01
prenet_depth: 0
ema_local_encoder: true
num_extra_tokens: 1
init_extra_token_zero: false
use_alibi_encoder: false
embed_dim: 1280
decoder:
decoder_dim: 1024
decoder_groups: 16
decoder_kernel: 5
decoder_layers: 3
final_layer_norm: false
input_dropout: 0

View File

@ -0,0 +1,120 @@
# @package _group_
common:
fp16: true
log_format: json
log_interval: 200
tensorboard_logdir: tb
min_loss_scale: 1e-6
fp16_no_flatten_grads: true
user_dir: ${env:PWD}/examples/data2vec
checkpoint:
save_interval: 5
save_interval_updates: 25000
keep_interval_updates: 1
no_epoch_checkpoints: true
task:
_name: mae_image_pretraining
data: /datasets01/imagenet_full_size/061417/
rebuild_batches: true
local_cache_path: /scratch/cache_abaevski/imagenet
key: source
precompute_mask_config: {}
dataset:
num_workers: 10
batch_size: 8
skip_invalid_size_inputs_valid_test: true
required_batch_size_multiple: 1
disable_validation: true
distributed_training:
distributed_world_size: 16
ddp_backend: c10d
criterion:
_name: model
log_keys:
- ema_decay
- target_var
- pred_var
- model_norm
- ema_norm
- masked_pct
optimization:
max_update: 375300
lr: [ 0.0004 ]
debug_param_names: true
clip_norm: 4
optimizer:
_name: composite
dynamic_groups: true
groups:
default:
lr_float: 4e-4
optimizer:
_name: adam
adam_betas: [0.9,0.95]
weight_decay: 0.05
lr_scheduler:
_name: cosine
warmup_updates: 50040
lr_scheduler: pass_through
model:
_name: data2vec_multi
ema_decay: 0.9998
ema_end_decay: 0.99995
ema_anneal_end_step: 150000
instance_norm_target_layer: true
layer_norm_target_layer: false
layer_norm_targets: true
end_of_block_targets: false
depth: 32
embed_dim: 1280
num_heads: 16
average_top_k_layers: 24
clone_batch: 16
norm_eps: 1e-6
min_target_var: 0
min_pred_var: 0
encoder_dropout: 0
post_mlp_drop: 0
attention_dropout: 0
activation_dropout: 0
supported_modality: IMAGE
cls_loss: 0.01
ema_encoder_only: false
modalities:
image:
inverse_mask: true
mask_prob: 0.75
mask_prob_adjust: 0.1
mask_length: 3
mask_noise_std: 0.01
prenet_depth: 0
ema_local_encoder: true
num_extra_tokens: 1
init_extra_token_zero: false
use_alibi_encoder: false
embed_dim: 1280
decoder:
decoder_dim: 1024
decoder_groups: 16
decoder_kernel: 5
decoder_layers: 3
input_dropout: 0

View File

@ -0,0 +1,122 @@
# @package _group_
common:
fp16: true
log_format: json
log_interval: 200
tensorboard_logdir: tb
min_loss_scale: 1e-6
fp16_no_flatten_grads: true
user_dir: ${env:PWD}/examples/data2vec
checkpoint:
save_interval: 1
save_interval_updates: 25000
keep_interval_updates: 1
no_epoch_checkpoints: true
task:
_name: audio_pretraining
data: /fsx-wav2vec/abaevski/data/librivox/no_silence
max_sample_size: 320000
min_sample_size: 32000
normalize: true
precompute_mask_config: {}
dataset:
num_workers: 8
max_tokens: 320000
skip_invalid_size_inputs_valid_test: true
validate_interval: 5
required_batch_size_multiple: 1
disable_validation: true
distributed_training:
distributed_world_size: 48
ddp_backend: c10d
criterion:
_name: model
log_keys:
- ema_decay
- target_var
- pred_var
- model_norm
- ema_norm
- masked_pct
optimization:
max_update: 600000
debug_param_names: true
clip_norm: 1
optimizer:
_name: composite
dynamic_groups: true
groups:
default:
lr_float: 0.0004
optimizer:
_name: adam
adam_betas: [0.9,0.98]
adam_eps: 1e-06
weight_decay: 0.01
lr_scheduler:
_name: cosine
warmup_updates: 10000
lr_scheduler: pass_through
model:
_name: data2vec_multi
loss_beta: 0
loss_scale: null
depth: 16
embed_dim: 1024
num_heads: 16
clone_batch: 12
ema_decay: 0.9997
ema_end_decay: 1
ema_anneal_end_step: 300000
ema_encoder_only: false
average_top_k_layers: 16
instance_norm_target_layer: true
layer_norm_target_layer: false
layer_norm_targets: false
layerdrop: 0
norm_eps: 1e-5
supported_modality: AUDIO
modalities:
audio:
feature_encoder_spec: '[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512,2,2)] + [(512,2,2)]'
conv_pos_depth: 5
conv_pos_width: 95
conv_pos_groups: 16
prenet_depth: 8
mask_prob: 0.55
mask_prob_adjust: 0.1
inverse_mask: false
mask_length: 5
mask_noise_std: 0.01
mask_dropout: 0
add_masks: false
ema_local_encoder: false
use_alibi_encoder: true
prenet_layerdrop: 0
prenet_dropout: 0.1
learned_alibi_scale: true
learned_alibi_scale_per_head: true
decoder:
input_dropout: 0.1
decoder_dim: 768
decoder_groups: 16
decoder_kernel: 7
decoder_layers: 4

View File

@ -0,0 +1,120 @@
# @package _group_
common:
fp16: true
log_format: json
log_interval: 200
tensorboard_logdir: tb
min_loss_scale: 1e-6
fp16_no_flatten_grads: true
user_dir: ${env:PWD}/examples/data2vec
checkpoint:
save_interval: 5
save_interval_updates: 25000
keep_interval_updates: 1
no_epoch_checkpoints: true
task:
_name: mae_image_pretraining
data: /datasets01/imagenet_full_size/061417/
rebuild_batches: true
local_cache_path: /scratch/cache_abaevski/imagenet
key: source
precompute_mask_config: {}
dataset:
num_workers: 10
batch_size: 8
skip_invalid_size_inputs_valid_test: true
required_batch_size_multiple: 1
disable_validation: true
distributed_training:
distributed_world_size: 16
ddp_backend: c10d
criterion:
_name: model
log_keys:
- ema_decay
- target_var
- pred_var
- model_norm
- ema_norm
- masked_pct
optimization:
max_update: 375300
lr: [ 0.0004 ]
debug_param_names: true
clip_norm: 4
optimizer:
_name: composite
dynamic_groups: true
groups:
default:
lr_float: 4e-4
optimizer:
_name: adam
adam_betas: [0.9,0.95]
weight_decay: 0.05
lr_scheduler:
_name: cosine
warmup_updates: 50040
lr_scheduler: pass_through
model:
_name: data2vec_multi
ema_decay: 0.9998
ema_end_decay: 0.99999
ema_anneal_end_step: 150000
instance_norm_target_layer: true
layer_norm_target_layer: false
layer_norm_targets: true
end_of_block_targets: false
depth: 24
embed_dim: 1024
num_heads: 16
average_top_k_layers: 18
clone_batch: 16
norm_eps: 1e-6
min_target_var: 0
min_pred_var: 0
encoder_dropout: 0
post_mlp_drop: 0
attention_dropout: 0
activation_dropout: 0
supported_modality: IMAGE
cls_loss: 0.01
ema_encoder_only: false
modalities:
image:
inverse_mask: true
mask_prob: 0.75
mask_prob_adjust: 0.1
mask_length: 3
mask_noise_std: 0.01
prenet_depth: 0
ema_local_encoder: true
num_extra_tokens: 1
init_extra_token_zero: false
use_alibi_encoder: false
embed_dim: 1024
decoder:
decoder_dim: 1024
decoder_groups: 16
decoder_kernel: 5
decoder_layers: 3
input_dropout: 0

View File

@ -0,0 +1,112 @@
# @package _group_
common:
fp16: true
log_format: json
log_interval: 200
tensorboard_logdir: tb
min_loss_scale: 1e-6
fp16_no_flatten_grads: true
user_dir: ${env:PWD}/examples/data2vec
checkpoint:
save_interval_updates: 50000
keep_interval_updates: 1
no_epoch_checkpoints: true
task:
_name: masked_lm
data: /fsx-wav2vec/abaevski/data/nlp/bookwiki_aml-full-mmap2-bin
sample_break_mode: none
tokens_per_sample: 512
include_target_tokens: true
random_token_prob: 0
leave_unmasked_prob: 0
include_index: True
skip_masking: True
d2v2_multi: True
dataset:
batch_size: 2
ignore_unused_valid_subsets: true
skip_invalid_size_inputs_valid_test: true
disable_validation: true
distributed_training:
distributed_world_size: 32
ddp_backend: c10d
criterion:
_name: model
log_keys:
- ema_decay
- target_var
- pred_var
- model_norm
- ema_norm
- masked_pct
optimization:
max_update: 600000
clip_norm: 1
optimizer:
_name: composite
dynamic_groups: true
groups:
default:
lr_float: 0.0001
optimizer:
_name: adam
adam_betas: [0.9,0.98]
adam_eps: 1e-06
weight_decay: 0.01
lr_scheduler:
_name: cosine
warmup_updates: 4000
lr_scheduler: pass_through
model:
_name: data2vec_multi
loss_beta: 0
loss_scale: 1
depth: 24
num_heads: 16
embed_dim: 1024
clone_batch: 8
ema_decay: 0.9999
ema_end_decay: 0.99999
ema_anneal_end_step: 100000
ema_encoder_only: true
average_top_k_layers: 24
layer_norm_target_layer: true
instance_norm_target_layer: false
batch_norm_target_layer: false
instance_norm_targets: true
layer_norm_targets: false
layerdrop: 0
norm_eps: 1e-5
supported_modality: TEXT
modalities:
text:
mask_prob: 0.5
mask_length: 1
mask_noise_std: 0.01
prenet_depth: 0
decoder:
input_dropout: 0.1
decoder_dim: 768
decoder_groups: 1
decoder_kernel: 9
decoder_layers: 5
decoder_residual: false
projection_layers: 2
projection_ratio: 2.0

View File

@ -0,0 +1,123 @@
# @package _group_
common:
fp16: true
log_format: json
log_interval: 200
tensorboard_logdir: tb
fp16_no_flatten_grads: true
user_dir: ${env:PWD}/examples/data2vec
checkpoint:
no_epoch_checkpoints: true
save_interval_updates: 50000
keep_interval_updates: 1
distributed_training:
distributed_world_size: 32
ddp_backend: legacy_ddp
task:
_name: masked_lm
data: /fsx-wav2vec/abaevski/data/nlp/bookwiki_aml-full-mmap2-bin
sample_break_mode: none
tokens_per_sample: 512
include_target_tokens: true
random_token_prob: 0
leave_unmasked_prob: 0
include_index: True
skip_masking: True
d2v2_multi: True
criterion:
_name: model
log_keys:
- ema_decay
- target_var
- pred_var
- model_norm
- ema_norm
- masked_pct
dataset:
batch_size: 2
ignore_unused_valid_subsets: true
skip_invalid_size_inputs_valid_test: true
disable_validation: true
optimization:
clip_norm: 1
lr: [3e-4]
max_update: 1000000
update_freq: [1]
optimizer:
_name: composite
groups:
default:
lr_float: 1e-4
optimizer:
_name: adam
adam_betas: [0.9,0.98]
adam_eps: 1e-06
weight_decay: 0.01
lr_scheduler:
_name: cosine
warmup_updates: 4000
decoder:
lr_float: 1e-4
optimizer:
_name: adam
adam_betas: [0.9,0.98]
adam_eps: 1e-06
weight_decay: 0.01
lr_scheduler:
_name: cosine
warmup_updates: 4000
lr_scheduler: pass_through
model:
_name: data2vec_multi
loss_beta: 4
loss_scale: 1
depth: 24
num_heads: 16
embed_dim: 1024
clone_batch: 8
ema_decay: 0.9999
ema_end_decay: 0.99999
ema_anneal_end_step: 100000
ema_encoder_only: true
average_top_k_layers: 24
layer_norm_target_layer: true
instance_norm_target_layer: false
batch_norm_target_layer: false
instance_norm_targets: true
layer_norm_targets: false
layerdrop: 0
norm_eps: 1e-5
supported_modality: TEXT
decoder_group: true
modalities:
text:
mask_prob: 0.5
mask_length: 1
mask_noise_std: 0.01
prenet_depth: 0
decoder:
input_dropout: 0.1
decoder_dim: 768
decoder_groups: 1
decoder_kernel: 9
decoder_layers: 5
decoder_residual: false
projection_layers: 2
projection_ratio: 2.0

View File

@ -0,0 +1,15 @@
# @package _global_
hydra:
sweep:
dir: ${env:PWD}/tmp_dbg/${now:%H-%M-%S}
distributed_training:
distributed_world_size: 1
nprocs_per_node: 1
distributed_port: -1
common:
log_interval: 1
dataset:
num_workers: 0

View File

@ -0,0 +1,37 @@
# @package _global_
hydra:
job:
config:
override_dirname:
kv_sep: ':'
item_sep: '/'
exclude_keys:
- run_config
- distributed_training.distributed_port
- distributed_training.distributed_world_size
- model.pretrained_model_path
- model.target_network_path
- next_script
- task.cache_in_scratch
- task.data
- checkpoint.save_interval_updates
- checkpoint.keep_interval_updates
- checkpoint.save_on_overflow
- common.log_interval
- common.user_dir
sweep:
dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
subdir: ''
launcher:
submitit_folder: ${hydra.sweep.dir}
timeout_min: 4320
cpus_per_task: 80
gpus_per_node: 8
tasks_per_node: 1
mem_gb: 450
nodes: 1
name: ${env:PREFIX}_${hydra.job.config_name}
partition: devlab,learnlab,learnfair,scavenge
constraint: volta32gb,ib4
max_num_timeout: 30

View File

@ -0,0 +1,37 @@
# @package _global_
hydra:
job:
config:
override_dirname:
kv_sep: ':'
item_sep: '/'
exclude_keys:
- run_config
- distributed_training.distributed_port
- distributed_training.distributed_world_size
- model.pretrained_model_path
- model.target_network_path
- next_script
- task.cache_in_scratch
- task.local_cache_path
- task.data
- checkpoint.save_interval_updates
- checkpoint.keep_interval_updates
- checkpoint.save_on_overflow
- common.log_interval
- common.user_dir
sweep:
dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
subdir: ''
launcher:
submitit_folder: ${hydra.sweep.dir}
timeout_min: 4320
cpus_per_task: 80
gpus_per_node: 8
tasks_per_node: 1
mem_gb: 0
nodes: 1
name: ${env:PREFIX}_${hydra.job.config_name}
partition: wav2vec,learnlab,learnfair
max_num_timeout: 30

View File

@ -0,0 +1,37 @@
# @package _global_
hydra:
job:
config:
override_dirname:
kv_sep: ':'
item_sep: '/'
exclude_keys:
- run_config
- distributed_training.distributed_port
- distributed_training.distributed_world_size
- model.pretrained_model_path
- model.target_network_path
- next_script
- task.cache_in_scratch
- task.data
- checkpoint.save_interval_updates
- checkpoint.keep_interval_updates
- checkpoint.save_on_overflow
- common.log_interval
- common.user_dir
sweep:
dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
subdir: ''
launcher:
submitit_folder: ${hydra.sweep.dir}
timeout_min: 4320
cpus_per_task: 10
gpus_per_node: 8
tasks_per_node: 8
mem_gb: 450
nodes: 2
name: ${env:PREFIX}_${hydra.job.config_name}
partition: devlab,learnlab,learnfair,scavenge
constraint: volta32gb,ib4
max_num_timeout: 30

View File

@ -0,0 +1,39 @@
# @package _global_
hydra:
job:
config:
override_dirname:
kv_sep: ':'
item_sep: '/'
exclude_keys:
- run_config
- distributed_training.distributed_port
- distributed_training.distributed_world_size
- model.pretrained_model_path
- model.target_network_path
- next_script
- task.cache_in_scratch
- task.local_cache_path
- task.data
- task.post_save_script
- checkpoint.save_interval_updates
- checkpoint.keep_interval_updates
- checkpoint.save_on_overflow
- common.log_interval
- common.user_dir
- model.model_path
sweep:
dir: /fsx-wav2vec/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
subdir: ''
launcher:
submitit_folder: ${hydra.sweep.dir}
timeout_min: 4320
cpus_per_task: 12
gpus_per_node: 8
tasks_per_node: 8
mem_gb: 0
nodes: 2
name: ${env:PREFIX}_${hydra.job.config_name}
partition: wav2vec
max_num_timeout: 30

View File

@ -0,0 +1,36 @@
# @package _global_
hydra:
job:
config:
override_dirname:
kv_sep: ':'
item_sep: '/'
exclude_keys:
- run_config
- distributed_training.distributed_port
- distributed_training.distributed_world_size
- model.pretrained_model_path
- model.target_network_path
- next_script
- task.cache_in_scratch
- task.data
- checkpoint.save_interval_updates
- checkpoint.keep_interval_updates
- checkpoint.save_on_overflow
- common.log_interval
sweep:
dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
subdir: ''
launcher:
submitit_folder: ${hydra.sweep.dir}
timeout_min: 4320
cpus_per_task: 80
gpus_per_node: 8
tasks_per_node: 1
mem_gb: 450
nodes: 3
name: ${env:PREFIX}_${hydra.job.config_name}
partition: devlab,learnlab,learnfair,scavenge
constraint: volta32gb,ib4
max_num_timeout: 30

View File

@ -0,0 +1,36 @@
# @package _global_
hydra:
job:
config:
override_dirname:
kv_sep: ':'
item_sep: '/'
exclude_keys:
- run_config
- distributed_training.distributed_port
- distributed_training.distributed_world_size
- model.pretrained_model_path
- model.target_network_path
- next_script
- task.cache_in_scratch
- task.data
- checkpoint.save_interval_updates
- checkpoint.keep_interval_updates
- checkpoint.save_on_overflow
- common.log_interval
sweep:
dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
subdir: ''
launcher:
submitit_folder: ${hydra.sweep.dir}
timeout_min: 4320
cpus_per_task: 10
gpus_per_node: 8
tasks_per_node: 8
mem_gb: 450
nodes: 4
name: ${env:PREFIX}_${hydra.job.config_name}
partition: devlab,learnlab,learnfair,scavenge
constraint: volta32gb,ib4
max_num_timeout: 30

View File

@ -0,0 +1,37 @@
# @package _global_
hydra:
job:
config:
override_dirname:
kv_sep: ':'
item_sep: '/'
exclude_keys:
- run_config
- distributed_training.distributed_port
- distributed_training.distributed_world_size
- model.pretrained_model_path
- model.target_network_path
- next_script
- task.cache_in_scratch
- task.data
- task.post_save_script
- checkpoint.save_interval_updates
- checkpoint.keep_interval_updates
- checkpoint.save_on_overflow
- common.log_interval
- common.user_dir
sweep:
dir: /fsx-wav2vec/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
subdir: ''
launcher:
submitit_folder: ${hydra.sweep.dir}
timeout_min: 4320
cpus_per_task: 12
gpus_per_node: 8
tasks_per_node: 8
mem_gb: 0
nodes: 4
name: ${env:PREFIX}_${hydra.job.config_name}
partition: wav2vec
max_num_timeout: 30

View File

@ -0,0 +1,36 @@
# @package _global_
hydra:
job:
config:
override_dirname:
kv_sep: ':'
item_sep: '/'
exclude_keys:
- run_config
- distributed_training.distributed_port
- distributed_training.distributed_world_size
- model.pretrained_model_path
- model.target_network_path
- next_script
- task.cache_in_scratch
- task.data
- checkpoint.save_interval_updates
- checkpoint.keep_interval_updates
- checkpoint.save_on_overflow
- common.log_interval
- common.user_dir
sweep:
dir: /fsx-wav2vec/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
subdir: ''
launcher:
submitit_folder: ${hydra.sweep.dir}
timeout_min: 4320
cpus_per_task: 12
gpus_per_node: 8
tasks_per_node: 8
mem_gb: 0
nodes: 6
name: ${env:PREFIX}_${hydra.job.config_name}
partition: wav2vec,learnlab,learnfair
max_num_timeout: 30

View File

@ -0,0 +1,37 @@
# @package _global_
hydra:
job:
config:
override_dirname:
kv_sep: ':'
item_sep: '/'
exclude_keys:
- run_config
- distributed_training.distributed_port
- distributed_training.distributed_world_size
- model.pretrained_model_path
- model.target_network_path
- next_script
- task.cache_in_scratch
- task.data
- checkpoint.save_interval_updates
- checkpoint.keep_interval_updates
- checkpoint.save_on_overflow
- common.log_interval
- common.user_dir
sweep:
dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
subdir: ''
launcher:
submitit_folder: ${hydra.sweep.dir}
timeout_min: 4320
cpus_per_task: 10
gpus_per_node: 8
tasks_per_node: 8
mem_gb: 450
nodes: 8
name: ${env:PREFIX}_${hydra.job.config_name}
partition: devlab,learnlab,learnfair,scavenge
constraint: volta32gb,ib4
max_num_timeout: 30

View File

@ -0,0 +1,36 @@
# @package _global_
hydra:
job:
config:
override_dirname:
kv_sep: ':'
item_sep: '/'
exclude_keys:
- run_config
- distributed_training.distributed_port
- distributed_training.distributed_world_size
- model.pretrained_model_path
- model.target_network_path
- next_script
- task.cache_in_scratch
- task.data
- checkpoint.save_interval_updates
- checkpoint.keep_interval_updates
- checkpoint.save_on_overflow
- common.log_interval
- common.user_dir
sweep:
dir: /fsx-wav2vec/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
subdir: ''
launcher:
submitit_folder: ${hydra.sweep.dir}
timeout_min: 4320
cpus_per_task: 12
gpus_per_node: 8
tasks_per_node: 8
mem_gb: 0
nodes: 8
name: ${env:PREFIX}_${hydra.job.config_name}
partition: wav2vec,learnlab,learnfair
max_num_timeout: 30

View File

@ -0,0 +1,60 @@
# @package _group_
common:
fp16: true
fp16_init_scale: 4
threshold_loss_scale: 1
fp16_scale_window: 128
log_format: json
log_interval: 200
user_dir: ${env:PWD}/examples/data2vec
task:
_name: sentence_prediction
data: ???
init_token: 0
separator_token: 2
num_classes: 2
max_positions: 512
d2v2_multi: True
checkpoint:
best_checkpoint_metric: mcc
maximize_best_checkpoint_metric: true
no_epoch_checkpoints: true
distributed_training:
find_unused_parameters: true
distributed_world_size: 1
nprocs_per_node: 1
distributed_port: -1
criterion:
_name: sentence_prediction
report_mcc: True
dataset:
batch_size: 16
required_batch_size_multiple: 1
max_tokens: 4400
num_workers: 1
optimizer:
_name: adam
weight_decay: 0.1
adam_betas: (0.9,0.98)
adam_eps: 1e-06
lr_scheduler:
_name: polynomial_decay
warmup_updates: 320
optimization:
clip_norm: 0.0
lr: [2e-05]
max_update: 5336
max_epoch: 10
model:
_name: data2vec_text_classification
model_path: ???

View File

@ -0,0 +1,60 @@
# @package _group_
common:
fp16: true
fp16_init_scale: 4
threshold_loss_scale: 1
fp16_scale_window: 128
log_format: json
log_interval: 200
user_dir: ${env:PWD}/examples/data2vec
task:
_name: sentence_prediction
data: ???
init_token: 0
separator_token: 2
num_classes: 3
max_positions: 512
d2v2_multi: True
checkpoint:
best_checkpoint_metric: accuracy
maximize_best_checkpoint_metric: true
no_epoch_checkpoints: true
distributed_training:
find_unused_parameters: true
distributed_world_size: 1
nprocs_per_node: 1
distributed_port: -1
criterion:
_name: sentence_prediction
dataset:
batch_size: 32
required_batch_size_multiple: 1
max_tokens: 4400
valid_subset: valid,valid1
num_workers: 1
optimizer:
_name: adam
weight_decay: 0.1
adam_betas: (0.9,0.98)
adam_eps: 1e-06
lr_scheduler:
_name: polynomial_decay
warmup_updates: 7432
optimization:
clip_norm: 0.0
lr: [2e-05]
max_update: 123873
max_epoch: 10
model:
_name: data2vec_text_classification
model_path: ???

View File

@ -0,0 +1,60 @@
# @package _group_
common:
fp16: true
fp16_init_scale: 4
threshold_loss_scale: 1
fp16_scale_window: 128
log_format: json
log_interval: 200
user_dir: ${env:PWD}/examples/data2vec
task:
_name: sentence_prediction
data: ???
init_token: 0
separator_token: 2
num_classes: 2
max_positions: 512
d2v2_multi: True
checkpoint:
best_checkpoint_metric: acc_and_f1
maximize_best_checkpoint_metric: true
no_epoch_checkpoints: true
distributed_training:
find_unused_parameters: true
distributed_world_size: 1
nprocs_per_node: 1
distributed_port: -1
criterion:
_name: sentence_prediction
report_acc_and_f1: True
dataset:
batch_size: 16
required_batch_size_multiple: 1
max_tokens: 4400
num_workers: 1
optimizer:
_name: adam
weight_decay: 0.1
adam_betas: (0.9,0.98)
adam_eps: 1e-06
lr_scheduler:
_name: polynomial_decay
warmup_updates: 137
optimization:
clip_norm: 0.0
lr: [2e-05]
max_update: 2296
max_epoch: 10
model:
_name: data2vec_text_classification
model_path: ???

View File

@ -0,0 +1,59 @@
# @package _group_
common:
fp16: true
fp16_init_scale: 4
threshold_loss_scale: 1
fp16_scale_window: 128
log_format: json
log_interval: 200
user_dir: ${env:PWD}/examples/data2vec
task:
_name: sentence_prediction
data: ???
init_token: 0
separator_token: 2
num_classes: 2
max_positions: 512
d2v2_multi: True
checkpoint:
best_checkpoint_metric: accuracy
maximize_best_checkpoint_metric: true
no_epoch_checkpoints: true
distributed_training:
find_unused_parameters: true
distributed_world_size: 1
nprocs_per_node: 1
distributed_port: -1
criterion:
_name: sentence_prediction
dataset:
batch_size: 32
required_batch_size_multiple: 1
max_tokens: 4400
num_workers: 1
optimizer:
_name: adam
weight_decay: 0.1
adam_betas: (0.9,0.98)
adam_eps: 1e-06
lr_scheduler:
_name: polynomial_decay
warmup_updates: 1986
optimization:
clip_norm: 0.0
lr: [2e-05]
max_update: 33112
max_epoch: 10
model:
_name: data2vec_text_classification
model_path: ???

View File

@ -0,0 +1,60 @@
# @package _group_
common:
fp16: true
fp16_init_scale: 4
threshold_loss_scale: 1
fp16_scale_window: 128
log_format: json
log_interval: 200
user_dir: ${env:PWD}/examples/data2vec
task:
_name: sentence_prediction
data: ???
init_token: 0
separator_token: 2
num_classes: 2
max_positions: 512
d2v2_multi: True
checkpoint:
best_checkpoint_metric: acc_and_f1
maximize_best_checkpoint_metric: true
no_epoch_checkpoints: true
distributed_training:
find_unused_parameters: true
distributed_world_size: 1
nprocs_per_node: 1
distributed_port: -1
criterion:
_name: sentence_prediction
report_acc_and_f1: True
dataset:
batch_size: 32
required_batch_size_multiple: 1
max_tokens: 4400
num_workers: 1
optimizer:
_name: adam
weight_decay: 0.1
adam_betas: (0.9,0.98)
adam_eps: 1e-06
lr_scheduler:
_name: polynomial_decay
warmup_updates: 28318
optimization:
clip_norm: 0.0
lr: [2e-05]
max_update: 113272
max_epoch: 10
model:
_name: data2vec_text_classification
model_path: ???

View File

@ -0,0 +1,59 @@
# @package _group_
common:
fp16: true
fp16_init_scale: 4
threshold_loss_scale: 1
fp16_scale_window: 128
log_format: json
log_interval: 200
user_dir: ${env:PWD}/examples/data2vec
task:
_name: sentence_prediction
data: ???
init_token: 0
separator_token: 2
num_classes: 2
max_positions: 512
d2v2_multi: True
checkpoint:
best_checkpoint_metric: accuracy
maximize_best_checkpoint_metric: true
no_epoch_checkpoints: true
distributed_training:
find_unused_parameters: true
distributed_world_size: 1
nprocs_per_node: 1
distributed_port: -1
criterion:
_name: sentence_prediction
dataset:
batch_size: 16
required_batch_size_multiple: 1
max_tokens: 4400
num_workers: 1
optimizer:
_name: adam
weight_decay: 0.1
adam_betas: (0.9,0.98)
adam_eps: 1e-06
lr_scheduler:
_name: polynomial_decay
warmup_updates: 122
optimization:
clip_norm: 0.0
lr: [2e-05]
max_update: 2036
max_epoch: 10
model:
_name: data2vec_text_classification
model_path: ???

View File

@ -0,0 +1,15 @@
# @package _global_
hydra:
sweep:
dir: ${env:PWD}/tmp_dbg/${now:%H-%M-%S}
distributed_training:
distributed_world_size: 1
nprocs_per_node: 1
distributed_port: -1
common:
log_interval: 1
dataset:
num_workers: 0

View File

@ -0,0 +1,59 @@
# @package _group_
common:
fp16: true
fp16_init_scale: 4
threshold_loss_scale: 1
fp16_scale_window: 128
log_format: json
log_interval: 200
user_dir: ${env:PWD}/examples/data2vec
task:
_name: sentence_prediction
data: ???
init_token: 0
separator_token: 2
num_classes: 2
max_positions: 512
d2v2_multi: True
checkpoint:
best_checkpoint_metric: accuracy
maximize_best_checkpoint_metric: true
no_epoch_checkpoints: true
distributed_training:
find_unused_parameters: true
distributed_world_size: 1
nprocs_per_node: 1
distributed_port: -1
criterion:
_name: sentence_prediction
dataset:
batch_size: 32
required_batch_size_multiple: 1
max_tokens: 4400
num_workers: 1
optimizer:
_name: adam
weight_decay: 0.1
adam_betas: (0.9,0.98)
adam_eps: 1e-06
lr_scheduler:
_name: polynomial_decay
warmup_updates: 1256
optimization:
clip_norm: 0.0
lr: [2e-05]
max_update: 20935
max_epoch: 10
model:
_name: data2vec_text_classification
model_path: ???

View File

@ -0,0 +1,61 @@
# @package _group_
common:
fp16: true
fp16_init_scale: 4
threshold_loss_scale: 1
fp16_scale_window: 128
log_format: json
log_interval: 200
user_dir: ${env:PWD}/examples/data2vec
task:
_name: sentence_prediction
data: ???
init_token: 0
separator_token: 2
num_classes: 1
max_positions: 512
d2v2_multi: True
checkpoint:
best_checkpoint_metric: pearson_and_spearman
maximize_best_checkpoint_metric: true
no_epoch_checkpoints: true
distributed_training:
find_unused_parameters: true
distributed_world_size: 1
nprocs_per_node: 1
distributed_port: -1
criterion:
_name: sentence_prediction
regression_target: true
report_pearson_and_spearman: True
dataset:
batch_size: 16
required_batch_size_multiple: 1
max_tokens: 4400
num_workers: 1
optimizer:
_name: adam
weight_decay: 0.1
adam_betas: (0.9,0.98)
adam_eps: 1e-06
lr_scheduler:
_name: polynomial_decay
warmup_updates: 214
optimization:
clip_norm: 0.0
lr: [4e-05]
max_update: 3598
max_epoch: 10
model:
_name: data2vec_text_classification
model_path: ???

View File

@ -0,0 +1,52 @@
# @package _group_
common:
fp16: true
log_format: json
log_interval: 200
tensorboard_logdir: tb
checkpoint:
save_interval: 1
save_interval_updates: 25000
keep_interval_updates: 1
no_epoch_checkpoints: true
best_checkpoint_metric: accuracy
task:
_name: image_classification
data: /datasets01/imagenet_full_size/061417
dataset:
num_workers: 6
batch_size: 64
skip_invalid_size_inputs_valid_test: true
required_batch_size_multiple: 1
valid_subset: val
distributed_training:
distributed_world_size: 8
ddp_backend: c10d
criterion:
_name: model
log_keys:
- correct
optimization:
max_update: 100000
lr: [0.0005]
optimizer:
_name: adam
adam_betas: (0.9,0.98)
adam_eps: 1e-06
weight_decay: 0.01
lr_scheduler:
_name: cosine
warmup_updates: 10000
model:
_name: data2vec_image_classification
model_path: ???

View File

@ -0,0 +1,65 @@
# @package _group_
common:
fp16: true
log_format: json
log_interval: 200
tensorboard_logdir: tb
fp16_no_flatten_grads: true
checkpoint:
save_interval: 1
save_interval_updates: 25000
keep_interval_updates: 1
no_epoch_checkpoints: true
best_checkpoint_metric: accuracy
maximize_best_checkpoint_metric: true
task:
_name: mae_image_classification
data: /datasets01/imagenet_full_size/061417
dataset:
num_workers: 6
batch_size: 32
skip_invalid_size_inputs_valid_test: true
required_batch_size_multiple: 2
valid_subset: val
distributed_training:
distributed_world_size: 16
ddp_backend: c10d
criterion:
_name: model
log_keys:
- correct
optimization:
max_update: 250200
lr: [0.001]
optimizer:
_name: composite
dynamic_groups: true
groups:
default:
lr_float: 0.001
optimizer:
_name: adam
adam_betas: [0.9,0.95]
weight_decay: 0.05
lr_scheduler:
_name: cosine
warmup_updates: 16000
min_lr: 1e-6
lr_scheduler: pass_through
model:
_name: mae_image_classification
mixup: 0.7
mixup_prob: 0.9
model_path: ???

View File

@ -0,0 +1,68 @@
# @package _group_
common:
fp16: true
log_format: json
log_interval: 200
tensorboard_logdir: tb
fp16_no_flatten_grads: true
checkpoint:
save_interval: 1
save_interval_updates: 25000
keep_interval_updates: 1
no_epoch_checkpoints: true
best_checkpoint_metric: accuracy
maximize_best_checkpoint_metric: true
task:
_name: mae_image_classification
data: /datasets01/imagenet_full_size/061417
dataset:
num_workers: 6
batch_size: 32
skip_invalid_size_inputs_valid_test: true
required_batch_size_multiple: 2
valid_subset: val
distributed_training:
distributed_world_size: 16
ddp_backend: c10d
criterion:
_name: model
log_keys:
- correct
optimization:
max_update: 125200
lr: [0.0005]
clip_norm: 4
optimizer:
_name: composite
dynamic_groups: true
groups:
default:
lr_float: 0.0005
optimizer:
_name: adam
adam_betas: [0.9,0.95]
weight_decay: 0.05
lr_scheduler:
_name: cosine
warmup_updates: 16000
min_lr: 1e-20
lr_scheduler: pass_through
model:
_name: mae_image_classification
mixup: 0.7
mixup_prob: 0.9
layer_decay: 0.75
drop_path_rate: 0.2
model_path: ???

View File

@ -0,0 +1,68 @@
# @package _group_
common:
fp16: true
log_format: json
log_interval: 200
tensorboard_logdir: tb
fp16_no_flatten_grads: true
checkpoint:
save_interval: 1
save_interval_updates: 25000
keep_interval_updates: 1
no_epoch_checkpoints: true
best_checkpoint_metric: accuracy
maximize_best_checkpoint_metric: true
task:
_name: mae_image_classification
data: /datasets01/imagenet_full_size/061417
dataset:
num_workers: 6
batch_size: 32
skip_invalid_size_inputs_valid_test: true
required_batch_size_multiple: 2
valid_subset: val
distributed_training:
distributed_world_size: 16
ddp_backend: c10d
criterion:
_name: model
log_keys:
- correct
optimization:
max_update: 125200
lr: [0.0005]
clip_norm: 4
optimizer:
_name: composite
dynamic_groups: true
groups:
default:
lr_float: 0.0005
optimizer:
_name: adam
adam_betas: [0.9,0.95]
weight_decay: 0.05
lr_scheduler:
_name: cosine
warmup_updates: 16000
min_lr: 1e-7
lr_scheduler: pass_through
model:
_name: mae_image_classification
mixup: 0.7
mixup_prob: 0.9
layer_decay: 0.75
drop_path_rate: 0.2
model_path: ???

View File

@ -0,0 +1,15 @@
# @package _global_
hydra:
sweep:
dir: ${env:PWD}/tmp_dbg/${now:%H-%M-%S}
distributed_training:
distributed_world_size: 1
nprocs_per_node: 1
distributed_port: -1
common:
log_interval: 1
dataset:
num_workers: 0

View File

@ -0,0 +1,37 @@
# @package _global_
hydra:
job:
config:
override_dirname:
kv_sep: ':'
item_sep: '/'
exclude_keys:
- run_config
- distributed_training.distributed_port
- distributed_training.distributed_world_size
- model.pretrained_model_path
- model.target_network_path
- next_script
- task.cache_in_scratch
- task.data
- checkpoint.save_interval_updates
- checkpoint.keep_interval_updates
- checkpoint.save_on_overflow
- common.log_interval
- common.user_dir
sweep:
dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
subdir: ''
launcher:
submitit_folder: ${hydra.sweep.dir}
timeout_min: 4320
cpus_per_task: 80
gpus_per_node: 8
tasks_per_node: 1
mem_gb: 450
nodes: 1
name: ${env:PREFIX}_${hydra.job.config_name}
partition: devlab,learnlab,learnfair,scavenge
constraint: volta32gb,ib4
max_num_timeout: 30

View File

@ -0,0 +1,36 @@
# @package _global_
hydra:
job:
config:
override_dirname:
kv_sep: ':'
item_sep: '/'
exclude_keys:
- run_config
- distributed_training.distributed_port
- distributed_training.distributed_world_size
- model.pretrained_model_path
- model.target_network_path
- next_script
- task.cache_in_scratch
- task.data
- checkpoint.save_interval_updates
- checkpoint.keep_interval_updates
- checkpoint.save_on_overflow
- common.log_interval
- common.user_dir
sweep:
dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
subdir: ''
launcher:
submitit_folder: ${hydra.sweep.dir}
timeout_min: 4320
cpus_per_task: 80
gpus_per_node: 8
tasks_per_node: 1
mem_gb: 0
nodes: 1
name: ${env:PREFIX}_${hydra.job.config_name}
partition: wav2vec,learnlab,learnfair
max_num_timeout: 30

View File

@ -0,0 +1,38 @@
# @package _global_
hydra:
job:
config:
override_dirname:
kv_sep: ':'
item_sep: '/'
exclude_keys:
- run_config
- distributed_training.distributed_port
- distributed_training.distributed_world_size
- model.pretrained_model_path
- model.target_network_path
- next_script
- task.cache_in_scratch
- task.data
- checkpoint.save_interval_updates
- checkpoint.keep_interval_updates
- checkpoint.save_on_overflow
- common.log_interval
- common.user_dir
- task.local_cache_path
sweep:
dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
subdir: ''
launcher:
submitit_folder: ${hydra.sweep.dir}
timeout_min: 4320
cpus_per_task: 10
gpus_per_node: 8
tasks_per_node: 8
mem_gb: 450
nodes: 2
name: ${env:PREFIX}_${hydra.job.config_name}
partition: devlab,learnlab,learnfair,scavenge
constraint: volta32gb,ib4
max_num_timeout: 30

View File

@ -0,0 +1,38 @@
# @package _global_
hydra:
job:
config:
override_dirname:
kv_sep: ':'
item_sep: '/'
exclude_keys:
- run_config
- distributed_training.distributed_port
- distributed_training.distributed_world_size
- model.pretrained_model_path
- model.target_network_path
- next_script
- task.cache_in_scratch
- task.data
- checkpoint.save_interval_updates
- checkpoint.keep_interval_updates
- checkpoint.save_on_overflow
- common.log_interval
- common.user_dir
- task.local_cache_path
- model.model_path
sweep:
dir: /fsx-wav2vec/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
subdir: ''
launcher:
submitit_folder: ${hydra.sweep.dir}
timeout_min: 4320
cpus_per_task: 10
gpus_per_node: 8
tasks_per_node: 8
mem_gb: 0
nodes: 2
name: ${env:PREFIX}_${hydra.job.config_name}
partition: wav2vec,learnlab,learnfair
max_num_timeout: 30

View File

@ -0,0 +1,36 @@
# @package _global_
hydra:
job:
config:
override_dirname:
kv_sep: ':'
item_sep: '/'
exclude_keys:
- run_config
- distributed_training.distributed_port
- distributed_training.distributed_world_size
- model.pretrained_model_path
- model.target_network_path
- next_script
- task.cache_in_scratch
- task.data
- checkpoint.save_interval_updates
- checkpoint.keep_interval_updates
- checkpoint.save_on_overflow
- common.log_interval
sweep:
dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
subdir: ''
launcher:
submitit_folder: ${hydra.sweep.dir}
timeout_min: 4320
cpus_per_task: 80
gpus_per_node: 8
tasks_per_node: 1
mem_gb: 450
nodes: 3
name: ${env:PREFIX}_${hydra.job.config_name}
partition: devlab,learnlab,learnfair,scavenge
constraint: volta32gb,ib4
max_num_timeout: 30

View File

@ -0,0 +1,36 @@
# @package _global_
hydra:
job:
config:
override_dirname:
kv_sep: ':'
item_sep: '/'
exclude_keys:
- run_config
- distributed_training.distributed_port
- distributed_training.distributed_world_size
- model.pretrained_model_path
- model.target_network_path
- next_script
- task.cache_in_scratch
- task.data
- checkpoint.save_interval_updates
- checkpoint.keep_interval_updates
- checkpoint.save_on_overflow
- common.log_interval
sweep:
dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
subdir: ''
launcher:
submitit_folder: ${hydra.sweep.dir}
timeout_min: 4320
cpus_per_task: 10
gpus_per_node: 8
tasks_per_node: 8
mem_gb: 450
nodes: 4
name: ${env:PREFIX}_${hydra.job.config_name}
partition: devlab,learnlab,learnfair,scavenge
constraint: volta32gb,ib4
max_num_timeout: 30

View File

@ -0,0 +1,36 @@
# @package _global_
hydra:
job:
config:
override_dirname:
kv_sep: ':'
item_sep: '/'
exclude_keys:
- run_config
- distributed_training.distributed_port
- distributed_training.distributed_world_size
- model.pretrained_model_path
- model.target_network_path
- next_script
- task.cache_in_scratch
- task.data
- checkpoint.save_interval_updates
- checkpoint.keep_interval_updates
- checkpoint.save_on_overflow
- common.log_interval
- common.user_dir
sweep:
dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
subdir: ''
launcher:
submitit_folder: ${hydra.sweep.dir}
timeout_min: 4320
cpus_per_task: 10
gpus_per_node: 8
tasks_per_node: 8
mem_gb: 0
nodes: 4
name: ${env:PREFIX}_${hydra.job.config_name}
partition: wav2vec,learnlab,learnfair
max_num_timeout: 30

View File

@ -0,0 +1,36 @@
# @package _global_
hydra:
job:
config:
override_dirname:
kv_sep: ':'
item_sep: '/'
exclude_keys:
- run_config
- distributed_training.distributed_port
- distributed_training.distributed_world_size
- model.pretrained_model_path
- model.target_network_path
- next_script
- task.cache_in_scratch
- task.data
- checkpoint.save_interval_updates
- checkpoint.keep_interval_updates
- checkpoint.save_on_overflow
- common.log_interval
- common.user_dir
sweep:
dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
subdir: ''
launcher:
submitit_folder: ${hydra.sweep.dir}
timeout_min: 4320
cpus_per_task: 10
gpus_per_node: 8
tasks_per_node: 8
mem_gb: 0
nodes: 6
name: ${env:PREFIX}_${hydra.job.config_name}
partition: wav2vec,learnlab,learnfair
max_num_timeout: 30

View File

@ -0,0 +1,36 @@
# @package _global_
hydra:
job:
config:
override_dirname:
kv_sep: ':'
item_sep: '/'
exclude_keys:
- run_config
- distributed_training.distributed_port
- distributed_training.distributed_world_size
- model.pretrained_model_path
- model.target_network_path
- next_script
- task.cache_in_scratch
- task.data
- checkpoint.save_interval_updates
- checkpoint.keep_interval_updates
- checkpoint.save_on_overflow
- common.log_interval
- common.user_dir
sweep:
dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
subdir: ''
launcher:
submitit_folder: ${hydra.sweep.dir}
timeout_min: 4320
cpus_per_task: 10
gpus_per_node: 8
tasks_per_node: 8
mem_gb: 0
nodes: 8
name: ${env:PREFIX}_${hydra.job.config_name}
partition: wav2vec,learnlab,learnfair
max_num_timeout: 30

View File

@ -0,0 +1,52 @@
# @package _group_
common:
fp16: true
log_format: json
log_interval: 200
tensorboard_logdir: tb
checkpoint:
save_interval: 5
save_interval_updates: 25000
keep_interval_updates: 1
no_epoch_checkpoints: true
task:
_name: image_pretraining
data: /datasets01/imagenet_full_size/061417/
dataset:
num_workers: 6
batch_size: 64
skip_invalid_size_inputs_valid_test: true
required_batch_size_multiple: 1
disable_validation: true
distributed_training:
distributed_world_size: 16
ddp_backend: c10d
criterion:
_name: model
log_keys:
- ema_decay
- target_var
- pred_var
optimization:
max_update: 400000
lr: [0.0005]
optimizer:
_name: adam
adam_betas: (0.9,0.98)
adam_eps: 1e-06
weight_decay: 0.01
lr_scheduler:
_name: cosine
warmup_updates: 10000
model:
_name: data2vec_vision

View File

@ -0,0 +1,64 @@
# @package _group_
common:
fp16: true
log_format: json
log_interval: 200
tensorboard_logdir: tb
checkpoint:
save_interval: 5
save_interval_updates: 25000
keep_interval_updates: 1
no_epoch_checkpoints: true
task:
_name: image_pretraining
data: /datasets01/imagenet_full_size/061417
dataset:
num_workers: 6
batch_size: 128
skip_invalid_size_inputs_valid_test: true
required_batch_size_multiple: 2
disable_validation: true
distributed_training:
distributed_world_size: 16
ddp_backend: legacy_ddp
criterion:
_name: model
log_keys:
- ema_decay
- target_var
- pred_var
optimization:
max_update: 375300 #300*1251
lr: [0.0005]
clip_norm: 3.0
optimizer:
_name: adam
adam_betas: (0.9,0.999)
adam_eps: 1e-08
weight_decay: 0.05
lr_scheduler:
_name: cosine
warmup_updates: 12510 # it should be 10 epochs
model:
_name: data2vec_vision
attention_dropout: 0.05
ema_decay: 0.999
ema_end_decay: 0.9998
layer_norm_targets: True
average_top_k_layers: 6
loss_beta: 2.0
drop_path: 0.25

View File

@ -0,0 +1,64 @@
# @package _group_
common:
fp16: true
log_format: json
log_interval: 200
tensorboard_logdir: tb
fp16_no_flatten_grads: true
checkpoint:
save_interval: 5
save_interval_updates: 25000
keep_interval_updates: 1
no_epoch_checkpoints: true
task:
_name: mae_image_pretraining
data: /datasets01/imagenet_full_size/061417/
rebuild_batches: true
dataset:
num_workers: 6
batch_size: 64
skip_invalid_size_inputs_valid_test: true
required_batch_size_multiple: 1
disable_validation: true
distributed_training:
distributed_world_size: 16
ddp_backend: c10d
criterion:
_name: model
optimization:
max_update: 375300
lr: [0.0006]
optimizer:
_name: composite
groups:
with_decay:
lr_float: 6e-4
optimizer:
_name: adam
adam_betas: [0.9,0.95]
weight_decay: 0.05
lr_scheduler:
_name: cosine
warmup_updates: 50040
no_decay:
lr_float: 6e-4
optimizer:
_name: adam
adam_betas: [0.9,0.95]
weight_decay: 0
lr_scheduler:
_name: cosine
warmup_updates: 50040
lr_scheduler: pass_through
model:
_name: mae

View File

@ -0,0 +1,15 @@
# @package _global_
hydra:
sweep:
dir: ${env:PWD}/tmp_dbg/${now:%H-%M-%S}
distributed_training:
distributed_world_size: 1
nprocs_per_node: 1
distributed_port: -1
common:
log_interval: 1
dataset:
num_workers: 0

View File

@ -0,0 +1,37 @@
# @package _global_
hydra:
job:
config:
override_dirname:
kv_sep: ':'
item_sep: '/'
exclude_keys:
- run_config
- distributed_training.distributed_port
- distributed_training.distributed_world_size
- model.pretrained_model_path
- model.target_network_path
- next_script
- task.cache_in_scratch
- task.data
- checkpoint.save_interval_updates
- checkpoint.keep_interval_updates
- checkpoint.save_on_overflow
- common.log_interval
- common.user_dir
sweep:
dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
subdir: ''
launcher:
submitit_folder: ${hydra.sweep.dir}
timeout_min: 4320
cpus_per_task: 80
gpus_per_node: 8
tasks_per_node: 1
mem_gb: 450
nodes: 1
name: ${env:PREFIX}_${hydra.job.config_name}
partition: devlab,learnlab,learnfair,scavenge
constraint: volta32gb,ib4
max_num_timeout: 30

View File

@ -0,0 +1,36 @@
# @package _global_
hydra:
job:
config:
override_dirname:
kv_sep: ':'
item_sep: '/'
exclude_keys:
- run_config
- distributed_training.distributed_port
- distributed_training.distributed_world_size
- model.pretrained_model_path
- model.target_network_path
- next_script
- task.cache_in_scratch
- task.data
- checkpoint.save_interval_updates
- checkpoint.keep_interval_updates
- checkpoint.save_on_overflow
- common.log_interval
- common.user_dir
sweep:
dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
subdir: ''
launcher:
submitit_folder: ${hydra.sweep.dir}
timeout_min: 4320
cpus_per_task: 80
gpus_per_node: 8
tasks_per_node: 1
mem_gb: 0
nodes: 1
name: ${env:PREFIX}_${hydra.job.config_name}
partition: wav2vec,learnlab,learnfair
max_num_timeout: 30

View File

@ -0,0 +1,38 @@
# @package _global_
hydra:
job:
config:
override_dirname:
kv_sep: ':'
item_sep: '/'
exclude_keys:
- run_config
- distributed_training.distributed_port
- distributed_training.distributed_world_size
- model.pretrained_model_path
- model.target_network_path
- next_script
- task.cache_in_scratch
- task.data
- checkpoint.save_interval_updates
- checkpoint.keep_interval_updates
- checkpoint.save_on_overflow
- common.log_interval
- common.user_dir
- task.local_cache_path
sweep:
dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
subdir: ''
launcher:
submitit_folder: ${hydra.sweep.dir}
timeout_min: 4320
cpus_per_task: 10
gpus_per_node: 8
tasks_per_node: 8
mem_gb: 450
nodes: 2
name: ${env:PREFIX}_${hydra.job.config_name}
partition: devlab,learnlab,learnfair,scavenge
constraint: volta32gb,ib4
max_num_timeout: 30

View File

@ -0,0 +1,37 @@
# @package _global_
hydra:
job:
config:
override_dirname:
kv_sep: ':'
item_sep: '/'
exclude_keys:
- run_config
- distributed_training.distributed_port
- distributed_training.distributed_world_size
- model.pretrained_model_path
- model.target_network_path
- next_script
- task.cache_in_scratch
- task.data
- checkpoint.save_interval_updates
- checkpoint.keep_interval_updates
- checkpoint.save_on_overflow
- common.log_interval
- common.user_dir
- task.local_cache_path
sweep:
dir: /fsx-wav2vec/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
subdir: ''
launcher:
submitit_folder: ${hydra.sweep.dir}
timeout_min: 4320
cpus_per_task: 10
gpus_per_node: 8
tasks_per_node: 8
mem_gb: 0
nodes: 2
name: ${env:PREFIX}_${hydra.job.config_name}
partition: wav2vec,learnlab,learnfair
max_num_timeout: 30

View File

@ -0,0 +1,36 @@
# @package _global_
hydra:
job:
config:
override_dirname:
kv_sep: ':'
item_sep: '/'
exclude_keys:
- run_config
- distributed_training.distributed_port
- distributed_training.distributed_world_size
- model.pretrained_model_path
- model.target_network_path
- next_script
- task.cache_in_scratch
- task.data
- checkpoint.save_interval_updates
- checkpoint.keep_interval_updates
- checkpoint.save_on_overflow
- common.log_interval
sweep:
dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
subdir: ''
launcher:
submitit_folder: ${hydra.sweep.dir}
timeout_min: 4320
cpus_per_task: 80
gpus_per_node: 8
tasks_per_node: 1
mem_gb: 450
nodes: 3
name: ${env:PREFIX}_${hydra.job.config_name}
partition: devlab,learnlab,learnfair,scavenge
constraint: volta32gb,ib4
max_num_timeout: 30

View File

@ -0,0 +1,36 @@
# @package _global_
hydra:
job:
config:
override_dirname:
kv_sep: ':'
item_sep: '/'
exclude_keys:
- run_config
- distributed_training.distributed_port
- distributed_training.distributed_world_size
- model.pretrained_model_path
- model.target_network_path
- next_script
- task.cache_in_scratch
- task.data
- checkpoint.save_interval_updates
- checkpoint.keep_interval_updates
- checkpoint.save_on_overflow
- common.log_interval
sweep:
dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
subdir: ''
launcher:
submitit_folder: ${hydra.sweep.dir}
timeout_min: 4320
cpus_per_task: 10
gpus_per_node: 8
tasks_per_node: 8
mem_gb: 450
nodes: 4
name: ${env:PREFIX}_${hydra.job.config_name}
partition: devlab,learnlab,learnfair,scavenge
constraint: volta32gb,ib4
max_num_timeout: 30

View File

@ -0,0 +1,36 @@
# @package _global_
hydra:
job:
config:
override_dirname:
kv_sep: ':'
item_sep: '/'
exclude_keys:
- run_config
- distributed_training.distributed_port
- distributed_training.distributed_world_size
- model.pretrained_model_path
- model.target_network_path
- next_script
- task.cache_in_scratch
- task.data
- checkpoint.save_interval_updates
- checkpoint.keep_interval_updates
- checkpoint.save_on_overflow
- common.log_interval
- common.user_dir
sweep:
dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
subdir: ''
launcher:
submitit_folder: ${hydra.sweep.dir}
timeout_min: 4320
cpus_per_task: 10
gpus_per_node: 8
tasks_per_node: 8
mem_gb: 0
nodes: 4
name: ${env:PREFIX}_${hydra.job.config_name}
partition: wav2vec,learnlab,learnfair
max_num_timeout: 30

View File

@ -0,0 +1,36 @@
# @package _global_
hydra:
job:
config:
override_dirname:
kv_sep: ':'
item_sep: '/'
exclude_keys:
- run_config
- distributed_training.distributed_port
- distributed_training.distributed_world_size
- model.pretrained_model_path
- model.target_network_path
- next_script
- task.cache_in_scratch
- task.data
- checkpoint.save_interval_updates
- checkpoint.keep_interval_updates
- checkpoint.save_on_overflow
- common.log_interval
- common.user_dir
sweep:
dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
subdir: ''
launcher:
submitit_folder: ${hydra.sweep.dir}
timeout_min: 4320
cpus_per_task: 10
gpus_per_node: 8
tasks_per_node: 8
mem_gb: 0
nodes: 6
name: ${env:PREFIX}_${hydra.job.config_name}
partition: wav2vec,learnlab,learnfair
max_num_timeout: 30

View File

@ -0,0 +1,36 @@
# @package _global_
hydra:
job:
config:
override_dirname:
kv_sep: ':'
item_sep: '/'
exclude_keys:
- run_config
- distributed_training.distributed_port
- distributed_training.distributed_world_size
- model.pretrained_model_path
- model.target_network_path
- next_script
- task.cache_in_scratch
- task.data
- checkpoint.save_interval_updates
- checkpoint.keep_interval_updates
- checkpoint.save_on_overflow
- common.log_interval
- common.user_dir
sweep:
dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
subdir: ''
launcher:
submitit_folder: ${hydra.sweep.dir}
timeout_min: 4320
cpus_per_task: 10
gpus_per_node: 8
tasks_per_node: 8
mem_gb: 0
nodes: 8
name: ${env:PREFIX}_${hydra.job.config_name}
partition: wav2vec,learnlab,learnfair
max_num_timeout: 30

View File

@ -0,0 +1,17 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from .image_dataset import ImageDataset
from .path_dataset import PathDataset
from .mae_image_dataset import MaeImageDataset
from .mae_finetuning_image_dataset import MaeFinetuningImageDataset
__all__ = [
"ImageDataset",
"MaeImageDataset",
"MaeFinetuningImageDataset",
"PathDataset",
]

View File

@ -0,0 +1,63 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
from fairseq.data import BaseWrapperDataset, data_utils
class AddClassTargetDataset(BaseWrapperDataset):
def __init__(
self,
dataset,
labels,
multi_class,
num_classes=None,
label_indices=None,
add_to_input=True,
):
super().__init__(dataset)
self.label_indices = label_indices
self.labels = labels
self.multi_class = multi_class
self.add_to_input = add_to_input
if num_classes is None and multi_class:
assert self.label_indices is not None
num_classes = len(self.label_indices)
self.num_classes = num_classes
def __getitem__(self, index):
item = self.dataset[index]
item_labels = self.labels[index]
if self.multi_class:
item["label"] = torch.zeros(self.num_classes)
for il in item_labels:
if self.label_indices is not None:
il = self.label_indices[il]
item["label"][il] = 1.0
else:
item["label"] = torch.tensor(
self.labels[index]
if self.label_indices is None
else self.label_indices[self.labels[index]]
)
return item
def collater(self, samples):
collated = self.dataset.collater(samples)
if len(collated) == 0:
return collated
indices = set(collated["id"].tolist())
target = [s["label"] for s in samples if s["id"] in indices]
collated["label"] = torch.stack(target, dim=0)
if self.add_to_input:
collated["net_input"]["label"] = collated["label"]
return collated

View File

@ -0,0 +1,127 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
import numpy as np
import os
from typing import Optional, Callable, Set
import torch
from torchvision.datasets.vision import VisionDataset
from torchvision.transforms import ToTensor
from fairseq.data import FairseqDataset
logger = logging.getLogger(__name__)
class ImageDataset(FairseqDataset, VisionDataset):
def __init__(
self,
root: str,
extensions: Set[str],
load_classes: bool,
transform: Optional[Callable] = None,
shuffle=True,
):
FairseqDataset.__init__(self)
VisionDataset.__init__(self, root=root, transform=transform)
self.shuffle = shuffle
self.tensor_transform = ToTensor()
self.classes = None
self.labels = None
if load_classes:
classes = [d.name for d in os.scandir(root) if d.is_dir()]
classes.sort()
self.classes = {cls_name: i for i, cls_name in enumerate(classes)}
logger.info(f"loaded {len(self.classes)} classes")
self.labels = []
def walk_path(root_path):
for root, _, fnames in sorted(os.walk(root_path, followlinks=True)):
for fname in sorted(fnames):
fname_ext = os.path.splitext(fname)
if fname_ext[-1].lower() not in extensions:
continue
path = os.path.join(root, fname)
yield path
logger.info(f"finding images in {root}")
if self.classes is not None:
self.files = []
self.labels = []
for c, i in self.classes.items():
for f in walk_path(os.path.join(root, c)):
self.files.append(f)
self.labels.append(i)
else:
self.files = [f for f in walk_path(root)]
logger.info(f"loaded {len(self.files)} examples")
def __getitem__(self, index):
from PIL import Image
fpath = self.files[index]
with open(fpath, "rb") as f:
img = Image.open(f).convert("RGB")
if self.transform is None:
img = self.tensor_transform(img)
else:
img = self.transform(img)
assert torch.is_tensor(img)
res = {"id": index, "img": img}
if self.labels is not None:
res["label"] = self.labels[index]
return res
def __len__(self):
return len(self.files)
def collater(self, samples):
if len(samples) == 0:
return {}
collated_img = torch.stack([s["img"] for s in samples], dim=0)
res = {
"id": torch.LongTensor([s["id"] for s in samples]),
"net_input": {
"img": collated_img,
},
}
if "label" in samples[0]:
res["net_input"]["label"] = torch.LongTensor([s["label"] for s in samples])
return res
def num_tokens(self, index):
return 1
def size(self, index):
return 1
def ordered_indices(self):
"""Return an ordered list of indices. Batches will be constructed based
on this order."""
if self.shuffle:
order = [np.random.permutation(len(self))]
else:
order = [np.arange(len(self))]
return order[0]

View File

@ -0,0 +1,135 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
import numpy as np
import os
import torch
from torchvision import datasets, transforms
from timm.data import create_transform
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
import PIL
from fairseq.data import FairseqDataset
from .mae_image_dataset import caching_loader
logger = logging.getLogger(__name__)
def build_transform(is_train, input_size, color_jitter, aa, reprob, remode, recount):
mean = IMAGENET_DEFAULT_MEAN
std = IMAGENET_DEFAULT_STD
# train transform
if is_train:
# this should always dispatch to transforms_imagenet_train
transform = create_transform(
input_size=input_size,
is_training=True,
color_jitter=color_jitter,
auto_augment=aa,
interpolation="bicubic",
re_prob=reprob,
re_mode=remode,
re_count=recount,
mean=mean,
std=std,
)
return transform
# eval transform
t = []
if input_size <= 224:
crop_pct = 224 / 256
else:
crop_pct = 1.0
size = int(input_size / crop_pct)
t.append(
transforms.Resize(
size, interpolation=PIL.Image.BICUBIC
), # to maintain same ratio w.r.t. 224 images
)
t.append(transforms.CenterCrop(input_size))
t.append(transforms.ToTensor())
t.append(transforms.Normalize(mean, std))
return transforms.Compose(t)
class MaeFinetuningImageDataset(FairseqDataset):
def __init__(
self,
root: str,
split: str,
is_train: bool,
input_size,
color_jitter=None,
aa="rand-m9-mstd0.5-inc1",
reprob=0.25,
remode="pixel",
recount=1,
local_cache_path=None,
shuffle=True,
):
FairseqDataset.__init__(self)
self.shuffle = shuffle
transform = build_transform(
is_train, input_size, color_jitter, aa, reprob, remode, recount
)
path = os.path.join(root, split)
loader = caching_loader(local_cache_path, datasets.folder.default_loader)
self.dataset = datasets.ImageFolder(path, loader=loader, transform=transform)
logger.info(f"loaded {len(self.dataset)} examples")
def __getitem__(self, index):
img, label = self.dataset[index]
return {"id": index, "img": img, "label": label}
def __len__(self):
return len(self.dataset)
def collater(self, samples):
if len(samples) == 0:
return {}
collated_img = torch.stack([s["img"] for s in samples], dim=0)
res = {
"id": torch.LongTensor([s["id"] for s in samples]),
"net_input": {
"imgs": collated_img,
},
}
if "label" in samples[0]:
res["net_input"]["labels"] = torch.LongTensor([s["label"] for s in samples])
return res
def num_tokens(self, index):
return 1
def size(self, index):
return 1
def ordered_indices(self):
"""Return an ordered list of indices. Batches will be constructed based
on this order."""
if self.shuffle:
order = [np.random.permutation(len(self))]
else:
order = [np.arange(len(self))]
return order[0]

View File

@ -0,0 +1,418 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from functools import partial
import logging
import math
import random
import time
import numpy as np
import os
import torch
from torchvision import datasets, transforms
from .path_dataset import PathDataset
from fairseq.data import FairseqDataset
from fairseq.data.data_utils import compute_block_mask_1d, compute_block_mask_2d
from shutil import copyfile
logger = logging.getLogger(__name__)
def load(path, loader, cache):
if hasattr(caching_loader, "cache_root"):
cache = caching_loader.cache_root
cached_path = cache + path
num_tries = 3
for curr_try in range(num_tries):
try:
if curr_try == 2:
return loader(path)
if not os.path.exists(cached_path) or curr_try > 0:
os.makedirs(os.path.dirname(cached_path), exist_ok=True)
copyfile(path, cached_path)
os.chmod(cached_path, 0o777)
return loader(cached_path)
except Exception as e:
logger.warning(str(e))
if "Errno 13" in str(e):
caching_loader.cache_root = f"/scratch/{random.randint(0, 69420)}"
logger.warning(f"setting cache root to {caching_loader.cache_root}")
cached_path = caching_loader.cache_root + path
if curr_try == (num_tries - 1):
raise
time.sleep(2)
def caching_loader(cache_root: str, loader):
if cache_root is None:
return loader
if cache_root == "slurm_tmpdir":
cache_root = os.environ["SLURM_TMPDIR"]
assert len(cache_root) > 0
if not cache_root.endswith("/"):
cache_root += "/"
return partial(load, loader=loader, cache=cache_root)
class RandomResizedCropAndInterpolationWithTwoPic:
"""Crop the given PIL Image to random size and aspect ratio with random interpolation.
A crop of random size (default: of 0.08 to 1.0) of the original size and a random
aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This crop
is finally resized to given size.
This is popularly used to train the Inception networks.
Args:
size: expected output size of each edge
scale: range of size of the origin size cropped
ratio: range of aspect ratio of the origin aspect ratio cropped
interpolation: Default: PIL.Image.BILINEAR
"""
def __init__(
self,
size,
second_size=None,
scale=(0.08, 1.0),
ratio=(3.0 / 4.0, 4.0 / 3.0),
interpolation="bilinear",
second_interpolation="lanczos",
):
if isinstance(size, tuple):
self.size = size
else:
self.size = (size, size)
if second_size is not None:
if isinstance(second_size, tuple):
self.second_size = second_size
else:
self.second_size = (second_size, second_size)
else:
self.second_size = None
if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
logger.warning("range should be of kind (min, max)")
if interpolation == "random":
from PIL import Image
self.interpolation = (Image.BILINEAR, Image.BICUBIC)
else:
self.interpolation = self._pil_interp(interpolation)
self.second_interpolation = (
self._pil_interp(second_interpolation)
if second_interpolation is not None
else None
)
self.scale = scale
self.ratio = ratio
def _pil_interp(self, method):
from PIL import Image
if method == "bicubic":
return Image.BICUBIC
elif method == "lanczos":
return Image.LANCZOS
elif method == "hamming":
return Image.HAMMING
else:
# default bilinear, do we want to allow nearest?
return Image.BILINEAR
@staticmethod
def get_params(img, scale, ratio):
"""Get parameters for ``crop`` for a random sized crop.
Args:
img (PIL Image): Image to be cropped.
scale (tuple): range of size of the origin size cropped
ratio (tuple): range of aspect ratio of the origin aspect ratio cropped
Returns:
tuple: params (i, j, h, w) to be passed to ``crop`` for a random
sized crop.
"""
area = img.size[0] * img.size[1]
for attempt in range(10):
target_area = random.uniform(*scale) * area
log_ratio = (math.log(ratio[0]), math.log(ratio[1]))
aspect_ratio = math.exp(random.uniform(*log_ratio))
w = int(round(math.sqrt(target_area * aspect_ratio)))
h = int(round(math.sqrt(target_area / aspect_ratio)))
if w <= img.size[0] and h <= img.size[1]:
i = random.randint(0, img.size[1] - h)
j = random.randint(0, img.size[0] - w)
return i, j, h, w
# Fallback to central crop
in_ratio = img.size[0] / img.size[1]
if in_ratio < min(ratio):
w = img.size[0]
h = int(round(w / min(ratio)))
elif in_ratio > max(ratio):
h = img.size[1]
w = int(round(h * max(ratio)))
else: # whole image
w = img.size[0]
h = img.size[1]
i = (img.size[1] - h) // 2
j = (img.size[0] - w) // 2
return i, j, h, w
def __call__(self, img):
import torchvision.transforms.functional as F
"""
Args:
img (PIL Image): Image to be cropped and resized.
Returns:
PIL Image: Randomly cropped and resized image.
"""
i, j, h, w = self.get_params(img, self.scale, self.ratio)
if isinstance(self.interpolation, (tuple, list)):
interpolation = random.choice(self.interpolation)
else:
interpolation = self.interpolation
if self.second_size is None:
return F.resized_crop(img, i, j, h, w, self.size, interpolation)
else:
return F.resized_crop(
img, i, j, h, w, self.size, interpolation
), F.resized_crop(
img, i, j, h, w, self.second_size, self.second_interpolation
)
class MaeImageDataset(FairseqDataset):
def __init__(
self,
root: str,
split: str,
input_size,
local_cache_path=None,
shuffle=True,
key="imgs",
beit_transforms=False,
target_transform=False,
no_transform=False,
compute_mask=False,
patch_size: int = 16,
mask_prob: float = 0.75,
mask_prob_adjust: float = 0,
mask_length: int = 1,
inverse_mask: bool = False,
expand_adjacent: bool = False,
mask_dropout: float = 0,
non_overlapping: bool = False,
require_same_masks: bool = True,
clone_batch: int = 1,
dataset_type: str = "imagefolder",
):
FairseqDataset.__init__(self)
self.shuffle = shuffle
self.key = key
loader = caching_loader(local_cache_path, datasets.folder.default_loader)
self.transform_source = None
self.transform_target = None
if target_transform:
self.transform_source = transforms.ColorJitter(0.4, 0.4, 0.4)
self.transform_target = transforms.ColorJitter(0.4, 0.4, 0.4)
if no_transform:
if input_size <= 224:
crop_pct = 224 / 256
else:
crop_pct = 1.0
size = int(input_size / crop_pct)
self.transform_train = transforms.Compose(
[
transforms.Resize(size, interpolation=3),
transforms.CenterCrop(input_size),
]
)
self.transform_train = transforms.Resize((input_size, input_size))
elif beit_transforms:
beit_transform_list = []
if not target_transform:
beit_transform_list.append(transforms.ColorJitter(0.4, 0.4, 0.4))
beit_transform_list.extend(
[
transforms.RandomHorizontalFlip(p=0.5),
RandomResizedCropAndInterpolationWithTwoPic(
size=input_size,
second_size=None,
interpolation="bicubic",
second_interpolation=None,
),
]
)
self.transform_train = transforms.Compose(beit_transform_list)
else:
self.transform_train = transforms.Compose(
[
transforms.RandomResizedCrop(
input_size, scale=(0.2, 1.0), interpolation=3
), # 3 is bicubic
transforms.RandomHorizontalFlip(),
]
)
self.final_transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
),
]
)
if dataset_type == "imagefolder":
self.dataset = datasets.ImageFolder(
os.path.join(root, split), loader=loader
)
elif dataset_type == "path":
self.dataset = PathDataset(
root,
loader,
None,
None,
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
)
else:
raise Exception(f"invalid dataset type {dataset_type}")
logger.info(
f"initial transform: {self.transform_train}, "
f"source transform: {self.transform_source}, "
f"target transform: {self.transform_target}, "
f"final transform: {self.final_transform}"
)
logger.info(f"loaded {len(self.dataset)} examples")
self.is_compute_mask = compute_mask
self.patches = (input_size // patch_size) ** 2
self.mask_prob = mask_prob
self.mask_prob_adjust = mask_prob_adjust
self.mask_length = mask_length
self.inverse_mask = inverse_mask
self.expand_adjacent = expand_adjacent
self.mask_dropout = mask_dropout
self.non_overlapping = non_overlapping
self.require_same_masks = require_same_masks
self.clone_batch = clone_batch
def __getitem__(self, index):
img, _ = self.dataset[index]
img = self.transform_train(img)
source = None
target = None
if self.transform_source is not None:
source = self.final_transform(self.transform_source(img))
if self.transform_target is not None:
target = self.final_transform(self.transform_target(img))
if source is None:
img = self.final_transform(img)
v = {"id": index, self.key: source if source is not None else img}
if target is not None:
v["target"] = target
if self.is_compute_mask:
if self.mask_length == 1:
mask = compute_block_mask_1d(
shape=(self.clone_batch, self.patches),
mask_prob=self.mask_prob,
mask_length=self.mask_length,
mask_prob_adjust=self.mask_prob_adjust,
inverse_mask=self.inverse_mask,
require_same_masks=True,
)
else:
mask = compute_block_mask_2d(
shape=(self.clone_batch, self.patches),
mask_prob=self.mask_prob,
mask_length=self.mask_length,
mask_prob_adjust=self.mask_prob_adjust,
inverse_mask=self.inverse_mask,
require_same_masks=True,
expand_adjcent=self.expand_adjacent,
mask_dropout=self.mask_dropout,
non_overlapping=self.non_overlapping,
)
v["precomputed_mask"] = mask
return v
def __len__(self):
return len(self.dataset)
def collater(self, samples):
if len(samples) == 0:
return {}
collated_img = torch.stack([s[self.key] for s in samples], dim=0)
res = {
"id": torch.LongTensor([s["id"] for s in samples]),
"net_input": {
self.key: collated_img,
},
}
if "target" in samples[0]:
collated_target = torch.stack([s["target"] for s in samples], dim=0)
res["net_input"]["target"] = collated_target
if "precomputed_mask" in samples[0]:
collated_mask = torch.cat([s["precomputed_mask"] for s in samples], dim=0)
res["net_input"]["precomputed_mask"] = collated_mask
return res
def num_tokens(self, index):
return 1
def size(self, index):
return 1
@property
def sizes(self):
return np.full((len(self),), 1)
def ordered_indices(self):
"""Return an ordered list of indices. Batches will be constructed based
on this order."""
if self.shuffle:
order = [np.random.permutation(len(self))]
else:
order = [np.arange(len(self))]
return order[0]

View File

@ -0,0 +1,14 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from enum import Enum, auto
class Modality(Enum):
AUDIO = auto()
IMAGE = auto()
TEXT = auto()

View File

@ -0,0 +1,64 @@
import glob
import os
from typing import List, Optional, Tuple
import logging
import numpy as np
import torchvision.transforms.functional as TF
import PIL
from PIL import Image
from torchvision.datasets import VisionDataset
logger = logging.getLogger(__name__)
class PathDataset(VisionDataset):
def __init__(
self,
root: List[str],
loader: None = None,
transform: Optional[str] = None,
extra_transform: Optional[str] = None,
mean: Optional[List[float]] = None,
std: Optional[List[float]] = None,
):
super().__init__(root=root)
PIL.Image.MAX_IMAGE_PIXELS = 256000001
self.files = []
for folder in self.root:
self.files.extend(
sorted(glob.glob(os.path.join(folder, "**", "*.jpg"), recursive=True))
)
self.files.extend(
sorted(glob.glob(os.path.join(folder, "**", "*.png"), recursive=True))
)
self.transform = transform
self.extra_transform = extra_transform
self.mean = mean
self.std = std
self.loader = loader
logger.info(f"loaded {len(self.files)} samples from {root}")
assert (mean is None) == (std is None)
def __len__(self) -> int:
return len(self.files)
def __getitem__(self, idx) -> Tuple[np.ndarray, np.ndarray]:
path = self.files[idx]
if self.loader is not None:
return self.loader(path), None
img = Image.open(path).convert("RGB")
if self.transform is not None:
img = self.transform(img)
img = TF.to_tensor(img)
if self.mean is not None and self.std is not None:
img = TF.normalize(img, self.mean, self.std)
return img, None

View File

@ -0,0 +1,165 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import torch
from omegaconf import OmegaConf
from fairseq.criterions.model_criterion import ModelCriterionConfig
from fairseq.dataclass.configs import FairseqConfig
from tasks import ImageClassificationConfig, ImagePretrainingConfig
from models.data2vec_image_classification import (
Data2VecImageClassificationConfig,
Data2VecImageClassificationModel,
)
from models.data2vec_vision import Data2VecVisionConfig, Data2VecVisionModel
def get_parser():
parser = argparse.ArgumentParser(
description="convert beit checkpoint into data2vec - vision checkpoint"
)
# fmt: off
parser.add_argument('checkpoint', help='checkpoint to convert')
parser.add_argument('--output', required=True, metavar='PATH', help='where to output converted checkpoint')
parser.add_argument('--type', type=str, choices=['vision', 'image_classification'], default='image_classification', help='type of model to upgrade')
parser.add_argument('--inception_norms', action='store_true', default=False)
# fmt: on
return parser
def update_checkpoint(model_dict, prefix, is_nested):
replace_paths = {
"cls_token": "model.cls_emb" if is_nested else "cls_emb",
"patch_embed": "model.patch_embed" if is_nested else "patch_embed",
"mask_token": "mask_emb",
}
starts_with = {
"patch_embed.proj": "model.patch_embed.conv"
if is_nested
else "patch_embed.conv",
"lm_head": "final_proj",
"fc_norm": "fc_norm",
"head": "head",
}
partial = {
"mlp.fc1": "mlp.0",
"mlp.fc2": "mlp.2",
}
for k in list(model_dict.keys()):
for sw, r in starts_with.items():
if k.startswith(sw):
replace_paths[k] = k.replace(sw, r)
for p, r in partial.items():
if p in k:
replace_paths[k] = prefix + k.replace(p, r)
if prefix != "":
for k in list(model_dict.keys()):
if k not in replace_paths:
replace_paths[k] = prefix + k
for k in list(model_dict.keys()):
if k in replace_paths:
model_dict[replace_paths[k]] = model_dict[k]
if k != replace_paths[k]:
del model_dict[k]
return model_dict
def main():
parser = get_parser()
args = parser.parse_args()
cp = torch.load(args.checkpoint, map_location="cpu")
cfg = FairseqConfig(
criterion=ModelCriterionConfig(_name="model", log_keys=["correct"]),
)
if args.type == "image_classification":
cfg.task = ImageClassificationConfig(
_name="image_classification",
data=".",
)
if args.inception_norms:
cfg.task.normalization_mean = [0.5, 0.5, 0.5]
cfg.task.normalization_std = [0.5, 0.5, 0.5]
cfg.model = Data2VecImageClassificationConfig(
_name="data2vec_image_classification",
)
cfg.model.pretrained_model_args = FairseqConfig(
model=Data2VecVisionConfig(
_name="data2vec_vision", shared_rel_pos_bias=False
),
task=ImagePretrainingConfig(
_name="image_pretraining",
),
)
cfg = OmegaConf.create(cfg)
state = {
"cfg": OmegaConf.to_container(cfg, resolve=True, enum_to_str=True),
"model": cp["module"],
"best_loss": None,
"optimizer": None,
"extra_state": {},
}
model = Data2VecImageClassificationModel(cfg.model)
model.load_state_dict(
update_checkpoint(state["model"], prefix="model.encoder.", is_nested=True),
strict=True,
)
elif args.type == "vision":
cfg.task = ImagePretrainingConfig(
_name="image_pretraining",
data=".",
)
if args.inception_norms:
cfg.task.normalization_mean = [0.5, 0.5, 0.5]
cfg.task.normalization_std = [0.5, 0.5, 0.5]
cfg.model = Data2VecVisionConfig(
_name="data2vec_vision",
)
cfg = OmegaConf.create(cfg)
state = {
"cfg": OmegaConf.to_container(cfg, resolve=True, enum_to_str=True),
"model": cp["model"],
"best_loss": None,
"optimizer": None,
"extra_state": {},
}
model = Data2VecVisionModel(cfg.model)
model.load_state_dict(
update_checkpoint(state["model"], prefix="encoder.", is_nested=False),
strict=True,
)
else:
raise Exception("unsupported type " + args.type)
print(state["cfg"], state.keys())
torch.save(state, args.output)
if __name__ == "__main__":
main()

View File

View File

@ -0,0 +1,614 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import contextlib
import logging
import re
from dataclasses import dataclass, field
from typing import Any, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from omegaconf import II, MISSING, open_dict
from fairseq import checkpoint_utils, tasks
from fairseq.dataclass import FairseqDataclass
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
from fairseq.models import (
BaseFairseqModel,
register_model,
)
from fairseq.models.wav2vec.wav2vec2 import MASKING_DISTRIBUTION_CHOICES
from fairseq.modules import TransposeLast
from fairseq.tasks import FairseqTask
logger = logging.getLogger(__name__)
@dataclass
class AudioClassificationConfig(FairseqDataclass):
model_path: str = field(
default=MISSING, metadata={"help": "path to wav2vec 2.0 model"}
)
no_pretrained_weights: bool = field(
default=False, metadata={"help": "if true, does not load pretrained weights"}
)
dropout_input: float = field(
default=0.0,
metadata={"help": "dropout to apply to the input (after feat extr)"},
)
final_dropout: float = field(
default=0.0,
metadata={"help": "dropout after transformer and before final projection"},
)
dropout: float = field(
default=0.0, metadata={"help": "dropout probability inside wav2vec 2.0 model"}
)
attention_dropout: float = field(
default=0.0,
metadata={
"help": "dropout probability for attention weights inside wav2vec 2.0 model"
},
)
activation_dropout: float = field(
default=0.0,
metadata={
"help": "dropout probability after activation in FFN inside wav2vec 2.0 model"
},
)
# masking
apply_mask: bool = field(
default=False, metadata={"help": "apply masking during fine-tuning"}
)
mask_length: int = field(
default=10, metadata={"help": "repeat the mask indices multiple times"}
)
mask_prob: float = field(
default=0.5,
metadata={
"help": "probability of replacing a token with mask (normalized by length)"
},
)
mask_selection: MASKING_DISTRIBUTION_CHOICES = field(
default="static", metadata={"help": "how to choose masks"}
)
mask_other: float = field(
default=0,
metadata={
"help": "secondary mask argument (used for more complex distributions), "
"see help in compute_mask_indices"
},
)
no_mask_overlap: bool = field(
default=False, metadata={"help": "whether to allow masks to overlap"}
)
mask_min_space: Optional[int] = field(
default=1,
metadata={"help": "min space between spans (if no overlap is enabled)"},
)
require_same_masks: bool = field(
default=True,
metadata={
"help": "whether to number of masked timesteps must be the same across all "
"examples in a batch"
},
)
mask_dropout: float = field(
default=0.0,
metadata={"help": "percent of masks to unmask for each sample"},
)
# channel masking
mask_channel_length: int = field(
default=10, metadata={"help": "length of the mask for features (channels)"}
)
mask_channel_prob: float = field(
default=0.0, metadata={"help": "probability of replacing a feature with 0"}
)
mask_channel_selection: MASKING_DISTRIBUTION_CHOICES = field(
default="static",
metadata={"help": "how to choose mask length for channel masking"},
)
mask_channel_other: float = field(
default=0,
metadata={
"help": "secondary mask argument (used for more complex distributions), "
"see help in compute_mask_indicesh"
},
)
no_mask_channel_overlap: bool = field(
default=False, metadata={"help": "whether to allow channel masks to overlap"}
)
freeze_finetune_updates: int = field(
default=0, metadata={"help": "dont finetune wav2vec for this many updates"}
)
feature_grad_mult: float = field(
default=0.0, metadata={"help": "reset feature grad mult in wav2vec 2.0 to this"}
)
layerdrop: float = field(
default=0.0, metadata={"help": "probability of dropping a layer in wav2vec 2.0"}
)
mask_channel_min_space: Optional[int] = field(
default=1,
metadata={"help": "min space between spans (if no overlap is enabled)"},
)
mask_channel_before: bool = False
normalize: bool = II("task.normalize")
data: str = II("task.data")
# this holds the loaded wav2vec args
d2v_args: Any = None
offload_activations: bool = field(
default=False, metadata={"help": "offload_activations"}
)
min_params_to_wrap: int = field(
default=int(1e8),
metadata={
"help": "minimum number of params for a layer to be wrapped with FSDP() when "
"training with --ddp-backend=fully_sharded. Smaller values will "
"improve memory efficiency, but may make torch.distributed "
"communication less efficient due to smaller input sizes. This option "
"is set to 0 (i.e., always wrap) when --checkpoint-activations or "
"--offload-activations are passed."
},
)
checkpoint_activations: bool = field(
default=False,
metadata={"help": "recompute activations and save memory for extra compute"},
)
ddp_backend: str = II("distributed_training.ddp_backend")
prediction_mode: str = "lin_softmax"
eval_prediction_mode: Optional[str] = None
conv_kernel: int = -1
conv_stride: int = 1
two_convs: bool = False
extreme_factor: float = 1.0
conv_feature_layers: Optional[str] = field(
default=None,
metadata={
"help": "string describing convolutional feature extraction layers in form of a python list that contains "
"[(dim, kernel_size, stride), ...]"
},
)
mixup_prob: float = 1.0
source_mixup: float = -1
same_mixup: bool = True
label_mixup: bool = False
gain_mode: str = "none"
@register_model("audio_classification", dataclass=AudioClassificationConfig)
class AudioClassificationModel(BaseFairseqModel):
def __init__(self, cfg: AudioClassificationConfig, num_classes):
super().__init__()
self.apply_mask = cfg.apply_mask
self.cfg = cfg
arg_overrides = {
"dropout": cfg.dropout,
"activation_dropout": cfg.activation_dropout,
"dropout_input": cfg.dropout_input,
"attention_dropout": cfg.attention_dropout,
"mask_length": cfg.mask_length,
"mask_prob": cfg.mask_prob,
"require_same_masks": getattr(cfg, "require_same_masks", True),
"mask_dropout": getattr(cfg, "mask_dropout", 0),
"mask_selection": cfg.mask_selection,
"mask_other": cfg.mask_other,
"no_mask_overlap": cfg.no_mask_overlap,
"mask_channel_length": cfg.mask_channel_length,
"mask_channel_prob": cfg.mask_channel_prob,
"mask_channel_before": cfg.mask_channel_before,
"mask_channel_selection": cfg.mask_channel_selection,
"mask_channel_other": cfg.mask_channel_other,
"no_mask_channel_overlap": cfg.no_mask_channel_overlap,
"encoder_layerdrop": cfg.layerdrop,
"feature_grad_mult": cfg.feature_grad_mult,
"checkpoint_activations": cfg.checkpoint_activations,
"offload_activations": cfg.offload_activations,
"min_params_to_wrap": cfg.min_params_to_wrap,
"mixup": -1,
}
if cfg.conv_feature_layers is not None:
arg_overrides["conv_feature_layers"] = cfg.conv_feature_layers
if cfg.d2v_args is None:
state = checkpoint_utils.load_checkpoint_to_cpu(
cfg.model_path, arg_overrides
)
d2v_args = state.get("cfg", None)
if d2v_args is None:
d2v_args = convert_namespace_to_omegaconf(state["args"])
d2v_args.criterion = None
d2v_args.lr_scheduler = None
cfg.d2v_args = d2v_args
logger.info(d2v_args)
else:
state = None
d2v_args = cfg.d2v_args
model_normalized = d2v_args.task.get(
"normalize", d2v_args.model.get("normalize", False)
)
assert cfg.normalize == model_normalized, (
"Fine-tuning works best when data normalization is the same. "
"Please check that --normalize is set or unset for both pre-training and here"
)
if hasattr(cfg, "checkpoint_activations") and cfg.checkpoint_activations:
with open_dict(d2v_args):
d2v_args.model.checkpoint_activations = cfg.checkpoint_activations
d2v_args.task.data = cfg.data
task = tasks.setup_task(d2v_args.task)
model = task.build_model(d2v_args.model, from_checkpoint=True)
model.remove_pretraining_modules()
if state is not None and not cfg.no_pretrained_weights:
self.load_model_weights(state, model, cfg)
d = d2v_args.model.encoder_embed_dim
self.d2v_model = model
self.final_dropout = nn.Dropout(cfg.final_dropout)
self.freeze_finetune_updates = cfg.freeze_finetune_updates
self.num_updates = 0
for p in self.parameters():
p.param_group = "pretrained"
if cfg.prediction_mode == "proj_avg_proj":
self.proj = nn.Linear(d, d * 2)
self.proj2 = nn.Linear(d * 2, num_classes)
for p in self.proj.parameters():
p.param_group = "projection"
for p in self.proj2.parameters():
p.param_group = "projection"
elif self.cfg.prediction_mode == "summary_proj":
self.proj = nn.Linear(d // 3, num_classes)
for p in self.proj.parameters():
p.param_group = "projection"
elif self.cfg.conv_kernel > 1 and not self.cfg.two_convs:
self.proj = nn.Sequential(
TransposeLast(),
nn.Conv1d(d, num_classes, kernel_size=self.cfg.conv_kernel, stride=self.cfg.conv_stride),
TransposeLast(),
)
for p in self.proj.parameters():
p.param_group = "projection"
elif self.cfg.conv_kernel > 0 and self.cfg.two_convs:
self.proj = nn.Sequential(
TransposeLast(),
nn.Conv1d(d, d, kernel_size=self.cfg.conv_kernel, stride=self.cfg.conv_stride),
TransposeLast(),
nn.GELU(),
nn.Linear(d, num_classes),
)
for p in self.proj.parameters():
p.param_group = "projection"
else:
self.proj = nn.Linear(d, num_classes)
for p in self.proj.parameters():
p.param_group = "projection"
def upgrade_state_dict_named(self, state_dict, name):
super().upgrade_state_dict_named(state_dict, name)
return state_dict
@classmethod
def build_model(cls, cfg: AudioClassificationConfig, task: FairseqTask):
"""Build a new model instance."""
assert hasattr(task, "labels"), f"Task {task} must have an attribute 'labels'"
return cls(cfg, len(task.labels))
def load_model_weights(self, state, model, cfg):
if cfg.ddp_backend == "fully_sharded":
from fairseq.distributed import FullyShardedDataParallel
for name, module in model.named_modules():
if "encoder.layers" in name and len(name.split(".")) == 3:
# Only for layers, we do a special handling and load the weights one by one
# We dont load all weights together as that wont be memory efficient and may
# cause oom
new_dict = {
k.replace(name + ".", ""): v
for (k, v) in state["model"].items()
if name + "." in k
}
assert isinstance(module, FullyShardedDataParallel)
with module.summon_full_params():
module.load_state_dict(new_dict, strict=True)
module._reset_lazy_init()
# Once layers are loaded, filter them out and load everything else.
r = re.compile("encoder.layers.\d.")
filtered_list = list(filter(r.match, state["model"].keys()))
new_big_dict = {
k: v for (k, v) in state["model"].items() if k not in filtered_list
}
model.load_state_dict(new_big_dict, strict=False)
else:
if "_ema" in state["model"]:
del state["model"]["_ema"]
model.load_state_dict(state["model"], strict=False)
def set_num_updates(self, num_updates):
"""Set the number of parameters updates."""
super().set_num_updates(num_updates)
self.num_updates = num_updates
def compute_gain(self, sound, fs=16_000, min_db=-80.0, mode="A_weighting"):
if fs == 16000:
n_fft = 2048
elif fs == 44100:
n_fft = 4096
else:
raise Exception("Invalid fs {}".format(fs))
stride = n_fft // 2
def a_weight(fs, n_fft, min_db=-80.0):
freq = np.linspace(0, fs // 2, n_fft // 2 + 1)
freq_sq = np.power(freq, 2)
freq_sq[0] = 1.0
weight = 2.0 + 20.0 * (
2 * np.log10(12194)
+ 2 * np.log10(freq_sq)
- np.log10(freq_sq + 12194 ** 2)
- np.log10(freq_sq + 20.6 ** 2)
- 0.5 * np.log10(freq_sq + 107.7 ** 2)
- 0.5 * np.log10(freq_sq + 737.9 ** 2)
)
weight = np.maximum(weight, min_db)
return weight
gain = []
for i in range(0, len(sound) - n_fft + 1, stride):
if mode == "RMSE":
g = np.mean(sound[i : i + n_fft] ** 2)
elif mode == "A_weighting":
spec = np.fft.rfft(np.hanning(n_fft + 1)[:-1] * sound[i : i + n_fft])
power_spec = np.abs(spec) ** 2
a_weighted_spec = power_spec * np.power(10, a_weight(fs, n_fft) / 10)
g = np.sum(a_weighted_spec)
else:
raise Exception("Invalid mode {}".format(mode))
gain.append(g)
gain = np.array(gain)
gain = np.maximum(gain, np.power(10, min_db / 10))
gain_db = 10 * np.log10(gain)
return gain_db
# adapted from https://github.com/mil-tokyo/bc_learning_sound/blob/master/utils.py
def compute_gain_torch(self, sound, fs=16_000, min_db=-80.0, mode="A_weighting"):
if fs == 16000:
n_fft = 2048
elif fs == 44100:
n_fft = 4096
else:
raise Exception("Invalid fs {}".format(fs))
if mode == "A_weighting":
if not hasattr(self, f"a_weight"):
self.a_weight = {}
if fs not in self.a_weight:
def a_weight(fs, n_fft, min_db=-80.0):
freq = np.linspace(0, fs // 2, n_fft // 2 + 1)
freq_sq = freq ** 2
freq_sq[0] = 1.0
weight = 2.0 + 20.0 * (
2 * np.log10(12194)
+ 2 * np.log10(freq_sq)
- np.log10(freq_sq + 12194 ** 2)
- np.log10(freq_sq + 20.6 ** 2)
- 0.5 * np.log10(freq_sq + 107.7 ** 2)
- 0.5 * np.log10(freq_sq + 737.9 ** 2)
)
weight = np.maximum(weight, min_db)
return weight
self.a_weight[fs] = torch.from_numpy(
np.power(10, a_weight(fs, n_fft, min_db) / 10)
).to(device=sound.device)
sound = sound.unfold(-1, n_fft, n_fft // 2)
if mode == "RMSE":
sound = sound ** 2
g = sound.mean(-1)
elif mode == "A_weighting":
w = torch.hann_window(n_fft, device=sound.device) * sound
spec = torch.fft.rfft(w)
power_spec = spec.abs() ** 2
a_weighted_spec = power_spec * self.a_weight[fs]
g = a_weighted_spec.sum(-1)
else:
raise Exception("Invalid mode {}".format(mode))
gain = torch.maximum(g, torch.tensor(10 ** (min_db / 10), device=g.device))
gain_db = 10 * torch.log10(gain)
return gain_db
def forward(self, source, padding_mask, label=None, **kwargs):
if self.cfg.source_mixup >= 0 and self.training and self.cfg.mixup_prob > 0:
with torch.no_grad():
mixed_source = source
mix_mask = None
if self.cfg.mixup_prob < 1:
mix_mask = (
torch.empty((source.size(0),), device=source.device)
.bernoulli_(self.cfg.mixup_prob)
.bool()
)
mixed_source = source[mix_mask]
r = (
torch.FloatTensor(
1 if self.cfg.same_mixup else mixed_source.size(0)
)
.uniform_(max(1e-6, self.cfg.source_mixup), 1)
.to(dtype=source.dtype, device=source.device)
)
mixup_perm = torch.randperm(source.size(0))
s2 = source[mixup_perm]
if self.cfg.gain_mode == "none":
p = r.unsqueeze(-1)
if mix_mask is not None:
s2 = s2[mix_mask]
else:
if self.cfg.gain_mode == "naive_rms":
G1 = source.pow(2).mean(dim=-1).sqrt()
else:
G1, _ = self.compute_gain_torch(
source, mode=self.cfg.gain_mode
).max(-1)
G1 = G1.to(dtype=source.dtype)
G2 = G1[mixup_perm]
if mix_mask is not None:
G1 = G1[mix_mask]
G2 = G2[mix_mask]
s2 = s2[mix_mask]
p = 1 / (1 + 10 ** ((G1 - G2) / 20) * (1 - r) / r)
p = p.unsqueeze(-1)
mixed = (p * mixed_source) + (1 - p) * s2
if mix_mask is None:
source = mixed / torch.sqrt(p ** 2 + (1 - p) ** 2)
else:
source[mix_mask] = mixed / torch.sqrt(p ** 2 + (1 - p) ** 2)
if label is not None and self.cfg.label_mixup:
r = r.unsqueeze(-1)
if mix_mask is None:
label = label * r + (1 - r) * label[mixup_perm]
else:
label[mix_mask] = (
label[mix_mask] * r + (1 - r) * label[mixup_perm][mix_mask]
)
d2v_args = {
"source": source,
"padding_mask": padding_mask,
"mask": self.apply_mask and self.training,
}
ft = self.freeze_finetune_updates <= self.num_updates
with torch.no_grad() if not ft else contextlib.ExitStack():
res = self.d2v_model.extract_features(**d2v_args)
x = res["x"]
padding_mask = res["padding_mask"]
if padding_mask is not None:
x[padding_mask] = 0
x = self.final_dropout(x)
if self.training or (
self.cfg.eval_prediction_mode is None or self.cfg.eval_prediction_mode == ""
):
prediction_mode = self.cfg.prediction_mode
else:
prediction_mode = self.cfg.eval_prediction_mode
if prediction_mode == "average_before":
x = x.mean(dim=1)
if prediction_mode != "summary_mha" and prediction_mode != "summary_proj" and prediction_mode != "cls":
x = self.proj(x)
logits = True
if prediction_mode == "lin_softmax":
x = F.logsigmoid(x.float())
x = torch.logsumexp(x + x, dim=1) - torch.logsumexp(x, dim=1)
x = x.clamp(max=0)
x = x - torch.log(-(torch.expm1(x)))
elif prediction_mode == "extremized_odds":
x = x.float().sum(dim=1)
x = x * self.cfg.extreme_factor
elif prediction_mode == "average_before":
x = x.float()
elif prediction_mode == "average":
x = x.float().mean(dim=1)
elif prediction_mode == "average_sigmoid":
x = torch.sigmoid(x.float())
x = x.mean(dim=1)
logits = False
elif prediction_mode == "max":
x, _ = x.float().max(dim=1)
elif prediction_mode == "max_sigmoid":
x = torch.sigmoid(x.float())
x, _ = x.float().max(dim=1)
logits = False
elif prediction_mode == "proj_avg_proj":
x = x.mean(dim=1)
x = self.proj2(x)
elif prediction_mode == "summary_mha" or prediction_mode == "summary_proj":
x = self.d2v_model.summary(
x, padding_mask, proj=prediction_mode == "summary_proj"
)
x = x.type_as(source)
x = self.proj(x)
elif prediction_mode == "cls":
x = x[:,0]
x = self.proj(x)
else:
raise Exception(f"unknown prediction mode {prediction_mode}")
if label is None:
return torch.sigmoid(x) if logits else x
x = torch.nan_to_num(x)
if logits:
loss = F.binary_cross_entropy_with_logits(
x, label.float(), reduction="none"
)
else:
loss = F.binary_cross_entropy(x, label.float(), reduction="none")
result = {
"losses": {
"main": loss,
},
"sample_size": label.sum(),
}
if not self.training:
result["_predictions"] = torch.sigmoid(x) if logits else x
result["_targets"] = label
return result

View File

@ -0,0 +1,813 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
import math
from dataclasses import dataclass, field
from typing import Optional, Callable
from functools import partial
import numpy as np
from omegaconf import II
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
from fairseq.modules import EMAModule, EMAModuleConfig
from fairseq.dataclass import FairseqDataclass
from fairseq.models import BaseFairseqModel, register_model
from examples.data2vec.data.modality import Modality
from examples.data2vec.models.modalities.base import (
MaskSeed,
D2vModalityConfig,
ModalitySpecificEncoder,
get_annealed_rate,
)
from examples.data2vec.models.modalities.modules import (
D2vDecoderConfig,
AltBlock,
Decoder1d,
)
from examples.data2vec.models.modalities.audio import (
D2vAudioConfig,
AudioEncoder,
)
from examples.data2vec.models.modalities.images import (
D2vImageConfig,
ImageEncoder,
)
from examples.data2vec.models.modalities.text import (
D2vTextConfig,
TextEncoder,
)
logger = logging.getLogger(__name__)
@dataclass
class D2vModalitiesConfig(FairseqDataclass):
audio: D2vAudioConfig = D2vAudioConfig()
image: D2vImageConfig = D2vImageConfig()
text: D2vTextConfig = D2vTextConfig()
@dataclass
class Data2VecMultiConfig(FairseqDataclass):
loss_beta: float = field(
default=0, metadata={"help": "beta for smooth l1 loss. 0 means use l2 loss"}
)
loss_scale: Optional[float] = field(
default=None,
metadata={
"help": "scale the reconstruction loss by this constant. if None then scales by 1/sqrt(dim)"
},
)
depth: int = 8
start_drop_path_rate: float = 0
end_drop_path_rate: float = 0
num_heads: int = 12
norm_eps: float = 1e-6
norm_affine: bool = True
encoder_dropout: float = 0.1
post_mlp_drop: float = 0.1
attention_dropout: float = 0.1
activation_dropout: float = 0.0
dropout_input: float = 0.0
layerdrop: float = 0.0
embed_dim: int = 768
mlp_ratio: float = 4
layer_norm_first: bool = False
average_top_k_layers: int = field(
default=8, metadata={"help": "how many layers to average"}
)
end_of_block_targets: bool = False
clone_batch: int = 1
layer_norm_target_layer: bool = False
batch_norm_target_layer: bool = False
instance_norm_target_layer: bool = False
instance_norm_targets: bool = False
layer_norm_targets: bool = False
ema_decay: float = field(default=0.999, metadata={"help": "initial ema decay rate"})
ema_same_dtype: bool = True
log_norms: bool = True
ema_end_decay: float = field(
default=0.9999, metadata={"help": "final ema decay rate"}
)
# when to finish annealing ema decay rate
ema_anneal_end_step: int = II("optimization.max_update")
ema_encoder_only: bool = field(
default=True,
metadata={
"help": "whether to momentum update only the shared transformer encoder"
},
)
max_update: int = II("optimization.max_update")
modalities: D2vModalitiesConfig = D2vModalitiesConfig()
shared_decoder: Optional[D2vDecoderConfig] = None
min_target_var: float = field(
default=0.1, metadata={"help": "stop training if target var falls below this"}
)
min_pred_var: float = field(
default=0.01,
metadata={"help": "stop training if prediction var falls below this"},
)
supported_modality: Optional[Modality] = None
mae_init: bool = False
seed: int = II("common.seed")
skip_ema: bool = False
cls_loss: float = 0
recon_loss: float = 0
d2v_loss: float = 1
decoder_group: bool = False
@register_model("data2vec_multi", dataclass=Data2VecMultiConfig)
class Data2VecMultiModel(BaseFairseqModel):
def make_modality_encoder(
self,
cfg: D2vModalityConfig,
embed_dim: int,
make_block: Callable[[float], nn.ModuleList],
norm_layer: Callable[[int], nn.LayerNorm],
layer_norm_first: bool,
alibi_biases,
task,
) -> ModalitySpecificEncoder:
if cfg.type == Modality.AUDIO:
enc_cls = AudioEncoder
elif cfg.type == Modality.IMAGE:
enc_cls = ImageEncoder
elif cfg.type == Modality.TEXT:
enc_cls = TextEncoder
if hasattr(task, "text_task"):
task = task.text_task
else:
raise Exception(f"unsupported modality {cfg.type}")
return enc_cls(
cfg,
embed_dim,
make_block,
norm_layer,
layer_norm_first,
alibi_biases,
task,
)
def __init__(self, cfg: Data2VecMultiConfig, modalities, skip_ema=False, task=None):
super().__init__()
self.cfg = cfg
self.modalities = modalities
self.task = task
make_layer_norm = partial(
nn.LayerNorm, eps=cfg.norm_eps, elementwise_affine=cfg.norm_affine
)
def make_block(drop_path, dim=None, heads=None):
return AltBlock(
cfg.embed_dim if dim is None else dim,
cfg.num_heads if heads is None else heads,
cfg.mlp_ratio,
qkv_bias=True,
drop=cfg.encoder_dropout,
attn_drop=cfg.attention_dropout,
mlp_drop=cfg.activation_dropout,
post_mlp_drop=cfg.post_mlp_drop,
drop_path=drop_path,
norm_layer=make_layer_norm,
layer_norm_first=cfg.layer_norm_first,
ffn_targets=not cfg.end_of_block_targets,
)
self.alibi_biases = {}
self.modality_encoders = nn.ModuleDict()
for mod in self.modalities:
mod_cfg = getattr(cfg.modalities, mod.name.lower())
enc = self.make_modality_encoder(
mod_cfg,
cfg.embed_dim,
make_block,
make_layer_norm,
cfg.layer_norm_first,
self.alibi_biases,
task,
)
self.modality_encoders[mod.name] = enc
self.ema = None
self.average_top_k_layers = cfg.average_top_k_layers
self.loss_beta = cfg.loss_beta
self.loss_scale = cfg.loss_scale
self.dropout_input = nn.Dropout(cfg.dropout_input)
dpr = np.linspace(cfg.start_drop_path_rate, cfg.end_drop_path_rate, cfg.depth)
self.blocks = nn.ModuleList([make_block(dpr[i]) for i in range(cfg.depth)])
self.norm = None
if cfg.layer_norm_first:
self.norm = make_layer_norm(cfg.embed_dim)
if self.cfg.mae_init:
self.apply(self._init_weights)
else:
from fairseq.modules.transformer_sentence_encoder import init_bert_params
self.apply(init_bert_params)
for mod_enc in self.modality_encoders.values():
mod_enc.reset_parameters()
if not skip_ema:
self.ema = self.make_ema_teacher(cfg.ema_decay)
self.shared_decoder = (
Decoder1d(cfg.shared_decoder, cfg.embed_dim)
if self.cfg.shared_decoder is not None
else None
)
if self.shared_decoder is not None:
self.shared_decoder.apply(self._init_weights)
self.recon_proj = None
if cfg.recon_loss > 0:
self.recon_proj = nn.Linear(cfg.embed_dim, cfg.embed_dim)
for pn, p in self.named_parameters():
if len(p.shape) == 1 or pn.endswith(".bias") or "alibi_scale" in pn:
p.optim_overrides = {"optimizer": {"weight_decay_scale": 0}}
if cfg.decoder_group and "decoder" in pn:
p.param_group = "decoder"
self.num_updates = 0
def _init_weights(self, m):
try:
from apex.normalization import FusedLayerNorm
fn = FusedLayerNorm
except:
fn = nn.LayerNorm
if isinstance(m, nn.Linear):
torch.nn.init.xavier_uniform_(m.weight)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm) or isinstance(m, fn):
if m.bias is not None:
nn.init.constant_(m.bias, 0)
if m.weight is not None:
nn.init.constant_(m.weight, 1.0)
@torch.no_grad()
def make_ema_teacher(self, ema_decay):
ema_config = EMAModuleConfig(
ema_decay=ema_decay,
ema_fp32=True,
log_norms=self.cfg.log_norms,
add_missing_params=False,
)
model_copy = self.make_target_model()
return EMAModule(
model_copy,
ema_config,
copy_model=False,
)
def make_target_model(self):
logger.info("making target model")
model_copy = Data2VecMultiModel(
self.cfg, self.modalities, skip_ema=True, task=self.task
)
if self.cfg.ema_encoder_only:
model_copy = model_copy.blocks
for p_s, p_t in zip(self.blocks.parameters(), model_copy.parameters()):
p_t.data.copy_(p_s.data)
else:
for p_s, p_t in zip(self.parameters(), model_copy.parameters()):
p_t.data.copy_(p_s.data)
for mod_enc in model_copy.modality_encoders.values():
mod_enc.decoder = None
if not mod_enc.modality_cfg.ema_local_encoder:
mod_enc.local_encoder = None
mod_enc.project_features = None
model_copy.requires_grad_(False)
return model_copy
def set_num_updates(self, num_updates):
super().set_num_updates(num_updates)
if self.ema is not None and (
(self.num_updates == 0 and num_updates > 1)
or self.num_updates >= num_updates
):
pass
elif self.training and self.ema is not None:
ema_weight_decay = None
if self.cfg.ema_decay != self.cfg.ema_end_decay:
if num_updates >= self.cfg.ema_anneal_end_step:
decay = self.cfg.ema_end_decay
else:
decay = get_annealed_rate(
self.cfg.ema_decay,
self.cfg.ema_end_decay,
num_updates,
self.cfg.ema_anneal_end_step,
)
self.ema.set_decay(decay, weight_decay=ema_weight_decay)
if self.ema.get_decay() < 1:
self.ema.step(self.blocks if self.cfg.ema_encoder_only else self)
self.num_updates = num_updates
def state_dict(self, destination=None, prefix="", keep_vars=False):
state = super().state_dict(destination, prefix, keep_vars)
if self.ema is not None:
state[prefix + "_ema"] = self.ema.fp32_params
return state
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
k = prefix + "_ema"
if self.ema is not None:
assert k in state_dict
self.ema.restore(state_dict[k], True)
del state_dict[k]
elif k in state_dict:
del state_dict[k]
return super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
@classmethod
def build_model(cls, cfg: Data2VecMultiConfig, task=None):
"""Build a new model instance."""
if task is None or not hasattr(task, "supported_modalities"):
modalities = (
[cfg.supported_modality]
if cfg.supported_modality is not None
else [
Modality.AUDIO,
Modality.IMAGE,
Modality.TEXT,
]
)
else:
modalities = task.supported_modalities
return cls(cfg, modalities, task=task, skip_ema=cfg.skip_ema)
def forward(
self,
source,
target=None,
id=None,
mode=None,
padding_mask=None,
mask=True,
features_only=False,
force_remove_masked=False,
remove_extra_tokens=True,
precomputed_mask=None,
):
if mode is None:
assert self.cfg.supported_modality is not None
mode = self.cfg.supported_modality
if isinstance(mode, Modality):
mode = mode.name
feature_extractor = self.modality_encoders[mode]
mask_seeds = None
if id is not None:
mask_seeds = MaskSeed(seed=self.cfg.seed, update=self.num_updates, ids=id)
extractor_out = feature_extractor(
source,
padding_mask,
mask,
remove_masked=not features_only or force_remove_masked,
clone_batch=self.cfg.clone_batch if not features_only else 1,
mask_seeds=mask_seeds,
precomputed_mask=precomputed_mask,
)
x = extractor_out["x"]
encoder_mask = extractor_out["encoder_mask"]
masked_padding_mask = extractor_out["padding_mask"]
masked_alibi_bias = extractor_out.get("alibi_bias", None)
alibi_scale = extractor_out.get("alibi_scale", None)
if self.dropout_input is not None:
x = self.dropout_input(x)
layer_results = []
for i, blk in enumerate(self.blocks):
if (
not self.training
or self.cfg.layerdrop == 0
or (np.random.random() > self.cfg.layerdrop)
):
ab = masked_alibi_bias
if ab is not None and alibi_scale is not None:
scale = (
alibi_scale[i]
if alibi_scale.size(0) > 1
else alibi_scale.squeeze(0)
)
ab = ab * scale.type_as(ab)
x, lr = blk(
x,
padding_mask=masked_padding_mask,
alibi_bias=ab,
)
if features_only:
layer_results.append(lr)
if self.norm is not None:
x = self.norm(x)
if features_only:
if remove_extra_tokens:
x = x[:, feature_extractor.modality_cfg.num_extra_tokens :]
if masked_padding_mask is not None:
masked_padding_mask = masked_padding_mask[
:, feature_extractor.modality_cfg.num_extra_tokens :
]
return {
"x": x,
"padding_mask": masked_padding_mask,
"layer_results": layer_results,
"mask": encoder_mask,
}
xs = []
if self.shared_decoder is not None:
dx = self.forward_decoder(
x,
feature_extractor,
self.shared_decoder,
encoder_mask,
)
xs.append(dx)
if feature_extractor.decoder is not None:
dx = self.forward_decoder(
x,
feature_extractor,
feature_extractor.decoder,
encoder_mask,
)
xs.append(dx)
orig_x = x
assert len(xs) > 0
p = next(self.ema.model.parameters())
device = x.device
dtype = x.dtype
ema_device = p.device
ema_dtype = p.dtype
if not self.cfg.ema_same_dtype:
dtype = ema_dtype
if ema_device != device or ema_dtype != dtype:
logger.info(f"adjusting ema dtype to {dtype} and device to {device}")
self.ema.model = self.ema.model.to(dtype=dtype, device=device)
ema_dtype = dtype
def to_device(d):
for k, p in d.items():
if isinstance(d[k], dict):
to_device(d[k])
else:
d[k] = p.to(device=device)
to_device(self.ema.fp32_params)
tm = self.ema.model
with torch.no_grad():
tm.eval()
if self.cfg.ema_encoder_only:
assert target is None
ema_input = extractor_out["local_features"]
ema_input = feature_extractor.contextualized_features(
ema_input.to(dtype=ema_dtype),
padding_mask,
mask=False,
remove_masked=False,
)
ema_blocks = tm
else:
ema_blocks = tm.blocks
if feature_extractor.modality_cfg.ema_local_encoder:
inp = (
target.to(dtype=ema_dtype)
if target is not None
else source.to(dtype=ema_dtype)
)
ema_input = tm.modality_encoders[mode](
inp,
padding_mask,
mask=False,
remove_masked=False,
)
else:
assert target is None
ema_input = extractor_out["local_features"]
ema_feature_enc = tm.modality_encoders[mode]
ema_input = ema_feature_enc.contextualized_features(
ema_input.to(dtype=ema_dtype),
padding_mask,
mask=False,
remove_masked=False,
)
ema_padding_mask = ema_input["padding_mask"]
ema_alibi_bias = ema_input.get("alibi_bias", None)
ema_alibi_scale = ema_input.get("alibi_scale", None)
ema_input = ema_input["x"]
y = []
ema_x = []
extra_tokens = feature_extractor.modality_cfg.num_extra_tokens
for i, blk in enumerate(ema_blocks):
ab = ema_alibi_bias
if ab is not None and alibi_scale is not None:
scale = (
ema_alibi_scale[i]
if ema_alibi_scale.size(0) > 1
else ema_alibi_scale.squeeze(0)
)
ab = ab * scale.type_as(ab)
ema_input, lr = blk(
ema_input,
padding_mask=ema_padding_mask,
alibi_bias=ab,
)
y.append(lr[:, extra_tokens:])
ema_x.append(ema_input[:, extra_tokens:])
y = self.make_targets(y, self.average_top_k_layers)
orig_targets = y
if self.cfg.clone_batch > 1:
y = y.repeat_interleave(self.cfg.clone_batch, 0)
masked = encoder_mask.mask.unsqueeze(-1)
masked_b = encoder_mask.mask.bool()
y = y[masked_b]
if xs[0].size(1) == masked_b.size(1):
xs = [x[masked_b] for x in xs]
else:
xs = [x.reshape(-1, x.size(-1)) for x in xs]
sample_size = masked.sum().long()
result = {
"losses": {},
"sample_size": sample_size,
}
sample_size = result["sample_size"]
if self.cfg.cls_loss > 0:
assert extra_tokens > 0
cls_target = orig_targets.mean(dim=1)
if self.cfg.clone_batch > 1:
cls_target = cls_target.repeat_interleave(self.cfg.clone_batch, 0)
cls_pred = x[:, extra_tokens - 1]
result["losses"]["cls"] = self.d2v_loss(cls_pred, cls_target) * (
self.cfg.cls_loss * sample_size
)
if self.cfg.recon_loss > 0:
with torch.no_grad():
target = feature_extractor.patchify(source)
mean = target.mean(dim=-1, keepdim=True)
var = target.var(dim=-1, keepdim=True)
target = (target - mean) / (var + 1.0e-6) ** 0.5
if self.cfg.clone_batch > 1:
target = target.repeat_interleave(self.cfg.clone_batch, 0)
if masked_b is not None:
target = target[masked_b]
recon = xs[0]
if self.recon_proj is not None:
recon = self.recon_proj(recon)
result["losses"]["recon"] = (
self.d2v_loss(recon, target.float()) * self.cfg.recon_loss
)
if self.cfg.d2v_loss > 0:
for i, x in enumerate(xs):
reg_loss = self.d2v_loss(x, y)
n = f"{mode}_regression_{i}" if len(xs) > 1 else f"{mode}_regression"
result["losses"][n] = reg_loss * self.cfg.d2v_loss
suffix = "" if len(self.modalities) == 1 else f"_{mode}"
with torch.no_grad():
if encoder_mask is not None:
result["masked_pct"] = 1 - (
encoder_mask.ids_keep.size(1) / encoder_mask.ids_restore.size(1)
)
for i, x in enumerate(xs):
n = f"pred_var{suffix}_{i}" if len(xs) > 1 else f"pred_var{suffix}"
result[n] = self.compute_var(x.float())
if self.ema is not None:
for k, v in self.ema.logs.items():
result[k] = v
y = y.float()
result[f"target_var{suffix}"] = self.compute_var(y)
if self.num_updates > 5000:
if result[f"target_var{suffix}"] < self.cfg.min_target_var:
logger.error(
f"target var is {result[f'target_var{suffix}'].item()} < {self.cfg.min_target_var}, exiting ({mode})"
)
raise Exception(
f"target var is {result[f'target_var{suffix}'].item()} < {self.cfg.min_target_var}, exiting ({mode})"
)
for k in result.keys():
if k.startswith("pred_var") and result[k] < self.cfg.min_pred_var:
logger.error(
f"{k} is {result[k].item()} < {self.cfg.min_pred_var}, exiting ({mode})"
)
raise Exception(
f"{k} is {result[k].item()} < {self.cfg.min_pred_var}, exiting ({mode})"
)
result["ema_decay"] = self.ema.get_decay() * 1000
return result
def forward_decoder(
self,
x,
feature_extractor,
decoder,
mask_info,
):
x = feature_extractor.decoder_input(x, mask_info)
x = decoder(*x)
return x
def d2v_loss(self, x, y):
x = x.view(-1, x.size(-1)).float()
y = y.view(-1, x.size(-1))
if self.loss_beta == 0:
loss = F.mse_loss(x, y, reduction="none")
else:
loss = F.smooth_l1_loss(x, y, reduction="none", beta=self.loss_beta)
if self.loss_scale is not None:
scale = self.loss_scale
else:
scale = 1 / math.sqrt(x.size(-1))
reg_loss = loss * scale
return reg_loss
def make_targets(self, y, num_layers):
with torch.no_grad():
target_layer_results = y[-num_layers:]
permuted = False
if self.cfg.instance_norm_target_layer or self.cfg.batch_norm_target_layer:
target_layer_results = [
tl.transpose(1, 2) for tl in target_layer_results # BTC -> BCT
]
permuted = True
if self.cfg.batch_norm_target_layer:
target_layer_results = [
F.batch_norm(
tl.float(), running_mean=None, running_var=None, training=True
)
for tl in target_layer_results
]
if self.cfg.instance_norm_target_layer:
target_layer_results = [
F.instance_norm(tl.float()) for tl in target_layer_results
]
if permuted:
target_layer_results = [
tl.transpose(1, 2) for tl in target_layer_results # BCT -> BTC
]
if self.cfg.layer_norm_target_layer:
target_layer_results = [
F.layer_norm(tl.float(), tl.shape[-1:])
for tl in target_layer_results
]
y = target_layer_results[0].float()
for tl in target_layer_results[1:]:
y.add_(tl.float())
y = y.div_(len(target_layer_results))
if self.cfg.layer_norm_targets:
y = F.layer_norm(y, y.shape[-1:])
if self.cfg.instance_norm_targets:
y = F.instance_norm(y.transpose(1, 2)).transpose(1, 2)
return y
@staticmethod
def compute_var(y):
y = y.view(-1, y.size(-1))
if dist.is_initialized():
zc = torch.tensor(y.size(0)).cuda()
zs = y.sum(dim=0)
zss = (y**2).sum(dim=0)
dist.all_reduce(zc)
dist.all_reduce(zs)
dist.all_reduce(zss)
var = zss / (zc - 1) - (zs**2) / (zc * (zc - 1))
return torch.sqrt(var + 1e-6).mean()
else:
return torch.sqrt(y.var(dim=0) + 1e-6).mean()
def extract_features(
self, source, mode=None, padding_mask=None, mask=False, remove_extra_tokens=True
):
res = self.forward(
source,
mode=mode,
padding_mask=padding_mask,
mask=mask,
features_only=True,
remove_extra_tokens=remove_extra_tokens,
)
return res
def remove_pretraining_modules(self, modality=None, keep_decoder=False):
self.ema = None
self.cfg.clone_batch = 1
self.recon_proj = None
if not keep_decoder:
self.shared_decoder = None
modality = modality.lower() if modality is not None else None
for k in list(self.modality_encoders.keys()):
if modality is not None and k.lower() != modality:
del self.modality_encoders[k]
else:
self.modality_encoders[k].remove_pretraining_modules(
keep_decoder=keep_decoder
)
if not keep_decoder:
self.modality_encoders[k].decoder = None

View File

@ -0,0 +1,143 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# The code in this file is adapted from the BeiT implementation which can be found here:
# https://github.com/microsoft/unilm/tree/master/beit
import logging
from dataclasses import dataclass
from typing import Any
from omegaconf import II, MISSING
import torch
import torch.nn as nn
import torch.nn.functional as F
from fairseq import checkpoint_utils, tasks
from fairseq.dataclass import FairseqDataclass
from fairseq.models import BaseFairseqModel, register_model
logger = logging.getLogger(__name__)
@dataclass
class Data2VecImageClassificationConfig(FairseqDataclass):
model_path: str = MISSING
no_pretrained_weights: bool = False
num_classes: int = 1000
mixup: float = 0.8
cutmix: float = 1.0
label_smoothing: float = 0.1
pretrained_model_args: Any = None
data: str = II("task.data")
@register_model(
"data2vec_image_classification", dataclass=Data2VecImageClassificationConfig
)
class Data2VecImageClassificationModel(BaseFairseqModel):
def __init__(self, cfg: Data2VecImageClassificationConfig):
super().__init__()
self.cfg = cfg
if cfg.pretrained_model_args is None:
state = checkpoint_utils.load_checkpoint_to_cpu(cfg.model_path, {})
pretrained_args = state.get("cfg", None)
pretrained_args.criterion = None
pretrained_args.lr_scheduler = None
cfg.pretrained_model_args = pretrained_args
logger.info(pretrained_args)
else:
state = None
pretrained_args = cfg.pretrained_model_args
pretrained_args.task.data = cfg.data
task = tasks.setup_task(pretrained_args.task)
model = task.build_model(pretrained_args.model, from_checkpoint=True)
model.remove_pretraining_modules()
self.model = model
if state is not None and not cfg.no_pretrained_weights:
self.load_model_weights(state, model, cfg)
self.fc_norm = nn.LayerNorm(pretrained_args.model.embed_dim)
self.head = nn.Linear(pretrained_args.model.embed_dim, cfg.num_classes)
self.head.weight.data.mul_(1e-3)
self.head.bias.data.mul_(1e-3)
self.mixup_fn = None
if cfg.mixup > 0 or cfg.cutmix > 0:
from timm.data import Mixup
self.mixup_fn = Mixup(
mixup_alpha=cfg.mixup,
cutmix_alpha=cfg.cutmix,
cutmix_minmax=None,
prob=1.0,
switch_prob=0.5,
mode="batch",
label_smoothing=cfg.label_smoothing,
num_classes=cfg.num_classes,
)
def load_model_weights(self, state, model, cfg):
if "_ema" in state["model"]:
del state["model"]["_ema"]
model.load_state_dict(state["model"], strict=True)
@classmethod
def build_model(cls, cfg: Data2VecImageClassificationConfig, task=None):
"""Build a new model instance."""
return cls(cfg)
def forward(
self,
img,
label=None,
):
if self.training and self.mixup_fn is not None and label is not None:
img, label = self.mixup_fn(img, label)
x = self.model(img, mask=False)
x = x[:, 1:]
x = self.fc_norm(x.mean(1))
x = self.head(x)
if label is None:
return x
if self.training and self.mixup_fn is not None:
loss = -label * F.log_softmax(x.float(), dim=-1)
else:
loss = F.cross_entropy(
x.float(),
label,
label_smoothing=self.cfg.label_smoothing if self.training else 0,
reduction="none",
)
result = {
"losses": {"regression": loss},
"sample_size": img.size(0),
}
if not self.training:
with torch.no_grad():
pred = x.argmax(-1)
correct = (pred == label).sum()
result["correct"] = correct
return result

View File

@ -0,0 +1,141 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# The code in this file is adapted from the BeiT implementation which can be found here:
# https://github.com/microsoft/unilm/tree/master/beit
import logging
from dataclasses import dataclass
from typing import Any
from omegaconf import II, MISSING
import torch
import torch.nn as nn
import torch.nn.functional as F
from fairseq import checkpoint_utils, tasks
from fairseq.dataclass import FairseqDataclass
from fairseq.models import BaseFairseqModel, register_model
from fairseq.models.roberta.model import RobertaClassificationHead
from examples.data2vec.data.modality import Modality
logger = logging.getLogger(__name__)
@dataclass
class Data2VecTextClassificationConfig(FairseqDataclass):
pooler_dropout: float = 0.0
pooler_activation_fn: str = "tanh"
quant_noise_pq: int = 0
quant_noise_pq_block_size: int = 8
spectral_norm_classification_head: bool = False
model_path: str = MISSING
no_pretrained_weights: bool = False
pretrained_model_args: Any = None
@register_model(
"data2vec_text_classification", dataclass=Data2VecTextClassificationConfig
)
class Data2VecTextClassificationModel(BaseFairseqModel):
def __init__(self, cfg: Data2VecTextClassificationConfig):
super().__init__()
self.cfg = cfg
if cfg.pretrained_model_args is None:
state = checkpoint_utils.load_checkpoint_to_cpu(cfg.model_path, {})
pretrained_args = state.get("cfg", None)
pretrained_args.criterion = None
pretrained_args.lr_scheduler = None
cfg.pretrained_model_args = pretrained_args
logger.info(pretrained_args)
else:
state = None
pretrained_args = cfg.pretrained_model_args
task = tasks.setup_task(pretrained_args.task)
model = task.build_model(pretrained_args.model, from_checkpoint=True)
model.remove_pretraining_modules()
self.model = model
if state is not None and not cfg.no_pretrained_weights:
self.load_model_weights(state, model, cfg)
self.classification_heads = nn.ModuleDict()
def load_model_weights(self, state, model, cfg):
for k in list(state["model"].keys()):
if (
k.startswith("shared_decoder") or
k.startswith("_ema") or
"decoder" in k
):
logger.info(f"Deleting {k} from checkpoint")
del state["model"][k]
model.load_state_dict(state["model"], strict=True)
@classmethod
def build_model(cls, cfg: Data2VecTextClassificationConfig, task=None):
"""Build a new model instance."""
return cls(cfg)
def register_classification_head(
self, name, num_classes=None, inner_dim=None, **kwargs
):
"""Register a classification head."""
if name in self.classification_heads:
prev_num_classes = self.classification_heads[name].out_proj.out_features
prev_inner_dim = self.classification_heads[name].dense.out_features
if num_classes != prev_num_classes or inner_dim != prev_inner_dim:
logger.warning(
're-registering head "{}" with num_classes {} (prev: {}) '
"and inner_dim {} (prev: {})".format(
name, num_classes, prev_num_classes, inner_dim, prev_inner_dim
)
)
embed_dim = self.cfg.pretrained_model_args.model.embed_dim
self.classification_heads[name] = RobertaClassificationHead(
input_dim=embed_dim,
inner_dim=inner_dim or embed_dim,
num_classes=num_classes,
activation_fn=self.cfg.pooler_activation_fn,
pooler_dropout=self.cfg.pooler_dropout,
q_noise=self.cfg.quant_noise_pq,
qn_block_size=self.cfg.quant_noise_pq_block_size,
do_spectral_norm=self.cfg.spectral_norm_classification_head,
)
def forward(
self,
source,
id,
padding_mask,
features_only=True,
remove_extra_tokens=True,
classification_head_name=None,
):
encoder_out = self.model(
source,
id=id,
mode=Modality.TEXT,
padding_mask=padding_mask,
mask=False,
features_only=features_only,
remove_extra_tokens=remove_extra_tokens
)
logits = self.classification_heads[classification_head_name](encoder_out["x"])
return logits, encoder_out

View File

@ -0,0 +1,727 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# The code in this file is adapted from the BeiT implementation which can be found here:
# https://github.com/microsoft/unilm/tree/master/beit
import logging
import math
import numpy as np
import random
from dataclasses import dataclass, field
from typing import Optional
from omegaconf import II
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
from fairseq.modules import EMAModule, EMAModuleConfig
from fairseq.dataclass import FairseqDataclass
from fairseq.models import BaseFairseqModel, register_model
logger = logging.getLogger(__name__)
@dataclass
class Data2VecVisionConfig(FairseqDataclass):
layer_scale_init_value: float = field(
default=1e-4, metadata={"help": "rescale layer outputs, 0 to disable"}
)
num_mask_patches: int = field(
default=75,
metadata={"help": "number of the visual tokens/patches need be masked"},
)
min_mask_patches_per_block: int = 16
max_mask_patches_per_block: int = 196
image_size: int = 224
patch_size: int = 16
in_channels: int = 3
shared_rel_pos_bias: bool = True
drop_path: float = 0.1
attention_dropout: float = 0.0
depth: int = 12
embed_dim: int = 768
num_heads: int = 12
mlp_ratio: int = 4
loss_beta: float = field(
default=0, metadata={"help": "beta for smooth l1 loss. 0 means use l2 loss"}
)
loss_scale: Optional[float] = field(
default=None,
metadata={
"help": "scale the reconstruction loss by this constant. if None then scales by 1/sqrt(dim)"
},
)
average_top_k_layers: int = field(
default=8, metadata={"help": "how many layers to average"}
)
end_of_block_targets: bool = True
layer_norm_target_layer: bool = False
instance_norm_target_layer: bool = False
batch_norm_target_layer: bool = False
instance_norm_targets: bool = False
layer_norm_targets: bool = False
ema_decay: float = field(default=0.999, metadata={"help": "initial ema decay rate"})
ema_end_decay: float = field(
default=0.9999, metadata={"help": "final ema decay rate"}
)
# when to finish annealing ema decay rate
ema_anneal_end_step: int = II("optimization.max_update")
ema_transformer_only: bool = field(
default=True,
metadata={"help": "whether to momentum update only the transformer layers"},
)
def get_annealed_rate(start, end, curr_step, total_steps):
r = end - start
pct_remaining = 1 - curr_step / total_steps
return end - r * pct_remaining
@register_model("data2vec_vision", dataclass=Data2VecVisionConfig)
class Data2VecVisionModel(BaseFairseqModel):
def __init__(self, cfg: Data2VecVisionConfig):
super().__init__()
self.cfg = cfg
self.ema = None
self.average_top_k_layers = cfg.average_top_k_layers
self.loss_beta = cfg.loss_beta
self.loss_scale = (
cfg.loss_scale
if cfg.loss_scale is not None
else 1 / math.sqrt(cfg.embed_dim)
)
self.patch_embed = PatchEmbed(
img_size=cfg.image_size,
patch_size=cfg.patch_size,
in_chans=cfg.in_channels,
embed_dim=cfg.embed_dim,
)
patch_size = self.patch_embed.patch_size
self.window_size = (
cfg.image_size // patch_size[0],
cfg.image_size // patch_size[1],
)
self.cls_emb = nn.Parameter(torch.FloatTensor(1, 1, cfg.embed_dim))
self.mask_emb = nn.Parameter(torch.FloatTensor(1, 1, cfg.embed_dim))
nn.init.trunc_normal_(self.cls_emb, 0.02)
nn.init.trunc_normal_(self.mask_emb, 0.02)
self.encoder = TransformerEncoder(cfg, self.patch_embed.patch_shape)
self.final_proj = nn.Linear(cfg.embed_dim, cfg.embed_dim)
self.num_updates = 0
def make_ema_teacher(self):
ema_config = EMAModuleConfig(
ema_decay=self.cfg.ema_decay,
ema_fp32=True,
)
self.ema = EMAModule(
self.encoder if self.cfg.ema_transformer_only else self,
ema_config,
)
def set_num_updates(self, num_updates):
super().set_num_updates(num_updates)
if self.ema is None and self.final_proj is not None:
logger.info(f"making ema teacher")
self.make_ema_teacher()
elif self.training and self.ema is not None:
if self.cfg.ema_decay != self.cfg.ema_end_decay:
if num_updates >= self.cfg.ema_anneal_end_step:
decay = self.cfg.ema_end_decay
else:
decay = get_annealed_rate(
self.cfg.ema_decay,
self.cfg.ema_end_decay,
num_updates,
self.cfg.ema_anneal_end_step,
)
self.ema.set_decay(decay)
if self.ema.get_decay() < 1:
self.ema.step(self.encoder if self.cfg.ema_transformer_only else self)
self.num_updates = num_updates
def state_dict(self, destination=None, prefix="", keep_vars=False):
state = super().state_dict(destination, prefix, keep_vars)
if self.ema is not None:
state[prefix + "_ema"] = self.ema.fp32_params
return state
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
if self.ema is not None:
k = prefix + "_ema"
assert k in state_dict
self.ema.restore(state_dict[k], True)
del state_dict[k]
return super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
@classmethod
def build_model(cls, cfg: Data2VecVisionConfig, task=None):
"""Build a new model instance."""
return cls(cfg)
def make_mask(self, bsz, num_masks, min_masks, max_masks):
height, width = self.window_size
masks = np.zeros(shape=(bsz, height, width), dtype=np.int)
for i in range(bsz):
mask = masks[i]
mask_count = 0
min_aspect = 0.3
max_aspect = 1 / min_aspect
log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect))
def _mask(mask, max_mask_patches):
delta = 0
for attempt in range(10):
target_area = random.uniform(min_masks, max_mask_patches)
aspect_ratio = math.exp(random.uniform(*log_aspect_ratio))
h = int(round(math.sqrt(target_area * aspect_ratio)))
w = int(round(math.sqrt(target_area / aspect_ratio)))
if w < width and h < height:
top = random.randint(0, height - h)
left = random.randint(0, width - w)
num_masked = mask[top : top + h, left : left + w].sum()
# Overlap
if 0 < h * w - num_masked <= max_mask_patches:
for i in range(top, top + h):
for j in range(left, left + w):
if mask[i, j] == 0:
mask[i, j] = 1
delta += 1
if delta > 0:
break
return delta
while mask_count < num_masks:
max_mask_patches = min(num_masks - mask_count, max_masks)
delta = _mask(mask, max_mask_patches)
if delta == 0:
break
else:
mask_count += delta
return torch.from_numpy(masks)
def forward(
self,
img,
mask: bool = True,
layer_results: bool = False,
):
x = self.patch_embed(img)
batch_size, seq_len, _ = x.size()
if mask:
mask_indices = self.make_mask(
img.size(0),
self.cfg.num_mask_patches,
self.cfg.min_mask_patches_per_block,
self.cfg.max_mask_patches_per_block,
)
bool_mask = mask_indices.view(mask_indices.size(0), -1).bool()
else:
mask_indices = bool_mask = None
cls_tokens = self.cls_emb.expand(batch_size, -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
if self.ema is not None:
with torch.no_grad():
self.ema.model.eval()
if self.cfg.ema_transformer_only:
y = self.ema.model(
x,
layer_results="end" if self.cfg.end_of_block_targets else "fc",
)
else:
y = self.ema.model(
img,
mask=False,
layer_results=True,
)
y = y[-self.cfg.average_top_k_layers :]
permuted = False
if self.cfg.instance_norm_target_layer or self.cfg.batch_norm_target_layer:
y = [tl.transpose(1, 2) for tl in y] # BTC -> BCT
permuted = True
if self.cfg.batch_norm_target_layer:
y = [
F.batch_norm(
tl.float(), running_mean=None, running_var=None, training=True
)
for tl in y
]
if self.cfg.instance_norm_target_layer:
y = [F.instance_norm(tl.float()) for tl in y]
if permuted:
y = [tl.transpose(1, 2) for tl in y] # BCT -> BTC
if self.cfg.layer_norm_target_layer:
y = [F.layer_norm(tl.float(), tl.shape[-1:]) for tl in y]
y = sum(y) / len(y)
if self.cfg.layer_norm_targets:
y = F.layer_norm(y.float(), y.shape[-1:])
if self.cfg.instance_norm_targets:
y = F.instance_norm(y.float().transpose(1, 2)).transpose(1, 2)
y = y[bool_mask].float()
if mask_indices is not None:
mask_token = self.mask_emb.expand(batch_size, seq_len, -1)
w = mask_indices.view(mask_indices.size(0), -1, 1).type_as(mask_token)
x[:, 1:] = x[:, 1:] * (1 - w) + mask_token * w
if layer_results:
enc_layer_results = "end" if self.cfg.end_of_block_targets else "fc"
else:
enc_layer_results = None
x = self.encoder(x, layer_results=enc_layer_results)
if layer_results or mask_indices is None:
return x
x = x[bool_mask].float()
if self.loss_beta == 0:
loss = F.mse_loss(x, y, reduction="none").sum(dim=-1)
else:
loss = F.smooth_l1_loss(x, y, reduction="none", beta=self.loss_beta).sum(
dim=-1
)
if self.loss_scale > 0:
loss = loss * self.loss_scale
result = {
"losses": {"regression": loss.sum()},
"sample_size": loss.numel(),
"target_var": self.compute_var(y),
"pred_var": self.compute_var(x),
"ema_decay": self.ema.get_decay() * 1000,
}
return result
@staticmethod
def compute_var(y):
y = y.view(-1, y.size(-1))
if dist.is_initialized():
zc = torch.tensor(y.size(0)).cuda()
zs = y.sum(dim=0)
zss = (y ** 2).sum(dim=0)
dist.all_reduce(zc)
dist.all_reduce(zs)
dist.all_reduce(zss)
var = zss / (zc - 1) - (zs ** 2) / (zc * (zc - 1))
return torch.sqrt(var + 1e-6).mean()
else:
return torch.sqrt(y.var(dim=0) + 1e-6).mean()
def remove_pretraining_modules(self, last_layer=None):
self.final_proj = None
self.ema = None
self.encoder.norm = nn.Identity()
self.mask_emb = None
if last_layer is not None:
self.encoder.layers = nn.ModuleList(
l for i, l in enumerate(self.encoder.layers) if i <= last_layer
)
class PatchEmbed(nn.Module):
"""Image to Patch Embedding"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
super().__init__()
if isinstance(img_size, int):
img_size = img_size, img_size
if isinstance(patch_size, int):
patch_size = patch_size, patch_size
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
self.img_size = img_size
self.patch_size = patch_size
self.num_patches = num_patches
self.conv = nn.Conv2d(
in_chans, embed_dim, kernel_size=patch_size, stride=patch_size
)
def forward(self, x):
# BCHW -> BTC
x = self.conv(x).flatten(2).transpose(1, 2)
return x
class Attention(nn.Module):
def __init__(
self,
dim,
num_heads=8,
qkv_bias=True,
attn_drop=0.0,
proj_drop=0.0,
window_size=None,
attn_head_dim=None,
):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
if attn_head_dim is not None:
head_dim = attn_head_dim
all_head_dim = head_dim * self.num_heads
self.scale = head_dim ** -0.5
self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
if qkv_bias:
self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
else:
self.q_bias = None
self.v_bias = None
if window_size:
self.window_size = window_size
self.num_relative_distance = (2 * window_size[0] - 1) * (
2 * window_size[1] - 1
) + 3
self.relative_position_bias_table = nn.Parameter(
torch.zeros(self.num_relative_distance, num_heads)
) # 2*Wh-1 * 2*Ww-1, nH
# cls to token & token 2 cls & cls to cls
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(window_size[0])
coords_w = torch.arange(window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = (
coords_flatten[:, :, None] - coords_flatten[:, None, :]
) # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(
1, 2, 0
).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += window_size[1] - 1
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
relative_position_index = torch.zeros(
size=(window_size[0] * window_size[1] + 1,) * 2,
dtype=relative_coords.dtype,
)
relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
relative_position_index[0, 0:] = self.num_relative_distance - 3
relative_position_index[0:, 0] = self.num_relative_distance - 2
relative_position_index[0, 0] = self.num_relative_distance - 1
self.register_buffer("relative_position_index", relative_position_index)
else:
self.window_size = None
self.relative_position_bias_table = None
self.relative_position_index = None
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(all_head_dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x, rel_pos_bias=None):
B, N, C = x.shape
qkv_bias = None
if self.q_bias is not None:
qkv_bias = torch.cat(
(
self.q_bias,
torch.zeros_like(self.v_bias, requires_grad=False),
self.v_bias,
)
)
# qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
q, k, v = (
qkv[0],
qkv[1],
qkv[2],
) # make torchscript happy (cannot use tensor as tuple)
q = q * self.scale
attn = q @ k.transpose(-2, -1)
if self.relative_position_bias_table is not None:
assert 1==2
relative_position_bias = self.relative_position_bias_table[
self.relative_position_index.view(-1)
].view(
self.window_size[0] * self.window_size[1] + 1,
self.window_size[0] * self.window_size[1] + 1,
-1,
) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(
2, 0, 1
).contiguous() # nH, Wh*Ww, Wh*Ww
attn = attn + relative_position_bias.unsqueeze(0)
print("attn.size() :", attn.size())
print("rel_pos_bias.size() :", rel_pos_bias.size())
if rel_pos_bias is not None:
attn = attn + rel_pos_bias
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
x = self.proj(x)
x = self.proj_drop(x)
return x
class RelativePositionBias(nn.Module):
def __init__(self, window_size, num_heads):
super().__init__()
self.window_size = window_size
self.num_relative_distance = (2 * window_size[0] - 1) * (
2 * window_size[1] - 1
) + 3
self.relative_position_bias_table = nn.Parameter(
torch.zeros(self.num_relative_distance, num_heads)
)
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(window_size[0])
coords_w = torch.arange(window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = (
coords_flatten[:, :, None] - coords_flatten[:, None, :]
) # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(
1, 2, 0
).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += window_size[1] - 1
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
relative_position_index = torch.zeros(
size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype
)
relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
relative_position_index[0, 0:] = self.num_relative_distance - 3
relative_position_index[0:, 0] = self.num_relative_distance - 2
relative_position_index[0, 0] = self.num_relative_distance - 1
self.register_buffer("relative_position_index", relative_position_index)
def forward(self):
relative_position_bias = self.relative_position_bias_table[
self.relative_position_index.view(-1)
].view(
self.window_size[0] * self.window_size[1] + 1,
self.window_size[0] * self.window_size[1] + 1,
-1,
) # Wh*Ww,Wh*Ww,nH
print("self.window_size :", self.window_size)
print("self.num_relative_distance :", self.num_relative_distance)
print("self.relative_position_index :", self.relative_position_index.size(), self.relative_position_index)
print("relative_position_bias.size(), relative_position_bias :",relative_position_bias.size(), relative_position_bias)
print("self.relative_position_bias_table.size(), self.relative_position_bias_table :",self.relative_position_bias_table.size(), self.relative_position_bias_table)
return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def forward(self, x):
if self.drop_prob == 0.0 or not self.training:
return x
keep_prob = 1 - self.drop_prob
shape = (x.shape[0],) + (1,) * (
x.ndim - 1
) # work with diff dim tensors, not just 2D ConvNets
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
random_tensor.floor_()
output = x.div(keep_prob) * random_tensor
return output
def extra_repr(self) -> str:
return "p={}".format(self.drop_prob)
class Block(nn.Module):
def __init__(
self,
dim,
num_heads,
mlp_ratio=4.0,
drop=0.0,
attn_drop=0.0,
drop_path=0.0,
init_values=None,
window_size=None,
):
super().__init__()
self.norm1 = nn.LayerNorm(dim)
self.attn = Attention(
dim,
num_heads=num_heads,
attn_drop=attn_drop,
proj_drop=drop,
window_size=window_size,
)
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = nn.LayerNorm(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = nn.Sequential(
nn.Linear(dim, mlp_hidden_dim),
nn.GELU(),
nn.Linear(mlp_hidden_dim, dim),
nn.Dropout(drop),
)
if init_values > 0:
self.gamma_1 = nn.Parameter(
init_values * torch.ones((dim)), requires_grad=True
)
self.gamma_2 = nn.Parameter(
init_values * torch.ones((dim)), requires_grad=True
)
else:
self.gamma_1, self.gamma_2 = None, None
def forward(self, x, rel_pos_bias=None):
print("inside block :", x.size())
if self.gamma_1 is None:
x = x + self.drop_path(self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias))
fc_feature = self.drop_path(self.mlp(self.norm2(x)))
x = x + fc_feature
else:
x = x + self.drop_path(
self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias)
)
fc_feature = self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
x = x + fc_feature
return x, fc_feature
class TransformerEncoder(nn.Module):
def __init__(self, cfg: Data2VecVisionConfig, patch_shape):
super().__init__()
self.rel_pos_bias = None
if cfg.shared_rel_pos_bias:
self.rel_pos_bias = RelativePositionBias(
window_size=patch_shape, num_heads=cfg.num_heads
)
dpr = [
x.item() for x in torch.linspace(0, cfg.drop_path, cfg.depth)
] # stochastic depth decay rule
print("TransformerEncoder > patch_shape :", patch_shape)
self.blocks = nn.ModuleList(
Block(
dim=cfg.embed_dim,
num_heads=cfg.num_heads,
attn_drop=cfg.attention_dropout,
drop_path=dpr[i],
init_values=cfg.layer_scale_init_value,
window_size=patch_shape if not cfg.shared_rel_pos_bias else None,
)
for i in range(cfg.depth)
)
self.norm = nn.LayerNorm(cfg.embed_dim)
self.apply(self.init_weights)
self.fix_init_weight()
def init_weights(self, m):
std = 0.02
if isinstance(m, nn.Linear):
nn.init.trunc_normal_(m.weight, std=std)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
nn.init.trunc_normal_(m.weight, std=std)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def fix_init_weight(self):
def rescale(param, layer_id):
param.div_(math.sqrt(2.0 * layer_id))
for layer_id, layer in enumerate(self.blocks):
rescale(layer.attn.proj.weight.data, layer_id + 1)
rescale(layer.mlp[2].weight.data, layer_id + 1)
def extract_features(self, x, layer_results):
rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
z = []
for i, blk in enumerate(self.blocks):
x, fc_feature = blk(x, rel_pos_bias=rel_pos_bias)
if layer_results == "end":
z.append(x)
elif layer_results == "fc":
z.append(fc_feature)
return z if layer_results else self.norm(x)
def forward(self, x, layer_results=None):
x = self.extract_features(x, layer_results=layer_results)
if layer_results:
return [z[:, 1:] for z in x]
x = x[:, 1:]
return x

View File

@ -0,0 +1,825 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# The code in this file is adapted from the BeiT implementation which can be found here:
# https://github.com/microsoft/unilm/tree/master/beit
import logging
from dataclasses import dataclass
from functools import partial
from timm.models.vision_transformer import PatchEmbed, Block
import torch
import torch.nn as nn
import numpy as np
from fairseq.dataclass import FairseqDataclass
from fairseq.models import BaseFairseqModel, register_model
from fairseq.models.wav2vec.wav2vec2 import TransformerSentenceEncoderLayer
from apex.normalization import FusedLayerNorm
import torch.nn.functional as F
logger = logging.getLogger(__name__)
@dataclass
class MaeConfig(FairseqDataclass):
input_size: int = 224
in_chans: int = 3
patch_size: int = 16
embed_dim: int = 768
depth: int = 12
num_heads: int = 12
decoder_embed_dim: int = 512
decoder_depth: int = 8
decoder_num_heads: int = 16
mlp_ratio: int = 4
norm_eps: float = 1e-6
drop_path_rate: float = 0.0
mask_ratio: float = 0.75
norm_pix_loss: bool = True
w2v_block: bool = False
alt_block: bool = False
alt_block2: bool = False
alt_attention: bool = False
block_dropout: float = 0
attention_dropout: float = 0
activation_dropout: float = 0
layer_norm_first: bool = False
fused_ln: bool = True
end_of_block_targets: bool = True
no_decoder_embed: bool = False
no_decoder_pos_embed: bool = False
mask_noise_std: float = 0
single_qkv: bool = False
use_rel_pos_bias: bool = False
no_cls: bool = False
def modify_relative_position_bias(orig_bias, bsz, mask):
if mask is None:
return orig_bias.unsqueeze(0).repeat(
bsz, 1, 1, 1
) # heads x seq_len x seq_len => bsz x heads x seq_len x seq_len
heads, max_seq_len, max_seq_len = orig_bias.shape # includes CLS token
mask_for_rel_pos_bias = torch.cat(
(torch.zeros(bsz, 1, dtype=mask.dtype, device=mask.device), mask), dim=1
).bool() # bsz x seqlen (add CLS token)
unmasked_for_rel_pos_bias = ~mask_for_rel_pos_bias
unmasked_for_rel_pos_bias = unmasked_for_rel_pos_bias.unsqueeze(1).repeat(
1, heads, 1
) # bsz x seq_len => bsz x heads x seq_len
b_t_t_rel_pos_bias = orig_bias.unsqueeze(0).repeat(
bsz, 1, 1, 1
) # heads x seq_len x seq_len => bsz x heads x seq_len x seq_len
b_t_t_rel_pos_bias = b_t_t_rel_pos_bias.masked_select(
unmasked_for_rel_pos_bias.unsqueeze(-1)
)
b_t_t_rel_pos_bias = b_t_t_rel_pos_bias.view(bsz, heads, -1, max_seq_len)
new_len = b_t_t_rel_pos_bias.size(-2)
b_t_t_rel_pos_bias = b_t_t_rel_pos_bias.masked_select(
unmasked_for_rel_pos_bias.unsqueeze(-2)
)
b_t_t_rel_pos_bias = b_t_t_rel_pos_bias.view(bsz, heads, new_len, new_len)
return b_t_t_rel_pos_bias
class AltBlock(nn.Module):
def __init__(
self,
dim,
num_heads,
mlp_ratio=4.0,
qkv_bias=False,
qk_scale=None,
drop=0.0,
attn_drop=0.0,
drop_path=0.0,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
layer_norm_first=True,
ffn_targets=False,
use_rel_pos_bias=False,
window_size=None,
alt_attention=False,
):
super().__init__()
self.layer_norm_first = layer_norm_first
self.ffn_targets = ffn_targets
from timm.models.vision_transformer import Attention, DropPath, Mlp
self.norm1 = norm_layer(dim)
self.use_rel_pos_bias = use_rel_pos_bias
if use_rel_pos_bias:
self.attn = AltAttention(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=drop,
window_size=window_size,
)
else:
if alt_attention:
from .multi.modules import AltAttention as AltAttention2
self.attn = AltAttention2(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=drop,
)
else:
self.attn = Attention(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=drop,
)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop,
)
def forward(self, x, rel_pos_bias=None, pos_mask=None):
if self.layer_norm_first:
if self.use_rel_pos_bias:
x = x + self.drop_path(
self.attn(
self.norm1(x), rel_pos_bias=rel_pos_bias, pos_mask=pos_mask
)
)
else:
x = x + self.drop_path(self.attn(self.norm1(x)))
t = self.mlp(self.norm2(x))
x = x + self.drop_path(t)
if not self.ffn_targets:
t = x
return x, t
else:
if self.use_rel_pos_bias:
x = x + self.drop_path(
self.attn(x, rel_pos_bias=rel_pos_bias, pos_mask=pos_mask)
)
else:
x = x + self.drop_path(self.attn(x))
r = x = self.norm1(x)
x = self.mlp(x)
t = x
x = self.norm2(r + self.drop_path(x))
if not self.ffn_targets:
t = x
return x, t
class AltAttention(nn.Module):
def __init__(
self,
dim,
num_heads=8,
qkv_bias=True,
qk_scale=None,
attn_drop=0.0,
proj_drop=0.0,
window_size=None,
attn_head_dim=None,
):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
if attn_head_dim is not None:
head_dim = attn_head_dim
all_head_dim = head_dim * self.num_heads
self.scale = qk_scale or head_dim ** -0.5
self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
if qkv_bias:
self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
else:
self.q_bias = None
self.v_bias = None
if window_size:
self.window_size = window_size
self.num_relative_distance = (2 * window_size[0] - 1) * (
2 * window_size[1] - 1
) + 3
self.relative_position_bias_table = nn.Parameter(
torch.zeros(self.num_relative_distance, num_heads)
) # 2*Wh-1 * 2*Ww-1, nH
# cls to token & token 2 cls & cls to cls
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(window_size[0])
coords_w = torch.arange(window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = (
coords_flatten[:, :, None] - coords_flatten[:, None, :]
) # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(
1, 2, 0
).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += window_size[1] - 1
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
relative_position_index = torch.zeros(
size=(window_size[0] * window_size[1] + 1,) * 2,
dtype=relative_coords.dtype,
)
relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
relative_position_index[0, 0:] = self.num_relative_distance - 3
relative_position_index[0:, 0] = self.num_relative_distance - 2
relative_position_index[0, 0] = self.num_relative_distance - 1
self.register_buffer("relative_position_index", relative_position_index)
else:
self.window_size = None
self.relative_position_bias_table = None
self.relative_position_index = None
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(all_head_dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x, rel_pos_bias=None, pos_mask=None):
B, N, C = x.shape
qkv_bias = None
if self.q_bias is not None:
qkv_bias = torch.cat(
(
self.q_bias,
torch.zeros_like(self.v_bias, requires_grad=False),
self.v_bias,
)
)
# qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
q, k, v = (
qkv[0],
qkv[1],
qkv[2],
) # make torchscript happy (cannot use tensor as tuple)
q = q * self.scale
attn = q @ k.transpose(-2, -1)
if self.relative_position_bias_table is not None:
relative_position_bias = self.relative_position_bias_table[
self.relative_position_index.view(-1)
].view(
self.window_size[0] * self.window_size[1] + 1,
self.window_size[0] * self.window_size[1] + 1,
-1,
) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(
2, 0, 1
).contiguous() # nH, Wh*Ww, Wh*Ww
attn = attn + modify_relative_position_bias(
relative_position_bias, x.size(0), pos_mask
)
if rel_pos_bias is not None:
attn = attn + rel_pos_bias
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
x = self.proj(x)
x = self.proj_drop(x)
return x
class RelativePositionBias(nn.Module):
def __init__(self, window_size, num_heads):
super().__init__()
self.window_size = window_size
self.num_relative_distance = (2 * window_size[0] - 1) * (
2 * window_size[1] - 1
) + 3
self.relative_position_bias_table = nn.Parameter(
torch.zeros(self.num_relative_distance, num_heads)
)
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(window_size[0])
coords_w = torch.arange(window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = (
coords_flatten[:, :, None] - coords_flatten[:, None, :]
) # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(
1, 2, 0
).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += window_size[1] - 1
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
relative_position_index = torch.zeros(
size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype
)
relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
relative_position_index[0, 0:] = self.num_relative_distance - 3
relative_position_index[0:, 0] = self.num_relative_distance - 2
relative_position_index[0, 0] = self.num_relative_distance - 1
self.register_buffer("relative_position_index", relative_position_index)
def forward(self):
relative_position_bias = self.relative_position_bias_table[
self.relative_position_index.view(-1)
].view(
self.window_size[0] * self.window_size[1] + 1,
self.window_size[0] * self.window_size[1] + 1,
-1,
) # Wh*Ww,Wh*Ww,nH
return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
"""
grid_size: int of the grid height and width
return:
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
"""
grid_h = np.arange(grid_size, dtype=np.float32)
grid_w = np.arange(grid_size, dtype=np.float32)
grid = np.meshgrid(grid_w, grid_h) # here w goes first
grid = np.stack(grid, axis=0)
grid = grid.reshape([2, 1, grid_size, grid_size])
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
if cls_token:
pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
return pos_embed
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
assert embed_dim % 2 == 0
# use half of dimensions to encode grid_h
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
return emb
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
"""
embed_dim: output dimension for each position
pos: a list of positions to be encoded: size (M,)
out: (M, D)
"""
assert embed_dim % 2 == 0
omega = np.arange(embed_dim // 2, dtype=np.float)
omega /= embed_dim / 2.0
omega = 1.0 / 10000 ** omega # (D/2,)
pos = pos.reshape(-1) # (M,)
out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
emb_sin = np.sin(out) # (M, D/2)
emb_cos = np.cos(out) # (M, D/2)
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
return emb
def interpolate_pos_embed(model, checkpoint_model):
if "pos_embed" in checkpoint_model:
pos_embed_checkpoint = checkpoint_model["pos_embed"]
embedding_size = pos_embed_checkpoint.shape[-1]
num_patches = model.patch_embed.num_patches
num_extra_tokens = model.pos_embed.shape[-2] - num_patches
# height (== width) for the checkpoint position embedding
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
# height (== width) for the new position embedding
new_size = int(num_patches ** 0.5)
# class_token and dist_token are kept unchanged
if orig_size != new_size:
print(
"Position interpolate from %dx%d to %dx%d"
% (orig_size, orig_size, new_size, new_size)
)
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
# only the position tokens are interpolated
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
pos_tokens = pos_tokens.reshape(
-1, orig_size, orig_size, embedding_size
).permute(0, 3, 1, 2)
pos_tokens = torch.nn.functional.interpolate(
pos_tokens,
size=(new_size, new_size),
mode="bicubic",
align_corners=False,
)
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
checkpoint_model["pos_embed"] = new_pos_embed
@register_model("mae", dataclass=MaeConfig)
class MaeModel(BaseFairseqModel):
def __init__(self, cfg: MaeConfig):
super().__init__()
self.cfg = cfg
self.mask_ratio = cfg.mask_ratio
# --------------------------------------------------------------------------
# MAE encoder specifics
self.patch_embed = PatchEmbed(
cfg.input_size, cfg.patch_size, cfg.in_chans, cfg.embed_dim
)
num_patches = self.patch_embed.num_patches
self.cls_token = nn.Parameter(torch.zeros(1, 1, cfg.embed_dim)) if not cfg.no_cls else None
self.pos_embed = nn.Parameter(
torch.zeros(1, num_patches + int(not cfg.no_cls), cfg.embed_dim), requires_grad=False
) # fixed sin-cos embedding
norm_layer = partial(nn.LayerNorm, eps=cfg.norm_eps)
dpr = [
x.item() for x in torch.linspace(0, cfg.drop_path_rate, cfg.depth)
] # stochastic depth decay rule
def make_block(drop_path):
if cfg.w2v_block:
return TransformerSentenceEncoderLayer(
embedding_dim=cfg.embed_dim,
ffn_embedding_dim=cfg.embed_dim * cfg.mlp_ratio,
num_attention_heads=cfg.num_heads,
dropout=cfg.block_dropout,
attention_dropout=cfg.attention_dropout,
activation_dropout=cfg.activation_dropout,
activation_fn="gelu",
layer_norm_first=cfg.layer_norm_first,
drop_path=drop_path,
norm_eps=1e-6,
single_qkv=cfg.single_qkv,
fused_ln=cfg.fused_ln,
)
elif cfg.alt_block:
window_size = (
cfg.input_size // self.patch_embed.patch_size[0],
cfg.input_size // self.patch_embed.patch_size[1],
)
return AltBlock(
cfg.embed_dim,
cfg.num_heads,
cfg.mlp_ratio,
qkv_bias=True,
qk_scale=None,
norm_layer=norm_layer,
drop_path=drop_path,
layer_norm_first=cfg.layer_norm_first,
ffn_targets=not cfg.end_of_block_targets,
use_rel_pos_bias=cfg.use_rel_pos_bias,
window_size=window_size
if (self.cfg.use_rel_pos_bias and not self.cfg.shared_rel_pos_bias)
else None,
alt_attention=cfg.alt_attention,
)
elif cfg.alt_block2:
from .multi.modules import AltBlock as AltBlock2
return AltBlock2(
cfg.embed_dim,
cfg.num_heads,
cfg.mlp_ratio,
qkv_bias=True,
qk_scale=None,
norm_layer=norm_layer,
drop_path=drop_path,
layer_norm_first=cfg.layer_norm_first,
ffn_targets=not cfg.end_of_block_targets,
)
else:
return Block(
cfg.embed_dim,
cfg.num_heads,
cfg.mlp_ratio,
qkv_bias=True,
qk_scale=None,
norm_layer=norm_layer,
drop_path=drop_path,
)
self.blocks = nn.ModuleList([make_block(dpr[i]) for i in range(cfg.depth)])
self.norm = norm_layer(cfg.embed_dim)
# --------------------------------------------------------------------------
# --------------------------------------------------------------------------
# MAE decoder specifics
self.decoder_embed = (
nn.Linear(cfg.embed_dim, cfg.decoder_embed_dim, bias=True)
if not cfg.no_decoder_embed
else None
)
self.mask_token = (
nn.Parameter(
torch.zeros(
1,
1,
cfg.decoder_embed_dim
if not cfg.no_decoder_embed
else cfg.embed_dim,
)
)
if cfg.mask_noise_std <= 0
else None
)
self.decoder_pos_embed = (
nn.Parameter(
torch.zeros(
1,
num_patches + 1,
cfg.decoder_embed_dim
if not cfg.no_decoder_embed
else cfg.embed_dim,
),
requires_grad=False,
)
if not cfg.no_decoder_pos_embed
else None
)
self.decoder_blocks = nn.ModuleList(
[
Block(
cfg.decoder_embed_dim,
cfg.decoder_num_heads,
cfg.mlp_ratio,
qkv_bias=True,
qk_scale=None,
norm_layer=norm_layer,
)
for _ in range(cfg.decoder_depth)
]
)
self.decoder_norm = norm_layer(cfg.decoder_embed_dim)
self.decoder_pred = nn.Linear(
cfg.decoder_embed_dim, cfg.patch_size ** 2 * cfg.in_chans, bias=True
) # decoder to patch
# --------------------------------------------------------------------------
self.norm_pix_loss = cfg.norm_pix_loss
self.initialize_weights()
for pn, p in self.named_parameters():
if len(p.shape) == 1 or pn.endswith(".bias"):
p.param_group = "no_decay"
else:
p.param_group = "with_decay"
def initialize_weights(self):
# initialization
# initialize (and freeze) pos_embed by sin-cos embedding
pos_embed = get_2d_sincos_pos_embed(
self.pos_embed.shape[-1],
int(self.patch_embed.num_patches ** 0.5),
cls_token=not self.cfg.no_cls,
)
self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
if self.decoder_pos_embed is not None:
decoder_pos_embed = get_2d_sincos_pos_embed(
self.decoder_pos_embed.shape[-1],
int(self.patch_embed.num_patches ** 0.5),
cls_token=not self.cfg.no_cls,
)
self.decoder_pos_embed.data.copy_(
torch.from_numpy(decoder_pos_embed).float().unsqueeze(0)
)
# initialize patch_embed like nn.Linear (instead of nn.Conv2d)
w = self.patch_embed.proj.weight.data
torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
# timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
if self.cls_token is not None:
torch.nn.init.normal_(self.cls_token, std=0.02)
if self.mask_token is not None:
torch.nn.init.normal_(self.mask_token, std=0.02)
# initialize nn.Linear and nn.LayerNorm
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
# we use xavier_uniform following official JAX ViT:
torch.nn.init.xavier_uniform_(m.weight)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm) or isinstance(m, FusedLayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def patchify(self, imgs):
"""
imgs: (N, 3, H, W)
x: (N, L, patch_size**2 *3)
"""
p = self.patch_embed.patch_size[0]
assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0
h = w = imgs.shape[2] // p
x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p))
x = torch.einsum("nchpwq->nhwpqc", x)
x = x.reshape(shape=(imgs.shape[0], h * w, p ** 2 * 3))
return x
def unpatchify(self, x):
"""
x: (N, L, patch_size**2 *3)
imgs: (N, 3, H, W)
"""
p = self.patch_embed.patch_size[0]
h = w = int(x.shape[1] ** 0.5)
assert h * w == x.shape[1]
x = x.reshape(shape=(x.shape[0], h, w, p, p, 3))
x = torch.einsum("nhwpqc->nchpwq", x)
imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p))
return imgs
def random_masking(self, x, mask_ratio):
"""
Perform per-sample random masking by per-sample shuffling.
Per-sample shuffling is done by argsort random noise.
x: [N, L, D], sequence
"""
N, L, D = x.shape # batch, length, dim
len_keep = int(L * (1 - mask_ratio))
noise = torch.rand(N, L, device=x.device) # noise in [0, 1]
# sort noise for each sample
ids_shuffle = torch.argsort(
noise, dim=1
) # ascend: small is keep, large is remove
ids_restore = torch.argsort(ids_shuffle, dim=1)
# keep the first subset
ids_keep = ids_shuffle[:, :len_keep]
x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
# generate the binary mask: 0 is keep, 1 is remove
mask = torch.ones([N, L], device=x.device)
mask[:, :len_keep] = 0
# unshuffle to get the binary mask
mask = torch.gather(mask, dim=1, index=ids_restore)
return x_masked, mask, ids_restore # x_masked is actually unmasked x
@classmethod
def build_model(cls, cfg: MaeConfig, task=None):
"""Build a new model instance."""
return cls(cfg)
def forward_encoder(self, x, mask_ratio):
# embed patches
x = self.patch_embed(x)
# add pos embed w/o cls token
# if self.cls_token is not None:
# x = x + self.pos_embed
# else:
x = x + self.pos_embed[:, 1:, :]
# masking: length -> length * mask_ratio
if mask_ratio > 0:
x, mask, ids_restore = self.random_masking(x, mask_ratio)
else:
mask = ids_restore = None
# append cls token
if self.cls_token is not None:
cls_token = self.cls_token + self.pos_embed[:, :1, :]
cls_tokens = cls_token.expand(x.shape[0], -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
# apply Transformer blocks
for blk in self.blocks:
x = blk(x)
if self.norm is not None:
x = self.norm(x)
return x, mask, ids_restore
def forward_decoder(self, x, ids_restore):
# embed tokens
x = self.decoder_embed(x)
# append mask tokens to sequence
mask_tokens = self.mask_token.repeat(
x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1
)
if self.cls_token is not None:
x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token
else:
x_ = torch.cat([x, mask_tokens], dim=1) # no cls token
x_ = torch.gather(
x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])
) # unshuffle
if self.cls_token is not None:
x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token
# add pos embed
x = x + self.decoder_pos_embed
# apply Transformer blocks
for blk in self.decoder_blocks:
x = blk(x)
x = self.decoder_norm(x)
# predictor projection
x = self.decoder_pred(x)
if self.cls_token is not None:
# remove cls token
x = x[:, 1:, :]
return x
def forward_loss(self, imgs, pred, mask):
"""
imgs: [N, 3, H, W]
pred: [N, L, p*p*3]
mask: [N, L], 0 is keep, 1 is remove,
"""
target = self.patchify(imgs)
if self.norm_pix_loss:
mean = target.mean(dim=-1, keepdim=True)
var = target.var(dim=-1, keepdim=True)
target = (target - mean) / (var + 1.0e-6) ** 0.5
loss = (pred - target) ** 2
loss = loss.mean(dim=-1) # [N, L], mean loss per patch
loss = (loss * mask).sum()
return loss, mask.sum()
def forward(self, imgs, predictions_only=False):
latent, mask, ids_restore = self.forward_encoder(
imgs, self.mask_ratio if not predictions_only else 0
)
if predictions_only:
return latent
pred = self.forward_decoder(latent, ids_restore) # [N, L, p*p*3]
loss, sample_size = self.forward_loss(imgs, pred, mask)
result = {
"losses": {"regression": loss},
"sample_size": sample_size,
}
return result
def remove_pretraining_modules(self):
self.decoder_embed = None
self.decoder_blocks = None
self.decoder_norm = None
self.decoder_pos_embed = None
self.decoder_pred = None
self.mask_token = None
if self.cfg.layer_norm_first:
self.norm = None

View File

@ -0,0 +1,386 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# The code in this file is adapted from the BeiT implementation which can be found here:
# https://github.com/microsoft/unilm/tree/master/beit
import logging
from dataclasses import dataclass
from enum import Enum, auto
from typing import Any, Optional
import numpy as np
from omegaconf import II, MISSING
import torch
import torch.nn as nn
import torch.nn.functional as F
from fairseq import checkpoint_utils, tasks
from omegaconf import open_dict
from fairseq.dataclass import FairseqDataclass
from fairseq.models import BaseFairseqModel, register_model
from .mae import interpolate_pos_embed
logger = logging.getLogger(__name__)
class PredictionMode(Enum):
MEAN_POOLING = auto()
CLS_TOKEN = auto()
LIN_SOFTMAX = auto()
@dataclass
class MaeImageClassificationConfig(FairseqDataclass):
model_path: str = MISSING
no_pretrained_weights: bool = False
linear_classifier: bool = False
num_classes: int = 1000
mixup: float = 0.8
cutmix: float = 1.0
label_smoothing: float = 0.1
drop_path_rate: float = 0.1
layer_decay: float = 0.65
mixup_prob: float = 1.0
mixup_switch_prob: float = 0.5
mixup_mode: str = "batch"
pretrained_model_args: Any = None
data: str = II("task.data")
norm_eps: Optional[float] = None
remove_alibi: bool = False
# regularization overwrites
encoder_dropout: float = 0
post_mlp_drop: float = 0
attention_dropout: float = 0
activation_dropout: float = 0.0
dropout_input: float = 0.0
layerdrop: float = 0.0
prenet_layerdrop: float = 0
prenet_dropout: float = 0
use_fc_norm: bool = True
prediction_mode: PredictionMode = PredictionMode.MEAN_POOLING
no_decay_blocks: bool = True
def get_layer_id_for_vit(name, num_layers):
"""
Assign a parameter with its layer id
Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33
"""
if name in ["cls_token", "pos_embed"]:
return 0
elif name.startswith("patch_embed"):
return 0
elif name.startswith("rel_pos_bias"):
return num_layers - 1
elif name.startswith("blocks"):
return int(name.split(".")[1]) + 1
else:
return num_layers
@register_model("mae_image_classification", dataclass=MaeImageClassificationConfig)
class MaeImageClassificationModel(BaseFairseqModel):
def __init__(self, cfg: MaeImageClassificationConfig):
super().__init__()
self.cfg = cfg
if cfg.pretrained_model_args is None:
state = checkpoint_utils.load_checkpoint_to_cpu(cfg.model_path, {})
pretrained_args = state.get("cfg", None)
pretrained_args.criterion = None
pretrained_args.lr_scheduler = None
logger.info(pretrained_args.model)
with open_dict(pretrained_args.model):
pretrained_args.model.drop_path_rate = cfg.drop_path_rate
if cfg.norm_eps is not None:
pretrained_args.model.norm_eps = cfg.norm_eps
cfg.pretrained_model_args = pretrained_args
logger.info(pretrained_args)
else:
state = None
pretrained_args = cfg.pretrained_model_args
if "data" in pretrained_args.task:
pretrained_args.task.data = cfg.data
elif "image" in pretrained_args.task:
pretrained_args.task.image.data = cfg.data
if "modalities" in pretrained_args.model:
prenet_blocks = pretrained_args.model["modalities"]["image"]["prenet_depth"]
model_blocks = pretrained_args.model["depth"]
with open_dict(pretrained_args):
dpr = np.linspace(0, cfg.drop_path_rate, model_blocks).tolist()
pretrained_args.model["modalities"]["image"][
"start_drop_path_rate"
] = dpr[0]
pretrained_args.model["modalities"]["image"][
"end_drop_path_rate"
] = max(0, dpr[prenet_blocks - 1])
pretrained_args.model["start_drop_path_rate"] = dpr[prenet_blocks]
pretrained_args.model["end_drop_path_rate"] = dpr[-1]
if "mae_masking" in pretrained_args.model["modalities"]["image"]:
del pretrained_args.model["modalities"]["image"]["mae_masking"]
if cfg.remove_alibi:
pretrained_args.model["modalities"]["image"][
"use_alibi_encoder"
] = False
if (
state is not None
and "modality_encoders.IMAGE.alibi_bias" in state["model"]
):
del state["model"]["modality_encoders.IMAGE.alibi_bias"]
pretrained_args.model["encoder_dropout"] = cfg.encoder_dropout
pretrained_args.model["post_mlp_drop"] = cfg.post_mlp_drop
pretrained_args.model["attention_dropout"] = cfg.attention_dropout
pretrained_args.model["activation_dropout"] = cfg.activation_dropout
pretrained_args.model["dropout_input"] = cfg.dropout_input
pretrained_args.model["layerdrop"] = cfg.layerdrop
pretrained_args.model["modalities"]["image"][
"prenet_layerdrop"
] = cfg.prenet_layerdrop
pretrained_args.model["modalities"]["image"][
"prenet_dropout"
] = cfg.prenet_dropout
else:
# not d2v multi
with open_dict(pretrained_args):
pretrained_args.model["drop_path_rate"] = cfg.drop_path_rate
pretrained_args.model["block_dropout"] = cfg.encoder_dropout
pretrained_args.model["attention_dropout"] = cfg.attention_dropout
pretrained_args.model["activation_dropout"] = cfg.activation_dropout
task = tasks.setup_task(pretrained_args.task)
model = task.build_model(pretrained_args.model, from_checkpoint=True)
self.d2v_multi = "data2vec_multi" in pretrained_args.model._name
self.linear_classifier = cfg.linear_classifier
self.model = model
if state is not None and not cfg.no_pretrained_weights:
interpolate_pos_embed(model, state)
if "modality_encoders.IMAGE.positional_encoder.pos_embed" in state["model"]:
state["model"][
"modality_encoders.IMAGE.positional_encoder.positions"
] = state["model"][
"modality_encoders.IMAGE.positional_encoder.pos_embed"
]
del state["model"][
"modality_encoders.IMAGE.positional_encoder.pos_embed"
]
if "modality_encoders.IMAGE.encoder_mask" in state["model"]:
del state["model"]["modality_encoders.IMAGE.encoder_mask"]
model.load_state_dict(state["model"], strict=True)
if self.d2v_multi:
model.remove_pretraining_modules(modality="image")
else:
model.remove_pretraining_modules()
if self.linear_classifier:
model.requires_grad_(False)
self.fc_norm = None
if self.cfg.use_fc_norm:
self.fc_norm = nn.LayerNorm(pretrained_args.model.embed_dim, eps=1e-6)
nn.init.constant_(self.fc_norm.bias, 0)
nn.init.constant_(self.fc_norm.weight, 1.0)
self.head = nn.Linear(pretrained_args.model.embed_dim, cfg.num_classes)
nn.init.trunc_normal_(self.head.weight, std=0.02)
nn.init.constant_(self.head.bias, 0)
self.mixup_fn = None
if cfg.mixup > 0 or cfg.cutmix > 0:
from timm.data import Mixup
self.mixup_fn = Mixup(
mixup_alpha=cfg.mixup,
cutmix_alpha=cfg.cutmix,
cutmix_minmax=None,
prob=cfg.mixup_prob,
switch_prob=cfg.mixup_switch_prob,
mode=cfg.mixup_mode,
label_smoothing=cfg.label_smoothing,
num_classes=cfg.num_classes,
)
if self.model.norm is not None:
for pn, p in self.model.norm.named_parameters():
if len(p.shape) == 1 or pn.endswith(".bias"):
p.optim_overrides = {"optimizer": {"weight_decay_scale": 0}}
if self.fc_norm is not None:
for pn, p in self.fc_norm.named_parameters():
if len(p.shape) == 1 or pn.endswith(".bias"):
p.optim_overrides = {"optimizer": {"weight_decay_scale": 0}}
for pn, p in self.head.named_parameters():
if len(p.shape) == 1 or pn.endswith(".bias"):
p.optim_overrides = {"optimizer": {"weight_decay_scale": 0}}
if self.d2v_multi:
mod_encs = list(model.modality_encoders.values())
assert len(mod_encs) == 1, len(mod_encs)
blocks = list(mod_encs[0].context_encoder.blocks) + list(model.blocks)
else:
blocks = model.blocks
num_layers = len(blocks) + 1
layer_scales = list(
cfg.layer_decay ** (num_layers - i) for i in range(num_layers + 1)
)
if self.d2v_multi:
for n, p in self.model.named_parameters():
optimizer_override_dict = {}
if len(p.shape) == 1 or n.endswith(".bias"):
optimizer_override_dict["weight_decay_scale"] = 0
p.optim_overrides = {"optimizer": optimizer_override_dict}
if cfg.layer_decay > 0:
for i, b in enumerate(blocks):
lid = i + 1
if layer_scales[lid] == 1.0:
continue
for n, p in b.named_parameters():
optim_override = getattr(p, "optim_overrides", {})
if "optimizer" not in optim_override:
optim_override["optimizer"] = {}
if cfg.no_decay_blocks:
optim_override["optimizer"]["lr_scale"] = layer_scales[lid]
p.optim_overrides = optim_override
else:
optim_override["optimizer"] = {
"lr_scale": layer_scales[lid]
}
p.optim_overrides = optim_override
else:
for n, p in self.model.named_parameters():
optimizer_override_dict = {}
layer_id = get_layer_id_for_vit(n, num_layers)
if len(p.shape) == 1 or n.endswith(".bias"):
optimizer_override_dict["weight_decay_scale"] = 0
if cfg.layer_decay > 0:
optimizer_override_dict["lr_scale"] = layer_scales[layer_id]
p.optim_overrides = {"optimizer": optimizer_override_dict}
@classmethod
def build_model(cls, cfg: MaeImageClassificationConfig, task=None):
"""Build a new model instance."""
return cls(cfg)
def forward(
self,
imgs,
labels=None,
):
if self.training and self.mixup_fn is not None and labels is not None:
imgs, labels = self.mixup_fn(imgs, labels)
if self.linear_classifier:
with torch.no_grad():
x = self.model_forward(imgs)
else:
x = self.model_forward(imgs)
if self.cfg.prediction_mode == PredictionMode.MEAN_POOLING:
x = x.mean(dim=1)
elif self.cfg.prediction_mode == PredictionMode.CLS_TOKEN:
x = x[:, 0]
elif self.cfg.prediction_mode == PredictionMode.LIN_SOFTMAX:
dtype = x.dtype
x = F.logsigmoid(x.float())
x = torch.logsumexp(x + x, dim=1) - torch.logsumexp(x + 1e-6, dim=1)
x = x.clamp(max=0)
x = x - torch.log(-(torch.expm1(x)))
x = torch.nan_to_num(x, nan=0, posinf=0, neginf=0)
x = x.to(dtype=dtype)
else:
raise Exception(f"unknown prediction mode {self.cfg.prediction_mode.name}")
if self.fc_norm is not None:
x = self.fc_norm(x)
x = self.head(x)
if labels is None:
return x
if self.training and self.mixup_fn is not None:
loss = -labels * F.log_softmax(x.float(), dim=-1)
else:
loss = F.cross_entropy(
x.float(),
labels,
label_smoothing=self.cfg.label_smoothing if self.training else 0,
reduction="none",
)
result = {
"losses": {"regression": loss},
"sample_size": imgs.size(0),
}
if not self.training:
with torch.no_grad():
pred = x.argmax(-1)
correct = (pred == labels).sum()
result["correct"] = correct
return result
def model_forward(self, imgs):
if self.d2v_multi:
x = self.model.extract_features(
imgs,
mode="IMAGE",
mask=False,
remove_extra_tokens=(
self.cfg.prediction_mode != PredictionMode.CLS_TOKEN
),
)["x"]
else:
x = self.model(imgs, predictions_only=True)
if (
"no_cls" not in self.model.cfg or not self.model.cfg.no_cls
) and not self.cfg.prediction_mode == PredictionMode.CLS_TOKEN:
x = x[:, 1:]
return x

View File

@ -0,0 +1,192 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from functools import partial
import torch
import torch.nn as nn
import numpy as np
from dataclasses import dataclass, field
from typing import Callable, Dict, Optional
from fairseq.models.wav2vec import ConvFeatureExtractionModel
from fairseq.modules import (
LayerNorm,
SamePad,
TransposeLast,
)
from fairseq.tasks import FairseqTask
from .base import D2vModalityConfig, ModalitySpecificEncoder, get_alibi_bias
from .modules import BlockEncoder, Decoder1d
from examples.data2vec.data.modality import Modality
@dataclass
class D2vAudioConfig(D2vModalityConfig):
type: Modality = Modality.AUDIO
extractor_mode: str = "layer_norm"
feature_encoder_spec: str = field(
default="[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512,2,2)] + [(512,2,2)]",
metadata={
"help": "string describing convolutional feature extraction layers in form of a python list that contains "
"[(dim, kernel_size, stride), ...]"
},
)
conv_pos_width: int = field(
default=95,
metadata={"help": "number of filters for convolutional positional embeddings"},
)
conv_pos_groups: int = field(
default=16,
metadata={"help": "number of groups for convolutional positional embedding"},
)
conv_pos_depth: int = field(
default=5,
metadata={"help": "depth of positional encoder network"},
)
conv_pos_pre_ln: bool = False
class AudioEncoder(ModalitySpecificEncoder):
modality_cfg: D2vAudioConfig
def __init__(
self,
modality_cfg: D2vAudioConfig,
embed_dim: int,
make_block: Callable[[float], nn.ModuleList],
norm_layer: Callable[[int], nn.LayerNorm],
layer_norm_first: bool,
alibi_biases: Dict,
task: Optional[FairseqTask],
):
self.feature_enc_layers = eval(modality_cfg.feature_encoder_spec)
feature_embed_dim = self.feature_enc_layers[-1][0]
local_encoder = ConvFeatureExtractionModel(
conv_layers=self.feature_enc_layers,
dropout=0.0,
mode=modality_cfg.extractor_mode,
conv_bias=False,
)
project_features = nn.Sequential(
TransposeLast(),
nn.LayerNorm(feature_embed_dim),
nn.Linear(feature_embed_dim, embed_dim),
)
num_pos_layers = modality_cfg.conv_pos_depth
k = max(3, modality_cfg.conv_pos_width // num_pos_layers)
positional_encoder = nn.Sequential(
TransposeLast(),
*[
nn.Sequential(
nn.Conv1d(
embed_dim,
embed_dim,
kernel_size=k,
padding=k // 2,
groups=modality_cfg.conv_pos_groups,
),
SamePad(k),
TransposeLast(),
LayerNorm(embed_dim, elementwise_affine=False),
TransposeLast(),
nn.GELU(),
)
for _ in range(num_pos_layers)
],
TransposeLast(),
)
if modality_cfg.conv_pos_pre_ln:
positional_encoder = nn.Sequential(LayerNorm(embed_dim), positional_encoder)
dpr = np.linspace(
modality_cfg.start_drop_path_rate,
modality_cfg.end_drop_path_rate,
modality_cfg.prenet_depth,
)
context_encoder = BlockEncoder(
nn.ModuleList(make_block(dpr[i]) for i in range(modality_cfg.prenet_depth)),
norm_layer(embed_dim) if not layer_norm_first else None,
layer_norm_first,
modality_cfg.prenet_layerdrop,
modality_cfg.prenet_dropout,
)
decoder = (
Decoder1d(modality_cfg.decoder, embed_dim)
if modality_cfg.decoder is not None
else None
)
alibi_bias_fn = partial(get_alibi_bias, alibi_biases=alibi_biases)
super().__init__(
modality_cfg=modality_cfg,
embed_dim=embed_dim,
local_encoder=local_encoder,
project_features=project_features,
fixed_positional_encoder=None,
relative_positional_encoder=positional_encoder,
context_encoder=context_encoder,
decoder=decoder,
get_alibi_bias=alibi_bias_fn,
)
def convert_padding_mask(self, x, padding_mask):
def get_feat_extract_output_lengths(input_lengths: torch.LongTensor):
"""
Computes the output length of the convolutional layers
"""
def _conv_out_length(input_length, kernel_size, stride):
return torch.floor((input_length - kernel_size) / stride + 1)
for i in range(len(self.feature_enc_layers)):
input_lengths = _conv_out_length(
input_lengths,
self.feature_enc_layers[i][1],
self.feature_enc_layers[i][2],
)
return input_lengths.to(torch.long)
if padding_mask is not None:
input_lengths = (1 - padding_mask.long()).sum(-1)
# apply conv formula to get real output_lengths
output_lengths = get_feat_extract_output_lengths(input_lengths)
if padding_mask.any():
padding_mask = torch.zeros(x.shape[:2], dtype=x.dtype, device=x.device)
# these two operations makes sure that all values
# before the output lengths indices are attended to
padding_mask[
(
torch.arange(padding_mask.shape[0], device=padding_mask.device),
output_lengths - 1,
)
] = 1
padding_mask = (
1 - padding_mask.flip([-1]).cumsum(-1).flip([-1])
).bool()
else:
padding_mask = torch.zeros(
x.shape[:2], dtype=torch.bool, device=x.device
)
return padding_mask
def reset_parameters(self):
super().reset_parameters()
for mod in self.project_features.children():
if isinstance(mod, nn.Linear):
mod.reset_parameters()
if self.decoder is not None:
self.decoder.reset_parameters()

View File

@ -0,0 +1,684 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import namedtuple
from dataclasses import dataclass
from functools import partial
from omegaconf import MISSING, II
from typing import Optional, Callable
from fairseq.data.data_utils import compute_mask_indices
from fairseq.modules import GradMultiply
from fairseq.utils import index_put
from examples.data2vec.data.modality import Modality
from .modules import D2vDecoderConfig
logger = logging.getLogger(__name__)
@dataclass
class D2vModalityConfig:
type: Modality = MISSING
prenet_depth: int = 4
prenet_layerdrop: float = 0
prenet_dropout: float = 0
start_drop_path_rate: float = 0
end_drop_path_rate: float = 0
num_extra_tokens: int = 0
init_extra_token_zero: bool = True
mask_noise_std: float = 0.01
mask_prob_min: Optional[float] = None
mask_prob: float = 0.7
inverse_mask: bool = False
mask_prob_adjust: float = 0
keep_masked_pct: float = 0
mask_length: int = 5
add_masks: bool = False
remove_masks: bool = False
mask_dropout: float = 0.0
encoder_zero_mask: bool = True
mask_channel_prob: float = 0.0
mask_channel_length: int = 64
ema_local_encoder: bool = False # used in data2vec_multi
local_grad_mult: float = 1.0
use_alibi_encoder: bool = False
alibi_scale: float = 1.0
learned_alibi: bool = False
alibi_max_pos: Optional[int] = None
learned_alibi_scale: bool = False
learned_alibi_scale_per_head: bool = False
learned_alibi_scale_per_layer: bool = False
num_alibi_heads: int = II("model.num_heads")
model_depth: int = II("model.depth")
decoder: Optional[D2vDecoderConfig] = D2vDecoderConfig()
MaskSeed = namedtuple("MaskSeed", ["seed", "update", "ids"])
MaskInfo = namedtuple("MaskInfo", ["x_unmasked", "mask", "ids_restore", "ids_keep"])
class ModalitySpecificEncoder(nn.Module):
def __init__(
self,
modality_cfg: D2vModalityConfig,
embed_dim: int,
local_encoder: nn.Module,
project_features: nn.Module,
fixed_positional_encoder: Optional[nn.Module],
relative_positional_encoder: Optional[nn.Module],
context_encoder: nn.Module,
decoder: nn.Module,
get_alibi_bias: Optional[Callable[[int, int, str, str], torch.Tensor]],
):
super().__init__()
self.modality_cfg = modality_cfg
self.local_encoder = local_encoder
self.project_features = project_features
self.fixed_positional_encoder = fixed_positional_encoder
self.relative_positional_encoder = relative_positional_encoder
self.context_encoder = context_encoder
self.decoder = decoder
self.get_alibi_bias = get_alibi_bias if modality_cfg.use_alibi_encoder else None
self.local_grad_mult = self.modality_cfg.local_grad_mult
self.extra_tokens = None
if modality_cfg.num_extra_tokens > 0:
self.extra_tokens = nn.Parameter(
torch.zeros(1, modality_cfg.num_extra_tokens, embed_dim)
)
if not modality_cfg.init_extra_token_zero:
nn.init.normal_(self.extra_tokens)
elif self.extra_tokens.size(1) > 1:
nn.init.normal_(self.extra_tokens[:, 1:])
self.alibi_scale = None
if self.get_alibi_bias is not None:
self.alibi_scale = nn.Parameter(
torch.full(
(
(modality_cfg.prenet_depth + modality_cfg.model_depth)
if modality_cfg.learned_alibi_scale_per_layer
else 1,
1,
self.modality_cfg.num_alibi_heads
if modality_cfg.learned_alibi_scale_per_head
else 1,
1,
1,
),
modality_cfg.alibi_scale,
dtype=torch.float,
),
requires_grad=modality_cfg.learned_alibi_scale,
)
if modality_cfg.learned_alibi and self.get_alibi_bias is not None:
assert modality_cfg.alibi_max_pos is not None
alibi_bias = self.get_alibi_bias(
batch_size=1,
time_steps=modality_cfg.alibi_max_pos,
heads=modality_cfg.num_alibi_heads,
scale=1.0,
dtype=torch.float,
device="cpu",
)
self.alibi_bias = nn.Parameter(alibi_bias)
self.get_alibi_bias = partial(
_learned_alibi_bias, alibi_bias=self.alibi_bias
)
def upgrade_state_dict_named(self, state_dict, name):
k = f"{name}.alibi_scale"
if k in state_dict and state_dict[k].dim() == 4:
state_dict[k] = state_dict[k].unsqueeze(0)
return state_dict
def convert_padding_mask(self, x, padding_mask):
return padding_mask
def decoder_input(self, x, mask_info: MaskInfo):
inp_drop = self.modality_cfg.decoder.input_dropout
if inp_drop > 0:
x = F.dropout(x, inp_drop, training=self.training, inplace=True)
num_extra = self.modality_cfg.num_extra_tokens
if mask_info is not None:
num_masked = mask_info.ids_restore.shape[1] - x.shape[1] + num_extra
mask_tokens = x.new_empty(
x.size(0),
num_masked,
x.size(-1),
).normal_(0, self.modality_cfg.mask_noise_std)
x_ = torch.cat([x[:, num_extra:], mask_tokens], dim=1)
x = torch.gather(x_, dim=1, index=mask_info.ids_restore)
if self.modality_cfg.decoder.add_positions_masked:
assert self.fixed_positional_encoder is not None
pos = self.fixed_positional_encoder(x, None)
x = x + (pos * mask_info.mask.unsqueeze(-1))
else:
x = x[:, num_extra:]
if self.modality_cfg.decoder.add_positions_all:
assert self.fixed_positional_encoder is not None
x = x + self.fixed_positional_encoder(x, None)
return x, mask_info
def local_features(self, features):
if self.local_grad_mult > 0:
if self.local_grad_mult == 1.0:
x = self.local_encoder(features)
else:
x = GradMultiply.apply(
self.local_encoder(features), self.local_grad_mult
)
else:
with torch.no_grad():
x = self.local_encoder(features)
x = self.project_features(x)
return x
def contextualized_features(
self,
x,
padding_mask,
mask,
remove_masked,
clone_batch: int = 1,
mask_seeds: Optional[torch.Tensor] = None,
precomputed_mask=None,
):
if padding_mask is not None:
padding_mask = self.convert_padding_mask(x, padding_mask)
local_features = x
if mask and clone_batch == 1:
local_features = local_features.clone()
orig_B, orig_T, _ = x.shape
pre_mask_B = orig_B
mask_info = None
x_pos = None
if self.fixed_positional_encoder is not None:
x = x + self.fixed_positional_encoder(x, padding_mask)
if mask:
if clone_batch > 1:
x = x.repeat_interleave(clone_batch, 0)
if mask_seeds is not None:
clone_hash = [
int(hash((mask_seeds.seed, ind)) % 1e10)
for ind in range(clone_batch - 1)
]
clone_hash = torch.tensor([0] + clone_hash).long().view(1, -1)
id = mask_seeds.ids
id = id.repeat_interleave(clone_batch, 0)
id = id.view(-1, clone_batch) + clone_hash.to(id)
id = id.view(-1)
mask_seeds = MaskSeed(
seed=mask_seeds.seed, update=mask_seeds.update, ids=id
)
if padding_mask is not None:
padding_mask = padding_mask.repeat_interleave(clone_batch, 0)
x, mask_info = self.compute_mask(
x,
padding_mask,
mask_seed=mask_seeds,
apply=self.relative_positional_encoder is not None or not remove_masked,
precomputed_mask=precomputed_mask,
)
if self.relative_positional_encoder is not None:
x_pos = self.relative_positional_encoder(x)
masked_padding_mask = padding_mask
if mask and remove_masked:
x = mask_info.x_unmasked
if x_pos is not None:
x = x + gather_unmasked(x_pos, mask_info)
if padding_mask is not None and padding_mask.any():
masked_padding_mask = gather_unmasked_mask(padding_mask, mask_info)
if not masked_padding_mask.any():
masked_padding_mask = None
else:
masked_padding_mask = None
elif x_pos is not None:
x = x + x_pos
alibi_bias = None
alibi_scale = self.alibi_scale
if self.get_alibi_bias is not None:
alibi_bias = self.get_alibi_bias(
batch_size=pre_mask_B,
time_steps=orig_T,
heads=self.modality_cfg.num_alibi_heads,
dtype=torch.float32,
device=x.device,
)
if alibi_scale is not None:
alibi_scale = alibi_scale.clamp_min(0)
if alibi_scale.size(0) == 1:
alibi_bias = alibi_bias * alibi_scale.squeeze(0).type_as(alibi_bias)
alibi_scale = None
if clone_batch > 1:
alibi_bias = alibi_bias.repeat_interleave(clone_batch, 0)
if mask_info is not None and remove_masked:
alibi_bias = masked_alibi(alibi_bias, mask_info)
if self.extra_tokens is not None:
num = self.extra_tokens.size(1)
x = torch.cat([self.extra_tokens.expand(x.size(0), -1, -1), x], dim=1)
if masked_padding_mask is not None:
# B x T
masked_padding_mask = F.pad(masked_padding_mask, (num, 0))
if alibi_bias is not None:
# B x H x T x T
alibi_bias = F.pad(alibi_bias, (num, 0, num, 0))
x = self.context_encoder(
x,
masked_padding_mask,
alibi_bias,
alibi_scale[: self.modality_cfg.prenet_depth]
if alibi_scale is not None
else None,
)
return {
"x": x,
"local_features": local_features,
"padding_mask": masked_padding_mask,
"alibi_bias": alibi_bias,
"alibi_scale": alibi_scale[self.modality_cfg.prenet_depth :]
if alibi_scale is not None and alibi_scale.size(0) > 1
else alibi_scale,
"encoder_mask": mask_info,
}
def forward(
self,
features,
padding_mask,
mask: bool,
remove_masked: bool,
clone_batch: int = 1,
mask_seeds: Optional[torch.Tensor] = None,
precomputed_mask=None,
):
x = self.local_features(features)
return self.contextualized_features(
x,
padding_mask,
mask,
remove_masked,
clone_batch,
mask_seeds,
precomputed_mask,
)
def reset_parameters(self):
pass
def compute_mask(
self,
x,
padding_mask,
mask_seed: Optional[MaskSeed],
apply,
precomputed_mask,
):
if precomputed_mask is not None:
mask = precomputed_mask
mask_info = self.make_maskinfo(x, mask)
else:
B, T, C = x.shape
cfg = self.modality_cfg
mask_prob = cfg.mask_prob
if (
cfg.mask_prob_min is not None
and cfg.mask_prob_min >= 0
and cfg.mask_prob_min < mask_prob
):
mask_prob = np.random.uniform(cfg.mask_prob_min, mask_prob)
if mask_prob > 0:
if cfg.mask_length == 1:
mask_info = random_masking(x, mask_prob, mask_seed)
else:
if self.modality_cfg.inverse_mask:
mask_prob = 1 - mask_prob
mask = compute_mask_indices(
(B, T),
padding_mask,
mask_prob,
cfg.mask_length,
min_masks=1,
require_same_masks=True,
mask_dropout=cfg.mask_dropout,
add_masks=cfg.add_masks,
seed=mask_seed.seed if mask_seed is not None else None,
epoch=mask_seed.update if mask_seed is not None else None,
indices=mask_seed.ids if mask_seed is not None else None,
)
mask = torch.from_numpy(mask).to(device=x.device)
if self.modality_cfg.inverse_mask:
mask = 1 - mask
mask_info = self.make_maskinfo(x, mask)
else:
mask_info = None
if apply:
x = self.apply_mask(x, mask_info)
return x, mask_info
def make_maskinfo(self, x, mask, shape=None):
if shape is None:
B, T, D = x.shape
else:
B, T, D = shape
mask = mask.to(torch.uint8)
ids_shuffle = mask.argsort(dim=1)
ids_restore = ids_shuffle.argsort(dim=1).unsqueeze(-1).expand(-1, -1, D)
len_keep = T - mask[0].sum()
if self.modality_cfg.keep_masked_pct > 0:
len_keep += round((T - int(len_keep)) * self.modality_cfg.keep_masked_pct)
ids_keep = ids_shuffle[:, :len_keep]
if shape is not None:
x_unmasked = None
else:
ids_keep = ids_keep.unsqueeze(-1).expand(-1, -1, D)
x_unmasked = torch.gather(x, dim=1, index=ids_keep)
mask_info = MaskInfo(
x_unmasked=x_unmasked,
mask=mask,
ids_restore=ids_restore,
ids_keep=ids_keep,
)
return mask_info
def apply_mask(self, x, mask_info):
cfg = self.modality_cfg
B, T, C = x.shape
if mask_info is not None:
mask = mask_info.mask
if cfg.encoder_zero_mask:
x = x * (1 - mask.type_as(x).unsqueeze(-1))
else:
num_masks = mask.sum().item()
masks = x.new_empty(num_masks, x.size(-1)).normal_(
0, cfg.mask_noise_std
)
x = index_put(x, mask, masks)
if cfg.mask_channel_prob > 0:
mask_channel = compute_mask_indices(
(B, C),
None,
cfg.mask_channel_prob,
cfg.mask_channel_length,
)
mask_channel = (
torch.from_numpy(mask_channel)
.to(x.device)
.unsqueeze(1)
.expand(-1, T, -1)
)
x = index_put(x, mask_channel, 0)
return x
def remove_pretraining_modules(self, keep_decoder=False):
if not keep_decoder:
self.decoder = None
def get_annealed_rate(start, end, curr_step, total_steps):
if curr_step >= total_steps:
return end
r = end - start
pct_remaining = 1 - curr_step / total_steps
return end - r * pct_remaining
# adapted from MAE
def random_masking(x, mask_ratio, mask_seed: Optional[MaskSeed]):
N, L, D = x.shape # batch, length, dim
len_keep = int(L * (1 - mask_ratio))
generator = None
if mask_seed is not None:
seed = int(
hash((mask_seed.seed, mask_seed.update, mask_seed.ids.sum().item())) % 1e6
)
generator = torch.Generator(device=x.device)
generator.manual_seed(seed)
noise = torch.rand(N, L, generator=generator, device=x.device) # noise in [0, 1]
# sort noise for each sample
ids_shuffle = noise.argsort(dim=1) # ascend: small is keep, large is remove
ids_restore = ids_shuffle.argsort(dim=1)
# keep the first subset
ids_keep = ids_shuffle[:, :len_keep]
ids_keep = ids_keep.unsqueeze(-1).expand(-1, -1, D)
x_unmasked = torch.gather(x, dim=1, index=ids_keep)
# generate the binary mask: 0 is keep, 1 is remove
mask = torch.ones([N, L], dtype=x.dtype, device=x.device)
mask[:, :len_keep] = 0
# unshuffle to get the binary mask
mask = torch.gather(mask, dim=1, index=ids_restore)
ids_restore = ids_restore.unsqueeze(-1).expand(-1, -1, D)
return MaskInfo(
x_unmasked=x_unmasked, mask=mask, ids_restore=ids_restore, ids_keep=ids_keep
)
def gather_unmasked(x: torch.Tensor, mask_info: MaskInfo) -> torch.Tensor:
return torch.gather(
x,
dim=1,
index=mask_info.ids_keep,
)
def gather_unmasked_mask(x: torch.Tensor, mask_info: MaskInfo) -> torch.Tensor:
return torch.gather(
x,
dim=1,
index=mask_info.ids_keep[..., 0], # ignore the feature dimension
)
def get_alibi(
max_positions: int,
attention_heads: int,
dims: int = 1,
distance: str = "manhattan",
):
def get_slopes(n):
def get_slopes_power_of_2(n):
start = 2 ** (-(2 ** -(math.log2(n) - 3)))
ratio = start
return [start * ratio**i for i in range(n)]
# In the paper, we only train models that have 2^a heads for some
# a. This function has some good properties that only occur when
# the input is a power of 2. To maintain that even when the number
# of heads is not a power of 2, we use this workaround.
if math.log2(n).is_integer():
return get_slopes_power_of_2(n)
else:
closest_power_of_2 = 2 ** math.floor(math.log2(n))
return (
get_slopes_power_of_2(closest_power_of_2)
+ get_slopes(2 * closest_power_of_2)[0::2][: n - closest_power_of_2]
)
maxpos = max_positions
attn_heads = attention_heads
slopes = torch.Tensor(get_slopes(attn_heads))
if dims == 1:
# prepare alibi position linear bias. Note that wav2vec2 is non
# autoregressive model so we want a symmetric mask with 0 on the
# diagonal and other wise linear decreasing valuees
pos_bias = (
torch.abs(
torch.arange(maxpos).unsqueeze(0) - torch.arange(maxpos).unsqueeze(1)
)
* -1
)
elif dims == 2:
if distance == "manhattan":
df = lambda x1, y1, x2, y2: abs(x1 - x2) + abs(y1 - y2)
elif distance == "euclidean":
df = lambda x1, y1, x2, y2: math.sqrt((x1 - x2) ** 2 + (y1 - y2) ** 2)
n = math.sqrt(max_positions)
assert n.is_integer(), n
n = int(n)
pos_bias = torch.zeros((max_positions, max_positions))
for i in range(n):
for j in range(n):
for k in range(n):
for l in range(n):
new_x = i * n + j
new_y = k * n + l
pos_bias[new_x, new_y] = -df(i, j, k, l)
else:
raise Exception(f"unsupported number of alibi dims: {dims}")
alibi_bias = slopes.unsqueeze(1).unsqueeze(1) * pos_bias.unsqueeze(0).expand(
attn_heads, -1, -1
)
return alibi_bias
def get_alibi_bias(
alibi_biases,
batch_size,
time_steps,
heads,
dtype,
device,
dims=1,
distance="manhattan",
):
cache_key = f"{dims}_{heads}_{distance}"
buffered = alibi_biases.get(cache_key, None)
target_size = heads * batch_size
if (
buffered is None
or buffered.size(0) < target_size
or buffered.size(1) < time_steps
or buffered.dtype != dtype
or buffered.device != device
):
bt = max(time_steps, buffered.size(1) if buffered is not None else 0)
bn = max(target_size, buffered.size(0) if buffered is not None else 0) // heads
buffered = (
get_alibi(bt, heads, dims=dims, distance=distance)
.to(dtype=dtype, device=device)
.repeat(bn, 1, 1)
)
alibi_biases[cache_key] = buffered
b = buffered[:target_size, :time_steps, :time_steps]
b = b.view(batch_size, heads, time_steps, time_steps)
return b
def _learned_alibi_bias(
alibi_bias,
batch_size,
time_steps,
heads,
scale,
dtype,
device,
):
assert alibi_bias.size(1) == heads, alibi_bias.shape
assert alibi_bias.dtype == dtype, alibi_bias.dtype
assert alibi_bias.device == device, alibi_bias.device
if alibi_bias.size(-1) < time_steps:
psz = math.ceil((time_steps - alibi_bias.size(-1)) / 2)
alibi_bias = F.pad(alibi_bias, (psz, psz, psz, psz), mode="replicate")
alibi_bias = alibi_bias.expand(batch_size, -1, -1, -1) * scale
return alibi_bias[..., :time_steps, :time_steps]
def masked_alibi(alibi_bias, mask_info):
H = alibi_bias.size(1)
orig_bias = alibi_bias
index = mask_info.ids_keep.unsqueeze(1)[..., 0].unsqueeze(-1)
alibi_bias = torch.gather(
orig_bias,
dim=-2,
index=index.expand(-1, H, -1, mask_info.ids_restore.size(1)),
)
alibi_bias = torch.gather(
alibi_bias,
dim=-1,
index=index.transpose(-1, -2).expand(-1, H, alibi_bias.size(-2), -1),
)
return alibi_bias

Some files were not shown because too many files have changed in this diff Show More