From adbd89fd4be9e68100bf9a4ba9eed1e7fb2e4040 Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Thu, 20 Aug 2020 06:40:45 -0700 Subject: [PATCH] Misc fixes (#2492) Summary: Pull Request resolved: https://github.com/pytorch/fairseq/pull/2492 Reviewed By: ngoyal2707 Differential Revision: D23177728 Pulled By: myleott fbshipit-source-id: 32424f61cab57f759f87e16e8d5144d3eed5ae36 --- docs/getting_started.rst | 3 ++ examples/wav2vec/README.md | 17 ++++++--- fairseq/__init__.py | 1 + fairseq/clib/libnat_cuda/edit_dist.cu | 36 +++++++++--------- fairseq/criterions/legacy_masked_lm.py | 18 +++------ fairseq/data/data_utils_fast.pyx | 37 ++++++++++--------- fairseq/data/token_block_utils_fast.pyx | 4 +- fairseq/iterative_refinement_generator.py | 2 +- .../transformer_sentence_encoder_layer.py | 6 +-- fairseq/models/nat/levenshtein_utils.py | 2 +- fairseq/modules/adaptive_softmax.py | 2 +- fairseq/scoring/bleu.py | 24 ++++++------ fairseq/scoring/wer.py | 3 +- fairseq/search.py | 4 +- fairseq/sequence_generator.py | 17 +++++---- fairseq/tasks/masked_lm.py | 11 ++---- fairseq/tasks/sentence_prediction.py | 14 ++++--- fairseq/utils.py | 4 +- tests/test_inference_dropout.py | 5 +++ tests/test_train.py | 11 ++++-- 20 files changed, 121 insertions(+), 100 deletions(-) diff --git a/docs/getting_started.rst b/docs/getting_started.rst index 416e29531..fa5971dd3 100644 --- a/docs/getting_started.rst +++ b/docs/getting_started.rst @@ -47,6 +47,9 @@ hypothesis along with an average log-likelihood; and *P* is the positional score per token position, including the end-of-sentence marker which is omitted from the text. +Other types of output lines you might see are *D*, the detokenized hypothesis, +*T*, the reference target, *A*, alignment info, *E* the history of generation steps. + See the `README `__ for a full list of pre-trained models available. diff --git a/examples/wav2vec/README.md b/examples/wav2vec/README.md index ca01e181c..2e59798ea 100644 --- a/examples/wav2vec/README.md +++ b/examples/wav2vec/README.md @@ -25,16 +25,23 @@ Given a directory containing wav files to be used for pretraining (we recommend ### Prepare training data manifest: +First, install the `soundfile` library: +```shell script +pip install soundfile +``` + +Next, run: + +```shell script +$ python examples/wav2vec/wav2vec_manifest.py /path/to/waves --dest /manifest/path --ext $ext --valid-percent $valid +``` + $ext should be set to flac, wav, or whatever format your dataset happens to use that soundfile can read. $valid should be set to some reasonable percentage (like 0.01) of training data to use for validation. To use a pre-defined validation set (like dev-other from librispeech), set to it 0 and then overwrite valid.tsv with a separately pre-processed manifest file. -```shell script -$ python examples/wav2vec/wav2vec_manifest.py /path/to/waves --dest /manifest/path --ext $ext --valid-percent $valid -``` - ### Train a wav2vec 2.0 base model: This configuration was used for the base model trained on the Librispeech dataset in the wav2vec 2.0 paper @@ -43,7 +50,7 @@ Note that this was tested with pytorch 1.4.0 and the input is expected to be sin ```shell script $ python train.py --distributed-world-size 64 --distributed-port $PORT /manifest/path \ ---save-dir /model/path fp16 --num-workers 6 --task audio_pretraining --criterion wav2vec --arch wav2vec2 \ +--save-dir /model/path --fp16 --num-workers 6 --task audio_pretraining --criterion wav2vec --arch wav2vec2 \ --log-keys '["prob_perplexity","code_perplexity","temp"]' --quantize-targets --extractor-mode default \ --conv-feature-layers '[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512,2,2)] * 2' --final-dim 256 --latent-vars 320 \ --latent-groups 2 --latent-temp '(2,0.5,0.999995)' --infonce --optimizer adam \ diff --git a/fairseq/__init__.py b/fairseq/__init__.py index 3dd29637a..1ba63fcaa 100644 --- a/fairseq/__init__.py +++ b/fairseq/__init__.py @@ -20,6 +20,7 @@ import fairseq.modules # noqa import fairseq.optim # noqa import fairseq.optim.lr_scheduler # noqa import fairseq.pdb # noqa +import fairseq.scoring # noqa import fairseq.tasks # noqa import fairseq.benchmark # noqa diff --git a/fairseq/clib/libnat_cuda/edit_dist.cu b/fairseq/clib/libnat_cuda/edit_dist.cu index b6486a8c2..22de16b27 100644 --- a/fairseq/clib/libnat_cuda/edit_dist.cu +++ b/fairseq/clib/libnat_cuda/edit_dist.cu @@ -253,11 +253,11 @@ torch::Tensor GenerateDeletionLabelCuda( AT_DISPATCH_ALL_TYPES(source.scalar_type(), "generate_deletion_labels", ([&] { generate_deletion_label_kernel<<>>( - source.data(), + source.data_ptr(), source.size(1), operations.size(1), - operations.data(), - labels.data()); + operations.data_ptr(), + labels.data_ptr()); })); return labels; @@ -276,12 +276,12 @@ auto stream = at::cuda::getCurrentCUDAStream(target.device().index()); AT_DISPATCH_ALL_TYPES(target.scalar_type(), "generate_insertion_labels", ([&] { generate_insertion_label_kernel<<>>( - target.data(), + target.data_ptr(), target.size(1), operations.size(1), - operations.data(), - labels.data(), - masks.data()); + operations.data_ptr(), + labels.data_ptr(), + masks.data_ptr()); })); return std::make_pair(labels, masks); @@ -306,25 +306,25 @@ torch::Tensor LevenshteinDistanceCuda( auto distances = torch::empty({batch_size, (source.size(1) + 1) * (target.size(1) + 1)}, options); AT_DISPATCH_ALL_TYPES(source.scalar_type(), "levenshtein_distance", ([&] { levenshtein_distance_kernel<<>>( - source.data(), - target.data(), - source_length.data(), - target_length.data(), + source.data_ptr(), + target.data_ptr(), + source_length.data_ptr(), + target_length.data_ptr(), source.size(1), target.size(1), - operations.data(), - distances.data()); + operations.data_ptr(), + distances.data_ptr()); })); } else { AT_DISPATCH_ALL_TYPES(source.scalar_type(), "faster_levenshtein_distance", ([&] { faster_levenshtein_distance_kernel<<>>( - source.data(), - target.data(), - source_length.data(), - target_length.data(), + source.data_ptr(), + target.data_ptr(), + source_length.data_ptr(), + target_length.data_ptr(), source.size(1), target.size(1), - operations.data()); + operations.data_ptr()); })); } diff --git a/fairseq/criterions/legacy_masked_lm.py b/fairseq/criterions/legacy_masked_lm.py index 10dea76e4..3dbfdfbe4 100644 --- a/fairseq/criterions/legacy_masked_lm.py +++ b/fairseq/criterions/legacy_masked_lm.py @@ -8,7 +8,7 @@ import math import torch import torch.nn.functional as F -from fairseq import utils +from fairseq import metrics, utils from fairseq.criterions import FairseqCriterion, register_criterion @@ -127,7 +127,7 @@ class LegacyMaskedLmLoss(FairseqCriterion): return loss, sample_size, logging_output @staticmethod - def aggregate_logging_outputs(logging_outputs): + def reduce_metrics(logging_outputs) -> None: """Aggregate logging outputs from data parallel training.""" lm_loss_sum = sum(log.get('lm_loss', 0) for log in logging_outputs) sentence_loss_sum = sum( @@ -137,16 +137,10 @@ class LegacyMaskedLmLoss(FairseqCriterion): sample_size = sum(log.get('sample_size', 0) for log in logging_outputs) agg_loss = sum(log.get('loss', 0) for log in logging_outputs) - agg_output = { - 'loss': agg_loss / sample_size / math.log(2) if sample_size > 0 else 0., - 'lm_loss': lm_loss_sum / ntokens / math.log(2) if ntokens > 0 else 0., - 'sentence_loss': sentence_loss_sum / nsentences / math.log(2) if nsentences > 0 else 0., - 'nll_loss': lm_loss_sum / ntokens / math.log(2) if ntokens > 0 else 0., - 'ntokens': ntokens, - 'nsentences': nsentences, - 'sample_size': sample_size, - } - return agg_output + metrics.log_scalar('loss', agg_loss / sample_size / math.log(2) if sample_size > 0 else 0., sample_size, round=3) + metrics.log_scalar('lm_loss', lm_loss_sum / ntokens / math.log(2) if ntokens > 0 else 0., ntokens, round=3) + metrics.log_scalar('sentence_loss', sentence_loss_sum / nsentences / math.log(2) if nsentences > 0 else 0., nsentences, round=3) + metrics.log_scalar('nll_loss', lm_loss_sum / ntokens / math.log(2) if ntokens > 0 else 0., ntokens, round=3) @staticmethod def logging_outputs_can_be_summed() -> bool: diff --git a/fairseq/data/data_utils_fast.pyx b/fairseq/data/data_utils_fast.pyx index c1f97bf5b..38b4aa67d 100644 --- a/fairseq/data/data_utils_fast.pyx +++ b/fairseq/data/data_utils_fast.pyx @@ -9,11 +9,12 @@ import numpy as np cimport cython cimport numpy as np -DTYPE = np.int64 -ctypedef np.int64_t DTYPE_t +from libc.stdint cimport int32_t, int64_t + +ctypedef int64_t DTYPE_t -cdef _is_batch_full(long num_sentences, long num_tokens, long max_tokens, long max_sentences): +cdef _is_batch_full(int64_t num_sentences, int64_t num_tokens, int64_t max_tokens, int64_t max_sentences): if num_sentences == 0: return 0 if max_sentences > 0 and num_sentences == max_sentences: @@ -27,18 +28,18 @@ cdef _is_batch_full(long num_sentences, long num_tokens, long max_tokens, long m cpdef list batch_by_size_fast( np.ndarray[DTYPE_t, ndim=1] indices, num_tokens_fn, - long max_tokens, - long max_sentences, - int bsz_mult, + int64_t max_tokens, + int64_t max_sentences, + int32_t bsz_mult, ): - cdef long sample_len = 0 + cdef int64_t sample_len = 0 cdef list sample_lens = [] cdef list batch = [] cdef list batches = [] - cdef long mod_len - cdef long i - cdef long idx - cdef long num_tokens + cdef int64_t mod_len + cdef int64_t i + cdef int64_t idx + cdef int64_t num_tokens cdef DTYPE_t[:] indices_view = indices for i in range(len(indices_view)): @@ -70,8 +71,8 @@ cpdef list batch_by_size_fast( cdef _find_valid_shape( DTYPE_t[:, :] shapes_view, - long num_sentences, - long num_tokens, + int64_t num_sentences, + int64_t num_tokens, ): """Return index of first valid shape of -1 if none is found.""" for i in range(shapes_view.shape[0]): @@ -86,14 +87,14 @@ cpdef list batch_fixed_shapes_fast( num_tokens_fn, np.ndarray[DTYPE_t, ndim=2] fixed_shapes_sorted, ): - cdef long sample_len = 0 + cdef int64_t sample_len = 0 cdef list sample_lens = [] cdef list batch = [] cdef list batches = [] - cdef long mod_len - cdef long i - cdef long idx - cdef long num_tokens + cdef int64_t mod_len + cdef int64_t i + cdef int64_t idx + cdef int64_t num_tokens cdef DTYPE_t[:] indices_view = indices cdef DTYPE_t[:, :] shapes_view = fixed_shapes_sorted diff --git a/fairseq/data/token_block_utils_fast.pyx b/fairseq/data/token_block_utils_fast.pyx index 5563b973e..5a2f16ec3 100644 --- a/fairseq/data/token_block_utils_fast.pyx +++ b/fairseq/data/token_block_utils_fast.pyx @@ -12,8 +12,10 @@ from libc.math cimport ceil cimport cython cimport numpy as np +from libc.stdint cimport int32_t, int64_t + DTYPE = np.int64 -ctypedef np.int64_t DTYPE_t +ctypedef int64_t DTYPE_t @cython.boundscheck(False) diff --git a/fairseq/iterative_refinement_generator.py b/fairseq/iterative_refinement_generator.py index c7a267d25..97e66fabe 100644 --- a/fairseq/iterative_refinement_generator.py +++ b/fairseq/iterative_refinement_generator.py @@ -266,7 +266,7 @@ class IterativeRefinementGenerator(object): if decoder_out.history is not None else None, ) - encoder_out = model.encoder.reorder_encoder_out(encoder_out, not_terminated.nonzero().squeeze()) + encoder_out = model.encoder.reorder_encoder_out(encoder_out, not_terminated.nonzero(as_tuple=False).squeeze()) sent_idxs = sent_idxs[not_terminated] prev_output_tokens = prev_decoder_out.output_tokens.clone() diff --git a/fairseq/model_parallel/modules/transformer_sentence_encoder_layer.py b/fairseq/model_parallel/modules/transformer_sentence_encoder_layer.py index 0e1ea2b7d..d09158b7f 100644 --- a/fairseq/model_parallel/modules/transformer_sentence_encoder_layer.py +++ b/fairseq/model_parallel/modules/transformer_sentence_encoder_layer.py @@ -66,14 +66,14 @@ class ModelParallelTransformerSentenceEncoderLayer(TransformerSentenceEncoderLay need_weights=False, attn_mask=self_attn_mask, ) - x = F.dropout(x, p=self.dropout, training=self.training) + x = self.dropout_module(x) x = residual + x residual = x x = self.final_layer_norm(x) x = self.activation_fn(self.fc1(x)) - x = F.dropout(x, p=self.activation_dropout, training=self.training) + x = self.activation_dropout_module(x) x = self.fc2(x) - x = F.dropout(x, p=self.dropout, training=self.training) + x = self.dropout_module(x) x = residual + x return x, None diff --git a/fairseq/models/nat/levenshtein_utils.py b/fairseq/models/nat/levenshtein_utils.py index e29b1fa27..11fb29578 100644 --- a/fairseq/models/nat/levenshtein_utils.py +++ b/fairseq/models/nat/levenshtein_utils.py @@ -250,7 +250,7 @@ def _skip_encoder_out(encoder, encoder_out, mask): if not mask.any(): return encoder_out else: - return encoder.reorder_encoder_out(encoder_out, mask.nonzero().squeeze()) + return encoder.reorder_encoder_out(encoder_out, mask.nonzero(as_tuple=False).squeeze()) def _fill(x, mask, y, padding_idx): diff --git a/fairseq/modules/adaptive_softmax.py b/fairseq/modules/adaptive_softmax.py index 96f8b75ad..8e47134a7 100644 --- a/fairseq/modules/adaptive_softmax.py +++ b/fairseq/modules/adaptive_softmax.py @@ -144,7 +144,7 @@ class AdaptiveSoftmax(nn.Module): new_target[0][mask] = self.cutoff[0] + i if mask.any(): - target_idxs.append(mask.nonzero().squeeze(1)) + target_idxs.append(mask.nonzero(as_tuple=False).squeeze(1)) new_target.append(target[mask].add(-self.cutoff[i])) else: target_idxs.append(None) diff --git a/fairseq/scoring/bleu.py b/fairseq/scoring/bleu.py index 476c0f047..15275d94c 100644 --- a/fairseq/scoring/bleu.py +++ b/fairseq/scoring/bleu.py @@ -11,15 +11,6 @@ import torch from fairseq.scoring import register_scoring -try: - from fairseq import libbleu -except ImportError as e: - sys.stderr.write("ERROR: missing libbleu.so. run `pip install --editable .`\n") - raise e - - -C = ctypes.cdll.LoadLibrary(libbleu.__file__) - class BleuStat(ctypes.Structure): _fields_ = [ @@ -70,13 +61,22 @@ class Scorer(object): self.pad = pad self.eos = eos self.unk = unk + + try: + from fairseq import libbleu + except ImportError as e: + sys.stderr.write("ERROR: missing libbleu.so. run `pip install --editable .`\n") + raise e + + self.C = ctypes.cdll.LoadLibrary(libbleu.__file__) + self.reset() def reset(self, one_init=False): if one_init: - C.bleu_one_init(ctypes.byref(self.stat)) + self.C.bleu_one_init(ctypes.byref(self.stat)) else: - C.bleu_zero_init(ctypes.byref(self.stat)) + self.C.bleu_zero_init(ctypes.byref(self.stat)) def add(self, ref, pred): if not isinstance(ref, torch.IntTensor): @@ -92,7 +92,7 @@ class Scorer(object): rref = rref.contiguous().view(-1) pred = pred.contiguous().view(-1) - C.bleu_add( + self.C.bleu_add( ctypes.byref(self.stat), ctypes.c_size_t(rref.size(0)), ctypes.c_void_p(rref.data_ptr()), diff --git a/fairseq/scoring/wer.py b/fairseq/scoring/wer.py index 4e09e4561..3aee5f69d 100644 --- a/fairseq/scoring/wer.py +++ b/fairseq/scoring/wer.py @@ -3,8 +3,6 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -import editdistance - from fairseq.scoring import register_scoring @@ -18,6 +16,7 @@ class WerScorer(object): self.ref_length = 0 def add_string(self, ref, pred): + import editdistance ref_items = ref.split() pred_items = pred.split() self.distance += editdistance.eval(ref_items, pred_items) diff --git a/fairseq/search.py b/fairseq/search.py index 9e18581a9..8aa196a3c 100644 --- a/fairseq/search.py +++ b/fairseq/search.py @@ -133,7 +133,9 @@ class DiverseBeamSearch(Search): # apply diversity penalty if g > 0: lprobs_g = torch.add( - lprobs_g, self.diversity_strength, diversity_buf.unsqueeze(1) + lprobs_g, + other=diversity_buf.unsqueeze(1), + alpha=self.diversity_strength, ) else: lprobs_g = lprobs_g.contiguous() diff --git a/fairseq/sequence_generator.py b/fairseq/sequence_generator.py index 42012fbbb..26e4c287b 100644 --- a/fairseq/sequence_generator.py +++ b/fairseq/sequence_generator.py @@ -183,7 +183,11 @@ class SequenceGenerator(nn.Module): src_lengths = (src_tokens.ne(self.eos) & src_tokens.ne(self.pad)).long().sum(dim=1) elif 'source' in net_input: src_tokens = net_input['source'] - src_lengths = net_input['padding_mask'].size(-1) - net_input['padding_mask'].sum(-1) if net_input['padding_mask'] is not None else torch.tensor(src_tokens.size(-1)) + src_lengths = ( + net_input['padding_mask'].size(-1) - net_input['padding_mask'].sum(-1) + if net_input['padding_mask'] is not None + else torch.tensor(src_tokens.size(-1)).to(src_tokens) + ) else: raise Exception('expected src_tokens or source in net input') @@ -372,11 +376,10 @@ class SequenceGenerator(nn.Module): new_bsz = bsz - len(finalized_sents) # construct batch_idxs which holds indices of batches to keep for the next pass - batch_mask = torch.ones(bsz).to(cand_indices) - batch_mask[ - torch.tensor(finalized_sents).to(cand_indices) - ] = torch.tensor(0).to(batch_mask) - batch_idxs = batch_mask.nonzero().squeeze(-1) + batch_mask = torch.ones(bsz, dtype=torch.bool, device=cand_indices.device) + batch_mask[finalized_sents] = False + # TODO replace `nonzero(as_tuple=False)` after TorchScript supports it + batch_idxs = torch.arange(bsz, device=cand_indices.device).masked_select(batch_mask) eos_mask = eos_mask[batch_idxs] cand_beams = cand_beams[batch_idxs] @@ -665,7 +668,7 @@ class SequenceGenerator(nn.Module): for bbsz_idx in range(bsz * beam_size): lprobs[bbsz_idx][ torch.tensor(banned_tokens[bbsz_idx]).long() - ] = torch.tensor(-math.inf, dtype=torch.float) + ] = torch.tensor(-math.inf).to(lprobs) return lprobs diff --git a/fairseq/tasks/masked_lm.py b/fairseq/tasks/masked_lm.py index 4d7ea54b6..4a6e6a2d3 100644 --- a/fairseq/tasks/masked_lm.py +++ b/fairseq/tasks/masked_lm.py @@ -16,7 +16,7 @@ from fairseq.data import ( NestedDictionaryDataset, NumelDataset, NumSamplesDataset, - PadDataset, + RightPadDataset, PrependTokenDataset, SortDataset, TokenBlockDataset, @@ -150,17 +150,15 @@ class MaskedLMTask(FairseqTask): { 'id': IdDataset(), 'net_input': { - 'src_tokens': PadDataset( + 'src_tokens': RightPadDataset( src_dataset, pad_idx=self.source_dictionary.pad(), - left_pad=False, ), 'src_lengths': NumelDataset(src_dataset, reduce=False), }, - 'target': PadDataset( + 'target': RightPadDataset( tgt_dataset, pad_idx=self.source_dictionary.pad(), - left_pad=False, ), 'nsentences': NumSamplesDataset(), 'ntokens': NumelDataset(src_dataset, reduce=True), @@ -174,7 +172,7 @@ class MaskedLMTask(FairseqTask): ) def build_dataset_for_inference(self, src_tokens, src_lengths, sort=True): - src_dataset = PadDataset( + src_dataset = RightPadDataset( TokenBlockDataset( src_tokens, src_lengths, @@ -184,7 +182,6 @@ class MaskedLMTask(FairseqTask): break_mode='eos', ), pad_idx=self.source_dictionary.pad(), - left_pad=False, ) src_dataset = PrependTokenDataset(src_dataset, self.source_dictionary.bos()) src_dataset = NestedDictionaryDataset( diff --git a/fairseq/tasks/sentence_prediction.py b/fairseq/tasks/sentence_prediction.py index b50c9922c..cf5eae38b 100644 --- a/fairseq/tasks/sentence_prediction.py +++ b/fairseq/tasks/sentence_prediction.py @@ -192,16 +192,20 @@ class SentencePredictionTask(FairseqTask): else: label_path = "{0}.label".format(get_path('label', split)) if os.path.exists(label_path): + def parse_regression_target(i, line): values = line.split() assert len(values) == self.args.num_classes, \ f'expected num_classes={self.args.num_classes} regression target values on line {i}, found: "{line}"' return [float(x) for x in values] - dataset.update( - target=RawLabelDataset([ - parse_regression_target(i, line.strip()) for i, line in enumerate(open(label_path).readlines()) - ]) - ) + + with open(label_path) as h: + dataset.update( + target=RawLabelDataset([ + parse_regression_target(i, line.strip()) + for i, line in enumerate(h.readlines()) + ]) + ) nested_dataset = NestedDictionaryDataset( dataset, diff --git a/fairseq/utils.py b/fairseq/utils.py index 2531896e5..d10ed2f28 100644 --- a/fairseq/utils.py +++ b/fairseq/utils.py @@ -527,8 +527,8 @@ def get_token_to_word_mapping(tokens, exclude_list): def extract_hard_alignment(attn, src_sent, tgt_sent, pad, eos): - tgt_valid = ((tgt_sent != pad) & (tgt_sent != eos)).nonzero().squeeze(dim=-1) - src_invalid = ((src_sent == pad) | (src_sent == eos)).nonzero().squeeze(dim=-1) + tgt_valid = ((tgt_sent != pad) & (tgt_sent != eos)).nonzero(as_tuple=False).squeeze(dim=-1) + src_invalid = ((src_sent == pad) | (src_sent == eos)).nonzero(as_tuple=False).squeeze(dim=-1) src_token_to_word = get_token_to_word_mapping(src_sent, [eos, pad]) tgt_token_to_word = get_token_to_word_mapping(tgt_sent, [eos, pad]) alignment = [] diff --git a/tests/test_inference_dropout.py b/tests/test_inference_dropout.py index 89e05473f..4857bc7a8 100644 --- a/tests/test_inference_dropout.py +++ b/tests/test_inference_dropout.py @@ -3,6 +3,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import logging import unittest from tests.test_sequence_generator import get_dummy_task_and_parser @@ -17,6 +18,10 @@ class TestInferenceDropout(unittest.TestCase): self.args = self.parser.parse_args([]) self.args.encoder_layers = 2 self.args.decoder_layers = 1 + logging.disable(logging.CRITICAL) + + def tearDown(self): + logging.disable(logging.NOTSET) def test_sets_inference_dropout_to_true(self): self.args.retain_dropout = True diff --git a/tests/test_train.py b/tests/test_train.py index 5be74e415..048acaca5 100644 --- a/tests/test_train.py +++ b/tests/test_train.py @@ -4,8 +4,9 @@ # LICENSE file in the root directory of this source tree. import contextlib -from io import StringIO +import logging import unittest +from io import StringIO from unittest.mock import MagicMock, patch import torch @@ -74,6 +75,11 @@ class TestLoadCheckpoint(unittest.TestCase): } self.applied_patches = [patch(p, d) for p, d in self.patches.items()] [p.start() for p in self.applied_patches] + logging.disable(logging.CRITICAL) + + def tearDown(self): + patch.stopall() + logging.disable(logging.NOTSET) def test_load_partial_checkpoint(self): with contextlib.redirect_stdout(StringIO()): @@ -192,9 +198,6 @@ class TestLoadCheckpoint(unittest.TestCase): self.assertFalse(reset_lr_scheduler) self.assertFalse(reset_meters) - def tearDown(self): - patch.stopall() - if __name__ == '__main__': unittest.main()