fairseq/examples/attention_head_selection
Hongyu Gong a59cea5944 attn head selection
Summary:
Add scripts for multihead attention selection in multilingual and multil-domain training from the following paper:
"Pay Better Attention to Attention: Head Selection in Multilingual and Multi-Domain Sequence Modeling", NeurIPS 2021.

Reviewed By: yuntang

Differential Revision: D31802221

fbshipit-source-id: 8c69b89bda29e6857bd3af02979c07e1b5cf49f1
2022-01-18 21:15:34 -08:00
..
src attn head selection 2022-01-18 21:15:34 -08:00
README.md attn head selection 2022-01-18 21:15:34 -08:00

Pay Better Attention to Attention: Head Selection in Multilingual and Multi-Domain Sequence Modeling (Gong et al., 2021)

https://arxiv.org/pdf/2106.10840.pdf

Introduction

We present attention head selection strategies in multilingual and multi-domain sequence modeling including text translation, speech recognition and speech translation tasks.

Below is an example of training multilingual/multi-domain speech recognition models.

Data Preparation

Prepare mTEDx data as in mTEDx example and CoVoST data as in CoVoST example. Similarly prepare EuroParl data.

Training a multilingual ASR model with attention head selection

data_dir=<path to mtedx data>
train_subset="train_ar_ar_tedx,train_de_de_tedx,train_el_el_tedx,train_es_es_tedx,train_fr_fr_tedx,train_it_it_tedx,train_pt_pt_tedx,train_ru_ru_tedx"
valid_subset="valid_ar_ar_tedx,valid_de_de_tedx,valid_el_el_tedx,valid_es_es_tedx,valid_fr_fr_tedx,valid_it_it_tedx,valid_pt_pt_tedx,valid_ru_ru_tedx"
strateg=<subset or group>

fairseq-train ${data_dir} \
    --user-dir examples/attention_head_selection/src \
    --train-subset "${train_subset}" \
    --valid-subset "${valid_subset}" \
    --config-yaml 'config_asr.yaml' \
    --arch 'head_selection_s2t_transformer_s' \
    --task 'speech_to_text_head_selection' \
    --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
    --lr-scheduler 'inverse_sqrt' --stop-min-lr -1.0 --warmup-updates 10000 \
    --lr 5e-4 \
    --clip-norm 10.0 \
    --seed 1 \
    --max-epoch 400 \
    --max-tokens 32000 \
    --ignore-prefix-size 1 \
    --dropout 0.3 \
    --optimizer adam --adam-eps 1e-06 --adam-betas '(0.9, 0.98)' \
    --skip-invalid-size-inputs-valid-test \
    --encoder-attn-head-select \
    --total-encoder-attention-heads 8 \
    --decoder-self-attn-head-select \
    --total-decoder-attention-heads 8 \
    --attn-head-select-strategy ${strategy} \
    --task-type lang \

Training a multi-domain ASR model with attention head selection

data_dir=<path to multi-domain data>
train_subset="train_es_es_tedx,train_fr_fr_tedx,train_pt_pt_tedx,train_it_it_tedx,train_ru_ru_tedx,train_el_el_tedx,train_ar_ar_tedx,train_de_de_tedx,train_ar_ar_cv,train_de_de_cv,train_es_es_cv,train_fr_fr_cv,train_it_it_cv,train_pt_pt_cv,train_ru_ru_cv,train_de_de_ep,train_es_es_ep,train_fr_fr_ep,train_it_it_ep,train_pt_pt_ep"
valid_subset="dev_es_es_tedx,dev_fr_fr_tedx,dev_pt_pt_tedx,dev_it_it_tedx,dev_ru_ru_tedx,dev_el_el_tedx,dev_ar_ar_tedx,dev_de_de_tedx,dev_ar_ar_cv,dev_de_de_cv,dev_es_es_cv,dev_fr_fr_cv,dev_it_it_cv,dev_pt_pt_cv,dev_ru_ru_cv,dev_de_de_ep,dev_es_es_ep,dev_fr_fr_ep,dev_it_it_ep,dev_pt_pt_ep"
strateg=<subset or group>

fairseq-train ${data_dir} \
    --user-dir examples/attention_head_selection/src \
    --train-subset "${train_subset}" \
    --valid-subset "${valid_subset}" \
    --config-yaml 'config_asr.yaml' \
    --arch head_selection_s2t_transformer_s \
    --task speech_to_text_head_selection \
    --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
    --lr-scheduler 'inverse_sqrt' --stop-min-lr -1.0 --warmup-updates 10000 \
    --lr 5e-4 \
    --clip-norm 10.0 \
    --seed 1 \
    --max-epoch 400 \
    --max-tokens 32000 \
    --ignore-prefix-size 1 \
    --dropout 0.3 \
    --optimizer adam --adam-eps 1e-06 --adam-betas '(0.9, 0.98)' \
    --skip-invalid-size-inputs-valid-test \
    --encoder-attn-head-select \
    --total-encoder-attention-heads 8 \
    --decoder-self-attn-head-select \
    --total-decoder-attention-heads 8 \
    --attn-head-select-strategy ${strategy} \
    --task-type domain

Inference in multilingual setting

MODEL_DIR=<checkpoint directory>
data_dir=<path to mtedx data>
gen_subset=<data to test, e.g., test_ar_ar_tedx>
train_subset="train_ar_ar_tedx,train_de_de_tedx,train_el_el_tedx,train_es_es_tedx,train_fr_fr_tedx,train_it_it_tedx,train_pt_pt_tedx,train_ru_ru_tedx"
last_n=10
CHECKPOINT_FILENAME="avg_last_${last_n}_checkpoint.pt"
CHECKPOINT="_avg"
RESULTS="${MODEL_DIR}/ckpt${CHECKPOINT}"
if [ ! -d $RESULTS ]; then
    mkdir -p $RESULTS
fi;

python scripts/average_checkpoints.py \
  --inputs ${MODEL_DIR} --num-epoch-checkpoints ${last_n} \
  --output "${MODEL_DIR}/${CHECKPOINT_FILENAME}"

fairseq-generate ${data_dir} \
    --user-dir examples/attention_head_selection/src \
    --arch 'head_selection_s2t_transformer_s' \
    --task 'speech_to_text_head_selection' \
    --train-subset ${train_subset} \
    --gen-subset ${gen_subset} \
    --path "${MODEL_DIR}/${CHECKPOINT_FILENAME}" \
    --config-yaml 'config_asr.yaml' \
    --prefix-size 1 \
    --max-tokens 40000 --beam 5 \
    --skip-invalid-size-inputs-valid-test \
    --results-path ${RESULTS} \
    --scoring wer --wer-tokenizer 13a \
    --wer-lowercase --wer-remove-punct --remove-bpe

Inference in multi-domain setting

MODEL_DIR=<checkpoint directory>
data_dir=<path to multi-domain data>
gen_subset=<data to test, e.g., test_pt_pt_cv>
train_subset="train_es_es_tedx,train_fr_fr_tedx,train_pt_pt_tedx,train_it_it_tedx,train_ru_ru_tedx,train_el_el_tedx,train_ar_ar_tedx,train_de_de_tedx,train_ar_ar_cv,train_de_de_cv,train_es_es_cv,train_fr_fr_cv,train_it_it_cv,train_pt_pt_cv,train_ru_ru_cv,train_de_de_ep,train_es_es_ep,train_fr_fr_ep,train_it_it_ep,train_pt_pt_ep"
last_n=10
CHECKPOINT_FILENAME="avg_last_${last_n}_checkpoint.pt"
CHECKPOINT="_avg"
RESULTS="${MODEL_DIR}/ckpt${CHECKPOINT}"
if [ ! -d $RESULTS ]; then
    mkdir -p $RESULTS
fi;

python scripts/average_checkpoints.py \
  --inputs ${MODEL_DIR} --num-epoch-checkpoints ${last_n} \
  --output "${MODEL_DIR}/${CHECKPOINT_FILENAME}"

fairseq-generate ${data_dir} \
    --user-dir examples/attention_head_selection/src \
    --arch 'head_selection_s2t_transformer_s' \
    --task 'speech_to_text_head_selection' \
    --train-subset ${train_subset} \
    --gen-subset ${gen_subset} \
    --path "${MODEL_DIR}/${CHECKPOINT_FILENAME}" \
    --config-yaml 'config_asr.yaml' \
    --prefix-size 1 \
    --max-tokens 40000 --beam 5 \
    --skip-invalid-size-inputs-valid-test \
    --results-path ${RESULTS} \
    --scoring wer --wer-tokenizer 13a \
    --wer-lowercase --wer-remove-punct --remove-bpe

Citation

@article{gong2021pay,
  title={Pay Better Attention to Attention: Head Selection in Multilingual and Multi-Domain Sequence Modeling},
  author={Gong, Hongyu and Tang, Yun and Pino, Juan and Li, Xian},
  journal={arXiv preprint arXiv:2106.10840},
  year={2021}
}
'''