diff --git a/examples/audio_nlp/nlu/README.md b/examples/audio_nlp/nlu/README.md new file mode 100644 index 000000000..ec23934e6 --- /dev/null +++ b/examples/audio_nlp/nlu/README.md @@ -0,0 +1,51 @@ +# End-to-end NLU + +End-to-end spoken language understanding (SLU) predicts intent directly from audio using a single model. It promises to improve the performance of assistant systems by leveraging acoustic information lost in the intermediate textual representation and preventing cascading errors from Automatic Speech Recognition (ASR). Further, having one unified model has efficiency advantages when deploying assistant systems on-device. + +This page releases the code for reproducing the results in [STOP: A dataset for Spoken Task Oriented Semantic Parsing](TODO) + +The dataset can be downloaded here: [download link](https://dl.fbaipublicfiles.com/stop/stop.tar.gz) + +## Pretrained models end-to-end NLU Models + +| Speech Pretraining | ASR Pretraining | Test EM Accuracy | Tesst EM-Tree Accuracy | Link | +| ----------- | ----------- |----------|----------|----------| +| None | None | 36.54 | 57.01 | [link](https://dl.fbaipublicfiles.com/stop/end-to-end-nlu-none-none.pt) | +| Wav2Vec | None | 68.05 | 82.53 | [link](https://dl.fbaipublicfiles.com/stop/end-to-end-nlu-wav2vec-none.pt) | +| HuBERT | None | 68.40 | 82.85 | [link](https://dl.fbaipublicfiles.com/stop/end-to-end-nlu-hubert-none.pt) | +| Wav2Vec | STOP | 68.70 | 82.78 | [link](https://dl.fbaipublicfiles.com/stop/end-to-end-nlu-wav2vec-stop.pt) | +| HuBERT | STOP | 69.23 | 82.87 | [link](https://dl.fbaipublicfiles.com/stop/end-to-end-nlu-hubert-stop.pt) | +| Wav2Vec | Librispeech | 68.47 | 82.49 | [link](https://dl.fbaipublicfiles.com/stop/end-to-end-nlu-wav2vec-ls.pt) | +| HuBERT | Librispeech | 68.70 | 82.78 | [link](https://dl.fbaipublicfiles.com/stop/end-to-end-nlu-hubert-ls.pt) | + +## Pretrained models ASR Models +| Speech Pre-training | ASR Dataset | STOP Eval WER | STOP Test WER | dev\_other WER | dev\_clean WER | test\_clean WER | test\_other WER | Link | +| ----------- | ----------- | ----------- | ----------- | ----------- | ----------- | ----------- | ----------- | ----------- | +| HuBERT | Librispeech | 8.47 | 2.99 | 3.25 | 8.06 | 25.68 | 26.19 | [link](https://dl.fbaipublicfiles.com/stop/ctc-asr-hubert-ls.pt) | +| Wav2Vec | Librispeech | 9.215 | 3.204 | 3.334 | 9.006 | 27.257 | 27.588 | [link](https://dl.fbaipublicfiles.com/stop/ctc-asr-wav2vec-ls.pt) | +| HuBERT | STOP | 46.31 | 31.30 | 31.52 | 47.16 | 4.29 | 4.26 | [link](https://dl.fbaipublicfiles.com/stop/ctc-asr-hubert-stop.pt) | +| Wav2Vec | STOP | 43.103 | 27.833 | 28.479 | 28.479 | 4.679 | 4.667 | [link](https://dl.fbaipublicfiles.com/stop/ctc-asr-wav2vec-stop.pt) | +| HuBERT | Librispeech + STOP | 9.015 | 3.211 | 3.372 | 8.635 | 5.133 | 5.056 | [link](https://dl.fbaipublicfiles.com/stop/ctc-asr-hubert-ls-stop.pt) | +| Wav2Vec | Librispeech + STOP | 9.549 | 3.537 | 3.625 | 9.514 | 5.59 | 5.562 | [link](https://dl.fbaipublicfiles.com/stop/ctc-asr-wav2vec-ls-stop.pt) | + +## Creating the fairseq datasets from STOP + +First, create the audio file manifests and label files: + +``` +python examples/audio_nlp/nlu/generate_manifests.py --stop_root $STOP_DOWNLOAD_DIR/stop --output $FAIRSEQ_DATASET_OUTPUT/ +``` + + +Run `./examples/audio_nlp/nlu/create_dict_stop.sh $FAIRSEQ_DATASET_OUTPUT` to generate the fairseq dictionaries. + + +## Training an End-to-end NLU Model + + +Download a wav2vec or hubert model from [link](https://github.com/facebookresearch/fairseq/tree/main/examples/hubert) or [link](https://github.com/facebookresearch/fairseq/tree/main/examples/wav2vec) + + +``` +python fairseq_cli/hydra-train --config-dir examples/audio_nlp/nlu/configs/ --config-name nlu_finetuning task.data=$FAIRSEQ_DATA_OUTPUT model.w2v_path=$PRETRAINED_MODEL_PATH +``` diff --git a/examples/audio_nlp/nlu/configs/nlu_finetuning.yaml b/examples/audio_nlp/nlu/configs/nlu_finetuning.yaml new file mode 100644 index 000000000..bb90f45a3 --- /dev/null +++ b/examples/audio_nlp/nlu/configs/nlu_finetuning.yaml @@ -0,0 +1,59 @@ +# @package _group_ + +common: + fp16: true + log_format: json + log_interval: 10 + tensorboard_logdir: tb + +checkpoint: + no_epoch_checkpoints: true + best_checkpoint_metric: em_error + save_interval: 10 + +task: + _name: nlu_finetuning + data: ??? + labels: parse + eval_wer_parse: true + autoregressive: true + +dataset: + num_workers: 6 + max_tokens: 1600000 + skip_invalid_size_inputs_valid_test: true + valid_subset: eval,test + train_subset: train + validate_interval: 10 + +criterion: + _name: label_smoothed_cross_entropy + +optimization: + max_update: 320000 + lr: [0.0001] + sentence_avg: true + update_freq: [1] + +optimizer: + _name: adam + adam_betas: (0.9,0.98) + adam_eps: 1e-08 + +lr_scheduler: + _name: tri_stage + phase_ratio: [0.1, 0.4, 0.5] + final_lr_scale: 0.05 + +model: + _name: wav2vec_seq2seq + w2v_path: ??? + autoregressive: true + apply_mask: true + mask_prob: 0.5 + mask_channel_prob: 0.5 + mask_channel_length: 64 + layerdrop: 0.1 + activation_dropout: 0.1 + feature_grad_mult: 0.0 + freeze_finetune_updates: 0 diff --git a/examples/audio_nlp/nlu/create_dict_stop.sh b/examples/audio_nlp/nlu/create_dict_stop.sh new file mode 100755 index 000000000..c78203534 --- /dev/null +++ b/examples/audio_nlp/nlu/create_dict_stop.sh @@ -0,0 +1,38 @@ +#!/bin/bash + +### Script handling creation of data binaries +### for model training within fairseq + + +fairseq_root="." + +data_root=$1 +train_prefix="${data_root}/train" +valid_prefix="${data_root}/eval" +test_prefix="${data_root}/test" + +dest_dir="$data_root/" + +#echo "src dict: $src_dict" > "$dest_dir/src_dict.txt" +#echo "trg dict: $tgt_dict" > "$dest_dir/tgt_dict.txt" + + #--tgtdict $tgt_dict \ +PYTHONPATH=$fairseq_root \ + python $fairseq_root/fairseq_cli/preprocess.py \ + --source-lang "parse" \ + --trainpref $train_prefix \ + --validpref $valid_prefix \ + --destdir $dest_dir \ + --only-source \ + --dict-only \ + --workers 60; + +PYTHONPATH=$fairseq_root \ + python $fairseq_root/fairseq_cli/preprocess.py \ + --source-lang "ltr" \ + --trainpref $train_prefix \ + --validpref $valid_prefix \ + --destdir $dest_dir \ + --only-source \ + --dict-only \ + --workers 60; diff --git a/fairseq/models/hubert/hubert_asr.py b/fairseq/models/hubert/hubert_asr.py index 8e06a2e67..6f5dc46f5 100644 --- a/fairseq/models/hubert/hubert_asr.py +++ b/fairseq/models/hubert/hubert_asr.py @@ -3,20 +3,31 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from typing import Any import contextlib +import copy +import math from argparse import Namespace from dataclasses import dataclass, field -from typing import Any - +from typing import Any, Optional +import numpy as np import torch import torch.nn as nn -from omegaconf import II, MISSING +import torch.nn.functional as F +from omegaconf import II, MISSING, open_dict from fairseq import checkpoint_utils, tasks, utils from fairseq.dataclass import FairseqDataclass from fairseq.dataclass.utils import convert_namespace_to_omegaconf -from fairseq.models import BaseFairseqModel, FairseqEncoder, register_model +from fairseq.models import ( + BaseFairseqModel, + FairseqEncoder, + FairseqEncoderDecoderModel, + FairseqIncrementalDecoder, + register_model, +) from fairseq.models.hubert.hubert import MASKING_DISTRIBUTION_CHOICES +from fairseq.modules import LayerNorm, PositionalEmbedding, TransformerDecoderLayer from fairseq.tasks import FairseqTask @@ -51,6 +62,9 @@ class HubertAsrConfig(FairseqDataclass): "help": "dropout probability after activation in FFN " "inside hubert model" }, ) + encoder_embed_dim: Optional[int] = field( + default=768, metadata={"help": "encoder embedding dimension"} + ) # masking apply_mask: bool = field( @@ -191,13 +205,12 @@ class HubertSeq2SeqConfig(HubertAsrConfig): metadata={"help": "use learned positional embeddings in the decoder"}, ) decoder_normalize_before: bool = field( - default=False, - metadata={"help": "apply layernorm before each decoder block"}, + default=False, metadata={"help": "apply layernorm before each decoder block"} ) no_token_positional_embeddings: bool = field( default=False, metadata={ - "help": "if set, disables positional embeddings " "(outside self attention)" + "help": "if set, disables positional embeddings (outside self attention)" }, ) decoder_dropout: float = field( @@ -206,22 +219,100 @@ class HubertSeq2SeqConfig(HubertAsrConfig): decoder_attention_dropout: float = field( default=0.0, metadata={ - "help": "dropout probability for attention weights " "inside the decoder" + "help": "dropout probability for attention weights inside the decoder" }, ) decoder_activation_dropout: float = field( default=0.0, metadata={ - "help": "dropout probability after activation in FFN " "inside the decoder" + "help": "dropout probability after activation in FFN inside the decoder" }, ) max_target_positions: int = field( default=2048, metadata={"help": "max target positions"} ) share_decoder_input_output_embed: bool = field( - default=False, - metadata={"help": "share decoder input and output embeddings"}, + default=False, metadata={"help": "share decoder input and output embeddings"} ) + autoregressive: bool = II("task.autoregressive") + seq2seq_path: str = field( + default='', + metadata={"help": "reset_dict"}, + ) + reset_dict: bool = field( + default=False, + metadata={"help": "reset_dict"}, + ) + +@register_model("hubert_seq2seq", dataclass=HubertSeq2SeqConfig) +class HubertSeq2SeqModel(FairseqEncoderDecoderModel): + def __init__(self, encoder, decoder): + super().__init__(encoder, decoder) + + @classmethod + def build_model(cls, cfg: HubertSeq2SeqConfig, task: FairseqTask): + """Build a new model instance.""" + + assert ( + cfg.autoregressive + ), "Please set task.autoregressive=true for seq2seq asr models" + + src_dict, tgt_dict = task.source_dictionary, task.target_dictionary + + def build_embedding(dictionary, embed_dim): + num_embeddings = len(dictionary) + padding_idx = dictionary.pad() + emb = Embedding(num_embeddings, embed_dim, padding_idx) + return emb + + decoder_embed_tokens = build_embedding(tgt_dict, cfg.decoder_embed_dim) + + encoder = cls.build_encoder(cfg, task) + decoder = cls.build_decoder(cfg, tgt_dict, decoder_embed_tokens) + + + model = HubertSeq2SeqModel(encoder, decoder) + + if cfg['seq2seq_path']: + state = checkpoint_utils.load_checkpoint_to_cpu( + cfg.seq2seq_path + ) + state = state['model'] + if cfg['reset_dict']: + del state['decoder.embed_out'] + del state['decoder.embed_tokens.weight'] + model.load_state_dict(state, strict=False) + return model + + @classmethod + def build_encoder(cls, cfg: HubertAsrConfig, task): + return HubertEncoder(cfg, task) + + @classmethod + def build_decoder(cls, cfg: HubertSeq2SeqConfig, tgt_dict, embed_tokens): + return TransformerDecoder(cfg, tgt_dict, embed_tokens) + + def forward(self, **kwargs): + encoder_out = self.encoder(**kwargs) + decoder_out = self.decoder(encoder_out=encoder_out, **kwargs) + return decoder_out + + def upgrade_state_dict_named(self, state_dict, name): + return state_dict + + def load_state_dict( + self, + state_dict, + strict=True, + model_cfg=None, + args: Optional[Namespace] = None, + ): + if(model_cfg.reset_dict): + logger.warn("Overriding loading strict state dict!") + del state_dict['decoder.embed_out'] + del state_dict['decoder.embed_tokens.weight'] + return super().load_state_dict(state_dict, False, model_cfg, args) + return super().load_state_dict(state_dict, strict, model_cfg, args) class HubertEncoder(FairseqEncoder): @@ -290,7 +381,7 @@ class HubertEncoder(FairseqEncoder): self.freeze_finetune_updates = cfg.freeze_finetune_updates self.num_updates = 0 - if task.target_dictionary is not None: + if task.target_dictionary is not None and not cfg.autoregressive : self.proj = Linear(d, len(task.target_dictionary)) elif getattr(cfg, "decoder_embed_dim", d) != d: self.proj = Linear(d, cfg.decoder_embed_dim) @@ -339,6 +430,10 @@ class HubertEncoder(FairseqEncoder): encoder_out["encoder_padding_mask"] = encoder_out[ "encoder_padding_mask" ].index_select(0, new_order) + if encoder_out["padding_mask"] is not None: + encoder_out["padding_mask"] = encoder_out[ + "padding_mask" + ].index_select(0, new_order) return encoder_out def max_positions(self): @@ -348,6 +443,219 @@ class HubertEncoder(FairseqEncoder): def upgrade_state_dict_named(self, state_dict, name): return state_dict +class TransformerDecoder(FairseqIncrementalDecoder): + """ + Transformer decoder consisting of *args.decoder_layers* layers. Each layer + is a :class:`TransformerDecoderLayer`. + + Args: + args (argparse.Namespace): parsed command-line arguments + dictionary (~fairseq.data.Dictionary): decoding dictionary + embed_tokens (torch.nn.Embedding): output embedding + no_encoder_attn (bool, optional): whether to attend to encoder outputs + (default: False). + """ + + def __init__( + self, + cfg: HubertSeq2SeqConfig, + dictionary, + embed_tokens, + no_encoder_attn=False, + ): + super().__init__(dictionary) + + self.dropout = cfg.decoder_dropout + self.share_input_output_embed = cfg.share_decoder_input_output_embed + + input_embed_dim = embed_tokens.embedding_dim + embed_dim = cfg.decoder_embed_dim + self.output_embed_dim = cfg.decoder_embed_dim + + self.layerdrop = cfg.decoder_layerdrop + + self.padding_idx = embed_tokens.padding_idx + self.max_target_positions = cfg.max_target_positions + + self.embed_tokens = embed_tokens + self.embed_scale = math.sqrt(embed_dim) # todo: try with input_embed_dim + + self.project_in_dim = ( + Linear(input_embed_dim, embed_dim, bias=False) + if embed_dim != input_embed_dim + else None + ) + + self.embed_positions = ( + PositionalEmbedding( + cfg.max_target_positions, + embed_dim, + self.padding_idx, + learned=cfg.decoder_learned_pos, + ) + if not cfg.no_token_positional_embeddings + else None + ) + + # TODO: update this when transformer gets converted to dataclass configs + transformer_cfg = copy.deepcopy(cfg) + with open_dict(transformer_cfg): + transformer_cfg.dropout = transformer_cfg.decoder_dropout + transformer_cfg.attention_dropout = ( + transformer_cfg.decoder_attention_dropout + ) + transformer_cfg.activation_dropout = ( + transformer_cfg.decoder_activation_dropout + ) + + self.layers = nn.ModuleList([]) + self.layers.extend( + [ + TransformerDecoderLayer(transformer_cfg, no_encoder_attn) + for _ in range(transformer_cfg.decoder_layers) + ] + ) + + if not self.share_input_output_embed: + self.embed_out = nn.Parameter( + torch.Tensor(len(dictionary), self.output_embed_dim) + ) + nn.init.normal_(self.embed_out, mean=0, std=self.output_embed_dim**-0.5) + + if transformer_cfg.decoder_normalize_before: + self.layer_norm = LayerNorm(embed_dim) + else: + self.layer_norm = None + + def forward( + self, prev_output_tokens, encoder_out=None, incremental_state=None, **unused + ): + """ + Args: + prev_output_tokens (LongTensor): previous decoder outputs of shape + `(batch, tgt_len)`, for teacher forcing + encoder_out (Tensor, optional): output from the encoder, used for + encoder-side attention + incremental_state (dict): dictionary used for storing state during + :ref:`Incremental decoding` + + Returns: + tuple: + - the decoder's output of shape `(batch, tgt_len, vocab)` + - a dictionary with any model-specific outputs + """ + if (type(prev_output_tokens) == list): + max_len = max((len(x) for x in prev_output_tokens)) + tmp = torch.zeros([len(prev_output_tokens), max_len], device=prev_output_tokens[0].device) + for (i, p) in enumerate(prev_output_tokens): + tmp[i, :len(p)] = p + prev_output_tokens = tmp + prev_output_tokens = prev_output_tokens.long() + x, extra = self.extract_features( + prev_output_tokens, encoder_out, incremental_state + ) + x = self.output_layer(x) + return x, extra + + def extract_features( + self, prev_output_tokens, encoder_out=None, incremental_state=None, **unused + ): + """ + Similar to *forward* but only return features. + + Returns: + tuple: + - the decoder's features of shape `(batch, tgt_len, embed_dim)` + - a dictionary with any model-specific outputs + """ + + # embed positions + positions = ( + self.embed_positions( + prev_output_tokens, incremental_state=incremental_state + ) + if self.embed_positions is not None + else None + ) + + if incremental_state is not None: + prev_output_tokens = prev_output_tokens[:, -1:] + if positions is not None: + positions = positions[:, -1:] + + # embed tokens and positions + x = self.embed_scale * self.embed_tokens(prev_output_tokens) + + if self.project_in_dim is not None: + x = self.project_in_dim(x) + + if positions is not None: + x += positions + x = F.dropout(x, p=self.dropout, training=self.training) + + # B x T x C -> T x B x C + x = x.transpose(0, 1) + attn = None + + inner_states = [x] + + # decoder layers + self_attn_padding_mask = None + if prev_output_tokens.eq(self.padding_idx).any(): + self_attn_padding_mask = prev_output_tokens.eq(self.padding_idx) + for layer in self.layers: + dropout_probability = np.random.random() + if not self.training or (dropout_probability > self.layerdrop): + x, attn, _ = layer( + x, + encoder_out["encoder_out"] if encoder_out is not None else None, + encoder_out["padding_mask"] if encoder_out is not None else None, + incremental_state, + self_attn_mask=self.buffered_future_mask(x) + if incremental_state is None + else None, + self_attn_padding_mask=self_attn_padding_mask + ) + inner_states.append(x) + + if self.layer_norm: + x = self.layer_norm(x) + + # T x B x C -> B x T x C + x = x.transpose(0, 1) + + return x, {"attn": attn, "inner_states": inner_states} + + def output_layer(self, features, **kwargs): + """Project features to the vocabulary size.""" + # project back to size of vocabulary + if self.share_input_output_embed: + return F.linear(features, self.embed_tokens.weight) + else: + return F.linear(features, self.embed_out) + + def max_positions(self): + """Maximum output length supported by the decoder.""" + if self.embed_positions is None: + return self.max_target_positions + return min(self.max_target_positions, self.embed_positions.max_positions) + + def buffered_future_mask(self, tensor): + dim = tensor.size(0) + if ( + not hasattr(self, "_future_mask") + or self._future_mask is None + or self._future_mask.device != tensor.device + or self._future_mask.size(0) < dim + ): + self._future_mask = torch.triu( + utils.fill_with_neg_inf(tensor.new(dim, dim)), 1 + ) + return self._future_mask[:dim, :dim] + + def upgrade_state_dict_named(self, state_dict, name): + return state_dict + def Embedding(num_embeddings, embedding_dim, padding_idx): m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx) diff --git a/fairseq/models/wav2vec/wav2vec2_asr.py b/fairseq/models/wav2vec/wav2vec2_asr.py index 08d875a01..8854bda5c 100644 --- a/fairseq/models/wav2vec/wav2vec2_asr.py +++ b/fairseq/models/wav2vec/wav2vec2_asr.py @@ -626,6 +626,12 @@ class TransformerDecoder(FairseqIncrementalDecoder): - the decoder's output of shape `(batch, tgt_len, vocab)` - a dictionary with any model-specific outputs """ + if (type(prev_output_tokens) == list): + max_len = max((len(x) for x in prev_output_tokens)) + tmp = torch.zeros([len(prev_output_tokens), max_len], device=prev_output_tokens[0].device) + for (i, p) in enumerate(prev_output_tokens): + tmp[i, :len(p)] = p + prev_output_tokens = tmp prev_output_tokens = prev_output_tokens.long() x, extra = self.extract_features( prev_output_tokens, encoder_out, incremental_state diff --git a/fairseq/tasks/nlu_finetuning.py b/fairseq/tasks/nlu_finetuning.py new file mode 100644 index 000000000..2ea39cdb8 --- /dev/null +++ b/fairseq/tasks/nlu_finetuning.py @@ -0,0 +1,481 @@ +# Copyright (c) 2017-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the LICENSE file in +# the root directory of this source tree. An additional grant of patent rights +# can be found in the PATENTS file in the same directory. + +import logging +import os +import torch +import json + +from argparse import Namespace +from dataclasses import dataclass, field +from typing import Optional, Any + +from fairseq.data import AddTargetDataset, Dictionary, encoders +from fairseq.tasks.audio_pretraining import AudioPretrainingTask, AudioPretrainingConfig +from fairseq.dataclass import FairseqDataclass +from fairseq.dataclass.configs import GenerationConfig +from fairseq.data.text_compressor import TextCompressor, TextCompressionLevel + +from . import register_task +from .. import utils +from ..logging import metrics + + +logger = logging.getLogger(__name__) + + +class LabelEncoder(object): + def __init__(self, dictionary): + self.dictionary = dictionary + + def __call__(self, label): + return self.dictionary.encode_line( + label, append_eos=False, add_if_not_exist=False + ) + + +def label_len_fn(label): + return len(label.split(" ")) + + +@dataclass +class NLUFinetuningConfig(AudioPretrainingConfig): + # Options for reporting WER metrics during validation. Only applicable to + # Seq2Seq models during fine-tuning + eval_wer: bool = field( + default=False, metadata={"help": "compute WER for Seq2Seq models"} + ) + eval_wer_parse: bool = field( + default=False, metadata={"help": "compute WER for Seq2Seq models"} + ) + eval_wer_config: GenerationConfig = field( + default_factory=lambda: GenerationConfig(), + metadata={"help": "beam search config for evaluating wer during training"}, + ) + eval_wer_tokenizer: Any = field( + default=None, + metadata={"help": "tokenizer config for evaluating wer during training"}, + ) + eval_wer_post_process: str = field( + default="letter", + metadata={ + "help": "remove BPE tokens before scoring (can be sentencepiece, letter, and more)" + }, + ) + eval_bleu: bool = field( + default=False, metadata={"help": "evaluation with BLEU scores"} + ) + eval_bleu_detok: Optional[str] = field( + default=None, metadata={ + "help": "detokenize before computing BLEU (e.g., 'moses'); " + "required if using --eval-bleu; use 'space' to disable " + "detokenization; see fairseq.data.encoders for other options" + } + ) + eval_bleu_detok_args: str = field( + default="{}", + metadata={"help": "args for building the tokenizer, if needed"} + ) + eval_tokenized_bleu: bool = field( + default=False, + metadata={"help": "compute tokenized BLEU instead of sacrebleu"} + ) + eval_bleu_remove_bpe: Optional[str] = field( + default=None, metadata={"help": "remove BPE before computing BLEU"} + ) + eval_bleu_args: str = field( + default="{}", + metadata={"help": "generation args for BLUE scoring, e.g., " + "'{\"beam\": 4, \"lenpen\": 0.6}'"} + ) + eval_bleu_print_samples: bool = field( + default=False, + metadata={"help": "print sample generations during validation"} + ) + autoregressive: bool = field( + default=False, + metadata={ + "help": "required for autoregressive decoders (like seq2seq models); " + "adds 'prev_output_tokens' to input and appends eos to target" + }, + ) + + +@register_task("nlu_finetuning", dataclass=NLUFinetuningConfig) +class NLUFinetuningTask(AudioPretrainingTask): + """ """ + + cfg: NLUFinetuningConfig + + def __init__( + self, + cfg: NLUFinetuningConfig, + ): + super().__init__(cfg) + self.blank_symbol = "" + + self.state.add_factory("target_dictionary", self.load_target_dictionary) + + def load_target_dictionary(self): + if self.cfg.labels: + dict_path = os.path.join(self.cfg.data, f"dict.{self.cfg.labels}.txt") + return Dictionary.load(dict_path) + return None + + def load_dataset(self, split: str, task_cfg: NLUFinetuningConfig = None, **kwargs): + super().load_dataset(split, task_cfg, **kwargs) + + task_cfg = task_cfg or self.cfg + assert task_cfg.labels is not None + text_compression_level = getattr( + TextCompressionLevel, str(self.cfg.text_compression_level) + ) + data_path = self.cfg.data + label_path = os.path.join(data_path, f"{split}.{task_cfg.labels}") + skipped_indices = getattr(self.datasets[split], "skipped_indices", set()) + text_compressor = TextCompressor(level=text_compression_level) + with open(label_path, "r") as f: + labels = [ + text_compressor.compress(l) + for i, l in enumerate(f) if i not in skipped_indices + ] + + assert len(labels) == len(self.datasets[split]), ( + f"labels length ({len(labels)}) and dataset length " + f"({len(self.datasets[split])}) do not match" + ) + + process_label = LabelEncoder(self.target_dictionary) + + self.datasets[split] = AddTargetDataset( + self.datasets[split], + labels, + pad=self.target_dictionary.pad(), + eos=self.target_dictionary.eos(), + batch_targets=True, + process_label=process_label, + label_len_fn=label_len_fn, + add_to_input=task_cfg.get("autoregressive", False), + text_compression_level=text_compression_level + ) + + @property + def target_dictionary(self): + """Return the :class:`~fairseq.data.Dictionary` for the language + model.""" + return self.state.target_dictionary + + def valid_step(self, sample, model, criterion): + loss, sample_size, logging_output = super().valid_step(sample, model, criterion) + if self.cfg.eval_wer_parse and self.cfg.autoregressive: + metrics = self._inference_with_wer_parse(self.sequence_generator, sample, model) + logging_output["_num_char_errors"] = metrics["num_char_errors"] + logging_output["_num_chars"] = metrics["num_chars"] + logging_output["_num_word_errors"] = metrics["num_word_errors"] + logging_output["_num_words"] = metrics["num_words"] + logging_output["_num_em_errors"] = metrics["num_em_errors"] + logging_output["_num_ems"] = metrics["num_ems"] + logging_output["_num_tree_errors"] = metrics["num_tree_errors"] + logging_output["_num_trees"] = metrics["num_trees"] + if self.cfg.eval_wer and self.cfg.autoregressive: + metrics = self._inference_with_wer(self.sequence_generator, sample, model) + logging_output["_num_char_errors"] = metrics["num_char_errors"] + logging_output["_num_chars"] = metrics["num_chars"] + logging_output["_num_word_errors"] = metrics["num_word_errors"] + logging_output["_num_words"] = metrics["num_words"] + if self.cfg.eval_bleu and self.cfg.autoregressive: + metrics = self._inference_with_bleu(self.sequence_generator, sample, model) + logging_output['_bleu_sys_len'] = metrics.sys_len + logging_output['_bleu_ref_len'] = metrics.ref_len + # we split counts into separate entries so that they can be + # summed efficiently across workers using fast-stat-sync + assert len(metrics.counts) == 4 + for i in range(4): + logging_output[f"_bleu_counts_{i}"] = metrics.counts[i] + logging_output[f"_bleu_totals_{i}"] = metrics.totals[i] + return loss, sample_size, logging_output + + def build_model(self, model_cfg: FairseqDataclass): + model = super().build_model(model_cfg) + + if (self.cfg.eval_wer or self.cfg.eval_wer_parse) and self.cfg.autoregressive: + self.sequence_generator = self.build_generator( + [model], + self.cfg.eval_wer_config, + ) + if self.cfg.eval_wer_tokenizer: + self.tokenizer = encoders.build_tokenizer(self.cfg.eval_wer_tokenizer) + else: + self.tokenizer = None + if self.cfg.eval_bleu and self.cfg.autoregressive: + assert self.cfg.eval_bleu_detok is not None, ( + '--eval-bleu-detok is required if using --eval-bleu; ' + 'try --eval-bleu-detok=moses (or --eval-bleu-detok=space ' + 'to disable detokenization, e.g., when using sentencepiece)' + ) + detok_args = json.loads(self.cfg.eval_bleu_detok_args) + self.tokenizer = encoders.build_tokenizer( + Namespace(tokenizer=self.cfg.eval_bleu_detok, **detok_args) + ) + gen_args = json.loads(self.cfg.eval_bleu_args) + gen_args = Namespace(**gen_args) + self.sequence_generator = self.build_generator([model], gen_args) + + return model + + def _inference_with_wer_parse(self, generator, sample, model): + import editdistance + + def decode(toks): + s = self.target_dictionary.string( + toks.int().cpu(), + self.cfg.eval_wer_post_process, + escape_unk=True, + ) + if self.tokenizer: + s = self.tokenizer.decode(s) + return s + + def decode_to_list(toks): + def token_string(i): + if i == self.target_dictionary.unk(): + return self.target_dictionary.unk_string(False) + else: + return self.target_dictionary[i] + return [ token_string(i) for i in toks ] + + def is_ont_token(token): + return '[' in token or ']' in token + + def post_process(l): + o = [] + for w in l: + if(w == self.target_dictionary.eos_word or w == '|'): + continue + if(w == '_'): + o.append(' ') + else: + o.append(w) + if(is_ont_token(w)): + o.append(' ') + return o + + num_word_errors, num_char_errors = 0, 0 + num_chars, num_words = 0, 0 + num_em_errors, num_ems = 0, 0 + num_tree_errors, num_trees = 0, 0 + gen_out = self.inference_step(generator, [model], sample, None) + for i in range(len(gen_out)): + hyp_tokens = gen_out[i][0]["tokens"] + # hyp = decode(hyp_tokens) + ref_tokens = utils.strip_pad(sample["target"][i], self.target_dictionary.pad()) + # ref = decode(ref_tokens) + hyp_list = decode_to_list(hyp_tokens) + ref_list = decode_to_list(ref_tokens) + + hyp_list = post_process(hyp_list) + ref_list = post_process(ref_list) + + hyp = ''.join(hyp_list).strip() + ref = ''.join(ref_list).strip() + num_chars += len(ref) + num_char_errors += editdistance.eval(hyp, ref) + hyp_words = hyp.split() + ref_words = ref.split() + hyp_tree = [ word for word in hyp_list if ('[' in word or ']' in word) ] + ref_tree = [ word for word in ref_list if ('[' in word or ']' in word) ] + # num_word_errors += editdistance.eval(hyp_words, ref_words) + hyp_before = decode(hyp_tokens).split() + ref_before = decode(ref_tokens).split() + + num_word_errors += editdistance.eval(hyp_before, ref_before) + num_words += len(ref_before) + if(hyp != ref): + num_em_errors += 1 + if(hyp_tree != ref_tree): + num_tree_errors += 1 + num_ems += 1 + num_trees += 1 + + return { + "num_char_errors": num_char_errors, + "num_chars": num_chars, + "num_word_errors": num_word_errors, + "num_words": num_words, + "num_ems": num_ems, + "num_em_errors": num_em_errors, + "num_trees": num_trees, + "num_tree_errors": num_tree_errors + } + + def _inference_with_wer(self, generator, sample, model): + import editdistance + + def decode(toks): + s = self.target_dictionary.string( + toks.int().cpu(), + self.cfg.eval_wer_post_process, + escape_unk=True, + ) + if self.tokenizer: + s = self.tokenizer.decode(s) + return s + + num_word_errors, num_char_errors = 0, 0 + num_chars, num_words = 0, 0 + gen_out = self.inference_step(generator, [model], sample, None) + for i in range(len(gen_out)): + hyp = decode(gen_out[i][0]["tokens"]) + ref = decode( + utils.strip_pad(sample["target"][i], self.target_dictionary.pad()), + ) + num_char_errors += editdistance.eval(hyp, ref) + num_chars += len(ref) + hyp_words = hyp.split() + ref_words = ref.split() + num_word_errors += editdistance.eval(hyp_words, ref_words) + num_words += len(ref_words) + + return { + "num_char_errors": num_char_errors, + "num_chars": num_chars, + "num_word_errors": num_word_errors, + "num_words": num_words, + } + + def _inference_with_bleu(self, generator, sample, model): + import sacrebleu + + def decode(toks, is_ref): + s = self.target_dictionary.string( + toks.int().cpu(), + self.cfg.eval_bleu_remove_bpe, + # The default unknown string in fairseq is ``, but + # this is tokenized by sacrebleu as `< unk >`, inflating + # BLEU scores. Instead, we use a somewhat more verbose + # alternative that is unlikely to appear in the real + # reference, but doesn't get split into multiple tokens. + unk_string=( + "UNKNOWNTOKENINREF" if is_ref else "UNKNOWNTOKENINHYP" + ), + ) + if self.tokenizer: + s = self.tokenizer.decode(s) + return s + + gen_out = self.inference_step(generator, [model], sample) + hyps, refs = [], [] + for i in range(len(gen_out)): + hyps.append(decode(gen_out[i][0]['tokens'], is_ref=False)) + refs.append( + decode( + utils.strip_pad( + sample['target'][i], + self.target_dictionary.pad() + ), + is_ref=True, # don't count as matches to the hypo + ) + ) + if self.cfg.eval_bleu_print_samples: + logger.info('H-{} {}'.format(sample["id"][0], hyps[0])) + logger.info('T-{} {}'.format(sample["id"][0], refs[0])) + + eval_tokenization = 'none' if self.cfg.eval_tokenized_bleu else '13a' + return sacrebleu.corpus_bleu(hyps, [refs], tokenize=eval_tokenization) + + def reduce_metrics(self, logging_outputs, criterion): + super().reduce_metrics(logging_outputs, criterion) + + if self.cfg.eval_wer or self.cfg.eval_wer_parse: + zero = torch.scalar_tensor(0.0) + num_char_errors = sum( + log.get("_num_char_errors", zero) for log in logging_outputs + ) + num_chars = sum(log.get("_num_chars", zero) for log in logging_outputs) + num_word_errors = sum( + log.get("_num_word_errors", zero) for log in logging_outputs + ) + num_words = sum(log.get("_num_words", zero) for log in logging_outputs) + metrics.log_scalar("_num_char_errors", num_char_errors) + metrics.log_scalar("_num_chars", num_chars) + metrics.log_scalar("_num_word_errors", num_word_errors) + metrics.log_scalar("_num_words", num_words) + if num_chars > 0: + metrics.log_derived( + "uer", + lambda meters: meters["_num_char_errors"].sum + * 100.0 + / meters["_num_chars"].sum + if meters["_num_chars"].sum > 0 + else float("nan"), + ) + if num_words > 0: + metrics.log_derived( + "wer", + lambda meters: meters["_num_word_errors"].sum + * 100.0 + / meters["_num_words"].sum + if meters["_num_words"].sum > 0 + else float("nan"), + ) + if self.cfg.eval_wer_parse: + num_em_errors = sum( + log.get("_num_em_errors", zero) for log in logging_outputs + ) + num_ems = sum( + log.get("_num_ems", zero) for log in logging_outputs + ) + metrics.log_scalar("_num_em_errors", num_em_errors) + metrics.log_scalar("_num_ems", num_ems) + num_tree_errors = sum( + log.get("_num_tree_errors", zero) for log in logging_outputs + ) + num_trees = sum( + log.get("_num_trees", zero) for log in logging_outputs + ) + metrics.log_scalar("_num_tree_errors", num_tree_errors) + metrics.log_scalar("_num_trees", num_trees) + + if num_ems > 0: + metrics.log_derived( + "em_error", + lambda meters: meters["_num_em_errors"].sum + * 100.0 + / meters["_num_ems"].sum + if meters["_num_ems"].sum > 0 + else float("nan"), + ) + if num_trees > 0: + metrics.log_derived( + "tree_error", + lambda meters: meters["_num_tree_errors"].sum + * 100.0 + / meters["_num_trees"].sum + if meters["_num_trees"].sum > 0 + else float("nan"), + ) + + if self.cfg.eval_bleu: + len_keys = ["_bleu_sys_len", "_bleu_ref_len"] + count_keys = [f"_bleu_counts_{i}" for i in range(4)] + total_keys = [f"_bleu_totals_{i}" for i in range(4)] + for k in len_keys + count_keys + total_keys: + metrics.log_scalar( + k, sum(log.get(k, 0) for log in logging_outputs) + ) + + import sacrebleu + metrics.log_derived( + 'bleu', + lambda meters: sacrebleu.compute_bleu( + correct=[meters[k].sum for k in count_keys], + total=[meters[k].sum for k in total_keys], + sys_len=meters['_bleu_sys_len'].sum, + ref_len=meters['_bleu_ref_len'].sum, + smooth_method="exp" + ).score + )