mirror of
https://github.com/facebookresearch/fairseq.git
synced 2024-09-21 14:17:25 +03:00
[feat][ust] Noise and data augmentation suite (#4692)
* Implemented data augmentation for concatenation (#3516) * Implemented ConcatAug as setting from config * Switched ConcatAug implementation to sweep script * Added rate and max tokens as ConcatAug params * Kept original fns, pulled concat_attempts as hyperparam * Fixed ConcatAug nits * ConcatAug typing recognizes int and np.int * Implemented waveform transforms and suite of noise augmentation techniques (#3517) * Implemented ConcatAug as setting from config * Switched ConcatAug implementation to sweep script * Kept original fns, pulled concat_attempts as hyperparam * Implemented WaveformTransforms, MusicAug * Removed leftovers from debugging * Separated out feature_ and waveform_transforms, updated constants, formatting cleanup * Added Babble and SporadicNoise augmentations * Fixed zero division error * Adding BackgroundNoiseAugment * Added warning for if using feature transforms with waveform input * warnings, SNR fix * fix for NoneType extension error * fix 2 for NoneType extension error * delete print * Dataset transform, NoisyOverlapAugment, reframe ConcatAugment (#3533) * Dataset transform, NoisyOverlapAugment, reframe ConcatAugment * using np.random instead of python random * fixed np random upper bound bug * cleanup * Changed args & return expressions for waveform transform * Documented new augmentation features * Create augmentation_example.md * Update augmentation_example.md * Update, benchmarking left to do * Move docs to speech_to_speech * Remove docs from speech_to_text * [docs] Updated clean benchmarks * [docs] Add benchmark data
This commit is contained in:
parent
9a00e0336b
commit
eba8a50d2b
435
examples/speech_to_speech/docs/data_augmentation.md
Normal file
435
examples/speech_to_speech/docs/data_augmentation.md
Normal file
@ -0,0 +1,435 @@
|
||||
# Noise and audio augmentation techniques
|
||||
|
||||
The noise and data augmentation techniques were written in an effort to understand how augmenatation can affect model robustness and performance in both clean and noisy settings.
|
||||
|
||||
All transforms discussed in this section are subclasses of `AudioFeatureTransform`, `AudioWaveformTransform`, or `AudioDatasetTransform`. Each `Audio*Transform` has unique interaction with the data. If interested in implemented one's own transforms, it is highly advisable to review the differences (see [Adding your own transforms](https://github.com/facebookresearch/fairseq/blob/main/examples/speech_to_speech/docs/data_augmentation.md#adding-your-own-transforms)). If only applying the in-built transforms, then one only needs to be mindful that the correct kind of transform is listed in the config (see [Using transforms](https://github.com/facebookresearch/fairseq/blob/main/examples/speech_to_speech/docs/data_augmentation.md#using-transforms)). These transforms can be applied to instances of `SpeechToTextDataset`.
|
||||
|
||||
### Contents
|
||||
[In-built transforms](https://github.com/facebookresearch/fairseq/blob/main/examples/speech_to_speech/docs/data_augmentation.md#in-built-transforms)
|
||||
|
||||
[Benchmark studies](https://github.com/facebookresearch/fairseq/blob/main/examples/speech_to_speech/docs/data_augmentation.md#benchmark-studies)
|
||||
|
||||
[Using transforms](https://github.com/facebookresearch/fairseq/blob/main/examples/speech_to_speech/docs/data_augmentation.md#using-transforms)
|
||||
|
||||
[Adding your own transforms](https://github.com/facebookresearch/fairseq/blob/main/examples/speech_to_speech/docs/data_augmentation.md#adding-your-own-transforms)
|
||||
|
||||
|
||||
## In-built transforms
|
||||
### 1. Utterance concatenation
|
||||
Utterance concatenation is a data augmenation technique introduced as ConcatAug in [Translatotron 2: High-quality direct speech-to-speech translation
|
||||
with voice preservation](https://arxiv.org/pdf/2107.08661.pdf).
|
||||
With some parameterized probability, samples are concatenated with one other randomly chosen sample from the whole dataset. In the positive (concatenation) case, accessing `dataset[i]` will return a `SpeechToTextDatasetItem` where `source=source[i]+source[j]` and `target=target[i]+target[j]`. In the negative (skip concatenation) case, accessing `dataset[i]` will return a `SpeechToTextDatasetItem` where `source=source[i]` and `target=target[i]` as usual.
|
||||
|
||||
**Usage**: `concataugment` is an `AudioDatasetTransform` and has three configurable hyperparameters:
|
||||
- `rate`: probability that any single access will result in the positive (concatenation) case. Defaults to 0.25.
|
||||
- `max_tokens`: maximum number of tokens allowed for concatenated source sequences. This parameter is meant to limit the length of concatenated samples to avoid out-of-memory errors. Defaults to 300.
|
||||
- `attempts`: maximum number of invalid concatenation attempts before defaulting to the negative (skip concatenation) case. This parameter aims to limit excessive time spent trying to find candidate samples that are short enough to concatenate with. Defaults to 5.
|
||||
|
||||
Please be wary of OOMs while using this augmentation technique; we used smaller batch sizes as a workaround to avoid OOMs. Batch size is determined by update frequency, batch size hyperparameter, and the number of GPU, so you may want to alter these to this end.
|
||||
|
||||
### 2. Noise augmentation suite
|
||||
|
||||
The four noise augmentation methods in this suite adhere to the following principle: with some parameterized probability, samples are overlayed with a noise track. The content of the noise track is specific to the method. Signal-to-noise ratio with which the noise track is overlayed is determined by choosing a value from a random uniform distribution with parameterized endpoints. The first three methods are based off data augmentation methods suggested in Section 3.3 of [X-Vectors: Robust DNN Embeddings for Speaker Recognition](https://danielpovey.com/files/2018_icassp_xvectors.pdf).
|
||||
|
||||
#### 2.1. Music augmentation
|
||||
For music augmentation, the noise track consists of one file uniformly randomly selected from a corpus of music files. The music file is cut to size, including being repeated to fill the original sample length if necessary.
|
||||
|
||||
**Usage**: `musicaugment` is an `AudioWaveformTransform` and has four configurable hyperparameters:
|
||||
- `samples_path`: path where background music files are saved as audios (.wav files). No default.
|
||||
- `rate`: probability that any single access will result in the positive (background music) case. Defaults to 0.25.
|
||||
- `snr_min`: lower endpoint of the range from which a signal-to-noise ratio is uniformly randomly chosen with which to add background noise to the original source. Defaults to 5.
|
||||
- `snr_max`: higher endpoint of the range from which a signal-to-noise ratio is uniformly randomly chosen with which to add background noise to the original source. Defaults to 15.
|
||||
|
||||
#### 2.2. Babble augmentation
|
||||
For babble augmentation, the noise track consists of multiple audios uniformly randomly selected from a corpus of speech files. The number of speech audios in the background track is chosen randomly with equal probability between 3 and 7 audios.
|
||||
|
||||
**Usage**: `babbleaugment` is an `AudioWaveformTransform` and has four configurable hyperparameters:
|
||||
- `samples_path`: path where background speech files are saved as audios (.wav files). No default.
|
||||
- `rate`: probability that any single access will result in the positive (background speech) case. Defaults to 0.25.
|
||||
- `snr_min`: lower endpoint of the range from which a signal-to-noise ratio is uniformly randomly chosen with which to add background noise to the original source. Defaults to 5.
|
||||
- `snr_max`: higher endpoint of the range from which a signal-to-noise ratio is uniformly randomly chosen with which to add background noise to the original source. Defaults to 15.
|
||||
|
||||
#### 2.3. Sporadic noise augmentation
|
||||
For sporadic noise augmentation, the noise track is mostly silent except for intermittent short clips of noise which are added at roughly a parameterized frequency. These clips are randomly chosen and cut from a corpus of noise files to lengths according to a parameterized Gaussian distribution.
|
||||
|
||||
**Usage**: `sporadicnoiseaugment` is an `AudioWaveformTransform` and has seven configurable hyperparameters:
|
||||
- `samples_path`: path where background noise files are saved as audios (.wav files). No default.
|
||||
- `rate`: probability that any single access will result in the positive (add a sporadic noise track) case. Defaults to 0.25.
|
||||
- `snr_min`: lower endpoint of the range from which a signal-to-noise ratio is uniformly randomly chosen with which to add background noise to the original source. Defaults to 5.
|
||||
- `snr_max`: higher endpoint of the range from which a signal-to-noise ratio is uniformly randomly chosen with which to add background noise to the original source. Defaults to 15.
|
||||
- `noise_rate`: rate in noises per second at which noise clip will be added to the original sample
|
||||
- `noise_len_mean`: mean of Gaussian normal distribution from which length of noise clip is chosen
|
||||
- `noise_len_std`: standard deviation of Gaussian normal distribution from which length of noise clip is chosen
|
||||
|
||||
#### 2.4. Background noise augmentation
|
||||
For background noise augmentation, the noise track is a single track uniformly randomly selected from a corpus of noise files. The noise file is cut to size, including being repeated to fill the original sample length if necessary.
|
||||
|
||||
**Usage**: `backgroundnoiseaugment` is an `AudioWaveformTransform` and has four configurable hyperparameters:
|
||||
- `samples_path`: path where background noise files are saved as audios (.wav files). No default.
|
||||
- `rate`: probability that any single access will result in the positive (background noise) case. Defaults to 0.25.
|
||||
- `snr_min`: lower endpoint of the range from which a signal-to-noise ratio is uniformly randomly chosen with which to add background noise to the original source. Defaults to 5.
|
||||
- `snr_max`: higher endpoint of the range from which a signal-to-noise ratio is uniformly randomly chosen with which to add background noise to the original source. Defaults to 15.
|
||||
|
||||
### 3. Mixed babble and background noise augmentation with recognizable source speaker
|
||||
|
||||
This augmentation technique is based on Algorithm 1 in [WavLM: Large-Scale Self-Supervised Pre-Training for Full Stack Speech Processing](https://arxiv.org/abs/2110.13900) and is similar to the noise augmentation suite techniques in that it has a background noise track. The noise track consists of either (1) another audio sample from the batch or (2) a background noise track. A key difference is the length of the noise track is chosen from a uniform random distribution between 0 and half of the original sample length.
|
||||
|
||||
**Usage**: `noisyoverlapaugment` is an `AudioDatasetTransform` and has seven configurable hyperparameters:
|
||||
- `noises_path`: path where background noise files are saved as audios (.wav files). No default.
|
||||
- `rate`: probability that any single access will result in the positive (background noise) case. Defaults to 0.25.
|
||||
- `mixing_noise_rate`: probability that in a positive (background noise) case, the noise track will consist of background noise (rather than babble from the batch). Defaults to 0.1.
|
||||
- `noise_snr_min`: lower endpoint of the range from which a signal-to-noise ratio is uniformly randomly chosen with which to add background noise to the original source. Defaults to -5.
|
||||
- `noise_snr_max`: higher endpoint of the range from which a signal-to-noise ratio is uniformly randomly chosen with which to add background noise to the original source. Defaults to 5.
|
||||
- `utterance_snr_min`: lower endpoint of the range from which a signal-to-noise ratio is uniformly randomly chosen with which to add **another audio from the batch** to the original source. Defaults to -5.
|
||||
- `utterance_snr_max`: higher endpoint of the range from which a signal-to-noise ratio is uniformly randomly chosen with which to add **another audio from the batch** to the original source. Defaults to 5.
|
||||
|
||||
## Benchmark studies
|
||||
### Evaluation on clean data
|
||||
Augmentation in training data|Hyperparameters|Training loss|BLEU (covost)|BLEU (epst)|BLEU (mtedx)
|
||||
---|---|---|---|---|---
|
||||
None||3.954|24.984|23.962|24.448
|
||||
ConcatAugment|rate = 0.25, max_tokens = 3000, attempts = 5|3.940|25.322|26.124|26.19
|
||||
BabbleAugment|rate = 0.25, MUSAN speech, snr_min = (-5), snr_max = 5|3.957|24.226|23.186|22.368|
|
||||
BackgroundNoiseAugment|rate = 0.1, MUSAN noises, snr_min = (-10), snr_max = 10|3.955|24.745|23.513|23.819
|
||||
MusicAugment|rate = 0.25, MUSAN music, snr_min = 0, snr_max = 20|3.954|25.096|24.301|23.341|
|
||||
SporadicNoiseAugment|rate = 0.1, noise_rate = 0.25, MUSAN noises, snr_min = 10, snr_max = 35|3.954|24.924|23.951|23.484|
|
||||
MusicAugment + BabbleAugment + BackgroundNoiseAugment + SporadicNoiseAugment|as above, except limited rates to sum to 0.25: music (0.074), background (0.029), babble (0.074), sporadic (0.029)|3.953|24.874|23.675|24.249|
|
||||
NoisyOverlapAugment|rate = 0.25, mixing_noise_rate = 0.5, MUSAN noises, utterance_snr_min = (-10), utterance_snr_max = 0, noise_snr_min = (-5), noise_snr_max = 20|3.954|24.949|24.015|23.768|
|
||||
|
||||
### Evaluation on data with music noise added at SNR = (-5) - 5
|
||||
Augmentation in training data|Training loss|BLEU (covost)|BLEU (epst)|BLEU (mtedx)
|
||||
---|---|---|---|---
|
||||
None|3.954|15.785|21.105|16.944
|
||||
ConcatAugment|3.940|17.186|23.255|18.24
|
||||
BabbleAugment|3.957|19.158|22.064|17.116
|
||||
BackgroundNoiseAugment|3.955|17.777|22.0|17.535|
|
||||
MusicAugment|3.954|20.345|23.126|19.433|
|
||||
SporadicNoiseAugment|3.954|15.927|21.382|14.736|
|
||||
MusicAugment + BabbleAugment + BackgroundNoiseAugment + SporadicNoiseAugment|3.953|19.724|22.659|17.852|
|
||||
NoisyOverlapAugment|3.954|17.49|22.142|17.207|
|
||||
|
||||
### Evaluation on data with babble noise added at SNR = (-5) - 5
|
||||
Augmentation in training data|Training loss|BLEU (covost)|BLEU (epst)|BLEU (mtedx)
|
||||
---|---|---|---|---
|
||||
None|3.954|4.092|13.514|5.13
|
||||
ConcatAugment|3.940|5.493|15.835|6.893
|
||||
BabbleAugment|3.957|16.12|21.097|13.996
|
||||
BackgroundNoiseAugment|3.955|4.691|15.784|5.982
|
||||
MusicAugment|3.954|8.06|17.764|9.008
|
||||
SporadicNoiseAugment|3.954|4.009|13.935|4.814
|
||||
MusicAugment + BabbleAugment + BackgroundNoiseAugment + SporadicNoiseAugment|3.953|14.692|20.882|14.45
|
||||
NoisyOverlapAugment|3.954|4.032|16.434|7.284
|
||||
|
||||
### Evaluation on data with sporadic noise added at SNR = (-5) - 5
|
||||
Augmentation in training data|Training loss|BLEU (covost)|BLEU (epst)|BLEU (mtedx)
|
||||
---|---|---|---|---
|
||||
None|3.954|23.778|23.745|22.748
|
||||
ConcatAugment|3.940|24.239|25.907|25.723
|
||||
BabbleAugment|3.957|23.42|23.048|21.076
|
||||
BackgroundNoiseAugment|3.955|23.998|23.467|22.494
|
||||
MusicAugment|3.954|24.142|24.181|19.143
|
||||
SporadicNoiseAugment|3.954|23.97|23.894|22.61
|
||||
MusicAugment + BabbleAugment + BackgroundNoiseAugment + SporadicNoiseAugment|3.953|24.118|23.59|23.717
|
||||
NoisyOverlapAugment|3.954|24.265|24.103|23.167
|
||||
|
||||
### Evaluation on data with background noise added at SNR = (-5) - 5
|
||||
Augmentation in training data|Training loss|BLEU (covost)|BLEU (epst)|BLEU (mtedx)
|
||||
---|---|---|---|---
|
||||
None|3.954|20.201|22.525|19.66
|
||||
ConcatAugment|3.940|20.904|24.706|21.353
|
||||
BabbleAugment|3.957|20.687|22.374|18.907
|
||||
BackgroundNoiseAugment|3.955|21.574|22.998|20.043
|
||||
MusicAugment|3.954|21.65|23.529|19.87
|
||||
SporadicNoiseAugment|3.954|20.578|22.577|19.096
|
||||
MusicAugment + BabbleAugment + BackgroundNoiseAugment + SporadicNoiseAugment|3.953|21.811|23.144|20.986
|
||||
NoisyOverlapAugment|3.954|21.312|23.153|20.302
|
||||
|
||||
### Evaluation on data with all four types of noises added at SNR = (-5) - 5, each applied with prob 0.5
|
||||
Augmentation in training data|Training loss|BLEU (covost)|BLEU (epst)|BLEU (mtedx)
|
||||
---|---|---|---|---
|
||||
None|3.954|10.895|19.319|12.748
|
||||
ConcatAugment|3.940|13.517|21.658|15.428
|
||||
BabbleAugment|3.957|18.09|21.384|16.018
|
||||
BackgroundNoiseAugment|3.955|12.837|20.719|13.933
|
||||
MusicAugment|3.954|16.589|21.823|15.927
|
||||
SporadicNoiseAugment|3.954|11.238|19.91|13.31
|
||||
MusicAugment + BabbleAugment + BackgroundNoiseAugment + SporadicNoiseAugment|3.953|18.636|21.935|17.845
|
||||
NoisyOverlapAugment|3.954|12.829|20.856|15.048
|
||||
|
||||
### Evaluation on data with noisy overlap augment
|
||||
Augmentation in training data|Training loss|BLEU (covost)|BLEU (epst)|BLEU (mtedx)
|
||||
---|---|---|---|---
|
||||
None|3.954|21.245|22.24|20.994
|
||||
ConcatAugment|3.940|21.611|24.247|23.068
|
||||
BabbleAugment|3.957|21.867|21.987|20.099|
|
||||
BackgroundNoiseAugment|3.955|21.533|21.806|19.717|
|
||||
MusicAugment|3.954|21.823|22.643|20.847|
|
||||
SporadicNoiseAugment|3.954|21.373|22.381|20.672|
|
||||
MusicAugment + BabbleAugment + BackgroundNoiseAugment + SporadicNoiseAugment|3.953|22.206|22.414|21.375|
|
||||
NoisyOverlapAugment|3.954|23.371|23.396|22.627|
|
||||
|
||||
## Using transforms
|
||||
Transforms are configurable.
|
||||
|
||||
1. Please pay careful attention to the type of transform you are applying.
|
||||
- `concataugment` and `noisyoverlapaugment` are instances of `AudioDatasetTransform` and should be listed in the config under `dataset_transforms`.
|
||||
- `musicaugment`, `babbleaugment`, `sporadicnoiseaugment`, and `backgroundnoiseaugment` are instances of `AudioWaveformTransform` and should be listed under `waveform_transforms`.
|
||||
- Instances of `AudioFeatureTransform` should be listed under `feature_transforms`.
|
||||
2. Feel free to apply these augmentations in different contexts, e.g., you may use a `_train` or `_eval` flag to specify when the transform will be applied. If the dataset at hand contains `train` in its name, those transforms under the `_train` flag will be applied; else, the remaining transforms will be applied.
|
||||
|
||||
For example, you would add this to your config to apply the musicaugment transform to a training dataset:
|
||||
```yaml
|
||||
musicaugment:
|
||||
samples_path: ${MUSIC_PATH}
|
||||
snr_min: 10
|
||||
snr_max: 15
|
||||
rate: 0.25
|
||||
waveform_transforms:
|
||||
_train:
|
||||
- musicaugment
|
||||
```
|
||||
or add this to apply the concataugment transform:
|
||||
```yaml
|
||||
concataugment:
|
||||
rate: 0.25
|
||||
max_tokens: 3000
|
||||
attempts: 5
|
||||
dataset_transforms:
|
||||
_train:
|
||||
- concataugment
|
||||
```
|
||||
You may also want to add multiple of one type of transform; here, we add multiple `AudioWaveformTransform`s:
|
||||
```yaml
|
||||
musicaugment:
|
||||
samples_path: ${MUSIC_PATH}
|
||||
snr_min: 5
|
||||
snr_max: 20
|
||||
rate: 0.25
|
||||
backgroundnoiseaugment:
|
||||
samples_path: ${NOISES_PATH}
|
||||
snr_min: 10
|
||||
snr_max: 20
|
||||
rate: 0.1
|
||||
sporadicnoiseaugment:
|
||||
samples_path: ${NOISES_PATH}
|
||||
snr_min: 5
|
||||
snr_max: 15
|
||||
rate: 0.1
|
||||
noise_rate: 0.25
|
||||
waveform_transforms:
|
||||
_train:
|
||||
- musicaugment
|
||||
- backgroundnoiseaugment
|
||||
- sporadicnoiseaugment
|
||||
```
|
||||
|
||||
## Adding your own transforms
|
||||
Note: We store transform implementations in `fairseq/data/audio/*_transforms` directories. You may refer to these as examples while implementing your own transform.
|
||||
|
||||
### Step 1. Picking the right class for your transform
|
||||
The integration into SpeechToTextDataset is quite different for each kind of transform, so it is important to understand which one is best suited to your purposes.
|
||||
|
||||
**Feature transforms**
|
||||
`AudioFeatureTransform` is a base class which allows **some transform to be applied to audio spectrograms** in the data loading step. One thing to note is that the source data is either saved as `np.ndarrays` or as audio files, and is to be returned either as features (spectrogram) or waveform. If and only if the data is to be returned as a spectrogram, then `AudioFeatureTransform`s will be applied.
|
||||
|
||||
**Waveform transforms**
|
||||
`AudioWaveformTransform` is a base class which allows some **transform to be applied to waveforms** in the data loading step. As mentioned above, there are two source and return types to data loading for this dataset. If and only if the data is saved in audio file format, then `AudioWaveformTransform`s will be applied, whichever return type is used.
|
||||
|
||||
**Dataset transforms**
|
||||
`AudioDatasetTransform` is a base class for transforms **based on more than one item in a dataset**, ex. concatenation of two random samples in a dataset. Rather than being applied in a consistent way, i.e., to all features or to all waveforms, the integration of a dataset transform is entirely specific. Adding a dataset transform requires actually editing the `fairseq/data/audio/speech_to_text_dataset.py` file.
|
||||
|
||||
### Step 2. Setting up your transform (generic to all types of transforms)
|
||||
Now that you know which kind of transform you would like to use, we are ready to implement it. This step is generic for all transform types, i.e., `TRANSFORM_TYPE` may be any of `feature`, `waveform`, or `dataset`. We will show how to build utterance concatenation (an `AudioDatasetTransform`) as an example.
|
||||
|
||||
Import the base class and registration function for your transform.
|
||||
```python
|
||||
from fairseq.data.audio.dataset_transforms import (
|
||||
AudioDatasetTransform,
|
||||
register_audio_dataset_transform
|
||||
)
|
||||
```
|
||||
|
||||
Define the class and register the transform. The name passed into the registration function is how your transform should be named in the config.
|
||||
```python
|
||||
@register_audio_dataset_transform("concataugment")
|
||||
class ConcatAugment(AudioDatasetTransform):
|
||||
```
|
||||
|
||||
We are now ready to add the basic important functions to our new class. In this example, `_DEFAULTS` refers to a dictionary with the default hyperparameter values that we defined. `from_config_dict` is called to instantiate the transform given hyperparameters from the config.
|
||||
```python
|
||||
@classmethod
|
||||
def from_config_dict(cls, config=None):
|
||||
_config = {} if config is None else config
|
||||
return ConcatAugment(
|
||||
_config.get("rate", _DEFAULTS["rate"]),
|
||||
_config.get("max_tokens", _DEFAULTS["max_tokens"]),
|
||||
_config.get("attempts", _DEFAULTS["attempts"]),
|
||||
)
|
||||
```
|
||||
We edit the instantiation function `__init__` to track hyperparameters and do any setup work.
|
||||
```python
|
||||
def __init__(
|
||||
self,
|
||||
rate=_DEFAULTS["rate"],
|
||||
max_tokens=_DEFAULTS["max_tokens"],
|
||||
attempts=_DEFAULTS["attempts"],
|
||||
):
|
||||
self.rate, self.max_tokens, self.attempts = rate, max_tokens, attempts
|
||||
```
|
||||
Lastly `__repr__` gives how the transform will be reported in an output log.
|
||||
```python
|
||||
def __repr__(self):
|
||||
return (
|
||||
self.__class__.__name__
|
||||
+ "("
|
||||
+ ", ".join(
|
||||
[
|
||||
f"rate={self.rate}",
|
||||
f"max_tokens={self.max_tokens}",
|
||||
f"attempts={self.attempts}",
|
||||
]
|
||||
)
|
||||
+ ")"
|
||||
)
|
||||
```
|
||||
|
||||
### Step 3. Adding the transform logic
|
||||
At this point, we are ready to implement the actual transform logic. The flow from here is different for each of the three transforms, so follow the path that is relevant to you.
|
||||
### ...for feature transforms
|
||||
The final step is implementing the `__call__` function, which applies the transform logic and **returns** the spectrogram with transform applied. This supports and should take exactly **two arguments**:
|
||||
- `self`
|
||||
- `x` (np.ndarray): the spectrogram for one source sample. (This is a positional argument, so you can use another parameter name like `spectrogram` instead of `x`.)
|
||||
|
||||
For example, this is the `__call__` function for GlobalCMVN (cepstral mean and variance normalization).
|
||||
```python
|
||||
def __call__(self, x):
|
||||
x = np.subtract(x, self.mean)
|
||||
x = np.divide(x, self.std)
|
||||
return x
|
||||
|
||||
```
|
||||
### ...for waveform transforms
|
||||
The final step is implementing the `__call__` function, which applies the transform logic. This supports and should take exactly **three arguments**:
|
||||
- `self`
|
||||
- `source` (numpy.ndarray or torch.Tensor): source audio 2d waveform (channels x length)
|
||||
- `sample_rate` (optional, defaults to None): sample rate of `source`
|
||||
|
||||
`__call__` **returns**:
|
||||
- transformed audio waveform
|
||||
- sample rate of transformed audio waveform
|
||||
|
||||
For example, this is the `__call__` function for augmentations in the Noise Augmentation Suite.
|
||||
```python
|
||||
def __call__(self, source, sample_rate=None):
|
||||
if np.random.random() > self.rate:
|
||||
return source
|
||||
|
||||
noise = self._get_noise(
|
||||
source.shape, always_2d=True, use_sample_rate=sample_rate
|
||||
)
|
||||
return self._mix(source, noise, rand_uniform(self.snr_min, self.snr_max)), sample_rate
|
||||
```
|
||||
|
||||
### ...for dataset transforms
|
||||
Dataset transforms are extremely flexible, and implementation involves directly integrating them into `fairseq/data/audio/speech_to_text_dataset.py` in transform-specific ways.
|
||||
There are two basic components: (1) check whether or not this transform is part of this dataset instance using `self.dataset_transforms.has_transform(TRANSFORM_CLS)`, and (2) if so, get the transform using `self.dataset_transforms.get_transform(TRANSFORM_CLS)` & apply it.
|
||||
Due to the case-by-case specificity, it is easier to demonstrate this by examples.
|
||||
|
||||
#### Example: NoisyOverlapAugment
|
||||
This transform requires access to multiple items within the same batch at once.
|
||||
|
||||
**Logic**: We still use the transform classes to keep away the transform logic. For example, `__call__` of `NoisyOverlapAugment` class takes a list of source tokens for items in a mini-batch, applies noise/utterance as dictated by the transform, and returns the list of transformed source tokens for items in the mini-batch.
|
||||
|
||||
```python
|
||||
def __call__(self, sources):
|
||||
for i, source in enumerate(sources):
|
||||
if np.random.random() > self.rate:
|
||||
continue
|
||||
|
||||
pri = source.numpy()
|
||||
|
||||
# ... some transform code omitted
|
||||
|
||||
pri[s_source : s_source + l] = np.add(
|
||||
pri[s_source : s_source + l], np.multiply(scl, sec[s_sec : s_sec + l])
|
||||
)
|
||||
sources[i] = torch.from_numpy(pri).float()
|
||||
|
||||
return sources
|
||||
```
|
||||
|
||||
**Integration**: The `collater` function for `SpeechToTextDataset` is responsible for preparing a mini-batch for training, so we integrate NOAug through adding a few lines to the top of this function:
|
||||
```python
|
||||
def collater(
|
||||
self, samples: List[SpeechToTextDatasetItem], return_order: bool = False
|
||||
) -> Dict:
|
||||
if len(samples) == 0:
|
||||
return {}
|
||||
indices = torch.tensor([x.index for x in samples], dtype=torch.long)
|
||||
|
||||
sources = [x.source for x in samples]
|
||||
|
||||
# NOAUG INTEGRATION BLOCK
|
||||
# (1) Check whether or not this transform is part of this dataset instance
|
||||
has_NOAug = self.dataset_transforms.has_transform(NoisyOverlapAugment)
|
||||
# (2) If so, get & apply the transform
|
||||
if has_NOAug and self.cfg.use_audio_input:
|
||||
NOAug = self.dataset_transforms.get_transform(NoisyOverlapAugment)
|
||||
sources = NOAug(sources)
|
||||
|
||||
frames = _collate_frames(sources, self.cfg.use_audio_input)
|
||||
# sort samples by descending number of frames
|
||||
n_frames = torch.tensor([x.size(0) for x in sources], dtype=torch.long)
|
||||
n_frames, order = n_frames.sort(descending=True)
|
||||
indices = indices.index_select(0, order)
|
||||
frames = frames.index_select(0, order)
|
||||
|
||||
# ... rest of function
|
||||
```
|
||||
|
||||
#### Example: ConcatAugment
|
||||
This transform requires access to another item within the dataset at once.
|
||||
|
||||
**Logic**: We abstract the logic for picking indices to concatenate by adding a `find_indices` function to the `ConcatAugment` class, which takes one index in the dataset and finds a compatible second index to concatenate source and target tokens.
|
||||
```python
|
||||
def find_indices(self, index: int, n_frames: List[int], n_samples: int):
|
||||
# skip conditions: application rate, max_tokens limit exceeded
|
||||
if np.random.random() > self.rate:
|
||||
return [index]
|
||||
if self.max_tokens and n_frames[index] > self.max_tokens:
|
||||
return [index]
|
||||
|
||||
# pick second sample to concatenate
|
||||
for _ in range(self.attempts):
|
||||
index2 = np.random.randint(0, n_samples)
|
||||
if index2 != index and (
|
||||
not self.max_tokens
|
||||
or n_frames[index] + n_frames[index2] < self.max_tokens
|
||||
):
|
||||
return [index, index2]
|
||||
|
||||
return [index]
|
||||
```
|
||||
|
||||
**Integration**: `SpeechToTextDataset` uses a custom `__getitem__(self, index)` function (called in the background when you write `dataset[i]`). We edited this function (as well as `_get_source_audio` and `get_tokenized_tgt_text`) to achieve the desired transform effect where accessing `dataset[i]` will return a `SpeechToTextDatasetItem` where `source=source[i]+source[j]` and `target=target[i]+target[j]`.
|
||||
```python
|
||||
def __getitem__(self, index: int) -> SpeechToTextDatasetItem:
|
||||
|
||||
# CONCATAUGMENT INTEGRATION BLOCK
|
||||
# (1) Check whether or not this transform is part of this dataset instance
|
||||
has_concat = self.dataset_transforms.has_transform(ConcatAugment)
|
||||
# (2) If so, get & apply the transform
|
||||
if has_concat:
|
||||
concat = self.dataset_transforms.get_transform(ConcatAugment)
|
||||
indices = concat.find_indices(index, self.n_frames, self.n_samples)
|
||||
|
||||
source = self._get_source_audio(indices if has_concat else index)
|
||||
source = self.pack_frames(source)
|
||||
|
||||
target = None
|
||||
if self.tgt_texts is not None:
|
||||
tokenized = self.get_tokenized_tgt_text(indices if has_concat else index)
|
||||
target = self.tgt_dict.encode_line(
|
||||
|
||||
# ... rest of function
|
||||
```
|
@ -0,0 +1,93 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Optional
|
||||
import importlib
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
|
||||
class AudioTransform(ABC):
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def from_config_dict(cls, config: Optional[Dict] = None):
|
||||
pass
|
||||
|
||||
|
||||
class CompositeAudioTransform(AudioTransform):
|
||||
def _from_config_dict(
|
||||
cls,
|
||||
transform_type,
|
||||
get_audio_transform,
|
||||
composite_cls,
|
||||
config=None,
|
||||
return_empty=False,
|
||||
):
|
||||
_config = {} if config is None else config
|
||||
_transforms = _config.get(f"{transform_type}_transforms")
|
||||
|
||||
if _transforms is None:
|
||||
if return_empty:
|
||||
_transforms = []
|
||||
else:
|
||||
return None
|
||||
|
||||
transforms = [
|
||||
get_audio_transform(_t).from_config_dict(_config.get(_t))
|
||||
for _t in _transforms
|
||||
]
|
||||
return composite_cls(transforms)
|
||||
|
||||
def __init__(self, transforms):
|
||||
self.transforms = [t for t in transforms if t is not None]
|
||||
|
||||
def __call__(self, x):
|
||||
for t in self.transforms:
|
||||
x = t(x)
|
||||
return x
|
||||
|
||||
def __repr__(self):
|
||||
format_string = (
|
||||
[self.__class__.__name__ + "("]
|
||||
+ [f" {t.__repr__()}" for t in self.transforms]
|
||||
+ [")"]
|
||||
)
|
||||
return "\n".join(format_string)
|
||||
|
||||
|
||||
def register_audio_transform(name, cls_type, registry, class_names):
|
||||
def register_audio_transform_cls(cls):
|
||||
if name in registry:
|
||||
raise ValueError(f"Cannot register duplicate transform ({name})")
|
||||
if not issubclass(cls, cls_type):
|
||||
raise ValueError(
|
||||
f"Transform ({name}: {cls.__name__}) must extend "
|
||||
f"{cls_type.__name__}"
|
||||
)
|
||||
if cls.__name__ in class_names:
|
||||
raise ValueError(
|
||||
f"Cannot register audio transform with duplicate "
|
||||
f"class name ({cls.__name__})"
|
||||
)
|
||||
registry[name] = cls
|
||||
class_names.add(cls.__name__)
|
||||
return cls
|
||||
|
||||
return register_audio_transform_cls
|
||||
|
||||
|
||||
def import_transforms(transforms_dir, transform_type):
|
||||
for file in os.listdir(transforms_dir):
|
||||
path = os.path.join(transforms_dir, file)
|
||||
if (
|
||||
not file.startswith("_")
|
||||
and not file.startswith(".")
|
||||
and (file.endswith(".py") or os.path.isdir(path))
|
||||
):
|
||||
name = file[: file.find(".py")] if file.endswith(".py") else file
|
||||
importlib.import_module(
|
||||
f"fairseq.data.audio.{transform_type}_transforms." + name
|
||||
)
|
||||
|
||||
|
||||
# Utility fn for uniform numbers in transforms
|
||||
def rand_uniform(a, b):
|
||||
return np.random.uniform() * (b - a) + a
|
@ -13,6 +13,8 @@ import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from fairseq.data.audio.waveform_transforms import CompositeAudioWaveformTransform
|
||||
|
||||
SF_AUDIO_FILE_EXTENSIONS = {".wav", ".flac", ".ogg"}
|
||||
FEATURE_OR_SF_AUDIO_FILE_EXTENSIONS = {".npy", ".wav", ".flac", ".ogg"}
|
||||
|
||||
@ -73,6 +75,7 @@ def get_waveform(
|
||||
always_2d: bool = True,
|
||||
output_sample_rate: Optional[int] = None,
|
||||
normalize_volume: bool = False,
|
||||
waveform_transforms: Optional[CompositeAudioWaveformTransform] = None,
|
||||
) -> Tuple[np.ndarray, int]:
|
||||
"""Get the waveform and sample rate of a 16-bit WAV/FLAC/OGG Vorbis audio.
|
||||
|
||||
@ -113,16 +116,25 @@ def get_waveform(
|
||||
|
||||
if not normalization:
|
||||
waveform *= 2**15 # denormalized to 16-bit signed integers
|
||||
|
||||
if waveform_transforms is not None:
|
||||
waveform, sample_rate = waveform_transforms(waveform, sample_rate)
|
||||
|
||||
if not always_2d:
|
||||
waveform = waveform.squeeze(axis=0)
|
||||
|
||||
return waveform, sample_rate
|
||||
|
||||
|
||||
def get_features_from_npy_or_audio(path):
|
||||
def get_features_from_npy_or_audio(path, waveform_transforms=None):
|
||||
ext = Path(path).suffix
|
||||
if ext not in FEATURE_OR_SF_AUDIO_FILE_EXTENSIONS:
|
||||
raise ValueError(f'Unsupported file format for "{path}"')
|
||||
return np.load(path) if ext == ".npy" else get_fbank(path)
|
||||
return (
|
||||
np.load(path)
|
||||
if ext == ".npy"
|
||||
else get_fbank(path, waveform_transforms=waveform_transforms)
|
||||
)
|
||||
|
||||
|
||||
def get_features_or_waveform_from_stored_zip(
|
||||
@ -131,6 +143,7 @@ def get_features_or_waveform_from_stored_zip(
|
||||
byte_size,
|
||||
need_waveform=False,
|
||||
use_sample_rate=None,
|
||||
waveform_transforms=None,
|
||||
):
|
||||
assert path.endswith(".zip")
|
||||
data = read_from_stored_zip(path, byte_offset, byte_size)
|
||||
@ -139,16 +152,23 @@ def get_features_or_waveform_from_stored_zip(
|
||||
features_or_waveform = np.load(f)
|
||||
elif is_sf_audio_data(data):
|
||||
features_or_waveform = (
|
||||
get_waveform(f, always_2d=False, output_sample_rate=use_sample_rate)[0]
|
||||
get_waveform(
|
||||
f,
|
||||
always_2d=False,
|
||||
output_sample_rate=use_sample_rate,
|
||||
waveform_transforms=waveform_transforms,
|
||||
)[0]
|
||||
if need_waveform
|
||||
else get_fbank(f)
|
||||
else get_fbank(f, waveform_transforms=waveform_transforms)
|
||||
)
|
||||
else:
|
||||
raise ValueError(f'Unknown file format for "{path}"')
|
||||
return features_or_waveform
|
||||
|
||||
|
||||
def get_features_or_waveform(path: str, need_waveform=False, use_sample_rate=None):
|
||||
def get_features_or_waveform(
|
||||
path: str, need_waveform=False, use_sample_rate=None, waveform_transforms=None
|
||||
):
|
||||
"""Get speech features from .npy file or waveform from .wav/.flac file.
|
||||
The file may be inside an uncompressed ZIP file and is accessed via byte
|
||||
offset and length.
|
||||
@ -166,9 +186,14 @@ def get_features_or_waveform(path: str, need_waveform=False, use_sample_rate=Non
|
||||
if len(slice_ptr) == 0:
|
||||
if need_waveform:
|
||||
return get_waveform(
|
||||
_path, always_2d=False, output_sample_rate=use_sample_rate
|
||||
_path,
|
||||
always_2d=False,
|
||||
output_sample_rate=use_sample_rate,
|
||||
waveform_transforms=waveform_transforms,
|
||||
)[0]
|
||||
return get_features_from_npy_or_audio(_path)
|
||||
return get_features_from_npy_or_audio(
|
||||
_path, waveform_transforms=waveform_transforms
|
||||
)
|
||||
elif len(slice_ptr) == 2:
|
||||
features_or_waveform = get_features_or_waveform_from_stored_zip(
|
||||
_path,
|
||||
@ -176,6 +201,7 @@ def get_features_or_waveform(path: str, need_waveform=False, use_sample_rate=Non
|
||||
slice_ptr[1],
|
||||
need_waveform=need_waveform,
|
||||
use_sample_rate=use_sample_rate,
|
||||
waveform_transforms=waveform_transforms,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid path: {path}")
|
||||
@ -223,12 +249,16 @@ def _get_torchaudio_fbank(
|
||||
return None
|
||||
|
||||
|
||||
def get_fbank(path_or_fp: Union[str, BinaryIO], n_bins=80) -> np.ndarray:
|
||||
def get_fbank(
|
||||
path_or_fp: Union[str, BinaryIO], n_bins=80, waveform_transforms=None
|
||||
) -> np.ndarray:
|
||||
"""Get mel-filter bank features via PyKaldi or TorchAudio. Prefer PyKaldi
|
||||
(faster CPP implementation) to TorchAudio (Python implementation). Note that
|
||||
Kaldi/TorchAudio requires 16-bit signed integers as inputs and hence the
|
||||
waveform should not be normalized."""
|
||||
waveform, sample_rate = get_waveform(path_or_fp, normalization=False)
|
||||
waveform, sample_rate = get_waveform(
|
||||
path_or_fp, normalization=False, waveform_transforms=waveform_transforms
|
||||
)
|
||||
|
||||
features = _get_kaldi_fbank(waveform, sample_rate, n_bins)
|
||||
if features is None:
|
||||
|
@ -3,12 +3,17 @@
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
from argparse import Namespace
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
import logging
|
||||
from typing import Dict, Optional
|
||||
|
||||
from fairseq.data import Dictionary
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_config_from_yaml(yaml_path: Path):
|
||||
try:
|
||||
import yaml
|
||||
@ -128,19 +133,46 @@ class S2TDataConfig(object):
|
||||
the root path. Set this to empty string when using absolute paths."""
|
||||
return self.config.get("audio_root", "")
|
||||
|
||||
def get_feature_transforms(self, split, is_train):
|
||||
def get_transforms(self, transform_type, split, is_train):
|
||||
"""Split-specific feature transforms. Allowing train set
|
||||
wildcard `_train`, evaluation set wildcard `_eval` and general
|
||||
wildcard `*` for matching."""
|
||||
from copy import deepcopy
|
||||
|
||||
cfg = deepcopy(self.config)
|
||||
_cur = cfg.get("transforms", {})
|
||||
_cur = cfg.get(f"{transform_type}transforms", {})
|
||||
cur = _cur.get(split)
|
||||
cur = _cur.get("_train") if cur is None and is_train else cur
|
||||
cur = _cur.get("_eval") if cur is None and not is_train else cur
|
||||
cur = _cur.get("*") if cur is None else cur
|
||||
cfg["transforms"] = cur
|
||||
return cur
|
||||
|
||||
def get_feature_transforms(self, split, is_train):
|
||||
cfg = deepcopy(self.config)
|
||||
# TODO: deprecate transforms
|
||||
cur = self.get_transforms("", split, is_train)
|
||||
if cur is not None:
|
||||
logger.warning(
|
||||
"Auto converting transforms into feature_transforms, "
|
||||
"but transforms will be deprecated in the future. Please "
|
||||
"update this in the config."
|
||||
)
|
||||
ft_transforms = self.get_transforms("feature_", split, is_train)
|
||||
if ft_transforms:
|
||||
cur.extend(ft_transforms)
|
||||
else:
|
||||
cur = self.get_transforms("feature_", split, is_train)
|
||||
cfg["feature_transforms"] = cur
|
||||
return cfg
|
||||
|
||||
def get_waveform_transforms(self, split, is_train):
|
||||
cfg = deepcopy(self.config)
|
||||
cfg["waveform_transforms"] = self.get_transforms("waveform_", split, is_train)
|
||||
return cfg
|
||||
|
||||
def get_dataset_transforms(self, split, is_train):
|
||||
cfg = deepcopy(self.config)
|
||||
cfg["dataset_transforms"] = self.get_transforms("dataset_", split, is_train)
|
||||
return cfg
|
||||
|
||||
@property
|
||||
@ -178,7 +210,13 @@ class S2SDataConfig(S2TDataConfig):
|
||||
def input_transformed_channels(self):
|
||||
"""The number of channels in the audio after feature transforms"""
|
||||
# TODO: move this into individual transforms
|
||||
# TODO: deprecate transforms
|
||||
_cur = self.config.get("transforms", {})
|
||||
ft_transforms = self.config.get("feature_transforms", {})
|
||||
if _cur and ft_transforms:
|
||||
_cur.update(ft_transforms)
|
||||
else:
|
||||
_cur = self.config.get("feature_transforms", {})
|
||||
cur = _cur.get("_train", [])
|
||||
|
||||
_channels = self.input_channels
|
||||
|
53
fairseq/data/audio/dataset_transforms/__init__.py
Normal file
53
fairseq/data/audio/dataset_transforms/__init__.py
Normal file
@ -0,0 +1,53 @@
|
||||
import os
|
||||
from fairseq.data.audio import (
|
||||
AudioTransform,
|
||||
CompositeAudioTransform,
|
||||
import_transforms,
|
||||
register_audio_transform,
|
||||
)
|
||||
|
||||
|
||||
class AudioDatasetTransform(AudioTransform):
|
||||
pass
|
||||
|
||||
|
||||
AUDIO_DATASET_TRANSFORM_REGISTRY = {}
|
||||
AUDIO_DATASET_TRANSFORM_CLASS_NAMES = set()
|
||||
|
||||
|
||||
def get_audio_dataset_transform(name):
|
||||
return AUDIO_DATASET_TRANSFORM_REGISTRY[name]
|
||||
|
||||
|
||||
def register_audio_dataset_transform(name):
|
||||
return register_audio_transform(
|
||||
name,
|
||||
AudioDatasetTransform,
|
||||
AUDIO_DATASET_TRANSFORM_REGISTRY,
|
||||
AUDIO_DATASET_TRANSFORM_CLASS_NAMES,
|
||||
)
|
||||
|
||||
|
||||
import_transforms(os.path.dirname(__file__), "dataset")
|
||||
|
||||
|
||||
class CompositeAudioDatasetTransform(CompositeAudioTransform):
|
||||
@classmethod
|
||||
def from_config_dict(cls, config=None):
|
||||
return super()._from_config_dict(
|
||||
cls,
|
||||
"dataset",
|
||||
get_audio_dataset_transform,
|
||||
CompositeAudioDatasetTransform,
|
||||
config,
|
||||
return_empty=True,
|
||||
)
|
||||
|
||||
def get_transform(self, cls):
|
||||
for t in self.transforms:
|
||||
if isinstance(t, cls):
|
||||
return t
|
||||
return None
|
||||
|
||||
def has_transform(self, cls):
|
||||
return self.get_transform(cls) is not None
|
61
fairseq/data/audio/dataset_transforms/concataugment.py
Normal file
61
fairseq/data/audio/dataset_transforms/concataugment.py
Normal file
@ -0,0 +1,61 @@
|
||||
from typing import List
|
||||
import numpy as np
|
||||
|
||||
from fairseq.data.audio.dataset_transforms import (
|
||||
AudioDatasetTransform,
|
||||
register_audio_dataset_transform,
|
||||
)
|
||||
|
||||
_DEFAULTS = {"rate": 0.25, "max_tokens": 3000, "attempts": 5}
|
||||
|
||||
|
||||
@register_audio_dataset_transform("concataugment")
|
||||
class ConcatAugment(AudioDatasetTransform):
|
||||
@classmethod
|
||||
def from_config_dict(cls, config=None):
|
||||
_config = {} if config is None else config
|
||||
return ConcatAugment(
|
||||
_config.get("rate", _DEFAULTS["rate"]),
|
||||
_config.get("max_tokens", _DEFAULTS["max_tokens"]),
|
||||
_config.get("attempts", _DEFAULTS["attempts"]),
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
rate=_DEFAULTS["rate"],
|
||||
max_tokens=_DEFAULTS["max_tokens"],
|
||||
attempts=_DEFAULTS["attempts"],
|
||||
):
|
||||
self.rate, self.max_tokens, self.attempts = rate, max_tokens, attempts
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
self.__class__.__name__
|
||||
+ "("
|
||||
+ ", ".join(
|
||||
[
|
||||
f"rate={self.rate}",
|
||||
f"max_tokens={self.max_tokens}",
|
||||
f"attempts={self.attempts}",
|
||||
]
|
||||
)
|
||||
+ ")"
|
||||
)
|
||||
|
||||
def find_indices(self, index: int, n_frames: List[int], n_samples: int):
|
||||
# skip conditions: application rate, max_tokens limit exceeded
|
||||
if np.random.random() > self.rate:
|
||||
return [index]
|
||||
if self.max_tokens and n_frames[index] > self.max_tokens:
|
||||
return [index]
|
||||
|
||||
# pick second sample to concatenate
|
||||
for _ in range(self.attempts):
|
||||
index2 = np.random.randint(0, n_samples)
|
||||
if index2 != index and (
|
||||
not self.max_tokens
|
||||
or n_frames[index] + n_frames[index2] < self.max_tokens
|
||||
):
|
||||
return [index, index2]
|
||||
|
||||
return [index]
|
105
fairseq/data/audio/dataset_transforms/noisyoverlapaugment.py
Normal file
105
fairseq/data/audio/dataset_transforms/noisyoverlapaugment.py
Normal file
@ -0,0 +1,105 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from fairseq.data.audio import rand_uniform
|
||||
from fairseq.data.audio.dataset_transforms import (
|
||||
AudioDatasetTransform,
|
||||
register_audio_dataset_transform,
|
||||
)
|
||||
from fairseq.data.audio.waveform_transforms.noiseaugment import (
|
||||
NoiseAugmentTransform,
|
||||
)
|
||||
|
||||
_DEFAULTS = {
|
||||
"rate": 0.25,
|
||||
"mixing_noise_rate": 0.1,
|
||||
"noise_path": "",
|
||||
"noise_snr_min": -5,
|
||||
"noise_snr_max": 5,
|
||||
"utterance_snr_min": -5,
|
||||
"utterance_snr_max": 5,
|
||||
}
|
||||
|
||||
|
||||
@register_audio_dataset_transform("noisyoverlapaugment")
|
||||
class NoisyOverlapAugment(AudioDatasetTransform):
|
||||
@classmethod
|
||||
def from_config_dict(cls, config=None):
|
||||
_config = {} if config is None else config
|
||||
return NoisyOverlapAugment(
|
||||
_config.get("rate", _DEFAULTS["rate"]),
|
||||
_config.get("mixing_noise_rate", _DEFAULTS["mixing_noise_rate"]),
|
||||
_config.get("noise_path", _DEFAULTS["noise_path"]),
|
||||
_config.get("noise_snr_min", _DEFAULTS["noise_snr_min"]),
|
||||
_config.get("noise_snr_max", _DEFAULTS["noise_snr_max"]),
|
||||
_config.get("utterance_snr_min", _DEFAULTS["utterance_snr_min"]),
|
||||
_config.get("utterance_snr_max", _DEFAULTS["utterance_snr_max"]),
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
rate=_DEFAULTS["rate"],
|
||||
mixing_noise_rate=_DEFAULTS["mixing_noise_rate"],
|
||||
noise_path=_DEFAULTS["noise_path"],
|
||||
noise_snr_min=_DEFAULTS["noise_snr_min"],
|
||||
noise_snr_max=_DEFAULTS["noise_snr_max"],
|
||||
utterance_snr_min=_DEFAULTS["utterance_snr_min"],
|
||||
utterance_snr_max=_DEFAULTS["utterance_snr_max"],
|
||||
):
|
||||
self.rate = rate
|
||||
self.mixing_noise_rate = mixing_noise_rate
|
||||
self.noise_shaper = NoiseAugmentTransform(noise_path)
|
||||
self.noise_snr_min = noise_snr_min
|
||||
self.noise_snr_max = noise_snr_max
|
||||
self.utterance_snr_min = utterance_snr_min
|
||||
self.utterance_snr_max = utterance_snr_max
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
self.__class__.__name__
|
||||
+ "("
|
||||
+ ", ".join(
|
||||
[
|
||||
f"rate={self.rate}",
|
||||
f"mixing_noise_rate={self.mixing_noise_rate}",
|
||||
f"noise_snr_min={self.noise_snr_min}",
|
||||
f"noise_snr_max={self.noise_snr_max}",
|
||||
f"utterance_snr_min={self.utterance_snr_min}",
|
||||
f"utterance_snr_max={self.utterance_snr_max}",
|
||||
]
|
||||
)
|
||||
+ ")"
|
||||
)
|
||||
|
||||
def __call__(self, sources):
|
||||
for i, source in enumerate(sources):
|
||||
if np.random.random() > self.rate:
|
||||
continue
|
||||
|
||||
pri = source.numpy()
|
||||
|
||||
if np.random.random() > self.mixing_noise_rate:
|
||||
sec = sources[np.random.randint(0, len(sources))].numpy()
|
||||
snr = rand_uniform(self.utterance_snr_min, self.utterance_snr_max)
|
||||
else:
|
||||
sec = self.noise_shaper.pick_sample(source.shape)
|
||||
snr = rand_uniform(self.noise_snr_min, self.noise_snr_max)
|
||||
|
||||
L1 = pri.shape[-1]
|
||||
L2 = sec.shape[-1]
|
||||
l = np.random.randint(0, min(round(L1 / 2), L2)) # mix len
|
||||
s_source = np.random.randint(0, L1 - l)
|
||||
s_sec = np.random.randint(0, L2 - l)
|
||||
|
||||
get_power = lambda x: np.mean(x**2)
|
||||
if get_power(sec) == 0:
|
||||
continue
|
||||
|
||||
scl = np.sqrt(get_power(pri) / (np.power(10, snr / 10) * get_power(sec)))
|
||||
|
||||
pri[s_source : s_source + l] = np.add(
|
||||
pri[s_source : s_source + l], np.multiply(scl, sec[s_sec : s_sec + l])
|
||||
)
|
||||
sources[i] = torch.from_numpy(pri).float()
|
||||
|
||||
return sources
|
@ -1,82 +1,43 @@
|
||||
import importlib
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Optional
|
||||
from fairseq.data.audio import (
|
||||
AudioTransform,
|
||||
CompositeAudioTransform,
|
||||
import_transforms,
|
||||
register_audio_transform,
|
||||
)
|
||||
|
||||
|
||||
class AudioFeatureTransform(ABC):
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def from_config_dict(cls, config: Optional[Dict] = None):
|
||||
pass
|
||||
class AudioFeatureTransform(AudioTransform):
|
||||
pass
|
||||
|
||||
|
||||
AUDIO_FEATURE_TRANSFORM_REGISTRY = {}
|
||||
AUDIO_FEATURE_TRANSFORM_CLASS_NAMES = set()
|
||||
|
||||
|
||||
def register_audio_feature_transform(name):
|
||||
def register_audio_feature_transform_cls(cls):
|
||||
if name in AUDIO_FEATURE_TRANSFORM_REGISTRY:
|
||||
raise ValueError(f"Cannot register duplicate transform ({name})")
|
||||
if not issubclass(cls, AudioFeatureTransform):
|
||||
raise ValueError(
|
||||
f"Transform ({name}: {cls.__name__}) must extend "
|
||||
"AudioFeatureTransform"
|
||||
)
|
||||
if cls.__name__ in AUDIO_FEATURE_TRANSFORM_CLASS_NAMES:
|
||||
raise ValueError(
|
||||
f"Cannot register audio feature transform with duplicate "
|
||||
f"class name ({cls.__name__})"
|
||||
)
|
||||
AUDIO_FEATURE_TRANSFORM_REGISTRY[name] = cls
|
||||
AUDIO_FEATURE_TRANSFORM_CLASS_NAMES.add(cls.__name__)
|
||||
return cls
|
||||
|
||||
return register_audio_feature_transform_cls
|
||||
|
||||
|
||||
def get_audio_feature_transform(name):
|
||||
return AUDIO_FEATURE_TRANSFORM_REGISTRY[name]
|
||||
|
||||
|
||||
transforms_dir = os.path.dirname(__file__)
|
||||
for file in os.listdir(transforms_dir):
|
||||
path = os.path.join(transforms_dir, file)
|
||||
if (
|
||||
not file.startswith("_")
|
||||
and not file.startswith(".")
|
||||
and (file.endswith(".py") or os.path.isdir(path))
|
||||
):
|
||||
name = file[: file.find(".py")] if file.endswith(".py") else file
|
||||
importlib.import_module("fairseq.data.audio.feature_transforms." + name)
|
||||
def register_audio_feature_transform(name):
|
||||
return register_audio_transform(
|
||||
name,
|
||||
AudioFeatureTransform,
|
||||
AUDIO_FEATURE_TRANSFORM_REGISTRY,
|
||||
AUDIO_FEATURE_TRANSFORM_CLASS_NAMES,
|
||||
)
|
||||
|
||||
|
||||
class CompositeAudioFeatureTransform(AudioFeatureTransform):
|
||||
import_transforms(os.path.dirname(__file__), "feature")
|
||||
|
||||
|
||||
class CompositeAudioFeatureTransform(CompositeAudioTransform):
|
||||
@classmethod
|
||||
def from_config_dict(cls, config=None):
|
||||
_config = {} if config is None else config
|
||||
_transforms = _config.get("transforms")
|
||||
if _transforms is None:
|
||||
return None
|
||||
transforms = [
|
||||
get_audio_feature_transform(_t).from_config_dict(_config.get(_t))
|
||||
for _t in _transforms
|
||||
]
|
||||
return CompositeAudioFeatureTransform(transforms)
|
||||
|
||||
def __init__(self, transforms):
|
||||
self.transforms = [t for t in transforms if t is not None]
|
||||
|
||||
def __call__(self, x):
|
||||
for t in self.transforms:
|
||||
x = t(x)
|
||||
return x
|
||||
|
||||
def __repr__(self):
|
||||
format_string = (
|
||||
[self.__class__.__name__ + "("]
|
||||
+ [f" {t.__repr__()}" for t in self.transforms]
|
||||
+ [")"]
|
||||
return super()._from_config_dict(
|
||||
cls,
|
||||
"feature",
|
||||
get_audio_feature_transform,
|
||||
CompositeAudioFeatureTransform,
|
||||
config,
|
||||
)
|
||||
return "\n".join(format_string)
|
||||
|
@ -10,7 +10,7 @@ import re
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -21,6 +21,12 @@ from fairseq.data import data_utils as fairseq_data_utils
|
||||
from fairseq.data.audio.audio_utils import get_features_or_waveform
|
||||
from fairseq.data.audio.data_cfg import S2TDataConfig
|
||||
from fairseq.data.audio.feature_transforms import CompositeAudioFeatureTransform
|
||||
from fairseq.data.audio.waveform_transforms import CompositeAudioWaveformTransform
|
||||
from fairseq.data.audio.dataset_transforms import CompositeAudioDatasetTransform
|
||||
from fairseq.data.audio.dataset_transforms.concataugment import ConcatAugment
|
||||
from fairseq.data.audio.dataset_transforms.noisyoverlapaugment import (
|
||||
NoisyOverlapAugment,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -46,6 +52,12 @@ def _collate_frames(
|
||||
return out
|
||||
|
||||
|
||||
def _is_int_or_np_int(n):
|
||||
return isinstance(n, int) or (
|
||||
isinstance(n, np.generic) and isinstance(n.item(), int)
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SpeechToTextDatasetItem(object):
|
||||
index: int
|
||||
@ -102,6 +114,20 @@ class SpeechToTextDataset(FairseqDataset):
|
||||
self.feature_transforms = CompositeAudioFeatureTransform.from_config_dict(
|
||||
self.cfg.get_feature_transforms(split, is_train_split)
|
||||
)
|
||||
self.waveform_transforms = CompositeAudioWaveformTransform.from_config_dict(
|
||||
self.cfg.get_waveform_transforms(split, is_train_split)
|
||||
)
|
||||
# TODO: add these to data_cfg.py
|
||||
self.dataset_transforms = CompositeAudioDatasetTransform.from_config_dict(
|
||||
self.cfg.get_dataset_transforms(split, is_train_split)
|
||||
)
|
||||
|
||||
# check proper usage of transforms
|
||||
if self.feature_transforms and self.cfg.use_audio_input:
|
||||
logger.warning(
|
||||
"Feature transforms will not be applied. To use feature transforms, "
|
||||
"set use_audio_input as False in config."
|
||||
)
|
||||
|
||||
self.pre_tokenizer = pre_tokenizer
|
||||
self.bpe_tokenizer = bpe_tokenizer
|
||||
@ -136,8 +162,11 @@ class SpeechToTextDataset(FairseqDataset):
|
||||
self.__class__.__name__
|
||||
+ f'(split="{self.split}", n_samples={self.n_samples:_}, '
|
||||
f"prepend_tgt_lang_tag={self.cfg.prepend_tgt_lang_tag}, "
|
||||
f"shuffle={self.shuffle}, transforms={self.feature_transforms}, "
|
||||
f"n_frames_per_step={self.n_frames_per_step}"
|
||||
f"n_frames_per_step={self.n_frames_per_step}, "
|
||||
f"shuffle={self.shuffle}, "
|
||||
f"feature_transforms={self.feature_transforms}, "
|
||||
f"waveform_transforms={self.waveform_transforms}, "
|
||||
f"dataset_transforms={self.dataset_transforms})"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -157,8 +186,13 @@ class SpeechToTextDataset(FairseqDataset):
|
||||
def tokenize(cls, tokenizer, text: str):
|
||||
return text if tokenizer is None else tokenizer.encode(text)
|
||||
|
||||
def get_tokenized_tgt_text(self, index: int):
|
||||
text = self.tokenize(self.pre_tokenizer, self.tgt_texts[index])
|
||||
def get_tokenized_tgt_text(self, index: Union[int, List[int]]):
|
||||
if _is_int_or_np_int(index):
|
||||
text = self.tgt_texts[index]
|
||||
else:
|
||||
text = " ".join([self.tgt_texts[i] for i in index])
|
||||
|
||||
text = self.tokenize(self.pre_tokenizer, text)
|
||||
text = self.tokenize(self.bpe_tokenizer, text)
|
||||
return text
|
||||
|
||||
@ -175,12 +209,37 @@ class SpeechToTextDataset(FairseqDataset):
|
||||
assert lang_tag_idx != dictionary.unk()
|
||||
return lang_tag_idx
|
||||
|
||||
def _get_source_audio(self, index: int) -> torch.Tensor:
|
||||
source = get_features_or_waveform(
|
||||
self.audio_paths[index],
|
||||
need_waveform=self.cfg.use_audio_input,
|
||||
use_sample_rate=self.cfg.use_sample_rate,
|
||||
)
|
||||
def _get_source_audio(self, index: Union[int, List[int]]) -> torch.Tensor:
|
||||
"""
|
||||
Gives source audio for given index with any relevant transforms
|
||||
applied. For ConcatAug, source audios for given indices are
|
||||
concatenated in given order.
|
||||
Args:
|
||||
index (int or List[int]): index—or in the case of ConcatAug,
|
||||
indices—to pull the source audio for
|
||||
Returns:
|
||||
source audios concatenated for given indices with
|
||||
relevant transforms appplied
|
||||
"""
|
||||
if _is_int_or_np_int(index):
|
||||
source = get_features_or_waveform(
|
||||
self.audio_paths[index],
|
||||
need_waveform=self.cfg.use_audio_input,
|
||||
use_sample_rate=self.cfg.use_sample_rate,
|
||||
waveform_transforms=self.waveform_transforms,
|
||||
)
|
||||
else:
|
||||
source = np.concatenate(
|
||||
[
|
||||
get_features_or_waveform(
|
||||
self.audio_paths[i],
|
||||
need_waveform=self.cfg.use_audio_input,
|
||||
use_sample_rate=self.cfg.use_sample_rate,
|
||||
waveform_transforms=self.waveform_transforms,
|
||||
)
|
||||
for i in index
|
||||
]
|
||||
)
|
||||
if self.cfg.use_audio_input:
|
||||
source = torch.from_numpy(source).float()
|
||||
if self.cfg.standardize_audio:
|
||||
@ -193,12 +252,17 @@ class SpeechToTextDataset(FairseqDataset):
|
||||
return source
|
||||
|
||||
def __getitem__(self, index: int) -> SpeechToTextDatasetItem:
|
||||
source = self._get_source_audio(index)
|
||||
has_concat = self.dataset_transforms.has_transform(ConcatAugment)
|
||||
if has_concat:
|
||||
concat = self.dataset_transforms.get_transform(ConcatAugment)
|
||||
indices = concat.find_indices(index, self.n_frames, self.n_samples)
|
||||
|
||||
source = self._get_source_audio(indices if has_concat else index)
|
||||
source = self.pack_frames(source)
|
||||
|
||||
target = None
|
||||
if self.tgt_texts is not None:
|
||||
tokenized = self.get_tokenized_tgt_text(index)
|
||||
tokenized = self.get_tokenized_tgt_text(indices if has_concat else index)
|
||||
target = self.tgt_dict.encode_line(
|
||||
tokenized, add_if_not_exist=False, append_eos=self.append_eos
|
||||
).long()
|
||||
@ -231,9 +295,16 @@ class SpeechToTextDataset(FairseqDataset):
|
||||
if len(samples) == 0:
|
||||
return {}
|
||||
indices = torch.tensor([x.index for x in samples], dtype=torch.long)
|
||||
frames = _collate_frames([x.source for x in samples], self.cfg.use_audio_input)
|
||||
|
||||
sources = [x.source for x in samples]
|
||||
has_NOAug = self.dataset_transforms.has_transform(NoisyOverlapAugment)
|
||||
if has_NOAug and self.cfg.use_audio_input:
|
||||
NOAug = self.dataset_transforms.get_transform(NoisyOverlapAugment)
|
||||
sources = NOAug(sources)
|
||||
|
||||
frames = _collate_frames(sources, self.cfg.use_audio_input)
|
||||
# sort samples by descending number of frames
|
||||
n_frames = torch.tensor([x.source.size(0) for x in samples], dtype=torch.long)
|
||||
n_frames = torch.tensor([x.size(0) for x in sources], dtype=torch.long)
|
||||
n_frames, order = n_frames.sort(descending=True)
|
||||
indices = indices.index_select(0, order)
|
||||
frames = frames.index_select(0, order)
|
||||
|
48
fairseq/data/audio/waveform_transforms/__init__.py
Normal file
48
fairseq/data/audio/waveform_transforms/__init__.py
Normal file
@ -0,0 +1,48 @@
|
||||
import os
|
||||
from fairseq.data.audio import (
|
||||
AudioTransform,
|
||||
CompositeAudioTransform,
|
||||
import_transforms,
|
||||
register_audio_transform,
|
||||
)
|
||||
|
||||
|
||||
class AudioWaveformTransform(AudioTransform):
|
||||
pass
|
||||
|
||||
|
||||
AUDIO_WAVEFORM_TRANSFORM_REGISTRY = {}
|
||||
AUDIO_WAVEFORM_TRANSFORM_CLASS_NAMES = set()
|
||||
|
||||
|
||||
def get_audio_waveform_transform(name):
|
||||
return AUDIO_WAVEFORM_TRANSFORM_REGISTRY[name]
|
||||
|
||||
|
||||
def register_audio_waveform_transform(name):
|
||||
return register_audio_transform(
|
||||
name,
|
||||
AudioWaveformTransform,
|
||||
AUDIO_WAVEFORM_TRANSFORM_REGISTRY,
|
||||
AUDIO_WAVEFORM_TRANSFORM_CLASS_NAMES,
|
||||
)
|
||||
|
||||
|
||||
import_transforms(os.path.dirname(__file__), "waveform")
|
||||
|
||||
|
||||
class CompositeAudioWaveformTransform(CompositeAudioTransform):
|
||||
@classmethod
|
||||
def from_config_dict(cls, config=None):
|
||||
return super()._from_config_dict(
|
||||
cls,
|
||||
"waveform",
|
||||
get_audio_waveform_transform,
|
||||
CompositeAudioWaveformTransform,
|
||||
config,
|
||||
)
|
||||
|
||||
def __call__(self, x, sample_rate):
|
||||
for t in self.transforms:
|
||||
x, sample_rate = t(x, sample_rate)
|
||||
return x, sample_rate
|
201
fairseq/data/audio/waveform_transforms/noiseaugment.py
Normal file
201
fairseq/data/audio/waveform_transforms/noiseaugment.py
Normal file
@ -0,0 +1,201 @@
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
from math import ceil
|
||||
|
||||
from fairseq.data.audio import rand_uniform
|
||||
from fairseq.data.audio.waveform_transforms import (
|
||||
AudioWaveformTransform,
|
||||
register_audio_waveform_transform,
|
||||
)
|
||||
|
||||
SNR_MIN = 5.0
|
||||
SNR_MAX = 15.0
|
||||
RATE = 0.25
|
||||
|
||||
NOISE_RATE = 1.0
|
||||
NOISE_LEN_MEAN = 0.2
|
||||
NOISE_LEN_STD = 0.05
|
||||
|
||||
|
||||
class NoiseAugmentTransform(AudioWaveformTransform):
|
||||
@classmethod
|
||||
def from_config_dict(cls, config=None):
|
||||
_config = {} if config is None else config
|
||||
return cls(
|
||||
_config.get("samples_path", None),
|
||||
_config.get("snr_min", SNR_MIN),
|
||||
_config.get("snr_max", SNR_MAX),
|
||||
_config.get("rate", RATE),
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
samples_path: str,
|
||||
snr_min: float = SNR_MIN,
|
||||
snr_max: float = SNR_MAX,
|
||||
rate: float = RATE,
|
||||
):
|
||||
# Sanity checks
|
||||
assert (
|
||||
samples_path
|
||||
), "need to provide path to audio samples for noise augmentation"
|
||||
assert snr_max >= snr_min, f"empty signal-to-noise range ({snr_min}, {snr_max})"
|
||||
assert rate >= 0 and rate <= 1, "rate should be a float between 0 to 1"
|
||||
|
||||
self.paths = list(Path(samples_path).glob("**/*.wav")) # load music
|
||||
self.n_samples = len(self.paths)
|
||||
assert self.n_samples > 0, f"no audio files found in {samples_path}"
|
||||
|
||||
self.snr_min = snr_min
|
||||
self.snr_max = snr_max
|
||||
self.rate = rate
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
self.__class__.__name__
|
||||
+ "("
|
||||
+ ", ".join(
|
||||
[
|
||||
f"n_samples={self.n_samples}",
|
||||
f"snr={self.snr_min}-{self.snr_max}dB",
|
||||
f"rate={self.rate}",
|
||||
]
|
||||
)
|
||||
+ ")"
|
||||
)
|
||||
|
||||
def pick_sample(self, goal_shape, always_2d=False, use_sample_rate=None):
|
||||
from fairseq.data.audio.audio_utils import get_waveform
|
||||
|
||||
path = self.paths[np.random.randint(0, self.n_samples)]
|
||||
sample = get_waveform(
|
||||
path, always_2d=always_2d, output_sample_rate=use_sample_rate
|
||||
)[0]
|
||||
|
||||
# Check dimensions match, else silently skip adding noise to sample
|
||||
# NOTE: SHOULD THIS QUIT WITH AN ERROR?
|
||||
is_2d = len(goal_shape) == 2
|
||||
if len(goal_shape) != sample.ndim or (
|
||||
is_2d and goal_shape[0] != sample.shape[0]
|
||||
):
|
||||
return np.zeros(goal_shape)
|
||||
|
||||
# Cut/repeat sample to size
|
||||
len_dim = len(goal_shape) - 1
|
||||
n_repeat = ceil(goal_shape[len_dim] / sample.shape[len_dim])
|
||||
repeated = np.tile(sample, [1, n_repeat] if is_2d else n_repeat)
|
||||
start = np.random.randint(0, repeated.shape[len_dim] - goal_shape[len_dim] + 1)
|
||||
return (
|
||||
repeated[:, start : start + goal_shape[len_dim]]
|
||||
if is_2d
|
||||
else repeated[start : start + goal_shape[len_dim]]
|
||||
)
|
||||
|
||||
def _mix(self, source, noise, snr):
|
||||
get_power = lambda x: np.mean(x**2)
|
||||
if get_power(noise):
|
||||
scl = np.sqrt(
|
||||
get_power(source) / (np.power(10, snr / 10) * get_power(noise))
|
||||
)
|
||||
else:
|
||||
scl = 0
|
||||
return 1 * source + scl * noise
|
||||
|
||||
def _get_noise(self, goal_shape, always_2d=False, use_sample_rate=None):
|
||||
return self.pick_sample(goal_shape, always_2d, use_sample_rate)
|
||||
|
||||
def __call__(self, source, sample_rate):
|
||||
if np.random.random() > self.rate:
|
||||
return source, sample_rate
|
||||
|
||||
noise = self._get_noise(
|
||||
source.shape, always_2d=True, use_sample_rate=sample_rate
|
||||
)
|
||||
|
||||
return (
|
||||
self._mix(source, noise, rand_uniform(self.snr_min, self.snr_max)),
|
||||
sample_rate,
|
||||
)
|
||||
|
||||
|
||||
@register_audio_waveform_transform("musicaugment")
|
||||
class MusicAugmentTransform(NoiseAugmentTransform):
|
||||
pass
|
||||
|
||||
|
||||
@register_audio_waveform_transform("backgroundnoiseaugment")
|
||||
class BackgroundNoiseAugmentTransform(NoiseAugmentTransform):
|
||||
pass
|
||||
|
||||
|
||||
@register_audio_waveform_transform("babbleaugment")
|
||||
class BabbleAugmentTransform(NoiseAugmentTransform):
|
||||
def _get_noise(self, goal_shape, always_2d=False, use_sample_rate=None):
|
||||
for i in range(np.random.randint(3, 8)):
|
||||
speech = self.pick_sample(goal_shape, always_2d, use_sample_rate)
|
||||
if i == 0:
|
||||
agg_noise = speech
|
||||
else: # SNR scaled by i (how many noise signals already in agg_noise)
|
||||
agg_noise = self._mix(agg_noise, speech, i)
|
||||
return agg_noise
|
||||
|
||||
|
||||
@register_audio_waveform_transform("sporadicnoiseaugment")
|
||||
class SporadicNoiseAugmentTransform(NoiseAugmentTransform):
|
||||
@classmethod
|
||||
def from_config_dict(cls, config=None):
|
||||
_config = {} if config is None else config
|
||||
return cls(
|
||||
_config.get("samples_path", None),
|
||||
_config.get("snr_min", SNR_MIN),
|
||||
_config.get("snr_max", SNR_MAX),
|
||||
_config.get("rate", RATE),
|
||||
_config.get("noise_rate", NOISE_RATE),
|
||||
_config.get("noise_len_mean", NOISE_LEN_MEAN),
|
||||
_config.get("noise_len_std", NOISE_LEN_STD),
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
samples_path: str,
|
||||
snr_min: float = SNR_MIN,
|
||||
snr_max: float = SNR_MAX,
|
||||
rate: float = RATE,
|
||||
noise_rate: float = NOISE_RATE, # noises per second
|
||||
noise_len_mean: float = NOISE_LEN_MEAN, # length of noises in seconds
|
||||
noise_len_std: float = NOISE_LEN_STD,
|
||||
):
|
||||
super().__init__(samples_path, snr_min, snr_max, rate)
|
||||
self.noise_rate = noise_rate
|
||||
self.noise_len_mean = noise_len_mean
|
||||
self.noise_len_std = noise_len_std
|
||||
|
||||
def _get_noise(self, goal_shape, always_2d=False, use_sample_rate=None):
|
||||
agg_noise = np.zeros(goal_shape)
|
||||
len_dim = len(goal_shape) - 1
|
||||
is_2d = len(goal_shape) == 2
|
||||
|
||||
n_noises = round(self.noise_rate * goal_shape[len_dim] / use_sample_rate)
|
||||
start_pointers = [
|
||||
round(rand_uniform(0, goal_shape[len_dim])) for _ in range(n_noises)
|
||||
]
|
||||
|
||||
for start_pointer in start_pointers:
|
||||
noise_shape = list(goal_shape)
|
||||
len_seconds = np.random.normal(self.noise_len_mean, self.noise_len_std)
|
||||
noise_shape[len_dim] = round(max(0, len_seconds) * use_sample_rate)
|
||||
end_pointer = start_pointer + noise_shape[len_dim]
|
||||
if end_pointer >= goal_shape[len_dim]:
|
||||
continue
|
||||
|
||||
noise = self.pick_sample(noise_shape, always_2d, use_sample_rate)
|
||||
if is_2d:
|
||||
agg_noise[:, start_pointer:end_pointer] = (
|
||||
agg_noise[:, start_pointer:end_pointer] + noise
|
||||
)
|
||||
else:
|
||||
agg_noise[start_pointer:end_pointer] = (
|
||||
agg_noise[start_pointer:end_pointer] + noise
|
||||
)
|
||||
|
||||
return agg_noise
|
Loading…
Reference in New Issue
Block a user