Misc fixes (#2492)

Summary: Pull Request resolved: https://github.com/pytorch/fairseq/pull/2492

Reviewed By: ngoyal2707

Differential Revision: D23177728

Pulled By: myleott

fbshipit-source-id: 32424f61cab57f759f87e16e8d5144d3eed5ae36
This commit is contained in:
Myle Ott 2020-08-20 06:40:45 -07:00 committed by Facebook GitHub Bot
parent 54b934417d
commit adbd89fd4b
20 changed files with 121 additions and 100 deletions

View File

@ -47,6 +47,9 @@ hypothesis along with an average log-likelihood; and *P* is the
positional score per token position, including the
end-of-sentence marker which is omitted from the text.
Other types of output lines you might see are *D*, the detokenized hypothesis,
*T*, the reference target, *A*, alignment info, *E* the history of generation steps.
See the `README <https://github.com/pytorch/fairseq#pre-trained-models>`__ for a
full list of pre-trained models available.

View File

@ -25,16 +25,23 @@ Given a directory containing wav files to be used for pretraining (we recommend
### Prepare training data manifest:
First, install the `soundfile` library:
```shell script
pip install soundfile
```
Next, run:
```shell script
$ python examples/wav2vec/wav2vec_manifest.py /path/to/waves --dest /manifest/path --ext $ext --valid-percent $valid
```
$ext should be set to flac, wav, or whatever format your dataset happens to use that soundfile can read.
$valid should be set to some reasonable percentage (like 0.01) of training data to use for validation.
To use a pre-defined validation set (like dev-other from librispeech), set to it 0 and then overwrite valid.tsv with a
separately pre-processed manifest file.
```shell script
$ python examples/wav2vec/wav2vec_manifest.py /path/to/waves --dest /manifest/path --ext $ext --valid-percent $valid
```
### Train a wav2vec 2.0 base model:
This configuration was used for the base model trained on the Librispeech dataset in the wav2vec 2.0 paper
@ -43,7 +50,7 @@ Note that this was tested with pytorch 1.4.0 and the input is expected to be sin
```shell script
$ python train.py --distributed-world-size 64 --distributed-port $PORT /manifest/path \
--save-dir /model/path fp16 --num-workers 6 --task audio_pretraining --criterion wav2vec --arch wav2vec2 \
--save-dir /model/path --fp16 --num-workers 6 --task audio_pretraining --criterion wav2vec --arch wav2vec2 \
--log-keys '["prob_perplexity","code_perplexity","temp"]' --quantize-targets --extractor-mode default \
--conv-feature-layers '[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512,2,2)] * 2' --final-dim 256 --latent-vars 320 \
--latent-groups 2 --latent-temp '(2,0.5,0.999995)' --infonce --optimizer adam \

View File

@ -20,6 +20,7 @@ import fairseq.modules # noqa
import fairseq.optim # noqa
import fairseq.optim.lr_scheduler # noqa
import fairseq.pdb # noqa
import fairseq.scoring # noqa
import fairseq.tasks # noqa
import fairseq.benchmark # noqa

View File

@ -253,11 +253,11 @@ torch::Tensor GenerateDeletionLabelCuda(
AT_DISPATCH_ALL_TYPES(source.scalar_type(), "generate_deletion_labels", ([&] {
generate_deletion_label_kernel<scalar_t><<<batch_size, 1, 0, stream>>>(
source.data<scalar_t>(),
source.data_ptr<scalar_t>(),
source.size(1),
operations.size(1),
operations.data<int>(),
labels.data<int>());
operations.data_ptr<int>(),
labels.data_ptr<int>());
}));
return labels;
@ -276,12 +276,12 @@ auto stream = at::cuda::getCurrentCUDAStream(target.device().index());
AT_DISPATCH_ALL_TYPES(target.scalar_type(), "generate_insertion_labels", ([&] {
generate_insertion_label_kernel<scalar_t><<<batch_size, 1, 0, stream>>>(
target.data<scalar_t>(),
target.data_ptr<scalar_t>(),
target.size(1),
operations.size(1),
operations.data<int>(),
labels.data<int>(),
masks.data<int>());
operations.data_ptr<int>(),
labels.data_ptr<int>(),
masks.data_ptr<int>());
}));
return std::make_pair(labels, masks);
@ -306,25 +306,25 @@ torch::Tensor LevenshteinDistanceCuda(
auto distances = torch::empty({batch_size, (source.size(1) + 1) * (target.size(1) + 1)}, options);
AT_DISPATCH_ALL_TYPES(source.scalar_type(), "levenshtein_distance", ([&] {
levenshtein_distance_kernel<scalar_t><<<batch_size, 1, 0, stream>>>(
source.data<scalar_t>(),
target.data<scalar_t>(),
source_length.data<int>(),
target_length.data<int>(),
source.data_ptr<scalar_t>(),
target.data_ptr<scalar_t>(),
source_length.data_ptr<int>(),
target_length.data_ptr<int>(),
source.size(1),
target.size(1),
operations.data<int>(),
distances.data<int>());
operations.data_ptr<int>(),
distances.data_ptr<int>());
}));
} else {
AT_DISPATCH_ALL_TYPES(source.scalar_type(), "faster_levenshtein_distance", ([&] {
faster_levenshtein_distance_kernel<scalar_t><<<batch_size, 1, shared_size, stream>>>(
source.data<scalar_t>(),
target.data<scalar_t>(),
source_length.data<int>(),
target_length.data<int>(),
source.data_ptr<scalar_t>(),
target.data_ptr<scalar_t>(),
source_length.data_ptr<int>(),
target_length.data_ptr<int>(),
source.size(1),
target.size(1),
operations.data<int>());
operations.data_ptr<int>());
}));
}

View File

@ -8,7 +8,7 @@ import math
import torch
import torch.nn.functional as F
from fairseq import utils
from fairseq import metrics, utils
from fairseq.criterions import FairseqCriterion, register_criterion
@ -127,7 +127,7 @@ class LegacyMaskedLmLoss(FairseqCriterion):
return loss, sample_size, logging_output
@staticmethod
def aggregate_logging_outputs(logging_outputs):
def reduce_metrics(logging_outputs) -> None:
"""Aggregate logging outputs from data parallel training."""
lm_loss_sum = sum(log.get('lm_loss', 0) for log in logging_outputs)
sentence_loss_sum = sum(
@ -137,16 +137,10 @@ class LegacyMaskedLmLoss(FairseqCriterion):
sample_size = sum(log.get('sample_size', 0) for log in logging_outputs)
agg_loss = sum(log.get('loss', 0) for log in logging_outputs)
agg_output = {
'loss': agg_loss / sample_size / math.log(2) if sample_size > 0 else 0.,
'lm_loss': lm_loss_sum / ntokens / math.log(2) if ntokens > 0 else 0.,
'sentence_loss': sentence_loss_sum / nsentences / math.log(2) if nsentences > 0 else 0.,
'nll_loss': lm_loss_sum / ntokens / math.log(2) if ntokens > 0 else 0.,
'ntokens': ntokens,
'nsentences': nsentences,
'sample_size': sample_size,
}
return agg_output
metrics.log_scalar('loss', agg_loss / sample_size / math.log(2) if sample_size > 0 else 0., sample_size, round=3)
metrics.log_scalar('lm_loss', lm_loss_sum / ntokens / math.log(2) if ntokens > 0 else 0., ntokens, round=3)
metrics.log_scalar('sentence_loss', sentence_loss_sum / nsentences / math.log(2) if nsentences > 0 else 0., nsentences, round=3)
metrics.log_scalar('nll_loss', lm_loss_sum / ntokens / math.log(2) if ntokens > 0 else 0., ntokens, round=3)
@staticmethod
def logging_outputs_can_be_summed() -> bool:

View File

@ -9,11 +9,12 @@ import numpy as np
cimport cython
cimport numpy as np
DTYPE = np.int64
ctypedef np.int64_t DTYPE_t
from libc.stdint cimport int32_t, int64_t
ctypedef int64_t DTYPE_t
cdef _is_batch_full(long num_sentences, long num_tokens, long max_tokens, long max_sentences):
cdef _is_batch_full(int64_t num_sentences, int64_t num_tokens, int64_t max_tokens, int64_t max_sentences):
if num_sentences == 0:
return 0
if max_sentences > 0 and num_sentences == max_sentences:
@ -27,18 +28,18 @@ cdef _is_batch_full(long num_sentences, long num_tokens, long max_tokens, long m
cpdef list batch_by_size_fast(
np.ndarray[DTYPE_t, ndim=1] indices,
num_tokens_fn,
long max_tokens,
long max_sentences,
int bsz_mult,
int64_t max_tokens,
int64_t max_sentences,
int32_t bsz_mult,
):
cdef long sample_len = 0
cdef int64_t sample_len = 0
cdef list sample_lens = []
cdef list batch = []
cdef list batches = []
cdef long mod_len
cdef long i
cdef long idx
cdef long num_tokens
cdef int64_t mod_len
cdef int64_t i
cdef int64_t idx
cdef int64_t num_tokens
cdef DTYPE_t[:] indices_view = indices
for i in range(len(indices_view)):
@ -70,8 +71,8 @@ cpdef list batch_by_size_fast(
cdef _find_valid_shape(
DTYPE_t[:, :] shapes_view,
long num_sentences,
long num_tokens,
int64_t num_sentences,
int64_t num_tokens,
):
"""Return index of first valid shape of -1 if none is found."""
for i in range(shapes_view.shape[0]):
@ -86,14 +87,14 @@ cpdef list batch_fixed_shapes_fast(
num_tokens_fn,
np.ndarray[DTYPE_t, ndim=2] fixed_shapes_sorted,
):
cdef long sample_len = 0
cdef int64_t sample_len = 0
cdef list sample_lens = []
cdef list batch = []
cdef list batches = []
cdef long mod_len
cdef long i
cdef long idx
cdef long num_tokens
cdef int64_t mod_len
cdef int64_t i
cdef int64_t idx
cdef int64_t num_tokens
cdef DTYPE_t[:] indices_view = indices
cdef DTYPE_t[:, :] shapes_view = fixed_shapes_sorted

View File

@ -12,8 +12,10 @@ from libc.math cimport ceil
cimport cython
cimport numpy as np
from libc.stdint cimport int32_t, int64_t
DTYPE = np.int64
ctypedef np.int64_t DTYPE_t
ctypedef int64_t DTYPE_t
@cython.boundscheck(False)

View File

@ -266,7 +266,7 @@ class IterativeRefinementGenerator(object):
if decoder_out.history is not None
else None,
)
encoder_out = model.encoder.reorder_encoder_out(encoder_out, not_terminated.nonzero().squeeze())
encoder_out = model.encoder.reorder_encoder_out(encoder_out, not_terminated.nonzero(as_tuple=False).squeeze())
sent_idxs = sent_idxs[not_terminated]
prev_output_tokens = prev_decoder_out.output_tokens.clone()

View File

@ -66,14 +66,14 @@ class ModelParallelTransformerSentenceEncoderLayer(TransformerSentenceEncoderLay
need_weights=False,
attn_mask=self_attn_mask,
)
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.dropout_module(x)
x = residual + x
residual = x
x = self.final_layer_norm(x)
x = self.activation_fn(self.fc1(x))
x = F.dropout(x, p=self.activation_dropout, training=self.training)
x = self.activation_dropout_module(x)
x = self.fc2(x)
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.dropout_module(x)
x = residual + x
return x, None

View File

@ -250,7 +250,7 @@ def _skip_encoder_out(encoder, encoder_out, mask):
if not mask.any():
return encoder_out
else:
return encoder.reorder_encoder_out(encoder_out, mask.nonzero().squeeze())
return encoder.reorder_encoder_out(encoder_out, mask.nonzero(as_tuple=False).squeeze())
def _fill(x, mask, y, padding_idx):

View File

@ -144,7 +144,7 @@ class AdaptiveSoftmax(nn.Module):
new_target[0][mask] = self.cutoff[0] + i
if mask.any():
target_idxs.append(mask.nonzero().squeeze(1))
target_idxs.append(mask.nonzero(as_tuple=False).squeeze(1))
new_target.append(target[mask].add(-self.cutoff[i]))
else:
target_idxs.append(None)

View File

@ -11,15 +11,6 @@ import torch
from fairseq.scoring import register_scoring
try:
from fairseq import libbleu
except ImportError as e:
sys.stderr.write("ERROR: missing libbleu.so. run `pip install --editable .`\n")
raise e
C = ctypes.cdll.LoadLibrary(libbleu.__file__)
class BleuStat(ctypes.Structure):
_fields_ = [
@ -70,13 +61,22 @@ class Scorer(object):
self.pad = pad
self.eos = eos
self.unk = unk
try:
from fairseq import libbleu
except ImportError as e:
sys.stderr.write("ERROR: missing libbleu.so. run `pip install --editable .`\n")
raise e
self.C = ctypes.cdll.LoadLibrary(libbleu.__file__)
self.reset()
def reset(self, one_init=False):
if one_init:
C.bleu_one_init(ctypes.byref(self.stat))
self.C.bleu_one_init(ctypes.byref(self.stat))
else:
C.bleu_zero_init(ctypes.byref(self.stat))
self.C.bleu_zero_init(ctypes.byref(self.stat))
def add(self, ref, pred):
if not isinstance(ref, torch.IntTensor):
@ -92,7 +92,7 @@ class Scorer(object):
rref = rref.contiguous().view(-1)
pred = pred.contiguous().view(-1)
C.bleu_add(
self.C.bleu_add(
ctypes.byref(self.stat),
ctypes.c_size_t(rref.size(0)),
ctypes.c_void_p(rref.data_ptr()),

View File

@ -3,8 +3,6 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import editdistance
from fairseq.scoring import register_scoring
@ -18,6 +16,7 @@ class WerScorer(object):
self.ref_length = 0
def add_string(self, ref, pred):
import editdistance
ref_items = ref.split()
pred_items = pred.split()
self.distance += editdistance.eval(ref_items, pred_items)

View File

@ -133,7 +133,9 @@ class DiverseBeamSearch(Search):
# apply diversity penalty
if g > 0:
lprobs_g = torch.add(
lprobs_g, self.diversity_strength, diversity_buf.unsqueeze(1)
lprobs_g,
other=diversity_buf.unsqueeze(1),
alpha=self.diversity_strength,
)
else:
lprobs_g = lprobs_g.contiguous()

View File

@ -183,7 +183,11 @@ class SequenceGenerator(nn.Module):
src_lengths = (src_tokens.ne(self.eos) & src_tokens.ne(self.pad)).long().sum(dim=1)
elif 'source' in net_input:
src_tokens = net_input['source']
src_lengths = net_input['padding_mask'].size(-1) - net_input['padding_mask'].sum(-1) if net_input['padding_mask'] is not None else torch.tensor(src_tokens.size(-1))
src_lengths = (
net_input['padding_mask'].size(-1) - net_input['padding_mask'].sum(-1)
if net_input['padding_mask'] is not None
else torch.tensor(src_tokens.size(-1)).to(src_tokens)
)
else:
raise Exception('expected src_tokens or source in net input')
@ -372,11 +376,10 @@ class SequenceGenerator(nn.Module):
new_bsz = bsz - len(finalized_sents)
# construct batch_idxs which holds indices of batches to keep for the next pass
batch_mask = torch.ones(bsz).to(cand_indices)
batch_mask[
torch.tensor(finalized_sents).to(cand_indices)
] = torch.tensor(0).to(batch_mask)
batch_idxs = batch_mask.nonzero().squeeze(-1)
batch_mask = torch.ones(bsz, dtype=torch.bool, device=cand_indices.device)
batch_mask[finalized_sents] = False
# TODO replace `nonzero(as_tuple=False)` after TorchScript supports it
batch_idxs = torch.arange(bsz, device=cand_indices.device).masked_select(batch_mask)
eos_mask = eos_mask[batch_idxs]
cand_beams = cand_beams[batch_idxs]
@ -665,7 +668,7 @@ class SequenceGenerator(nn.Module):
for bbsz_idx in range(bsz * beam_size):
lprobs[bbsz_idx][
torch.tensor(banned_tokens[bbsz_idx]).long()
] = torch.tensor(-math.inf, dtype=torch.float)
] = torch.tensor(-math.inf).to(lprobs)
return lprobs

View File

@ -16,7 +16,7 @@ from fairseq.data import (
NestedDictionaryDataset,
NumelDataset,
NumSamplesDataset,
PadDataset,
RightPadDataset,
PrependTokenDataset,
SortDataset,
TokenBlockDataset,
@ -150,17 +150,15 @@ class MaskedLMTask(FairseqTask):
{
'id': IdDataset(),
'net_input': {
'src_tokens': PadDataset(
'src_tokens': RightPadDataset(
src_dataset,
pad_idx=self.source_dictionary.pad(),
left_pad=False,
),
'src_lengths': NumelDataset(src_dataset, reduce=False),
},
'target': PadDataset(
'target': RightPadDataset(
tgt_dataset,
pad_idx=self.source_dictionary.pad(),
left_pad=False,
),
'nsentences': NumSamplesDataset(),
'ntokens': NumelDataset(src_dataset, reduce=True),
@ -174,7 +172,7 @@ class MaskedLMTask(FairseqTask):
)
def build_dataset_for_inference(self, src_tokens, src_lengths, sort=True):
src_dataset = PadDataset(
src_dataset = RightPadDataset(
TokenBlockDataset(
src_tokens,
src_lengths,
@ -184,7 +182,6 @@ class MaskedLMTask(FairseqTask):
break_mode='eos',
),
pad_idx=self.source_dictionary.pad(),
left_pad=False,
)
src_dataset = PrependTokenDataset(src_dataset, self.source_dictionary.bos())
src_dataset = NestedDictionaryDataset(

View File

@ -192,16 +192,20 @@ class SentencePredictionTask(FairseqTask):
else:
label_path = "{0}.label".format(get_path('label', split))
if os.path.exists(label_path):
def parse_regression_target(i, line):
values = line.split()
assert len(values) == self.args.num_classes, \
f'expected num_classes={self.args.num_classes} regression target values on line {i}, found: "{line}"'
return [float(x) for x in values]
dataset.update(
target=RawLabelDataset([
parse_regression_target(i, line.strip()) for i, line in enumerate(open(label_path).readlines())
])
)
with open(label_path) as h:
dataset.update(
target=RawLabelDataset([
parse_regression_target(i, line.strip())
for i, line in enumerate(h.readlines())
])
)
nested_dataset = NestedDictionaryDataset(
dataset,

View File

@ -527,8 +527,8 @@ def get_token_to_word_mapping(tokens, exclude_list):
def extract_hard_alignment(attn, src_sent, tgt_sent, pad, eos):
tgt_valid = ((tgt_sent != pad) & (tgt_sent != eos)).nonzero().squeeze(dim=-1)
src_invalid = ((src_sent == pad) | (src_sent == eos)).nonzero().squeeze(dim=-1)
tgt_valid = ((tgt_sent != pad) & (tgt_sent != eos)).nonzero(as_tuple=False).squeeze(dim=-1)
src_invalid = ((src_sent == pad) | (src_sent == eos)).nonzero(as_tuple=False).squeeze(dim=-1)
src_token_to_word = get_token_to_word_mapping(src_sent, [eos, pad])
tgt_token_to_word = get_token_to_word_mapping(tgt_sent, [eos, pad])
alignment = []

View File

@ -3,6 +3,7 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
import unittest
from tests.test_sequence_generator import get_dummy_task_and_parser
@ -17,6 +18,10 @@ class TestInferenceDropout(unittest.TestCase):
self.args = self.parser.parse_args([])
self.args.encoder_layers = 2
self.args.decoder_layers = 1
logging.disable(logging.CRITICAL)
def tearDown(self):
logging.disable(logging.NOTSET)
def test_sets_inference_dropout_to_true(self):
self.args.retain_dropout = True

View File

@ -4,8 +4,9 @@
# LICENSE file in the root directory of this source tree.
import contextlib
from io import StringIO
import logging
import unittest
from io import StringIO
from unittest.mock import MagicMock, patch
import torch
@ -74,6 +75,11 @@ class TestLoadCheckpoint(unittest.TestCase):
}
self.applied_patches = [patch(p, d) for p, d in self.patches.items()]
[p.start() for p in self.applied_patches]
logging.disable(logging.CRITICAL)
def tearDown(self):
patch.stopall()
logging.disable(logging.NOTSET)
def test_load_partial_checkpoint(self):
with contextlib.redirect_stdout(StringIO()):
@ -192,9 +198,6 @@ class TestLoadCheckpoint(unittest.TestCase):
self.assertFalse(reset_lr_scheduler)
self.assertFalse(reset_meters)
def tearDown(self):
patch.stopall()
if __name__ == '__main__':
unittest.main()