Merge branch 'pytorch:main' into recipe

This commit is contained in:
Jiatong 2022-04-18 14:40:40 -04:00 committed by GitHub
commit 231a6cd32e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 80 additions and 85 deletions

18
.github/CODEOWNERS vendored Normal file
View File

@ -0,0 +1,18 @@
# Setting up CODEOWNERS for UST related codebase
# Documentation for open sourced models relevant to UST
examples/speech_to_text @kahne @sravyapopuri388 @jmp84
examples/speech_to_speech @an918tw @sravyapopuri388 @jmp84
examples/speech_synthesis @kahne @jmp84
examples/simultaneous_translation @kahne @jmp84
examples/speech_text_joint_to_text @yuntang @jmp84
# Speech related models relevant to UST
fairseq/models/speech_to_speech @sravyapopuri388 @jmp84
fairseq/models/speech_to_text @kahne @sravyapopuri388 @jmp84
fairseq/models/text_to_speech @kahne @jmp84
# CONFORMER IMPLEMENTATION
fairseq/modules/conformer_layer.py @sravyapopuri388 @jmp84
fairseq/modules/espnet_multihead_attention.py @sravyapopuri388 @jmp84
fairseq/modules/rotary_positional_embedding.py @sravyapopuri388 @jmp84
fairseq/modules/positional_encoding.py @sravyapopuri388 @jmp84

View File

@ -1,5 +1,7 @@
# Speech to speech translation (S2ST)
We provide the implementation and resources for the following work on speech-to-speech translation (S2ST):
* [Direct speech-to-speech translation with discrete units (Lee et al. 2021)](docs/direct_s2st_discrete_units.md)
* [Textless Speech-to-Speech Translation on Real Data (Lee et al. 2021)](docs/textless_s2st_real_data.md)
* [Direct speech-to-speech translation with discrete units (Lee et al. 2021)](docs/direct_s2st_discrete_units.md)
* [Textless Speech-to-Speech Translation on Real Data (Lee et al. 2021)](docs/textless_s2st_real_data.md)
* [Enhanced Direct Speech-to-Speech Translation Using Self-supervised Pre-training and Data Augmentation](docs/enhanced_direct_s2st_discrete_units.md)

View File

@ -1,31 +1,34 @@
# Speech to speech translation (S2ST)
We provide the implementation for speech-to-unit translation (S2UT) proposed in [_Enhanced Direct Speech-to-Speech Translation Using Self-supervised Pre-training and Data Augmentation_](arxiv link) and the various pretrained models used.
We provide the implementation for speech-to-unit translation (S2UT) proposed in [Enhanced Direct Speech-to-Speech Translation Using Self-supervised Pre-training and Data Augmentation (Popuri et al. 2022)](https://arxiv.org/abs/2204.02967) and the various pretrained models used.
## Pretrained Models
### Unit extraction
### Unit extraction
We used the multilingual HuBERT model open sourced in [Textless S2ST with Real Data](textless_s2st_real_data.md)
### Wav2vec 2.0
Language | Block type | Model size | Dataset | Model |
--- | --- | --- | --- | --- |
Es | Transformer | BASE | Voxpopuli | [ckpt ](https://dl.fbaipublicfiles.com/fairseq/speech_to_speech/s2st_finetuning/w2v2/es/transformer_B.pt) |
Es | Transformer | LARGE | Voxpopuli | [ckpt ](https://dl.fbaipublicfiles.com/fairseq/speech_to_speech/s2st_finetuning/w2v2/es/transformer_L.pt) |
Es | Conformer | LARGE | Voxpopuli | [ckpt ](https://dl.fbaipublicfiles.com/fairseq/speech_to_speech/s2st_finetuning/w2v2/es/conformer_L.pt) |
En | Transformer | BASE | Librilight| [ckpt ](https://dl.fbaipublicfiles.com/fairseq/speech_to_speech/s2st_finetuning/w2v2/en/transformer_B.pt) |
En | Conformer | LARGE | Librilight | [ckpt ](https://dl.fbaipublicfiles.com/fairseq/speech_to_speech/s2st_finetuning/w2v2/en/conformer_L.pt) |
Es | Transformer | BASE | Voxpopuli | [ckpt](https://dl.fbaipublicfiles.com/fairseq/speech_to_speech/s2st_finetuning/w2v2/es/transformer_B.pt) |
Es | Transformer | LARGE | Voxpopuli | [ckpt](https://dl.fbaipublicfiles.com/fairseq/speech_to_speech/s2st_finetuning/w2v2/es/transformer_L.pt) |
Es | Conformer | LARGE | Voxpopuli | [ckpt](https://dl.fbaipublicfiles.com/fairseq/speech_to_speech/s2st_finetuning/w2v2/es/conformer_L.pt) |
En | Transformer | BASE | Librilight| [ckpt](https://dl.fbaipublicfiles.com/fairseq/speech_to_speech/s2st_finetuning/w2v2/en/transformer_B.pt) |
En | Conformer | LARGE | Librilight | [ckpt](https://dl.fbaipublicfiles.com/fairseq/speech_to_speech/s2st_finetuning/w2v2/en/conformer_L.pt) |
### Unit mBART
Unit size | Dataset | Unit config | Model |
--- | --- | --- | --- |
1000 | [Voxpopuli](https://aclanthology.org/2021.acl-long.80) En, Es unlabelled speech | [mbart_large](https://github.com/pytorch/fairseq/blob/f591cc94caa85098ccf125a4782f91125b6a086d/fairseq/models/bart/model.py#L368) |[ckpt](htpps://dl.fbaipublicfiles.com/fairseq/speech_to_speech/s2st_finetuning/unit_mBART/checkpoint.pt) |
--- | --- | --- | --- |
1000 | [Voxpopuli](https://aclanthology.org/2021.acl-long.80) En, Es unlabelled speech | [mbart_large](https://github.com/pytorch/fairseq/blob/f591cc94caa85098ccf125a4782f91125b6a086d/fairseq/models/bart/model.py#L368) |[ckpt](https://dl.fbaipublicfiles.com/fairseq/speech_to_speech/s2st_finetuning/unit_mBART/checkpoint.pt) |
## Data preparation
## Data preparation
1. To prepare data for S2UT finetuning, follow the steps from [Direct S2ST with Discrete Units](./direct_s2st_discrete_units.md) and format the data in the _S2UT_ format. Note that we use 1000 units from the eleventh layer (`--layer 11`) of the multilingual hubert model linked above instead
2. Run
```
var="id\taudio\tn_frames\ttgt_text\ttgt_n_frames"
sed -i "1s/.*/$var/" ${SPLIT}.tsv
@ -36,6 +39,7 @@ sed -i "1s/.*/$var/" ${SPLIT}.tsv
**Speech-to-unit translation (S2UT)**
Here's an example for finetuning S2UT models with 1000 discrete units as target. You can download the config file and vocabulary from here[add links]:
```
fairseq-train $DATA_ROOT \
--config-yaml config.yaml \
@ -54,14 +58,14 @@ fairseq-train $DATA_ROOT \
--max-target-positions 4000 --update-freq 120 \
--seed 1 --fp16 --num-workers 1
```
* Adjust `--update-freq` accordingly for different #GPUs. In the above we set `--update-freq 15` to simulate training with 120 GPUs.
* In the above setting we finetune the model end to end, corresponding to the full setup in the paper.
* To apply LNA-E partial finetuning, add `--finetune-w2v-params layer_norm,self_attn`
* For LNA-D partial finetuning add `--finetune-decoder-params encoder_attn,layer_norm,self_attn`. To optionally freeze the encoder by k updates, use `--freeze-finetune-updates ${K}`
* For LNA-E,D partial finetuning add both the above options.
**Unit-based HiFi-GAN vocoder**
**Unit-based HiFi-GAN vocoder**
We apply the open-sourced unit-based HiFi-GAN vocoders to convert the predicted unit sequences to waveform. They are open sourced in [Textless S2ST with Real Data](textless_s2st_real_data.md)
@ -70,6 +74,7 @@ We apply the open-sourced unit-based HiFi-GAN vocoders to convert the predicted
**Speech-to-unit translation (S2UT)**
1. Follow the same inference process as in [fairseq-S2T](https://github.com/pytorch/fairseq/tree/main/examples/speech_to_text) to generate unit sequences (`${RESULTS_PATH}/generate-${GEN_SUBSET}.txt`).
```
fairseq-generate $DATA_ROOT \
--config-yaml config.yaml \
@ -81,6 +86,7 @@ fairseq-generate $DATA_ROOT \
```
2. Convert unit sequences to waveform.
```
grep "^D\-" ${RESULTS_PATH}/generate-${GEN_SUBSET}.txt | \
sed 's/^D-//ig' | sort -nk1 | cut -f3 \
@ -92,12 +98,11 @@ python examples/speech_to_speech/generate_waveform_from_code.py \
--results-path ${RESULTS_PATH} --dur-prediction
```
## Evaluation
To evaluate speech translation output, we first apply ASR on the speech output and then compute BLEU score betweent the ASR decoded text and the references using sacreBLEU.
* Text normalization: We use the text cleaner at [https://github.com/keithito/tacotron](https://github.com/keithito/tacotron) for pre-processing reference English text for ASR BLEU evaluation. The text cleaner used for Spanish text normalization will be updated here shortly.
* En ASR: We use the "[Wav2Vec 2.0 Large (LV-60) + Self Training / 960 hours / Libri-Light + Librispeech](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_vox_960h_pl.pt)" En ASR model open-sourced by the [wav2vec](https://github.com/pytorch/fairseq/tree/main/examples/wav2vec) project. The model is also available on [Hugging Face](https://huggingface.co/facebook/wav2vec2-large-960h-lv60-self).
* Es ASR: We use the [Wav2Vec2-Large-XLSR-53-Spanish](https://huggingface.co/facebook/wav2vec2-large-xlsr-53) finetuned on spanish Common Voice Es ASR model open-sourced by Jonatasgrosman(https://huggingface.co/jonatasgrosman/wav2vec2-large-xlsr-53-spanish) on [Hugging Face](https://huggingface.co/jonatasgrosman/wav2vec2-large-xlsr-53-spanish).
* Es ASR: We use the [Wav2Vec2-Large-XLSR-53-Spanish](https://huggingface.co/facebook/wav2vec2-large-xlsr-53) finetuned on spanish Common Voice Es ASR model open-sourced by Jonatasgrosman(<https://huggingface.co/jonatasgrosman/wav2vec2-large-xlsr-53-spanish>) on [Hugging Face](https://huggingface.co/jonatasgrosman/wav2vec2-large-xlsr-53-spanish).
* See [instructions](https://github.com/pytorch/fairseq/tree/main/examples/wav2vec#evaluating-a-ctc-model) on how to run inference with a wav2vec-based ASR model.

View File

@ -23,15 +23,15 @@ Wav2Vec 2.0 Large | 10 minutes | [Librispeech](http://www.openslr.org/12) | [do
Wav2Vec 2.0 Large | 100 hours | [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_big_100h.pt)
Wav2Vec 2.0 Large | 960 hours | [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_big_960h.pt)
Wav2Vec 2.0 Large (LV-60)* | No finetuning | [Libri-Light](https://github.com/facebookresearch/libri-light) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_vox_new.pt)
Wav2Vec 2.0 Large conformer - rel_pos (LV-60)* | No finetuning | [Libri-Light](https://github.com/facebookresearch/libri-light) | [download](s3://dl.fbaipublicfiles.com/fairseq/conformer/wav2vec2/librilight/LL_relpos_PT_no_FT )
Wav2Vec 2.0 Large conformer - rope (LV-60)* | No finetuning | [Libri-Light](https://github.com/facebookresearch/libri-light) | [download](s3://dl.fbaipublicfiles.com/fairseq/conformer/wav2vec2/librilight/LL_relpos_PT_no_FT )
Wav2Vec 2.0 Large conformer - rel_pos (LV-60)* | No finetuning | [Libri-Light](https://github.com/facebookresearch/libri-light) | [download](s3://dl.fbaipublicfiles.com/fairseq/conformer/wav2vec2/librilight/LL_relpos_PT_no_FT)
Wav2Vec 2.0 Large conformer - rope (LV-60)* | No finetuning | [Libri-Light](https://github.com/facebookresearch/libri-light) | [download](s3://dl.fbaipublicfiles.com/fairseq/conformer/wav2vec2/librilight/LL_rope_PT_no_FT)
Wav2Vec 2.0 Large (LV-60)* | 10 minutes | [Libri-Light](https://github.com/facebookresearch/libri-light) + [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_vox_10m_new.pt)
Wav2Vec 2.0 Large (LV-60)* | 100 hours | [Libri-Light](https://github.com/facebookresearch/libri-light) + [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_vox_100h_new.pt)
Wav2Vec 2.0 Large conformer - rel_pos (LV-60)* | 100 hours | [Libri-Light](https://github.com/facebookresearch/libri-light) | [download](s3://dl.fbaipublicfiles.com/fairseq/conformer/wav2vec2/librilight/LL_relpos_PT_100h_FT.pt )
Wav2Vec 2.0 Large conformer - rope (LV-60)* | 100 hours | [Libri-Light](https://github.com/facebookresearch/libri-light) | [download](s3://dl.fbaipublicfiles.com/fairseq/conformer/wav2vec2/librilight/LL_rope_PT_100h_FT.pt )
Wav2Vec 2.0 Large conformer - rel_pos (LV-60)* | 100 hours | [Libri-Light](https://github.com/facebookresearch/libri-light) | [download](s3://dl.fbaipublicfiles.com/fairseq/conformer/wav2vec2/librilight/LL_relpos_PT_100h_FT.pt)
Wav2Vec 2.0 Large conformer - rope (LV-60)* | 100 hours | [Libri-Light](https://github.com/facebookresearch/libri-light) | [download](s3://dl.fbaipublicfiles.com/fairseq/conformer/wav2vec2/librilight/LL_rope_PT_100h_FT.pt)
Wav2Vec 2.0 Large (LV-60)* | 960 hours | [Libri-Light](https://github.com/facebookresearch/libri-light) + [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec2_vox_960h_new.pt)
Wav2Vec 2.0 Large conformer - rel_pos (LV-60)* | 960 hours | [Libri-Light](https://github.com/facebookresearch/libri-light) | [download](s3://dl.fbaipublicfiles.com/fairseq/conformer/wav2vec2/librilight/LL_relpos_PT_960h_FT.pt )
Wav2Vec 2.0 Large conformer - rope (LV-60)* | 960 hours | [Libri-Light](https://github.com/facebookresearch/libri-light) | [download](s3://dl.fbaipublicfiles.com/fairseq/conformer/wav2vec2/librilight/LL_rope_PT_960h_FT.pt )
Wav2Vec 2.0 Large conformer - rel_pos (LV-60)* | 960 hours | [Libri-Light](https://github.com/facebookresearch/libri-light) | [download](s3://dl.fbaipublicfiles.com/fairseq/conformer/wav2vec2/librilight/LL_relpos_PT_960h_FT.pt)
Wav2Vec 2.0 Large conformer - rope (LV-60)* | 960 hours | [Libri-Light](https://github.com/facebookresearch/libri-light) | [download](s3://dl.fbaipublicfiles.com/fairseq/conformer/wav2vec2/librilight/LL_rope_PT_960h_FT.pt)
Wav2Vec 2.0 Large (LV-60) + Self Training * | 10 minutes | [Libri-Light](https://github.com/facebookresearch/libri-light) + [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_vox_10m_pl.pt)
Wav2Vec 2.0 Large (LV-60) + Self Training * | 100 hours | [Libri-Light](https://github.com/facebookresearch/libri-light) + [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_vox_100h_pl.pt)
Wav2Vec 2.0 Large (LV-60) + Self Training * | 960 hours | [Libri-Light](https://github.com/facebookresearch/libri-light) + [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_vox_960h_pl.pt)

View File

@ -9,13 +9,13 @@ XGLM models are trained on a new multilingual corpus extracted from CommonCrawl
## Pre-trained models
Model | Layers | Model Dim | Languages | Download
---|---|---|---|---
`XGLM 564M` | 24 | 1024 | trained on 30 languages| [xglm.564M.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/xglm/xglm.564M.tar.gz)
`XGLM 1.7B` | 24 | 2048 | trained on 30 languages| [xglm.1.7B.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/xglm/xglm.1.7B.tar.gz)
`XGLM 2.9B` | 48 | 2048 | trained on 30 languages| [xglm.2.9B.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/xglm/xglm.2.9B.tar.gz)
`XGLM 7.5B` | 32 | 4096 | trained on 30 languages| [xglm.7.5B.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/xglm/xglm.7.5B.tar.gz)
`XGLM 4.5B` | 48 | 2048 | trained on 134 languages| [xglm.4.5B.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/xglm/xglm.4.5B.tar.gz)
Model | Layers | Model Dim | FFN Dim | Languages | Download
---|---|---|---|---|---
`XGLM 564M` | 24 | 1024 | 4096 | trained on 30 languages| [xglm.564M.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/xglm/xglm.564M.tar.gz)
`XGLM 1.7B` | 24 | 2048 | 8192 | trained on 30 languages| [xglm.1.7B.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/xglm/xglm.1.7B.tar.gz)
`XGLM 2.9B` | 48 | 2048 | 8192 | trained on 30 languages| [xglm.2.9B.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/xglm/xglm.2.9B.tar.gz)
`XGLM 7.5B` | 32 | 4096 | 16384 | trained on 30 languages| [xglm.7.5B.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/xglm/xglm.7.5B.tar.gz)
`XGLM 4.5B` | 48 | 2048 | 16384 | trained on 134 languages| [xglm.4.5B.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/xglm/xglm.4.5B.tar.gz)
## Pre-training Data Format
Our models were pre-trained with data in the following format (i.e. paragraphs are separated with new lines and documents were separated with double new lines).

View File

@ -232,7 +232,7 @@ def load_checkpoint(cfg: CheckpointConfig, trainer, **passthrough_args):
)
else:
raise ValueError(
f"--funetune-from-model {cfg.finetune_from_model} does not exist"
f"--finetune-from-model {cfg.finetune_from_model} does not exist"
)
elif suffix is not None:
checkpoint_path = cfg.restore_file.replace(".pt", suffix + ".pt")

View File

@ -141,7 +141,7 @@ class HubertDataset(FairseqDataset):
self.single_target = single_target
self.label_rates = (
[label_rates for _ in range(len(label_paths))]
if isinstance(label_rates, int)
if isinstance(label_rates, float)
else label_rates
)
self.store_labels = store_labels
@ -302,7 +302,7 @@ class HubertDataset(FairseqDataset):
targets_list, lengths_list, ntokens_list = [], [], []
itr = zip(targets_by_label, self.label_rates, self.pad_list)
for targets, label_rate, pad in itr:
if label_rate == -1:
if label_rate == -1.0:
targets, lengths, ntokens = self.collater_seq_label(targets, pad)
else:
targets, lengths, ntokens = self.collater_frm_label(

View File

@ -476,7 +476,7 @@ def compute_mask_indices(
new_parts = []
if span_start - s - min_space >= keep_length:
new_parts.append((s, span_start - min_space + 1))
if e - span_start - keep_length - min_space > keep_length:
if e - span_start - length - min_space > keep_length:
new_parts.append((span_start + length + min_space, e))
return new_parts

View File

@ -35,7 +35,7 @@ logger = logging.getLogger(__name__)
@dataclass
class HubertConfig(FairseqDataclass):
label_rate: int = II("task.label_rate")
label_rate: float = II("task.label_rate")
extractor_mode: EXTRACTOR_MODE_CHOICES = field(
default="default",

View File

@ -4,7 +4,7 @@
""" Wrapper for ngram_repeat_block cuda extension """
import math
import warnings
from typing import Dict, List, Optional
from typing import List
import torch
from torch import nn
@ -95,56 +95,26 @@ class NGramRepeatBlock(nn.Module):
def _no_repeat_ngram(self, tokens, lprobs, bsz: int, beam_size: int, step: int):
"""For each hypothesis generate a list of previous ngrams and set associated lprobs to -inf"""
gen_ngrams: List[Dict[str, List[int]]] = [
torch.jit.annotate(Dict[str, List[int]], {})
for bbsz_idx in range(bsz * beam_size)
banned_tokens = [
torch.jit.annotate(List[int], []) for bbsz_idx in range(bsz * beam_size)
]
cpu_tokens = tokens.cpu()
for bbsz_idx in range(bsz * beam_size):
gen_tokens: List[int] = cpu_tokens[bbsz_idx].tolist()
for ngram in self.transpose_list(
[gen_tokens[i:] for i in range(self.no_repeat_ngram_size)]
):
key = ",".join([str(x) for x in ngram[:-1]])
gen_ngrams[bbsz_idx][key] = gen_ngrams[bbsz_idx].get(
key, torch.jit.annotate(List[int], [])
) + [ngram[-1]]
if step + 2 - self.no_repeat_ngram_size >= 0:
# no banned tokens if we haven't generated no_repeat_ngram_size tokens yet
banned_tokens = [
self.calculate_banned_tokens(
tokens, step, gen_ngrams, self.no_repeat_ngram_size, bbsz_idx
)
for bbsz_idx in range(bsz * beam_size)
]
else:
banned_tokens = [
torch.jit.annotate(List[int], []) for bbsz_idx in range(bsz * beam_size)
]
cpu_tokens: List[List[int]] = tokens.cpu().tolist()
check_start_pos = step + 2 - self.no_repeat_ngram_size
for bbsz_idx in range(bsz * beam_size):
ngram_to_check = cpu_tokens[bbsz_idx][
-(self.no_repeat_ngram_size - 1) :
]
for i in range(check_start_pos):
if (
ngram_to_check
== cpu_tokens[bbsz_idx][i : i + self.no_repeat_ngram_size - 1]
):
banned_tokens[bbsz_idx].append(
cpu_tokens[bbsz_idx][i + self.no_repeat_ngram_size - 1]
)
for bbsz_idx in range(bsz * beam_size):
lprobs[bbsz_idx][
torch.tensor(banned_tokens[bbsz_idx], dtype=torch.int64)
] = torch.tensor(-math.inf).to(lprobs)
return lprobs
@staticmethod
def calculate_banned_tokens(
tokens,
step: int,
gen_ngrams: List[Dict[str, List[int]]],
no_repeat_ngram_size: int,
bbsz_idx: int,
):
tokens_list: List[int] = tokens[
bbsz_idx, step + 2 - no_repeat_ngram_size : step + 1
].tolist()
# before decoding the next token, prevent decoding of ngrams that have already appeared
ngram_index = ",".join([str(x) for x in tokens_list])
return gen_ngrams[bbsz_idx].get(ngram_index, torch.jit.annotate(List[int], []))
@staticmethod
def transpose_list(l: List[List[int]]):
# GeneratorExp aren't supported in TS so ignoring the lint
min_len = min([len(x) for x in l]) # noqa
l2 = [[row[i] for row in l] for i in range(min_len)]
return l2

View File

@ -55,9 +55,9 @@ class HubertPretrainingConfig(FairseqDataclass):
"help": "if set, looks for labels in this directory instead",
},
)
label_rate: int = field(
default=-1,
metadata={"help": "label frame rate. -1 for sequence label"},
label_rate: float = field(
default=-1.0,
metadata={"help": "label frame rate. -1.0 for sequence label"},
)
sample_rate: int = field(
default=16_000,