STOP Dataset Release and Experiment Reproduction (#4578)

* STOP paper release 

Co-authored-by: Paden Tomasello <padentomasello@devfair0417.h2.fair>
Co-authored-by: Paden Tomasello <padentomasello@learnfair5258.h2.fair>
This commit is contained in:
padentomasello 2022-07-18 15:47:30 -07:00 committed by GitHub
parent cba35cdbca
commit c82b7c54df
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 955 additions and 12 deletions

View File

@ -0,0 +1,51 @@
# End-to-end NLU
End-to-end spoken language understanding (SLU) predicts intent directly from audio using a single model. It promises to improve the performance of assistant systems by leveraging acoustic information lost in the intermediate textual representation and preventing cascading errors from Automatic Speech Recognition (ASR). Further, having one unified model has efficiency advantages when deploying assistant systems on-device.
This page releases the code for reproducing the results in [STOP: A dataset for Spoken Task Oriented Semantic Parsing](TODO)
The dataset can be downloaded here: [download link](https://dl.fbaipublicfiles.com/stop/stop.tar.gz)
## Pretrained models end-to-end NLU Models
| Speech Pretraining | ASR Pretraining | Test EM Accuracy | Tesst EM-Tree Accuracy | Link |
| ----------- | ----------- |----------|----------|----------|
| None | None | 36.54 | 57.01 | [link](https://dl.fbaipublicfiles.com/stop/end-to-end-nlu-none-none.pt) |
| Wav2Vec | None | 68.05 | 82.53 | [link](https://dl.fbaipublicfiles.com/stop/end-to-end-nlu-wav2vec-none.pt) |
| HuBERT | None | 68.40 | 82.85 | [link](https://dl.fbaipublicfiles.com/stop/end-to-end-nlu-hubert-none.pt) |
| Wav2Vec | STOP | 68.70 | 82.78 | [link](https://dl.fbaipublicfiles.com/stop/end-to-end-nlu-wav2vec-stop.pt) |
| HuBERT | STOP | 69.23 | 82.87 | [link](https://dl.fbaipublicfiles.com/stop/end-to-end-nlu-hubert-stop.pt) |
| Wav2Vec | Librispeech | 68.47 | 82.49 | [link](https://dl.fbaipublicfiles.com/stop/end-to-end-nlu-wav2vec-ls.pt) |
| HuBERT | Librispeech | 68.70 | 82.78 | [link](https://dl.fbaipublicfiles.com/stop/end-to-end-nlu-hubert-ls.pt) |
## Pretrained models ASR Models
| Speech Pre-training | ASR Dataset | STOP Eval WER | STOP Test WER | dev\_other WER | dev\_clean WER | test\_clean WER | test\_other WER | Link |
| ----------- | ----------- | ----------- | ----------- | ----------- | ----------- | ----------- | ----------- | ----------- |
| HuBERT | Librispeech | 8.47 | 2.99 | 3.25 | 8.06 | 25.68 | 26.19 | [link](https://dl.fbaipublicfiles.com/stop/ctc-asr-hubert-ls.pt) |
| Wav2Vec | Librispeech | 9.215 | 3.204 | 3.334 | 9.006 | 27.257 | 27.588 | [link](https://dl.fbaipublicfiles.com/stop/ctc-asr-wav2vec-ls.pt) |
| HuBERT | STOP | 46.31 | 31.30 | 31.52 | 47.16 | 4.29 | 4.26 | [link](https://dl.fbaipublicfiles.com/stop/ctc-asr-hubert-stop.pt) |
| Wav2Vec | STOP | 43.103 | 27.833 | 28.479 | 28.479 | 4.679 | 4.667 | [link](https://dl.fbaipublicfiles.com/stop/ctc-asr-wav2vec-stop.pt) |
| HuBERT | Librispeech + STOP | 9.015 | 3.211 | 3.372 | 8.635 | 5.133 | 5.056 | [link](https://dl.fbaipublicfiles.com/stop/ctc-asr-hubert-ls-stop.pt) |
| Wav2Vec | Librispeech + STOP | 9.549 | 3.537 | 3.625 | 9.514 | 5.59 | 5.562 | [link](https://dl.fbaipublicfiles.com/stop/ctc-asr-wav2vec-ls-stop.pt) |
## Creating the fairseq datasets from STOP
First, create the audio file manifests and label files:
```
python examples/audio_nlp/nlu/generate_manifests.py --stop_root $STOP_DOWNLOAD_DIR/stop --output $FAIRSEQ_DATASET_OUTPUT/
```
Run `./examples/audio_nlp/nlu/create_dict_stop.sh $FAIRSEQ_DATASET_OUTPUT` to generate the fairseq dictionaries.
## Training an End-to-end NLU Model
Download a wav2vec or hubert model from [link](https://github.com/facebookresearch/fairseq/tree/main/examples/hubert) or [link](https://github.com/facebookresearch/fairseq/tree/main/examples/wav2vec)
```
python fairseq_cli/hydra-train --config-dir examples/audio_nlp/nlu/configs/ --config-name nlu_finetuning task.data=$FAIRSEQ_DATA_OUTPUT model.w2v_path=$PRETRAINED_MODEL_PATH
```

View File

@ -0,0 +1,59 @@
# @package _group_
common:
fp16: true
log_format: json
log_interval: 10
tensorboard_logdir: tb
checkpoint:
no_epoch_checkpoints: true
best_checkpoint_metric: em_error
save_interval: 10
task:
_name: nlu_finetuning
data: ???
labels: parse
eval_wer_parse: true
autoregressive: true
dataset:
num_workers: 6
max_tokens: 1600000
skip_invalid_size_inputs_valid_test: true
valid_subset: eval,test
train_subset: train
validate_interval: 10
criterion:
_name: label_smoothed_cross_entropy
optimization:
max_update: 320000
lr: [0.0001]
sentence_avg: true
update_freq: [1]
optimizer:
_name: adam
adam_betas: (0.9,0.98)
adam_eps: 1e-08
lr_scheduler:
_name: tri_stage
phase_ratio: [0.1, 0.4, 0.5]
final_lr_scale: 0.05
model:
_name: wav2vec_seq2seq
w2v_path: ???
autoregressive: true
apply_mask: true
mask_prob: 0.5
mask_channel_prob: 0.5
mask_channel_length: 64
layerdrop: 0.1
activation_dropout: 0.1
feature_grad_mult: 0.0
freeze_finetune_updates: 0

View File

@ -0,0 +1,38 @@
#!/bin/bash
### Script handling creation of data binaries
### for model training within fairseq
fairseq_root="."
data_root=$1
train_prefix="${data_root}/train"
valid_prefix="${data_root}/eval"
test_prefix="${data_root}/test"
dest_dir="$data_root/"
#echo "src dict: $src_dict" > "$dest_dir/src_dict.txt"
#echo "trg dict: $tgt_dict" > "$dest_dir/tgt_dict.txt"
#--tgtdict $tgt_dict \
PYTHONPATH=$fairseq_root \
python $fairseq_root/fairseq_cli/preprocess.py \
--source-lang "parse" \
--trainpref $train_prefix \
--validpref $valid_prefix \
--destdir $dest_dir \
--only-source \
--dict-only \
--workers 60;
PYTHONPATH=$fairseq_root \
python $fairseq_root/fairseq_cli/preprocess.py \
--source-lang "ltr" \
--trainpref $train_prefix \
--validpref $valid_prefix \
--destdir $dest_dir \
--only-source \
--dict-only \
--workers 60;

View File

@ -3,20 +3,31 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import Any
import contextlib
import copy
import math
from argparse import Namespace
from dataclasses import dataclass, field
from typing import Any
from typing import Any, Optional
import numpy as np
import torch
import torch.nn as nn
from omegaconf import II, MISSING
import torch.nn.functional as F
from omegaconf import II, MISSING, open_dict
from fairseq import checkpoint_utils, tasks, utils
from fairseq.dataclass import FairseqDataclass
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
from fairseq.models import BaseFairseqModel, FairseqEncoder, register_model
from fairseq.models import (
BaseFairseqModel,
FairseqEncoder,
FairseqEncoderDecoderModel,
FairseqIncrementalDecoder,
register_model,
)
from fairseq.models.hubert.hubert import MASKING_DISTRIBUTION_CHOICES
from fairseq.modules import LayerNorm, PositionalEmbedding, TransformerDecoderLayer
from fairseq.tasks import FairseqTask
@ -51,6 +62,9 @@ class HubertAsrConfig(FairseqDataclass):
"help": "dropout probability after activation in FFN " "inside hubert model"
},
)
encoder_embed_dim: Optional[int] = field(
default=768, metadata={"help": "encoder embedding dimension"}
)
# masking
apply_mask: bool = field(
@ -191,13 +205,12 @@ class HubertSeq2SeqConfig(HubertAsrConfig):
metadata={"help": "use learned positional embeddings in the decoder"},
)
decoder_normalize_before: bool = field(
default=False,
metadata={"help": "apply layernorm before each decoder block"},
default=False, metadata={"help": "apply layernorm before each decoder block"}
)
no_token_positional_embeddings: bool = field(
default=False,
metadata={
"help": "if set, disables positional embeddings " "(outside self attention)"
"help": "if set, disables positional embeddings (outside self attention)"
},
)
decoder_dropout: float = field(
@ -206,22 +219,100 @@ class HubertSeq2SeqConfig(HubertAsrConfig):
decoder_attention_dropout: float = field(
default=0.0,
metadata={
"help": "dropout probability for attention weights " "inside the decoder"
"help": "dropout probability for attention weights inside the decoder"
},
)
decoder_activation_dropout: float = field(
default=0.0,
metadata={
"help": "dropout probability after activation in FFN " "inside the decoder"
"help": "dropout probability after activation in FFN inside the decoder"
},
)
max_target_positions: int = field(
default=2048, metadata={"help": "max target positions"}
)
share_decoder_input_output_embed: bool = field(
default=False,
metadata={"help": "share decoder input and output embeddings"},
default=False, metadata={"help": "share decoder input and output embeddings"}
)
autoregressive: bool = II("task.autoregressive")
seq2seq_path: str = field(
default='',
metadata={"help": "reset_dict"},
)
reset_dict: bool = field(
default=False,
metadata={"help": "reset_dict"},
)
@register_model("hubert_seq2seq", dataclass=HubertSeq2SeqConfig)
class HubertSeq2SeqModel(FairseqEncoderDecoderModel):
def __init__(self, encoder, decoder):
super().__init__(encoder, decoder)
@classmethod
def build_model(cls, cfg: HubertSeq2SeqConfig, task: FairseqTask):
"""Build a new model instance."""
assert (
cfg.autoregressive
), "Please set task.autoregressive=true for seq2seq asr models"
src_dict, tgt_dict = task.source_dictionary, task.target_dictionary
def build_embedding(dictionary, embed_dim):
num_embeddings = len(dictionary)
padding_idx = dictionary.pad()
emb = Embedding(num_embeddings, embed_dim, padding_idx)
return emb
decoder_embed_tokens = build_embedding(tgt_dict, cfg.decoder_embed_dim)
encoder = cls.build_encoder(cfg, task)
decoder = cls.build_decoder(cfg, tgt_dict, decoder_embed_tokens)
model = HubertSeq2SeqModel(encoder, decoder)
if cfg['seq2seq_path']:
state = checkpoint_utils.load_checkpoint_to_cpu(
cfg.seq2seq_path
)
state = state['model']
if cfg['reset_dict']:
del state['decoder.embed_out']
del state['decoder.embed_tokens.weight']
model.load_state_dict(state, strict=False)
return model
@classmethod
def build_encoder(cls, cfg: HubertAsrConfig, task):
return HubertEncoder(cfg, task)
@classmethod
def build_decoder(cls, cfg: HubertSeq2SeqConfig, tgt_dict, embed_tokens):
return TransformerDecoder(cfg, tgt_dict, embed_tokens)
def forward(self, **kwargs):
encoder_out = self.encoder(**kwargs)
decoder_out = self.decoder(encoder_out=encoder_out, **kwargs)
return decoder_out
def upgrade_state_dict_named(self, state_dict, name):
return state_dict
def load_state_dict(
self,
state_dict,
strict=True,
model_cfg=None,
args: Optional[Namespace] = None,
):
if(model_cfg.reset_dict):
logger.warn("Overriding loading strict state dict!")
del state_dict['decoder.embed_out']
del state_dict['decoder.embed_tokens.weight']
return super().load_state_dict(state_dict, False, model_cfg, args)
return super().load_state_dict(state_dict, strict, model_cfg, args)
class HubertEncoder(FairseqEncoder):
@ -290,7 +381,7 @@ class HubertEncoder(FairseqEncoder):
self.freeze_finetune_updates = cfg.freeze_finetune_updates
self.num_updates = 0
if task.target_dictionary is not None:
if task.target_dictionary is not None and not cfg.autoregressive :
self.proj = Linear(d, len(task.target_dictionary))
elif getattr(cfg, "decoder_embed_dim", d) != d:
self.proj = Linear(d, cfg.decoder_embed_dim)
@ -339,6 +430,10 @@ class HubertEncoder(FairseqEncoder):
encoder_out["encoder_padding_mask"] = encoder_out[
"encoder_padding_mask"
].index_select(0, new_order)
if encoder_out["padding_mask"] is not None:
encoder_out["padding_mask"] = encoder_out[
"padding_mask"
].index_select(0, new_order)
return encoder_out
def max_positions(self):
@ -348,6 +443,219 @@ class HubertEncoder(FairseqEncoder):
def upgrade_state_dict_named(self, state_dict, name):
return state_dict
class TransformerDecoder(FairseqIncrementalDecoder):
"""
Transformer decoder consisting of *args.decoder_layers* layers. Each layer
is a :class:`TransformerDecoderLayer`.
Args:
args (argparse.Namespace): parsed command-line arguments
dictionary (~fairseq.data.Dictionary): decoding dictionary
embed_tokens (torch.nn.Embedding): output embedding
no_encoder_attn (bool, optional): whether to attend to encoder outputs
(default: False).
"""
def __init__(
self,
cfg: HubertSeq2SeqConfig,
dictionary,
embed_tokens,
no_encoder_attn=False,
):
super().__init__(dictionary)
self.dropout = cfg.decoder_dropout
self.share_input_output_embed = cfg.share_decoder_input_output_embed
input_embed_dim = embed_tokens.embedding_dim
embed_dim = cfg.decoder_embed_dim
self.output_embed_dim = cfg.decoder_embed_dim
self.layerdrop = cfg.decoder_layerdrop
self.padding_idx = embed_tokens.padding_idx
self.max_target_positions = cfg.max_target_positions
self.embed_tokens = embed_tokens
self.embed_scale = math.sqrt(embed_dim) # todo: try with input_embed_dim
self.project_in_dim = (
Linear(input_embed_dim, embed_dim, bias=False)
if embed_dim != input_embed_dim
else None
)
self.embed_positions = (
PositionalEmbedding(
cfg.max_target_positions,
embed_dim,
self.padding_idx,
learned=cfg.decoder_learned_pos,
)
if not cfg.no_token_positional_embeddings
else None
)
# TODO: update this when transformer gets converted to dataclass configs
transformer_cfg = copy.deepcopy(cfg)
with open_dict(transformer_cfg):
transformer_cfg.dropout = transformer_cfg.decoder_dropout
transformer_cfg.attention_dropout = (
transformer_cfg.decoder_attention_dropout
)
transformer_cfg.activation_dropout = (
transformer_cfg.decoder_activation_dropout
)
self.layers = nn.ModuleList([])
self.layers.extend(
[
TransformerDecoderLayer(transformer_cfg, no_encoder_attn)
for _ in range(transformer_cfg.decoder_layers)
]
)
if not self.share_input_output_embed:
self.embed_out = nn.Parameter(
torch.Tensor(len(dictionary), self.output_embed_dim)
)
nn.init.normal_(self.embed_out, mean=0, std=self.output_embed_dim**-0.5)
if transformer_cfg.decoder_normalize_before:
self.layer_norm = LayerNorm(embed_dim)
else:
self.layer_norm = None
def forward(
self, prev_output_tokens, encoder_out=None, incremental_state=None, **unused
):
"""
Args:
prev_output_tokens (LongTensor): previous decoder outputs of shape
`(batch, tgt_len)`, for teacher forcing
encoder_out (Tensor, optional): output from the encoder, used for
encoder-side attention
incremental_state (dict): dictionary used for storing state during
:ref:`Incremental decoding`
Returns:
tuple:
- the decoder's output of shape `(batch, tgt_len, vocab)`
- a dictionary with any model-specific outputs
"""
if (type(prev_output_tokens) == list):
max_len = max((len(x) for x in prev_output_tokens))
tmp = torch.zeros([len(prev_output_tokens), max_len], device=prev_output_tokens[0].device)
for (i, p) in enumerate(prev_output_tokens):
tmp[i, :len(p)] = p
prev_output_tokens = tmp
prev_output_tokens = prev_output_tokens.long()
x, extra = self.extract_features(
prev_output_tokens, encoder_out, incremental_state
)
x = self.output_layer(x)
return x, extra
def extract_features(
self, prev_output_tokens, encoder_out=None, incremental_state=None, **unused
):
"""
Similar to *forward* but only return features.
Returns:
tuple:
- the decoder's features of shape `(batch, tgt_len, embed_dim)`
- a dictionary with any model-specific outputs
"""
# embed positions
positions = (
self.embed_positions(
prev_output_tokens, incremental_state=incremental_state
)
if self.embed_positions is not None
else None
)
if incremental_state is not None:
prev_output_tokens = prev_output_tokens[:, -1:]
if positions is not None:
positions = positions[:, -1:]
# embed tokens and positions
x = self.embed_scale * self.embed_tokens(prev_output_tokens)
if self.project_in_dim is not None:
x = self.project_in_dim(x)
if positions is not None:
x += positions
x = F.dropout(x, p=self.dropout, training=self.training)
# B x T x C -> T x B x C
x = x.transpose(0, 1)
attn = None
inner_states = [x]
# decoder layers
self_attn_padding_mask = None
if prev_output_tokens.eq(self.padding_idx).any():
self_attn_padding_mask = prev_output_tokens.eq(self.padding_idx)
for layer in self.layers:
dropout_probability = np.random.random()
if not self.training or (dropout_probability > self.layerdrop):
x, attn, _ = layer(
x,
encoder_out["encoder_out"] if encoder_out is not None else None,
encoder_out["padding_mask"] if encoder_out is not None else None,
incremental_state,
self_attn_mask=self.buffered_future_mask(x)
if incremental_state is None
else None,
self_attn_padding_mask=self_attn_padding_mask
)
inner_states.append(x)
if self.layer_norm:
x = self.layer_norm(x)
# T x B x C -> B x T x C
x = x.transpose(0, 1)
return x, {"attn": attn, "inner_states": inner_states}
def output_layer(self, features, **kwargs):
"""Project features to the vocabulary size."""
# project back to size of vocabulary
if self.share_input_output_embed:
return F.linear(features, self.embed_tokens.weight)
else:
return F.linear(features, self.embed_out)
def max_positions(self):
"""Maximum output length supported by the decoder."""
if self.embed_positions is None:
return self.max_target_positions
return min(self.max_target_positions, self.embed_positions.max_positions)
def buffered_future_mask(self, tensor):
dim = tensor.size(0)
if (
not hasattr(self, "_future_mask")
or self._future_mask is None
or self._future_mask.device != tensor.device
or self._future_mask.size(0) < dim
):
self._future_mask = torch.triu(
utils.fill_with_neg_inf(tensor.new(dim, dim)), 1
)
return self._future_mask[:dim, :dim]
def upgrade_state_dict_named(self, state_dict, name):
return state_dict
def Embedding(num_embeddings, embedding_dim, padding_idx):
m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)

View File

@ -626,6 +626,12 @@ class TransformerDecoder(FairseqIncrementalDecoder):
- the decoder's output of shape `(batch, tgt_len, vocab)`
- a dictionary with any model-specific outputs
"""
if (type(prev_output_tokens) == list):
max_len = max((len(x) for x in prev_output_tokens))
tmp = torch.zeros([len(prev_output_tokens), max_len], device=prev_output_tokens[0].device)
for (i, p) in enumerate(prev_output_tokens):
tmp[i, :len(p)] = p
prev_output_tokens = tmp
prev_output_tokens = prev_output_tokens.long()
x, extra = self.extract_features(
prev_output_tokens, encoder_out, incremental_state

View File

@ -0,0 +1,481 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import logging
import os
import torch
import json
from argparse import Namespace
from dataclasses import dataclass, field
from typing import Optional, Any
from fairseq.data import AddTargetDataset, Dictionary, encoders
from fairseq.tasks.audio_pretraining import AudioPretrainingTask, AudioPretrainingConfig
from fairseq.dataclass import FairseqDataclass
from fairseq.dataclass.configs import GenerationConfig
from fairseq.data.text_compressor import TextCompressor, TextCompressionLevel
from . import register_task
from .. import utils
from ..logging import metrics
logger = logging.getLogger(__name__)
class LabelEncoder(object):
def __init__(self, dictionary):
self.dictionary = dictionary
def __call__(self, label):
return self.dictionary.encode_line(
label, append_eos=False, add_if_not_exist=False
)
def label_len_fn(label):
return len(label.split(" "))
@dataclass
class NLUFinetuningConfig(AudioPretrainingConfig):
# Options for reporting WER metrics during validation. Only applicable to
# Seq2Seq models during fine-tuning
eval_wer: bool = field(
default=False, metadata={"help": "compute WER for Seq2Seq models"}
)
eval_wer_parse: bool = field(
default=False, metadata={"help": "compute WER for Seq2Seq models"}
)
eval_wer_config: GenerationConfig = field(
default_factory=lambda: GenerationConfig(),
metadata={"help": "beam search config for evaluating wer during training"},
)
eval_wer_tokenizer: Any = field(
default=None,
metadata={"help": "tokenizer config for evaluating wer during training"},
)
eval_wer_post_process: str = field(
default="letter",
metadata={
"help": "remove BPE tokens before scoring (can be sentencepiece, letter, and more)"
},
)
eval_bleu: bool = field(
default=False, metadata={"help": "evaluation with BLEU scores"}
)
eval_bleu_detok: Optional[str] = field(
default=None, metadata={
"help": "detokenize before computing BLEU (e.g., 'moses'); "
"required if using --eval-bleu; use 'space' to disable "
"detokenization; see fairseq.data.encoders for other options"
}
)
eval_bleu_detok_args: str = field(
default="{}",
metadata={"help": "args for building the tokenizer, if needed"}
)
eval_tokenized_bleu: bool = field(
default=False,
metadata={"help": "compute tokenized BLEU instead of sacrebleu"}
)
eval_bleu_remove_bpe: Optional[str] = field(
default=None, metadata={"help": "remove BPE before computing BLEU"}
)
eval_bleu_args: str = field(
default="{}",
metadata={"help": "generation args for BLUE scoring, e.g., "
"'{\"beam\": 4, \"lenpen\": 0.6}'"}
)
eval_bleu_print_samples: bool = field(
default=False,
metadata={"help": "print sample generations during validation"}
)
autoregressive: bool = field(
default=False,
metadata={
"help": "required for autoregressive decoders (like seq2seq models); "
"adds 'prev_output_tokens' to input and appends eos to target"
},
)
@register_task("nlu_finetuning", dataclass=NLUFinetuningConfig)
class NLUFinetuningTask(AudioPretrainingTask):
""" """
cfg: NLUFinetuningConfig
def __init__(
self,
cfg: NLUFinetuningConfig,
):
super().__init__(cfg)
self.blank_symbol = "<s>"
self.state.add_factory("target_dictionary", self.load_target_dictionary)
def load_target_dictionary(self):
if self.cfg.labels:
dict_path = os.path.join(self.cfg.data, f"dict.{self.cfg.labels}.txt")
return Dictionary.load(dict_path)
return None
def load_dataset(self, split: str, task_cfg: NLUFinetuningConfig = None, **kwargs):
super().load_dataset(split, task_cfg, **kwargs)
task_cfg = task_cfg or self.cfg
assert task_cfg.labels is not None
text_compression_level = getattr(
TextCompressionLevel, str(self.cfg.text_compression_level)
)
data_path = self.cfg.data
label_path = os.path.join(data_path, f"{split}.{task_cfg.labels}")
skipped_indices = getattr(self.datasets[split], "skipped_indices", set())
text_compressor = TextCompressor(level=text_compression_level)
with open(label_path, "r") as f:
labels = [
text_compressor.compress(l)
for i, l in enumerate(f) if i not in skipped_indices
]
assert len(labels) == len(self.datasets[split]), (
f"labels length ({len(labels)}) and dataset length "
f"({len(self.datasets[split])}) do not match"
)
process_label = LabelEncoder(self.target_dictionary)
self.datasets[split] = AddTargetDataset(
self.datasets[split],
labels,
pad=self.target_dictionary.pad(),
eos=self.target_dictionary.eos(),
batch_targets=True,
process_label=process_label,
label_len_fn=label_len_fn,
add_to_input=task_cfg.get("autoregressive", False),
text_compression_level=text_compression_level
)
@property
def target_dictionary(self):
"""Return the :class:`~fairseq.data.Dictionary` for the language
model."""
return self.state.target_dictionary
def valid_step(self, sample, model, criterion):
loss, sample_size, logging_output = super().valid_step(sample, model, criterion)
if self.cfg.eval_wer_parse and self.cfg.autoregressive:
metrics = self._inference_with_wer_parse(self.sequence_generator, sample, model)
logging_output["_num_char_errors"] = metrics["num_char_errors"]
logging_output["_num_chars"] = metrics["num_chars"]
logging_output["_num_word_errors"] = metrics["num_word_errors"]
logging_output["_num_words"] = metrics["num_words"]
logging_output["_num_em_errors"] = metrics["num_em_errors"]
logging_output["_num_ems"] = metrics["num_ems"]
logging_output["_num_tree_errors"] = metrics["num_tree_errors"]
logging_output["_num_trees"] = metrics["num_trees"]
if self.cfg.eval_wer and self.cfg.autoregressive:
metrics = self._inference_with_wer(self.sequence_generator, sample, model)
logging_output["_num_char_errors"] = metrics["num_char_errors"]
logging_output["_num_chars"] = metrics["num_chars"]
logging_output["_num_word_errors"] = metrics["num_word_errors"]
logging_output["_num_words"] = metrics["num_words"]
if self.cfg.eval_bleu and self.cfg.autoregressive:
metrics = self._inference_with_bleu(self.sequence_generator, sample, model)
logging_output['_bleu_sys_len'] = metrics.sys_len
logging_output['_bleu_ref_len'] = metrics.ref_len
# we split counts into separate entries so that they can be
# summed efficiently across workers using fast-stat-sync
assert len(metrics.counts) == 4
for i in range(4):
logging_output[f"_bleu_counts_{i}"] = metrics.counts[i]
logging_output[f"_bleu_totals_{i}"] = metrics.totals[i]
return loss, sample_size, logging_output
def build_model(self, model_cfg: FairseqDataclass):
model = super().build_model(model_cfg)
if (self.cfg.eval_wer or self.cfg.eval_wer_parse) and self.cfg.autoregressive:
self.sequence_generator = self.build_generator(
[model],
self.cfg.eval_wer_config,
)
if self.cfg.eval_wer_tokenizer:
self.tokenizer = encoders.build_tokenizer(self.cfg.eval_wer_tokenizer)
else:
self.tokenizer = None
if self.cfg.eval_bleu and self.cfg.autoregressive:
assert self.cfg.eval_bleu_detok is not None, (
'--eval-bleu-detok is required if using --eval-bleu; '
'try --eval-bleu-detok=moses (or --eval-bleu-detok=space '
'to disable detokenization, e.g., when using sentencepiece)'
)
detok_args = json.loads(self.cfg.eval_bleu_detok_args)
self.tokenizer = encoders.build_tokenizer(
Namespace(tokenizer=self.cfg.eval_bleu_detok, **detok_args)
)
gen_args = json.loads(self.cfg.eval_bleu_args)
gen_args = Namespace(**gen_args)
self.sequence_generator = self.build_generator([model], gen_args)
return model
def _inference_with_wer_parse(self, generator, sample, model):
import editdistance
def decode(toks):
s = self.target_dictionary.string(
toks.int().cpu(),
self.cfg.eval_wer_post_process,
escape_unk=True,
)
if self.tokenizer:
s = self.tokenizer.decode(s)
return s
def decode_to_list(toks):
def token_string(i):
if i == self.target_dictionary.unk():
return self.target_dictionary.unk_string(False)
else:
return self.target_dictionary[i]
return [ token_string(i) for i in toks ]
def is_ont_token(token):
return '[' in token or ']' in token
def post_process(l):
o = []
for w in l:
if(w == self.target_dictionary.eos_word or w == '|'):
continue
if(w == '_'):
o.append(' ')
else:
o.append(w)
if(is_ont_token(w)):
o.append(' ')
return o
num_word_errors, num_char_errors = 0, 0
num_chars, num_words = 0, 0
num_em_errors, num_ems = 0, 0
num_tree_errors, num_trees = 0, 0
gen_out = self.inference_step(generator, [model], sample, None)
for i in range(len(gen_out)):
hyp_tokens = gen_out[i][0]["tokens"]
# hyp = decode(hyp_tokens)
ref_tokens = utils.strip_pad(sample["target"][i], self.target_dictionary.pad())
# ref = decode(ref_tokens)
hyp_list = decode_to_list(hyp_tokens)
ref_list = decode_to_list(ref_tokens)
hyp_list = post_process(hyp_list)
ref_list = post_process(ref_list)
hyp = ''.join(hyp_list).strip()
ref = ''.join(ref_list).strip()
num_chars += len(ref)
num_char_errors += editdistance.eval(hyp, ref)
hyp_words = hyp.split()
ref_words = ref.split()
hyp_tree = [ word for word in hyp_list if ('[' in word or ']' in word) ]
ref_tree = [ word for word in ref_list if ('[' in word or ']' in word) ]
# num_word_errors += editdistance.eval(hyp_words, ref_words)
hyp_before = decode(hyp_tokens).split()
ref_before = decode(ref_tokens).split()
num_word_errors += editdistance.eval(hyp_before, ref_before)
num_words += len(ref_before)
if(hyp != ref):
num_em_errors += 1
if(hyp_tree != ref_tree):
num_tree_errors += 1
num_ems += 1
num_trees += 1
return {
"num_char_errors": num_char_errors,
"num_chars": num_chars,
"num_word_errors": num_word_errors,
"num_words": num_words,
"num_ems": num_ems,
"num_em_errors": num_em_errors,
"num_trees": num_trees,
"num_tree_errors": num_tree_errors
}
def _inference_with_wer(self, generator, sample, model):
import editdistance
def decode(toks):
s = self.target_dictionary.string(
toks.int().cpu(),
self.cfg.eval_wer_post_process,
escape_unk=True,
)
if self.tokenizer:
s = self.tokenizer.decode(s)
return s
num_word_errors, num_char_errors = 0, 0
num_chars, num_words = 0, 0
gen_out = self.inference_step(generator, [model], sample, None)
for i in range(len(gen_out)):
hyp = decode(gen_out[i][0]["tokens"])
ref = decode(
utils.strip_pad(sample["target"][i], self.target_dictionary.pad()),
)
num_char_errors += editdistance.eval(hyp, ref)
num_chars += len(ref)
hyp_words = hyp.split()
ref_words = ref.split()
num_word_errors += editdistance.eval(hyp_words, ref_words)
num_words += len(ref_words)
return {
"num_char_errors": num_char_errors,
"num_chars": num_chars,
"num_word_errors": num_word_errors,
"num_words": num_words,
}
def _inference_with_bleu(self, generator, sample, model):
import sacrebleu
def decode(toks, is_ref):
s = self.target_dictionary.string(
toks.int().cpu(),
self.cfg.eval_bleu_remove_bpe,
# The default unknown string in fairseq is `<unk>`, but
# this is tokenized by sacrebleu as `< unk >`, inflating
# BLEU scores. Instead, we use a somewhat more verbose
# alternative that is unlikely to appear in the real
# reference, but doesn't get split into multiple tokens.
unk_string=(
"UNKNOWNTOKENINREF" if is_ref else "UNKNOWNTOKENINHYP"
),
)
if self.tokenizer:
s = self.tokenizer.decode(s)
return s
gen_out = self.inference_step(generator, [model], sample)
hyps, refs = [], []
for i in range(len(gen_out)):
hyps.append(decode(gen_out[i][0]['tokens'], is_ref=False))
refs.append(
decode(
utils.strip_pad(
sample['target'][i],
self.target_dictionary.pad()
),
is_ref=True, # don't count <unk> as matches to the hypo
)
)
if self.cfg.eval_bleu_print_samples:
logger.info('H-{} {}'.format(sample["id"][0], hyps[0]))
logger.info('T-{} {}'.format(sample["id"][0], refs[0]))
eval_tokenization = 'none' if self.cfg.eval_tokenized_bleu else '13a'
return sacrebleu.corpus_bleu(hyps, [refs], tokenize=eval_tokenization)
def reduce_metrics(self, logging_outputs, criterion):
super().reduce_metrics(logging_outputs, criterion)
if self.cfg.eval_wer or self.cfg.eval_wer_parse:
zero = torch.scalar_tensor(0.0)
num_char_errors = sum(
log.get("_num_char_errors", zero) for log in logging_outputs
)
num_chars = sum(log.get("_num_chars", zero) for log in logging_outputs)
num_word_errors = sum(
log.get("_num_word_errors", zero) for log in logging_outputs
)
num_words = sum(log.get("_num_words", zero) for log in logging_outputs)
metrics.log_scalar("_num_char_errors", num_char_errors)
metrics.log_scalar("_num_chars", num_chars)
metrics.log_scalar("_num_word_errors", num_word_errors)
metrics.log_scalar("_num_words", num_words)
if num_chars > 0:
metrics.log_derived(
"uer",
lambda meters: meters["_num_char_errors"].sum
* 100.0
/ meters["_num_chars"].sum
if meters["_num_chars"].sum > 0
else float("nan"),
)
if num_words > 0:
metrics.log_derived(
"wer",
lambda meters: meters["_num_word_errors"].sum
* 100.0
/ meters["_num_words"].sum
if meters["_num_words"].sum > 0
else float("nan"),
)
if self.cfg.eval_wer_parse:
num_em_errors = sum(
log.get("_num_em_errors", zero) for log in logging_outputs
)
num_ems = sum(
log.get("_num_ems", zero) for log in logging_outputs
)
metrics.log_scalar("_num_em_errors", num_em_errors)
metrics.log_scalar("_num_ems", num_ems)
num_tree_errors = sum(
log.get("_num_tree_errors", zero) for log in logging_outputs
)
num_trees = sum(
log.get("_num_trees", zero) for log in logging_outputs
)
metrics.log_scalar("_num_tree_errors", num_tree_errors)
metrics.log_scalar("_num_trees", num_trees)
if num_ems > 0:
metrics.log_derived(
"em_error",
lambda meters: meters["_num_em_errors"].sum
* 100.0
/ meters["_num_ems"].sum
if meters["_num_ems"].sum > 0
else float("nan"),
)
if num_trees > 0:
metrics.log_derived(
"tree_error",
lambda meters: meters["_num_tree_errors"].sum
* 100.0
/ meters["_num_trees"].sum
if meters["_num_trees"].sum > 0
else float("nan"),
)
if self.cfg.eval_bleu:
len_keys = ["_bleu_sys_len", "_bleu_ref_len"]
count_keys = [f"_bleu_counts_{i}" for i in range(4)]
total_keys = [f"_bleu_totals_{i}" for i in range(4)]
for k in len_keys + count_keys + total_keys:
metrics.log_scalar(
k, sum(log.get(k, 0) for log in logging_outputs)
)
import sacrebleu
metrics.log_derived(
'bleu',
lambda meters: sacrebleu.compute_bleu(
correct=[meters[k].sum for k in count_keys],
total=[meters[k].sum for k in total_keys],
sys_len=meters['_bleu_sys_len'].sum,
ref_len=meters['_bleu_ref_len'].sum,
smooth_method="exp"
).score
)