mirror of
https://github.com/facebookresearch/fairseq.git
synced 2024-10-05 13:17:39 +03:00
data2vec v2.0 (#4903)
data2v2c 2.0 Co-authored-by: Arun Babu <arbabu@fb.com> Co-authored-by: Wei-Ning Hsu <wnhsu@csail.mit.edu>
This commit is contained in:
parent
0f33ccf7cf
commit
d871f6169f
@ -1,3 +1,125 @@
|
||||
# data2vec 2.0
|
||||
|
||||
data2vec 2.0 improves the training efficiency of the original data2vec algorithm. We make the following improvements for efficiency considerations - we forward only the unmasked timesteps through the encoder, we use convolutional decoder and we use multimasking to amortize the compute overhead of the teacher model. You can find details in [Efficient Self-supervised Learning with Contextualized Target Representations for Vision, Speech and Language](https://ai.facebook.com/research/xyz)
|
||||
|
||||
## Pretrained and finetuned models
|
||||
### Vision
|
||||
| Model | Finetuning split | Link
|
||||
|---|---|---
|
||||
data2vec ViT-B | No fine-tuning | [download](https://dl.fbaipublicfiles.com/fairseq/data2vec2/base_imagenet.pt)
|
||||
data2vec ViT-B | Imagenet-1K | [download](https://dl.fbaipublicfiles.com/fairseq/data2vec2/base_imagenet_ft.pt)
|
||||
data2vec ViT-L | No fine-tuning | [download](https://dl.fbaipublicfiles.com/fairseq/data2vec2/large_imagenet.pt)
|
||||
data2vec ViT-L | Imagenet-1K | [download](https://dl.fbaipublicfiles.com/fairseq/data2vec2/large_imagenet_ft.pt)
|
||||
data2vec ViT-H | No fine-tuning | [download](https://dl.fbaipublicfiles.com/fairseq/data2vec2/huge_imagenet.pt)
|
||||
data2vec ViT-H | Imagenet-1K | [download](https://dl.fbaipublicfiles.com/fairseq/data2vec2/huge_imagenet_ft.pt)
|
||||
|
||||
Vision models only are license under CC-BY-NC.
|
||||
### Speech
|
||||
|
||||
| Model | Finetuning split | Dataset | Link
|
||||
|---|---|---|---
|
||||
data2vec Base | No fine-tuning | [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/data2vec2/base_libri.pt)
|
||||
data2vec Base | 960 hours | [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/data2vec2/base_libri_960h.pt)
|
||||
data2vec Large | No fine-tuning | [Libri-light](https://github.com/facebookresearch/libri-light) | [download](https://dl.fbaipublicfiles.com/fairseq/data2vec2/large_vox.pt)
|
||||
data2vec Large | 960 hours | [Libri-light](https://github.com/facebookresearch/libri-light) | [download](https://dl.fbaipublicfiles.com/fairseq/data2vec2/large_vox_960h.pt)
|
||||
|
||||
### NLP
|
||||
|
||||
Model | Fine-tuning data | Dataset | Link
|
||||
|---|---|---|---|
|
||||
data2vec Base | No fine-tuning | Books + Wiki | [download](https://dl.fbaipublicfiles.com/fairseq/data2vec2/nlp_base.pt)
|
||||
|
||||
[//]: # (## Data Preparation)
|
||||
|
||||
[//]: # ()
|
||||
[//]: # (### Vision)
|
||||
|
||||
[//]: # (add details)
|
||||
|
||||
[//]: # (### Speech)
|
||||
|
||||
[//]: # (add details)
|
||||
|
||||
[//]: # ()
|
||||
[//]: # (### NLP)
|
||||
|
||||
[//]: # (add details)
|
||||
|
||||
|
||||
## Commands to train different models using data2vec 2.0
|
||||
|
||||
### Vision
|
||||
|
||||
Commands to pretrain different model configurations
|
||||
```shell script
|
||||
$ python fairseq_cli/hydra_train.py -m --config-dir examples/data2vec/config/v2 \
|
||||
--config-name base_images_only_task task.data=/path/to/dir
|
||||
```
|
||||
|
||||
```shell script
|
||||
$ python fairseq_cli/hydra_train.py -m --config-dir examples/data2vec/config/v2 \
|
||||
--config-name large_images_only_task task.data=/path/to/dir
|
||||
```
|
||||
|
||||
```shell script
|
||||
$ python fairseq_cli/hydra_train.py -m --config-dir examples/data2vec/config/v2 \
|
||||
--config-name huge_images14_only_task task.data=/path/to/dir
|
||||
```
|
||||
|
||||
Commands to finetune different model configurations
|
||||
|
||||
```shell script
|
||||
$ python fairseq_cli/hydra_train.py -m --config-dir examples/data2vec/config/vision/finetuning \
|
||||
--config-name mae_imagenet_clean task.data=/path/to/dir model.model_path=/path/to/pretrained/model
|
||||
```
|
||||
|
||||
```shell script
|
||||
$ python fairseq_cli/hydra_train.py -m --config-dir examples/data2vec/config/vision/finetuning \
|
||||
--config-name mae_imagenet_large_clean task.data=/path/to/dir model.model_path=/path/to/pretrained/model
|
||||
```
|
||||
|
||||
```shell script
|
||||
$ python fairseq_cli/hydra_train.py -m --config-dir examples/data2vec/config/vision/finetuning \
|
||||
--config-name mae_imagenet_huge_clean task.data=/path/to/dir model.model_path=/path/to/pretrained/model
|
||||
```
|
||||
|
||||
### Speech
|
||||
|
||||
```shell script
|
||||
$ python fairseq_cli/hydra_train.py -m --config-dir examples/data2vec/config/v2 \
|
||||
--config-name base_audio_only_task task.data=/path/to/manifests
|
||||
```
|
||||
|
||||
```shell script
|
||||
$ python fairseq_cli/hydra_train.py -m --config-dir examples/data2vec/config/v2 \
|
||||
--config-name large_audio_only_task task.data=/path/to/manifests
|
||||
```
|
||||
|
||||
Finetuning:
|
||||
|
||||
```shell script
|
||||
$ python fairseq_cli/hydra_train.py -m --config-dir examples/wav2vec/config/finetuning --config-name vox_10h \
|
||||
task.data=/path/to/manifests model.w2v_path=/path/to/pretrained/model common.user_dir=examples/data2vec
|
||||
```
|
||||
|
||||
Replace vox_10h with the right config depending on your model and fine-tuning split.
|
||||
See examples/wav2vec/config/finetuning for all available configs.
|
||||
|
||||
### NLP
|
||||
|
||||
Commands to pretrain
|
||||
```shell script
|
||||
$ python fairseq_cli/hydra_train.py -m --config-dir examples/data2vec/config/v2 \
|
||||
--config-name base_text_only_task task.data=/path/to/file
|
||||
```
|
||||
|
||||
Commands to fine-tune all GLUE tasks
|
||||
```shell script
|
||||
$ task=cola # choose from [cola|qnli|mrpc|rte|sst_2|mnli|qqp|sts_b]
|
||||
$ lr=1e-5 # sweep [1e-5|2e-5|4e-5|6e-5] for each task
|
||||
$ python fairseq_cli/hydra_train.py -m --config-dir examples/data2vec/config/v2/text_finetuning \
|
||||
--config-name $task task.data=/path/to/file model.model_path=/path/to/pretrained/model "optimization.lr=[${lr}]"
|
||||
```
|
||||
|
||||
# data2vec
|
||||
|
||||
|
0
examples/data2vec/__init__.py
Normal file
0
examples/data2vec/__init__.py
Normal file
@ -0,0 +1,70 @@
|
||||
# @package _group_
|
||||
|
||||
common:
|
||||
fp16: true
|
||||
log_format: json
|
||||
log_interval: 200
|
||||
all_gather_list_size: 70000
|
||||
tensorboard_logdir: tb
|
||||
min_loss_scale: 1e-6
|
||||
|
||||
checkpoint:
|
||||
save_interval: 1
|
||||
no_epoch_checkpoints: true
|
||||
best_checkpoint_metric: mAP
|
||||
maximize_best_checkpoint_metric: true
|
||||
|
||||
task:
|
||||
_name: audio_classification
|
||||
data: ???
|
||||
normalize: true
|
||||
labels: lbl
|
||||
|
||||
dataset:
|
||||
num_workers: 6
|
||||
max_tokens: 2560000
|
||||
skip_invalid_size_inputs_valid_test: true
|
||||
valid_subset: eval
|
||||
validate_interval: 5
|
||||
|
||||
distributed_training:
|
||||
ddp_backend: legacy_ddp
|
||||
distributed_world_size: 8
|
||||
|
||||
criterion:
|
||||
_name: model
|
||||
can_sum: false
|
||||
log_keys:
|
||||
- _predictions
|
||||
- _targets
|
||||
|
||||
optimization:
|
||||
max_update: 30000
|
||||
lr: [0.00006] # scratch 53-5
|
||||
|
||||
optimizer:
|
||||
_name: adam
|
||||
adam_betas: (0.9,0.98)
|
||||
adam_eps: 1e-08
|
||||
|
||||
lr_scheduler:
|
||||
_name: cosine
|
||||
warmup_updates: 5000
|
||||
|
||||
model:
|
||||
_name: audio_classification
|
||||
model_path: ???
|
||||
apply_mask: true
|
||||
mask_prob: 0.6
|
||||
mask_length: 5 # scratch 1
|
||||
mask_channel_prob: 0
|
||||
mask_channel_length: 64
|
||||
layerdrop: 0.1
|
||||
dropout: 0.1
|
||||
activation_dropout: 0.1
|
||||
attention_dropout: 0.2
|
||||
feature_grad_mult: 0 # scratch 1
|
||||
label_mixup: true
|
||||
source_mixup: 0.5
|
||||
prediction_mode: lin_softmax # scratch average_sigmoid
|
||||
|
@ -0,0 +1,35 @@
|
||||
# @package _global_
|
||||
|
||||
hydra:
|
||||
job:
|
||||
config:
|
||||
override_dirname:
|
||||
kv_sep: ':'
|
||||
item_sep: '/'
|
||||
exclude_keys:
|
||||
- run_config
|
||||
- distributed_training.distributed_port
|
||||
- distributed_training.distributed_world_size
|
||||
- model.pretrained_model_path
|
||||
- model.target_network_path
|
||||
- next_script
|
||||
- task.cache_in_scratch
|
||||
- task.data
|
||||
- checkpoint.save_interval_updates
|
||||
- checkpoint.keep_interval_updates
|
||||
- checkpoint.save_on_overflow
|
||||
sweep:
|
||||
dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
|
||||
subdir: ''
|
||||
launcher:
|
||||
submitit_folder: ${hydra.sweep.dir}
|
||||
timeout_min: 4320
|
||||
cpus_per_task: 10
|
||||
gpus_per_node: 8
|
||||
tasks_per_node: 8
|
||||
mem_gb: 450
|
||||
nodes: 1
|
||||
name: ${env:PREFIX}_${hydra.job.config_name}
|
||||
partition: devlab,learnlab,learnfair,scavenge
|
||||
constraint: volta32gb,ib4
|
||||
max_num_timeout: 30
|
@ -0,0 +1,35 @@
|
||||
# @package _global_
|
||||
|
||||
hydra:
|
||||
job:
|
||||
config:
|
||||
override_dirname:
|
||||
kv_sep: ':'
|
||||
item_sep: '/'
|
||||
exclude_keys:
|
||||
- run_config
|
||||
- distributed_training.distributed_port
|
||||
- distributed_training.distributed_world_size
|
||||
- model.pretrained_model_path
|
||||
- model.target_network_path
|
||||
- next_script
|
||||
- task.cache_in_scratch
|
||||
- task.data
|
||||
- checkpoint.save_interval_updates
|
||||
- checkpoint.keep_interval_updates
|
||||
- checkpoint.save_on_overflow
|
||||
sweep:
|
||||
dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
|
||||
subdir: ''
|
||||
launcher:
|
||||
submitit_folder: ${hydra.sweep.dir}
|
||||
timeout_min: 4320
|
||||
cpus_per_task: 10
|
||||
gpus_per_node: 1
|
||||
tasks_per_node: 1
|
||||
mem_gb: 100
|
||||
nodes: 1
|
||||
name: ${env:PREFIX}_${hydra.job.config_name}
|
||||
partition: devlab,learnlab,learnfair,scavenge
|
||||
constraint: volta32gb
|
||||
max_num_timeout: 30
|
@ -0,0 +1,35 @@
|
||||
# @package _global_
|
||||
|
||||
hydra:
|
||||
job:
|
||||
config:
|
||||
override_dirname:
|
||||
kv_sep: ':'
|
||||
item_sep: '/'
|
||||
exclude_keys:
|
||||
- run_config
|
||||
- distributed_training.distributed_port
|
||||
- distributed_training.distributed_world_size
|
||||
- model.pretrained_model_path
|
||||
- model.target_network_path
|
||||
- next_script
|
||||
- task.cache_in_scratch
|
||||
- task.data
|
||||
- checkpoint.save_interval_updates
|
||||
- checkpoint.keep_interval_updates
|
||||
- checkpoint.save_on_overflow
|
||||
sweep:
|
||||
dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
|
||||
subdir: ''
|
||||
launcher:
|
||||
submitit_folder: ${hydra.sweep.dir}
|
||||
timeout_min: 4320
|
||||
cpus_per_task: 10
|
||||
gpus_per_node: 8
|
||||
tasks_per_node: 8
|
||||
mem_gb: 450
|
||||
nodes: 2
|
||||
name: ${env:PREFIX}_${hydra.job.config_name}
|
||||
partition: devlab,learnlab,learnfair,scavenge
|
||||
constraint: volta32gb,ib4
|
||||
max_num_timeout: 30
|
91
examples/data2vec/config/audio/pretraining/audioset.yaml
Normal file
91
examples/data2vec/config/audio/pretraining/audioset.yaml
Normal file
@ -0,0 +1,91 @@
|
||||
# @package _group_
|
||||
|
||||
common:
|
||||
fp16: true
|
||||
log_format: json
|
||||
log_interval: 200
|
||||
tensorboard_logdir: tb
|
||||
min_loss_scale: 1e-6
|
||||
user_dir: /private/home/abaevski/fairseq-py/examples/data2vec
|
||||
|
||||
checkpoint:
|
||||
save_interval: 1
|
||||
save_interval_updates: 25000
|
||||
keep_interval_updates: 1
|
||||
no_epoch_checkpoints: true
|
||||
|
||||
task:
|
||||
_name: audio_pretraining
|
||||
data: /private/home/abaevski/data/audioset
|
||||
max_sample_size: 320000
|
||||
min_sample_size: 32000
|
||||
normalize: true
|
||||
|
||||
dataset:
|
||||
num_workers: 6
|
||||
max_tokens: 3400000
|
||||
skip_invalid_size_inputs_valid_test: true
|
||||
validate_interval: 5
|
||||
required_batch_size_multiple: 1
|
||||
disable_validation: true
|
||||
|
||||
distributed_training:
|
||||
distributed_world_size: 24
|
||||
ddp_backend: legacy_ddp
|
||||
|
||||
criterion:
|
||||
_name: model
|
||||
log_keys:
|
||||
- ema_decay
|
||||
- target_var
|
||||
- pred_var
|
||||
# - avg_self_attn
|
||||
# - weights
|
||||
|
||||
optimization:
|
||||
max_update: 200000
|
||||
lr: [0.0005]
|
||||
|
||||
optimizer:
|
||||
_name: adam
|
||||
adam_betas: (0.9,0.98)
|
||||
adam_eps: 1e-06
|
||||
weight_decay: 0.01
|
||||
|
||||
lr_scheduler:
|
||||
_name: cosine
|
||||
warmup_updates: 10000
|
||||
|
||||
model:
|
||||
_name: data2vec_audio
|
||||
extractor_mode: layer_norm
|
||||
encoder_layerdrop: 0.05
|
||||
dropout_input: 0.0
|
||||
dropout_features: 0.0
|
||||
feature_grad_mult: 1.0
|
||||
encoder_embed_dim: 768
|
||||
|
||||
mask_prob: 0.65
|
||||
mask_length: 10
|
||||
|
||||
loss_beta: 0
|
||||
loss_scale: null
|
||||
|
||||
instance_norm_target_layer: true
|
||||
layer_norm_targets: true
|
||||
average_top_k_layers: 12
|
||||
|
||||
self_attn_norm_type: deepnorm
|
||||
final_norm_type: deepnorm
|
||||
|
||||
pos_conv_depth: 5
|
||||
conv_pos: 95
|
||||
|
||||
ema_decay: 0.999
|
||||
ema_end_decay: 0.9999
|
||||
ema_anneal_end_step: 30000
|
||||
ema_transformer_only: true
|
||||
ema_layers_only: false
|
||||
|
||||
require_same_masks: true
|
||||
mask_dropout: 0
|
@ -0,0 +1,15 @@
|
||||
# @package _global_
|
||||
hydra:
|
||||
sweep:
|
||||
dir: ${env:PWD}/tmp_dbg/${now:%H-%M-%S}
|
||||
|
||||
distributed_training:
|
||||
distributed_world_size: 1
|
||||
nprocs_per_node: 1
|
||||
distributed_port: -1
|
||||
|
||||
common:
|
||||
log_interval: 1
|
||||
|
||||
dataset:
|
||||
num_workers: 0
|
@ -0,0 +1,37 @@
|
||||
# @package _global_
|
||||
|
||||
hydra:
|
||||
job:
|
||||
config:
|
||||
override_dirname:
|
||||
kv_sep: ':'
|
||||
item_sep: '/'
|
||||
exclude_keys:
|
||||
- run_config
|
||||
- distributed_training.distributed_port
|
||||
- distributed_training.distributed_world_size
|
||||
- model.pretrained_model_path
|
||||
- model.target_network_path
|
||||
- next_script
|
||||
- task.cache_in_scratch
|
||||
- task.data
|
||||
- checkpoint.save_interval_updates
|
||||
- checkpoint.keep_interval_updates
|
||||
- checkpoint.save_on_overflow
|
||||
- common.log_interval
|
||||
- common.user_dir
|
||||
sweep:
|
||||
dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
|
||||
subdir: ''
|
||||
launcher:
|
||||
submitit_folder: ${hydra.sweep.dir}
|
||||
timeout_min: 4320
|
||||
cpus_per_task: 80
|
||||
gpus_per_node: 8
|
||||
tasks_per_node: 1
|
||||
mem_gb: 450
|
||||
nodes: 1
|
||||
name: ${env:PREFIX}_${hydra.job.config_name}
|
||||
partition: devlab,learnlab,learnfair,scavenge
|
||||
constraint: volta32gb,ib4
|
||||
max_num_timeout: 30
|
@ -0,0 +1,36 @@
|
||||
# @package _global_
|
||||
|
||||
hydra:
|
||||
job:
|
||||
config:
|
||||
override_dirname:
|
||||
kv_sep: ':'
|
||||
item_sep: '/'
|
||||
exclude_keys:
|
||||
- run_config
|
||||
- distributed_training.distributed_port
|
||||
- distributed_training.distributed_world_size
|
||||
- model.pretrained_model_path
|
||||
- model.target_network_path
|
||||
- next_script
|
||||
- task.cache_in_scratch
|
||||
- task.data
|
||||
- checkpoint.save_interval_updates
|
||||
- checkpoint.keep_interval_updates
|
||||
- checkpoint.save_on_overflow
|
||||
- common.log_interval
|
||||
- common.user_dir
|
||||
sweep:
|
||||
dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
|
||||
subdir: ''
|
||||
launcher:
|
||||
submitit_folder: ${hydra.sweep.dir}
|
||||
timeout_min: 4320
|
||||
cpus_per_task: 80
|
||||
gpus_per_node: 8
|
||||
tasks_per_node: 1
|
||||
mem_gb: 0
|
||||
nodes: 1
|
||||
name: ${env:PREFIX}_${hydra.job.config_name}
|
||||
partition: wav2vec,learnlab,learnfair
|
||||
max_num_timeout: 30
|
@ -0,0 +1,37 @@
|
||||
# @package _global_
|
||||
|
||||
hydra:
|
||||
job:
|
||||
config:
|
||||
override_dirname:
|
||||
kv_sep: ':'
|
||||
item_sep: '/'
|
||||
exclude_keys:
|
||||
- run_config
|
||||
- distributed_training.distributed_port
|
||||
- distributed_training.distributed_world_size
|
||||
- model.pretrained_model_path
|
||||
- model.target_network_path
|
||||
- next_script
|
||||
- task.cache_in_scratch
|
||||
- task.data
|
||||
- checkpoint.save_interval_updates
|
||||
- checkpoint.keep_interval_updates
|
||||
- checkpoint.save_on_overflow
|
||||
- common.log_interval
|
||||
- common.user_dir
|
||||
sweep:
|
||||
dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
|
||||
subdir: ''
|
||||
launcher:
|
||||
submitit_folder: ${hydra.sweep.dir}
|
||||
timeout_min: 4320
|
||||
cpus_per_task: 10
|
||||
gpus_per_node: 8
|
||||
tasks_per_node: 8
|
||||
mem_gb: 450
|
||||
nodes: 2
|
||||
name: ${env:PREFIX}_${hydra.job.config_name}
|
||||
partition: devlab,learnlab,learnfair,scavenge
|
||||
constraint: volta32gb,ib4
|
||||
max_num_timeout: 30
|
@ -0,0 +1,37 @@
|
||||
# @package _global_
|
||||
|
||||
hydra:
|
||||
job:
|
||||
config:
|
||||
override_dirname:
|
||||
kv_sep: ':'
|
||||
item_sep: '/'
|
||||
exclude_keys:
|
||||
- run_config
|
||||
- distributed_training.distributed_port
|
||||
- distributed_training.distributed_world_size
|
||||
- model.pretrained_model_path
|
||||
- model.target_network_path
|
||||
- next_script
|
||||
- task.cache_in_scratch
|
||||
- task.data
|
||||
- task.post_save_script
|
||||
- checkpoint.save_interval_updates
|
||||
- checkpoint.keep_interval_updates
|
||||
- checkpoint.save_on_overflow
|
||||
- common.log_interval
|
||||
- common.user_dir
|
||||
sweep:
|
||||
dir: /fsx-wav2vec/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
|
||||
subdir: ''
|
||||
launcher:
|
||||
submitit_folder: ${hydra.sweep.dir}
|
||||
timeout_min: 4320
|
||||
cpus_per_task: 10
|
||||
gpus_per_node: 8
|
||||
tasks_per_node: 8
|
||||
mem_gb: 0
|
||||
nodes: 2
|
||||
name: ${env:PREFIX}_${hydra.job.config_name}
|
||||
partition: wav2vec,learnlab,learnfair
|
||||
max_num_timeout: 30
|
@ -0,0 +1,36 @@
|
||||
# @package _global_
|
||||
|
||||
hydra:
|
||||
job:
|
||||
config:
|
||||
override_dirname:
|
||||
kv_sep: ':'
|
||||
item_sep: '/'
|
||||
exclude_keys:
|
||||
- run_config
|
||||
- distributed_training.distributed_port
|
||||
- distributed_training.distributed_world_size
|
||||
- model.pretrained_model_path
|
||||
- model.target_network_path
|
||||
- next_script
|
||||
- task.cache_in_scratch
|
||||
- task.data
|
||||
- checkpoint.save_interval_updates
|
||||
- checkpoint.keep_interval_updates
|
||||
- checkpoint.save_on_overflow
|
||||
- common.log_interval
|
||||
sweep:
|
||||
dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
|
||||
subdir: ''
|
||||
launcher:
|
||||
submitit_folder: ${hydra.sweep.dir}
|
||||
timeout_min: 4320
|
||||
cpus_per_task: 80
|
||||
gpus_per_node: 8
|
||||
tasks_per_node: 1
|
||||
mem_gb: 450
|
||||
nodes: 3
|
||||
name: ${env:PREFIX}_${hydra.job.config_name}
|
||||
partition: devlab,learnlab,learnfair,scavenge
|
||||
constraint: volta32gb,ib4
|
||||
max_num_timeout: 30
|
@ -0,0 +1,36 @@
|
||||
# @package _global_
|
||||
|
||||
hydra:
|
||||
job:
|
||||
config:
|
||||
override_dirname:
|
||||
kv_sep: ':'
|
||||
item_sep: '/'
|
||||
exclude_keys:
|
||||
- run_config
|
||||
- distributed_training.distributed_port
|
||||
- distributed_training.distributed_world_size
|
||||
- model.pretrained_model_path
|
||||
- model.target_network_path
|
||||
- next_script
|
||||
- task.cache_in_scratch
|
||||
- task.data
|
||||
- checkpoint.save_interval_updates
|
||||
- checkpoint.keep_interval_updates
|
||||
- checkpoint.save_on_overflow
|
||||
- common.log_interval
|
||||
sweep:
|
||||
dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
|
||||
subdir: ''
|
||||
launcher:
|
||||
submitit_folder: ${hydra.sweep.dir}
|
||||
timeout_min: 4320
|
||||
cpus_per_task: 10
|
||||
gpus_per_node: 8
|
||||
tasks_per_node: 8
|
||||
mem_gb: 450
|
||||
nodes: 4
|
||||
name: ${env:PREFIX}_${hydra.job.config_name}
|
||||
partition: devlab,learnlab,learnfair,scavenge
|
||||
constraint: volta32gb,ib4
|
||||
max_num_timeout: 30
|
@ -0,0 +1,37 @@
|
||||
# @package _global_
|
||||
|
||||
hydra:
|
||||
job:
|
||||
config:
|
||||
override_dirname:
|
||||
kv_sep: ':'
|
||||
item_sep: '/'
|
||||
exclude_keys:
|
||||
- run_config
|
||||
- distributed_training.distributed_port
|
||||
- distributed_training.distributed_world_size
|
||||
- model.pretrained_model_path
|
||||
- model.target_network_path
|
||||
- next_script
|
||||
- task.cache_in_scratch
|
||||
- task.data
|
||||
- task.post_save_script
|
||||
- checkpoint.save_interval_updates
|
||||
- checkpoint.keep_interval_updates
|
||||
- checkpoint.save_on_overflow
|
||||
- common.log_interval
|
||||
- common.user_dir
|
||||
sweep:
|
||||
dir: /fsx-wav2vec/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
|
||||
subdir: ''
|
||||
launcher:
|
||||
submitit_folder: ${hydra.sweep.dir}
|
||||
timeout_min: 4320
|
||||
cpus_per_task: 10
|
||||
gpus_per_node: 8
|
||||
tasks_per_node: 8
|
||||
mem_gb: 0
|
||||
nodes: 4
|
||||
name: ${env:PREFIX}_${hydra.job.config_name}
|
||||
partition: wav2vec,learnlab,learnfair
|
||||
max_num_timeout: 30
|
@ -0,0 +1,36 @@
|
||||
# @package _global_
|
||||
|
||||
hydra:
|
||||
job:
|
||||
config:
|
||||
override_dirname:
|
||||
kv_sep: ':'
|
||||
item_sep: '/'
|
||||
exclude_keys:
|
||||
- run_config
|
||||
- distributed_training.distributed_port
|
||||
- distributed_training.distributed_world_size
|
||||
- model.pretrained_model_path
|
||||
- model.target_network_path
|
||||
- next_script
|
||||
- task.cache_in_scratch
|
||||
- task.data
|
||||
- checkpoint.save_interval_updates
|
||||
- checkpoint.keep_interval_updates
|
||||
- checkpoint.save_on_overflow
|
||||
- common.log_interval
|
||||
- common.user_dir
|
||||
sweep:
|
||||
dir: /fsx-wav2vec/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
|
||||
subdir: ''
|
||||
launcher:
|
||||
submitit_folder: ${hydra.sweep.dir}
|
||||
timeout_min: 4320
|
||||
cpus_per_task: 10
|
||||
gpus_per_node: 8
|
||||
tasks_per_node: 8
|
||||
mem_gb: 0
|
||||
nodes: 6
|
||||
name: ${env:PREFIX}_${hydra.job.config_name}
|
||||
partition: wav2vec,learnlab,learnfair
|
||||
max_num_timeout: 30
|
@ -0,0 +1,36 @@
|
||||
# @package _global_
|
||||
|
||||
hydra:
|
||||
job:
|
||||
config:
|
||||
override_dirname:
|
||||
kv_sep: ':'
|
||||
item_sep: '/'
|
||||
exclude_keys:
|
||||
- run_config
|
||||
- distributed_training.distributed_port
|
||||
- distributed_training.distributed_world_size
|
||||
- model.pretrained_model_path
|
||||
- model.target_network_path
|
||||
- next_script
|
||||
- task.cache_in_scratch
|
||||
- task.data
|
||||
- checkpoint.save_interval_updates
|
||||
- checkpoint.keep_interval_updates
|
||||
- checkpoint.save_on_overflow
|
||||
- common.log_interval
|
||||
- common.user_dir
|
||||
sweep:
|
||||
dir: /fsx-wav2vec/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
|
||||
subdir: ''
|
||||
launcher:
|
||||
submitit_folder: ${hydra.sweep.dir}
|
||||
timeout_min: 4320
|
||||
cpus_per_task: 10
|
||||
gpus_per_node: 8
|
||||
tasks_per_node: 8
|
||||
mem_gb: 0
|
||||
nodes: 8
|
||||
name: ${env:PREFIX}_${hydra.job.config_name}
|
||||
partition: wav2vec,learnlab,learnfair
|
||||
max_num_timeout: 30
|
@ -0,0 +1,15 @@
|
||||
# @package _global_
|
||||
hydra:
|
||||
sweep:
|
||||
dir: ${env:PWD}/tmp_dbg/${now:%H-%M-%S}
|
||||
|
||||
distributed_training:
|
||||
distributed_world_size: 1
|
||||
nprocs_per_node: 1
|
||||
distributed_port: -1
|
||||
|
||||
common:
|
||||
log_interval: 1
|
||||
|
||||
dataset:
|
||||
num_workers: 0
|
@ -0,0 +1,37 @@
|
||||
# @package _global_
|
||||
|
||||
hydra:
|
||||
job:
|
||||
config:
|
||||
override_dirname:
|
||||
kv_sep: '_'
|
||||
item_sep: '/'
|
||||
exclude_keys:
|
||||
- run_config
|
||||
- distributed_training.distributed_port
|
||||
- distributed_training.distributed_world_size
|
||||
- model.pretrained_model_path
|
||||
- model.target_network_path
|
||||
- next_script
|
||||
- task.cache_in_scratch
|
||||
- task.data
|
||||
- checkpoint.save_interval_updates
|
||||
- checkpoint.keep_interval_updates
|
||||
- checkpoint.save_on_overflow
|
||||
- common.log_interval
|
||||
- common.user_dir
|
||||
sweep:
|
||||
dir: /fsx-wav2vec/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
|
||||
subdir: ''
|
||||
launcher:
|
||||
submitit_folder: ${hydra.sweep.dir}/submitit
|
||||
timeout_min: 4320
|
||||
cpus_per_task: 80
|
||||
gpus_per_node: 8
|
||||
tasks_per_node: 1
|
||||
mem_gb: 0
|
||||
nodes: 1
|
||||
name: ${env:PREFIX}_${hydra.job.config_name}
|
||||
partition: wav2vec
|
||||
max_num_timeout: 30
|
||||
exclude: a100-st-p4d24xlarge-471
|
@ -0,0 +1,37 @@
|
||||
# @package _global_
|
||||
|
||||
hydra:
|
||||
job:
|
||||
config:
|
||||
override_dirname:
|
||||
kv_sep: ':'
|
||||
item_sep: '/'
|
||||
exclude_keys:
|
||||
- run_config
|
||||
- distributed_training.distributed_port
|
||||
- distributed_training.distributed_world_size
|
||||
- model.pretrained_model_path
|
||||
- model.target_network_path
|
||||
- next_script
|
||||
- task.cache_in_scratch
|
||||
- task.data
|
||||
- checkpoint.save_interval_updates
|
||||
- checkpoint.keep_interval_updates
|
||||
- checkpoint.save_on_overflow
|
||||
- common.log_interval
|
||||
- common.user_dir
|
||||
sweep:
|
||||
dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
|
||||
subdir: ''
|
||||
launcher:
|
||||
submitit_folder: ${hydra.sweep.dir}
|
||||
timeout_min: 4320
|
||||
cpus_per_task: 80
|
||||
gpus_per_node: 8
|
||||
tasks_per_node: 1
|
||||
mem_gb: 450
|
||||
nodes: 2
|
||||
name: ${env:PREFIX}_${hydra.job.config_name}
|
||||
partition: devlab,learnlab,learnfair,scavenge
|
||||
constraint: volta32gb,ib4
|
||||
max_num_timeout: 30
|
@ -0,0 +1,37 @@
|
||||
# @package _global_
|
||||
|
||||
hydra:
|
||||
job:
|
||||
config:
|
||||
override_dirname:
|
||||
kv_sep: '_'
|
||||
item_sep: '/'
|
||||
exclude_keys:
|
||||
- run_config
|
||||
- distributed_training.distributed_port
|
||||
- distributed_training.distributed_world_size
|
||||
- model.pretrained_model_path
|
||||
- model.target_network_path
|
||||
- next_script
|
||||
- task.cache_in_scratch
|
||||
- task.data
|
||||
- checkpoint.save_interval_updates
|
||||
- checkpoint.keep_interval_updates
|
||||
- checkpoint.save_on_overflow
|
||||
- common.log_interval
|
||||
- common.user_dir
|
||||
sweep:
|
||||
dir: /fsx-wav2vec/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
|
||||
subdir: ''
|
||||
launcher:
|
||||
submitit_folder: ${hydra.sweep.dir}/submitit
|
||||
timeout_min: 4320
|
||||
cpus_per_task: 10
|
||||
gpus_per_node: 8
|
||||
tasks_per_node: 8
|
||||
mem_gb: 0
|
||||
nodes: 2
|
||||
name: ${env:PREFIX}_${hydra.job.config_name}
|
||||
partition: wav2vec
|
||||
max_num_timeout: 30
|
||||
exclude: a100-st-p4d24xlarge-471
|
@ -0,0 +1,36 @@
|
||||
# @package _global_
|
||||
|
||||
hydra:
|
||||
job:
|
||||
config:
|
||||
override_dirname:
|
||||
kv_sep: ':'
|
||||
item_sep: '/'
|
||||
exclude_keys:
|
||||
- run_config
|
||||
- distributed_training.distributed_port
|
||||
- distributed_training.distributed_world_size
|
||||
- model.pretrained_model_path
|
||||
- model.target_network_path
|
||||
- next_script
|
||||
- task.cache_in_scratch
|
||||
- task.data
|
||||
- checkpoint.save_interval_updates
|
||||
- checkpoint.keep_interval_updates
|
||||
- checkpoint.save_on_overflow
|
||||
- common.log_interval
|
||||
sweep:
|
||||
dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
|
||||
subdir: ''
|
||||
launcher:
|
||||
submitit_folder: ${hydra.sweep.dir}
|
||||
timeout_min: 4320
|
||||
cpus_per_task: 10
|
||||
gpus_per_node: 8
|
||||
tasks_per_node: 8
|
||||
mem_gb: 450
|
||||
nodes: 3
|
||||
name: ${env:PREFIX}_${hydra.job.config_name}
|
||||
partition: devlab,learnlab,learnfair,scavenge
|
||||
constraint: volta32gb,ib4
|
||||
max_num_timeout: 30
|
@ -0,0 +1,36 @@
|
||||
# @package _global_
|
||||
|
||||
hydra:
|
||||
job:
|
||||
config:
|
||||
override_dirname:
|
||||
kv_sep: ':'
|
||||
item_sep: '/'
|
||||
exclude_keys:
|
||||
- run_config
|
||||
- distributed_training.distributed_port
|
||||
- distributed_training.distributed_world_size
|
||||
- model.pretrained_model_path
|
||||
- model.target_network_path
|
||||
- next_script
|
||||
- task.cache_in_scratch
|
||||
- task.data
|
||||
- checkpoint.save_interval_updates
|
||||
- checkpoint.keep_interval_updates
|
||||
- checkpoint.save_on_overflow
|
||||
- common.log_interval
|
||||
sweep:
|
||||
dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
|
||||
subdir: ''
|
||||
launcher:
|
||||
submitit_folder: ${hydra.sweep.dir}
|
||||
timeout_min: 4320
|
||||
cpus_per_task: 10
|
||||
gpus_per_node: 8
|
||||
tasks_per_node: 8
|
||||
mem_gb: 450
|
||||
nodes: 4
|
||||
name: ${env:PREFIX}_${hydra.job.config_name}
|
||||
partition: devlab,learnlab,learnfair,scavenge
|
||||
constraint: volta32gb,ib4
|
||||
max_num_timeout: 30
|
@ -0,0 +1,41 @@
|
||||
# @package _global_
|
||||
|
||||
hydra:
|
||||
job:
|
||||
config:
|
||||
override_dirname:
|
||||
kv_sep: '_'
|
||||
item_sep: '/'
|
||||
exclude_keys:
|
||||
- run_config
|
||||
- distributed_training.distributed_port
|
||||
- distributed_training.distributed_world_size
|
||||
- model.pretrained_model_path
|
||||
- model.target_network_path
|
||||
- next_script
|
||||
- task.cache_in_scratch
|
||||
- task.data
|
||||
- checkpoint.save_interval_updates
|
||||
- checkpoint.keep_interval_updates
|
||||
- checkpoint.save_on_overflow
|
||||
- common.log_interval
|
||||
- common.user_dir
|
||||
sweep:
|
||||
dir: /fsx-wav2vec/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
|
||||
subdir: ''
|
||||
launcher:
|
||||
submitit_folder: ${hydra.sweep.dir}/submitit
|
||||
timeout_min: 4320
|
||||
cpus_per_task: 10
|
||||
gpus_per_node: 8
|
||||
tasks_per_node: 8
|
||||
mem_gb: 0
|
||||
nodes: 4
|
||||
name: ${env:PREFIX}_${hydra.job.config_name}
|
||||
partition: wav2vec
|
||||
max_num_timeout: 30
|
||||
exclude: a100-st-p4d24xlarge-471
|
||||
|
||||
distributed_training:
|
||||
distributed_world_size: 32
|
||||
ddp_backend: legacy_ddp
|
@ -0,0 +1,41 @@
|
||||
# @package _global_
|
||||
|
||||
hydra:
|
||||
job:
|
||||
config:
|
||||
override_dirname:
|
||||
kv_sep: '_'
|
||||
item_sep: '/'
|
||||
exclude_keys:
|
||||
- run_config
|
||||
- distributed_training.distributed_port
|
||||
- distributed_training.distributed_world_size
|
||||
- model.pretrained_model_path
|
||||
- model.target_network_path
|
||||
- next_script
|
||||
- task.cache_in_scratch
|
||||
- task.data
|
||||
- checkpoint.save_interval_updates
|
||||
- checkpoint.keep_interval_updates
|
||||
- checkpoint.save_on_overflow
|
||||
- common.log_interval
|
||||
- common.user_dir
|
||||
sweep:
|
||||
dir: /fsx-wav2vec/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
|
||||
subdir: ''
|
||||
launcher:
|
||||
submitit_folder: ${hydra.sweep.dir}/submitit
|
||||
timeout_min: 4320
|
||||
cpus_per_task: 10
|
||||
gpus_per_node: 8
|
||||
tasks_per_node: 8
|
||||
mem_gb: 0
|
||||
nodes: 8
|
||||
name: pt
|
||||
partition: wav2vec
|
||||
max_num_timeout: 30
|
||||
exclude: a100-st-p4d24xlarge-471
|
||||
|
||||
distributed_training:
|
||||
distributed_world_size: 64
|
||||
ddp_backend: legacy_ddp
|
113
examples/data2vec/config/v2/base_audio_only_task.yaml
Normal file
113
examples/data2vec/config/v2/base_audio_only_task.yaml
Normal file
@ -0,0 +1,113 @@
|
||||
# @package _group_
|
||||
|
||||
common:
|
||||
fp16: true
|
||||
log_format: json
|
||||
log_interval: 200
|
||||
tensorboard_logdir: tb
|
||||
min_loss_scale: 1e-6
|
||||
fp16_no_flatten_grads: false
|
||||
user_dir: ${env:PWD}/examples/data2vec
|
||||
|
||||
checkpoint:
|
||||
save_interval: 1
|
||||
save_interval_updates: 25000
|
||||
keep_interval_updates: 1
|
||||
no_epoch_checkpoints: true
|
||||
|
||||
task:
|
||||
_name: audio_pretraining
|
||||
data: /private/home/abaevski/data/librispeech/full
|
||||
max_sample_size: 320000
|
||||
min_sample_size: 32000
|
||||
normalize: true
|
||||
precompute_mask_config: {}
|
||||
|
||||
dataset:
|
||||
num_workers: 6
|
||||
max_tokens: 1000000
|
||||
skip_invalid_size_inputs_valid_test: true
|
||||
validate_interval: 5
|
||||
required_batch_size_multiple: 1
|
||||
disable_validation: true
|
||||
|
||||
distributed_training:
|
||||
distributed_world_size: 8
|
||||
ddp_backend: legacy_ddp
|
||||
|
||||
criterion:
|
||||
_name: model
|
||||
log_keys:
|
||||
- ema_decay
|
||||
- target_var
|
||||
- pred_var
|
||||
- model_norm
|
||||
- ema_norm
|
||||
- masked_pct
|
||||
|
||||
optimization:
|
||||
max_update: 400000
|
||||
lr: [0.00075]
|
||||
debug_param_names: true
|
||||
|
||||
optimizer:
|
||||
_name: adam
|
||||
adam_betas: [ 0.9,0.98 ]
|
||||
adam_eps: 1e-06
|
||||
weight_decay: 0.01
|
||||
|
||||
lr_scheduler:
|
||||
_name: cosine
|
||||
warmup_updates: 8000
|
||||
|
||||
model:
|
||||
_name: data2vec_multi
|
||||
|
||||
loss_beta: 0
|
||||
loss_scale: null
|
||||
|
||||
depth: 12
|
||||
embed_dim: 768
|
||||
clone_batch: 8
|
||||
|
||||
ema_decay: 0.999
|
||||
ema_end_decay: 0.99999
|
||||
ema_anneal_end_step: 75000
|
||||
ema_encoder_only: false
|
||||
|
||||
average_top_k_layers: 8
|
||||
instance_norm_target_layer: true
|
||||
layer_norm_target_layer: false
|
||||
layer_norm_targets: false
|
||||
|
||||
layerdrop: 0.05
|
||||
norm_eps: 1e-5
|
||||
|
||||
supported_modality: AUDIO
|
||||
|
||||
modalities:
|
||||
audio:
|
||||
feature_encoder_spec: '[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512,2,2)] + [(512,2,2)]'
|
||||
conv_pos_depth: 5
|
||||
conv_pos_width: 95
|
||||
conv_pos_groups: 16
|
||||
prenet_depth: 0
|
||||
mask_prob: 0.5
|
||||
mask_prob_adjust: 0.05
|
||||
inverse_mask: false
|
||||
mask_length: 5
|
||||
mask_noise_std: 0.01
|
||||
mask_dropout: 0
|
||||
add_masks: false
|
||||
ema_local_encoder: false
|
||||
use_alibi_encoder: true
|
||||
prenet_layerdrop: 0.05
|
||||
prenet_dropout: 0.1
|
||||
learned_alibi_scale: true
|
||||
learned_alibi_scale_per_head: true
|
||||
decoder:
|
||||
input_dropout: 0.1
|
||||
decoder_dim: 384
|
||||
decoder_groups: 16
|
||||
decoder_kernel: 7
|
||||
decoder_layers: 4
|
116
examples/data2vec/config/v2/base_images_only_task.yaml
Normal file
116
examples/data2vec/config/v2/base_images_only_task.yaml
Normal file
@ -0,0 +1,116 @@
|
||||
# @package _group_
|
||||
|
||||
common:
|
||||
fp16: true
|
||||
log_format: json
|
||||
log_interval: 200
|
||||
tensorboard_logdir: tb
|
||||
min_loss_scale: 1e-6
|
||||
fp16_no_flatten_grads: true
|
||||
user_dir: ${env:PWD}/examples/data2vec
|
||||
|
||||
checkpoint:
|
||||
save_interval: 5
|
||||
save_interval_updates: 25000
|
||||
keep_interval_updates: 1
|
||||
no_epoch_checkpoints: true
|
||||
|
||||
task:
|
||||
_name: mae_image_pretraining
|
||||
data: /datasets01/imagenet_full_size/061417/
|
||||
rebuild_batches: true
|
||||
local_cache_path: /scratch/cache_abaevski/imagenet
|
||||
key: source
|
||||
precompute_mask_config: {}
|
||||
|
||||
dataset:
|
||||
num_workers: 10
|
||||
batch_size: 16
|
||||
skip_invalid_size_inputs_valid_test: true
|
||||
required_batch_size_multiple: 1
|
||||
disable_validation: true
|
||||
|
||||
distributed_training:
|
||||
distributed_world_size: 16
|
||||
ddp_backend: c10d
|
||||
|
||||
criterion:
|
||||
_name: model
|
||||
log_keys:
|
||||
- ema_decay
|
||||
- target_var
|
||||
- pred_var
|
||||
- model_norm
|
||||
- ema_norm
|
||||
- masked_pct
|
||||
|
||||
optimization:
|
||||
max_update: 375300
|
||||
lr: [ 0.001 ]
|
||||
debug_param_names: true
|
||||
clip_norm: 4
|
||||
|
||||
optimizer:
|
||||
_name: composite
|
||||
dynamic_groups: true
|
||||
groups:
|
||||
default:
|
||||
lr_float: 1e-3
|
||||
optimizer:
|
||||
_name: adam
|
||||
adam_betas: [0.9,0.95]
|
||||
weight_decay: 0.05
|
||||
lr_scheduler:
|
||||
_name: cosine
|
||||
warmup_updates: 50040
|
||||
|
||||
lr_scheduler: pass_through
|
||||
|
||||
model:
|
||||
_name: data2vec_multi
|
||||
|
||||
ema_decay: 0.9998
|
||||
ema_end_decay: 0.99999
|
||||
ema_anneal_end_step: 100000
|
||||
instance_norm_target_layer: true
|
||||
layer_norm_target_layer: false
|
||||
layer_norm_targets: true
|
||||
end_of_block_targets: false
|
||||
|
||||
depth: 10
|
||||
average_top_k_layers: 10
|
||||
clone_batch: 16
|
||||
|
||||
norm_eps: 1e-6
|
||||
|
||||
min_target_var: 0
|
||||
min_pred_var: 0
|
||||
|
||||
encoder_dropout: 0
|
||||
post_mlp_drop: 0
|
||||
attention_dropout: 0
|
||||
activation_dropout: 0
|
||||
|
||||
supported_modality: IMAGE
|
||||
cls_loss: 0.01
|
||||
|
||||
ema_encoder_only: false
|
||||
|
||||
modalities:
|
||||
image:
|
||||
inverse_mask: true
|
||||
mask_prob: 0.8
|
||||
mask_prob_adjust: 0.07
|
||||
mask_length: 3
|
||||
mask_noise_std: 0.01
|
||||
prenet_depth: 2
|
||||
ema_local_encoder: true
|
||||
num_extra_tokens: 1
|
||||
init_extra_token_zero: false
|
||||
use_alibi_encoder: false
|
||||
decoder:
|
||||
decoder_dim: 768
|
||||
decoder_groups: 16
|
||||
decoder_kernel: 3
|
||||
decoder_layers: 6
|
||||
input_dropout: 0
|
112
examples/data2vec/config/v2/base_text_only_task.yaml
Normal file
112
examples/data2vec/config/v2/base_text_only_task.yaml
Normal file
@ -0,0 +1,112 @@
|
||||
# @package _group_
|
||||
|
||||
common:
|
||||
fp16: true
|
||||
log_format: json
|
||||
log_interval: 200
|
||||
tensorboard_logdir: tb
|
||||
fp16_no_flatten_grads: true
|
||||
user_dir: ${env:PWD}/examples/data2vec
|
||||
|
||||
checkpoint:
|
||||
no_epoch_checkpoints: true
|
||||
save_interval_updates: 50000
|
||||
keep_interval_updates: 1
|
||||
|
||||
distributed_training:
|
||||
distributed_world_size: 16
|
||||
ddp_backend: legacy_ddp
|
||||
|
||||
task:
|
||||
_name: masked_lm
|
||||
data: /fsx-wav2vec/abaevski/data/nlp/bookwiki_aml-full-mmap2-bin
|
||||
sample_break_mode: none
|
||||
tokens_per_sample: 512
|
||||
include_target_tokens: true
|
||||
random_token_prob: 0
|
||||
leave_unmasked_prob: 0
|
||||
include_index: True
|
||||
skip_masking: True
|
||||
d2v2_multi: True
|
||||
|
||||
criterion:
|
||||
_name: model
|
||||
log_keys:
|
||||
- ema_decay
|
||||
- target_var
|
||||
- pred_var
|
||||
- model_norm
|
||||
- ema_norm
|
||||
- masked_pct
|
||||
|
||||
dataset:
|
||||
batch_size: 4
|
||||
ignore_unused_valid_subsets: true
|
||||
skip_invalid_size_inputs_valid_test: true
|
||||
disable_validation: true
|
||||
|
||||
optimization:
|
||||
clip_norm: 1
|
||||
lr: [0.0002]
|
||||
max_update: 1000000
|
||||
update_freq: [1]
|
||||
|
||||
optimizer:
|
||||
_name: composite
|
||||
dynamic_groups: true
|
||||
groups:
|
||||
default:
|
||||
lr_float: 0.0002
|
||||
optimizer:
|
||||
_name: adam
|
||||
adam_betas: [0.9,0.98]
|
||||
adam_eps: 1e-06
|
||||
weight_decay: 0.01
|
||||
lr_scheduler:
|
||||
_name: cosine
|
||||
warmup_updates: 4000
|
||||
|
||||
lr_scheduler: pass_through
|
||||
|
||||
model:
|
||||
_name: data2vec_multi
|
||||
|
||||
loss_beta: 0
|
||||
loss_scale: 1
|
||||
|
||||
depth: 12
|
||||
embed_dim: 768
|
||||
clone_batch: 8
|
||||
|
||||
ema_decay: 0.9999
|
||||
ema_end_decay: 0.99999
|
||||
ema_anneal_end_step: 100000
|
||||
ema_encoder_only: true
|
||||
|
||||
average_top_k_layers: 12
|
||||
layer_norm_target_layer: false
|
||||
instance_norm_target_layer: true
|
||||
batch_norm_target_layer: false
|
||||
instance_norm_targets: false
|
||||
layer_norm_targets: false
|
||||
|
||||
layerdrop: 0
|
||||
norm_eps: 1e-5
|
||||
|
||||
supported_modality: TEXT
|
||||
|
||||
modalities:
|
||||
text:
|
||||
mask_prob: 0.48
|
||||
mask_length: 1
|
||||
mask_noise_std: 0.01
|
||||
prenet_depth: 0
|
||||
decoder:
|
||||
input_dropout: 0.1
|
||||
decoder_dim: 768
|
||||
decoder_groups: 1
|
||||
decoder_kernel: 9
|
||||
decoder_layers: 5
|
||||
decoder_residual: false
|
||||
projection_layers: 2
|
||||
projection_ratio: 2.0
|
122
examples/data2vec/config/v2/huge_images14_only_task.yaml
Normal file
122
examples/data2vec/config/v2/huge_images14_only_task.yaml
Normal file
@ -0,0 +1,122 @@
|
||||
# @package _group_
|
||||
|
||||
common:
|
||||
fp16: true
|
||||
log_format: json
|
||||
log_interval: 200
|
||||
tensorboard_logdir: tb
|
||||
min_loss_scale: 1e-6
|
||||
fp16_no_flatten_grads: true
|
||||
user_dir: ${env:PWD}/examples/data2vec
|
||||
|
||||
checkpoint:
|
||||
save_interval: 5
|
||||
save_interval_updates: 25000
|
||||
keep_interval_updates: 1
|
||||
no_epoch_checkpoints: true
|
||||
|
||||
task:
|
||||
_name: mae_image_pretraining
|
||||
data: /datasets01/imagenet_full_size/061417/
|
||||
rebuild_batches: true
|
||||
local_cache_path: /scratch/cache_abaevski/imagenet
|
||||
key: source
|
||||
precompute_mask_config: {}
|
||||
|
||||
dataset:
|
||||
num_workers: 10
|
||||
batch_size: 8
|
||||
skip_invalid_size_inputs_valid_test: true
|
||||
required_batch_size_multiple: 1
|
||||
disable_validation: true
|
||||
|
||||
distributed_training:
|
||||
distributed_world_size: 32
|
||||
ddp_backend: c10d
|
||||
|
||||
criterion:
|
||||
_name: model
|
||||
log_keys:
|
||||
- ema_decay
|
||||
- target_var
|
||||
- pred_var
|
||||
- model_norm
|
||||
- ema_norm
|
||||
- masked_pct
|
||||
|
||||
optimization:
|
||||
max_update: 500000
|
||||
lr: [ 0.0004 ]
|
||||
debug_param_names: true
|
||||
clip_norm: 4
|
||||
|
||||
optimizer:
|
||||
_name: composite
|
||||
dynamic_groups: true
|
||||
groups:
|
||||
default:
|
||||
lr_float: 4e-4
|
||||
optimizer:
|
||||
_name: adam
|
||||
adam_betas: [0.9,0.95]
|
||||
weight_decay: 0.05
|
||||
lr_scheduler:
|
||||
_name: cosine
|
||||
warmup_updates: 50040
|
||||
|
||||
lr_scheduler: pass_through
|
||||
|
||||
model:
|
||||
_name: data2vec_multi
|
||||
|
||||
ema_decay: 0.9998
|
||||
ema_end_decay: 1
|
||||
ema_anneal_end_step: 300000
|
||||
instance_norm_target_layer: true
|
||||
layer_norm_target_layer: false
|
||||
layer_norm_targets: true
|
||||
end_of_block_targets: false
|
||||
|
||||
depth: 32
|
||||
embed_dim: 1280
|
||||
num_heads: 16
|
||||
|
||||
average_top_k_layers: 24
|
||||
clone_batch: 16
|
||||
|
||||
norm_eps: 1e-6
|
||||
|
||||
min_target_var: 0
|
||||
min_pred_var: 0
|
||||
|
||||
encoder_dropout: 0
|
||||
post_mlp_drop: 0
|
||||
attention_dropout: 0
|
||||
activation_dropout: 0
|
||||
|
||||
supported_modality: IMAGE
|
||||
cls_loss: 0.01
|
||||
|
||||
ema_encoder_only: false
|
||||
|
||||
modalities:
|
||||
image:
|
||||
patch_size: 14
|
||||
inverse_mask: true
|
||||
mask_prob: 0.75
|
||||
mask_prob_adjust: 0.1
|
||||
mask_length: 3
|
||||
mask_noise_std: 0.01
|
||||
prenet_depth: 0
|
||||
ema_local_encoder: true
|
||||
num_extra_tokens: 1
|
||||
init_extra_token_zero: false
|
||||
use_alibi_encoder: false
|
||||
embed_dim: 1280
|
||||
decoder:
|
||||
decoder_dim: 1024
|
||||
decoder_groups: 16
|
||||
decoder_kernel: 5
|
||||
decoder_layers: 3
|
||||
final_layer_norm: false
|
||||
input_dropout: 0
|
120
examples/data2vec/config/v2/huge_images_only_task.yaml
Normal file
120
examples/data2vec/config/v2/huge_images_only_task.yaml
Normal file
@ -0,0 +1,120 @@
|
||||
# @package _group_
|
||||
|
||||
common:
|
||||
fp16: true
|
||||
log_format: json
|
||||
log_interval: 200
|
||||
tensorboard_logdir: tb
|
||||
min_loss_scale: 1e-6
|
||||
fp16_no_flatten_grads: true
|
||||
user_dir: ${env:PWD}/examples/data2vec
|
||||
|
||||
checkpoint:
|
||||
save_interval: 5
|
||||
save_interval_updates: 25000
|
||||
keep_interval_updates: 1
|
||||
no_epoch_checkpoints: true
|
||||
|
||||
task:
|
||||
_name: mae_image_pretraining
|
||||
data: /datasets01/imagenet_full_size/061417/
|
||||
rebuild_batches: true
|
||||
local_cache_path: /scratch/cache_abaevski/imagenet
|
||||
key: source
|
||||
precompute_mask_config: {}
|
||||
|
||||
dataset:
|
||||
num_workers: 10
|
||||
batch_size: 8
|
||||
skip_invalid_size_inputs_valid_test: true
|
||||
required_batch_size_multiple: 1
|
||||
disable_validation: true
|
||||
|
||||
distributed_training:
|
||||
distributed_world_size: 16
|
||||
ddp_backend: c10d
|
||||
|
||||
criterion:
|
||||
_name: model
|
||||
log_keys:
|
||||
- ema_decay
|
||||
- target_var
|
||||
- pred_var
|
||||
- model_norm
|
||||
- ema_norm
|
||||
- masked_pct
|
||||
|
||||
optimization:
|
||||
max_update: 375300
|
||||
lr: [ 0.0004 ]
|
||||
debug_param_names: true
|
||||
clip_norm: 4
|
||||
|
||||
optimizer:
|
||||
_name: composite
|
||||
dynamic_groups: true
|
||||
groups:
|
||||
default:
|
||||
lr_float: 4e-4
|
||||
optimizer:
|
||||
_name: adam
|
||||
adam_betas: [0.9,0.95]
|
||||
weight_decay: 0.05
|
||||
lr_scheduler:
|
||||
_name: cosine
|
||||
warmup_updates: 50040
|
||||
|
||||
lr_scheduler: pass_through
|
||||
|
||||
model:
|
||||
_name: data2vec_multi
|
||||
|
||||
ema_decay: 0.9998
|
||||
ema_end_decay: 0.99995
|
||||
ema_anneal_end_step: 150000
|
||||
instance_norm_target_layer: true
|
||||
layer_norm_target_layer: false
|
||||
layer_norm_targets: true
|
||||
end_of_block_targets: false
|
||||
|
||||
depth: 32
|
||||
embed_dim: 1280
|
||||
num_heads: 16
|
||||
|
||||
average_top_k_layers: 24
|
||||
clone_batch: 16
|
||||
|
||||
norm_eps: 1e-6
|
||||
|
||||
min_target_var: 0
|
||||
min_pred_var: 0
|
||||
|
||||
encoder_dropout: 0
|
||||
post_mlp_drop: 0
|
||||
attention_dropout: 0
|
||||
activation_dropout: 0
|
||||
|
||||
supported_modality: IMAGE
|
||||
cls_loss: 0.01
|
||||
|
||||
ema_encoder_only: false
|
||||
|
||||
modalities:
|
||||
image:
|
||||
inverse_mask: true
|
||||
mask_prob: 0.75
|
||||
mask_prob_adjust: 0.1
|
||||
mask_length: 3
|
||||
mask_noise_std: 0.01
|
||||
prenet_depth: 0
|
||||
ema_local_encoder: true
|
||||
num_extra_tokens: 1
|
||||
init_extra_token_zero: false
|
||||
use_alibi_encoder: false
|
||||
embed_dim: 1280
|
||||
decoder:
|
||||
decoder_dim: 1024
|
||||
decoder_groups: 16
|
||||
decoder_kernel: 5
|
||||
decoder_layers: 3
|
||||
input_dropout: 0
|
122
examples/data2vec/config/v2/large_audio_only_task.yaml
Normal file
122
examples/data2vec/config/v2/large_audio_only_task.yaml
Normal file
@ -0,0 +1,122 @@
|
||||
# @package _group_
|
||||
|
||||
common:
|
||||
fp16: true
|
||||
log_format: json
|
||||
log_interval: 200
|
||||
tensorboard_logdir: tb
|
||||
min_loss_scale: 1e-6
|
||||
fp16_no_flatten_grads: true
|
||||
user_dir: ${env:PWD}/examples/data2vec
|
||||
|
||||
checkpoint:
|
||||
save_interval: 1
|
||||
save_interval_updates: 25000
|
||||
keep_interval_updates: 1
|
||||
no_epoch_checkpoints: true
|
||||
|
||||
task:
|
||||
_name: audio_pretraining
|
||||
data: /fsx-wav2vec/abaevski/data/librivox/no_silence
|
||||
max_sample_size: 320000
|
||||
min_sample_size: 32000
|
||||
normalize: true
|
||||
precompute_mask_config: {}
|
||||
|
||||
dataset:
|
||||
num_workers: 8
|
||||
max_tokens: 320000
|
||||
skip_invalid_size_inputs_valid_test: true
|
||||
validate_interval: 5
|
||||
required_batch_size_multiple: 1
|
||||
disable_validation: true
|
||||
|
||||
distributed_training:
|
||||
distributed_world_size: 48
|
||||
ddp_backend: c10d
|
||||
|
||||
criterion:
|
||||
_name: model
|
||||
log_keys:
|
||||
- ema_decay
|
||||
- target_var
|
||||
- pred_var
|
||||
- model_norm
|
||||
- ema_norm
|
||||
- masked_pct
|
||||
|
||||
optimization:
|
||||
max_update: 600000
|
||||
debug_param_names: true
|
||||
clip_norm: 1
|
||||
|
||||
optimizer:
|
||||
_name: composite
|
||||
dynamic_groups: true
|
||||
groups:
|
||||
default:
|
||||
lr_float: 0.0004
|
||||
optimizer:
|
||||
_name: adam
|
||||
adam_betas: [0.9,0.98]
|
||||
adam_eps: 1e-06
|
||||
weight_decay: 0.01
|
||||
lr_scheduler:
|
||||
_name: cosine
|
||||
warmup_updates: 10000
|
||||
|
||||
lr_scheduler: pass_through
|
||||
|
||||
model:
|
||||
_name: data2vec_multi
|
||||
|
||||
loss_beta: 0
|
||||
loss_scale: null
|
||||
|
||||
depth: 16
|
||||
embed_dim: 1024
|
||||
num_heads: 16
|
||||
|
||||
clone_batch: 12
|
||||
|
||||
ema_decay: 0.9997
|
||||
ema_end_decay: 1
|
||||
ema_anneal_end_step: 300000
|
||||
ema_encoder_only: false
|
||||
|
||||
average_top_k_layers: 16
|
||||
instance_norm_target_layer: true
|
||||
layer_norm_target_layer: false
|
||||
layer_norm_targets: false
|
||||
|
||||
layerdrop: 0
|
||||
norm_eps: 1e-5
|
||||
|
||||
supported_modality: AUDIO
|
||||
|
||||
modalities:
|
||||
audio:
|
||||
feature_encoder_spec: '[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512,2,2)] + [(512,2,2)]'
|
||||
conv_pos_depth: 5
|
||||
conv_pos_width: 95
|
||||
conv_pos_groups: 16
|
||||
prenet_depth: 8
|
||||
mask_prob: 0.55
|
||||
mask_prob_adjust: 0.1
|
||||
inverse_mask: false
|
||||
mask_length: 5
|
||||
mask_noise_std: 0.01
|
||||
mask_dropout: 0
|
||||
add_masks: false
|
||||
ema_local_encoder: false
|
||||
use_alibi_encoder: true
|
||||
prenet_layerdrop: 0
|
||||
prenet_dropout: 0.1
|
||||
learned_alibi_scale: true
|
||||
learned_alibi_scale_per_head: true
|
||||
decoder:
|
||||
input_dropout: 0.1
|
||||
decoder_dim: 768
|
||||
decoder_groups: 16
|
||||
decoder_kernel: 7
|
||||
decoder_layers: 4
|
120
examples/data2vec/config/v2/large_images_only_task.yaml
Normal file
120
examples/data2vec/config/v2/large_images_only_task.yaml
Normal file
@ -0,0 +1,120 @@
|
||||
# @package _group_
|
||||
|
||||
common:
|
||||
fp16: true
|
||||
log_format: json
|
||||
log_interval: 200
|
||||
tensorboard_logdir: tb
|
||||
min_loss_scale: 1e-6
|
||||
fp16_no_flatten_grads: true
|
||||
user_dir: ${env:PWD}/examples/data2vec
|
||||
|
||||
checkpoint:
|
||||
save_interval: 5
|
||||
save_interval_updates: 25000
|
||||
keep_interval_updates: 1
|
||||
no_epoch_checkpoints: true
|
||||
|
||||
task:
|
||||
_name: mae_image_pretraining
|
||||
data: /datasets01/imagenet_full_size/061417/
|
||||
rebuild_batches: true
|
||||
local_cache_path: /scratch/cache_abaevski/imagenet
|
||||
key: source
|
||||
precompute_mask_config: {}
|
||||
|
||||
dataset:
|
||||
num_workers: 10
|
||||
batch_size: 8
|
||||
skip_invalid_size_inputs_valid_test: true
|
||||
required_batch_size_multiple: 1
|
||||
disable_validation: true
|
||||
|
||||
distributed_training:
|
||||
distributed_world_size: 16
|
||||
ddp_backend: c10d
|
||||
|
||||
criterion:
|
||||
_name: model
|
||||
log_keys:
|
||||
- ema_decay
|
||||
- target_var
|
||||
- pred_var
|
||||
- model_norm
|
||||
- ema_norm
|
||||
- masked_pct
|
||||
|
||||
optimization:
|
||||
max_update: 375300
|
||||
lr: [ 0.0004 ]
|
||||
debug_param_names: true
|
||||
clip_norm: 4
|
||||
|
||||
optimizer:
|
||||
_name: composite
|
||||
dynamic_groups: true
|
||||
groups:
|
||||
default:
|
||||
lr_float: 4e-4
|
||||
optimizer:
|
||||
_name: adam
|
||||
adam_betas: [0.9,0.95]
|
||||
weight_decay: 0.05
|
||||
lr_scheduler:
|
||||
_name: cosine
|
||||
warmup_updates: 50040
|
||||
|
||||
lr_scheduler: pass_through
|
||||
|
||||
model:
|
||||
_name: data2vec_multi
|
||||
|
||||
ema_decay: 0.9998
|
||||
ema_end_decay: 0.99999
|
||||
ema_anneal_end_step: 150000
|
||||
instance_norm_target_layer: true
|
||||
layer_norm_target_layer: false
|
||||
layer_norm_targets: true
|
||||
end_of_block_targets: false
|
||||
|
||||
depth: 24
|
||||
embed_dim: 1024
|
||||
num_heads: 16
|
||||
|
||||
average_top_k_layers: 18
|
||||
clone_batch: 16
|
||||
|
||||
norm_eps: 1e-6
|
||||
|
||||
min_target_var: 0
|
||||
min_pred_var: 0
|
||||
|
||||
encoder_dropout: 0
|
||||
post_mlp_drop: 0
|
||||
attention_dropout: 0
|
||||
activation_dropout: 0
|
||||
|
||||
supported_modality: IMAGE
|
||||
cls_loss: 0.01
|
||||
|
||||
ema_encoder_only: false
|
||||
|
||||
modalities:
|
||||
image:
|
||||
inverse_mask: true
|
||||
mask_prob: 0.75
|
||||
mask_prob_adjust: 0.1
|
||||
mask_length: 3
|
||||
mask_noise_std: 0.01
|
||||
prenet_depth: 0
|
||||
ema_local_encoder: true
|
||||
num_extra_tokens: 1
|
||||
init_extra_token_zero: false
|
||||
use_alibi_encoder: false
|
||||
embed_dim: 1024
|
||||
decoder:
|
||||
decoder_dim: 1024
|
||||
decoder_groups: 16
|
||||
decoder_kernel: 5
|
||||
decoder_layers: 3
|
||||
input_dropout: 0
|
112
examples/data2vec/config/v2/large_text_only_task.yaml
Normal file
112
examples/data2vec/config/v2/large_text_only_task.yaml
Normal file
@ -0,0 +1,112 @@
|
||||
# @package _group_
|
||||
|
||||
common:
|
||||
fp16: true
|
||||
log_format: json
|
||||
log_interval: 200
|
||||
tensorboard_logdir: tb
|
||||
min_loss_scale: 1e-6
|
||||
fp16_no_flatten_grads: true
|
||||
user_dir: ${env:PWD}/examples/data2vec
|
||||
|
||||
checkpoint:
|
||||
save_interval_updates: 50000
|
||||
keep_interval_updates: 1
|
||||
no_epoch_checkpoints: true
|
||||
|
||||
task:
|
||||
_name: masked_lm
|
||||
data: /fsx-wav2vec/abaevski/data/nlp/bookwiki_aml-full-mmap2-bin
|
||||
sample_break_mode: none
|
||||
tokens_per_sample: 512
|
||||
include_target_tokens: true
|
||||
random_token_prob: 0
|
||||
leave_unmasked_prob: 0
|
||||
include_index: True
|
||||
skip_masking: True
|
||||
d2v2_multi: True
|
||||
|
||||
dataset:
|
||||
batch_size: 2
|
||||
ignore_unused_valid_subsets: true
|
||||
skip_invalid_size_inputs_valid_test: true
|
||||
disable_validation: true
|
||||
|
||||
distributed_training:
|
||||
distributed_world_size: 32
|
||||
ddp_backend: c10d
|
||||
|
||||
criterion:
|
||||
_name: model
|
||||
log_keys:
|
||||
- ema_decay
|
||||
- target_var
|
||||
- pred_var
|
||||
- model_norm
|
||||
- ema_norm
|
||||
- masked_pct
|
||||
|
||||
optimization:
|
||||
max_update: 600000
|
||||
clip_norm: 1
|
||||
|
||||
optimizer:
|
||||
_name: composite
|
||||
dynamic_groups: true
|
||||
groups:
|
||||
default:
|
||||
lr_float: 0.0001
|
||||
optimizer:
|
||||
_name: adam
|
||||
adam_betas: [0.9,0.98]
|
||||
adam_eps: 1e-06
|
||||
weight_decay: 0.01
|
||||
lr_scheduler:
|
||||
_name: cosine
|
||||
warmup_updates: 4000
|
||||
|
||||
lr_scheduler: pass_through
|
||||
|
||||
model:
|
||||
_name: data2vec_multi
|
||||
|
||||
loss_beta: 0
|
||||
loss_scale: 1
|
||||
|
||||
depth: 24
|
||||
num_heads: 16
|
||||
embed_dim: 1024
|
||||
clone_batch: 8
|
||||
|
||||
ema_decay: 0.9999
|
||||
ema_end_decay: 0.99999
|
||||
ema_anneal_end_step: 100000
|
||||
ema_encoder_only: true
|
||||
|
||||
average_top_k_layers: 24
|
||||
layer_norm_target_layer: true
|
||||
instance_norm_target_layer: false
|
||||
batch_norm_target_layer: false
|
||||
instance_norm_targets: true
|
||||
layer_norm_targets: false
|
||||
|
||||
layerdrop: 0
|
||||
norm_eps: 1e-5
|
||||
|
||||
supported_modality: TEXT
|
||||
|
||||
modalities:
|
||||
text:
|
||||
mask_prob: 0.5
|
||||
mask_length: 1
|
||||
mask_noise_std: 0.01
|
||||
prenet_depth: 0
|
||||
decoder:
|
||||
input_dropout: 0.1
|
||||
decoder_dim: 768
|
||||
decoder_groups: 1
|
||||
decoder_kernel: 9
|
||||
decoder_layers: 5
|
||||
decoder_residual: false
|
||||
projection_layers: 2
|
||||
projection_ratio: 2.0
|
123
examples/data2vec/config/v2/large_text_only_task_pgrp_1M.yaml
Normal file
123
examples/data2vec/config/v2/large_text_only_task_pgrp_1M.yaml
Normal file
@ -0,0 +1,123 @@
|
||||
# @package _group_
|
||||
|
||||
common:
|
||||
fp16: true
|
||||
log_format: json
|
||||
log_interval: 200
|
||||
tensorboard_logdir: tb
|
||||
fp16_no_flatten_grads: true
|
||||
user_dir: ${env:PWD}/examples/data2vec
|
||||
|
||||
checkpoint:
|
||||
no_epoch_checkpoints: true
|
||||
save_interval_updates: 50000
|
||||
keep_interval_updates: 1
|
||||
|
||||
distributed_training:
|
||||
distributed_world_size: 32
|
||||
ddp_backend: legacy_ddp
|
||||
|
||||
task:
|
||||
_name: masked_lm
|
||||
data: /fsx-wav2vec/abaevski/data/nlp/bookwiki_aml-full-mmap2-bin
|
||||
sample_break_mode: none
|
||||
tokens_per_sample: 512
|
||||
include_target_tokens: true
|
||||
random_token_prob: 0
|
||||
leave_unmasked_prob: 0
|
||||
include_index: True
|
||||
skip_masking: True
|
||||
d2v2_multi: True
|
||||
|
||||
criterion:
|
||||
_name: model
|
||||
log_keys:
|
||||
- ema_decay
|
||||
- target_var
|
||||
- pred_var
|
||||
- model_norm
|
||||
- ema_norm
|
||||
- masked_pct
|
||||
|
||||
dataset:
|
||||
batch_size: 2
|
||||
ignore_unused_valid_subsets: true
|
||||
skip_invalid_size_inputs_valid_test: true
|
||||
disable_validation: true
|
||||
|
||||
optimization:
|
||||
clip_norm: 1
|
||||
lr: [3e-4]
|
||||
max_update: 1000000
|
||||
update_freq: [1]
|
||||
|
||||
optimizer:
|
||||
_name: composite
|
||||
groups:
|
||||
default:
|
||||
lr_float: 1e-4
|
||||
optimizer:
|
||||
_name: adam
|
||||
adam_betas: [0.9,0.98]
|
||||
adam_eps: 1e-06
|
||||
weight_decay: 0.01
|
||||
lr_scheduler:
|
||||
_name: cosine
|
||||
warmup_updates: 4000
|
||||
decoder:
|
||||
lr_float: 1e-4
|
||||
optimizer:
|
||||
_name: adam
|
||||
adam_betas: [0.9,0.98]
|
||||
adam_eps: 1e-06
|
||||
weight_decay: 0.01
|
||||
lr_scheduler:
|
||||
_name: cosine
|
||||
warmup_updates: 4000
|
||||
|
||||
lr_scheduler: pass_through
|
||||
|
||||
model:
|
||||
_name: data2vec_multi
|
||||
|
||||
loss_beta: 4
|
||||
loss_scale: 1
|
||||
|
||||
depth: 24
|
||||
num_heads: 16
|
||||
embed_dim: 1024
|
||||
clone_batch: 8
|
||||
|
||||
ema_decay: 0.9999
|
||||
ema_end_decay: 0.99999
|
||||
ema_anneal_end_step: 100000
|
||||
ema_encoder_only: true
|
||||
|
||||
average_top_k_layers: 24
|
||||
layer_norm_target_layer: true
|
||||
instance_norm_target_layer: false
|
||||
batch_norm_target_layer: false
|
||||
instance_norm_targets: true
|
||||
layer_norm_targets: false
|
||||
|
||||
layerdrop: 0
|
||||
norm_eps: 1e-5
|
||||
|
||||
supported_modality: TEXT
|
||||
decoder_group: true
|
||||
|
||||
modalities:
|
||||
text:
|
||||
mask_prob: 0.5
|
||||
mask_length: 1
|
||||
mask_noise_std: 0.01
|
||||
prenet_depth: 0
|
||||
decoder:
|
||||
input_dropout: 0.1
|
||||
decoder_dim: 768
|
||||
decoder_groups: 1
|
||||
decoder_kernel: 9
|
||||
decoder_layers: 5
|
||||
decoder_residual: false
|
||||
projection_layers: 2
|
||||
projection_ratio: 2.0
|
15
examples/data2vec/config/v2/run_config/local.yaml
Normal file
15
examples/data2vec/config/v2/run_config/local.yaml
Normal file
@ -0,0 +1,15 @@
|
||||
# @package _global_
|
||||
hydra:
|
||||
sweep:
|
||||
dir: ${env:PWD}/tmp_dbg/${now:%H-%M-%S}
|
||||
|
||||
distributed_training:
|
||||
distributed_world_size: 1
|
||||
nprocs_per_node: 1
|
||||
distributed_port: -1
|
||||
|
||||
common:
|
||||
log_interval: 1
|
||||
|
||||
dataset:
|
||||
num_workers: 0
|
37
examples/data2vec/config/v2/run_config/slurm_1.yaml
Normal file
37
examples/data2vec/config/v2/run_config/slurm_1.yaml
Normal file
@ -0,0 +1,37 @@
|
||||
# @package _global_
|
||||
|
||||
hydra:
|
||||
job:
|
||||
config:
|
||||
override_dirname:
|
||||
kv_sep: ':'
|
||||
item_sep: '/'
|
||||
exclude_keys:
|
||||
- run_config
|
||||
- distributed_training.distributed_port
|
||||
- distributed_training.distributed_world_size
|
||||
- model.pretrained_model_path
|
||||
- model.target_network_path
|
||||
- next_script
|
||||
- task.cache_in_scratch
|
||||
- task.data
|
||||
- checkpoint.save_interval_updates
|
||||
- checkpoint.keep_interval_updates
|
||||
- checkpoint.save_on_overflow
|
||||
- common.log_interval
|
||||
- common.user_dir
|
||||
sweep:
|
||||
dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
|
||||
subdir: ''
|
||||
launcher:
|
||||
submitit_folder: ${hydra.sweep.dir}
|
||||
timeout_min: 4320
|
||||
cpus_per_task: 80
|
||||
gpus_per_node: 8
|
||||
tasks_per_node: 1
|
||||
mem_gb: 450
|
||||
nodes: 1
|
||||
name: ${env:PREFIX}_${hydra.job.config_name}
|
||||
partition: devlab,learnlab,learnfair,scavenge
|
||||
constraint: volta32gb,ib4
|
||||
max_num_timeout: 30
|
37
examples/data2vec/config/v2/run_config/slurm_1_aws.yaml
Normal file
37
examples/data2vec/config/v2/run_config/slurm_1_aws.yaml
Normal file
@ -0,0 +1,37 @@
|
||||
# @package _global_
|
||||
|
||||
hydra:
|
||||
job:
|
||||
config:
|
||||
override_dirname:
|
||||
kv_sep: ':'
|
||||
item_sep: '/'
|
||||
exclude_keys:
|
||||
- run_config
|
||||
- distributed_training.distributed_port
|
||||
- distributed_training.distributed_world_size
|
||||
- model.pretrained_model_path
|
||||
- model.target_network_path
|
||||
- next_script
|
||||
- task.cache_in_scratch
|
||||
- task.local_cache_path
|
||||
- task.data
|
||||
- checkpoint.save_interval_updates
|
||||
- checkpoint.keep_interval_updates
|
||||
- checkpoint.save_on_overflow
|
||||
- common.log_interval
|
||||
- common.user_dir
|
||||
sweep:
|
||||
dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
|
||||
subdir: ''
|
||||
launcher:
|
||||
submitit_folder: ${hydra.sweep.dir}
|
||||
timeout_min: 4320
|
||||
cpus_per_task: 80
|
||||
gpus_per_node: 8
|
||||
tasks_per_node: 1
|
||||
mem_gb: 0
|
||||
nodes: 1
|
||||
name: ${env:PREFIX}_${hydra.job.config_name}
|
||||
partition: wav2vec,learnlab,learnfair
|
||||
max_num_timeout: 30
|
37
examples/data2vec/config/v2/run_config/slurm_2.yaml
Normal file
37
examples/data2vec/config/v2/run_config/slurm_2.yaml
Normal file
@ -0,0 +1,37 @@
|
||||
# @package _global_
|
||||
|
||||
hydra:
|
||||
job:
|
||||
config:
|
||||
override_dirname:
|
||||
kv_sep: ':'
|
||||
item_sep: '/'
|
||||
exclude_keys:
|
||||
- run_config
|
||||
- distributed_training.distributed_port
|
||||
- distributed_training.distributed_world_size
|
||||
- model.pretrained_model_path
|
||||
- model.target_network_path
|
||||
- next_script
|
||||
- task.cache_in_scratch
|
||||
- task.data
|
||||
- checkpoint.save_interval_updates
|
||||
- checkpoint.keep_interval_updates
|
||||
- checkpoint.save_on_overflow
|
||||
- common.log_interval
|
||||
- common.user_dir
|
||||
sweep:
|
||||
dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
|
||||
subdir: ''
|
||||
launcher:
|
||||
submitit_folder: ${hydra.sweep.dir}
|
||||
timeout_min: 4320
|
||||
cpus_per_task: 10
|
||||
gpus_per_node: 8
|
||||
tasks_per_node: 8
|
||||
mem_gb: 450
|
||||
nodes: 2
|
||||
name: ${env:PREFIX}_${hydra.job.config_name}
|
||||
partition: devlab,learnlab,learnfair,scavenge
|
||||
constraint: volta32gb,ib4
|
||||
max_num_timeout: 30
|
39
examples/data2vec/config/v2/run_config/slurm_2_aws.yaml
Normal file
39
examples/data2vec/config/v2/run_config/slurm_2_aws.yaml
Normal file
@ -0,0 +1,39 @@
|
||||
# @package _global_
|
||||
|
||||
hydra:
|
||||
job:
|
||||
config:
|
||||
override_dirname:
|
||||
kv_sep: ':'
|
||||
item_sep: '/'
|
||||
exclude_keys:
|
||||
- run_config
|
||||
- distributed_training.distributed_port
|
||||
- distributed_training.distributed_world_size
|
||||
- model.pretrained_model_path
|
||||
- model.target_network_path
|
||||
- next_script
|
||||
- task.cache_in_scratch
|
||||
- task.local_cache_path
|
||||
- task.data
|
||||
- task.post_save_script
|
||||
- checkpoint.save_interval_updates
|
||||
- checkpoint.keep_interval_updates
|
||||
- checkpoint.save_on_overflow
|
||||
- common.log_interval
|
||||
- common.user_dir
|
||||
- model.model_path
|
||||
sweep:
|
||||
dir: /fsx-wav2vec/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
|
||||
subdir: ''
|
||||
launcher:
|
||||
submitit_folder: ${hydra.sweep.dir}
|
||||
timeout_min: 4320
|
||||
cpus_per_task: 12
|
||||
gpus_per_node: 8
|
||||
tasks_per_node: 8
|
||||
mem_gb: 0
|
||||
nodes: 2
|
||||
name: ${env:PREFIX}_${hydra.job.config_name}
|
||||
partition: wav2vec
|
||||
max_num_timeout: 30
|
36
examples/data2vec/config/v2/run_config/slurm_3.yaml
Normal file
36
examples/data2vec/config/v2/run_config/slurm_3.yaml
Normal file
@ -0,0 +1,36 @@
|
||||
# @package _global_
|
||||
|
||||
hydra:
|
||||
job:
|
||||
config:
|
||||
override_dirname:
|
||||
kv_sep: ':'
|
||||
item_sep: '/'
|
||||
exclude_keys:
|
||||
- run_config
|
||||
- distributed_training.distributed_port
|
||||
- distributed_training.distributed_world_size
|
||||
- model.pretrained_model_path
|
||||
- model.target_network_path
|
||||
- next_script
|
||||
- task.cache_in_scratch
|
||||
- task.data
|
||||
- checkpoint.save_interval_updates
|
||||
- checkpoint.keep_interval_updates
|
||||
- checkpoint.save_on_overflow
|
||||
- common.log_interval
|
||||
sweep:
|
||||
dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
|
||||
subdir: ''
|
||||
launcher:
|
||||
submitit_folder: ${hydra.sweep.dir}
|
||||
timeout_min: 4320
|
||||
cpus_per_task: 80
|
||||
gpus_per_node: 8
|
||||
tasks_per_node: 1
|
||||
mem_gb: 450
|
||||
nodes: 3
|
||||
name: ${env:PREFIX}_${hydra.job.config_name}
|
||||
partition: devlab,learnlab,learnfair,scavenge
|
||||
constraint: volta32gb,ib4
|
||||
max_num_timeout: 30
|
36
examples/data2vec/config/v2/run_config/slurm_4.yaml
Normal file
36
examples/data2vec/config/v2/run_config/slurm_4.yaml
Normal file
@ -0,0 +1,36 @@
|
||||
# @package _global_
|
||||
|
||||
hydra:
|
||||
job:
|
||||
config:
|
||||
override_dirname:
|
||||
kv_sep: ':'
|
||||
item_sep: '/'
|
||||
exclude_keys:
|
||||
- run_config
|
||||
- distributed_training.distributed_port
|
||||
- distributed_training.distributed_world_size
|
||||
- model.pretrained_model_path
|
||||
- model.target_network_path
|
||||
- next_script
|
||||
- task.cache_in_scratch
|
||||
- task.data
|
||||
- checkpoint.save_interval_updates
|
||||
- checkpoint.keep_interval_updates
|
||||
- checkpoint.save_on_overflow
|
||||
- common.log_interval
|
||||
sweep:
|
||||
dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
|
||||
subdir: ''
|
||||
launcher:
|
||||
submitit_folder: ${hydra.sweep.dir}
|
||||
timeout_min: 4320
|
||||
cpus_per_task: 10
|
||||
gpus_per_node: 8
|
||||
tasks_per_node: 8
|
||||
mem_gb: 450
|
||||
nodes: 4
|
||||
name: ${env:PREFIX}_${hydra.job.config_name}
|
||||
partition: devlab,learnlab,learnfair,scavenge
|
||||
constraint: volta32gb,ib4
|
||||
max_num_timeout: 30
|
37
examples/data2vec/config/v2/run_config/slurm_4_aws.yaml
Normal file
37
examples/data2vec/config/v2/run_config/slurm_4_aws.yaml
Normal file
@ -0,0 +1,37 @@
|
||||
# @package _global_
|
||||
|
||||
hydra:
|
||||
job:
|
||||
config:
|
||||
override_dirname:
|
||||
kv_sep: ':'
|
||||
item_sep: '/'
|
||||
exclude_keys:
|
||||
- run_config
|
||||
- distributed_training.distributed_port
|
||||
- distributed_training.distributed_world_size
|
||||
- model.pretrained_model_path
|
||||
- model.target_network_path
|
||||
- next_script
|
||||
- task.cache_in_scratch
|
||||
- task.data
|
||||
- task.post_save_script
|
||||
- checkpoint.save_interval_updates
|
||||
- checkpoint.keep_interval_updates
|
||||
- checkpoint.save_on_overflow
|
||||
- common.log_interval
|
||||
- common.user_dir
|
||||
sweep:
|
||||
dir: /fsx-wav2vec/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
|
||||
subdir: ''
|
||||
launcher:
|
||||
submitit_folder: ${hydra.sweep.dir}
|
||||
timeout_min: 4320
|
||||
cpus_per_task: 12
|
||||
gpus_per_node: 8
|
||||
tasks_per_node: 8
|
||||
mem_gb: 0
|
||||
nodes: 4
|
||||
name: ${env:PREFIX}_${hydra.job.config_name}
|
||||
partition: wav2vec
|
||||
max_num_timeout: 30
|
36
examples/data2vec/config/v2/run_config/slurm_6_aws.yaml
Normal file
36
examples/data2vec/config/v2/run_config/slurm_6_aws.yaml
Normal file
@ -0,0 +1,36 @@
|
||||
# @package _global_
|
||||
|
||||
hydra:
|
||||
job:
|
||||
config:
|
||||
override_dirname:
|
||||
kv_sep: ':'
|
||||
item_sep: '/'
|
||||
exclude_keys:
|
||||
- run_config
|
||||
- distributed_training.distributed_port
|
||||
- distributed_training.distributed_world_size
|
||||
- model.pretrained_model_path
|
||||
- model.target_network_path
|
||||
- next_script
|
||||
- task.cache_in_scratch
|
||||
- task.data
|
||||
- checkpoint.save_interval_updates
|
||||
- checkpoint.keep_interval_updates
|
||||
- checkpoint.save_on_overflow
|
||||
- common.log_interval
|
||||
- common.user_dir
|
||||
sweep:
|
||||
dir: /fsx-wav2vec/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
|
||||
subdir: ''
|
||||
launcher:
|
||||
submitit_folder: ${hydra.sweep.dir}
|
||||
timeout_min: 4320
|
||||
cpus_per_task: 12
|
||||
gpus_per_node: 8
|
||||
tasks_per_node: 8
|
||||
mem_gb: 0
|
||||
nodes: 6
|
||||
name: ${env:PREFIX}_${hydra.job.config_name}
|
||||
partition: wav2vec,learnlab,learnfair
|
||||
max_num_timeout: 30
|
37
examples/data2vec/config/v2/run_config/slurm_8.yaml
Normal file
37
examples/data2vec/config/v2/run_config/slurm_8.yaml
Normal file
@ -0,0 +1,37 @@
|
||||
# @package _global_
|
||||
|
||||
hydra:
|
||||
job:
|
||||
config:
|
||||
override_dirname:
|
||||
kv_sep: ':'
|
||||
item_sep: '/'
|
||||
exclude_keys:
|
||||
- run_config
|
||||
- distributed_training.distributed_port
|
||||
- distributed_training.distributed_world_size
|
||||
- model.pretrained_model_path
|
||||
- model.target_network_path
|
||||
- next_script
|
||||
- task.cache_in_scratch
|
||||
- task.data
|
||||
- checkpoint.save_interval_updates
|
||||
- checkpoint.keep_interval_updates
|
||||
- checkpoint.save_on_overflow
|
||||
- common.log_interval
|
||||
- common.user_dir
|
||||
sweep:
|
||||
dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
|
||||
subdir: ''
|
||||
launcher:
|
||||
submitit_folder: ${hydra.sweep.dir}
|
||||
timeout_min: 4320
|
||||
cpus_per_task: 10
|
||||
gpus_per_node: 8
|
||||
tasks_per_node: 8
|
||||
mem_gb: 450
|
||||
nodes: 8
|
||||
name: ${env:PREFIX}_${hydra.job.config_name}
|
||||
partition: devlab,learnlab,learnfair,scavenge
|
||||
constraint: volta32gb,ib4
|
||||
max_num_timeout: 30
|
36
examples/data2vec/config/v2/run_config/slurm_8_aws.yaml
Normal file
36
examples/data2vec/config/v2/run_config/slurm_8_aws.yaml
Normal file
@ -0,0 +1,36 @@
|
||||
# @package _global_
|
||||
|
||||
hydra:
|
||||
job:
|
||||
config:
|
||||
override_dirname:
|
||||
kv_sep: ':'
|
||||
item_sep: '/'
|
||||
exclude_keys:
|
||||
- run_config
|
||||
- distributed_training.distributed_port
|
||||
- distributed_training.distributed_world_size
|
||||
- model.pretrained_model_path
|
||||
- model.target_network_path
|
||||
- next_script
|
||||
- task.cache_in_scratch
|
||||
- task.data
|
||||
- checkpoint.save_interval_updates
|
||||
- checkpoint.keep_interval_updates
|
||||
- checkpoint.save_on_overflow
|
||||
- common.log_interval
|
||||
- common.user_dir
|
||||
sweep:
|
||||
dir: /fsx-wav2vec/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
|
||||
subdir: ''
|
||||
launcher:
|
||||
submitit_folder: ${hydra.sweep.dir}
|
||||
timeout_min: 4320
|
||||
cpus_per_task: 12
|
||||
gpus_per_node: 8
|
||||
tasks_per_node: 8
|
||||
mem_gb: 0
|
||||
nodes: 8
|
||||
name: ${env:PREFIX}_${hydra.job.config_name}
|
||||
partition: wav2vec,learnlab,learnfair
|
||||
max_num_timeout: 30
|
60
examples/data2vec/config/v2/text_finetuning/cola.yaml
Normal file
60
examples/data2vec/config/v2/text_finetuning/cola.yaml
Normal file
@ -0,0 +1,60 @@
|
||||
# @package _group_
|
||||
|
||||
common:
|
||||
fp16: true
|
||||
fp16_init_scale: 4
|
||||
threshold_loss_scale: 1
|
||||
fp16_scale_window: 128
|
||||
log_format: json
|
||||
log_interval: 200
|
||||
user_dir: ${env:PWD}/examples/data2vec
|
||||
|
||||
task:
|
||||
_name: sentence_prediction
|
||||
data: ???
|
||||
init_token: 0
|
||||
separator_token: 2
|
||||
num_classes: 2
|
||||
max_positions: 512
|
||||
d2v2_multi: True
|
||||
|
||||
checkpoint:
|
||||
best_checkpoint_metric: mcc
|
||||
maximize_best_checkpoint_metric: true
|
||||
no_epoch_checkpoints: true
|
||||
|
||||
distributed_training:
|
||||
find_unused_parameters: true
|
||||
distributed_world_size: 1
|
||||
nprocs_per_node: 1
|
||||
distributed_port: -1
|
||||
|
||||
criterion:
|
||||
_name: sentence_prediction
|
||||
report_mcc: True
|
||||
|
||||
dataset:
|
||||
batch_size: 16
|
||||
required_batch_size_multiple: 1
|
||||
max_tokens: 4400
|
||||
num_workers: 1
|
||||
|
||||
optimizer:
|
||||
_name: adam
|
||||
weight_decay: 0.1
|
||||
adam_betas: (0.9,0.98)
|
||||
adam_eps: 1e-06
|
||||
|
||||
lr_scheduler:
|
||||
_name: polynomial_decay
|
||||
warmup_updates: 320
|
||||
|
||||
optimization:
|
||||
clip_norm: 0.0
|
||||
lr: [2e-05]
|
||||
max_update: 5336
|
||||
max_epoch: 10
|
||||
|
||||
model:
|
||||
_name: data2vec_text_classification
|
||||
model_path: ???
|
60
examples/data2vec/config/v2/text_finetuning/mnli.yaml
Normal file
60
examples/data2vec/config/v2/text_finetuning/mnli.yaml
Normal file
@ -0,0 +1,60 @@
|
||||
# @package _group_
|
||||
|
||||
common:
|
||||
fp16: true
|
||||
fp16_init_scale: 4
|
||||
threshold_loss_scale: 1
|
||||
fp16_scale_window: 128
|
||||
log_format: json
|
||||
log_interval: 200
|
||||
user_dir: ${env:PWD}/examples/data2vec
|
||||
|
||||
task:
|
||||
_name: sentence_prediction
|
||||
data: ???
|
||||
init_token: 0
|
||||
separator_token: 2
|
||||
num_classes: 3
|
||||
max_positions: 512
|
||||
d2v2_multi: True
|
||||
|
||||
checkpoint:
|
||||
best_checkpoint_metric: accuracy
|
||||
maximize_best_checkpoint_metric: true
|
||||
no_epoch_checkpoints: true
|
||||
|
||||
distributed_training:
|
||||
find_unused_parameters: true
|
||||
distributed_world_size: 1
|
||||
nprocs_per_node: 1
|
||||
distributed_port: -1
|
||||
|
||||
criterion:
|
||||
_name: sentence_prediction
|
||||
|
||||
dataset:
|
||||
batch_size: 32
|
||||
required_batch_size_multiple: 1
|
||||
max_tokens: 4400
|
||||
valid_subset: valid,valid1
|
||||
num_workers: 1
|
||||
|
||||
optimizer:
|
||||
_name: adam
|
||||
weight_decay: 0.1
|
||||
adam_betas: (0.9,0.98)
|
||||
adam_eps: 1e-06
|
||||
|
||||
lr_scheduler:
|
||||
_name: polynomial_decay
|
||||
warmup_updates: 7432
|
||||
|
||||
optimization:
|
||||
clip_norm: 0.0
|
||||
lr: [2e-05]
|
||||
max_update: 123873
|
||||
max_epoch: 10
|
||||
|
||||
model:
|
||||
_name: data2vec_text_classification
|
||||
model_path: ???
|
60
examples/data2vec/config/v2/text_finetuning/mrpc.yaml
Normal file
60
examples/data2vec/config/v2/text_finetuning/mrpc.yaml
Normal file
@ -0,0 +1,60 @@
|
||||
# @package _group_
|
||||
|
||||
common:
|
||||
fp16: true
|
||||
fp16_init_scale: 4
|
||||
threshold_loss_scale: 1
|
||||
fp16_scale_window: 128
|
||||
log_format: json
|
||||
log_interval: 200
|
||||
user_dir: ${env:PWD}/examples/data2vec
|
||||
|
||||
task:
|
||||
_name: sentence_prediction
|
||||
data: ???
|
||||
init_token: 0
|
||||
separator_token: 2
|
||||
num_classes: 2
|
||||
max_positions: 512
|
||||
d2v2_multi: True
|
||||
|
||||
checkpoint:
|
||||
best_checkpoint_metric: acc_and_f1
|
||||
maximize_best_checkpoint_metric: true
|
||||
no_epoch_checkpoints: true
|
||||
|
||||
distributed_training:
|
||||
find_unused_parameters: true
|
||||
distributed_world_size: 1
|
||||
nprocs_per_node: 1
|
||||
distributed_port: -1
|
||||
|
||||
criterion:
|
||||
_name: sentence_prediction
|
||||
report_acc_and_f1: True
|
||||
|
||||
dataset:
|
||||
batch_size: 16
|
||||
required_batch_size_multiple: 1
|
||||
max_tokens: 4400
|
||||
num_workers: 1
|
||||
|
||||
optimizer:
|
||||
_name: adam
|
||||
weight_decay: 0.1
|
||||
adam_betas: (0.9,0.98)
|
||||
adam_eps: 1e-06
|
||||
|
||||
lr_scheduler:
|
||||
_name: polynomial_decay
|
||||
warmup_updates: 137
|
||||
|
||||
optimization:
|
||||
clip_norm: 0.0
|
||||
lr: [2e-05]
|
||||
max_update: 2296
|
||||
max_epoch: 10
|
||||
|
||||
model:
|
||||
_name: data2vec_text_classification
|
||||
model_path: ???
|
59
examples/data2vec/config/v2/text_finetuning/qnli.yaml
Normal file
59
examples/data2vec/config/v2/text_finetuning/qnli.yaml
Normal file
@ -0,0 +1,59 @@
|
||||
# @package _group_
|
||||
|
||||
common:
|
||||
fp16: true
|
||||
fp16_init_scale: 4
|
||||
threshold_loss_scale: 1
|
||||
fp16_scale_window: 128
|
||||
log_format: json
|
||||
log_interval: 200
|
||||
user_dir: ${env:PWD}/examples/data2vec
|
||||
|
||||
task:
|
||||
_name: sentence_prediction
|
||||
data: ???
|
||||
init_token: 0
|
||||
separator_token: 2
|
||||
num_classes: 2
|
||||
max_positions: 512
|
||||
d2v2_multi: True
|
||||
|
||||
checkpoint:
|
||||
best_checkpoint_metric: accuracy
|
||||
maximize_best_checkpoint_metric: true
|
||||
no_epoch_checkpoints: true
|
||||
|
||||
distributed_training:
|
||||
find_unused_parameters: true
|
||||
distributed_world_size: 1
|
||||
nprocs_per_node: 1
|
||||
distributed_port: -1
|
||||
|
||||
criterion:
|
||||
_name: sentence_prediction
|
||||
|
||||
dataset:
|
||||
batch_size: 32
|
||||
required_batch_size_multiple: 1
|
||||
max_tokens: 4400
|
||||
num_workers: 1
|
||||
|
||||
optimizer:
|
||||
_name: adam
|
||||
weight_decay: 0.1
|
||||
adam_betas: (0.9,0.98)
|
||||
adam_eps: 1e-06
|
||||
|
||||
lr_scheduler:
|
||||
_name: polynomial_decay
|
||||
warmup_updates: 1986
|
||||
|
||||
optimization:
|
||||
clip_norm: 0.0
|
||||
lr: [2e-05]
|
||||
max_update: 33112
|
||||
max_epoch: 10
|
||||
|
||||
model:
|
||||
_name: data2vec_text_classification
|
||||
model_path: ???
|
60
examples/data2vec/config/v2/text_finetuning/qqp.yaml
Normal file
60
examples/data2vec/config/v2/text_finetuning/qqp.yaml
Normal file
@ -0,0 +1,60 @@
|
||||
# @package _group_
|
||||
|
||||
common:
|
||||
fp16: true
|
||||
fp16_init_scale: 4
|
||||
threshold_loss_scale: 1
|
||||
fp16_scale_window: 128
|
||||
log_format: json
|
||||
log_interval: 200
|
||||
user_dir: ${env:PWD}/examples/data2vec
|
||||
|
||||
task:
|
||||
_name: sentence_prediction
|
||||
data: ???
|
||||
init_token: 0
|
||||
separator_token: 2
|
||||
num_classes: 2
|
||||
max_positions: 512
|
||||
d2v2_multi: True
|
||||
|
||||
checkpoint:
|
||||
best_checkpoint_metric: acc_and_f1
|
||||
maximize_best_checkpoint_metric: true
|
||||
no_epoch_checkpoints: true
|
||||
|
||||
distributed_training:
|
||||
find_unused_parameters: true
|
||||
distributed_world_size: 1
|
||||
nprocs_per_node: 1
|
||||
distributed_port: -1
|
||||
|
||||
criterion:
|
||||
_name: sentence_prediction
|
||||
report_acc_and_f1: True
|
||||
|
||||
dataset:
|
||||
batch_size: 32
|
||||
required_batch_size_multiple: 1
|
||||
max_tokens: 4400
|
||||
num_workers: 1
|
||||
|
||||
optimizer:
|
||||
_name: adam
|
||||
weight_decay: 0.1
|
||||
adam_betas: (0.9,0.98)
|
||||
adam_eps: 1e-06
|
||||
|
||||
lr_scheduler:
|
||||
_name: polynomial_decay
|
||||
warmup_updates: 28318
|
||||
|
||||
optimization:
|
||||
clip_norm: 0.0
|
||||
lr: [2e-05]
|
||||
max_update: 113272
|
||||
max_epoch: 10
|
||||
|
||||
model:
|
||||
_name: data2vec_text_classification
|
||||
model_path: ???
|
59
examples/data2vec/config/v2/text_finetuning/rte.yaml
Normal file
59
examples/data2vec/config/v2/text_finetuning/rte.yaml
Normal file
@ -0,0 +1,59 @@
|
||||
# @package _group_
|
||||
|
||||
common:
|
||||
fp16: true
|
||||
fp16_init_scale: 4
|
||||
threshold_loss_scale: 1
|
||||
fp16_scale_window: 128
|
||||
log_format: json
|
||||
log_interval: 200
|
||||
user_dir: ${env:PWD}/examples/data2vec
|
||||
|
||||
task:
|
||||
_name: sentence_prediction
|
||||
data: ???
|
||||
init_token: 0
|
||||
separator_token: 2
|
||||
num_classes: 2
|
||||
max_positions: 512
|
||||
d2v2_multi: True
|
||||
|
||||
checkpoint:
|
||||
best_checkpoint_metric: accuracy
|
||||
maximize_best_checkpoint_metric: true
|
||||
no_epoch_checkpoints: true
|
||||
|
||||
distributed_training:
|
||||
find_unused_parameters: true
|
||||
distributed_world_size: 1
|
||||
nprocs_per_node: 1
|
||||
distributed_port: -1
|
||||
|
||||
criterion:
|
||||
_name: sentence_prediction
|
||||
|
||||
dataset:
|
||||
batch_size: 16
|
||||
required_batch_size_multiple: 1
|
||||
max_tokens: 4400
|
||||
num_workers: 1
|
||||
|
||||
optimizer:
|
||||
_name: adam
|
||||
weight_decay: 0.1
|
||||
adam_betas: (0.9,0.98)
|
||||
adam_eps: 1e-06
|
||||
|
||||
lr_scheduler:
|
||||
_name: polynomial_decay
|
||||
warmup_updates: 122
|
||||
|
||||
optimization:
|
||||
clip_norm: 0.0
|
||||
lr: [2e-05]
|
||||
max_update: 2036
|
||||
max_epoch: 10
|
||||
|
||||
model:
|
||||
_name: data2vec_text_classification
|
||||
model_path: ???
|
@ -0,0 +1,15 @@
|
||||
# @package _global_
|
||||
hydra:
|
||||
sweep:
|
||||
dir: ${env:PWD}/tmp_dbg/${now:%H-%M-%S}
|
||||
|
||||
distributed_training:
|
||||
distributed_world_size: 1
|
||||
nprocs_per_node: 1
|
||||
distributed_port: -1
|
||||
|
||||
common:
|
||||
log_interval: 1
|
||||
|
||||
dataset:
|
||||
num_workers: 0
|
59
examples/data2vec/config/v2/text_finetuning/sst_2.yaml
Normal file
59
examples/data2vec/config/v2/text_finetuning/sst_2.yaml
Normal file
@ -0,0 +1,59 @@
|
||||
# @package _group_
|
||||
|
||||
common:
|
||||
fp16: true
|
||||
fp16_init_scale: 4
|
||||
threshold_loss_scale: 1
|
||||
fp16_scale_window: 128
|
||||
log_format: json
|
||||
log_interval: 200
|
||||
user_dir: ${env:PWD}/examples/data2vec
|
||||
|
||||
task:
|
||||
_name: sentence_prediction
|
||||
data: ???
|
||||
init_token: 0
|
||||
separator_token: 2
|
||||
num_classes: 2
|
||||
max_positions: 512
|
||||
d2v2_multi: True
|
||||
|
||||
checkpoint:
|
||||
best_checkpoint_metric: accuracy
|
||||
maximize_best_checkpoint_metric: true
|
||||
no_epoch_checkpoints: true
|
||||
|
||||
distributed_training:
|
||||
find_unused_parameters: true
|
||||
distributed_world_size: 1
|
||||
nprocs_per_node: 1
|
||||
distributed_port: -1
|
||||
|
||||
criterion:
|
||||
_name: sentence_prediction
|
||||
|
||||
dataset:
|
||||
batch_size: 32
|
||||
required_batch_size_multiple: 1
|
||||
max_tokens: 4400
|
||||
num_workers: 1
|
||||
|
||||
optimizer:
|
||||
_name: adam
|
||||
weight_decay: 0.1
|
||||
adam_betas: (0.9,0.98)
|
||||
adam_eps: 1e-06
|
||||
|
||||
lr_scheduler:
|
||||
_name: polynomial_decay
|
||||
warmup_updates: 1256
|
||||
|
||||
optimization:
|
||||
clip_norm: 0.0
|
||||
lr: [2e-05]
|
||||
max_update: 20935
|
||||
max_epoch: 10
|
||||
|
||||
model:
|
||||
_name: data2vec_text_classification
|
||||
model_path: ???
|
61
examples/data2vec/config/v2/text_finetuning/sts_b.yaml
Normal file
61
examples/data2vec/config/v2/text_finetuning/sts_b.yaml
Normal file
@ -0,0 +1,61 @@
|
||||
# @package _group_
|
||||
|
||||
common:
|
||||
fp16: true
|
||||
fp16_init_scale: 4
|
||||
threshold_loss_scale: 1
|
||||
fp16_scale_window: 128
|
||||
log_format: json
|
||||
log_interval: 200
|
||||
user_dir: ${env:PWD}/examples/data2vec
|
||||
|
||||
task:
|
||||
_name: sentence_prediction
|
||||
data: ???
|
||||
init_token: 0
|
||||
separator_token: 2
|
||||
num_classes: 1
|
||||
max_positions: 512
|
||||
d2v2_multi: True
|
||||
|
||||
checkpoint:
|
||||
best_checkpoint_metric: pearson_and_spearman
|
||||
maximize_best_checkpoint_metric: true
|
||||
no_epoch_checkpoints: true
|
||||
|
||||
distributed_training:
|
||||
find_unused_parameters: true
|
||||
distributed_world_size: 1
|
||||
nprocs_per_node: 1
|
||||
distributed_port: -1
|
||||
|
||||
criterion:
|
||||
_name: sentence_prediction
|
||||
regression_target: true
|
||||
report_pearson_and_spearman: True
|
||||
|
||||
dataset:
|
||||
batch_size: 16
|
||||
required_batch_size_multiple: 1
|
||||
max_tokens: 4400
|
||||
num_workers: 1
|
||||
|
||||
optimizer:
|
||||
_name: adam
|
||||
weight_decay: 0.1
|
||||
adam_betas: (0.9,0.98)
|
||||
adam_eps: 1e-06
|
||||
|
||||
lr_scheduler:
|
||||
_name: polynomial_decay
|
||||
warmup_updates: 214
|
||||
|
||||
optimization:
|
||||
clip_norm: 0.0
|
||||
lr: [4e-05]
|
||||
max_update: 3598
|
||||
max_epoch: 10
|
||||
|
||||
model:
|
||||
_name: data2vec_text_classification
|
||||
model_path: ???
|
52
examples/data2vec/config/vision/finetuning/imagenet.yaml
Normal file
52
examples/data2vec/config/vision/finetuning/imagenet.yaml
Normal file
@ -0,0 +1,52 @@
|
||||
# @package _group_
|
||||
|
||||
common:
|
||||
fp16: true
|
||||
log_format: json
|
||||
log_interval: 200
|
||||
tensorboard_logdir: tb
|
||||
|
||||
checkpoint:
|
||||
save_interval: 1
|
||||
save_interval_updates: 25000
|
||||
keep_interval_updates: 1
|
||||
no_epoch_checkpoints: true
|
||||
best_checkpoint_metric: accuracy
|
||||
|
||||
task:
|
||||
_name: image_classification
|
||||
data: /datasets01/imagenet_full_size/061417
|
||||
|
||||
dataset:
|
||||
num_workers: 6
|
||||
batch_size: 64
|
||||
skip_invalid_size_inputs_valid_test: true
|
||||
required_batch_size_multiple: 1
|
||||
valid_subset: val
|
||||
|
||||
distributed_training:
|
||||
distributed_world_size: 8
|
||||
ddp_backend: c10d
|
||||
|
||||
criterion:
|
||||
_name: model
|
||||
log_keys:
|
||||
- correct
|
||||
|
||||
optimization:
|
||||
max_update: 100000
|
||||
lr: [0.0005]
|
||||
|
||||
optimizer:
|
||||
_name: adam
|
||||
adam_betas: (0.9,0.98)
|
||||
adam_eps: 1e-06
|
||||
weight_decay: 0.01
|
||||
|
||||
lr_scheduler:
|
||||
_name: cosine
|
||||
warmup_updates: 10000
|
||||
|
||||
model:
|
||||
_name: data2vec_image_classification
|
||||
model_path: ???
|
@ -0,0 +1,65 @@
|
||||
# @package _group_
|
||||
|
||||
common:
|
||||
fp16: true
|
||||
log_format: json
|
||||
log_interval: 200
|
||||
tensorboard_logdir: tb
|
||||
fp16_no_flatten_grads: true
|
||||
|
||||
checkpoint:
|
||||
save_interval: 1
|
||||
save_interval_updates: 25000
|
||||
keep_interval_updates: 1
|
||||
no_epoch_checkpoints: true
|
||||
best_checkpoint_metric: accuracy
|
||||
maximize_best_checkpoint_metric: true
|
||||
|
||||
task:
|
||||
_name: mae_image_classification
|
||||
data: /datasets01/imagenet_full_size/061417
|
||||
|
||||
dataset:
|
||||
num_workers: 6
|
||||
batch_size: 32
|
||||
skip_invalid_size_inputs_valid_test: true
|
||||
required_batch_size_multiple: 2
|
||||
valid_subset: val
|
||||
|
||||
distributed_training:
|
||||
distributed_world_size: 16
|
||||
ddp_backend: c10d
|
||||
|
||||
criterion:
|
||||
_name: model
|
||||
log_keys:
|
||||
- correct
|
||||
|
||||
optimization:
|
||||
max_update: 250200
|
||||
lr: [0.001]
|
||||
|
||||
optimizer:
|
||||
_name: composite
|
||||
dynamic_groups: true
|
||||
groups:
|
||||
default:
|
||||
lr_float: 0.001
|
||||
optimizer:
|
||||
_name: adam
|
||||
adam_betas: [0.9,0.95]
|
||||
weight_decay: 0.05
|
||||
lr_scheduler:
|
||||
_name: cosine
|
||||
warmup_updates: 16000
|
||||
min_lr: 1e-6
|
||||
|
||||
|
||||
lr_scheduler: pass_through
|
||||
|
||||
model:
|
||||
_name: mae_image_classification
|
||||
mixup: 0.7
|
||||
mixup_prob: 0.9
|
||||
|
||||
model_path: ???
|
@ -0,0 +1,68 @@
|
||||
# @package _group_
|
||||
|
||||
common:
|
||||
fp16: true
|
||||
log_format: json
|
||||
log_interval: 200
|
||||
tensorboard_logdir: tb
|
||||
fp16_no_flatten_grads: true
|
||||
|
||||
checkpoint:
|
||||
save_interval: 1
|
||||
save_interval_updates: 25000
|
||||
keep_interval_updates: 1
|
||||
no_epoch_checkpoints: true
|
||||
best_checkpoint_metric: accuracy
|
||||
maximize_best_checkpoint_metric: true
|
||||
|
||||
task:
|
||||
_name: mae_image_classification
|
||||
data: /datasets01/imagenet_full_size/061417
|
||||
|
||||
dataset:
|
||||
num_workers: 6
|
||||
batch_size: 32
|
||||
skip_invalid_size_inputs_valid_test: true
|
||||
required_batch_size_multiple: 2
|
||||
valid_subset: val
|
||||
|
||||
distributed_training:
|
||||
distributed_world_size: 16
|
||||
ddp_backend: c10d
|
||||
|
||||
criterion:
|
||||
_name: model
|
||||
log_keys:
|
||||
- correct
|
||||
|
||||
optimization:
|
||||
max_update: 125200
|
||||
lr: [0.0005]
|
||||
clip_norm: 4
|
||||
|
||||
optimizer:
|
||||
_name: composite
|
||||
dynamic_groups: true
|
||||
groups:
|
||||
default:
|
||||
lr_float: 0.0005
|
||||
optimizer:
|
||||
_name: adam
|
||||
adam_betas: [0.9,0.95]
|
||||
weight_decay: 0.05
|
||||
lr_scheduler:
|
||||
_name: cosine
|
||||
warmup_updates: 16000
|
||||
min_lr: 1e-20
|
||||
|
||||
|
||||
lr_scheduler: pass_through
|
||||
|
||||
model:
|
||||
_name: mae_image_classification
|
||||
mixup: 0.7
|
||||
mixup_prob: 0.9
|
||||
layer_decay: 0.75
|
||||
drop_path_rate: 0.2
|
||||
|
||||
model_path: ???
|
@ -0,0 +1,68 @@
|
||||
# @package _group_
|
||||
|
||||
common:
|
||||
fp16: true
|
||||
log_format: json
|
||||
log_interval: 200
|
||||
tensorboard_logdir: tb
|
||||
fp16_no_flatten_grads: true
|
||||
|
||||
checkpoint:
|
||||
save_interval: 1
|
||||
save_interval_updates: 25000
|
||||
keep_interval_updates: 1
|
||||
no_epoch_checkpoints: true
|
||||
best_checkpoint_metric: accuracy
|
||||
maximize_best_checkpoint_metric: true
|
||||
|
||||
task:
|
||||
_name: mae_image_classification
|
||||
data: /datasets01/imagenet_full_size/061417
|
||||
|
||||
dataset:
|
||||
num_workers: 6
|
||||
batch_size: 32
|
||||
skip_invalid_size_inputs_valid_test: true
|
||||
required_batch_size_multiple: 2
|
||||
valid_subset: val
|
||||
|
||||
distributed_training:
|
||||
distributed_world_size: 16
|
||||
ddp_backend: c10d
|
||||
|
||||
criterion:
|
||||
_name: model
|
||||
log_keys:
|
||||
- correct
|
||||
|
||||
optimization:
|
||||
max_update: 125200
|
||||
lr: [0.0005]
|
||||
clip_norm: 4
|
||||
|
||||
optimizer:
|
||||
_name: composite
|
||||
dynamic_groups: true
|
||||
groups:
|
||||
default:
|
||||
lr_float: 0.0005
|
||||
optimizer:
|
||||
_name: adam
|
||||
adam_betas: [0.9,0.95]
|
||||
weight_decay: 0.05
|
||||
lr_scheduler:
|
||||
_name: cosine
|
||||
warmup_updates: 16000
|
||||
min_lr: 1e-7
|
||||
|
||||
|
||||
lr_scheduler: pass_through
|
||||
|
||||
model:
|
||||
_name: mae_image_classification
|
||||
mixup: 0.7
|
||||
mixup_prob: 0.9
|
||||
layer_decay: 0.75
|
||||
drop_path_rate: 0.2
|
||||
|
||||
model_path: ???
|
@ -0,0 +1,15 @@
|
||||
# @package _global_
|
||||
hydra:
|
||||
sweep:
|
||||
dir: ${env:PWD}/tmp_dbg/${now:%H-%M-%S}
|
||||
|
||||
distributed_training:
|
||||
distributed_world_size: 1
|
||||
nprocs_per_node: 1
|
||||
distributed_port: -1
|
||||
|
||||
common:
|
||||
log_interval: 1
|
||||
|
||||
dataset:
|
||||
num_workers: 0
|
@ -0,0 +1,37 @@
|
||||
# @package _global_
|
||||
|
||||
hydra:
|
||||
job:
|
||||
config:
|
||||
override_dirname:
|
||||
kv_sep: ':'
|
||||
item_sep: '/'
|
||||
exclude_keys:
|
||||
- run_config
|
||||
- distributed_training.distributed_port
|
||||
- distributed_training.distributed_world_size
|
||||
- model.pretrained_model_path
|
||||
- model.target_network_path
|
||||
- next_script
|
||||
- task.cache_in_scratch
|
||||
- task.data
|
||||
- checkpoint.save_interval_updates
|
||||
- checkpoint.keep_interval_updates
|
||||
- checkpoint.save_on_overflow
|
||||
- common.log_interval
|
||||
- common.user_dir
|
||||
sweep:
|
||||
dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
|
||||
subdir: ''
|
||||
launcher:
|
||||
submitit_folder: ${hydra.sweep.dir}
|
||||
timeout_min: 4320
|
||||
cpus_per_task: 80
|
||||
gpus_per_node: 8
|
||||
tasks_per_node: 1
|
||||
mem_gb: 450
|
||||
nodes: 1
|
||||
name: ${env:PREFIX}_${hydra.job.config_name}
|
||||
partition: devlab,learnlab,learnfair,scavenge
|
||||
constraint: volta32gb,ib4
|
||||
max_num_timeout: 30
|
@ -0,0 +1,36 @@
|
||||
# @package _global_
|
||||
|
||||
hydra:
|
||||
job:
|
||||
config:
|
||||
override_dirname:
|
||||
kv_sep: ':'
|
||||
item_sep: '/'
|
||||
exclude_keys:
|
||||
- run_config
|
||||
- distributed_training.distributed_port
|
||||
- distributed_training.distributed_world_size
|
||||
- model.pretrained_model_path
|
||||
- model.target_network_path
|
||||
- next_script
|
||||
- task.cache_in_scratch
|
||||
- task.data
|
||||
- checkpoint.save_interval_updates
|
||||
- checkpoint.keep_interval_updates
|
||||
- checkpoint.save_on_overflow
|
||||
- common.log_interval
|
||||
- common.user_dir
|
||||
sweep:
|
||||
dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
|
||||
subdir: ''
|
||||
launcher:
|
||||
submitit_folder: ${hydra.sweep.dir}
|
||||
timeout_min: 4320
|
||||
cpus_per_task: 80
|
||||
gpus_per_node: 8
|
||||
tasks_per_node: 1
|
||||
mem_gb: 0
|
||||
nodes: 1
|
||||
name: ${env:PREFIX}_${hydra.job.config_name}
|
||||
partition: wav2vec,learnlab,learnfair
|
||||
max_num_timeout: 30
|
@ -0,0 +1,38 @@
|
||||
# @package _global_
|
||||
|
||||
hydra:
|
||||
job:
|
||||
config:
|
||||
override_dirname:
|
||||
kv_sep: ':'
|
||||
item_sep: '/'
|
||||
exclude_keys:
|
||||
- run_config
|
||||
- distributed_training.distributed_port
|
||||
- distributed_training.distributed_world_size
|
||||
- model.pretrained_model_path
|
||||
- model.target_network_path
|
||||
- next_script
|
||||
- task.cache_in_scratch
|
||||
- task.data
|
||||
- checkpoint.save_interval_updates
|
||||
- checkpoint.keep_interval_updates
|
||||
- checkpoint.save_on_overflow
|
||||
- common.log_interval
|
||||
- common.user_dir
|
||||
- task.local_cache_path
|
||||
sweep:
|
||||
dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
|
||||
subdir: ''
|
||||
launcher:
|
||||
submitit_folder: ${hydra.sweep.dir}
|
||||
timeout_min: 4320
|
||||
cpus_per_task: 10
|
||||
gpus_per_node: 8
|
||||
tasks_per_node: 8
|
||||
mem_gb: 450
|
||||
nodes: 2
|
||||
name: ${env:PREFIX}_${hydra.job.config_name}
|
||||
partition: devlab,learnlab,learnfair,scavenge
|
||||
constraint: volta32gb,ib4
|
||||
max_num_timeout: 30
|
@ -0,0 +1,38 @@
|
||||
# @package _global_
|
||||
|
||||
hydra:
|
||||
job:
|
||||
config:
|
||||
override_dirname:
|
||||
kv_sep: ':'
|
||||
item_sep: '/'
|
||||
exclude_keys:
|
||||
- run_config
|
||||
- distributed_training.distributed_port
|
||||
- distributed_training.distributed_world_size
|
||||
- model.pretrained_model_path
|
||||
- model.target_network_path
|
||||
- next_script
|
||||
- task.cache_in_scratch
|
||||
- task.data
|
||||
- checkpoint.save_interval_updates
|
||||
- checkpoint.keep_interval_updates
|
||||
- checkpoint.save_on_overflow
|
||||
- common.log_interval
|
||||
- common.user_dir
|
||||
- task.local_cache_path
|
||||
- model.model_path
|
||||
sweep:
|
||||
dir: /fsx-wav2vec/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
|
||||
subdir: ''
|
||||
launcher:
|
||||
submitit_folder: ${hydra.sweep.dir}
|
||||
timeout_min: 4320
|
||||
cpus_per_task: 10
|
||||
gpus_per_node: 8
|
||||
tasks_per_node: 8
|
||||
mem_gb: 0
|
||||
nodes: 2
|
||||
name: ${env:PREFIX}_${hydra.job.config_name}
|
||||
partition: wav2vec,learnlab,learnfair
|
||||
max_num_timeout: 30
|
@ -0,0 +1,36 @@
|
||||
# @package _global_
|
||||
|
||||
hydra:
|
||||
job:
|
||||
config:
|
||||
override_dirname:
|
||||
kv_sep: ':'
|
||||
item_sep: '/'
|
||||
exclude_keys:
|
||||
- run_config
|
||||
- distributed_training.distributed_port
|
||||
- distributed_training.distributed_world_size
|
||||
- model.pretrained_model_path
|
||||
- model.target_network_path
|
||||
- next_script
|
||||
- task.cache_in_scratch
|
||||
- task.data
|
||||
- checkpoint.save_interval_updates
|
||||
- checkpoint.keep_interval_updates
|
||||
- checkpoint.save_on_overflow
|
||||
- common.log_interval
|
||||
sweep:
|
||||
dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
|
||||
subdir: ''
|
||||
launcher:
|
||||
submitit_folder: ${hydra.sweep.dir}
|
||||
timeout_min: 4320
|
||||
cpus_per_task: 80
|
||||
gpus_per_node: 8
|
||||
tasks_per_node: 1
|
||||
mem_gb: 450
|
||||
nodes: 3
|
||||
name: ${env:PREFIX}_${hydra.job.config_name}
|
||||
partition: devlab,learnlab,learnfair,scavenge
|
||||
constraint: volta32gb,ib4
|
||||
max_num_timeout: 30
|
@ -0,0 +1,36 @@
|
||||
# @package _global_
|
||||
|
||||
hydra:
|
||||
job:
|
||||
config:
|
||||
override_dirname:
|
||||
kv_sep: ':'
|
||||
item_sep: '/'
|
||||
exclude_keys:
|
||||
- run_config
|
||||
- distributed_training.distributed_port
|
||||
- distributed_training.distributed_world_size
|
||||
- model.pretrained_model_path
|
||||
- model.target_network_path
|
||||
- next_script
|
||||
- task.cache_in_scratch
|
||||
- task.data
|
||||
- checkpoint.save_interval_updates
|
||||
- checkpoint.keep_interval_updates
|
||||
- checkpoint.save_on_overflow
|
||||
- common.log_interval
|
||||
sweep:
|
||||
dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
|
||||
subdir: ''
|
||||
launcher:
|
||||
submitit_folder: ${hydra.sweep.dir}
|
||||
timeout_min: 4320
|
||||
cpus_per_task: 10
|
||||
gpus_per_node: 8
|
||||
tasks_per_node: 8
|
||||
mem_gb: 450
|
||||
nodes: 4
|
||||
name: ${env:PREFIX}_${hydra.job.config_name}
|
||||
partition: devlab,learnlab,learnfair,scavenge
|
||||
constraint: volta32gb,ib4
|
||||
max_num_timeout: 30
|
@ -0,0 +1,36 @@
|
||||
# @package _global_
|
||||
|
||||
hydra:
|
||||
job:
|
||||
config:
|
||||
override_dirname:
|
||||
kv_sep: ':'
|
||||
item_sep: '/'
|
||||
exclude_keys:
|
||||
- run_config
|
||||
- distributed_training.distributed_port
|
||||
- distributed_training.distributed_world_size
|
||||
- model.pretrained_model_path
|
||||
- model.target_network_path
|
||||
- next_script
|
||||
- task.cache_in_scratch
|
||||
- task.data
|
||||
- checkpoint.save_interval_updates
|
||||
- checkpoint.keep_interval_updates
|
||||
- checkpoint.save_on_overflow
|
||||
- common.log_interval
|
||||
- common.user_dir
|
||||
sweep:
|
||||
dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
|
||||
subdir: ''
|
||||
launcher:
|
||||
submitit_folder: ${hydra.sweep.dir}
|
||||
timeout_min: 4320
|
||||
cpus_per_task: 10
|
||||
gpus_per_node: 8
|
||||
tasks_per_node: 8
|
||||
mem_gb: 0
|
||||
nodes: 4
|
||||
name: ${env:PREFIX}_${hydra.job.config_name}
|
||||
partition: wav2vec,learnlab,learnfair
|
||||
max_num_timeout: 30
|
@ -0,0 +1,36 @@
|
||||
# @package _global_
|
||||
|
||||
hydra:
|
||||
job:
|
||||
config:
|
||||
override_dirname:
|
||||
kv_sep: ':'
|
||||
item_sep: '/'
|
||||
exclude_keys:
|
||||
- run_config
|
||||
- distributed_training.distributed_port
|
||||
- distributed_training.distributed_world_size
|
||||
- model.pretrained_model_path
|
||||
- model.target_network_path
|
||||
- next_script
|
||||
- task.cache_in_scratch
|
||||
- task.data
|
||||
- checkpoint.save_interval_updates
|
||||
- checkpoint.keep_interval_updates
|
||||
- checkpoint.save_on_overflow
|
||||
- common.log_interval
|
||||
- common.user_dir
|
||||
sweep:
|
||||
dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
|
||||
subdir: ''
|
||||
launcher:
|
||||
submitit_folder: ${hydra.sweep.dir}
|
||||
timeout_min: 4320
|
||||
cpus_per_task: 10
|
||||
gpus_per_node: 8
|
||||
tasks_per_node: 8
|
||||
mem_gb: 0
|
||||
nodes: 6
|
||||
name: ${env:PREFIX}_${hydra.job.config_name}
|
||||
partition: wav2vec,learnlab,learnfair
|
||||
max_num_timeout: 30
|
@ -0,0 +1,36 @@
|
||||
# @package _global_
|
||||
|
||||
hydra:
|
||||
job:
|
||||
config:
|
||||
override_dirname:
|
||||
kv_sep: ':'
|
||||
item_sep: '/'
|
||||
exclude_keys:
|
||||
- run_config
|
||||
- distributed_training.distributed_port
|
||||
- distributed_training.distributed_world_size
|
||||
- model.pretrained_model_path
|
||||
- model.target_network_path
|
||||
- next_script
|
||||
- task.cache_in_scratch
|
||||
- task.data
|
||||
- checkpoint.save_interval_updates
|
||||
- checkpoint.keep_interval_updates
|
||||
- checkpoint.save_on_overflow
|
||||
- common.log_interval
|
||||
- common.user_dir
|
||||
sweep:
|
||||
dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
|
||||
subdir: ''
|
||||
launcher:
|
||||
submitit_folder: ${hydra.sweep.dir}
|
||||
timeout_min: 4320
|
||||
cpus_per_task: 10
|
||||
gpus_per_node: 8
|
||||
tasks_per_node: 8
|
||||
mem_gb: 0
|
||||
nodes: 8
|
||||
name: ${env:PREFIX}_${hydra.job.config_name}
|
||||
partition: wav2vec,learnlab,learnfair
|
||||
max_num_timeout: 30
|
@ -0,0 +1,52 @@
|
||||
# @package _group_
|
||||
|
||||
common:
|
||||
fp16: true
|
||||
log_format: json
|
||||
log_interval: 200
|
||||
tensorboard_logdir: tb
|
||||
|
||||
checkpoint:
|
||||
save_interval: 5
|
||||
save_interval_updates: 25000
|
||||
keep_interval_updates: 1
|
||||
no_epoch_checkpoints: true
|
||||
|
||||
task:
|
||||
_name: image_pretraining
|
||||
data: /datasets01/imagenet_full_size/061417/
|
||||
|
||||
dataset:
|
||||
num_workers: 6
|
||||
batch_size: 64
|
||||
skip_invalid_size_inputs_valid_test: true
|
||||
required_batch_size_multiple: 1
|
||||
disable_validation: true
|
||||
|
||||
distributed_training:
|
||||
distributed_world_size: 16
|
||||
ddp_backend: c10d
|
||||
|
||||
criterion:
|
||||
_name: model
|
||||
log_keys:
|
||||
- ema_decay
|
||||
- target_var
|
||||
- pred_var
|
||||
|
||||
optimization:
|
||||
max_update: 400000
|
||||
lr: [0.0005]
|
||||
|
||||
optimizer:
|
||||
_name: adam
|
||||
adam_betas: (0.9,0.98)
|
||||
adam_eps: 1e-06
|
||||
weight_decay: 0.01
|
||||
|
||||
lr_scheduler:
|
||||
_name: cosine
|
||||
warmup_updates: 10000
|
||||
|
||||
model:
|
||||
_name: data2vec_vision
|
@ -0,0 +1,64 @@
|
||||
# @package _group_
|
||||
|
||||
common:
|
||||
fp16: true
|
||||
log_format: json
|
||||
log_interval: 200
|
||||
tensorboard_logdir: tb
|
||||
|
||||
checkpoint:
|
||||
save_interval: 5
|
||||
save_interval_updates: 25000
|
||||
keep_interval_updates: 1
|
||||
no_epoch_checkpoints: true
|
||||
|
||||
task:
|
||||
_name: image_pretraining
|
||||
data: /datasets01/imagenet_full_size/061417
|
||||
|
||||
dataset:
|
||||
num_workers: 6
|
||||
batch_size: 128
|
||||
skip_invalid_size_inputs_valid_test: true
|
||||
required_batch_size_multiple: 2
|
||||
disable_validation: true
|
||||
|
||||
distributed_training:
|
||||
distributed_world_size: 16
|
||||
ddp_backend: legacy_ddp
|
||||
|
||||
criterion:
|
||||
_name: model
|
||||
log_keys:
|
||||
- ema_decay
|
||||
- target_var
|
||||
- pred_var
|
||||
|
||||
optimization:
|
||||
max_update: 375300 #300*1251
|
||||
lr: [0.0005]
|
||||
clip_norm: 3.0
|
||||
|
||||
optimizer:
|
||||
_name: adam
|
||||
adam_betas: (0.9,0.999)
|
||||
adam_eps: 1e-08
|
||||
weight_decay: 0.05
|
||||
|
||||
lr_scheduler:
|
||||
_name: cosine
|
||||
warmup_updates: 12510 # it should be 10 epochs
|
||||
|
||||
model:
|
||||
_name: data2vec_vision
|
||||
|
||||
attention_dropout: 0.05
|
||||
|
||||
ema_decay: 0.999
|
||||
ema_end_decay: 0.9998
|
||||
layer_norm_targets: True
|
||||
average_top_k_layers: 6
|
||||
|
||||
loss_beta: 2.0
|
||||
|
||||
drop_path: 0.25
|
@ -0,0 +1,64 @@
|
||||
# @package _group_
|
||||
|
||||
common:
|
||||
fp16: true
|
||||
log_format: json
|
||||
log_interval: 200
|
||||
tensorboard_logdir: tb
|
||||
fp16_no_flatten_grads: true
|
||||
|
||||
checkpoint:
|
||||
save_interval: 5
|
||||
save_interval_updates: 25000
|
||||
keep_interval_updates: 1
|
||||
no_epoch_checkpoints: true
|
||||
|
||||
task:
|
||||
_name: mae_image_pretraining
|
||||
data: /datasets01/imagenet_full_size/061417/
|
||||
rebuild_batches: true
|
||||
|
||||
dataset:
|
||||
num_workers: 6
|
||||
batch_size: 64
|
||||
skip_invalid_size_inputs_valid_test: true
|
||||
required_batch_size_multiple: 1
|
||||
disable_validation: true
|
||||
|
||||
distributed_training:
|
||||
distributed_world_size: 16
|
||||
ddp_backend: c10d
|
||||
|
||||
criterion:
|
||||
_name: model
|
||||
|
||||
optimization:
|
||||
max_update: 375300
|
||||
lr: [0.0006]
|
||||
|
||||
optimizer:
|
||||
_name: composite
|
||||
groups:
|
||||
with_decay:
|
||||
lr_float: 6e-4
|
||||
optimizer:
|
||||
_name: adam
|
||||
adam_betas: [0.9,0.95]
|
||||
weight_decay: 0.05
|
||||
lr_scheduler:
|
||||
_name: cosine
|
||||
warmup_updates: 50040
|
||||
no_decay:
|
||||
lr_float: 6e-4
|
||||
optimizer:
|
||||
_name: adam
|
||||
adam_betas: [0.9,0.95]
|
||||
weight_decay: 0
|
||||
lr_scheduler:
|
||||
_name: cosine
|
||||
warmup_updates: 50040
|
||||
|
||||
lr_scheduler: pass_through
|
||||
|
||||
model:
|
||||
_name: mae
|
@ -0,0 +1,15 @@
|
||||
# @package _global_
|
||||
hydra:
|
||||
sweep:
|
||||
dir: ${env:PWD}/tmp_dbg/${now:%H-%M-%S}
|
||||
|
||||
distributed_training:
|
||||
distributed_world_size: 1
|
||||
nprocs_per_node: 1
|
||||
distributed_port: -1
|
||||
|
||||
common:
|
||||
log_interval: 1
|
||||
|
||||
dataset:
|
||||
num_workers: 0
|
@ -0,0 +1,37 @@
|
||||
# @package _global_
|
||||
|
||||
hydra:
|
||||
job:
|
||||
config:
|
||||
override_dirname:
|
||||
kv_sep: ':'
|
||||
item_sep: '/'
|
||||
exclude_keys:
|
||||
- run_config
|
||||
- distributed_training.distributed_port
|
||||
- distributed_training.distributed_world_size
|
||||
- model.pretrained_model_path
|
||||
- model.target_network_path
|
||||
- next_script
|
||||
- task.cache_in_scratch
|
||||
- task.data
|
||||
- checkpoint.save_interval_updates
|
||||
- checkpoint.keep_interval_updates
|
||||
- checkpoint.save_on_overflow
|
||||
- common.log_interval
|
||||
- common.user_dir
|
||||
sweep:
|
||||
dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
|
||||
subdir: ''
|
||||
launcher:
|
||||
submitit_folder: ${hydra.sweep.dir}
|
||||
timeout_min: 4320
|
||||
cpus_per_task: 80
|
||||
gpus_per_node: 8
|
||||
tasks_per_node: 1
|
||||
mem_gb: 450
|
||||
nodes: 1
|
||||
name: ${env:PREFIX}_${hydra.job.config_name}
|
||||
partition: devlab,learnlab,learnfair,scavenge
|
||||
constraint: volta32gb,ib4
|
||||
max_num_timeout: 30
|
@ -0,0 +1,36 @@
|
||||
# @package _global_
|
||||
|
||||
hydra:
|
||||
job:
|
||||
config:
|
||||
override_dirname:
|
||||
kv_sep: ':'
|
||||
item_sep: '/'
|
||||
exclude_keys:
|
||||
- run_config
|
||||
- distributed_training.distributed_port
|
||||
- distributed_training.distributed_world_size
|
||||
- model.pretrained_model_path
|
||||
- model.target_network_path
|
||||
- next_script
|
||||
- task.cache_in_scratch
|
||||
- task.data
|
||||
- checkpoint.save_interval_updates
|
||||
- checkpoint.keep_interval_updates
|
||||
- checkpoint.save_on_overflow
|
||||
- common.log_interval
|
||||
- common.user_dir
|
||||
sweep:
|
||||
dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
|
||||
subdir: ''
|
||||
launcher:
|
||||
submitit_folder: ${hydra.sweep.dir}
|
||||
timeout_min: 4320
|
||||
cpus_per_task: 80
|
||||
gpus_per_node: 8
|
||||
tasks_per_node: 1
|
||||
mem_gb: 0
|
||||
nodes: 1
|
||||
name: ${env:PREFIX}_${hydra.job.config_name}
|
||||
partition: wav2vec,learnlab,learnfair
|
||||
max_num_timeout: 30
|
@ -0,0 +1,38 @@
|
||||
# @package _global_
|
||||
|
||||
hydra:
|
||||
job:
|
||||
config:
|
||||
override_dirname:
|
||||
kv_sep: ':'
|
||||
item_sep: '/'
|
||||
exclude_keys:
|
||||
- run_config
|
||||
- distributed_training.distributed_port
|
||||
- distributed_training.distributed_world_size
|
||||
- model.pretrained_model_path
|
||||
- model.target_network_path
|
||||
- next_script
|
||||
- task.cache_in_scratch
|
||||
- task.data
|
||||
- checkpoint.save_interval_updates
|
||||
- checkpoint.keep_interval_updates
|
||||
- checkpoint.save_on_overflow
|
||||
- common.log_interval
|
||||
- common.user_dir
|
||||
- task.local_cache_path
|
||||
sweep:
|
||||
dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
|
||||
subdir: ''
|
||||
launcher:
|
||||
submitit_folder: ${hydra.sweep.dir}
|
||||
timeout_min: 4320
|
||||
cpus_per_task: 10
|
||||
gpus_per_node: 8
|
||||
tasks_per_node: 8
|
||||
mem_gb: 450
|
||||
nodes: 2
|
||||
name: ${env:PREFIX}_${hydra.job.config_name}
|
||||
partition: devlab,learnlab,learnfair,scavenge
|
||||
constraint: volta32gb,ib4
|
||||
max_num_timeout: 30
|
@ -0,0 +1,37 @@
|
||||
# @package _global_
|
||||
|
||||
hydra:
|
||||
job:
|
||||
config:
|
||||
override_dirname:
|
||||
kv_sep: ':'
|
||||
item_sep: '/'
|
||||
exclude_keys:
|
||||
- run_config
|
||||
- distributed_training.distributed_port
|
||||
- distributed_training.distributed_world_size
|
||||
- model.pretrained_model_path
|
||||
- model.target_network_path
|
||||
- next_script
|
||||
- task.cache_in_scratch
|
||||
- task.data
|
||||
- checkpoint.save_interval_updates
|
||||
- checkpoint.keep_interval_updates
|
||||
- checkpoint.save_on_overflow
|
||||
- common.log_interval
|
||||
- common.user_dir
|
||||
- task.local_cache_path
|
||||
sweep:
|
||||
dir: /fsx-wav2vec/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
|
||||
subdir: ''
|
||||
launcher:
|
||||
submitit_folder: ${hydra.sweep.dir}
|
||||
timeout_min: 4320
|
||||
cpus_per_task: 10
|
||||
gpus_per_node: 8
|
||||
tasks_per_node: 8
|
||||
mem_gb: 0
|
||||
nodes: 2
|
||||
name: ${env:PREFIX}_${hydra.job.config_name}
|
||||
partition: wav2vec,learnlab,learnfair
|
||||
max_num_timeout: 30
|
@ -0,0 +1,36 @@
|
||||
# @package _global_
|
||||
|
||||
hydra:
|
||||
job:
|
||||
config:
|
||||
override_dirname:
|
||||
kv_sep: ':'
|
||||
item_sep: '/'
|
||||
exclude_keys:
|
||||
- run_config
|
||||
- distributed_training.distributed_port
|
||||
- distributed_training.distributed_world_size
|
||||
- model.pretrained_model_path
|
||||
- model.target_network_path
|
||||
- next_script
|
||||
- task.cache_in_scratch
|
||||
- task.data
|
||||
- checkpoint.save_interval_updates
|
||||
- checkpoint.keep_interval_updates
|
||||
- checkpoint.save_on_overflow
|
||||
- common.log_interval
|
||||
sweep:
|
||||
dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
|
||||
subdir: ''
|
||||
launcher:
|
||||
submitit_folder: ${hydra.sweep.dir}
|
||||
timeout_min: 4320
|
||||
cpus_per_task: 80
|
||||
gpus_per_node: 8
|
||||
tasks_per_node: 1
|
||||
mem_gb: 450
|
||||
nodes: 3
|
||||
name: ${env:PREFIX}_${hydra.job.config_name}
|
||||
partition: devlab,learnlab,learnfair,scavenge
|
||||
constraint: volta32gb,ib4
|
||||
max_num_timeout: 30
|
@ -0,0 +1,36 @@
|
||||
# @package _global_
|
||||
|
||||
hydra:
|
||||
job:
|
||||
config:
|
||||
override_dirname:
|
||||
kv_sep: ':'
|
||||
item_sep: '/'
|
||||
exclude_keys:
|
||||
- run_config
|
||||
- distributed_training.distributed_port
|
||||
- distributed_training.distributed_world_size
|
||||
- model.pretrained_model_path
|
||||
- model.target_network_path
|
||||
- next_script
|
||||
- task.cache_in_scratch
|
||||
- task.data
|
||||
- checkpoint.save_interval_updates
|
||||
- checkpoint.keep_interval_updates
|
||||
- checkpoint.save_on_overflow
|
||||
- common.log_interval
|
||||
sweep:
|
||||
dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
|
||||
subdir: ''
|
||||
launcher:
|
||||
submitit_folder: ${hydra.sweep.dir}
|
||||
timeout_min: 4320
|
||||
cpus_per_task: 10
|
||||
gpus_per_node: 8
|
||||
tasks_per_node: 8
|
||||
mem_gb: 450
|
||||
nodes: 4
|
||||
name: ${env:PREFIX}_${hydra.job.config_name}
|
||||
partition: devlab,learnlab,learnfair,scavenge
|
||||
constraint: volta32gb,ib4
|
||||
max_num_timeout: 30
|
@ -0,0 +1,36 @@
|
||||
# @package _global_
|
||||
|
||||
hydra:
|
||||
job:
|
||||
config:
|
||||
override_dirname:
|
||||
kv_sep: ':'
|
||||
item_sep: '/'
|
||||
exclude_keys:
|
||||
- run_config
|
||||
- distributed_training.distributed_port
|
||||
- distributed_training.distributed_world_size
|
||||
- model.pretrained_model_path
|
||||
- model.target_network_path
|
||||
- next_script
|
||||
- task.cache_in_scratch
|
||||
- task.data
|
||||
- checkpoint.save_interval_updates
|
||||
- checkpoint.keep_interval_updates
|
||||
- checkpoint.save_on_overflow
|
||||
- common.log_interval
|
||||
- common.user_dir
|
||||
sweep:
|
||||
dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
|
||||
subdir: ''
|
||||
launcher:
|
||||
submitit_folder: ${hydra.sweep.dir}
|
||||
timeout_min: 4320
|
||||
cpus_per_task: 10
|
||||
gpus_per_node: 8
|
||||
tasks_per_node: 8
|
||||
mem_gb: 0
|
||||
nodes: 4
|
||||
name: ${env:PREFIX}_${hydra.job.config_name}
|
||||
partition: wav2vec,learnlab,learnfair
|
||||
max_num_timeout: 30
|
@ -0,0 +1,36 @@
|
||||
# @package _global_
|
||||
|
||||
hydra:
|
||||
job:
|
||||
config:
|
||||
override_dirname:
|
||||
kv_sep: ':'
|
||||
item_sep: '/'
|
||||
exclude_keys:
|
||||
- run_config
|
||||
- distributed_training.distributed_port
|
||||
- distributed_training.distributed_world_size
|
||||
- model.pretrained_model_path
|
||||
- model.target_network_path
|
||||
- next_script
|
||||
- task.cache_in_scratch
|
||||
- task.data
|
||||
- checkpoint.save_interval_updates
|
||||
- checkpoint.keep_interval_updates
|
||||
- checkpoint.save_on_overflow
|
||||
- common.log_interval
|
||||
- common.user_dir
|
||||
sweep:
|
||||
dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
|
||||
subdir: ''
|
||||
launcher:
|
||||
submitit_folder: ${hydra.sweep.dir}
|
||||
timeout_min: 4320
|
||||
cpus_per_task: 10
|
||||
gpus_per_node: 8
|
||||
tasks_per_node: 8
|
||||
mem_gb: 0
|
||||
nodes: 6
|
||||
name: ${env:PREFIX}_${hydra.job.config_name}
|
||||
partition: wav2vec,learnlab,learnfair
|
||||
max_num_timeout: 30
|
@ -0,0 +1,36 @@
|
||||
# @package _global_
|
||||
|
||||
hydra:
|
||||
job:
|
||||
config:
|
||||
override_dirname:
|
||||
kv_sep: ':'
|
||||
item_sep: '/'
|
||||
exclude_keys:
|
||||
- run_config
|
||||
- distributed_training.distributed_port
|
||||
- distributed_training.distributed_world_size
|
||||
- model.pretrained_model_path
|
||||
- model.target_network_path
|
||||
- next_script
|
||||
- task.cache_in_scratch
|
||||
- task.data
|
||||
- checkpoint.save_interval_updates
|
||||
- checkpoint.keep_interval_updates
|
||||
- checkpoint.save_on_overflow
|
||||
- common.log_interval
|
||||
- common.user_dir
|
||||
sweep:
|
||||
dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
|
||||
subdir: ''
|
||||
launcher:
|
||||
submitit_folder: ${hydra.sweep.dir}
|
||||
timeout_min: 4320
|
||||
cpus_per_task: 10
|
||||
gpus_per_node: 8
|
||||
tasks_per_node: 8
|
||||
mem_gb: 0
|
||||
nodes: 8
|
||||
name: ${env:PREFIX}_${hydra.job.config_name}
|
||||
partition: wav2vec,learnlab,learnfair
|
||||
max_num_timeout: 30
|
17
examples/data2vec/data/__init__.py
Normal file
17
examples/data2vec/data/__init__.py
Normal file
@ -0,0 +1,17 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from .image_dataset import ImageDataset
|
||||
from .path_dataset import PathDataset
|
||||
from .mae_image_dataset import MaeImageDataset
|
||||
from .mae_finetuning_image_dataset import MaeFinetuningImageDataset
|
||||
|
||||
|
||||
__all__ = [
|
||||
"ImageDataset",
|
||||
"MaeImageDataset",
|
||||
"MaeFinetuningImageDataset",
|
||||
"PathDataset",
|
||||
]
|
63
examples/data2vec/data/add_class_target_dataset.py
Normal file
63
examples/data2vec/data/add_class_target_dataset.py
Normal file
@ -0,0 +1,63 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
|
||||
from fairseq.data import BaseWrapperDataset, data_utils
|
||||
|
||||
|
||||
class AddClassTargetDataset(BaseWrapperDataset):
|
||||
def __init__(
|
||||
self,
|
||||
dataset,
|
||||
labels,
|
||||
multi_class,
|
||||
num_classes=None,
|
||||
label_indices=None,
|
||||
add_to_input=True,
|
||||
):
|
||||
super().__init__(dataset)
|
||||
|
||||
self.label_indices = label_indices
|
||||
self.labels = labels
|
||||
self.multi_class = multi_class
|
||||
self.add_to_input = add_to_input
|
||||
if num_classes is None and multi_class:
|
||||
assert self.label_indices is not None
|
||||
num_classes = len(self.label_indices)
|
||||
|
||||
self.num_classes = num_classes
|
||||
|
||||
def __getitem__(self, index):
|
||||
item = self.dataset[index]
|
||||
item_labels = self.labels[index]
|
||||
if self.multi_class:
|
||||
item["label"] = torch.zeros(self.num_classes)
|
||||
for il in item_labels:
|
||||
if self.label_indices is not None:
|
||||
il = self.label_indices[il]
|
||||
item["label"][il] = 1.0
|
||||
else:
|
||||
item["label"] = torch.tensor(
|
||||
self.labels[index]
|
||||
if self.label_indices is None
|
||||
else self.label_indices[self.labels[index]]
|
||||
)
|
||||
|
||||
return item
|
||||
|
||||
def collater(self, samples):
|
||||
collated = self.dataset.collater(samples)
|
||||
if len(collated) == 0:
|
||||
return collated
|
||||
|
||||
indices = set(collated["id"].tolist())
|
||||
target = [s["label"] for s in samples if s["id"] in indices]
|
||||
collated["label"] = torch.stack(target, dim=0)
|
||||
|
||||
if self.add_to_input:
|
||||
collated["net_input"]["label"] = collated["label"]
|
||||
|
||||
return collated
|
127
examples/data2vec/data/image_dataset.py
Normal file
127
examples/data2vec/data/image_dataset.py
Normal file
@ -0,0 +1,127 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
import logging
|
||||
|
||||
import numpy as np
|
||||
import os
|
||||
from typing import Optional, Callable, Set
|
||||
|
||||
import torch
|
||||
|
||||
from torchvision.datasets.vision import VisionDataset
|
||||
from torchvision.transforms import ToTensor
|
||||
|
||||
from fairseq.data import FairseqDataset
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ImageDataset(FairseqDataset, VisionDataset):
|
||||
def __init__(
|
||||
self,
|
||||
root: str,
|
||||
extensions: Set[str],
|
||||
load_classes: bool,
|
||||
transform: Optional[Callable] = None,
|
||||
shuffle=True,
|
||||
):
|
||||
FairseqDataset.__init__(self)
|
||||
VisionDataset.__init__(self, root=root, transform=transform)
|
||||
|
||||
self.shuffle = shuffle
|
||||
self.tensor_transform = ToTensor()
|
||||
|
||||
self.classes = None
|
||||
self.labels = None
|
||||
if load_classes:
|
||||
classes = [d.name for d in os.scandir(root) if d.is_dir()]
|
||||
classes.sort()
|
||||
self.classes = {cls_name: i for i, cls_name in enumerate(classes)}
|
||||
logger.info(f"loaded {len(self.classes)} classes")
|
||||
self.labels = []
|
||||
|
||||
def walk_path(root_path):
|
||||
for root, _, fnames in sorted(os.walk(root_path, followlinks=True)):
|
||||
for fname in sorted(fnames):
|
||||
fname_ext = os.path.splitext(fname)
|
||||
if fname_ext[-1].lower() not in extensions:
|
||||
continue
|
||||
|
||||
path = os.path.join(root, fname)
|
||||
yield path
|
||||
|
||||
logger.info(f"finding images in {root}")
|
||||
if self.classes is not None:
|
||||
self.files = []
|
||||
self.labels = []
|
||||
for c, i in self.classes.items():
|
||||
for f in walk_path(os.path.join(root, c)):
|
||||
self.files.append(f)
|
||||
self.labels.append(i)
|
||||
else:
|
||||
self.files = [f for f in walk_path(root)]
|
||||
|
||||
logger.info(f"loaded {len(self.files)} examples")
|
||||
|
||||
def __getitem__(self, index):
|
||||
from PIL import Image
|
||||
|
||||
fpath = self.files[index]
|
||||
|
||||
with open(fpath, "rb") as f:
|
||||
img = Image.open(f).convert("RGB")
|
||||
|
||||
if self.transform is None:
|
||||
img = self.tensor_transform(img)
|
||||
else:
|
||||
img = self.transform(img)
|
||||
assert torch.is_tensor(img)
|
||||
|
||||
res = {"id": index, "img": img}
|
||||
|
||||
if self.labels is not None:
|
||||
res["label"] = self.labels[index]
|
||||
|
||||
return res
|
||||
|
||||
def __len__(self):
|
||||
return len(self.files)
|
||||
|
||||
def collater(self, samples):
|
||||
if len(samples) == 0:
|
||||
return {}
|
||||
|
||||
collated_img = torch.stack([s["img"] for s in samples], dim=0)
|
||||
|
||||
res = {
|
||||
"id": torch.LongTensor([s["id"] for s in samples]),
|
||||
"net_input": {
|
||||
"img": collated_img,
|
||||
},
|
||||
}
|
||||
|
||||
if "label" in samples[0]:
|
||||
res["net_input"]["label"] = torch.LongTensor([s["label"] for s in samples])
|
||||
|
||||
return res
|
||||
|
||||
def num_tokens(self, index):
|
||||
return 1
|
||||
|
||||
def size(self, index):
|
||||
return 1
|
||||
|
||||
def ordered_indices(self):
|
||||
"""Return an ordered list of indices. Batches will be constructed based
|
||||
on this order."""
|
||||
if self.shuffle:
|
||||
order = [np.random.permutation(len(self))]
|
||||
else:
|
||||
order = [np.arange(len(self))]
|
||||
|
||||
return order[0]
|
135
examples/data2vec/data/mae_finetuning_image_dataset.py
Normal file
135
examples/data2vec/data/mae_finetuning_image_dataset.py
Normal file
@ -0,0 +1,135 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
import logging
|
||||
|
||||
import numpy as np
|
||||
import os
|
||||
|
||||
import torch
|
||||
|
||||
from torchvision import datasets, transforms
|
||||
|
||||
from timm.data import create_transform
|
||||
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
import PIL
|
||||
|
||||
from fairseq.data import FairseqDataset
|
||||
from .mae_image_dataset import caching_loader
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def build_transform(is_train, input_size, color_jitter, aa, reprob, remode, recount):
|
||||
mean = IMAGENET_DEFAULT_MEAN
|
||||
std = IMAGENET_DEFAULT_STD
|
||||
# train transform
|
||||
if is_train:
|
||||
# this should always dispatch to transforms_imagenet_train
|
||||
transform = create_transform(
|
||||
input_size=input_size,
|
||||
is_training=True,
|
||||
color_jitter=color_jitter,
|
||||
auto_augment=aa,
|
||||
interpolation="bicubic",
|
||||
re_prob=reprob,
|
||||
re_mode=remode,
|
||||
re_count=recount,
|
||||
mean=mean,
|
||||
std=std,
|
||||
)
|
||||
return transform
|
||||
|
||||
# eval transform
|
||||
t = []
|
||||
if input_size <= 224:
|
||||
crop_pct = 224 / 256
|
||||
else:
|
||||
crop_pct = 1.0
|
||||
size = int(input_size / crop_pct)
|
||||
t.append(
|
||||
transforms.Resize(
|
||||
size, interpolation=PIL.Image.BICUBIC
|
||||
), # to maintain same ratio w.r.t. 224 images
|
||||
)
|
||||
t.append(transforms.CenterCrop(input_size))
|
||||
|
||||
t.append(transforms.ToTensor())
|
||||
t.append(transforms.Normalize(mean, std))
|
||||
return transforms.Compose(t)
|
||||
|
||||
|
||||
class MaeFinetuningImageDataset(FairseqDataset):
|
||||
def __init__(
|
||||
self,
|
||||
root: str,
|
||||
split: str,
|
||||
is_train: bool,
|
||||
input_size,
|
||||
color_jitter=None,
|
||||
aa="rand-m9-mstd0.5-inc1",
|
||||
reprob=0.25,
|
||||
remode="pixel",
|
||||
recount=1,
|
||||
local_cache_path=None,
|
||||
shuffle=True,
|
||||
):
|
||||
FairseqDataset.__init__(self)
|
||||
|
||||
self.shuffle = shuffle
|
||||
|
||||
transform = build_transform(
|
||||
is_train, input_size, color_jitter, aa, reprob, remode, recount
|
||||
)
|
||||
|
||||
path = os.path.join(root, split)
|
||||
loader = caching_loader(local_cache_path, datasets.folder.default_loader)
|
||||
|
||||
self.dataset = datasets.ImageFolder(path, loader=loader, transform=transform)
|
||||
|
||||
logger.info(f"loaded {len(self.dataset)} examples")
|
||||
|
||||
def __getitem__(self, index):
|
||||
img, label = self.dataset[index]
|
||||
return {"id": index, "img": img, "label": label}
|
||||
|
||||
def __len__(self):
|
||||
return len(self.dataset)
|
||||
|
||||
def collater(self, samples):
|
||||
if len(samples) == 0:
|
||||
return {}
|
||||
|
||||
collated_img = torch.stack([s["img"] for s in samples], dim=0)
|
||||
|
||||
res = {
|
||||
"id": torch.LongTensor([s["id"] for s in samples]),
|
||||
"net_input": {
|
||||
"imgs": collated_img,
|
||||
},
|
||||
}
|
||||
|
||||
if "label" in samples[0]:
|
||||
res["net_input"]["labels"] = torch.LongTensor([s["label"] for s in samples])
|
||||
|
||||
return res
|
||||
|
||||
def num_tokens(self, index):
|
||||
return 1
|
||||
|
||||
def size(self, index):
|
||||
return 1
|
||||
|
||||
def ordered_indices(self):
|
||||
"""Return an ordered list of indices. Batches will be constructed based
|
||||
on this order."""
|
||||
if self.shuffle:
|
||||
order = [np.random.permutation(len(self))]
|
||||
else:
|
||||
order = [np.arange(len(self))]
|
||||
|
||||
return order[0]
|
418
examples/data2vec/data/mae_image_dataset.py
Normal file
418
examples/data2vec/data/mae_image_dataset.py
Normal file
@ -0,0 +1,418 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
from functools import partial
|
||||
import logging
|
||||
import math
|
||||
import random
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import os
|
||||
|
||||
import torch
|
||||
|
||||
from torchvision import datasets, transforms
|
||||
from .path_dataset import PathDataset
|
||||
|
||||
from fairseq.data import FairseqDataset
|
||||
from fairseq.data.data_utils import compute_block_mask_1d, compute_block_mask_2d
|
||||
|
||||
from shutil import copyfile
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def load(path, loader, cache):
|
||||
if hasattr(caching_loader, "cache_root"):
|
||||
cache = caching_loader.cache_root
|
||||
|
||||
cached_path = cache + path
|
||||
|
||||
num_tries = 3
|
||||
for curr_try in range(num_tries):
|
||||
try:
|
||||
if curr_try == 2:
|
||||
return loader(path)
|
||||
if not os.path.exists(cached_path) or curr_try > 0:
|
||||
os.makedirs(os.path.dirname(cached_path), exist_ok=True)
|
||||
copyfile(path, cached_path)
|
||||
os.chmod(cached_path, 0o777)
|
||||
return loader(cached_path)
|
||||
except Exception as e:
|
||||
logger.warning(str(e))
|
||||
if "Errno 13" in str(e):
|
||||
caching_loader.cache_root = f"/scratch/{random.randint(0, 69420)}"
|
||||
logger.warning(f"setting cache root to {caching_loader.cache_root}")
|
||||
cached_path = caching_loader.cache_root + path
|
||||
if curr_try == (num_tries - 1):
|
||||
raise
|
||||
time.sleep(2)
|
||||
|
||||
|
||||
def caching_loader(cache_root: str, loader):
|
||||
if cache_root is None:
|
||||
return loader
|
||||
|
||||
if cache_root == "slurm_tmpdir":
|
||||
cache_root = os.environ["SLURM_TMPDIR"]
|
||||
assert len(cache_root) > 0
|
||||
|
||||
if not cache_root.endswith("/"):
|
||||
cache_root += "/"
|
||||
|
||||
return partial(load, loader=loader, cache=cache_root)
|
||||
|
||||
|
||||
class RandomResizedCropAndInterpolationWithTwoPic:
|
||||
"""Crop the given PIL Image to random size and aspect ratio with random interpolation.
|
||||
|
||||
A crop of random size (default: of 0.08 to 1.0) of the original size and a random
|
||||
aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This crop
|
||||
is finally resized to given size.
|
||||
This is popularly used to train the Inception networks.
|
||||
|
||||
Args:
|
||||
size: expected output size of each edge
|
||||
scale: range of size of the origin size cropped
|
||||
ratio: range of aspect ratio of the origin aspect ratio cropped
|
||||
interpolation: Default: PIL.Image.BILINEAR
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
size,
|
||||
second_size=None,
|
||||
scale=(0.08, 1.0),
|
||||
ratio=(3.0 / 4.0, 4.0 / 3.0),
|
||||
interpolation="bilinear",
|
||||
second_interpolation="lanczos",
|
||||
):
|
||||
if isinstance(size, tuple):
|
||||
self.size = size
|
||||
else:
|
||||
self.size = (size, size)
|
||||
if second_size is not None:
|
||||
if isinstance(second_size, tuple):
|
||||
self.second_size = second_size
|
||||
else:
|
||||
self.second_size = (second_size, second_size)
|
||||
else:
|
||||
self.second_size = None
|
||||
if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
|
||||
logger.warning("range should be of kind (min, max)")
|
||||
|
||||
if interpolation == "random":
|
||||
from PIL import Image
|
||||
|
||||
self.interpolation = (Image.BILINEAR, Image.BICUBIC)
|
||||
else:
|
||||
self.interpolation = self._pil_interp(interpolation)
|
||||
|
||||
self.second_interpolation = (
|
||||
self._pil_interp(second_interpolation)
|
||||
if second_interpolation is not None
|
||||
else None
|
||||
)
|
||||
self.scale = scale
|
||||
self.ratio = ratio
|
||||
|
||||
def _pil_interp(self, method):
|
||||
from PIL import Image
|
||||
|
||||
if method == "bicubic":
|
||||
return Image.BICUBIC
|
||||
elif method == "lanczos":
|
||||
return Image.LANCZOS
|
||||
elif method == "hamming":
|
||||
return Image.HAMMING
|
||||
else:
|
||||
# default bilinear, do we want to allow nearest?
|
||||
return Image.BILINEAR
|
||||
|
||||
@staticmethod
|
||||
def get_params(img, scale, ratio):
|
||||
"""Get parameters for ``crop`` for a random sized crop.
|
||||
|
||||
Args:
|
||||
img (PIL Image): Image to be cropped.
|
||||
scale (tuple): range of size of the origin size cropped
|
||||
ratio (tuple): range of aspect ratio of the origin aspect ratio cropped
|
||||
|
||||
Returns:
|
||||
tuple: params (i, j, h, w) to be passed to ``crop`` for a random
|
||||
sized crop.
|
||||
"""
|
||||
area = img.size[0] * img.size[1]
|
||||
|
||||
for attempt in range(10):
|
||||
target_area = random.uniform(*scale) * area
|
||||
log_ratio = (math.log(ratio[0]), math.log(ratio[1]))
|
||||
aspect_ratio = math.exp(random.uniform(*log_ratio))
|
||||
|
||||
w = int(round(math.sqrt(target_area * aspect_ratio)))
|
||||
h = int(round(math.sqrt(target_area / aspect_ratio)))
|
||||
|
||||
if w <= img.size[0] and h <= img.size[1]:
|
||||
i = random.randint(0, img.size[1] - h)
|
||||
j = random.randint(0, img.size[0] - w)
|
||||
return i, j, h, w
|
||||
|
||||
# Fallback to central crop
|
||||
in_ratio = img.size[0] / img.size[1]
|
||||
if in_ratio < min(ratio):
|
||||
w = img.size[0]
|
||||
h = int(round(w / min(ratio)))
|
||||
elif in_ratio > max(ratio):
|
||||
h = img.size[1]
|
||||
w = int(round(h * max(ratio)))
|
||||
else: # whole image
|
||||
w = img.size[0]
|
||||
h = img.size[1]
|
||||
i = (img.size[1] - h) // 2
|
||||
j = (img.size[0] - w) // 2
|
||||
return i, j, h, w
|
||||
|
||||
def __call__(self, img):
|
||||
import torchvision.transforms.functional as F
|
||||
|
||||
"""
|
||||
Args:
|
||||
img (PIL Image): Image to be cropped and resized.
|
||||
|
||||
Returns:
|
||||
PIL Image: Randomly cropped and resized image.
|
||||
"""
|
||||
i, j, h, w = self.get_params(img, self.scale, self.ratio)
|
||||
if isinstance(self.interpolation, (tuple, list)):
|
||||
interpolation = random.choice(self.interpolation)
|
||||
else:
|
||||
interpolation = self.interpolation
|
||||
if self.second_size is None:
|
||||
return F.resized_crop(img, i, j, h, w, self.size, interpolation)
|
||||
else:
|
||||
return F.resized_crop(
|
||||
img, i, j, h, w, self.size, interpolation
|
||||
), F.resized_crop(
|
||||
img, i, j, h, w, self.second_size, self.second_interpolation
|
||||
)
|
||||
|
||||
|
||||
class MaeImageDataset(FairseqDataset):
|
||||
def __init__(
|
||||
self,
|
||||
root: str,
|
||||
split: str,
|
||||
input_size,
|
||||
local_cache_path=None,
|
||||
shuffle=True,
|
||||
key="imgs",
|
||||
beit_transforms=False,
|
||||
target_transform=False,
|
||||
no_transform=False,
|
||||
compute_mask=False,
|
||||
patch_size: int = 16,
|
||||
mask_prob: float = 0.75,
|
||||
mask_prob_adjust: float = 0,
|
||||
mask_length: int = 1,
|
||||
inverse_mask: bool = False,
|
||||
expand_adjacent: bool = False,
|
||||
mask_dropout: float = 0,
|
||||
non_overlapping: bool = False,
|
||||
require_same_masks: bool = True,
|
||||
clone_batch: int = 1,
|
||||
dataset_type: str = "imagefolder",
|
||||
):
|
||||
FairseqDataset.__init__(self)
|
||||
|
||||
self.shuffle = shuffle
|
||||
self.key = key
|
||||
|
||||
loader = caching_loader(local_cache_path, datasets.folder.default_loader)
|
||||
|
||||
self.transform_source = None
|
||||
self.transform_target = None
|
||||
|
||||
if target_transform:
|
||||
self.transform_source = transforms.ColorJitter(0.4, 0.4, 0.4)
|
||||
self.transform_target = transforms.ColorJitter(0.4, 0.4, 0.4)
|
||||
|
||||
if no_transform:
|
||||
if input_size <= 224:
|
||||
crop_pct = 224 / 256
|
||||
else:
|
||||
crop_pct = 1.0
|
||||
size = int(input_size / crop_pct)
|
||||
|
||||
self.transform_train = transforms.Compose(
|
||||
[
|
||||
transforms.Resize(size, interpolation=3),
|
||||
transforms.CenterCrop(input_size),
|
||||
]
|
||||
)
|
||||
|
||||
self.transform_train = transforms.Resize((input_size, input_size))
|
||||
elif beit_transforms:
|
||||
beit_transform_list = []
|
||||
if not target_transform:
|
||||
beit_transform_list.append(transforms.ColorJitter(0.4, 0.4, 0.4))
|
||||
beit_transform_list.extend(
|
||||
[
|
||||
transforms.RandomHorizontalFlip(p=0.5),
|
||||
RandomResizedCropAndInterpolationWithTwoPic(
|
||||
size=input_size,
|
||||
second_size=None,
|
||||
interpolation="bicubic",
|
||||
second_interpolation=None,
|
||||
),
|
||||
]
|
||||
)
|
||||
self.transform_train = transforms.Compose(beit_transform_list)
|
||||
else:
|
||||
self.transform_train = transforms.Compose(
|
||||
[
|
||||
transforms.RandomResizedCrop(
|
||||
input_size, scale=(0.2, 1.0), interpolation=3
|
||||
), # 3 is bicubic
|
||||
transforms.RandomHorizontalFlip(),
|
||||
]
|
||||
)
|
||||
self.final_transform = transforms.Compose(
|
||||
[
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(
|
||||
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
if dataset_type == "imagefolder":
|
||||
self.dataset = datasets.ImageFolder(
|
||||
os.path.join(root, split), loader=loader
|
||||
)
|
||||
elif dataset_type == "path":
|
||||
self.dataset = PathDataset(
|
||||
root,
|
||||
loader,
|
||||
None,
|
||||
None,
|
||||
mean=[0.485, 0.456, 0.406],
|
||||
std=[0.229, 0.224, 0.225],
|
||||
)
|
||||
else:
|
||||
raise Exception(f"invalid dataset type {dataset_type}")
|
||||
|
||||
logger.info(
|
||||
f"initial transform: {self.transform_train}, "
|
||||
f"source transform: {self.transform_source}, "
|
||||
f"target transform: {self.transform_target}, "
|
||||
f"final transform: {self.final_transform}"
|
||||
)
|
||||
logger.info(f"loaded {len(self.dataset)} examples")
|
||||
|
||||
self.is_compute_mask = compute_mask
|
||||
self.patches = (input_size // patch_size) ** 2
|
||||
self.mask_prob = mask_prob
|
||||
self.mask_prob_adjust = mask_prob_adjust
|
||||
self.mask_length = mask_length
|
||||
self.inverse_mask = inverse_mask
|
||||
self.expand_adjacent = expand_adjacent
|
||||
self.mask_dropout = mask_dropout
|
||||
self.non_overlapping = non_overlapping
|
||||
self.require_same_masks = require_same_masks
|
||||
self.clone_batch = clone_batch
|
||||
|
||||
def __getitem__(self, index):
|
||||
img, _ = self.dataset[index]
|
||||
|
||||
img = self.transform_train(img)
|
||||
|
||||
source = None
|
||||
target = None
|
||||
if self.transform_source is not None:
|
||||
source = self.final_transform(self.transform_source(img))
|
||||
if self.transform_target is not None:
|
||||
target = self.final_transform(self.transform_target(img))
|
||||
|
||||
if source is None:
|
||||
img = self.final_transform(img)
|
||||
|
||||
v = {"id": index, self.key: source if source is not None else img}
|
||||
if target is not None:
|
||||
v["target"] = target
|
||||
|
||||
if self.is_compute_mask:
|
||||
if self.mask_length == 1:
|
||||
mask = compute_block_mask_1d(
|
||||
shape=(self.clone_batch, self.patches),
|
||||
mask_prob=self.mask_prob,
|
||||
mask_length=self.mask_length,
|
||||
mask_prob_adjust=self.mask_prob_adjust,
|
||||
inverse_mask=self.inverse_mask,
|
||||
require_same_masks=True,
|
||||
)
|
||||
else:
|
||||
mask = compute_block_mask_2d(
|
||||
shape=(self.clone_batch, self.patches),
|
||||
mask_prob=self.mask_prob,
|
||||
mask_length=self.mask_length,
|
||||
mask_prob_adjust=self.mask_prob_adjust,
|
||||
inverse_mask=self.inverse_mask,
|
||||
require_same_masks=True,
|
||||
expand_adjcent=self.expand_adjacent,
|
||||
mask_dropout=self.mask_dropout,
|
||||
non_overlapping=self.non_overlapping,
|
||||
)
|
||||
|
||||
v["precomputed_mask"] = mask
|
||||
|
||||
return v
|
||||
|
||||
def __len__(self):
|
||||
return len(self.dataset)
|
||||
|
||||
def collater(self, samples):
|
||||
if len(samples) == 0:
|
||||
return {}
|
||||
|
||||
collated_img = torch.stack([s[self.key] for s in samples], dim=0)
|
||||
|
||||
res = {
|
||||
"id": torch.LongTensor([s["id"] for s in samples]),
|
||||
"net_input": {
|
||||
self.key: collated_img,
|
||||
},
|
||||
}
|
||||
|
||||
if "target" in samples[0]:
|
||||
collated_target = torch.stack([s["target"] for s in samples], dim=0)
|
||||
res["net_input"]["target"] = collated_target
|
||||
|
||||
if "precomputed_mask" in samples[0]:
|
||||
collated_mask = torch.cat([s["precomputed_mask"] for s in samples], dim=0)
|
||||
res["net_input"]["precomputed_mask"] = collated_mask
|
||||
|
||||
return res
|
||||
|
||||
def num_tokens(self, index):
|
||||
return 1
|
||||
|
||||
def size(self, index):
|
||||
return 1
|
||||
|
||||
@property
|
||||
def sizes(self):
|
||||
return np.full((len(self),), 1)
|
||||
|
||||
def ordered_indices(self):
|
||||
"""Return an ordered list of indices. Batches will be constructed based
|
||||
on this order."""
|
||||
if self.shuffle:
|
||||
order = [np.random.permutation(len(self))]
|
||||
else:
|
||||
order = [np.arange(len(self))]
|
||||
|
||||
return order[0]
|
14
examples/data2vec/data/modality.py
Normal file
14
examples/data2vec/data/modality.py
Normal file
@ -0,0 +1,14 @@
|
||||
# Copyright (c) 2017-present, Facebook, Inc.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the LICENSE file in
|
||||
# the root directory of this source tree. An additional grant of patent rights
|
||||
# can be found in the PATENTS file in the same directory.
|
||||
|
||||
from enum import Enum, auto
|
||||
|
||||
|
||||
class Modality(Enum):
|
||||
AUDIO = auto()
|
||||
IMAGE = auto()
|
||||
TEXT = auto()
|
64
examples/data2vec/data/path_dataset.py
Normal file
64
examples/data2vec/data/path_dataset.py
Normal file
@ -0,0 +1,64 @@
|
||||
import glob
|
||||
import os
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import logging
|
||||
import numpy as np
|
||||
import torchvision.transforms.functional as TF
|
||||
import PIL
|
||||
from PIL import Image
|
||||
from torchvision.datasets import VisionDataset
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PathDataset(VisionDataset):
|
||||
def __init__(
|
||||
self,
|
||||
root: List[str],
|
||||
loader: None = None,
|
||||
transform: Optional[str] = None,
|
||||
extra_transform: Optional[str] = None,
|
||||
mean: Optional[List[float]] = None,
|
||||
std: Optional[List[float]] = None,
|
||||
):
|
||||
super().__init__(root=root)
|
||||
|
||||
PIL.Image.MAX_IMAGE_PIXELS = 256000001
|
||||
|
||||
self.files = []
|
||||
for folder in self.root:
|
||||
self.files.extend(
|
||||
sorted(glob.glob(os.path.join(folder, "**", "*.jpg"), recursive=True))
|
||||
)
|
||||
self.files.extend(
|
||||
sorted(glob.glob(os.path.join(folder, "**", "*.png"), recursive=True))
|
||||
)
|
||||
|
||||
self.transform = transform
|
||||
self.extra_transform = extra_transform
|
||||
self.mean = mean
|
||||
self.std = std
|
||||
|
||||
self.loader = loader
|
||||
|
||||
logger.info(f"loaded {len(self.files)} samples from {root}")
|
||||
|
||||
assert (mean is None) == (std is None)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.files)
|
||||
|
||||
def __getitem__(self, idx) -> Tuple[np.ndarray, np.ndarray]:
|
||||
path = self.files[idx]
|
||||
|
||||
if self.loader is not None:
|
||||
return self.loader(path), None
|
||||
|
||||
img = Image.open(path).convert("RGB")
|
||||
if self.transform is not None:
|
||||
img = self.transform(img)
|
||||
img = TF.to_tensor(img)
|
||||
if self.mean is not None and self.std is not None:
|
||||
img = TF.normalize(img, self.mean, self.std)
|
||||
return img, None
|
165
examples/data2vec/fb_convert_beit_cp.py
Normal file
165
examples/data2vec/fb_convert_beit_cp.py
Normal file
@ -0,0 +1,165 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import argparse
|
||||
import torch
|
||||
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from fairseq.criterions.model_criterion import ModelCriterionConfig
|
||||
from fairseq.dataclass.configs import FairseqConfig
|
||||
|
||||
from tasks import ImageClassificationConfig, ImagePretrainingConfig
|
||||
from models.data2vec_image_classification import (
|
||||
Data2VecImageClassificationConfig,
|
||||
Data2VecImageClassificationModel,
|
||||
)
|
||||
from models.data2vec_vision import Data2VecVisionConfig, Data2VecVisionModel
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="convert beit checkpoint into data2vec - vision checkpoint"
|
||||
)
|
||||
# fmt: off
|
||||
parser.add_argument('checkpoint', help='checkpoint to convert')
|
||||
parser.add_argument('--output', required=True, metavar='PATH', help='where to output converted checkpoint')
|
||||
parser.add_argument('--type', type=str, choices=['vision', 'image_classification'], default='image_classification', help='type of model to upgrade')
|
||||
parser.add_argument('--inception_norms', action='store_true', default=False)
|
||||
# fmt: on
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def update_checkpoint(model_dict, prefix, is_nested):
|
||||
|
||||
replace_paths = {
|
||||
"cls_token": "model.cls_emb" if is_nested else "cls_emb",
|
||||
"patch_embed": "model.patch_embed" if is_nested else "patch_embed",
|
||||
"mask_token": "mask_emb",
|
||||
}
|
||||
|
||||
starts_with = {
|
||||
"patch_embed.proj": "model.patch_embed.conv"
|
||||
if is_nested
|
||||
else "patch_embed.conv",
|
||||
"lm_head": "final_proj",
|
||||
"fc_norm": "fc_norm",
|
||||
"head": "head",
|
||||
}
|
||||
|
||||
partial = {
|
||||
"mlp.fc1": "mlp.0",
|
||||
"mlp.fc2": "mlp.2",
|
||||
}
|
||||
|
||||
for k in list(model_dict.keys()):
|
||||
for sw, r in starts_with.items():
|
||||
if k.startswith(sw):
|
||||
replace_paths[k] = k.replace(sw, r)
|
||||
for p, r in partial.items():
|
||||
if p in k:
|
||||
replace_paths[k] = prefix + k.replace(p, r)
|
||||
|
||||
if prefix != "":
|
||||
for k in list(model_dict.keys()):
|
||||
if k not in replace_paths:
|
||||
replace_paths[k] = prefix + k
|
||||
|
||||
for k in list(model_dict.keys()):
|
||||
if k in replace_paths:
|
||||
model_dict[replace_paths[k]] = model_dict[k]
|
||||
if k != replace_paths[k]:
|
||||
del model_dict[k]
|
||||
|
||||
return model_dict
|
||||
|
||||
|
||||
def main():
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
cp = torch.load(args.checkpoint, map_location="cpu")
|
||||
|
||||
cfg = FairseqConfig(
|
||||
criterion=ModelCriterionConfig(_name="model", log_keys=["correct"]),
|
||||
)
|
||||
|
||||
if args.type == "image_classification":
|
||||
|
||||
cfg.task = ImageClassificationConfig(
|
||||
_name="image_classification",
|
||||
data=".",
|
||||
)
|
||||
|
||||
if args.inception_norms:
|
||||
cfg.task.normalization_mean = [0.5, 0.5, 0.5]
|
||||
cfg.task.normalization_std = [0.5, 0.5, 0.5]
|
||||
|
||||
cfg.model = Data2VecImageClassificationConfig(
|
||||
_name="data2vec_image_classification",
|
||||
)
|
||||
cfg.model.pretrained_model_args = FairseqConfig(
|
||||
model=Data2VecVisionConfig(
|
||||
_name="data2vec_vision", shared_rel_pos_bias=False
|
||||
),
|
||||
task=ImagePretrainingConfig(
|
||||
_name="image_pretraining",
|
||||
),
|
||||
)
|
||||
|
||||
cfg = OmegaConf.create(cfg)
|
||||
|
||||
state = {
|
||||
"cfg": OmegaConf.to_container(cfg, resolve=True, enum_to_str=True),
|
||||
"model": cp["module"],
|
||||
"best_loss": None,
|
||||
"optimizer": None,
|
||||
"extra_state": {},
|
||||
}
|
||||
|
||||
model = Data2VecImageClassificationModel(cfg.model)
|
||||
model.load_state_dict(
|
||||
update_checkpoint(state["model"], prefix="model.encoder.", is_nested=True),
|
||||
strict=True,
|
||||
)
|
||||
elif args.type == "vision":
|
||||
cfg.task = ImagePretrainingConfig(
|
||||
_name="image_pretraining",
|
||||
data=".",
|
||||
)
|
||||
|
||||
if args.inception_norms:
|
||||
cfg.task.normalization_mean = [0.5, 0.5, 0.5]
|
||||
cfg.task.normalization_std = [0.5, 0.5, 0.5]
|
||||
|
||||
cfg.model = Data2VecVisionConfig(
|
||||
_name="data2vec_vision",
|
||||
)
|
||||
cfg = OmegaConf.create(cfg)
|
||||
|
||||
state = {
|
||||
"cfg": OmegaConf.to_container(cfg, resolve=True, enum_to_str=True),
|
||||
"model": cp["model"],
|
||||
"best_loss": None,
|
||||
"optimizer": None,
|
||||
"extra_state": {},
|
||||
}
|
||||
|
||||
model = Data2VecVisionModel(cfg.model)
|
||||
model.load_state_dict(
|
||||
update_checkpoint(state["model"], prefix="encoder.", is_nested=False),
|
||||
strict=True,
|
||||
)
|
||||
else:
|
||||
raise Exception("unsupported type " + args.type)
|
||||
|
||||
print(state["cfg"], state.keys())
|
||||
torch.save(state, args.output)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
0
examples/data2vec/models/__init__.py
Normal file
0
examples/data2vec/models/__init__.py
Normal file
614
examples/data2vec/models/audio_classification.py
Normal file
614
examples/data2vec/models/audio_classification.py
Normal file
@ -0,0 +1,614 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import contextlib
|
||||
import logging
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
from omegaconf import II, MISSING, open_dict
|
||||
|
||||
from fairseq import checkpoint_utils, tasks
|
||||
from fairseq.dataclass import FairseqDataclass
|
||||
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
|
||||
from fairseq.models import (
|
||||
BaseFairseqModel,
|
||||
register_model,
|
||||
)
|
||||
from fairseq.models.wav2vec.wav2vec2 import MASKING_DISTRIBUTION_CHOICES
|
||||
from fairseq.modules import TransposeLast
|
||||
from fairseq.tasks import FairseqTask
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AudioClassificationConfig(FairseqDataclass):
|
||||
model_path: str = field(
|
||||
default=MISSING, metadata={"help": "path to wav2vec 2.0 model"}
|
||||
)
|
||||
no_pretrained_weights: bool = field(
|
||||
default=False, metadata={"help": "if true, does not load pretrained weights"}
|
||||
)
|
||||
dropout_input: float = field(
|
||||
default=0.0,
|
||||
metadata={"help": "dropout to apply to the input (after feat extr)"},
|
||||
)
|
||||
final_dropout: float = field(
|
||||
default=0.0,
|
||||
metadata={"help": "dropout after transformer and before final projection"},
|
||||
)
|
||||
dropout: float = field(
|
||||
default=0.0, metadata={"help": "dropout probability inside wav2vec 2.0 model"}
|
||||
)
|
||||
attention_dropout: float = field(
|
||||
default=0.0,
|
||||
metadata={
|
||||
"help": "dropout probability for attention weights inside wav2vec 2.0 model"
|
||||
},
|
||||
)
|
||||
activation_dropout: float = field(
|
||||
default=0.0,
|
||||
metadata={
|
||||
"help": "dropout probability after activation in FFN inside wav2vec 2.0 model"
|
||||
},
|
||||
)
|
||||
|
||||
# masking
|
||||
apply_mask: bool = field(
|
||||
default=False, metadata={"help": "apply masking during fine-tuning"}
|
||||
)
|
||||
mask_length: int = field(
|
||||
default=10, metadata={"help": "repeat the mask indices multiple times"}
|
||||
)
|
||||
mask_prob: float = field(
|
||||
default=0.5,
|
||||
metadata={
|
||||
"help": "probability of replacing a token with mask (normalized by length)"
|
||||
},
|
||||
)
|
||||
mask_selection: MASKING_DISTRIBUTION_CHOICES = field(
|
||||
default="static", metadata={"help": "how to choose masks"}
|
||||
)
|
||||
mask_other: float = field(
|
||||
default=0,
|
||||
metadata={
|
||||
"help": "secondary mask argument (used for more complex distributions), "
|
||||
"see help in compute_mask_indices"
|
||||
},
|
||||
)
|
||||
no_mask_overlap: bool = field(
|
||||
default=False, metadata={"help": "whether to allow masks to overlap"}
|
||||
)
|
||||
mask_min_space: Optional[int] = field(
|
||||
default=1,
|
||||
metadata={"help": "min space between spans (if no overlap is enabled)"},
|
||||
)
|
||||
require_same_masks: bool = field(
|
||||
default=True,
|
||||
metadata={
|
||||
"help": "whether to number of masked timesteps must be the same across all "
|
||||
"examples in a batch"
|
||||
},
|
||||
)
|
||||
mask_dropout: float = field(
|
||||
default=0.0,
|
||||
metadata={"help": "percent of masks to unmask for each sample"},
|
||||
)
|
||||
|
||||
# channel masking
|
||||
mask_channel_length: int = field(
|
||||
default=10, metadata={"help": "length of the mask for features (channels)"}
|
||||
)
|
||||
mask_channel_prob: float = field(
|
||||
default=0.0, metadata={"help": "probability of replacing a feature with 0"}
|
||||
)
|
||||
mask_channel_selection: MASKING_DISTRIBUTION_CHOICES = field(
|
||||
default="static",
|
||||
metadata={"help": "how to choose mask length for channel masking"},
|
||||
)
|
||||
mask_channel_other: float = field(
|
||||
default=0,
|
||||
metadata={
|
||||
"help": "secondary mask argument (used for more complex distributions), "
|
||||
"see help in compute_mask_indicesh"
|
||||
},
|
||||
)
|
||||
no_mask_channel_overlap: bool = field(
|
||||
default=False, metadata={"help": "whether to allow channel masks to overlap"}
|
||||
)
|
||||
freeze_finetune_updates: int = field(
|
||||
default=0, metadata={"help": "dont finetune wav2vec for this many updates"}
|
||||
)
|
||||
feature_grad_mult: float = field(
|
||||
default=0.0, metadata={"help": "reset feature grad mult in wav2vec 2.0 to this"}
|
||||
)
|
||||
layerdrop: float = field(
|
||||
default=0.0, metadata={"help": "probability of dropping a layer in wav2vec 2.0"}
|
||||
)
|
||||
mask_channel_min_space: Optional[int] = field(
|
||||
default=1,
|
||||
metadata={"help": "min space between spans (if no overlap is enabled)"},
|
||||
)
|
||||
mask_channel_before: bool = False
|
||||
normalize: bool = II("task.normalize")
|
||||
data: str = II("task.data")
|
||||
# this holds the loaded wav2vec args
|
||||
d2v_args: Any = None
|
||||
offload_activations: bool = field(
|
||||
default=False, metadata={"help": "offload_activations"}
|
||||
)
|
||||
min_params_to_wrap: int = field(
|
||||
default=int(1e8),
|
||||
metadata={
|
||||
"help": "minimum number of params for a layer to be wrapped with FSDP() when "
|
||||
"training with --ddp-backend=fully_sharded. Smaller values will "
|
||||
"improve memory efficiency, but may make torch.distributed "
|
||||
"communication less efficient due to smaller input sizes. This option "
|
||||
"is set to 0 (i.e., always wrap) when --checkpoint-activations or "
|
||||
"--offload-activations are passed."
|
||||
},
|
||||
)
|
||||
|
||||
checkpoint_activations: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "recompute activations and save memory for extra compute"},
|
||||
)
|
||||
ddp_backend: str = II("distributed_training.ddp_backend")
|
||||
|
||||
prediction_mode: str = "lin_softmax"
|
||||
eval_prediction_mode: Optional[str] = None
|
||||
conv_kernel: int = -1
|
||||
conv_stride: int = 1
|
||||
two_convs: bool = False
|
||||
extreme_factor: float = 1.0
|
||||
|
||||
conv_feature_layers: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "string describing convolutional feature extraction layers in form of a python list that contains "
|
||||
"[(dim, kernel_size, stride), ...]"
|
||||
},
|
||||
)
|
||||
|
||||
mixup_prob: float = 1.0
|
||||
source_mixup: float = -1
|
||||
same_mixup: bool = True
|
||||
label_mixup: bool = False
|
||||
|
||||
gain_mode: str = "none"
|
||||
|
||||
|
||||
@register_model("audio_classification", dataclass=AudioClassificationConfig)
|
||||
class AudioClassificationModel(BaseFairseqModel):
|
||||
def __init__(self, cfg: AudioClassificationConfig, num_classes):
|
||||
super().__init__()
|
||||
|
||||
self.apply_mask = cfg.apply_mask
|
||||
self.cfg = cfg
|
||||
|
||||
arg_overrides = {
|
||||
"dropout": cfg.dropout,
|
||||
"activation_dropout": cfg.activation_dropout,
|
||||
"dropout_input": cfg.dropout_input,
|
||||
"attention_dropout": cfg.attention_dropout,
|
||||
"mask_length": cfg.mask_length,
|
||||
"mask_prob": cfg.mask_prob,
|
||||
"require_same_masks": getattr(cfg, "require_same_masks", True),
|
||||
"mask_dropout": getattr(cfg, "mask_dropout", 0),
|
||||
"mask_selection": cfg.mask_selection,
|
||||
"mask_other": cfg.mask_other,
|
||||
"no_mask_overlap": cfg.no_mask_overlap,
|
||||
"mask_channel_length": cfg.mask_channel_length,
|
||||
"mask_channel_prob": cfg.mask_channel_prob,
|
||||
"mask_channel_before": cfg.mask_channel_before,
|
||||
"mask_channel_selection": cfg.mask_channel_selection,
|
||||
"mask_channel_other": cfg.mask_channel_other,
|
||||
"no_mask_channel_overlap": cfg.no_mask_channel_overlap,
|
||||
"encoder_layerdrop": cfg.layerdrop,
|
||||
"feature_grad_mult": cfg.feature_grad_mult,
|
||||
"checkpoint_activations": cfg.checkpoint_activations,
|
||||
"offload_activations": cfg.offload_activations,
|
||||
"min_params_to_wrap": cfg.min_params_to_wrap,
|
||||
"mixup": -1,
|
||||
}
|
||||
|
||||
if cfg.conv_feature_layers is not None:
|
||||
arg_overrides["conv_feature_layers"] = cfg.conv_feature_layers
|
||||
|
||||
if cfg.d2v_args is None:
|
||||
state = checkpoint_utils.load_checkpoint_to_cpu(
|
||||
cfg.model_path, arg_overrides
|
||||
)
|
||||
d2v_args = state.get("cfg", None)
|
||||
if d2v_args is None:
|
||||
d2v_args = convert_namespace_to_omegaconf(state["args"])
|
||||
d2v_args.criterion = None
|
||||
d2v_args.lr_scheduler = None
|
||||
cfg.d2v_args = d2v_args
|
||||
|
||||
logger.info(d2v_args)
|
||||
|
||||
else:
|
||||
state = None
|
||||
d2v_args = cfg.d2v_args
|
||||
|
||||
model_normalized = d2v_args.task.get(
|
||||
"normalize", d2v_args.model.get("normalize", False)
|
||||
)
|
||||
assert cfg.normalize == model_normalized, (
|
||||
"Fine-tuning works best when data normalization is the same. "
|
||||
"Please check that --normalize is set or unset for both pre-training and here"
|
||||
)
|
||||
|
||||
if hasattr(cfg, "checkpoint_activations") and cfg.checkpoint_activations:
|
||||
with open_dict(d2v_args):
|
||||
d2v_args.model.checkpoint_activations = cfg.checkpoint_activations
|
||||
|
||||
d2v_args.task.data = cfg.data
|
||||
task = tasks.setup_task(d2v_args.task)
|
||||
model = task.build_model(d2v_args.model, from_checkpoint=True)
|
||||
|
||||
model.remove_pretraining_modules()
|
||||
|
||||
if state is not None and not cfg.no_pretrained_weights:
|
||||
self.load_model_weights(state, model, cfg)
|
||||
|
||||
d = d2v_args.model.encoder_embed_dim
|
||||
|
||||
self.d2v_model = model
|
||||
|
||||
self.final_dropout = nn.Dropout(cfg.final_dropout)
|
||||
self.freeze_finetune_updates = cfg.freeze_finetune_updates
|
||||
self.num_updates = 0
|
||||
|
||||
for p in self.parameters():
|
||||
p.param_group = "pretrained"
|
||||
|
||||
if cfg.prediction_mode == "proj_avg_proj":
|
||||
self.proj = nn.Linear(d, d * 2)
|
||||
self.proj2 = nn.Linear(d * 2, num_classes)
|
||||
|
||||
for p in self.proj.parameters():
|
||||
p.param_group = "projection"
|
||||
for p in self.proj2.parameters():
|
||||
p.param_group = "projection"
|
||||
elif self.cfg.prediction_mode == "summary_proj":
|
||||
self.proj = nn.Linear(d // 3, num_classes)
|
||||
for p in self.proj.parameters():
|
||||
p.param_group = "projection"
|
||||
elif self.cfg.conv_kernel > 1 and not self.cfg.two_convs:
|
||||
self.proj = nn.Sequential(
|
||||
TransposeLast(),
|
||||
nn.Conv1d(d, num_classes, kernel_size=self.cfg.conv_kernel, stride=self.cfg.conv_stride),
|
||||
TransposeLast(),
|
||||
)
|
||||
for p in self.proj.parameters():
|
||||
p.param_group = "projection"
|
||||
elif self.cfg.conv_kernel > 0 and self.cfg.two_convs:
|
||||
self.proj = nn.Sequential(
|
||||
TransposeLast(),
|
||||
nn.Conv1d(d, d, kernel_size=self.cfg.conv_kernel, stride=self.cfg.conv_stride),
|
||||
TransposeLast(),
|
||||
nn.GELU(),
|
||||
nn.Linear(d, num_classes),
|
||||
)
|
||||
for p in self.proj.parameters():
|
||||
p.param_group = "projection"
|
||||
else:
|
||||
self.proj = nn.Linear(d, num_classes)
|
||||
for p in self.proj.parameters():
|
||||
p.param_group = "projection"
|
||||
|
||||
def upgrade_state_dict_named(self, state_dict, name):
|
||||
super().upgrade_state_dict_named(state_dict, name)
|
||||
return state_dict
|
||||
|
||||
@classmethod
|
||||
def build_model(cls, cfg: AudioClassificationConfig, task: FairseqTask):
|
||||
"""Build a new model instance."""
|
||||
|
||||
assert hasattr(task, "labels"), f"Task {task} must have an attribute 'labels'"
|
||||
|
||||
return cls(cfg, len(task.labels))
|
||||
|
||||
def load_model_weights(self, state, model, cfg):
|
||||
if cfg.ddp_backend == "fully_sharded":
|
||||
from fairseq.distributed import FullyShardedDataParallel
|
||||
|
||||
for name, module in model.named_modules():
|
||||
if "encoder.layers" in name and len(name.split(".")) == 3:
|
||||
# Only for layers, we do a special handling and load the weights one by one
|
||||
# We dont load all weights together as that wont be memory efficient and may
|
||||
# cause oom
|
||||
new_dict = {
|
||||
k.replace(name + ".", ""): v
|
||||
for (k, v) in state["model"].items()
|
||||
if name + "." in k
|
||||
}
|
||||
assert isinstance(module, FullyShardedDataParallel)
|
||||
with module.summon_full_params():
|
||||
module.load_state_dict(new_dict, strict=True)
|
||||
module._reset_lazy_init()
|
||||
|
||||
# Once layers are loaded, filter them out and load everything else.
|
||||
r = re.compile("encoder.layers.\d.")
|
||||
filtered_list = list(filter(r.match, state["model"].keys()))
|
||||
|
||||
new_big_dict = {
|
||||
k: v for (k, v) in state["model"].items() if k not in filtered_list
|
||||
}
|
||||
|
||||
model.load_state_dict(new_big_dict, strict=False)
|
||||
else:
|
||||
if "_ema" in state["model"]:
|
||||
del state["model"]["_ema"]
|
||||
model.load_state_dict(state["model"], strict=False)
|
||||
|
||||
def set_num_updates(self, num_updates):
|
||||
"""Set the number of parameters updates."""
|
||||
super().set_num_updates(num_updates)
|
||||
self.num_updates = num_updates
|
||||
|
||||
def compute_gain(self, sound, fs=16_000, min_db=-80.0, mode="A_weighting"):
|
||||
if fs == 16000:
|
||||
n_fft = 2048
|
||||
elif fs == 44100:
|
||||
n_fft = 4096
|
||||
else:
|
||||
raise Exception("Invalid fs {}".format(fs))
|
||||
stride = n_fft // 2
|
||||
|
||||
def a_weight(fs, n_fft, min_db=-80.0):
|
||||
freq = np.linspace(0, fs // 2, n_fft // 2 + 1)
|
||||
freq_sq = np.power(freq, 2)
|
||||
freq_sq[0] = 1.0
|
||||
weight = 2.0 + 20.0 * (
|
||||
2 * np.log10(12194)
|
||||
+ 2 * np.log10(freq_sq)
|
||||
- np.log10(freq_sq + 12194 ** 2)
|
||||
- np.log10(freq_sq + 20.6 ** 2)
|
||||
- 0.5 * np.log10(freq_sq + 107.7 ** 2)
|
||||
- 0.5 * np.log10(freq_sq + 737.9 ** 2)
|
||||
)
|
||||
weight = np.maximum(weight, min_db)
|
||||
|
||||
return weight
|
||||
|
||||
gain = []
|
||||
for i in range(0, len(sound) - n_fft + 1, stride):
|
||||
if mode == "RMSE":
|
||||
g = np.mean(sound[i : i + n_fft] ** 2)
|
||||
elif mode == "A_weighting":
|
||||
spec = np.fft.rfft(np.hanning(n_fft + 1)[:-1] * sound[i : i + n_fft])
|
||||
power_spec = np.abs(spec) ** 2
|
||||
a_weighted_spec = power_spec * np.power(10, a_weight(fs, n_fft) / 10)
|
||||
g = np.sum(a_weighted_spec)
|
||||
else:
|
||||
raise Exception("Invalid mode {}".format(mode))
|
||||
gain.append(g)
|
||||
|
||||
gain = np.array(gain)
|
||||
gain = np.maximum(gain, np.power(10, min_db / 10))
|
||||
gain_db = 10 * np.log10(gain)
|
||||
|
||||
return gain_db
|
||||
|
||||
# adapted from https://github.com/mil-tokyo/bc_learning_sound/blob/master/utils.py
|
||||
def compute_gain_torch(self, sound, fs=16_000, min_db=-80.0, mode="A_weighting"):
|
||||
if fs == 16000:
|
||||
n_fft = 2048
|
||||
elif fs == 44100:
|
||||
n_fft = 4096
|
||||
else:
|
||||
raise Exception("Invalid fs {}".format(fs))
|
||||
|
||||
if mode == "A_weighting":
|
||||
if not hasattr(self, f"a_weight"):
|
||||
self.a_weight = {}
|
||||
|
||||
if fs not in self.a_weight:
|
||||
|
||||
def a_weight(fs, n_fft, min_db=-80.0):
|
||||
freq = np.linspace(0, fs // 2, n_fft // 2 + 1)
|
||||
freq_sq = freq ** 2
|
||||
freq_sq[0] = 1.0
|
||||
weight = 2.0 + 20.0 * (
|
||||
2 * np.log10(12194)
|
||||
+ 2 * np.log10(freq_sq)
|
||||
- np.log10(freq_sq + 12194 ** 2)
|
||||
- np.log10(freq_sq + 20.6 ** 2)
|
||||
- 0.5 * np.log10(freq_sq + 107.7 ** 2)
|
||||
- 0.5 * np.log10(freq_sq + 737.9 ** 2)
|
||||
)
|
||||
weight = np.maximum(weight, min_db)
|
||||
|
||||
return weight
|
||||
|
||||
self.a_weight[fs] = torch.from_numpy(
|
||||
np.power(10, a_weight(fs, n_fft, min_db) / 10)
|
||||
).to(device=sound.device)
|
||||
|
||||
sound = sound.unfold(-1, n_fft, n_fft // 2)
|
||||
|
||||
if mode == "RMSE":
|
||||
sound = sound ** 2
|
||||
g = sound.mean(-1)
|
||||
elif mode == "A_weighting":
|
||||
w = torch.hann_window(n_fft, device=sound.device) * sound
|
||||
spec = torch.fft.rfft(w)
|
||||
power_spec = spec.abs() ** 2
|
||||
a_weighted_spec = power_spec * self.a_weight[fs]
|
||||
g = a_weighted_spec.sum(-1)
|
||||
else:
|
||||
raise Exception("Invalid mode {}".format(mode))
|
||||
|
||||
gain = torch.maximum(g, torch.tensor(10 ** (min_db / 10), device=g.device))
|
||||
gain_db = 10 * torch.log10(gain)
|
||||
|
||||
return gain_db
|
||||
|
||||
def forward(self, source, padding_mask, label=None, **kwargs):
|
||||
|
||||
if self.cfg.source_mixup >= 0 and self.training and self.cfg.mixup_prob > 0:
|
||||
with torch.no_grad():
|
||||
mixed_source = source
|
||||
mix_mask = None
|
||||
if self.cfg.mixup_prob < 1:
|
||||
mix_mask = (
|
||||
torch.empty((source.size(0),), device=source.device)
|
||||
.bernoulli_(self.cfg.mixup_prob)
|
||||
.bool()
|
||||
)
|
||||
mixed_source = source[mix_mask]
|
||||
|
||||
r = (
|
||||
torch.FloatTensor(
|
||||
1 if self.cfg.same_mixup else mixed_source.size(0)
|
||||
)
|
||||
.uniform_(max(1e-6, self.cfg.source_mixup), 1)
|
||||
.to(dtype=source.dtype, device=source.device)
|
||||
)
|
||||
|
||||
mixup_perm = torch.randperm(source.size(0))
|
||||
s2 = source[mixup_perm]
|
||||
|
||||
if self.cfg.gain_mode == "none":
|
||||
p = r.unsqueeze(-1)
|
||||
if mix_mask is not None:
|
||||
s2 = s2[mix_mask]
|
||||
else:
|
||||
if self.cfg.gain_mode == "naive_rms":
|
||||
G1 = source.pow(2).mean(dim=-1).sqrt()
|
||||
else:
|
||||
G1, _ = self.compute_gain_torch(
|
||||
source, mode=self.cfg.gain_mode
|
||||
).max(-1)
|
||||
G1 = G1.to(dtype=source.dtype)
|
||||
|
||||
G2 = G1[mixup_perm]
|
||||
|
||||
if mix_mask is not None:
|
||||
G1 = G1[mix_mask]
|
||||
G2 = G2[mix_mask]
|
||||
s2 = s2[mix_mask]
|
||||
|
||||
p = 1 / (1 + 10 ** ((G1 - G2) / 20) * (1 - r) / r)
|
||||
p = p.unsqueeze(-1)
|
||||
|
||||
mixed = (p * mixed_source) + (1 - p) * s2
|
||||
|
||||
if mix_mask is None:
|
||||
source = mixed / torch.sqrt(p ** 2 + (1 - p) ** 2)
|
||||
else:
|
||||
source[mix_mask] = mixed / torch.sqrt(p ** 2 + (1 - p) ** 2)
|
||||
|
||||
if label is not None and self.cfg.label_mixup:
|
||||
r = r.unsqueeze(-1)
|
||||
if mix_mask is None:
|
||||
label = label * r + (1 - r) * label[mixup_perm]
|
||||
else:
|
||||
label[mix_mask] = (
|
||||
label[mix_mask] * r + (1 - r) * label[mixup_perm][mix_mask]
|
||||
)
|
||||
|
||||
d2v_args = {
|
||||
"source": source,
|
||||
"padding_mask": padding_mask,
|
||||
"mask": self.apply_mask and self.training,
|
||||
}
|
||||
|
||||
ft = self.freeze_finetune_updates <= self.num_updates
|
||||
|
||||
with torch.no_grad() if not ft else contextlib.ExitStack():
|
||||
res = self.d2v_model.extract_features(**d2v_args)
|
||||
|
||||
x = res["x"]
|
||||
padding_mask = res["padding_mask"]
|
||||
if padding_mask is not None:
|
||||
x[padding_mask] = 0
|
||||
|
||||
x = self.final_dropout(x)
|
||||
|
||||
if self.training or (
|
||||
self.cfg.eval_prediction_mode is None or self.cfg.eval_prediction_mode == ""
|
||||
):
|
||||
prediction_mode = self.cfg.prediction_mode
|
||||
else:
|
||||
prediction_mode = self.cfg.eval_prediction_mode
|
||||
|
||||
if prediction_mode == "average_before":
|
||||
x = x.mean(dim=1)
|
||||
|
||||
if prediction_mode != "summary_mha" and prediction_mode != "summary_proj" and prediction_mode != "cls":
|
||||
x = self.proj(x)
|
||||
|
||||
logits = True
|
||||
if prediction_mode == "lin_softmax":
|
||||
x = F.logsigmoid(x.float())
|
||||
x = torch.logsumexp(x + x, dim=1) - torch.logsumexp(x, dim=1)
|
||||
x = x.clamp(max=0)
|
||||
x = x - torch.log(-(torch.expm1(x)))
|
||||
elif prediction_mode == "extremized_odds":
|
||||
x = x.float().sum(dim=1)
|
||||
x = x * self.cfg.extreme_factor
|
||||
elif prediction_mode == "average_before":
|
||||
x = x.float()
|
||||
elif prediction_mode == "average":
|
||||
x = x.float().mean(dim=1)
|
||||
elif prediction_mode == "average_sigmoid":
|
||||
x = torch.sigmoid(x.float())
|
||||
x = x.mean(dim=1)
|
||||
logits = False
|
||||
elif prediction_mode == "max":
|
||||
x, _ = x.float().max(dim=1)
|
||||
elif prediction_mode == "max_sigmoid":
|
||||
x = torch.sigmoid(x.float())
|
||||
x, _ = x.float().max(dim=1)
|
||||
logits = False
|
||||
elif prediction_mode == "proj_avg_proj":
|
||||
x = x.mean(dim=1)
|
||||
x = self.proj2(x)
|
||||
elif prediction_mode == "summary_mha" or prediction_mode == "summary_proj":
|
||||
x = self.d2v_model.summary(
|
||||
x, padding_mask, proj=prediction_mode == "summary_proj"
|
||||
)
|
||||
x = x.type_as(source)
|
||||
x = self.proj(x)
|
||||
elif prediction_mode == "cls":
|
||||
x = x[:,0]
|
||||
x = self.proj(x)
|
||||
else:
|
||||
raise Exception(f"unknown prediction mode {prediction_mode}")
|
||||
|
||||
if label is None:
|
||||
return torch.sigmoid(x) if logits else x
|
||||
|
||||
x = torch.nan_to_num(x)
|
||||
|
||||
if logits:
|
||||
loss = F.binary_cross_entropy_with_logits(
|
||||
x, label.float(), reduction="none"
|
||||
)
|
||||
else:
|
||||
loss = F.binary_cross_entropy(x, label.float(), reduction="none")
|
||||
|
||||
result = {
|
||||
"losses": {
|
||||
"main": loss,
|
||||
},
|
||||
"sample_size": label.sum(),
|
||||
}
|
||||
|
||||
if not self.training:
|
||||
result["_predictions"] = torch.sigmoid(x) if logits else x
|
||||
result["_targets"] = label
|
||||
|
||||
return result
|
813
examples/data2vec/models/data2vec2.py
Normal file
813
examples/data2vec/models/data2vec2.py
Normal file
@ -0,0 +1,813 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import math
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional, Callable
|
||||
from functools import partial
|
||||
import numpy as np
|
||||
|
||||
from omegaconf import II
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.distributed as dist
|
||||
|
||||
from fairseq.modules import EMAModule, EMAModuleConfig
|
||||
|
||||
from fairseq.dataclass import FairseqDataclass
|
||||
from fairseq.models import BaseFairseqModel, register_model
|
||||
|
||||
from examples.data2vec.data.modality import Modality
|
||||
|
||||
from examples.data2vec.models.modalities.base import (
|
||||
MaskSeed,
|
||||
D2vModalityConfig,
|
||||
ModalitySpecificEncoder,
|
||||
get_annealed_rate,
|
||||
)
|
||||
from examples.data2vec.models.modalities.modules import (
|
||||
D2vDecoderConfig,
|
||||
AltBlock,
|
||||
Decoder1d,
|
||||
)
|
||||
|
||||
from examples.data2vec.models.modalities.audio import (
|
||||
D2vAudioConfig,
|
||||
AudioEncoder,
|
||||
)
|
||||
from examples.data2vec.models.modalities.images import (
|
||||
D2vImageConfig,
|
||||
ImageEncoder,
|
||||
)
|
||||
from examples.data2vec.models.modalities.text import (
|
||||
D2vTextConfig,
|
||||
TextEncoder,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class D2vModalitiesConfig(FairseqDataclass):
|
||||
audio: D2vAudioConfig = D2vAudioConfig()
|
||||
image: D2vImageConfig = D2vImageConfig()
|
||||
text: D2vTextConfig = D2vTextConfig()
|
||||
|
||||
|
||||
@dataclass
|
||||
class Data2VecMultiConfig(FairseqDataclass):
|
||||
|
||||
loss_beta: float = field(
|
||||
default=0, metadata={"help": "beta for smooth l1 loss. 0 means use l2 loss"}
|
||||
)
|
||||
loss_scale: Optional[float] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "scale the reconstruction loss by this constant. if None then scales by 1/sqrt(dim)"
|
||||
},
|
||||
)
|
||||
|
||||
depth: int = 8
|
||||
start_drop_path_rate: float = 0
|
||||
end_drop_path_rate: float = 0
|
||||
num_heads: int = 12
|
||||
norm_eps: float = 1e-6
|
||||
norm_affine: bool = True
|
||||
encoder_dropout: float = 0.1
|
||||
post_mlp_drop: float = 0.1
|
||||
attention_dropout: float = 0.1
|
||||
activation_dropout: float = 0.0
|
||||
dropout_input: float = 0.0
|
||||
layerdrop: float = 0.0
|
||||
embed_dim: int = 768
|
||||
mlp_ratio: float = 4
|
||||
layer_norm_first: bool = False
|
||||
|
||||
average_top_k_layers: int = field(
|
||||
default=8, metadata={"help": "how many layers to average"}
|
||||
)
|
||||
|
||||
end_of_block_targets: bool = False
|
||||
|
||||
clone_batch: int = 1
|
||||
|
||||
layer_norm_target_layer: bool = False
|
||||
batch_norm_target_layer: bool = False
|
||||
instance_norm_target_layer: bool = False
|
||||
instance_norm_targets: bool = False
|
||||
layer_norm_targets: bool = False
|
||||
|
||||
ema_decay: float = field(default=0.999, metadata={"help": "initial ema decay rate"})
|
||||
ema_same_dtype: bool = True
|
||||
log_norms: bool = True
|
||||
ema_end_decay: float = field(
|
||||
default=0.9999, metadata={"help": "final ema decay rate"}
|
||||
)
|
||||
|
||||
# when to finish annealing ema decay rate
|
||||
ema_anneal_end_step: int = II("optimization.max_update")
|
||||
|
||||
ema_encoder_only: bool = field(
|
||||
default=True,
|
||||
metadata={
|
||||
"help": "whether to momentum update only the shared transformer encoder"
|
||||
},
|
||||
)
|
||||
|
||||
max_update: int = II("optimization.max_update")
|
||||
|
||||
modalities: D2vModalitiesConfig = D2vModalitiesConfig()
|
||||
|
||||
shared_decoder: Optional[D2vDecoderConfig] = None
|
||||
|
||||
min_target_var: float = field(
|
||||
default=0.1, metadata={"help": "stop training if target var falls below this"}
|
||||
)
|
||||
min_pred_var: float = field(
|
||||
default=0.01,
|
||||
metadata={"help": "stop training if prediction var falls below this"},
|
||||
)
|
||||
|
||||
supported_modality: Optional[Modality] = None
|
||||
mae_init: bool = False
|
||||
|
||||
seed: int = II("common.seed")
|
||||
|
||||
skip_ema: bool = False
|
||||
|
||||
cls_loss: float = 0
|
||||
recon_loss: float = 0
|
||||
d2v_loss: float = 1
|
||||
|
||||
decoder_group: bool = False
|
||||
|
||||
|
||||
@register_model("data2vec_multi", dataclass=Data2VecMultiConfig)
|
||||
class Data2VecMultiModel(BaseFairseqModel):
|
||||
def make_modality_encoder(
|
||||
self,
|
||||
cfg: D2vModalityConfig,
|
||||
embed_dim: int,
|
||||
make_block: Callable[[float], nn.ModuleList],
|
||||
norm_layer: Callable[[int], nn.LayerNorm],
|
||||
layer_norm_first: bool,
|
||||
alibi_biases,
|
||||
task,
|
||||
) -> ModalitySpecificEncoder:
|
||||
if cfg.type == Modality.AUDIO:
|
||||
enc_cls = AudioEncoder
|
||||
elif cfg.type == Modality.IMAGE:
|
||||
enc_cls = ImageEncoder
|
||||
elif cfg.type == Modality.TEXT:
|
||||
enc_cls = TextEncoder
|
||||
if hasattr(task, "text_task"):
|
||||
task = task.text_task
|
||||
else:
|
||||
raise Exception(f"unsupported modality {cfg.type}")
|
||||
|
||||
return enc_cls(
|
||||
cfg,
|
||||
embed_dim,
|
||||
make_block,
|
||||
norm_layer,
|
||||
layer_norm_first,
|
||||
alibi_biases,
|
||||
task,
|
||||
)
|
||||
|
||||
def __init__(self, cfg: Data2VecMultiConfig, modalities, skip_ema=False, task=None):
|
||||
super().__init__()
|
||||
self.cfg = cfg
|
||||
self.modalities = modalities
|
||||
self.task = task
|
||||
|
||||
make_layer_norm = partial(
|
||||
nn.LayerNorm, eps=cfg.norm_eps, elementwise_affine=cfg.norm_affine
|
||||
)
|
||||
|
||||
def make_block(drop_path, dim=None, heads=None):
|
||||
return AltBlock(
|
||||
cfg.embed_dim if dim is None else dim,
|
||||
cfg.num_heads if heads is None else heads,
|
||||
cfg.mlp_ratio,
|
||||
qkv_bias=True,
|
||||
drop=cfg.encoder_dropout,
|
||||
attn_drop=cfg.attention_dropout,
|
||||
mlp_drop=cfg.activation_dropout,
|
||||
post_mlp_drop=cfg.post_mlp_drop,
|
||||
drop_path=drop_path,
|
||||
norm_layer=make_layer_norm,
|
||||
layer_norm_first=cfg.layer_norm_first,
|
||||
ffn_targets=not cfg.end_of_block_targets,
|
||||
)
|
||||
|
||||
self.alibi_biases = {}
|
||||
self.modality_encoders = nn.ModuleDict()
|
||||
for mod in self.modalities:
|
||||
mod_cfg = getattr(cfg.modalities, mod.name.lower())
|
||||
enc = self.make_modality_encoder(
|
||||
mod_cfg,
|
||||
cfg.embed_dim,
|
||||
make_block,
|
||||
make_layer_norm,
|
||||
cfg.layer_norm_first,
|
||||
self.alibi_biases,
|
||||
task,
|
||||
)
|
||||
self.modality_encoders[mod.name] = enc
|
||||
|
||||
self.ema = None
|
||||
|
||||
self.average_top_k_layers = cfg.average_top_k_layers
|
||||
self.loss_beta = cfg.loss_beta
|
||||
self.loss_scale = cfg.loss_scale
|
||||
|
||||
self.dropout_input = nn.Dropout(cfg.dropout_input)
|
||||
|
||||
dpr = np.linspace(cfg.start_drop_path_rate, cfg.end_drop_path_rate, cfg.depth)
|
||||
|
||||
self.blocks = nn.ModuleList([make_block(dpr[i]) for i in range(cfg.depth)])
|
||||
|
||||
self.norm = None
|
||||
if cfg.layer_norm_first:
|
||||
self.norm = make_layer_norm(cfg.embed_dim)
|
||||
|
||||
if self.cfg.mae_init:
|
||||
self.apply(self._init_weights)
|
||||
else:
|
||||
from fairseq.modules.transformer_sentence_encoder import init_bert_params
|
||||
|
||||
self.apply(init_bert_params)
|
||||
|
||||
for mod_enc in self.modality_encoders.values():
|
||||
mod_enc.reset_parameters()
|
||||
|
||||
if not skip_ema:
|
||||
self.ema = self.make_ema_teacher(cfg.ema_decay)
|
||||
self.shared_decoder = (
|
||||
Decoder1d(cfg.shared_decoder, cfg.embed_dim)
|
||||
if self.cfg.shared_decoder is not None
|
||||
else None
|
||||
)
|
||||
if self.shared_decoder is not None:
|
||||
self.shared_decoder.apply(self._init_weights)
|
||||
|
||||
self.recon_proj = None
|
||||
if cfg.recon_loss > 0:
|
||||
self.recon_proj = nn.Linear(cfg.embed_dim, cfg.embed_dim)
|
||||
|
||||
for pn, p in self.named_parameters():
|
||||
if len(p.shape) == 1 or pn.endswith(".bias") or "alibi_scale" in pn:
|
||||
p.optim_overrides = {"optimizer": {"weight_decay_scale": 0}}
|
||||
if cfg.decoder_group and "decoder" in pn:
|
||||
p.param_group = "decoder"
|
||||
|
||||
self.num_updates = 0
|
||||
|
||||
def _init_weights(self, m):
|
||||
|
||||
try:
|
||||
from apex.normalization import FusedLayerNorm
|
||||
|
||||
fn = FusedLayerNorm
|
||||
except:
|
||||
fn = nn.LayerNorm
|
||||
|
||||
if isinstance(m, nn.Linear):
|
||||
torch.nn.init.xavier_uniform_(m.weight)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.LayerNorm) or isinstance(m, fn):
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
if m.weight is not None:
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
|
||||
@torch.no_grad()
|
||||
def make_ema_teacher(self, ema_decay):
|
||||
ema_config = EMAModuleConfig(
|
||||
ema_decay=ema_decay,
|
||||
ema_fp32=True,
|
||||
log_norms=self.cfg.log_norms,
|
||||
add_missing_params=False,
|
||||
)
|
||||
|
||||
model_copy = self.make_target_model()
|
||||
|
||||
return EMAModule(
|
||||
model_copy,
|
||||
ema_config,
|
||||
copy_model=False,
|
||||
)
|
||||
|
||||
def make_target_model(self):
|
||||
logger.info("making target model")
|
||||
|
||||
model_copy = Data2VecMultiModel(
|
||||
self.cfg, self.modalities, skip_ema=True, task=self.task
|
||||
)
|
||||
|
||||
if self.cfg.ema_encoder_only:
|
||||
model_copy = model_copy.blocks
|
||||
for p_s, p_t in zip(self.blocks.parameters(), model_copy.parameters()):
|
||||
p_t.data.copy_(p_s.data)
|
||||
else:
|
||||
for p_s, p_t in zip(self.parameters(), model_copy.parameters()):
|
||||
p_t.data.copy_(p_s.data)
|
||||
|
||||
for mod_enc in model_copy.modality_encoders.values():
|
||||
mod_enc.decoder = None
|
||||
if not mod_enc.modality_cfg.ema_local_encoder:
|
||||
mod_enc.local_encoder = None
|
||||
mod_enc.project_features = None
|
||||
|
||||
model_copy.requires_grad_(False)
|
||||
return model_copy
|
||||
|
||||
def set_num_updates(self, num_updates):
|
||||
super().set_num_updates(num_updates)
|
||||
|
||||
if self.ema is not None and (
|
||||
(self.num_updates == 0 and num_updates > 1)
|
||||
or self.num_updates >= num_updates
|
||||
):
|
||||
pass
|
||||
elif self.training and self.ema is not None:
|
||||
ema_weight_decay = None
|
||||
if self.cfg.ema_decay != self.cfg.ema_end_decay:
|
||||
if num_updates >= self.cfg.ema_anneal_end_step:
|
||||
decay = self.cfg.ema_end_decay
|
||||
else:
|
||||
decay = get_annealed_rate(
|
||||
self.cfg.ema_decay,
|
||||
self.cfg.ema_end_decay,
|
||||
num_updates,
|
||||
self.cfg.ema_anneal_end_step,
|
||||
)
|
||||
self.ema.set_decay(decay, weight_decay=ema_weight_decay)
|
||||
if self.ema.get_decay() < 1:
|
||||
self.ema.step(self.blocks if self.cfg.ema_encoder_only else self)
|
||||
|
||||
self.num_updates = num_updates
|
||||
|
||||
def state_dict(self, destination=None, prefix="", keep_vars=False):
|
||||
state = super().state_dict(destination, prefix, keep_vars)
|
||||
|
||||
if self.ema is not None:
|
||||
state[prefix + "_ema"] = self.ema.fp32_params
|
||||
|
||||
return state
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||
k = prefix + "_ema"
|
||||
if self.ema is not None:
|
||||
assert k in state_dict
|
||||
self.ema.restore(state_dict[k], True)
|
||||
del state_dict[k]
|
||||
elif k in state_dict:
|
||||
del state_dict[k]
|
||||
|
||||
return super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def build_model(cls, cfg: Data2VecMultiConfig, task=None):
|
||||
"""Build a new model instance."""
|
||||
if task is None or not hasattr(task, "supported_modalities"):
|
||||
modalities = (
|
||||
[cfg.supported_modality]
|
||||
if cfg.supported_modality is not None
|
||||
else [
|
||||
Modality.AUDIO,
|
||||
Modality.IMAGE,
|
||||
Modality.TEXT,
|
||||
]
|
||||
)
|
||||
else:
|
||||
modalities = task.supported_modalities
|
||||
return cls(cfg, modalities, task=task, skip_ema=cfg.skip_ema)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
source,
|
||||
target=None,
|
||||
id=None,
|
||||
mode=None,
|
||||
padding_mask=None,
|
||||
mask=True,
|
||||
features_only=False,
|
||||
force_remove_masked=False,
|
||||
remove_extra_tokens=True,
|
||||
precomputed_mask=None,
|
||||
):
|
||||
if mode is None:
|
||||
assert self.cfg.supported_modality is not None
|
||||
mode = self.cfg.supported_modality
|
||||
|
||||
if isinstance(mode, Modality):
|
||||
mode = mode.name
|
||||
|
||||
feature_extractor = self.modality_encoders[mode]
|
||||
|
||||
mask_seeds = None
|
||||
if id is not None:
|
||||
mask_seeds = MaskSeed(seed=self.cfg.seed, update=self.num_updates, ids=id)
|
||||
|
||||
extractor_out = feature_extractor(
|
||||
source,
|
||||
padding_mask,
|
||||
mask,
|
||||
remove_masked=not features_only or force_remove_masked,
|
||||
clone_batch=self.cfg.clone_batch if not features_only else 1,
|
||||
mask_seeds=mask_seeds,
|
||||
precomputed_mask=precomputed_mask,
|
||||
)
|
||||
|
||||
x = extractor_out["x"]
|
||||
encoder_mask = extractor_out["encoder_mask"]
|
||||
masked_padding_mask = extractor_out["padding_mask"]
|
||||
masked_alibi_bias = extractor_out.get("alibi_bias", None)
|
||||
alibi_scale = extractor_out.get("alibi_scale", None)
|
||||
|
||||
if self.dropout_input is not None:
|
||||
x = self.dropout_input(x)
|
||||
|
||||
layer_results = []
|
||||
for i, blk in enumerate(self.blocks):
|
||||
if (
|
||||
not self.training
|
||||
or self.cfg.layerdrop == 0
|
||||
or (np.random.random() > self.cfg.layerdrop)
|
||||
):
|
||||
ab = masked_alibi_bias
|
||||
if ab is not None and alibi_scale is not None:
|
||||
scale = (
|
||||
alibi_scale[i]
|
||||
if alibi_scale.size(0) > 1
|
||||
else alibi_scale.squeeze(0)
|
||||
)
|
||||
ab = ab * scale.type_as(ab)
|
||||
|
||||
x, lr = blk(
|
||||
x,
|
||||
padding_mask=masked_padding_mask,
|
||||
alibi_bias=ab,
|
||||
)
|
||||
if features_only:
|
||||
layer_results.append(lr)
|
||||
|
||||
if self.norm is not None:
|
||||
x = self.norm(x)
|
||||
|
||||
if features_only:
|
||||
if remove_extra_tokens:
|
||||
x = x[:, feature_extractor.modality_cfg.num_extra_tokens :]
|
||||
if masked_padding_mask is not None:
|
||||
masked_padding_mask = masked_padding_mask[
|
||||
:, feature_extractor.modality_cfg.num_extra_tokens :
|
||||
]
|
||||
|
||||
return {
|
||||
"x": x,
|
||||
"padding_mask": masked_padding_mask,
|
||||
"layer_results": layer_results,
|
||||
"mask": encoder_mask,
|
||||
}
|
||||
|
||||
xs = []
|
||||
|
||||
if self.shared_decoder is not None:
|
||||
dx = self.forward_decoder(
|
||||
x,
|
||||
feature_extractor,
|
||||
self.shared_decoder,
|
||||
encoder_mask,
|
||||
)
|
||||
xs.append(dx)
|
||||
if feature_extractor.decoder is not None:
|
||||
dx = self.forward_decoder(
|
||||
x,
|
||||
feature_extractor,
|
||||
feature_extractor.decoder,
|
||||
encoder_mask,
|
||||
)
|
||||
xs.append(dx)
|
||||
orig_x = x
|
||||
|
||||
assert len(xs) > 0
|
||||
|
||||
p = next(self.ema.model.parameters())
|
||||
device = x.device
|
||||
dtype = x.dtype
|
||||
ema_device = p.device
|
||||
ema_dtype = p.dtype
|
||||
|
||||
if not self.cfg.ema_same_dtype:
|
||||
dtype = ema_dtype
|
||||
|
||||
if ema_device != device or ema_dtype != dtype:
|
||||
logger.info(f"adjusting ema dtype to {dtype} and device to {device}")
|
||||
self.ema.model = self.ema.model.to(dtype=dtype, device=device)
|
||||
ema_dtype = dtype
|
||||
|
||||
def to_device(d):
|
||||
for k, p in d.items():
|
||||
if isinstance(d[k], dict):
|
||||
to_device(d[k])
|
||||
else:
|
||||
d[k] = p.to(device=device)
|
||||
|
||||
to_device(self.ema.fp32_params)
|
||||
tm = self.ema.model
|
||||
|
||||
with torch.no_grad():
|
||||
tm.eval()
|
||||
|
||||
if self.cfg.ema_encoder_only:
|
||||
assert target is None
|
||||
ema_input = extractor_out["local_features"]
|
||||
ema_input = feature_extractor.contextualized_features(
|
||||
ema_input.to(dtype=ema_dtype),
|
||||
padding_mask,
|
||||
mask=False,
|
||||
remove_masked=False,
|
||||
)
|
||||
ema_blocks = tm
|
||||
else:
|
||||
ema_blocks = tm.blocks
|
||||
if feature_extractor.modality_cfg.ema_local_encoder:
|
||||
inp = (
|
||||
target.to(dtype=ema_dtype)
|
||||
if target is not None
|
||||
else source.to(dtype=ema_dtype)
|
||||
)
|
||||
ema_input = tm.modality_encoders[mode](
|
||||
inp,
|
||||
padding_mask,
|
||||
mask=False,
|
||||
remove_masked=False,
|
||||
)
|
||||
else:
|
||||
assert target is None
|
||||
ema_input = extractor_out["local_features"]
|
||||
ema_feature_enc = tm.modality_encoders[mode]
|
||||
ema_input = ema_feature_enc.contextualized_features(
|
||||
ema_input.to(dtype=ema_dtype),
|
||||
padding_mask,
|
||||
mask=False,
|
||||
remove_masked=False,
|
||||
)
|
||||
|
||||
ema_padding_mask = ema_input["padding_mask"]
|
||||
ema_alibi_bias = ema_input.get("alibi_bias", None)
|
||||
ema_alibi_scale = ema_input.get("alibi_scale", None)
|
||||
ema_input = ema_input["x"]
|
||||
|
||||
y = []
|
||||
ema_x = []
|
||||
extra_tokens = feature_extractor.modality_cfg.num_extra_tokens
|
||||
for i, blk in enumerate(ema_blocks):
|
||||
ab = ema_alibi_bias
|
||||
if ab is not None and alibi_scale is not None:
|
||||
scale = (
|
||||
ema_alibi_scale[i]
|
||||
if ema_alibi_scale.size(0) > 1
|
||||
else ema_alibi_scale.squeeze(0)
|
||||
)
|
||||
ab = ab * scale.type_as(ab)
|
||||
|
||||
ema_input, lr = blk(
|
||||
ema_input,
|
||||
padding_mask=ema_padding_mask,
|
||||
alibi_bias=ab,
|
||||
)
|
||||
y.append(lr[:, extra_tokens:])
|
||||
ema_x.append(ema_input[:, extra_tokens:])
|
||||
|
||||
y = self.make_targets(y, self.average_top_k_layers)
|
||||
orig_targets = y
|
||||
|
||||
if self.cfg.clone_batch > 1:
|
||||
y = y.repeat_interleave(self.cfg.clone_batch, 0)
|
||||
|
||||
masked = encoder_mask.mask.unsqueeze(-1)
|
||||
masked_b = encoder_mask.mask.bool()
|
||||
y = y[masked_b]
|
||||
|
||||
if xs[0].size(1) == masked_b.size(1):
|
||||
xs = [x[masked_b] for x in xs]
|
||||
else:
|
||||
xs = [x.reshape(-1, x.size(-1)) for x in xs]
|
||||
|
||||
sample_size = masked.sum().long()
|
||||
|
||||
result = {
|
||||
"losses": {},
|
||||
"sample_size": sample_size,
|
||||
}
|
||||
|
||||
sample_size = result["sample_size"]
|
||||
|
||||
if self.cfg.cls_loss > 0:
|
||||
assert extra_tokens > 0
|
||||
cls_target = orig_targets.mean(dim=1)
|
||||
if self.cfg.clone_batch > 1:
|
||||
cls_target = cls_target.repeat_interleave(self.cfg.clone_batch, 0)
|
||||
cls_pred = x[:, extra_tokens - 1]
|
||||
result["losses"]["cls"] = self.d2v_loss(cls_pred, cls_target) * (
|
||||
self.cfg.cls_loss * sample_size
|
||||
)
|
||||
|
||||
if self.cfg.recon_loss > 0:
|
||||
|
||||
with torch.no_grad():
|
||||
target = feature_extractor.patchify(source)
|
||||
mean = target.mean(dim=-1, keepdim=True)
|
||||
var = target.var(dim=-1, keepdim=True)
|
||||
target = (target - mean) / (var + 1.0e-6) ** 0.5
|
||||
|
||||
if self.cfg.clone_batch > 1:
|
||||
target = target.repeat_interleave(self.cfg.clone_batch, 0)
|
||||
|
||||
if masked_b is not None:
|
||||
target = target[masked_b]
|
||||
|
||||
recon = xs[0]
|
||||
if self.recon_proj is not None:
|
||||
recon = self.recon_proj(recon)
|
||||
|
||||
result["losses"]["recon"] = (
|
||||
self.d2v_loss(recon, target.float()) * self.cfg.recon_loss
|
||||
)
|
||||
|
||||
if self.cfg.d2v_loss > 0:
|
||||
for i, x in enumerate(xs):
|
||||
reg_loss = self.d2v_loss(x, y)
|
||||
n = f"{mode}_regression_{i}" if len(xs) > 1 else f"{mode}_regression"
|
||||
result["losses"][n] = reg_loss * self.cfg.d2v_loss
|
||||
|
||||
suffix = "" if len(self.modalities) == 1 else f"_{mode}"
|
||||
with torch.no_grad():
|
||||
if encoder_mask is not None:
|
||||
result["masked_pct"] = 1 - (
|
||||
encoder_mask.ids_keep.size(1) / encoder_mask.ids_restore.size(1)
|
||||
)
|
||||
for i, x in enumerate(xs):
|
||||
n = f"pred_var{suffix}_{i}" if len(xs) > 1 else f"pred_var{suffix}"
|
||||
result[n] = self.compute_var(x.float())
|
||||
if self.ema is not None:
|
||||
for k, v in self.ema.logs.items():
|
||||
result[k] = v
|
||||
|
||||
y = y.float()
|
||||
result[f"target_var{suffix}"] = self.compute_var(y)
|
||||
|
||||
if self.num_updates > 5000:
|
||||
if result[f"target_var{suffix}"] < self.cfg.min_target_var:
|
||||
logger.error(
|
||||
f"target var is {result[f'target_var{suffix}'].item()} < {self.cfg.min_target_var}, exiting ({mode})"
|
||||
)
|
||||
raise Exception(
|
||||
f"target var is {result[f'target_var{suffix}'].item()} < {self.cfg.min_target_var}, exiting ({mode})"
|
||||
)
|
||||
|
||||
for k in result.keys():
|
||||
if k.startswith("pred_var") and result[k] < self.cfg.min_pred_var:
|
||||
logger.error(
|
||||
f"{k} is {result[k].item()} < {self.cfg.min_pred_var}, exiting ({mode})"
|
||||
)
|
||||
raise Exception(
|
||||
f"{k} is {result[k].item()} < {self.cfg.min_pred_var}, exiting ({mode})"
|
||||
)
|
||||
|
||||
result["ema_decay"] = self.ema.get_decay() * 1000
|
||||
|
||||
return result
|
||||
|
||||
def forward_decoder(
|
||||
self,
|
||||
x,
|
||||
feature_extractor,
|
||||
decoder,
|
||||
mask_info,
|
||||
):
|
||||
x = feature_extractor.decoder_input(x, mask_info)
|
||||
x = decoder(*x)
|
||||
|
||||
return x
|
||||
|
||||
def d2v_loss(self, x, y):
|
||||
x = x.view(-1, x.size(-1)).float()
|
||||
y = y.view(-1, x.size(-1))
|
||||
|
||||
if self.loss_beta == 0:
|
||||
loss = F.mse_loss(x, y, reduction="none")
|
||||
else:
|
||||
loss = F.smooth_l1_loss(x, y, reduction="none", beta=self.loss_beta)
|
||||
|
||||
if self.loss_scale is not None:
|
||||
scale = self.loss_scale
|
||||
else:
|
||||
scale = 1 / math.sqrt(x.size(-1))
|
||||
|
||||
reg_loss = loss * scale
|
||||
|
||||
return reg_loss
|
||||
|
||||
def make_targets(self, y, num_layers):
|
||||
|
||||
with torch.no_grad():
|
||||
target_layer_results = y[-num_layers:]
|
||||
|
||||
permuted = False
|
||||
if self.cfg.instance_norm_target_layer or self.cfg.batch_norm_target_layer:
|
||||
target_layer_results = [
|
||||
tl.transpose(1, 2) for tl in target_layer_results # BTC -> BCT
|
||||
]
|
||||
permuted = True
|
||||
if self.cfg.batch_norm_target_layer:
|
||||
target_layer_results = [
|
||||
F.batch_norm(
|
||||
tl.float(), running_mean=None, running_var=None, training=True
|
||||
)
|
||||
for tl in target_layer_results
|
||||
]
|
||||
if self.cfg.instance_norm_target_layer:
|
||||
target_layer_results = [
|
||||
F.instance_norm(tl.float()) for tl in target_layer_results
|
||||
]
|
||||
if permuted:
|
||||
target_layer_results = [
|
||||
tl.transpose(1, 2) for tl in target_layer_results # BCT -> BTC
|
||||
]
|
||||
if self.cfg.layer_norm_target_layer:
|
||||
target_layer_results = [
|
||||
F.layer_norm(tl.float(), tl.shape[-1:])
|
||||
for tl in target_layer_results
|
||||
]
|
||||
|
||||
y = target_layer_results[0].float()
|
||||
for tl in target_layer_results[1:]:
|
||||
y.add_(tl.float())
|
||||
y = y.div_(len(target_layer_results))
|
||||
|
||||
if self.cfg.layer_norm_targets:
|
||||
y = F.layer_norm(y, y.shape[-1:])
|
||||
|
||||
if self.cfg.instance_norm_targets:
|
||||
y = F.instance_norm(y.transpose(1, 2)).transpose(1, 2)
|
||||
|
||||
return y
|
||||
|
||||
@staticmethod
|
||||
def compute_var(y):
|
||||
y = y.view(-1, y.size(-1))
|
||||
if dist.is_initialized():
|
||||
zc = torch.tensor(y.size(0)).cuda()
|
||||
zs = y.sum(dim=0)
|
||||
zss = (y**2).sum(dim=0)
|
||||
|
||||
dist.all_reduce(zc)
|
||||
dist.all_reduce(zs)
|
||||
dist.all_reduce(zss)
|
||||
|
||||
var = zss / (zc - 1) - (zs**2) / (zc * (zc - 1))
|
||||
return torch.sqrt(var + 1e-6).mean()
|
||||
else:
|
||||
return torch.sqrt(y.var(dim=0) + 1e-6).mean()
|
||||
|
||||
def extract_features(
|
||||
self, source, mode=None, padding_mask=None, mask=False, remove_extra_tokens=True
|
||||
):
|
||||
res = self.forward(
|
||||
source,
|
||||
mode=mode,
|
||||
padding_mask=padding_mask,
|
||||
mask=mask,
|
||||
features_only=True,
|
||||
remove_extra_tokens=remove_extra_tokens,
|
||||
)
|
||||
return res
|
||||
|
||||
def remove_pretraining_modules(self, modality=None, keep_decoder=False):
|
||||
self.ema = None
|
||||
self.cfg.clone_batch = 1
|
||||
self.recon_proj = None
|
||||
|
||||
if not keep_decoder:
|
||||
self.shared_decoder = None
|
||||
|
||||
modality = modality.lower() if modality is not None else None
|
||||
for k in list(self.modality_encoders.keys()):
|
||||
if modality is not None and k.lower() != modality:
|
||||
del self.modality_encoders[k]
|
||||
else:
|
||||
self.modality_encoders[k].remove_pretraining_modules(
|
||||
keep_decoder=keep_decoder
|
||||
)
|
||||
if not keep_decoder:
|
||||
self.modality_encoders[k].decoder = None
|
143
examples/data2vec/models/data2vec_image_classification.py
Normal file
143
examples/data2vec/models/data2vec_image_classification.py
Normal file
@ -0,0 +1,143 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
# The code in this file is adapted from the BeiT implementation which can be found here:
|
||||
# https://github.com/microsoft/unilm/tree/master/beit
|
||||
|
||||
import logging
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from omegaconf import II, MISSING
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from fairseq import checkpoint_utils, tasks
|
||||
|
||||
from fairseq.dataclass import FairseqDataclass
|
||||
from fairseq.models import BaseFairseqModel, register_model
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Data2VecImageClassificationConfig(FairseqDataclass):
|
||||
model_path: str = MISSING
|
||||
no_pretrained_weights: bool = False
|
||||
num_classes: int = 1000
|
||||
mixup: float = 0.8
|
||||
cutmix: float = 1.0
|
||||
label_smoothing: float = 0.1
|
||||
|
||||
pretrained_model_args: Any = None
|
||||
data: str = II("task.data")
|
||||
|
||||
|
||||
@register_model(
|
||||
"data2vec_image_classification", dataclass=Data2VecImageClassificationConfig
|
||||
)
|
||||
class Data2VecImageClassificationModel(BaseFairseqModel):
|
||||
def __init__(self, cfg: Data2VecImageClassificationConfig):
|
||||
super().__init__()
|
||||
self.cfg = cfg
|
||||
|
||||
if cfg.pretrained_model_args is None:
|
||||
state = checkpoint_utils.load_checkpoint_to_cpu(cfg.model_path, {})
|
||||
pretrained_args = state.get("cfg", None)
|
||||
pretrained_args.criterion = None
|
||||
pretrained_args.lr_scheduler = None
|
||||
cfg.pretrained_model_args = pretrained_args
|
||||
|
||||
logger.info(pretrained_args)
|
||||
else:
|
||||
state = None
|
||||
pretrained_args = cfg.pretrained_model_args
|
||||
|
||||
pretrained_args.task.data = cfg.data
|
||||
task = tasks.setup_task(pretrained_args.task)
|
||||
model = task.build_model(pretrained_args.model, from_checkpoint=True)
|
||||
|
||||
model.remove_pretraining_modules()
|
||||
|
||||
self.model = model
|
||||
|
||||
if state is not None and not cfg.no_pretrained_weights:
|
||||
self.load_model_weights(state, model, cfg)
|
||||
|
||||
self.fc_norm = nn.LayerNorm(pretrained_args.model.embed_dim)
|
||||
self.head = nn.Linear(pretrained_args.model.embed_dim, cfg.num_classes)
|
||||
|
||||
self.head.weight.data.mul_(1e-3)
|
||||
self.head.bias.data.mul_(1e-3)
|
||||
|
||||
self.mixup_fn = None
|
||||
|
||||
if cfg.mixup > 0 or cfg.cutmix > 0:
|
||||
from timm.data import Mixup
|
||||
|
||||
self.mixup_fn = Mixup(
|
||||
mixup_alpha=cfg.mixup,
|
||||
cutmix_alpha=cfg.cutmix,
|
||||
cutmix_minmax=None,
|
||||
prob=1.0,
|
||||
switch_prob=0.5,
|
||||
mode="batch",
|
||||
label_smoothing=cfg.label_smoothing,
|
||||
num_classes=cfg.num_classes,
|
||||
)
|
||||
|
||||
def load_model_weights(self, state, model, cfg):
|
||||
if "_ema" in state["model"]:
|
||||
del state["model"]["_ema"]
|
||||
model.load_state_dict(state["model"], strict=True)
|
||||
|
||||
@classmethod
|
||||
def build_model(cls, cfg: Data2VecImageClassificationConfig, task=None):
|
||||
"""Build a new model instance."""
|
||||
|
||||
return cls(cfg)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
img,
|
||||
label=None,
|
||||
):
|
||||
if self.training and self.mixup_fn is not None and label is not None:
|
||||
img, label = self.mixup_fn(img, label)
|
||||
|
||||
x = self.model(img, mask=False)
|
||||
x = x[:, 1:]
|
||||
x = self.fc_norm(x.mean(1))
|
||||
x = self.head(x)
|
||||
|
||||
if label is None:
|
||||
return x
|
||||
|
||||
if self.training and self.mixup_fn is not None:
|
||||
loss = -label * F.log_softmax(x.float(), dim=-1)
|
||||
else:
|
||||
loss = F.cross_entropy(
|
||||
x.float(),
|
||||
label,
|
||||
label_smoothing=self.cfg.label_smoothing if self.training else 0,
|
||||
reduction="none",
|
||||
)
|
||||
|
||||
result = {
|
||||
"losses": {"regression": loss},
|
||||
"sample_size": img.size(0),
|
||||
}
|
||||
|
||||
if not self.training:
|
||||
with torch.no_grad():
|
||||
pred = x.argmax(-1)
|
||||
correct = (pred == label).sum()
|
||||
result["correct"] = correct
|
||||
|
||||
return result
|
141
examples/data2vec/models/data2vec_text_classification.py
Normal file
141
examples/data2vec/models/data2vec_text_classification.py
Normal file
@ -0,0 +1,141 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
# The code in this file is adapted from the BeiT implementation which can be found here:
|
||||
# https://github.com/microsoft/unilm/tree/master/beit
|
||||
|
||||
import logging
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from omegaconf import II, MISSING
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from fairseq import checkpoint_utils, tasks
|
||||
|
||||
from fairseq.dataclass import FairseqDataclass
|
||||
from fairseq.models import BaseFairseqModel, register_model
|
||||
from fairseq.models.roberta.model import RobertaClassificationHead
|
||||
|
||||
from examples.data2vec.data.modality import Modality
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Data2VecTextClassificationConfig(FairseqDataclass):
|
||||
pooler_dropout: float = 0.0
|
||||
pooler_activation_fn: str = "tanh"
|
||||
quant_noise_pq: int = 0
|
||||
quant_noise_pq_block_size: int = 8
|
||||
spectral_norm_classification_head: bool = False
|
||||
|
||||
model_path: str = MISSING
|
||||
no_pretrained_weights: bool = False
|
||||
|
||||
pretrained_model_args: Any = None
|
||||
|
||||
|
||||
@register_model(
|
||||
"data2vec_text_classification", dataclass=Data2VecTextClassificationConfig
|
||||
)
|
||||
class Data2VecTextClassificationModel(BaseFairseqModel):
|
||||
def __init__(self, cfg: Data2VecTextClassificationConfig):
|
||||
super().__init__()
|
||||
self.cfg = cfg
|
||||
|
||||
if cfg.pretrained_model_args is None:
|
||||
state = checkpoint_utils.load_checkpoint_to_cpu(cfg.model_path, {})
|
||||
pretrained_args = state.get("cfg", None)
|
||||
pretrained_args.criterion = None
|
||||
pretrained_args.lr_scheduler = None
|
||||
cfg.pretrained_model_args = pretrained_args
|
||||
|
||||
logger.info(pretrained_args)
|
||||
else:
|
||||
state = None
|
||||
pretrained_args = cfg.pretrained_model_args
|
||||
|
||||
task = tasks.setup_task(pretrained_args.task)
|
||||
model = task.build_model(pretrained_args.model, from_checkpoint=True)
|
||||
|
||||
model.remove_pretraining_modules()
|
||||
|
||||
self.model = model
|
||||
|
||||
if state is not None and not cfg.no_pretrained_weights:
|
||||
self.load_model_weights(state, model, cfg)
|
||||
|
||||
self.classification_heads = nn.ModuleDict()
|
||||
|
||||
|
||||
def load_model_weights(self, state, model, cfg):
|
||||
for k in list(state["model"].keys()):
|
||||
if (
|
||||
k.startswith("shared_decoder") or
|
||||
k.startswith("_ema") or
|
||||
"decoder" in k
|
||||
):
|
||||
logger.info(f"Deleting {k} from checkpoint")
|
||||
del state["model"][k]
|
||||
model.load_state_dict(state["model"], strict=True)
|
||||
|
||||
@classmethod
|
||||
def build_model(cls, cfg: Data2VecTextClassificationConfig, task=None):
|
||||
"""Build a new model instance."""
|
||||
|
||||
return cls(cfg)
|
||||
|
||||
def register_classification_head(
|
||||
self, name, num_classes=None, inner_dim=None, **kwargs
|
||||
):
|
||||
"""Register a classification head."""
|
||||
if name in self.classification_heads:
|
||||
prev_num_classes = self.classification_heads[name].out_proj.out_features
|
||||
prev_inner_dim = self.classification_heads[name].dense.out_features
|
||||
if num_classes != prev_num_classes or inner_dim != prev_inner_dim:
|
||||
logger.warning(
|
||||
're-registering head "{}" with num_classes {} (prev: {}) '
|
||||
"and inner_dim {} (prev: {})".format(
|
||||
name, num_classes, prev_num_classes, inner_dim, prev_inner_dim
|
||||
)
|
||||
)
|
||||
embed_dim = self.cfg.pretrained_model_args.model.embed_dim
|
||||
self.classification_heads[name] = RobertaClassificationHead(
|
||||
input_dim=embed_dim,
|
||||
inner_dim=inner_dim or embed_dim,
|
||||
num_classes=num_classes,
|
||||
activation_fn=self.cfg.pooler_activation_fn,
|
||||
pooler_dropout=self.cfg.pooler_dropout,
|
||||
q_noise=self.cfg.quant_noise_pq,
|
||||
qn_block_size=self.cfg.quant_noise_pq_block_size,
|
||||
do_spectral_norm=self.cfg.spectral_norm_classification_head,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
source,
|
||||
id,
|
||||
padding_mask,
|
||||
features_only=True,
|
||||
remove_extra_tokens=True,
|
||||
classification_head_name=None,
|
||||
):
|
||||
encoder_out = self.model(
|
||||
source,
|
||||
id=id,
|
||||
mode=Modality.TEXT,
|
||||
padding_mask=padding_mask,
|
||||
mask=False,
|
||||
features_only=features_only,
|
||||
remove_extra_tokens=remove_extra_tokens
|
||||
)
|
||||
logits = self.classification_heads[classification_head_name](encoder_out["x"])
|
||||
return logits, encoder_out
|
727
examples/data2vec/models/data2vec_vision.py
Normal file
727
examples/data2vec/models/data2vec_vision.py
Normal file
@ -0,0 +1,727 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
# The code in this file is adapted from the BeiT implementation which can be found here:
|
||||
# https://github.com/microsoft/unilm/tree/master/beit
|
||||
|
||||
import logging
|
||||
import math
|
||||
import numpy as np
|
||||
import random
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
from omegaconf import II
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.distributed as dist
|
||||
|
||||
from fairseq.modules import EMAModule, EMAModuleConfig
|
||||
from fairseq.dataclass import FairseqDataclass
|
||||
from fairseq.models import BaseFairseqModel, register_model
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Data2VecVisionConfig(FairseqDataclass):
|
||||
layer_scale_init_value: float = field(
|
||||
default=1e-4, metadata={"help": "rescale layer outputs, 0 to disable"}
|
||||
)
|
||||
num_mask_patches: int = field(
|
||||
default=75,
|
||||
metadata={"help": "number of the visual tokens/patches need be masked"},
|
||||
)
|
||||
min_mask_patches_per_block: int = 16
|
||||
max_mask_patches_per_block: int = 196
|
||||
image_size: int = 224
|
||||
patch_size: int = 16
|
||||
in_channels: int = 3
|
||||
|
||||
shared_rel_pos_bias: bool = True
|
||||
|
||||
drop_path: float = 0.1
|
||||
attention_dropout: float = 0.0
|
||||
|
||||
depth: int = 12
|
||||
embed_dim: int = 768
|
||||
num_heads: int = 12
|
||||
mlp_ratio: int = 4
|
||||
|
||||
loss_beta: float = field(
|
||||
default=0, metadata={"help": "beta for smooth l1 loss. 0 means use l2 loss"}
|
||||
)
|
||||
loss_scale: Optional[float] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "scale the reconstruction loss by this constant. if None then scales by 1/sqrt(dim)"
|
||||
},
|
||||
)
|
||||
average_top_k_layers: int = field(
|
||||
default=8, metadata={"help": "how many layers to average"}
|
||||
)
|
||||
|
||||
end_of_block_targets: bool = True
|
||||
layer_norm_target_layer: bool = False
|
||||
instance_norm_target_layer: bool = False
|
||||
batch_norm_target_layer: bool = False
|
||||
instance_norm_targets: bool = False
|
||||
layer_norm_targets: bool = False
|
||||
|
||||
ema_decay: float = field(default=0.999, metadata={"help": "initial ema decay rate"})
|
||||
ema_end_decay: float = field(
|
||||
default=0.9999, metadata={"help": "final ema decay rate"}
|
||||
)
|
||||
|
||||
# when to finish annealing ema decay rate
|
||||
ema_anneal_end_step: int = II("optimization.max_update")
|
||||
|
||||
ema_transformer_only: bool = field(
|
||||
default=True,
|
||||
metadata={"help": "whether to momentum update only the transformer layers"},
|
||||
)
|
||||
|
||||
|
||||
def get_annealed_rate(start, end, curr_step, total_steps):
|
||||
r = end - start
|
||||
pct_remaining = 1 - curr_step / total_steps
|
||||
return end - r * pct_remaining
|
||||
|
||||
|
||||
@register_model("data2vec_vision", dataclass=Data2VecVisionConfig)
|
||||
class Data2VecVisionModel(BaseFairseqModel):
|
||||
def __init__(self, cfg: Data2VecVisionConfig):
|
||||
super().__init__()
|
||||
self.cfg = cfg
|
||||
|
||||
self.ema = None
|
||||
|
||||
self.average_top_k_layers = cfg.average_top_k_layers
|
||||
self.loss_beta = cfg.loss_beta
|
||||
self.loss_scale = (
|
||||
cfg.loss_scale
|
||||
if cfg.loss_scale is not None
|
||||
else 1 / math.sqrt(cfg.embed_dim)
|
||||
)
|
||||
|
||||
self.patch_embed = PatchEmbed(
|
||||
img_size=cfg.image_size,
|
||||
patch_size=cfg.patch_size,
|
||||
in_chans=cfg.in_channels,
|
||||
embed_dim=cfg.embed_dim,
|
||||
)
|
||||
|
||||
patch_size = self.patch_embed.patch_size
|
||||
self.window_size = (
|
||||
cfg.image_size // patch_size[0],
|
||||
cfg.image_size // patch_size[1],
|
||||
)
|
||||
|
||||
self.cls_emb = nn.Parameter(torch.FloatTensor(1, 1, cfg.embed_dim))
|
||||
self.mask_emb = nn.Parameter(torch.FloatTensor(1, 1, cfg.embed_dim))
|
||||
|
||||
nn.init.trunc_normal_(self.cls_emb, 0.02)
|
||||
nn.init.trunc_normal_(self.mask_emb, 0.02)
|
||||
|
||||
self.encoder = TransformerEncoder(cfg, self.patch_embed.patch_shape)
|
||||
|
||||
self.final_proj = nn.Linear(cfg.embed_dim, cfg.embed_dim)
|
||||
self.num_updates = 0
|
||||
|
||||
def make_ema_teacher(self):
|
||||
ema_config = EMAModuleConfig(
|
||||
ema_decay=self.cfg.ema_decay,
|
||||
ema_fp32=True,
|
||||
)
|
||||
self.ema = EMAModule(
|
||||
self.encoder if self.cfg.ema_transformer_only else self,
|
||||
ema_config,
|
||||
)
|
||||
|
||||
def set_num_updates(self, num_updates):
|
||||
super().set_num_updates(num_updates)
|
||||
|
||||
if self.ema is None and self.final_proj is not None:
|
||||
logger.info(f"making ema teacher")
|
||||
self.make_ema_teacher()
|
||||
elif self.training and self.ema is not None:
|
||||
if self.cfg.ema_decay != self.cfg.ema_end_decay:
|
||||
if num_updates >= self.cfg.ema_anneal_end_step:
|
||||
decay = self.cfg.ema_end_decay
|
||||
else:
|
||||
decay = get_annealed_rate(
|
||||
self.cfg.ema_decay,
|
||||
self.cfg.ema_end_decay,
|
||||
num_updates,
|
||||
self.cfg.ema_anneal_end_step,
|
||||
)
|
||||
self.ema.set_decay(decay)
|
||||
if self.ema.get_decay() < 1:
|
||||
self.ema.step(self.encoder if self.cfg.ema_transformer_only else self)
|
||||
|
||||
self.num_updates = num_updates
|
||||
|
||||
def state_dict(self, destination=None, prefix="", keep_vars=False):
|
||||
state = super().state_dict(destination, prefix, keep_vars)
|
||||
|
||||
if self.ema is not None:
|
||||
state[prefix + "_ema"] = self.ema.fp32_params
|
||||
|
||||
return state
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||
if self.ema is not None:
|
||||
k = prefix + "_ema"
|
||||
assert k in state_dict
|
||||
self.ema.restore(state_dict[k], True)
|
||||
del state_dict[k]
|
||||
return super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def build_model(cls, cfg: Data2VecVisionConfig, task=None):
|
||||
"""Build a new model instance."""
|
||||
|
||||
return cls(cfg)
|
||||
|
||||
def make_mask(self, bsz, num_masks, min_masks, max_masks):
|
||||
height, width = self.window_size
|
||||
|
||||
masks = np.zeros(shape=(bsz, height, width), dtype=np.int)
|
||||
|
||||
for i in range(bsz):
|
||||
mask = masks[i]
|
||||
mask_count = 0
|
||||
|
||||
min_aspect = 0.3
|
||||
max_aspect = 1 / min_aspect
|
||||
log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect))
|
||||
|
||||
def _mask(mask, max_mask_patches):
|
||||
delta = 0
|
||||
for attempt in range(10):
|
||||
target_area = random.uniform(min_masks, max_mask_patches)
|
||||
aspect_ratio = math.exp(random.uniform(*log_aspect_ratio))
|
||||
h = int(round(math.sqrt(target_area * aspect_ratio)))
|
||||
w = int(round(math.sqrt(target_area / aspect_ratio)))
|
||||
if w < width and h < height:
|
||||
top = random.randint(0, height - h)
|
||||
left = random.randint(0, width - w)
|
||||
|
||||
num_masked = mask[top : top + h, left : left + w].sum()
|
||||
# Overlap
|
||||
if 0 < h * w - num_masked <= max_mask_patches:
|
||||
for i in range(top, top + h):
|
||||
for j in range(left, left + w):
|
||||
if mask[i, j] == 0:
|
||||
mask[i, j] = 1
|
||||
delta += 1
|
||||
|
||||
if delta > 0:
|
||||
break
|
||||
return delta
|
||||
|
||||
while mask_count < num_masks:
|
||||
max_mask_patches = min(num_masks - mask_count, max_masks)
|
||||
|
||||
delta = _mask(mask, max_mask_patches)
|
||||
if delta == 0:
|
||||
break
|
||||
else:
|
||||
mask_count += delta
|
||||
|
||||
return torch.from_numpy(masks)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
img,
|
||||
mask: bool = True,
|
||||
layer_results: bool = False,
|
||||
):
|
||||
x = self.patch_embed(img)
|
||||
batch_size, seq_len, _ = x.size()
|
||||
|
||||
if mask:
|
||||
mask_indices = self.make_mask(
|
||||
img.size(0),
|
||||
self.cfg.num_mask_patches,
|
||||
self.cfg.min_mask_patches_per_block,
|
||||
self.cfg.max_mask_patches_per_block,
|
||||
)
|
||||
bool_mask = mask_indices.view(mask_indices.size(0), -1).bool()
|
||||
else:
|
||||
mask_indices = bool_mask = None
|
||||
|
||||
cls_tokens = self.cls_emb.expand(batch_size, -1, -1)
|
||||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
|
||||
if self.ema is not None:
|
||||
with torch.no_grad():
|
||||
self.ema.model.eval()
|
||||
|
||||
if self.cfg.ema_transformer_only:
|
||||
y = self.ema.model(
|
||||
x,
|
||||
layer_results="end" if self.cfg.end_of_block_targets else "fc",
|
||||
)
|
||||
else:
|
||||
y = self.ema.model(
|
||||
img,
|
||||
mask=False,
|
||||
layer_results=True,
|
||||
)
|
||||
|
||||
y = y[-self.cfg.average_top_k_layers :]
|
||||
|
||||
permuted = False
|
||||
if self.cfg.instance_norm_target_layer or self.cfg.batch_norm_target_layer:
|
||||
y = [tl.transpose(1, 2) for tl in y] # BTC -> BCT
|
||||
permuted = True
|
||||
|
||||
if self.cfg.batch_norm_target_layer:
|
||||
y = [
|
||||
F.batch_norm(
|
||||
tl.float(), running_mean=None, running_var=None, training=True
|
||||
)
|
||||
for tl in y
|
||||
]
|
||||
|
||||
if self.cfg.instance_norm_target_layer:
|
||||
y = [F.instance_norm(tl.float()) for tl in y]
|
||||
|
||||
if permuted:
|
||||
y = [tl.transpose(1, 2) for tl in y] # BCT -> BTC
|
||||
|
||||
if self.cfg.layer_norm_target_layer:
|
||||
y = [F.layer_norm(tl.float(), tl.shape[-1:]) for tl in y]
|
||||
|
||||
y = sum(y) / len(y)
|
||||
|
||||
if self.cfg.layer_norm_targets:
|
||||
y = F.layer_norm(y.float(), y.shape[-1:])
|
||||
|
||||
if self.cfg.instance_norm_targets:
|
||||
y = F.instance_norm(y.float().transpose(1, 2)).transpose(1, 2)
|
||||
|
||||
y = y[bool_mask].float()
|
||||
|
||||
if mask_indices is not None:
|
||||
mask_token = self.mask_emb.expand(batch_size, seq_len, -1)
|
||||
w = mask_indices.view(mask_indices.size(0), -1, 1).type_as(mask_token)
|
||||
x[:, 1:] = x[:, 1:] * (1 - w) + mask_token * w
|
||||
|
||||
if layer_results:
|
||||
enc_layer_results = "end" if self.cfg.end_of_block_targets else "fc"
|
||||
else:
|
||||
enc_layer_results = None
|
||||
|
||||
x = self.encoder(x, layer_results=enc_layer_results)
|
||||
if layer_results or mask_indices is None:
|
||||
return x
|
||||
|
||||
x = x[bool_mask].float()
|
||||
|
||||
if self.loss_beta == 0:
|
||||
loss = F.mse_loss(x, y, reduction="none").sum(dim=-1)
|
||||
else:
|
||||
loss = F.smooth_l1_loss(x, y, reduction="none", beta=self.loss_beta).sum(
|
||||
dim=-1
|
||||
)
|
||||
|
||||
if self.loss_scale > 0:
|
||||
loss = loss * self.loss_scale
|
||||
|
||||
result = {
|
||||
"losses": {"regression": loss.sum()},
|
||||
"sample_size": loss.numel(),
|
||||
"target_var": self.compute_var(y),
|
||||
"pred_var": self.compute_var(x),
|
||||
"ema_decay": self.ema.get_decay() * 1000,
|
||||
}
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def compute_var(y):
|
||||
y = y.view(-1, y.size(-1))
|
||||
if dist.is_initialized():
|
||||
zc = torch.tensor(y.size(0)).cuda()
|
||||
zs = y.sum(dim=0)
|
||||
zss = (y ** 2).sum(dim=0)
|
||||
|
||||
dist.all_reduce(zc)
|
||||
dist.all_reduce(zs)
|
||||
dist.all_reduce(zss)
|
||||
|
||||
var = zss / (zc - 1) - (zs ** 2) / (zc * (zc - 1))
|
||||
return torch.sqrt(var + 1e-6).mean()
|
||||
else:
|
||||
return torch.sqrt(y.var(dim=0) + 1e-6).mean()
|
||||
|
||||
def remove_pretraining_modules(self, last_layer=None):
|
||||
self.final_proj = None
|
||||
self.ema = None
|
||||
self.encoder.norm = nn.Identity()
|
||||
self.mask_emb = None
|
||||
if last_layer is not None:
|
||||
self.encoder.layers = nn.ModuleList(
|
||||
l for i, l in enumerate(self.encoder.layers) if i <= last_layer
|
||||
)
|
||||
|
||||
|
||||
class PatchEmbed(nn.Module):
|
||||
"""Image to Patch Embedding"""
|
||||
|
||||
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
|
||||
super().__init__()
|
||||
if isinstance(img_size, int):
|
||||
img_size = img_size, img_size
|
||||
if isinstance(patch_size, int):
|
||||
patch_size = patch_size, patch_size
|
||||
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
|
||||
self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
|
||||
self.img_size = img_size
|
||||
self.patch_size = patch_size
|
||||
self.num_patches = num_patches
|
||||
|
||||
self.conv = nn.Conv2d(
|
||||
in_chans, embed_dim, kernel_size=patch_size, stride=patch_size
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
# BCHW -> BTC
|
||||
x = self.conv(x).flatten(2).transpose(1, 2)
|
||||
return x
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
num_heads=8,
|
||||
qkv_bias=True,
|
||||
attn_drop=0.0,
|
||||
proj_drop=0.0,
|
||||
window_size=None,
|
||||
attn_head_dim=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
if attn_head_dim is not None:
|
||||
head_dim = attn_head_dim
|
||||
all_head_dim = head_dim * self.num_heads
|
||||
self.scale = head_dim ** -0.5
|
||||
|
||||
self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
|
||||
if qkv_bias:
|
||||
self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
|
||||
self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
|
||||
else:
|
||||
self.q_bias = None
|
||||
self.v_bias = None
|
||||
|
||||
if window_size:
|
||||
self.window_size = window_size
|
||||
self.num_relative_distance = (2 * window_size[0] - 1) * (
|
||||
2 * window_size[1] - 1
|
||||
) + 3
|
||||
self.relative_position_bias_table = nn.Parameter(
|
||||
torch.zeros(self.num_relative_distance, num_heads)
|
||||
) # 2*Wh-1 * 2*Ww-1, nH
|
||||
# cls to token & token 2 cls & cls to cls
|
||||
|
||||
# get pair-wise relative position index for each token inside the window
|
||||
coords_h = torch.arange(window_size[0])
|
||||
coords_w = torch.arange(window_size[1])
|
||||
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
||||
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
||||
relative_coords = (
|
||||
coords_flatten[:, :, None] - coords_flatten[:, None, :]
|
||||
) # 2, Wh*Ww, Wh*Ww
|
||||
relative_coords = relative_coords.permute(
|
||||
1, 2, 0
|
||||
).contiguous() # Wh*Ww, Wh*Ww, 2
|
||||
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
|
||||
relative_coords[:, :, 1] += window_size[1] - 1
|
||||
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
|
||||
relative_position_index = torch.zeros(
|
||||
size=(window_size[0] * window_size[1] + 1,) * 2,
|
||||
dtype=relative_coords.dtype,
|
||||
)
|
||||
relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
||||
relative_position_index[0, 0:] = self.num_relative_distance - 3
|
||||
relative_position_index[0:, 0] = self.num_relative_distance - 2
|
||||
relative_position_index[0, 0] = self.num_relative_distance - 1
|
||||
|
||||
self.register_buffer("relative_position_index", relative_position_index)
|
||||
else:
|
||||
self.window_size = None
|
||||
self.relative_position_bias_table = None
|
||||
self.relative_position_index = None
|
||||
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj = nn.Linear(all_head_dim, dim)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
|
||||
def forward(self, x, rel_pos_bias=None):
|
||||
B, N, C = x.shape
|
||||
qkv_bias = None
|
||||
if self.q_bias is not None:
|
||||
qkv_bias = torch.cat(
|
||||
(
|
||||
self.q_bias,
|
||||
torch.zeros_like(self.v_bias, requires_grad=False),
|
||||
self.v_bias,
|
||||
)
|
||||
)
|
||||
# qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
||||
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
|
||||
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = (
|
||||
qkv[0],
|
||||
qkv[1],
|
||||
qkv[2],
|
||||
) # make torchscript happy (cannot use tensor as tuple)
|
||||
|
||||
q = q * self.scale
|
||||
attn = q @ k.transpose(-2, -1)
|
||||
|
||||
if self.relative_position_bias_table is not None:
|
||||
assert 1==2
|
||||
relative_position_bias = self.relative_position_bias_table[
|
||||
self.relative_position_index.view(-1)
|
||||
].view(
|
||||
self.window_size[0] * self.window_size[1] + 1,
|
||||
self.window_size[0] * self.window_size[1] + 1,
|
||||
-1,
|
||||
) # Wh*Ww,Wh*Ww,nH
|
||||
relative_position_bias = relative_position_bias.permute(
|
||||
2, 0, 1
|
||||
).contiguous() # nH, Wh*Ww, Wh*Ww
|
||||
attn = attn + relative_position_bias.unsqueeze(0)
|
||||
print("attn.size() :", attn.size())
|
||||
print("rel_pos_bias.size() :", rel_pos_bias.size())
|
||||
if rel_pos_bias is not None:
|
||||
attn = attn + rel_pos_bias
|
||||
attn = attn.softmax(dim=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
|
||||
x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
||||
|
||||
class RelativePositionBias(nn.Module):
|
||||
def __init__(self, window_size, num_heads):
|
||||
super().__init__()
|
||||
self.window_size = window_size
|
||||
self.num_relative_distance = (2 * window_size[0] - 1) * (
|
||||
2 * window_size[1] - 1
|
||||
) + 3
|
||||
self.relative_position_bias_table = nn.Parameter(
|
||||
torch.zeros(self.num_relative_distance, num_heads)
|
||||
)
|
||||
|
||||
# get pair-wise relative position index for each token inside the window
|
||||
coords_h = torch.arange(window_size[0])
|
||||
coords_w = torch.arange(window_size[1])
|
||||
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
||||
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
||||
relative_coords = (
|
||||
coords_flatten[:, :, None] - coords_flatten[:, None, :]
|
||||
) # 2, Wh*Ww, Wh*Ww
|
||||
relative_coords = relative_coords.permute(
|
||||
1, 2, 0
|
||||
).contiguous() # Wh*Ww, Wh*Ww, 2
|
||||
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
|
||||
relative_coords[:, :, 1] += window_size[1] - 1
|
||||
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
|
||||
relative_position_index = torch.zeros(
|
||||
size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype
|
||||
)
|
||||
relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
||||
relative_position_index[0, 0:] = self.num_relative_distance - 3
|
||||
relative_position_index[0:, 0] = self.num_relative_distance - 2
|
||||
relative_position_index[0, 0] = self.num_relative_distance - 1
|
||||
|
||||
self.register_buffer("relative_position_index", relative_position_index)
|
||||
|
||||
def forward(self):
|
||||
relative_position_bias = self.relative_position_bias_table[
|
||||
self.relative_position_index.view(-1)
|
||||
].view(
|
||||
self.window_size[0] * self.window_size[1] + 1,
|
||||
self.window_size[0] * self.window_size[1] + 1,
|
||||
-1,
|
||||
) # Wh*Ww,Wh*Ww,nH
|
||||
print("self.window_size :", self.window_size)
|
||||
print("self.num_relative_distance :", self.num_relative_distance)
|
||||
print("self.relative_position_index :", self.relative_position_index.size(), self.relative_position_index)
|
||||
print("relative_position_bias.size(), relative_position_bias :",relative_position_bias.size(), relative_position_bias)
|
||||
print("self.relative_position_bias_table.size(), self.relative_position_bias_table :",self.relative_position_bias_table.size(), self.relative_position_bias_table)
|
||||
return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
||||
|
||||
|
||||
class DropPath(nn.Module):
|
||||
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
||||
|
||||
def __init__(self, drop_prob=None):
|
||||
super(DropPath, self).__init__()
|
||||
self.drop_prob = drop_prob
|
||||
|
||||
def forward(self, x):
|
||||
if self.drop_prob == 0.0 or not self.training:
|
||||
return x
|
||||
keep_prob = 1 - self.drop_prob
|
||||
shape = (x.shape[0],) + (1,) * (
|
||||
x.ndim - 1
|
||||
) # work with diff dim tensors, not just 2D ConvNets
|
||||
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
|
||||
random_tensor.floor_()
|
||||
output = x.div(keep_prob) * random_tensor
|
||||
return output
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
return "p={}".format(self.drop_prob)
|
||||
|
||||
|
||||
class Block(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
num_heads,
|
||||
mlp_ratio=4.0,
|
||||
drop=0.0,
|
||||
attn_drop=0.0,
|
||||
drop_path=0.0,
|
||||
init_values=None,
|
||||
window_size=None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.norm1 = nn.LayerNorm(dim)
|
||||
self.attn = Attention(
|
||||
dim,
|
||||
num_heads=num_heads,
|
||||
attn_drop=attn_drop,
|
||||
proj_drop=drop,
|
||||
window_size=window_size,
|
||||
)
|
||||
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
||||
self.norm2 = nn.LayerNorm(dim)
|
||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(dim, mlp_hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Linear(mlp_hidden_dim, dim),
|
||||
nn.Dropout(drop),
|
||||
)
|
||||
|
||||
if init_values > 0:
|
||||
self.gamma_1 = nn.Parameter(
|
||||
init_values * torch.ones((dim)), requires_grad=True
|
||||
)
|
||||
self.gamma_2 = nn.Parameter(
|
||||
init_values * torch.ones((dim)), requires_grad=True
|
||||
)
|
||||
else:
|
||||
self.gamma_1, self.gamma_2 = None, None
|
||||
|
||||
def forward(self, x, rel_pos_bias=None):
|
||||
print("inside block :", x.size())
|
||||
if self.gamma_1 is None:
|
||||
x = x + self.drop_path(self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias))
|
||||
fc_feature = self.drop_path(self.mlp(self.norm2(x)))
|
||||
x = x + fc_feature
|
||||
else:
|
||||
x = x + self.drop_path(
|
||||
self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias)
|
||||
)
|
||||
fc_feature = self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
|
||||
x = x + fc_feature
|
||||
return x, fc_feature
|
||||
|
||||
|
||||
class TransformerEncoder(nn.Module):
|
||||
def __init__(self, cfg: Data2VecVisionConfig, patch_shape):
|
||||
super().__init__()
|
||||
|
||||
self.rel_pos_bias = None
|
||||
if cfg.shared_rel_pos_bias:
|
||||
self.rel_pos_bias = RelativePositionBias(
|
||||
window_size=patch_shape, num_heads=cfg.num_heads
|
||||
)
|
||||
|
||||
dpr = [
|
||||
x.item() for x in torch.linspace(0, cfg.drop_path, cfg.depth)
|
||||
] # stochastic depth decay rule
|
||||
|
||||
print("TransformerEncoder > patch_shape :", patch_shape)
|
||||
self.blocks = nn.ModuleList(
|
||||
Block(
|
||||
dim=cfg.embed_dim,
|
||||
num_heads=cfg.num_heads,
|
||||
attn_drop=cfg.attention_dropout,
|
||||
drop_path=dpr[i],
|
||||
init_values=cfg.layer_scale_init_value,
|
||||
window_size=patch_shape if not cfg.shared_rel_pos_bias else None,
|
||||
)
|
||||
for i in range(cfg.depth)
|
||||
)
|
||||
|
||||
self.norm = nn.LayerNorm(cfg.embed_dim)
|
||||
|
||||
self.apply(self.init_weights)
|
||||
self.fix_init_weight()
|
||||
|
||||
def init_weights(self, m):
|
||||
std = 0.02
|
||||
if isinstance(m, nn.Linear):
|
||||
nn.init.trunc_normal_(m.weight, std=std)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
nn.init.constant_(m.bias, 0)
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
elif isinstance(m, nn.Conv2d):
|
||||
nn.init.trunc_normal_(m.weight, std=std)
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def fix_init_weight(self):
|
||||
def rescale(param, layer_id):
|
||||
param.div_(math.sqrt(2.0 * layer_id))
|
||||
|
||||
for layer_id, layer in enumerate(self.blocks):
|
||||
rescale(layer.attn.proj.weight.data, layer_id + 1)
|
||||
rescale(layer.mlp[2].weight.data, layer_id + 1)
|
||||
|
||||
def extract_features(self, x, layer_results):
|
||||
|
||||
rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
|
||||
|
||||
z = []
|
||||
for i, blk in enumerate(self.blocks):
|
||||
x, fc_feature = blk(x, rel_pos_bias=rel_pos_bias)
|
||||
if layer_results == "end":
|
||||
z.append(x)
|
||||
elif layer_results == "fc":
|
||||
z.append(fc_feature)
|
||||
|
||||
return z if layer_results else self.norm(x)
|
||||
|
||||
def forward(self, x, layer_results=None):
|
||||
x = self.extract_features(x, layer_results=layer_results)
|
||||
if layer_results:
|
||||
return [z[:, 1:] for z in x]
|
||||
|
||||
x = x[:, 1:]
|
||||
return x
|
825
examples/data2vec/models/mae.py
Normal file
825
examples/data2vec/models/mae.py
Normal file
@ -0,0 +1,825 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
# The code in this file is adapted from the BeiT implementation which can be found here:
|
||||
# https://github.com/microsoft/unilm/tree/master/beit
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
|
||||
from timm.models.vision_transformer import PatchEmbed, Block
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
import numpy as np
|
||||
|
||||
from fairseq.dataclass import FairseqDataclass
|
||||
from fairseq.models import BaseFairseqModel, register_model
|
||||
from fairseq.models.wav2vec.wav2vec2 import TransformerSentenceEncoderLayer
|
||||
|
||||
from apex.normalization import FusedLayerNorm
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MaeConfig(FairseqDataclass):
|
||||
input_size: int = 224
|
||||
in_chans: int = 3
|
||||
patch_size: int = 16
|
||||
embed_dim: int = 768
|
||||
depth: int = 12
|
||||
num_heads: int = 12
|
||||
decoder_embed_dim: int = 512
|
||||
decoder_depth: int = 8
|
||||
decoder_num_heads: int = 16
|
||||
mlp_ratio: int = 4
|
||||
norm_eps: float = 1e-6
|
||||
|
||||
drop_path_rate: float = 0.0
|
||||
|
||||
mask_ratio: float = 0.75
|
||||
norm_pix_loss: bool = True
|
||||
|
||||
w2v_block: bool = False
|
||||
alt_block: bool = False
|
||||
alt_block2: bool = False
|
||||
alt_attention: bool = False
|
||||
block_dropout: float = 0
|
||||
attention_dropout: float = 0
|
||||
activation_dropout: float = 0
|
||||
layer_norm_first: bool = False
|
||||
|
||||
fused_ln: bool = True
|
||||
end_of_block_targets: bool = True
|
||||
|
||||
no_decoder_embed: bool = False
|
||||
no_decoder_pos_embed: bool = False
|
||||
mask_noise_std: float = 0
|
||||
|
||||
single_qkv: bool = False
|
||||
use_rel_pos_bias: bool = False
|
||||
no_cls: bool = False
|
||||
|
||||
|
||||
def modify_relative_position_bias(orig_bias, bsz, mask):
|
||||
if mask is None:
|
||||
return orig_bias.unsqueeze(0).repeat(
|
||||
bsz, 1, 1, 1
|
||||
) # heads x seq_len x seq_len => bsz x heads x seq_len x seq_len
|
||||
heads, max_seq_len, max_seq_len = orig_bias.shape # includes CLS token
|
||||
mask_for_rel_pos_bias = torch.cat(
|
||||
(torch.zeros(bsz, 1, dtype=mask.dtype, device=mask.device), mask), dim=1
|
||||
).bool() # bsz x seqlen (add CLS token)
|
||||
unmasked_for_rel_pos_bias = ~mask_for_rel_pos_bias
|
||||
unmasked_for_rel_pos_bias = unmasked_for_rel_pos_bias.unsqueeze(1).repeat(
|
||||
1, heads, 1
|
||||
) # bsz x seq_len => bsz x heads x seq_len
|
||||
b_t_t_rel_pos_bias = orig_bias.unsqueeze(0).repeat(
|
||||
bsz, 1, 1, 1
|
||||
) # heads x seq_len x seq_len => bsz x heads x seq_len x seq_len
|
||||
b_t_t_rel_pos_bias = b_t_t_rel_pos_bias.masked_select(
|
||||
unmasked_for_rel_pos_bias.unsqueeze(-1)
|
||||
)
|
||||
b_t_t_rel_pos_bias = b_t_t_rel_pos_bias.view(bsz, heads, -1, max_seq_len)
|
||||
new_len = b_t_t_rel_pos_bias.size(-2)
|
||||
b_t_t_rel_pos_bias = b_t_t_rel_pos_bias.masked_select(
|
||||
unmasked_for_rel_pos_bias.unsqueeze(-2)
|
||||
)
|
||||
b_t_t_rel_pos_bias = b_t_t_rel_pos_bias.view(bsz, heads, new_len, new_len)
|
||||
return b_t_t_rel_pos_bias
|
||||
|
||||
|
||||
class AltBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
num_heads,
|
||||
mlp_ratio=4.0,
|
||||
qkv_bias=False,
|
||||
qk_scale=None,
|
||||
drop=0.0,
|
||||
attn_drop=0.0,
|
||||
drop_path=0.0,
|
||||
act_layer=nn.GELU,
|
||||
norm_layer=nn.LayerNorm,
|
||||
layer_norm_first=True,
|
||||
ffn_targets=False,
|
||||
use_rel_pos_bias=False,
|
||||
window_size=None,
|
||||
alt_attention=False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.layer_norm_first = layer_norm_first
|
||||
self.ffn_targets = ffn_targets
|
||||
|
||||
from timm.models.vision_transformer import Attention, DropPath, Mlp
|
||||
|
||||
self.norm1 = norm_layer(dim)
|
||||
self.use_rel_pos_bias = use_rel_pos_bias
|
||||
if use_rel_pos_bias:
|
||||
self.attn = AltAttention(
|
||||
dim,
|
||||
num_heads=num_heads,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_scale=qk_scale,
|
||||
attn_drop=attn_drop,
|
||||
proj_drop=drop,
|
||||
window_size=window_size,
|
||||
)
|
||||
else:
|
||||
if alt_attention:
|
||||
from .multi.modules import AltAttention as AltAttention2
|
||||
self.attn = AltAttention2(
|
||||
dim,
|
||||
num_heads=num_heads,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_scale=qk_scale,
|
||||
attn_drop=attn_drop,
|
||||
proj_drop=drop,
|
||||
)
|
||||
else:
|
||||
self.attn = Attention(
|
||||
dim,
|
||||
num_heads=num_heads,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_scale=qk_scale,
|
||||
attn_drop=attn_drop,
|
||||
proj_drop=drop,
|
||||
)
|
||||
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
||||
self.norm2 = norm_layer(dim)
|
||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||
self.mlp = Mlp(
|
||||
in_features=dim,
|
||||
hidden_features=mlp_hidden_dim,
|
||||
act_layer=act_layer,
|
||||
drop=drop,
|
||||
)
|
||||
|
||||
def forward(self, x, rel_pos_bias=None, pos_mask=None):
|
||||
if self.layer_norm_first:
|
||||
if self.use_rel_pos_bias:
|
||||
x = x + self.drop_path(
|
||||
self.attn(
|
||||
self.norm1(x), rel_pos_bias=rel_pos_bias, pos_mask=pos_mask
|
||||
)
|
||||
)
|
||||
else:
|
||||
x = x + self.drop_path(self.attn(self.norm1(x)))
|
||||
t = self.mlp(self.norm2(x))
|
||||
x = x + self.drop_path(t)
|
||||
if not self.ffn_targets:
|
||||
t = x
|
||||
return x, t
|
||||
else:
|
||||
if self.use_rel_pos_bias:
|
||||
x = x + self.drop_path(
|
||||
self.attn(x, rel_pos_bias=rel_pos_bias, pos_mask=pos_mask)
|
||||
)
|
||||
else:
|
||||
x = x + self.drop_path(self.attn(x))
|
||||
r = x = self.norm1(x)
|
||||
x = self.mlp(x)
|
||||
t = x
|
||||
x = self.norm2(r + self.drop_path(x))
|
||||
if not self.ffn_targets:
|
||||
t = x
|
||||
return x, t
|
||||
|
||||
|
||||
class AltAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
num_heads=8,
|
||||
qkv_bias=True,
|
||||
qk_scale=None,
|
||||
attn_drop=0.0,
|
||||
proj_drop=0.0,
|
||||
window_size=None,
|
||||
attn_head_dim=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
if attn_head_dim is not None:
|
||||
head_dim = attn_head_dim
|
||||
all_head_dim = head_dim * self.num_heads
|
||||
self.scale = qk_scale or head_dim ** -0.5
|
||||
|
||||
self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
|
||||
if qkv_bias:
|
||||
self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
|
||||
self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
|
||||
else:
|
||||
self.q_bias = None
|
||||
self.v_bias = None
|
||||
|
||||
if window_size:
|
||||
self.window_size = window_size
|
||||
self.num_relative_distance = (2 * window_size[0] - 1) * (
|
||||
2 * window_size[1] - 1
|
||||
) + 3
|
||||
self.relative_position_bias_table = nn.Parameter(
|
||||
torch.zeros(self.num_relative_distance, num_heads)
|
||||
) # 2*Wh-1 * 2*Ww-1, nH
|
||||
# cls to token & token 2 cls & cls to cls
|
||||
|
||||
# get pair-wise relative position index for each token inside the window
|
||||
coords_h = torch.arange(window_size[0])
|
||||
coords_w = torch.arange(window_size[1])
|
||||
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
||||
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
||||
relative_coords = (
|
||||
coords_flatten[:, :, None] - coords_flatten[:, None, :]
|
||||
) # 2, Wh*Ww, Wh*Ww
|
||||
relative_coords = relative_coords.permute(
|
||||
1, 2, 0
|
||||
).contiguous() # Wh*Ww, Wh*Ww, 2
|
||||
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
|
||||
relative_coords[:, :, 1] += window_size[1] - 1
|
||||
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
|
||||
relative_position_index = torch.zeros(
|
||||
size=(window_size[0] * window_size[1] + 1,) * 2,
|
||||
dtype=relative_coords.dtype,
|
||||
)
|
||||
relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
||||
relative_position_index[0, 0:] = self.num_relative_distance - 3
|
||||
relative_position_index[0:, 0] = self.num_relative_distance - 2
|
||||
relative_position_index[0, 0] = self.num_relative_distance - 1
|
||||
|
||||
self.register_buffer("relative_position_index", relative_position_index)
|
||||
else:
|
||||
self.window_size = None
|
||||
self.relative_position_bias_table = None
|
||||
self.relative_position_index = None
|
||||
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj = nn.Linear(all_head_dim, dim)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
|
||||
def forward(self, x, rel_pos_bias=None, pos_mask=None):
|
||||
B, N, C = x.shape
|
||||
qkv_bias = None
|
||||
if self.q_bias is not None:
|
||||
qkv_bias = torch.cat(
|
||||
(
|
||||
self.q_bias,
|
||||
torch.zeros_like(self.v_bias, requires_grad=False),
|
||||
self.v_bias,
|
||||
)
|
||||
)
|
||||
# qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
||||
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
|
||||
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = (
|
||||
qkv[0],
|
||||
qkv[1],
|
||||
qkv[2],
|
||||
) # make torchscript happy (cannot use tensor as tuple)
|
||||
|
||||
q = q * self.scale
|
||||
attn = q @ k.transpose(-2, -1)
|
||||
|
||||
if self.relative_position_bias_table is not None:
|
||||
relative_position_bias = self.relative_position_bias_table[
|
||||
self.relative_position_index.view(-1)
|
||||
].view(
|
||||
self.window_size[0] * self.window_size[1] + 1,
|
||||
self.window_size[0] * self.window_size[1] + 1,
|
||||
-1,
|
||||
) # Wh*Ww,Wh*Ww,nH
|
||||
relative_position_bias = relative_position_bias.permute(
|
||||
2, 0, 1
|
||||
).contiguous() # nH, Wh*Ww, Wh*Ww
|
||||
attn = attn + modify_relative_position_bias(
|
||||
relative_position_bias, x.size(0), pos_mask
|
||||
)
|
||||
|
||||
if rel_pos_bias is not None:
|
||||
attn = attn + rel_pos_bias
|
||||
|
||||
attn = attn.softmax(dim=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
|
||||
x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
||||
|
||||
class RelativePositionBias(nn.Module):
|
||||
def __init__(self, window_size, num_heads):
|
||||
super().__init__()
|
||||
self.window_size = window_size
|
||||
self.num_relative_distance = (2 * window_size[0] - 1) * (
|
||||
2 * window_size[1] - 1
|
||||
) + 3
|
||||
self.relative_position_bias_table = nn.Parameter(
|
||||
torch.zeros(self.num_relative_distance, num_heads)
|
||||
)
|
||||
|
||||
# get pair-wise relative position index for each token inside the window
|
||||
coords_h = torch.arange(window_size[0])
|
||||
coords_w = torch.arange(window_size[1])
|
||||
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
||||
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
||||
relative_coords = (
|
||||
coords_flatten[:, :, None] - coords_flatten[:, None, :]
|
||||
) # 2, Wh*Ww, Wh*Ww
|
||||
relative_coords = relative_coords.permute(
|
||||
1, 2, 0
|
||||
).contiguous() # Wh*Ww, Wh*Ww, 2
|
||||
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
|
||||
relative_coords[:, :, 1] += window_size[1] - 1
|
||||
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
|
||||
relative_position_index = torch.zeros(
|
||||
size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype
|
||||
)
|
||||
relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
||||
relative_position_index[0, 0:] = self.num_relative_distance - 3
|
||||
relative_position_index[0:, 0] = self.num_relative_distance - 2
|
||||
relative_position_index[0, 0] = self.num_relative_distance - 1
|
||||
|
||||
self.register_buffer("relative_position_index", relative_position_index)
|
||||
|
||||
def forward(self):
|
||||
relative_position_bias = self.relative_position_bias_table[
|
||||
self.relative_position_index.view(-1)
|
||||
].view(
|
||||
self.window_size[0] * self.window_size[1] + 1,
|
||||
self.window_size[0] * self.window_size[1] + 1,
|
||||
-1,
|
||||
) # Wh*Ww,Wh*Ww,nH
|
||||
return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
||||
|
||||
|
||||
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
|
||||
"""
|
||||
grid_size: int of the grid height and width
|
||||
return:
|
||||
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
||||
"""
|
||||
grid_h = np.arange(grid_size, dtype=np.float32)
|
||||
grid_w = np.arange(grid_size, dtype=np.float32)
|
||||
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
||||
grid = np.stack(grid, axis=0)
|
||||
|
||||
grid = grid.reshape([2, 1, grid_size, grid_size])
|
||||
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
||||
if cls_token:
|
||||
pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
|
||||
return pos_embed
|
||||
|
||||
|
||||
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
||||
assert embed_dim % 2 == 0
|
||||
|
||||
# use half of dimensions to encode grid_h
|
||||
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
||||
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
||||
|
||||
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
||||
return emb
|
||||
|
||||
|
||||
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
||||
"""
|
||||
embed_dim: output dimension for each position
|
||||
pos: a list of positions to be encoded: size (M,)
|
||||
out: (M, D)
|
||||
"""
|
||||
assert embed_dim % 2 == 0
|
||||
omega = np.arange(embed_dim // 2, dtype=np.float)
|
||||
omega /= embed_dim / 2.0
|
||||
omega = 1.0 / 10000 ** omega # (D/2,)
|
||||
|
||||
pos = pos.reshape(-1) # (M,)
|
||||
out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
|
||||
|
||||
emb_sin = np.sin(out) # (M, D/2)
|
||||
emb_cos = np.cos(out) # (M, D/2)
|
||||
|
||||
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
||||
return emb
|
||||
|
||||
|
||||
def interpolate_pos_embed(model, checkpoint_model):
|
||||
if "pos_embed" in checkpoint_model:
|
||||
pos_embed_checkpoint = checkpoint_model["pos_embed"]
|
||||
embedding_size = pos_embed_checkpoint.shape[-1]
|
||||
num_patches = model.patch_embed.num_patches
|
||||
num_extra_tokens = model.pos_embed.shape[-2] - num_patches
|
||||
# height (== width) for the checkpoint position embedding
|
||||
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
|
||||
# height (== width) for the new position embedding
|
||||
new_size = int(num_patches ** 0.5)
|
||||
# class_token and dist_token are kept unchanged
|
||||
if orig_size != new_size:
|
||||
print(
|
||||
"Position interpolate from %dx%d to %dx%d"
|
||||
% (orig_size, orig_size, new_size, new_size)
|
||||
)
|
||||
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
|
||||
# only the position tokens are interpolated
|
||||
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
|
||||
pos_tokens = pos_tokens.reshape(
|
||||
-1, orig_size, orig_size, embedding_size
|
||||
).permute(0, 3, 1, 2)
|
||||
pos_tokens = torch.nn.functional.interpolate(
|
||||
pos_tokens,
|
||||
size=(new_size, new_size),
|
||||
mode="bicubic",
|
||||
align_corners=False,
|
||||
)
|
||||
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
|
||||
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
|
||||
checkpoint_model["pos_embed"] = new_pos_embed
|
||||
|
||||
|
||||
@register_model("mae", dataclass=MaeConfig)
|
||||
class MaeModel(BaseFairseqModel):
|
||||
def __init__(self, cfg: MaeConfig):
|
||||
super().__init__()
|
||||
self.cfg = cfg
|
||||
|
||||
self.mask_ratio = cfg.mask_ratio
|
||||
|
||||
# --------------------------------------------------------------------------
|
||||
# MAE encoder specifics
|
||||
self.patch_embed = PatchEmbed(
|
||||
cfg.input_size, cfg.patch_size, cfg.in_chans, cfg.embed_dim
|
||||
)
|
||||
num_patches = self.patch_embed.num_patches
|
||||
|
||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, cfg.embed_dim)) if not cfg.no_cls else None
|
||||
self.pos_embed = nn.Parameter(
|
||||
torch.zeros(1, num_patches + int(not cfg.no_cls), cfg.embed_dim), requires_grad=False
|
||||
) # fixed sin-cos embedding
|
||||
|
||||
norm_layer = partial(nn.LayerNorm, eps=cfg.norm_eps)
|
||||
|
||||
dpr = [
|
||||
x.item() for x in torch.linspace(0, cfg.drop_path_rate, cfg.depth)
|
||||
] # stochastic depth decay rule
|
||||
|
||||
def make_block(drop_path):
|
||||
if cfg.w2v_block:
|
||||
return TransformerSentenceEncoderLayer(
|
||||
embedding_dim=cfg.embed_dim,
|
||||
ffn_embedding_dim=cfg.embed_dim * cfg.mlp_ratio,
|
||||
num_attention_heads=cfg.num_heads,
|
||||
dropout=cfg.block_dropout,
|
||||
attention_dropout=cfg.attention_dropout,
|
||||
activation_dropout=cfg.activation_dropout,
|
||||
activation_fn="gelu",
|
||||
layer_norm_first=cfg.layer_norm_first,
|
||||
drop_path=drop_path,
|
||||
norm_eps=1e-6,
|
||||
single_qkv=cfg.single_qkv,
|
||||
fused_ln=cfg.fused_ln,
|
||||
)
|
||||
elif cfg.alt_block:
|
||||
window_size = (
|
||||
cfg.input_size // self.patch_embed.patch_size[0],
|
||||
cfg.input_size // self.patch_embed.patch_size[1],
|
||||
)
|
||||
return AltBlock(
|
||||
cfg.embed_dim,
|
||||
cfg.num_heads,
|
||||
cfg.mlp_ratio,
|
||||
qkv_bias=True,
|
||||
qk_scale=None,
|
||||
norm_layer=norm_layer,
|
||||
drop_path=drop_path,
|
||||
layer_norm_first=cfg.layer_norm_first,
|
||||
ffn_targets=not cfg.end_of_block_targets,
|
||||
use_rel_pos_bias=cfg.use_rel_pos_bias,
|
||||
window_size=window_size
|
||||
if (self.cfg.use_rel_pos_bias and not self.cfg.shared_rel_pos_bias)
|
||||
else None,
|
||||
alt_attention=cfg.alt_attention,
|
||||
)
|
||||
elif cfg.alt_block2:
|
||||
from .multi.modules import AltBlock as AltBlock2
|
||||
return AltBlock2(
|
||||
cfg.embed_dim,
|
||||
cfg.num_heads,
|
||||
cfg.mlp_ratio,
|
||||
qkv_bias=True,
|
||||
qk_scale=None,
|
||||
norm_layer=norm_layer,
|
||||
drop_path=drop_path,
|
||||
layer_norm_first=cfg.layer_norm_first,
|
||||
ffn_targets=not cfg.end_of_block_targets,
|
||||
)
|
||||
else:
|
||||
return Block(
|
||||
cfg.embed_dim,
|
||||
cfg.num_heads,
|
||||
cfg.mlp_ratio,
|
||||
qkv_bias=True,
|
||||
qk_scale=None,
|
||||
norm_layer=norm_layer,
|
||||
drop_path=drop_path,
|
||||
)
|
||||
|
||||
self.blocks = nn.ModuleList([make_block(dpr[i]) for i in range(cfg.depth)])
|
||||
self.norm = norm_layer(cfg.embed_dim)
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
# --------------------------------------------------------------------------
|
||||
# MAE decoder specifics
|
||||
self.decoder_embed = (
|
||||
nn.Linear(cfg.embed_dim, cfg.decoder_embed_dim, bias=True)
|
||||
if not cfg.no_decoder_embed
|
||||
else None
|
||||
)
|
||||
|
||||
self.mask_token = (
|
||||
nn.Parameter(
|
||||
torch.zeros(
|
||||
1,
|
||||
1,
|
||||
cfg.decoder_embed_dim
|
||||
if not cfg.no_decoder_embed
|
||||
else cfg.embed_dim,
|
||||
)
|
||||
)
|
||||
if cfg.mask_noise_std <= 0
|
||||
else None
|
||||
)
|
||||
|
||||
self.decoder_pos_embed = (
|
||||
nn.Parameter(
|
||||
torch.zeros(
|
||||
1,
|
||||
num_patches + 1,
|
||||
cfg.decoder_embed_dim
|
||||
if not cfg.no_decoder_embed
|
||||
else cfg.embed_dim,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
if not cfg.no_decoder_pos_embed
|
||||
else None
|
||||
)
|
||||
|
||||
self.decoder_blocks = nn.ModuleList(
|
||||
[
|
||||
Block(
|
||||
cfg.decoder_embed_dim,
|
||||
cfg.decoder_num_heads,
|
||||
cfg.mlp_ratio,
|
||||
qkv_bias=True,
|
||||
qk_scale=None,
|
||||
norm_layer=norm_layer,
|
||||
)
|
||||
for _ in range(cfg.decoder_depth)
|
||||
]
|
||||
)
|
||||
|
||||
self.decoder_norm = norm_layer(cfg.decoder_embed_dim)
|
||||
self.decoder_pred = nn.Linear(
|
||||
cfg.decoder_embed_dim, cfg.patch_size ** 2 * cfg.in_chans, bias=True
|
||||
) # decoder to patch
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
self.norm_pix_loss = cfg.norm_pix_loss
|
||||
|
||||
self.initialize_weights()
|
||||
|
||||
for pn, p in self.named_parameters():
|
||||
if len(p.shape) == 1 or pn.endswith(".bias"):
|
||||
p.param_group = "no_decay"
|
||||
else:
|
||||
p.param_group = "with_decay"
|
||||
|
||||
def initialize_weights(self):
|
||||
# initialization
|
||||
# initialize (and freeze) pos_embed by sin-cos embedding
|
||||
pos_embed = get_2d_sincos_pos_embed(
|
||||
self.pos_embed.shape[-1],
|
||||
int(self.patch_embed.num_patches ** 0.5),
|
||||
cls_token=not self.cfg.no_cls,
|
||||
)
|
||||
self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
|
||||
|
||||
if self.decoder_pos_embed is not None:
|
||||
decoder_pos_embed = get_2d_sincos_pos_embed(
|
||||
self.decoder_pos_embed.shape[-1],
|
||||
int(self.patch_embed.num_patches ** 0.5),
|
||||
cls_token=not self.cfg.no_cls,
|
||||
)
|
||||
self.decoder_pos_embed.data.copy_(
|
||||
torch.from_numpy(decoder_pos_embed).float().unsqueeze(0)
|
||||
)
|
||||
|
||||
# initialize patch_embed like nn.Linear (instead of nn.Conv2d)
|
||||
w = self.patch_embed.proj.weight.data
|
||||
torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
||||
|
||||
# timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
|
||||
if self.cls_token is not None:
|
||||
torch.nn.init.normal_(self.cls_token, std=0.02)
|
||||
|
||||
if self.mask_token is not None:
|
||||
torch.nn.init.normal_(self.mask_token, std=0.02)
|
||||
|
||||
# initialize nn.Linear and nn.LayerNorm
|
||||
self.apply(self._init_weights)
|
||||
|
||||
def _init_weights(self, m):
|
||||
if isinstance(m, nn.Linear):
|
||||
# we use xavier_uniform following official JAX ViT:
|
||||
torch.nn.init.xavier_uniform_(m.weight)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.LayerNorm) or isinstance(m, FusedLayerNorm):
|
||||
nn.init.constant_(m.bias, 0)
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
|
||||
def patchify(self, imgs):
|
||||
"""
|
||||
imgs: (N, 3, H, W)
|
||||
x: (N, L, patch_size**2 *3)
|
||||
"""
|
||||
p = self.patch_embed.patch_size[0]
|
||||
assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0
|
||||
|
||||
h = w = imgs.shape[2] // p
|
||||
x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p))
|
||||
x = torch.einsum("nchpwq->nhwpqc", x)
|
||||
x = x.reshape(shape=(imgs.shape[0], h * w, p ** 2 * 3))
|
||||
return x
|
||||
|
||||
def unpatchify(self, x):
|
||||
"""
|
||||
x: (N, L, patch_size**2 *3)
|
||||
imgs: (N, 3, H, W)
|
||||
"""
|
||||
p = self.patch_embed.patch_size[0]
|
||||
h = w = int(x.shape[1] ** 0.5)
|
||||
assert h * w == x.shape[1]
|
||||
|
||||
x = x.reshape(shape=(x.shape[0], h, w, p, p, 3))
|
||||
x = torch.einsum("nhwpqc->nchpwq", x)
|
||||
imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p))
|
||||
return imgs
|
||||
|
||||
def random_masking(self, x, mask_ratio):
|
||||
"""
|
||||
Perform per-sample random masking by per-sample shuffling.
|
||||
Per-sample shuffling is done by argsort random noise.
|
||||
x: [N, L, D], sequence
|
||||
"""
|
||||
N, L, D = x.shape # batch, length, dim
|
||||
len_keep = int(L * (1 - mask_ratio))
|
||||
|
||||
noise = torch.rand(N, L, device=x.device) # noise in [0, 1]
|
||||
|
||||
# sort noise for each sample
|
||||
ids_shuffle = torch.argsort(
|
||||
noise, dim=1
|
||||
) # ascend: small is keep, large is remove
|
||||
ids_restore = torch.argsort(ids_shuffle, dim=1)
|
||||
|
||||
# keep the first subset
|
||||
ids_keep = ids_shuffle[:, :len_keep]
|
||||
x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
|
||||
|
||||
# generate the binary mask: 0 is keep, 1 is remove
|
||||
mask = torch.ones([N, L], device=x.device)
|
||||
mask[:, :len_keep] = 0
|
||||
# unshuffle to get the binary mask
|
||||
mask = torch.gather(mask, dim=1, index=ids_restore)
|
||||
|
||||
return x_masked, mask, ids_restore # x_masked is actually unmasked x
|
||||
|
||||
@classmethod
|
||||
def build_model(cls, cfg: MaeConfig, task=None):
|
||||
"""Build a new model instance."""
|
||||
|
||||
return cls(cfg)
|
||||
|
||||
def forward_encoder(self, x, mask_ratio):
|
||||
# embed patches
|
||||
x = self.patch_embed(x)
|
||||
|
||||
# add pos embed w/o cls token
|
||||
# if self.cls_token is not None:
|
||||
# x = x + self.pos_embed
|
||||
# else:
|
||||
x = x + self.pos_embed[:, 1:, :]
|
||||
|
||||
# masking: length -> length * mask_ratio
|
||||
if mask_ratio > 0:
|
||||
x, mask, ids_restore = self.random_masking(x, mask_ratio)
|
||||
else:
|
||||
mask = ids_restore = None
|
||||
|
||||
# append cls token
|
||||
if self.cls_token is not None:
|
||||
cls_token = self.cls_token + self.pos_embed[:, :1, :]
|
||||
cls_tokens = cls_token.expand(x.shape[0], -1, -1)
|
||||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
|
||||
# apply Transformer blocks
|
||||
for blk in self.blocks:
|
||||
x = blk(x)
|
||||
|
||||
if self.norm is not None:
|
||||
x = self.norm(x)
|
||||
|
||||
return x, mask, ids_restore
|
||||
|
||||
def forward_decoder(self, x, ids_restore):
|
||||
# embed tokens
|
||||
x = self.decoder_embed(x)
|
||||
|
||||
# append mask tokens to sequence
|
||||
mask_tokens = self.mask_token.repeat(
|
||||
x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1
|
||||
)
|
||||
if self.cls_token is not None:
|
||||
x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token
|
||||
else:
|
||||
x_ = torch.cat([x, mask_tokens], dim=1) # no cls token
|
||||
|
||||
x_ = torch.gather(
|
||||
x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])
|
||||
) # unshuffle
|
||||
|
||||
if self.cls_token is not None:
|
||||
x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token
|
||||
|
||||
# add pos embed
|
||||
x = x + self.decoder_pos_embed
|
||||
|
||||
# apply Transformer blocks
|
||||
for blk in self.decoder_blocks:
|
||||
x = blk(x)
|
||||
x = self.decoder_norm(x)
|
||||
|
||||
# predictor projection
|
||||
x = self.decoder_pred(x)
|
||||
|
||||
if self.cls_token is not None:
|
||||
# remove cls token
|
||||
x = x[:, 1:, :]
|
||||
|
||||
return x
|
||||
|
||||
def forward_loss(self, imgs, pred, mask):
|
||||
"""
|
||||
imgs: [N, 3, H, W]
|
||||
pred: [N, L, p*p*3]
|
||||
mask: [N, L], 0 is keep, 1 is remove,
|
||||
"""
|
||||
target = self.patchify(imgs)
|
||||
if self.norm_pix_loss:
|
||||
mean = target.mean(dim=-1, keepdim=True)
|
||||
var = target.var(dim=-1, keepdim=True)
|
||||
target = (target - mean) / (var + 1.0e-6) ** 0.5
|
||||
|
||||
loss = (pred - target) ** 2
|
||||
loss = loss.mean(dim=-1) # [N, L], mean loss per patch
|
||||
|
||||
loss = (loss * mask).sum()
|
||||
return loss, mask.sum()
|
||||
|
||||
def forward(self, imgs, predictions_only=False):
|
||||
latent, mask, ids_restore = self.forward_encoder(
|
||||
imgs, self.mask_ratio if not predictions_only else 0
|
||||
)
|
||||
|
||||
if predictions_only:
|
||||
return latent
|
||||
|
||||
pred = self.forward_decoder(latent, ids_restore) # [N, L, p*p*3]
|
||||
loss, sample_size = self.forward_loss(imgs, pred, mask)
|
||||
|
||||
result = {
|
||||
"losses": {"regression": loss},
|
||||
"sample_size": sample_size,
|
||||
}
|
||||
return result
|
||||
|
||||
def remove_pretraining_modules(self):
|
||||
self.decoder_embed = None
|
||||
self.decoder_blocks = None
|
||||
self.decoder_norm = None
|
||||
self.decoder_pos_embed = None
|
||||
self.decoder_pred = None
|
||||
self.mask_token = None
|
||||
if self.cfg.layer_norm_first:
|
||||
self.norm = None
|
386
examples/data2vec/models/mae_image_classification.py
Normal file
386
examples/data2vec/models/mae_image_classification.py
Normal file
@ -0,0 +1,386 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
# The code in this file is adapted from the BeiT implementation which can be found here:
|
||||
# https://github.com/microsoft/unilm/tree/master/beit
|
||||
|
||||
import logging
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum, auto
|
||||
from typing import Any, Optional
|
||||
|
||||
import numpy as np
|
||||
from omegaconf import II, MISSING
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from fairseq import checkpoint_utils, tasks
|
||||
from omegaconf import open_dict
|
||||
|
||||
from fairseq.dataclass import FairseqDataclass
|
||||
from fairseq.models import BaseFairseqModel, register_model
|
||||
from .mae import interpolate_pos_embed
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PredictionMode(Enum):
|
||||
MEAN_POOLING = auto()
|
||||
CLS_TOKEN = auto()
|
||||
LIN_SOFTMAX = auto()
|
||||
|
||||
|
||||
@dataclass
|
||||
class MaeImageClassificationConfig(FairseqDataclass):
|
||||
model_path: str = MISSING
|
||||
no_pretrained_weights: bool = False
|
||||
linear_classifier: bool = False
|
||||
num_classes: int = 1000
|
||||
mixup: float = 0.8
|
||||
cutmix: float = 1.0
|
||||
label_smoothing: float = 0.1
|
||||
|
||||
drop_path_rate: float = 0.1
|
||||
layer_decay: float = 0.65
|
||||
|
||||
mixup_prob: float = 1.0
|
||||
mixup_switch_prob: float = 0.5
|
||||
mixup_mode: str = "batch"
|
||||
|
||||
pretrained_model_args: Any = None
|
||||
data: str = II("task.data")
|
||||
|
||||
norm_eps: Optional[float] = None
|
||||
|
||||
remove_alibi: bool = False
|
||||
|
||||
# regularization overwrites
|
||||
encoder_dropout: float = 0
|
||||
post_mlp_drop: float = 0
|
||||
attention_dropout: float = 0
|
||||
activation_dropout: float = 0.0
|
||||
dropout_input: float = 0.0
|
||||
layerdrop: float = 0.0
|
||||
|
||||
prenet_layerdrop: float = 0
|
||||
prenet_dropout: float = 0
|
||||
|
||||
use_fc_norm: bool = True
|
||||
prediction_mode: PredictionMode = PredictionMode.MEAN_POOLING
|
||||
|
||||
no_decay_blocks: bool = True
|
||||
|
||||
|
||||
def get_layer_id_for_vit(name, num_layers):
|
||||
"""
|
||||
Assign a parameter with its layer id
|
||||
Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33
|
||||
"""
|
||||
if name in ["cls_token", "pos_embed"]:
|
||||
return 0
|
||||
elif name.startswith("patch_embed"):
|
||||
return 0
|
||||
elif name.startswith("rel_pos_bias"):
|
||||
return num_layers - 1
|
||||
elif name.startswith("blocks"):
|
||||
return int(name.split(".")[1]) + 1
|
||||
else:
|
||||
return num_layers
|
||||
|
||||
|
||||
@register_model("mae_image_classification", dataclass=MaeImageClassificationConfig)
|
||||
class MaeImageClassificationModel(BaseFairseqModel):
|
||||
def __init__(self, cfg: MaeImageClassificationConfig):
|
||||
super().__init__()
|
||||
self.cfg = cfg
|
||||
|
||||
if cfg.pretrained_model_args is None:
|
||||
state = checkpoint_utils.load_checkpoint_to_cpu(cfg.model_path, {})
|
||||
pretrained_args = state.get("cfg", None)
|
||||
|
||||
pretrained_args.criterion = None
|
||||
pretrained_args.lr_scheduler = None
|
||||
|
||||
logger.info(pretrained_args.model)
|
||||
|
||||
with open_dict(pretrained_args.model):
|
||||
pretrained_args.model.drop_path_rate = cfg.drop_path_rate
|
||||
if cfg.norm_eps is not None:
|
||||
pretrained_args.model.norm_eps = cfg.norm_eps
|
||||
|
||||
cfg.pretrained_model_args = pretrained_args
|
||||
|
||||
logger.info(pretrained_args)
|
||||
else:
|
||||
state = None
|
||||
pretrained_args = cfg.pretrained_model_args
|
||||
|
||||
if "data" in pretrained_args.task:
|
||||
pretrained_args.task.data = cfg.data
|
||||
elif "image" in pretrained_args.task:
|
||||
pretrained_args.task.image.data = cfg.data
|
||||
|
||||
if "modalities" in pretrained_args.model:
|
||||
prenet_blocks = pretrained_args.model["modalities"]["image"]["prenet_depth"]
|
||||
model_blocks = pretrained_args.model["depth"]
|
||||
with open_dict(pretrained_args):
|
||||
dpr = np.linspace(0, cfg.drop_path_rate, model_blocks).tolist()
|
||||
pretrained_args.model["modalities"]["image"][
|
||||
"start_drop_path_rate"
|
||||
] = dpr[0]
|
||||
pretrained_args.model["modalities"]["image"][
|
||||
"end_drop_path_rate"
|
||||
] = max(0, dpr[prenet_blocks - 1])
|
||||
pretrained_args.model["start_drop_path_rate"] = dpr[prenet_blocks]
|
||||
pretrained_args.model["end_drop_path_rate"] = dpr[-1]
|
||||
|
||||
if "mae_masking" in pretrained_args.model["modalities"]["image"]:
|
||||
del pretrained_args.model["modalities"]["image"]["mae_masking"]
|
||||
|
||||
if cfg.remove_alibi:
|
||||
pretrained_args.model["modalities"]["image"][
|
||||
"use_alibi_encoder"
|
||||
] = False
|
||||
if (
|
||||
state is not None
|
||||
and "modality_encoders.IMAGE.alibi_bias" in state["model"]
|
||||
):
|
||||
del state["model"]["modality_encoders.IMAGE.alibi_bias"]
|
||||
|
||||
pretrained_args.model["encoder_dropout"] = cfg.encoder_dropout
|
||||
pretrained_args.model["post_mlp_drop"] = cfg.post_mlp_drop
|
||||
pretrained_args.model["attention_dropout"] = cfg.attention_dropout
|
||||
pretrained_args.model["activation_dropout"] = cfg.activation_dropout
|
||||
pretrained_args.model["dropout_input"] = cfg.dropout_input
|
||||
pretrained_args.model["layerdrop"] = cfg.layerdrop
|
||||
|
||||
pretrained_args.model["modalities"]["image"][
|
||||
"prenet_layerdrop"
|
||||
] = cfg.prenet_layerdrop
|
||||
pretrained_args.model["modalities"]["image"][
|
||||
"prenet_dropout"
|
||||
] = cfg.prenet_dropout
|
||||
else:
|
||||
# not d2v multi
|
||||
with open_dict(pretrained_args):
|
||||
pretrained_args.model["drop_path_rate"] = cfg.drop_path_rate
|
||||
pretrained_args.model["block_dropout"] = cfg.encoder_dropout
|
||||
pretrained_args.model["attention_dropout"] = cfg.attention_dropout
|
||||
pretrained_args.model["activation_dropout"] = cfg.activation_dropout
|
||||
|
||||
task = tasks.setup_task(pretrained_args.task)
|
||||
model = task.build_model(pretrained_args.model, from_checkpoint=True)
|
||||
|
||||
self.d2v_multi = "data2vec_multi" in pretrained_args.model._name
|
||||
self.linear_classifier = cfg.linear_classifier
|
||||
|
||||
self.model = model
|
||||
|
||||
if state is not None and not cfg.no_pretrained_weights:
|
||||
interpolate_pos_embed(model, state)
|
||||
|
||||
if "modality_encoders.IMAGE.positional_encoder.pos_embed" in state["model"]:
|
||||
state["model"][
|
||||
"modality_encoders.IMAGE.positional_encoder.positions"
|
||||
] = state["model"][
|
||||
"modality_encoders.IMAGE.positional_encoder.pos_embed"
|
||||
]
|
||||
del state["model"][
|
||||
"modality_encoders.IMAGE.positional_encoder.pos_embed"
|
||||
]
|
||||
if "modality_encoders.IMAGE.encoder_mask" in state["model"]:
|
||||
del state["model"]["modality_encoders.IMAGE.encoder_mask"]
|
||||
|
||||
model.load_state_dict(state["model"], strict=True)
|
||||
|
||||
if self.d2v_multi:
|
||||
model.remove_pretraining_modules(modality="image")
|
||||
else:
|
||||
model.remove_pretraining_modules()
|
||||
|
||||
if self.linear_classifier:
|
||||
model.requires_grad_(False)
|
||||
|
||||
self.fc_norm = None
|
||||
if self.cfg.use_fc_norm:
|
||||
self.fc_norm = nn.LayerNorm(pretrained_args.model.embed_dim, eps=1e-6)
|
||||
nn.init.constant_(self.fc_norm.bias, 0)
|
||||
nn.init.constant_(self.fc_norm.weight, 1.0)
|
||||
|
||||
self.head = nn.Linear(pretrained_args.model.embed_dim, cfg.num_classes)
|
||||
|
||||
nn.init.trunc_normal_(self.head.weight, std=0.02)
|
||||
nn.init.constant_(self.head.bias, 0)
|
||||
|
||||
self.mixup_fn = None
|
||||
|
||||
if cfg.mixup > 0 or cfg.cutmix > 0:
|
||||
from timm.data import Mixup
|
||||
|
||||
self.mixup_fn = Mixup(
|
||||
mixup_alpha=cfg.mixup,
|
||||
cutmix_alpha=cfg.cutmix,
|
||||
cutmix_minmax=None,
|
||||
prob=cfg.mixup_prob,
|
||||
switch_prob=cfg.mixup_switch_prob,
|
||||
mode=cfg.mixup_mode,
|
||||
label_smoothing=cfg.label_smoothing,
|
||||
num_classes=cfg.num_classes,
|
||||
)
|
||||
|
||||
if self.model.norm is not None:
|
||||
for pn, p in self.model.norm.named_parameters():
|
||||
if len(p.shape) == 1 or pn.endswith(".bias"):
|
||||
p.optim_overrides = {"optimizer": {"weight_decay_scale": 0}}
|
||||
|
||||
if self.fc_norm is not None:
|
||||
for pn, p in self.fc_norm.named_parameters():
|
||||
if len(p.shape) == 1 or pn.endswith(".bias"):
|
||||
p.optim_overrides = {"optimizer": {"weight_decay_scale": 0}}
|
||||
|
||||
for pn, p in self.head.named_parameters():
|
||||
if len(p.shape) == 1 or pn.endswith(".bias"):
|
||||
p.optim_overrides = {"optimizer": {"weight_decay_scale": 0}}
|
||||
|
||||
if self.d2v_multi:
|
||||
mod_encs = list(model.modality_encoders.values())
|
||||
assert len(mod_encs) == 1, len(mod_encs)
|
||||
blocks = list(mod_encs[0].context_encoder.blocks) + list(model.blocks)
|
||||
else:
|
||||
blocks = model.blocks
|
||||
|
||||
num_layers = len(blocks) + 1
|
||||
layer_scales = list(
|
||||
cfg.layer_decay ** (num_layers - i) for i in range(num_layers + 1)
|
||||
)
|
||||
|
||||
if self.d2v_multi:
|
||||
for n, p in self.model.named_parameters():
|
||||
optimizer_override_dict = {}
|
||||
|
||||
if len(p.shape) == 1 or n.endswith(".bias"):
|
||||
optimizer_override_dict["weight_decay_scale"] = 0
|
||||
|
||||
p.optim_overrides = {"optimizer": optimizer_override_dict}
|
||||
|
||||
if cfg.layer_decay > 0:
|
||||
for i, b in enumerate(blocks):
|
||||
lid = i + 1
|
||||
if layer_scales[lid] == 1.0:
|
||||
continue
|
||||
|
||||
for n, p in b.named_parameters():
|
||||
optim_override = getattr(p, "optim_overrides", {})
|
||||
if "optimizer" not in optim_override:
|
||||
optim_override["optimizer"] = {}
|
||||
|
||||
if cfg.no_decay_blocks:
|
||||
optim_override["optimizer"]["lr_scale"] = layer_scales[lid]
|
||||
p.optim_overrides = optim_override
|
||||
else:
|
||||
optim_override["optimizer"] = {
|
||||
"lr_scale": layer_scales[lid]
|
||||
}
|
||||
p.optim_overrides = optim_override
|
||||
|
||||
else:
|
||||
for n, p in self.model.named_parameters():
|
||||
optimizer_override_dict = {}
|
||||
layer_id = get_layer_id_for_vit(n, num_layers)
|
||||
|
||||
if len(p.shape) == 1 or n.endswith(".bias"):
|
||||
optimizer_override_dict["weight_decay_scale"] = 0
|
||||
|
||||
if cfg.layer_decay > 0:
|
||||
optimizer_override_dict["lr_scale"] = layer_scales[layer_id]
|
||||
p.optim_overrides = {"optimizer": optimizer_override_dict}
|
||||
|
||||
@classmethod
|
||||
def build_model(cls, cfg: MaeImageClassificationConfig, task=None):
|
||||
"""Build a new model instance."""
|
||||
|
||||
return cls(cfg)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
imgs,
|
||||
labels=None,
|
||||
):
|
||||
if self.training and self.mixup_fn is not None and labels is not None:
|
||||
imgs, labels = self.mixup_fn(imgs, labels)
|
||||
|
||||
if self.linear_classifier:
|
||||
with torch.no_grad():
|
||||
x = self.model_forward(imgs)
|
||||
else:
|
||||
x = self.model_forward(imgs)
|
||||
|
||||
if self.cfg.prediction_mode == PredictionMode.MEAN_POOLING:
|
||||
x = x.mean(dim=1)
|
||||
elif self.cfg.prediction_mode == PredictionMode.CLS_TOKEN:
|
||||
x = x[:, 0]
|
||||
elif self.cfg.prediction_mode == PredictionMode.LIN_SOFTMAX:
|
||||
dtype = x.dtype
|
||||
x = F.logsigmoid(x.float())
|
||||
x = torch.logsumexp(x + x, dim=1) - torch.logsumexp(x + 1e-6, dim=1)
|
||||
x = x.clamp(max=0)
|
||||
x = x - torch.log(-(torch.expm1(x)))
|
||||
x = torch.nan_to_num(x, nan=0, posinf=0, neginf=0)
|
||||
x = x.to(dtype=dtype)
|
||||
else:
|
||||
raise Exception(f"unknown prediction mode {self.cfg.prediction_mode.name}")
|
||||
|
||||
if self.fc_norm is not None:
|
||||
x = self.fc_norm(x)
|
||||
|
||||
x = self.head(x)
|
||||
|
||||
if labels is None:
|
||||
return x
|
||||
|
||||
if self.training and self.mixup_fn is not None:
|
||||
loss = -labels * F.log_softmax(x.float(), dim=-1)
|
||||
else:
|
||||
loss = F.cross_entropy(
|
||||
x.float(),
|
||||
labels,
|
||||
label_smoothing=self.cfg.label_smoothing if self.training else 0,
|
||||
reduction="none",
|
||||
)
|
||||
|
||||
result = {
|
||||
"losses": {"regression": loss},
|
||||
"sample_size": imgs.size(0),
|
||||
}
|
||||
|
||||
if not self.training:
|
||||
with torch.no_grad():
|
||||
pred = x.argmax(-1)
|
||||
correct = (pred == labels).sum()
|
||||
result["correct"] = correct
|
||||
|
||||
return result
|
||||
|
||||
def model_forward(self, imgs):
|
||||
if self.d2v_multi:
|
||||
x = self.model.extract_features(
|
||||
imgs,
|
||||
mode="IMAGE",
|
||||
mask=False,
|
||||
remove_extra_tokens=(
|
||||
self.cfg.prediction_mode != PredictionMode.CLS_TOKEN
|
||||
),
|
||||
)["x"]
|
||||
else:
|
||||
x = self.model(imgs, predictions_only=True)
|
||||
if (
|
||||
"no_cls" not in self.model.cfg or not self.model.cfg.no_cls
|
||||
) and not self.cfg.prediction_mode == PredictionMode.CLS_TOKEN:
|
||||
x = x[:, 1:]
|
||||
return x
|
0
examples/data2vec/models/modalities/__init__.py
Normal file
0
examples/data2vec/models/modalities/__init__.py
Normal file
192
examples/data2vec/models/modalities/audio.py
Normal file
192
examples/data2vec/models/modalities/audio.py
Normal file
@ -0,0 +1,192 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from functools import partial
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Callable, Dict, Optional
|
||||
from fairseq.models.wav2vec import ConvFeatureExtractionModel
|
||||
from fairseq.modules import (
|
||||
LayerNorm,
|
||||
SamePad,
|
||||
TransposeLast,
|
||||
)
|
||||
from fairseq.tasks import FairseqTask
|
||||
from .base import D2vModalityConfig, ModalitySpecificEncoder, get_alibi_bias
|
||||
from .modules import BlockEncoder, Decoder1d
|
||||
from examples.data2vec.data.modality import Modality
|
||||
|
||||
|
||||
@dataclass
|
||||
class D2vAudioConfig(D2vModalityConfig):
|
||||
type: Modality = Modality.AUDIO
|
||||
extractor_mode: str = "layer_norm"
|
||||
feature_encoder_spec: str = field(
|
||||
default="[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512,2,2)] + [(512,2,2)]",
|
||||
metadata={
|
||||
"help": "string describing convolutional feature extraction layers in form of a python list that contains "
|
||||
"[(dim, kernel_size, stride), ...]"
|
||||
},
|
||||
)
|
||||
conv_pos_width: int = field(
|
||||
default=95,
|
||||
metadata={"help": "number of filters for convolutional positional embeddings"},
|
||||
)
|
||||
conv_pos_groups: int = field(
|
||||
default=16,
|
||||
metadata={"help": "number of groups for convolutional positional embedding"},
|
||||
)
|
||||
conv_pos_depth: int = field(
|
||||
default=5,
|
||||
metadata={"help": "depth of positional encoder network"},
|
||||
)
|
||||
conv_pos_pre_ln: bool = False
|
||||
|
||||
|
||||
class AudioEncoder(ModalitySpecificEncoder):
|
||||
|
||||
modality_cfg: D2vAudioConfig
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
modality_cfg: D2vAudioConfig,
|
||||
embed_dim: int,
|
||||
make_block: Callable[[float], nn.ModuleList],
|
||||
norm_layer: Callable[[int], nn.LayerNorm],
|
||||
layer_norm_first: bool,
|
||||
alibi_biases: Dict,
|
||||
task: Optional[FairseqTask],
|
||||
):
|
||||
|
||||
self.feature_enc_layers = eval(modality_cfg.feature_encoder_spec)
|
||||
feature_embed_dim = self.feature_enc_layers[-1][0]
|
||||
|
||||
local_encoder = ConvFeatureExtractionModel(
|
||||
conv_layers=self.feature_enc_layers,
|
||||
dropout=0.0,
|
||||
mode=modality_cfg.extractor_mode,
|
||||
conv_bias=False,
|
||||
)
|
||||
|
||||
project_features = nn.Sequential(
|
||||
TransposeLast(),
|
||||
nn.LayerNorm(feature_embed_dim),
|
||||
nn.Linear(feature_embed_dim, embed_dim),
|
||||
)
|
||||
|
||||
num_pos_layers = modality_cfg.conv_pos_depth
|
||||
k = max(3, modality_cfg.conv_pos_width // num_pos_layers)
|
||||
|
||||
positional_encoder = nn.Sequential(
|
||||
TransposeLast(),
|
||||
*[
|
||||
nn.Sequential(
|
||||
nn.Conv1d(
|
||||
embed_dim,
|
||||
embed_dim,
|
||||
kernel_size=k,
|
||||
padding=k // 2,
|
||||
groups=modality_cfg.conv_pos_groups,
|
||||
),
|
||||
SamePad(k),
|
||||
TransposeLast(),
|
||||
LayerNorm(embed_dim, elementwise_affine=False),
|
||||
TransposeLast(),
|
||||
nn.GELU(),
|
||||
)
|
||||
for _ in range(num_pos_layers)
|
||||
],
|
||||
TransposeLast(),
|
||||
)
|
||||
|
||||
if modality_cfg.conv_pos_pre_ln:
|
||||
positional_encoder = nn.Sequential(LayerNorm(embed_dim), positional_encoder)
|
||||
|
||||
dpr = np.linspace(
|
||||
modality_cfg.start_drop_path_rate,
|
||||
modality_cfg.end_drop_path_rate,
|
||||
modality_cfg.prenet_depth,
|
||||
)
|
||||
context_encoder = BlockEncoder(
|
||||
nn.ModuleList(make_block(dpr[i]) for i in range(modality_cfg.prenet_depth)),
|
||||
norm_layer(embed_dim) if not layer_norm_first else None,
|
||||
layer_norm_first,
|
||||
modality_cfg.prenet_layerdrop,
|
||||
modality_cfg.prenet_dropout,
|
||||
)
|
||||
|
||||
decoder = (
|
||||
Decoder1d(modality_cfg.decoder, embed_dim)
|
||||
if modality_cfg.decoder is not None
|
||||
else None
|
||||
)
|
||||
|
||||
alibi_bias_fn = partial(get_alibi_bias, alibi_biases=alibi_biases)
|
||||
|
||||
super().__init__(
|
||||
modality_cfg=modality_cfg,
|
||||
embed_dim=embed_dim,
|
||||
local_encoder=local_encoder,
|
||||
project_features=project_features,
|
||||
fixed_positional_encoder=None,
|
||||
relative_positional_encoder=positional_encoder,
|
||||
context_encoder=context_encoder,
|
||||
decoder=decoder,
|
||||
get_alibi_bias=alibi_bias_fn,
|
||||
)
|
||||
|
||||
def convert_padding_mask(self, x, padding_mask):
|
||||
def get_feat_extract_output_lengths(input_lengths: torch.LongTensor):
|
||||
"""
|
||||
Computes the output length of the convolutional layers
|
||||
"""
|
||||
|
||||
def _conv_out_length(input_length, kernel_size, stride):
|
||||
return torch.floor((input_length - kernel_size) / stride + 1)
|
||||
|
||||
for i in range(len(self.feature_enc_layers)):
|
||||
input_lengths = _conv_out_length(
|
||||
input_lengths,
|
||||
self.feature_enc_layers[i][1],
|
||||
self.feature_enc_layers[i][2],
|
||||
)
|
||||
|
||||
return input_lengths.to(torch.long)
|
||||
|
||||
if padding_mask is not None:
|
||||
input_lengths = (1 - padding_mask.long()).sum(-1)
|
||||
# apply conv formula to get real output_lengths
|
||||
output_lengths = get_feat_extract_output_lengths(input_lengths)
|
||||
|
||||
if padding_mask.any():
|
||||
padding_mask = torch.zeros(x.shape[:2], dtype=x.dtype, device=x.device)
|
||||
|
||||
# these two operations makes sure that all values
|
||||
# before the output lengths indices are attended to
|
||||
padding_mask[
|
||||
(
|
||||
torch.arange(padding_mask.shape[0], device=padding_mask.device),
|
||||
output_lengths - 1,
|
||||
)
|
||||
] = 1
|
||||
padding_mask = (
|
||||
1 - padding_mask.flip([-1]).cumsum(-1).flip([-1])
|
||||
).bool()
|
||||
else:
|
||||
padding_mask = torch.zeros(
|
||||
x.shape[:2], dtype=torch.bool, device=x.device
|
||||
)
|
||||
|
||||
return padding_mask
|
||||
|
||||
def reset_parameters(self):
|
||||
super().reset_parameters()
|
||||
for mod in self.project_features.children():
|
||||
if isinstance(mod, nn.Linear):
|
||||
mod.reset_parameters()
|
||||
if self.decoder is not None:
|
||||
self.decoder.reset_parameters()
|
684
examples/data2vec/models/modalities/base.py
Normal file
684
examples/data2vec/models/modalities/base.py
Normal file
@ -0,0 +1,684 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import math
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from collections import namedtuple
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from omegaconf import MISSING, II
|
||||
from typing import Optional, Callable
|
||||
from fairseq.data.data_utils import compute_mask_indices
|
||||
from fairseq.modules import GradMultiply
|
||||
from fairseq.utils import index_put
|
||||
from examples.data2vec.data.modality import Modality
|
||||
from .modules import D2vDecoderConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class D2vModalityConfig:
|
||||
type: Modality = MISSING
|
||||
prenet_depth: int = 4
|
||||
prenet_layerdrop: float = 0
|
||||
prenet_dropout: float = 0
|
||||
start_drop_path_rate: float = 0
|
||||
end_drop_path_rate: float = 0
|
||||
|
||||
num_extra_tokens: int = 0
|
||||
init_extra_token_zero: bool = True
|
||||
|
||||
mask_noise_std: float = 0.01
|
||||
mask_prob_min: Optional[float] = None
|
||||
mask_prob: float = 0.7
|
||||
inverse_mask: bool = False
|
||||
mask_prob_adjust: float = 0
|
||||
keep_masked_pct: float = 0
|
||||
|
||||
mask_length: int = 5
|
||||
add_masks: bool = False
|
||||
remove_masks: bool = False
|
||||
mask_dropout: float = 0.0
|
||||
encoder_zero_mask: bool = True
|
||||
|
||||
mask_channel_prob: float = 0.0
|
||||
mask_channel_length: int = 64
|
||||
|
||||
ema_local_encoder: bool = False # used in data2vec_multi
|
||||
local_grad_mult: float = 1.0
|
||||
|
||||
use_alibi_encoder: bool = False
|
||||
alibi_scale: float = 1.0
|
||||
learned_alibi: bool = False
|
||||
alibi_max_pos: Optional[int] = None
|
||||
learned_alibi_scale: bool = False
|
||||
learned_alibi_scale_per_head: bool = False
|
||||
learned_alibi_scale_per_layer: bool = False
|
||||
|
||||
num_alibi_heads: int = II("model.num_heads")
|
||||
model_depth: int = II("model.depth")
|
||||
|
||||
decoder: Optional[D2vDecoderConfig] = D2vDecoderConfig()
|
||||
|
||||
|
||||
MaskSeed = namedtuple("MaskSeed", ["seed", "update", "ids"])
|
||||
MaskInfo = namedtuple("MaskInfo", ["x_unmasked", "mask", "ids_restore", "ids_keep"])
|
||||
|
||||
|
||||
class ModalitySpecificEncoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
modality_cfg: D2vModalityConfig,
|
||||
embed_dim: int,
|
||||
local_encoder: nn.Module,
|
||||
project_features: nn.Module,
|
||||
fixed_positional_encoder: Optional[nn.Module],
|
||||
relative_positional_encoder: Optional[nn.Module],
|
||||
context_encoder: nn.Module,
|
||||
decoder: nn.Module,
|
||||
get_alibi_bias: Optional[Callable[[int, int, str, str], torch.Tensor]],
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.modality_cfg = modality_cfg
|
||||
self.local_encoder = local_encoder
|
||||
self.project_features = project_features
|
||||
self.fixed_positional_encoder = fixed_positional_encoder
|
||||
self.relative_positional_encoder = relative_positional_encoder
|
||||
self.context_encoder = context_encoder
|
||||
|
||||
self.decoder = decoder
|
||||
self.get_alibi_bias = get_alibi_bias if modality_cfg.use_alibi_encoder else None
|
||||
|
||||
self.local_grad_mult = self.modality_cfg.local_grad_mult
|
||||
|
||||
self.extra_tokens = None
|
||||
if modality_cfg.num_extra_tokens > 0:
|
||||
self.extra_tokens = nn.Parameter(
|
||||
torch.zeros(1, modality_cfg.num_extra_tokens, embed_dim)
|
||||
)
|
||||
if not modality_cfg.init_extra_token_zero:
|
||||
nn.init.normal_(self.extra_tokens)
|
||||
elif self.extra_tokens.size(1) > 1:
|
||||
nn.init.normal_(self.extra_tokens[:, 1:])
|
||||
|
||||
self.alibi_scale = None
|
||||
if self.get_alibi_bias is not None:
|
||||
self.alibi_scale = nn.Parameter(
|
||||
torch.full(
|
||||
(
|
||||
(modality_cfg.prenet_depth + modality_cfg.model_depth)
|
||||
if modality_cfg.learned_alibi_scale_per_layer
|
||||
else 1,
|
||||
1,
|
||||
self.modality_cfg.num_alibi_heads
|
||||
if modality_cfg.learned_alibi_scale_per_head
|
||||
else 1,
|
||||
1,
|
||||
1,
|
||||
),
|
||||
modality_cfg.alibi_scale,
|
||||
dtype=torch.float,
|
||||
),
|
||||
requires_grad=modality_cfg.learned_alibi_scale,
|
||||
)
|
||||
|
||||
if modality_cfg.learned_alibi and self.get_alibi_bias is not None:
|
||||
assert modality_cfg.alibi_max_pos is not None
|
||||
alibi_bias = self.get_alibi_bias(
|
||||
batch_size=1,
|
||||
time_steps=modality_cfg.alibi_max_pos,
|
||||
heads=modality_cfg.num_alibi_heads,
|
||||
scale=1.0,
|
||||
dtype=torch.float,
|
||||
device="cpu",
|
||||
)
|
||||
self.alibi_bias = nn.Parameter(alibi_bias)
|
||||
self.get_alibi_bias = partial(
|
||||
_learned_alibi_bias, alibi_bias=self.alibi_bias
|
||||
)
|
||||
|
||||
def upgrade_state_dict_named(self, state_dict, name):
|
||||
k = f"{name}.alibi_scale"
|
||||
if k in state_dict and state_dict[k].dim() == 4:
|
||||
state_dict[k] = state_dict[k].unsqueeze(0)
|
||||
|
||||
return state_dict
|
||||
|
||||
def convert_padding_mask(self, x, padding_mask):
|
||||
return padding_mask
|
||||
|
||||
def decoder_input(self, x, mask_info: MaskInfo):
|
||||
inp_drop = self.modality_cfg.decoder.input_dropout
|
||||
if inp_drop > 0:
|
||||
x = F.dropout(x, inp_drop, training=self.training, inplace=True)
|
||||
|
||||
num_extra = self.modality_cfg.num_extra_tokens
|
||||
|
||||
if mask_info is not None:
|
||||
num_masked = mask_info.ids_restore.shape[1] - x.shape[1] + num_extra
|
||||
|
||||
mask_tokens = x.new_empty(
|
||||
x.size(0),
|
||||
num_masked,
|
||||
x.size(-1),
|
||||
).normal_(0, self.modality_cfg.mask_noise_std)
|
||||
|
||||
x_ = torch.cat([x[:, num_extra:], mask_tokens], dim=1)
|
||||
x = torch.gather(x_, dim=1, index=mask_info.ids_restore)
|
||||
|
||||
if self.modality_cfg.decoder.add_positions_masked:
|
||||
assert self.fixed_positional_encoder is not None
|
||||
pos = self.fixed_positional_encoder(x, None)
|
||||
x = x + (pos * mask_info.mask.unsqueeze(-1))
|
||||
else:
|
||||
x = x[:, num_extra:]
|
||||
|
||||
if self.modality_cfg.decoder.add_positions_all:
|
||||
assert self.fixed_positional_encoder is not None
|
||||
x = x + self.fixed_positional_encoder(x, None)
|
||||
|
||||
return x, mask_info
|
||||
|
||||
def local_features(self, features):
|
||||
if self.local_grad_mult > 0:
|
||||
if self.local_grad_mult == 1.0:
|
||||
x = self.local_encoder(features)
|
||||
else:
|
||||
x = GradMultiply.apply(
|
||||
self.local_encoder(features), self.local_grad_mult
|
||||
)
|
||||
else:
|
||||
with torch.no_grad():
|
||||
x = self.local_encoder(features)
|
||||
|
||||
x = self.project_features(x)
|
||||
return x
|
||||
|
||||
def contextualized_features(
|
||||
self,
|
||||
x,
|
||||
padding_mask,
|
||||
mask,
|
||||
remove_masked,
|
||||
clone_batch: int = 1,
|
||||
mask_seeds: Optional[torch.Tensor] = None,
|
||||
precomputed_mask=None,
|
||||
):
|
||||
|
||||
if padding_mask is not None:
|
||||
padding_mask = self.convert_padding_mask(x, padding_mask)
|
||||
|
||||
local_features = x
|
||||
if mask and clone_batch == 1:
|
||||
local_features = local_features.clone()
|
||||
|
||||
orig_B, orig_T, _ = x.shape
|
||||
pre_mask_B = orig_B
|
||||
mask_info = None
|
||||
|
||||
x_pos = None
|
||||
if self.fixed_positional_encoder is not None:
|
||||
x = x + self.fixed_positional_encoder(x, padding_mask)
|
||||
|
||||
if mask:
|
||||
if clone_batch > 1:
|
||||
x = x.repeat_interleave(clone_batch, 0)
|
||||
if mask_seeds is not None:
|
||||
clone_hash = [
|
||||
int(hash((mask_seeds.seed, ind)) % 1e10)
|
||||
for ind in range(clone_batch - 1)
|
||||
]
|
||||
clone_hash = torch.tensor([0] + clone_hash).long().view(1, -1)
|
||||
|
||||
id = mask_seeds.ids
|
||||
id = id.repeat_interleave(clone_batch, 0)
|
||||
id = id.view(-1, clone_batch) + clone_hash.to(id)
|
||||
id = id.view(-1)
|
||||
mask_seeds = MaskSeed(
|
||||
seed=mask_seeds.seed, update=mask_seeds.update, ids=id
|
||||
)
|
||||
if padding_mask is not None:
|
||||
padding_mask = padding_mask.repeat_interleave(clone_batch, 0)
|
||||
|
||||
x, mask_info = self.compute_mask(
|
||||
x,
|
||||
padding_mask,
|
||||
mask_seed=mask_seeds,
|
||||
apply=self.relative_positional_encoder is not None or not remove_masked,
|
||||
precomputed_mask=precomputed_mask,
|
||||
)
|
||||
|
||||
if self.relative_positional_encoder is not None:
|
||||
x_pos = self.relative_positional_encoder(x)
|
||||
|
||||
masked_padding_mask = padding_mask
|
||||
if mask and remove_masked:
|
||||
x = mask_info.x_unmasked
|
||||
if x_pos is not None:
|
||||
x = x + gather_unmasked(x_pos, mask_info)
|
||||
|
||||
if padding_mask is not None and padding_mask.any():
|
||||
masked_padding_mask = gather_unmasked_mask(padding_mask, mask_info)
|
||||
if not masked_padding_mask.any():
|
||||
masked_padding_mask = None
|
||||
else:
|
||||
masked_padding_mask = None
|
||||
|
||||
elif x_pos is not None:
|
||||
x = x + x_pos
|
||||
|
||||
alibi_bias = None
|
||||
alibi_scale = self.alibi_scale
|
||||
|
||||
if self.get_alibi_bias is not None:
|
||||
alibi_bias = self.get_alibi_bias(
|
||||
batch_size=pre_mask_B,
|
||||
time_steps=orig_T,
|
||||
heads=self.modality_cfg.num_alibi_heads,
|
||||
dtype=torch.float32,
|
||||
device=x.device,
|
||||
)
|
||||
|
||||
if alibi_scale is not None:
|
||||
alibi_scale = alibi_scale.clamp_min(0)
|
||||
if alibi_scale.size(0) == 1:
|
||||
alibi_bias = alibi_bias * alibi_scale.squeeze(0).type_as(alibi_bias)
|
||||
alibi_scale = None
|
||||
|
||||
if clone_batch > 1:
|
||||
alibi_bias = alibi_bias.repeat_interleave(clone_batch, 0)
|
||||
|
||||
if mask_info is not None and remove_masked:
|
||||
alibi_bias = masked_alibi(alibi_bias, mask_info)
|
||||
|
||||
if self.extra_tokens is not None:
|
||||
num = self.extra_tokens.size(1)
|
||||
x = torch.cat([self.extra_tokens.expand(x.size(0), -1, -1), x], dim=1)
|
||||
if masked_padding_mask is not None:
|
||||
# B x T
|
||||
masked_padding_mask = F.pad(masked_padding_mask, (num, 0))
|
||||
if alibi_bias is not None:
|
||||
# B x H x T x T
|
||||
alibi_bias = F.pad(alibi_bias, (num, 0, num, 0))
|
||||
|
||||
x = self.context_encoder(
|
||||
x,
|
||||
masked_padding_mask,
|
||||
alibi_bias,
|
||||
alibi_scale[: self.modality_cfg.prenet_depth]
|
||||
if alibi_scale is not None
|
||||
else None,
|
||||
)
|
||||
|
||||
return {
|
||||
"x": x,
|
||||
"local_features": local_features,
|
||||
"padding_mask": masked_padding_mask,
|
||||
"alibi_bias": alibi_bias,
|
||||
"alibi_scale": alibi_scale[self.modality_cfg.prenet_depth :]
|
||||
if alibi_scale is not None and alibi_scale.size(0) > 1
|
||||
else alibi_scale,
|
||||
"encoder_mask": mask_info,
|
||||
}
|
||||
|
||||
def forward(
|
||||
self,
|
||||
features,
|
||||
padding_mask,
|
||||
mask: bool,
|
||||
remove_masked: bool,
|
||||
clone_batch: int = 1,
|
||||
mask_seeds: Optional[torch.Tensor] = None,
|
||||
precomputed_mask=None,
|
||||
):
|
||||
x = self.local_features(features)
|
||||
return self.contextualized_features(
|
||||
x,
|
||||
padding_mask,
|
||||
mask,
|
||||
remove_masked,
|
||||
clone_batch,
|
||||
mask_seeds,
|
||||
precomputed_mask,
|
||||
)
|
||||
|
||||
def reset_parameters(self):
|
||||
pass
|
||||
|
||||
def compute_mask(
|
||||
self,
|
||||
x,
|
||||
padding_mask,
|
||||
mask_seed: Optional[MaskSeed],
|
||||
apply,
|
||||
precomputed_mask,
|
||||
):
|
||||
if precomputed_mask is not None:
|
||||
mask = precomputed_mask
|
||||
mask_info = self.make_maskinfo(x, mask)
|
||||
else:
|
||||
B, T, C = x.shape
|
||||
cfg = self.modality_cfg
|
||||
|
||||
mask_prob = cfg.mask_prob
|
||||
|
||||
if (
|
||||
cfg.mask_prob_min is not None
|
||||
and cfg.mask_prob_min >= 0
|
||||
and cfg.mask_prob_min < mask_prob
|
||||
):
|
||||
mask_prob = np.random.uniform(cfg.mask_prob_min, mask_prob)
|
||||
|
||||
if mask_prob > 0:
|
||||
if cfg.mask_length == 1:
|
||||
mask_info = random_masking(x, mask_prob, mask_seed)
|
||||
else:
|
||||
if self.modality_cfg.inverse_mask:
|
||||
mask_prob = 1 - mask_prob
|
||||
|
||||
mask = compute_mask_indices(
|
||||
(B, T),
|
||||
padding_mask,
|
||||
mask_prob,
|
||||
cfg.mask_length,
|
||||
min_masks=1,
|
||||
require_same_masks=True,
|
||||
mask_dropout=cfg.mask_dropout,
|
||||
add_masks=cfg.add_masks,
|
||||
seed=mask_seed.seed if mask_seed is not None else None,
|
||||
epoch=mask_seed.update if mask_seed is not None else None,
|
||||
indices=mask_seed.ids if mask_seed is not None else None,
|
||||
)
|
||||
|
||||
mask = torch.from_numpy(mask).to(device=x.device)
|
||||
if self.modality_cfg.inverse_mask:
|
||||
mask = 1 - mask
|
||||
mask_info = self.make_maskinfo(x, mask)
|
||||
else:
|
||||
mask_info = None
|
||||
|
||||
if apply:
|
||||
x = self.apply_mask(x, mask_info)
|
||||
|
||||
return x, mask_info
|
||||
|
||||
def make_maskinfo(self, x, mask, shape=None):
|
||||
if shape is None:
|
||||
B, T, D = x.shape
|
||||
else:
|
||||
B, T, D = shape
|
||||
|
||||
mask = mask.to(torch.uint8)
|
||||
ids_shuffle = mask.argsort(dim=1)
|
||||
ids_restore = ids_shuffle.argsort(dim=1).unsqueeze(-1).expand(-1, -1, D)
|
||||
|
||||
len_keep = T - mask[0].sum()
|
||||
if self.modality_cfg.keep_masked_pct > 0:
|
||||
len_keep += round((T - int(len_keep)) * self.modality_cfg.keep_masked_pct)
|
||||
|
||||
ids_keep = ids_shuffle[:, :len_keep]
|
||||
|
||||
if shape is not None:
|
||||
x_unmasked = None
|
||||
else:
|
||||
ids_keep = ids_keep.unsqueeze(-1).expand(-1, -1, D)
|
||||
x_unmasked = torch.gather(x, dim=1, index=ids_keep)
|
||||
|
||||
mask_info = MaskInfo(
|
||||
x_unmasked=x_unmasked,
|
||||
mask=mask,
|
||||
ids_restore=ids_restore,
|
||||
ids_keep=ids_keep,
|
||||
)
|
||||
return mask_info
|
||||
|
||||
def apply_mask(self, x, mask_info):
|
||||
cfg = self.modality_cfg
|
||||
B, T, C = x.shape
|
||||
|
||||
if mask_info is not None:
|
||||
mask = mask_info.mask
|
||||
if cfg.encoder_zero_mask:
|
||||
x = x * (1 - mask.type_as(x).unsqueeze(-1))
|
||||
else:
|
||||
num_masks = mask.sum().item()
|
||||
masks = x.new_empty(num_masks, x.size(-1)).normal_(
|
||||
0, cfg.mask_noise_std
|
||||
)
|
||||
x = index_put(x, mask, masks)
|
||||
if cfg.mask_channel_prob > 0:
|
||||
mask_channel = compute_mask_indices(
|
||||
(B, C),
|
||||
None,
|
||||
cfg.mask_channel_prob,
|
||||
cfg.mask_channel_length,
|
||||
)
|
||||
mask_channel = (
|
||||
torch.from_numpy(mask_channel)
|
||||
.to(x.device)
|
||||
.unsqueeze(1)
|
||||
.expand(-1, T, -1)
|
||||
)
|
||||
x = index_put(x, mask_channel, 0)
|
||||
return x
|
||||
|
||||
def remove_pretraining_modules(self, keep_decoder=False):
|
||||
if not keep_decoder:
|
||||
self.decoder = None
|
||||
|
||||
|
||||
def get_annealed_rate(start, end, curr_step, total_steps):
|
||||
if curr_step >= total_steps:
|
||||
return end
|
||||
r = end - start
|
||||
pct_remaining = 1 - curr_step / total_steps
|
||||
return end - r * pct_remaining
|
||||
|
||||
|
||||
# adapted from MAE
|
||||
def random_masking(x, mask_ratio, mask_seed: Optional[MaskSeed]):
|
||||
N, L, D = x.shape # batch, length, dim
|
||||
len_keep = int(L * (1 - mask_ratio))
|
||||
|
||||
generator = None
|
||||
if mask_seed is not None:
|
||||
seed = int(
|
||||
hash((mask_seed.seed, mask_seed.update, mask_seed.ids.sum().item())) % 1e6
|
||||
)
|
||||
generator = torch.Generator(device=x.device)
|
||||
generator.manual_seed(seed)
|
||||
|
||||
noise = torch.rand(N, L, generator=generator, device=x.device) # noise in [0, 1]
|
||||
|
||||
# sort noise for each sample
|
||||
ids_shuffle = noise.argsort(dim=1) # ascend: small is keep, large is remove
|
||||
ids_restore = ids_shuffle.argsort(dim=1)
|
||||
|
||||
# keep the first subset
|
||||
ids_keep = ids_shuffle[:, :len_keep]
|
||||
ids_keep = ids_keep.unsqueeze(-1).expand(-1, -1, D)
|
||||
x_unmasked = torch.gather(x, dim=1, index=ids_keep)
|
||||
|
||||
# generate the binary mask: 0 is keep, 1 is remove
|
||||
mask = torch.ones([N, L], dtype=x.dtype, device=x.device)
|
||||
mask[:, :len_keep] = 0
|
||||
# unshuffle to get the binary mask
|
||||
mask = torch.gather(mask, dim=1, index=ids_restore)
|
||||
|
||||
ids_restore = ids_restore.unsqueeze(-1).expand(-1, -1, D)
|
||||
|
||||
return MaskInfo(
|
||||
x_unmasked=x_unmasked, mask=mask, ids_restore=ids_restore, ids_keep=ids_keep
|
||||
)
|
||||
|
||||
|
||||
def gather_unmasked(x: torch.Tensor, mask_info: MaskInfo) -> torch.Tensor:
|
||||
return torch.gather(
|
||||
x,
|
||||
dim=1,
|
||||
index=mask_info.ids_keep,
|
||||
)
|
||||
|
||||
|
||||
def gather_unmasked_mask(x: torch.Tensor, mask_info: MaskInfo) -> torch.Tensor:
|
||||
return torch.gather(
|
||||
x,
|
||||
dim=1,
|
||||
index=mask_info.ids_keep[..., 0], # ignore the feature dimension
|
||||
)
|
||||
|
||||
|
||||
def get_alibi(
|
||||
max_positions: int,
|
||||
attention_heads: int,
|
||||
dims: int = 1,
|
||||
distance: str = "manhattan",
|
||||
):
|
||||
def get_slopes(n):
|
||||
def get_slopes_power_of_2(n):
|
||||
start = 2 ** (-(2 ** -(math.log2(n) - 3)))
|
||||
ratio = start
|
||||
return [start * ratio**i for i in range(n)]
|
||||
|
||||
# In the paper, we only train models that have 2^a heads for some
|
||||
# a. This function has some good properties that only occur when
|
||||
# the input is a power of 2. To maintain that even when the number
|
||||
# of heads is not a power of 2, we use this workaround.
|
||||
if math.log2(n).is_integer():
|
||||
return get_slopes_power_of_2(n)
|
||||
else:
|
||||
closest_power_of_2 = 2 ** math.floor(math.log2(n))
|
||||
return (
|
||||
get_slopes_power_of_2(closest_power_of_2)
|
||||
+ get_slopes(2 * closest_power_of_2)[0::2][: n - closest_power_of_2]
|
||||
)
|
||||
|
||||
maxpos = max_positions
|
||||
attn_heads = attention_heads
|
||||
slopes = torch.Tensor(get_slopes(attn_heads))
|
||||
|
||||
if dims == 1:
|
||||
# prepare alibi position linear bias. Note that wav2vec2 is non
|
||||
# autoregressive model so we want a symmetric mask with 0 on the
|
||||
# diagonal and other wise linear decreasing valuees
|
||||
pos_bias = (
|
||||
torch.abs(
|
||||
torch.arange(maxpos).unsqueeze(0) - torch.arange(maxpos).unsqueeze(1)
|
||||
)
|
||||
* -1
|
||||
)
|
||||
elif dims == 2:
|
||||
if distance == "manhattan":
|
||||
df = lambda x1, y1, x2, y2: abs(x1 - x2) + abs(y1 - y2)
|
||||
elif distance == "euclidean":
|
||||
df = lambda x1, y1, x2, y2: math.sqrt((x1 - x2) ** 2 + (y1 - y2) ** 2)
|
||||
|
||||
n = math.sqrt(max_positions)
|
||||
assert n.is_integer(), n
|
||||
n = int(n)
|
||||
|
||||
pos_bias = torch.zeros((max_positions, max_positions))
|
||||
|
||||
for i in range(n):
|
||||
for j in range(n):
|
||||
for k in range(n):
|
||||
for l in range(n):
|
||||
new_x = i * n + j
|
||||
new_y = k * n + l
|
||||
pos_bias[new_x, new_y] = -df(i, j, k, l)
|
||||
|
||||
else:
|
||||
raise Exception(f"unsupported number of alibi dims: {dims}")
|
||||
|
||||
alibi_bias = slopes.unsqueeze(1).unsqueeze(1) * pos_bias.unsqueeze(0).expand(
|
||||
attn_heads, -1, -1
|
||||
)
|
||||
|
||||
return alibi_bias
|
||||
|
||||
|
||||
def get_alibi_bias(
|
||||
alibi_biases,
|
||||
batch_size,
|
||||
time_steps,
|
||||
heads,
|
||||
dtype,
|
||||
device,
|
||||
dims=1,
|
||||
distance="manhattan",
|
||||
):
|
||||
cache_key = f"{dims}_{heads}_{distance}"
|
||||
|
||||
buffered = alibi_biases.get(cache_key, None)
|
||||
|
||||
target_size = heads * batch_size
|
||||
if (
|
||||
buffered is None
|
||||
or buffered.size(0) < target_size
|
||||
or buffered.size(1) < time_steps
|
||||
or buffered.dtype != dtype
|
||||
or buffered.device != device
|
||||
):
|
||||
bt = max(time_steps, buffered.size(1) if buffered is not None else 0)
|
||||
bn = max(target_size, buffered.size(0) if buffered is not None else 0) // heads
|
||||
|
||||
buffered = (
|
||||
get_alibi(bt, heads, dims=dims, distance=distance)
|
||||
.to(dtype=dtype, device=device)
|
||||
.repeat(bn, 1, 1)
|
||||
)
|
||||
|
||||
alibi_biases[cache_key] = buffered
|
||||
|
||||
b = buffered[:target_size, :time_steps, :time_steps]
|
||||
b = b.view(batch_size, heads, time_steps, time_steps)
|
||||
return b
|
||||
|
||||
|
||||
def _learned_alibi_bias(
|
||||
alibi_bias,
|
||||
batch_size,
|
||||
time_steps,
|
||||
heads,
|
||||
scale,
|
||||
dtype,
|
||||
device,
|
||||
):
|
||||
assert alibi_bias.size(1) == heads, alibi_bias.shape
|
||||
assert alibi_bias.dtype == dtype, alibi_bias.dtype
|
||||
assert alibi_bias.device == device, alibi_bias.device
|
||||
|
||||
if alibi_bias.size(-1) < time_steps:
|
||||
psz = math.ceil((time_steps - alibi_bias.size(-1)) / 2)
|
||||
alibi_bias = F.pad(alibi_bias, (psz, psz, psz, psz), mode="replicate")
|
||||
|
||||
alibi_bias = alibi_bias.expand(batch_size, -1, -1, -1) * scale
|
||||
return alibi_bias[..., :time_steps, :time_steps]
|
||||
|
||||
|
||||
def masked_alibi(alibi_bias, mask_info):
|
||||
H = alibi_bias.size(1)
|
||||
|
||||
orig_bias = alibi_bias
|
||||
|
||||
index = mask_info.ids_keep.unsqueeze(1)[..., 0].unsqueeze(-1)
|
||||
alibi_bias = torch.gather(
|
||||
orig_bias,
|
||||
dim=-2,
|
||||
index=index.expand(-1, H, -1, mask_info.ids_restore.size(1)),
|
||||
)
|
||||
alibi_bias = torch.gather(
|
||||
alibi_bias,
|
||||
dim=-1,
|
||||
index=index.transpose(-1, -2).expand(-1, H, alibi_bias.size(-2), -1),
|
||||
)
|
||||
|
||||
return alibi_bias
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user