diff --git a/docs/optim.rst b/docs/optim.rst index 67370b335..c3326456b 100644 --- a/docs/optim.rst +++ b/docs/optim.rst @@ -15,9 +15,15 @@ Optimizers update the Model parameters based on the gradients. :members: :undoc-members: +.. autoclass:: fairseq.optim.adadelta.Adadelta + :members: + :undoc-members: .. autoclass:: fairseq.optim.adagrad.Adagrad :members: :undoc-members: +.. autoclass:: fairseq.optim.adafactor.FairseqAdafactor + :members: + :undoc-members: .. autoclass:: fairseq.optim.adam.FairseqAdam :members: :undoc-members: diff --git a/docs/overview.rst b/docs/overview.rst index 2da90ef4e..e8b7aaaf9 100644 --- a/docs/overview.rst +++ b/docs/overview.rst @@ -28,11 +28,12 @@ fairseq implements the following high-level training flow:: lr_scheduler.step_update(num_updates) lr_scheduler.step(epoch) -where the default implementation for ``train.train_step`` is roughly:: +where the default implementation for ``task.train_step`` is roughly:: def train_step(self, batch, model, criterion, optimizer): loss = criterion(model, batch) optimizer.backward(loss) + return loss **Registering new plug-ins** diff --git a/docs/tutorial_classifying_names.rst b/docs/tutorial_classifying_names.rst index cdec9b0a1..eb974133d 100644 --- a/docs/tutorial_classifying_names.rst +++ b/docs/tutorial_classifying_names.rst @@ -354,7 +354,7 @@ The model files should appear in the :file:`checkpoints/` directory. Finally we can write a short script to evaluate our model on new inputs. Create a new file named :file:`eval_classifier.py` with the following contents:: - from fairseq import data, options, tasks, utils + from fairseq import checkpoint_utils, data, options, tasks # Parse command-line arguments for generation parser = options.get_generation_parser(default_task='simple_classification') @@ -365,7 +365,7 @@ a new file named :file:`eval_classifier.py` with the following contents:: # Load model print('| loading model from {}'.format(args.path)) - models, _model_args = utils.load_ensemble_for_inference([args.path], task) + models, _model_args = checkpoint_utils.load_model_ensemble([args.path], task=task) model = models[0] while True: diff --git a/eval_lm.py b/eval_lm.py index 726dfa074..34c9a5485 100644 --- a/eval_lm.py +++ b/eval_lm.py @@ -13,11 +13,10 @@ Evaluate the perplexity of a trained language model. import numpy as np import torch -from fairseq import options, progress_bar, tasks, utils +from fairseq import checkpoint_utils, options, progress_bar, tasks, utils from fairseq.data import LMContextWindowDataset from fairseq.meters import StopwatchMeter, TimeMeter from fairseq.sequence_scorer import SequenceScorer -from fairseq.utils import import_user_module class WordStat(object): @@ -49,7 +48,7 @@ class WordStat(object): def main(parsed_args): assert parsed_args.path is not None, '--path required for evaluation!' - import_user_module(parsed_args) + utils.import_user_module(parsed_args) print(parsed_args) @@ -59,12 +58,17 @@ def main(parsed_args): # Load ensemble print('| loading model(s) from {}'.format(parsed_args.path)) - models, args = utils.load_ensemble_for_inference( - parsed_args.path.split(':'), task, model_arg_overrides=eval(parsed_args.model_overrides), + models, args = checkpoint_utils.load_model_ensemble( + parsed_args.path.split(':'), + arg_overrides=eval(parsed_args.model_overrides), + task=task, ) for arg in vars(parsed_args).keys(): - if arg not in {'self_target', 'future_target', 'past_target', 'tokens_per_sample', 'output_size_dictionary'}: + if arg not in { + 'self_target', 'future_target', 'past_target', 'tokens_per_sample', + 'output_size_dictionary', 'add_bos_token', + }: setattr(args, arg, getattr(parsed_args, arg)) # reduce tokens per sample by the required context window size @@ -151,6 +155,11 @@ def main(parsed_args): tgt_len = tokens.numel() pos_scores = hypo['positional_scores'].float() + if args.add_bos_token: + assert hypo['tokens'][0].item() == task.target_dictionary.bos() + tokens = tokens[1:] + pos_scores = pos_scores[1:] + skipped_toks = 0 if bpe_toks is not None: for i in range(tgt_len - 1): diff --git a/examples/language_model/README.md b/examples/language_model/README.md index e785f86a9..4929e98f9 100644 --- a/examples/language_model/README.md +++ b/examples/language_model/README.md @@ -39,7 +39,7 @@ $ fairseq-train --task language_modeling data-bin/wikitext-103 \ --save-dir checkpoints/transformer_wikitext-103 --arch transformer_lm_wiki103 \ --max-update 286000 --max-lr 1.0 --t-mult 2 --lr-period-updates 270000 --lr-scheduler cosine --lr-shrink 0.75 \ --warmup-updates 16000 --warmup-init-lr 1e-07 --min-lr 1e-09 --optimizer nag --lr 0.0001 --clip-norm 0.1 \ - --criterion adaptive_loss --max-tokens 3072 --update-freq 4 --tokens-per-sample 3072 --seed 1 \ + --criterion adaptive_loss --max-tokens 3072 --update-freq 3 --tokens-per-sample 3072 --seed 1 \ --sample-break-mode none --skip-invalid-size-inputs-valid-test --ddp-backend=no_c10d # Evaluate: diff --git a/fairseq/checkpoint_utils.py b/fairseq/checkpoint_utils.py new file mode 100644 index 000000000..4f8237c0d --- /dev/null +++ b/fairseq/checkpoint_utils.py @@ -0,0 +1,177 @@ +# Copyright (c) 2017-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the LICENSE file in +# the root directory of this source tree. An additional grant of patent rights +# can be found in the PATENTS file in the same directory. + +from collections import OrderedDict +import logging +import os +import re +import traceback + +import torch +from torch.serialization import default_restore_location + +from fairseq import tasks + + +def load_checkpoint_to_cpu(path): + """Loads a checkpoint to CPU (with upgrading for backward compatibility).""" + state = torch.load( + path, map_location=lambda s, l: default_restore_location(s, 'cpu'), + ) + state = _upgrade_state_dict(state) + return state + + +def load_model_ensemble(filenames, arg_overrides=None, task=None): + """Loads an ensemble of models. + + Args: + filenames (List[str]): checkpoint files to load + arg_overrides (Dict[str,Any], optional): override model args that + were used during model training + task (fairseq.tasks.FairseqTask, optional): task to use for loading + """ + ensemble = [] + for filename in filenames: + if not os.path.exists(filename): + raise IOError('Model file not found: {}'.format(filename)) + state = load_checkpoint_to_cpu(filename) + + args = state['args'] + if arg_overrides is not None: + for arg_name, arg_val in arg_overrides.items(): + setattr(args, arg_name, arg_val) + + if task is None: + task = tasks.setup_task(args) + + # build model for ensemble + model = task.build_model(args) + model.load_state_dict(state['model'], strict=True) + ensemble.append(model) + + return ensemble, args + + +def checkpoint_paths(path, pattern=r'checkpoint(\d+)\.pt'): + """Retrieves all checkpoints found in `path` directory. + + Checkpoints are identified by matching filename to the specified pattern. If + the pattern contains groups, the result will be sorted by the first group in + descending order. + """ + pt_regexp = re.compile(pattern) + files = os.listdir(path) + + entries = [] + for i, f in enumerate(files): + m = pt_regexp.fullmatch(f) + if m is not None: + idx = int(m.group(1)) if len(m.groups()) > 0 else i + entries.append((idx, m.group(0))) + return [os.path.join(path, x[1]) for x in sorted(entries, reverse=True)] + + +def torch_persistent_save(*args, **kwargs): + for i in range(3): + try: + return torch.save(*args, **kwargs) + except Exception: + if i == 2: + logging.error(traceback.format_exc()) + + +def convert_state_dict_type(state_dict, ttype=torch.FloatTensor): + if isinstance(state_dict, dict): + cpu_dict = OrderedDict() + for k, v in state_dict.items(): + cpu_dict[k] = convert_state_dict_type(v) + return cpu_dict + elif isinstance(state_dict, list): + return [convert_state_dict_type(v) for v in state_dict] + elif torch.is_tensor(state_dict): + return state_dict.type(ttype) + else: + return state_dict + + +def save_state( + filename, args, model_state_dict, criterion, optimizer, lr_scheduler, + num_updates, optim_history=None, extra_state=None, +): + if optim_history is None: + optim_history = [] + if extra_state is None: + extra_state = {} + state_dict = { + 'args': args, + 'model': model_state_dict if model_state_dict else {}, + 'optimizer_history': optim_history + [ + { + 'criterion_name': criterion.__class__.__name__, + 'optimizer_name': optimizer.__class__.__name__, + 'lr_scheduler_state': lr_scheduler.state_dict(), + 'num_updates': num_updates, + } + ], + 'last_optimizer_state': convert_state_dict_type(optimizer.state_dict()), + 'extra_state': extra_state, + } + torch_persistent_save(state_dict, filename) + + +def _upgrade_state_dict(state): + """Helper for upgrading old model checkpoints.""" + # add optimizer_history + if 'optimizer_history' not in state: + state['optimizer_history'] = [ + { + 'criterion_name': 'CrossEntropyCriterion', + 'best_loss': state['best_loss'], + }, + ] + state['last_optimizer_state'] = state['optimizer'] + del state['optimizer'] + del state['best_loss'] + # move extra_state into sub-dictionary + if 'epoch' in state and 'extra_state' not in state: + state['extra_state'] = { + 'epoch': state['epoch'], + 'batch_offset': state['batch_offset'], + 'val_loss': state['val_loss'], + } + del state['epoch'] + del state['batch_offset'] + del state['val_loss'] + # reduce optimizer history's memory usage (only keep the last state) + if 'optimizer' in state['optimizer_history'][-1]: + state['last_optimizer_state'] = state['optimizer_history'][-1]['optimizer'] + for optim_hist in state['optimizer_history']: + del optim_hist['optimizer'] + # record the optimizer class name + if 'optimizer_name' not in state['optimizer_history'][-1]: + state['optimizer_history'][-1]['optimizer_name'] = 'FairseqNAG' + # move best_loss into lr_scheduler_state + if 'lr_scheduler_state' not in state['optimizer_history'][-1]: + state['optimizer_history'][-1]['lr_scheduler_state'] = { + 'best': state['optimizer_history'][-1]['best_loss'], + } + del state['optimizer_history'][-1]['best_loss'] + # keep track of number of updates + if 'num_updates' not in state['optimizer_history'][-1]: + state['optimizer_history'][-1]['num_updates'] = 0 + # old model checkpoints may not have separate source/target positions + if hasattr(state['args'], 'max_positions') and not hasattr(state['args'], 'max_source_positions'): + state['args'].max_source_positions = state['args'].max_positions + state['args'].max_target_positions = state['args'].max_positions + # use stateful training data iterator + if 'train_iterator' not in state['extra_state']: + state['extra_state']['train_iterator'] = { + 'epoch': state['extra_state']['epoch'], + 'iterations_in_epoch': state['extra_state'].get('batch_offset', 0), + } + return state diff --git a/fairseq/data/dictionary.py b/fairseq/data/dictionary.py index a4b724d01..19fa644aa 100644 --- a/fairseq/data/dictionary.py +++ b/fairseq/data/dictionary.py @@ -18,13 +18,12 @@ from fairseq.data import data_utils class Dictionary(object): """A mapping from symbols to consecutive integers""" - def __init__(self, pad='', eos='', unk=''): + def __init__(self, pad='', eos='', unk='', bos=''): self.unk_word, self.pad_word, self.eos_word = unk, pad, eos self.symbols = [] self.count = [] self.indices = {} - # dictionary indexing starts at 1 for consistency with Lua - self.add_symbol('') + self.bos_index = self.add_symbol(bos) self.pad_index = self.add_symbol(pad) self.eos_index = self.add_symbol(eos) self.unk_index = self.add_symbol(unk) @@ -143,6 +142,10 @@ class Dictionary(object): self.symbols = list(new_symbols) self.indices = new_indices + def bos(self): + """Helper to get index of beginning-of-sentence symbol""" + return self.bos_index + def pad(self): """Helper to get index of pad symbol""" return self.pad_index diff --git a/fairseq/data/monolingual_dataset.py b/fairseq/data/monolingual_dataset.py index 13e89955c..e75810c4a 100644 --- a/fairseq/data/monolingual_dataset.py +++ b/fairseq/data/monolingual_dataset.py @@ -62,13 +62,14 @@ class MonolingualDataset(FairseqDataset): """ def __init__(self, dataset, sizes, src_vocab, tgt_vocab, add_eos_for_other_targets, shuffle, - targets=None): + targets=None, add_bos_token=False): self.dataset = dataset self.sizes = np.array(sizes) self.vocab = src_vocab self.tgt_vocab = tgt_vocab self.add_eos_for_other_targets = add_eos_for_other_targets self.shuffle = shuffle + self.add_bos_token = add_bos_token assert targets is None or all(t in {'self', 'future', 'past'} for t in targets), \ "targets must be none or one of 'self', 'future', 'past'" @@ -91,6 +92,7 @@ class MonolingualDataset(FairseqDataset): else: source = self.dataset[index] target = None + source, target = self._maybe_add_bos(source, target) return {'id': index, 'source': source, 'target': target} def __len__(self): @@ -129,6 +131,13 @@ class MonolingualDataset(FairseqDataset): return source, self._filter_vocab(target) + def _maybe_add_bos(self, source, target): + if self.add_bos_token: + source = torch.cat([source.new([self.vocab.bos()]), source]) + if target is not None: + target = torch.cat([target.new([self.tgt_vocab.bos()]), target]) + return source, target + def _filter_vocab(self, target): if len(self.tgt_vocab) != len(self.vocab): def _filter(target): @@ -173,6 +182,7 @@ class MonolingualDataset(FairseqDataset): target = self.vocab.dummy_sentence(tgt_len + 2) source, past_target, future_target = target[1:-1], target[2:], target[:-2] source, target = self._make_source_target(source, past_target, future_target) + source, target = self._maybe_add_bos(source, target) return self.collater([ {'id': i, 'source': source, 'target': target} diff --git a/fairseq/models/fconv.py b/fairseq/models/fconv.py index f7d69c3f1..afe42a67a 100644 --- a/fairseq/models/fconv.py +++ b/fairseq/models/fconv.py @@ -141,7 +141,7 @@ class FConvLanguageModel(FairseqLanguageModel): # make sure all arguments are present in older models base_lm_architecture(args) - if hasattr(args, 'max_target_positions'): + if hasattr(args, 'max_target_positions') and not hasattr(args, 'tokens_per_sample'): args.tokens_per_sample = args.max_target_positions decoder = FConvDecoder( diff --git a/fairseq/models/fconv_self_att.py b/fairseq/models/fconv_self_att.py index 51127bba8..602706e0e 100644 --- a/fairseq/models/fconv_self_att.py +++ b/fairseq/models/fconv_self_att.py @@ -4,7 +4,6 @@ # This source code is licensed under the license found in the LICENSE file in # the root directory of this source tree. An additional grant of patent rights # can be found in the PATENTS file in the same directory. -# import math @@ -12,11 +11,11 @@ import torch import torch.nn as nn import torch.nn.functional as F +from fairseq import checkpoint_utils from fairseq.modules import ( DownsampledMultiHeadAttention, GradMultiply, LayerNorm, LearnedPositionalEmbedding, LinearizedConvolution, ) -from fairseq import utils from . import ( FairseqEncoder, CompositeEncoder, FairseqDecoder, FairseqModel, @@ -84,8 +83,7 @@ class FConvModelSelfAtt(FairseqModel): pretrained = eval(args.pretrained) if pretrained: print("| loading pretrained model") - trained_model = utils.load_ensemble_for_inference( - # not actually for inference, but loads pretrained model parameters + trained_model = checkpoint_utils.load_model_ensemble( filenames=[args.pretrained_checkpoint], task=task, )[0][0] diff --git a/fairseq/models/transformer.py b/fairseq/models/transformer.py index ee02a6208..c9c4dc2ac 100644 --- a/fairseq/models/transformer.py +++ b/fairseq/models/transformer.py @@ -830,6 +830,7 @@ def base_lm_architecture(args): args.decoder_learned_pos = getattr(args, 'decoder_learned_pos', False) args.activation_fn = getattr(args, 'activation_fn', 'relu') + args.add_bos_token = getattr(args, 'add_bos_token', False) args.character_embeddings = getattr(args, 'character_embeddings', False) args.decoder_output_dim = getattr(args, 'decoder_output_dim', args.decoder_embed_dim) @@ -927,7 +928,7 @@ def transformer_wmt_en_de(args): base_architecture(args) -# parameters used in the "Attention Is All You Need" paper (Vaswani, et al, 2017) +# parameters used in the "Attention Is All You Need" paper (Vaswani et al., 2017) @register_model_architecture('transformer', 'transformer_vaswani_wmt_en_de_big') def transformer_vaswani_wmt_en_de_big(args): args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 1024) diff --git a/fairseq/models/transformer_from_pretrained_xlm.py b/fairseq/models/transformer_from_pretrained_xlm.py index 769a2762f..f151c3db2 100644 --- a/fairseq/models/transformer_from_pretrained_xlm.py +++ b/fairseq/models/transformer_from_pretrained_xlm.py @@ -8,7 +8,7 @@ import os from typing import Any, Dict -from fairseq import utils +from fairseq import checkpoint_utils from fairseq.data.masked_lm_dictionary import MaskedLMDictionary from fairseq.models.transformer import ( TransformerDecoder, @@ -92,7 +92,7 @@ def upgrade_state_dict_with_xlm_weights( if not os.path.exists(pretrained_xlm_checkpoint): raise IOError(f"Model file not found: {pretrained_xlm_checkpoint}") - state = utils.load_checkpoint_to_cpu(pretrained_xlm_checkpoint) + state = checkpoint_utils.load_checkpoint_to_cpu(pretrained_xlm_checkpoint) xlm_state_dict = state["model"] for key in xlm_state_dict.keys(): diff --git a/fairseq/modules/gelu.py b/fairseq/modules/gelu.py index 449f289ba..35ca84afd 100644 --- a/fairseq/modules/gelu.py +++ b/fairseq/modules/gelu.py @@ -4,17 +4,15 @@ # This source code is licensed under the license found in the LICENSE file in # the root directory of this source tree. An additional grant of patent rights # can be found in the PATENTS file in the same directory. - -import math - -import torch - - """ See "Gaussian Error Linear Units (GELUs)" by Dan Hendrycks and Kevin Gimpel with the corresponding GitHub repo: https://github.com/hendrycks/GELUs """ +import math + +import torch + def gelu_fast(x): if not hasattr(gelu_fast, "_a"): diff --git a/fairseq/optim/adadelta.py b/fairseq/optim/adadelta.py index 409170e22..c175d51d3 100644 --- a/fairseq/optim/adadelta.py +++ b/fairseq/optim/adadelta.py @@ -19,11 +19,15 @@ class Adadelta(FairseqOptimizer): @staticmethod def add_args(parser): """Add optimizer-specific arguments to the parser.""" + # fmt: off parser.add_argument('--adadelta-rho', type=float, default=0.9, metavar='RHO', help='coefficient used for computing a running average of squared gradients') parser.add_argument('--adadelta-eps', type=float, default=1e-6, metavar='EPS', help='term added to the denominator to improve numerical stability') + parser.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD', + help='weight decay') parser.add_argument('--anneal-eps', action='store_true', help='flag to anneal eps') + # fmt: on @property def optimizer_config(self): diff --git a/fairseq/optim/adafactor.py b/fairseq/optim/adafactor.py index 982c077db..0555ae6d6 100644 --- a/fairseq/optim/adafactor.py +++ b/fairseq/optim/adafactor.py @@ -21,6 +21,7 @@ class FairseqAdafactor(FairseqOptimizer): @staticmethod def add_args(parser): """Add optimizer-specific arguments to the parser.""" + # fmt: off parser.add_argument('--adafactor-eps', default='(1e-30, 1e-3)', metavar="E", help='epsilons for Adafactor optimizer') parser.add_argument('--clip-threshold', type=float, default=1.0, metavar="C", @@ -31,11 +32,14 @@ class FairseqAdafactor(FairseqOptimizer): help='beta for first moment estimator. Optional') parser.add_argument('--scale-parameter', action='store_true', help='scale learning rate by root mean square of parameter.') + parser.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD', + help='weight decay') parser.add_argument('--warmup-init', action='store_true', help='use relative step for warm-up learning rate schedule') parser.add_argument('--relative-step', action='store_true', help='set learning rate to inverse square root of timestep.' 'If false, external learning rate applied') + # fmt: on @property def optimizer_config(self): diff --git a/fairseq/optim/adagrad.py b/fairseq/optim/adagrad.py index 50f2dea9b..deafd3df9 100644 --- a/fairseq/optim/adagrad.py +++ b/fairseq/optim/adagrad.py @@ -16,6 +16,14 @@ class Adagrad(FairseqOptimizer): super().__init__(args, params) self._optimizer = torch.optim.Adagrad(params, **self.optimizer_config) + @staticmethod + def add_args(parser): + """Add optimizer-specific arguments to the parser.""" + # fmt: off + parser.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD', + help='weight decay') + # fmt: on + @property def optimizer_config(self): """ diff --git a/fairseq/optim/adam.py b/fairseq/optim/adam.py index e6db7ee7e..379927aef 100644 --- a/fairseq/optim/adam.py +++ b/fairseq/optim/adam.py @@ -30,6 +30,8 @@ class FairseqAdam(FairseqOptimizer): help='betas for Adam optimizer') parser.add_argument('--adam-eps', type=float, default=1e-8, metavar='D', help='epsilon for Adam optimizer') + parser.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD', + help='weight decay') # fmt: on @property diff --git a/fairseq/optim/lr_scheduler/cosine_lr_scheduler.py b/fairseq/optim/lr_scheduler/cosine_lr_scheduler.py index c30a398ed..184311423 100644 --- a/fairseq/optim/lr_scheduler/cosine_lr_scheduler.py +++ b/fairseq/optim/lr_scheduler/cosine_lr_scheduler.py @@ -85,6 +85,8 @@ class CosineSchedule(FairseqLRScheduler): help='factor to grow the length of each period') parser.add_argument('--lr-period-updates', default=-1, type=float, metavar='LR', help='initial number of updates per period') + parser.add_argument('--lr-shrink', default=0.1, type=float, metavar='LS', + help='shrink factor for annealing') # fmt: on def step(self, epoch, val_loss=None): diff --git a/fairseq/optim/lr_scheduler/fixed_schedule.py b/fairseq/optim/lr_scheduler/fixed_schedule.py index 8a22cb7e3..dc65d942c 100644 --- a/fairseq/optim/lr_scheduler/fixed_schedule.py +++ b/fairseq/optim/lr_scheduler/fixed_schedule.py @@ -30,6 +30,8 @@ class FixedSchedule(FairseqLRScheduler): # fmt: off parser.add_argument('--force-anneal', '--fa', type=int, metavar='N', help='force annealing at specified epoch') + parser.add_argument('--lr-shrink', default=0.1, type=float, metavar='LS', + help='shrink factor for annealing, lr_new = (lr * lr_shrink)') parser.add_argument('--warmup-updates', default=0, type=int, metavar='N', help='warmup the learning rate linearly for the first N updates') # fmt: on diff --git a/fairseq/optim/lr_scheduler/reduce_lr_on_plateau.py b/fairseq/optim/lr_scheduler/reduce_lr_on_plateau.py index f822c8551..b4fe1a624 100644 --- a/fairseq/optim/lr_scheduler/reduce_lr_on_plateau.py +++ b/fairseq/optim/lr_scheduler/reduce_lr_on_plateau.py @@ -24,6 +24,14 @@ class ReduceLROnPlateau(FairseqLRScheduler): self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( self.optimizer.optimizer, patience=0, factor=args.lr_shrink) + @staticmethod + def add_args(parser): + """Add arguments to the parser for this LR scheduler.""" + # fmt: off + parser.add_argument('--lr-shrink', default=0.1, type=float, metavar='LS', + help='shrink factor for annealing, lr_new = (lr * lr_shrink)') + # fmt: on + def state_dict(self): """Return the LR scheduler state dict.""" return { diff --git a/fairseq/optim/lr_scheduler/triangular_lr_scheduler.py b/fairseq/optim/lr_scheduler/triangular_lr_scheduler.py index 7b3307a0a..a9ef35a24 100644 --- a/fairseq/optim/lr_scheduler/triangular_lr_scheduler.py +++ b/fairseq/optim/lr_scheduler/triangular_lr_scheduler.py @@ -46,6 +46,8 @@ class TriangularSchedule(FairseqLRScheduler): help='max learning rate, must be more than args.lr') parser.add_argument('--lr-period-updates', default=5000, type=float, metavar='LR', help='initial number of updates per period (cycle length)') + parser.add_argument('--lr-shrink', default=0.1, type=float, metavar='LS', + help='shrink factor for annealing') parser.add_argument('--shrink-min', action='store_true', help='if set, also shrinks min lr') # fmt: on diff --git a/fairseq/optim/nag.py b/fairseq/optim/nag.py index 34c685cd4..b5216578b 100644 --- a/fairseq/optim/nag.py +++ b/fairseq/optim/nag.py @@ -16,6 +16,16 @@ class FairseqNAG(FairseqOptimizer): super().__init__(args, params) self._optimizer = NAG(params, **self.optimizer_config) + @staticmethod + def add_args(parser): + """Add optimizer-specific arguments to the parser.""" + # fmt: off + parser.add_argument('--momentum', default=0.99, type=float, metavar='M', + help='momentum factor') + parser.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD', + help='weight decay') + # fmt: on + @property def optimizer_config(self): """ diff --git a/fairseq/optim/sgd.py b/fairseq/optim/sgd.py index 4304e805d..ac5435207 100644 --- a/fairseq/optim/sgd.py +++ b/fairseq/optim/sgd.py @@ -16,6 +16,16 @@ class SGD(FairseqOptimizer): super().__init__(args, params) self._optimizer = torch.optim.SGD(params, **self.optimizer_config) + @staticmethod + def add_args(parser): + """Add optimizer-specific arguments to the parser.""" + # fmt: off + parser.add_argument('--momentum', default=0.0, type=float, metavar='M', + help='momentum factor') + parser.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD', + help='weight decay') + # fmt: on + @property def optimizer_config(self): """ diff --git a/fairseq/options.py b/fairseq/options.py index 27952c551..ff5b68c6e 100644 --- a/fairseq/options.py +++ b/fairseq/options.py @@ -303,19 +303,13 @@ def add_optimization_args(parser): metavar='LR_1,LR_2,...,LR_N', help='learning rate for the first N epochs; all epochs >N using LR_N' ' (note: this may be interpreted differently depending on --lr-scheduler)') - group.add_argument('--momentum', default=0.99, type=float, metavar='M', - help='momentum factor') - group.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD', - help='weight decay') # Learning rate schedulers can be found under fairseq/optim/lr_scheduler/ - group.add_argument('--lr-scheduler', default='reduce_lr_on_plateau', + group.add_argument('--lr-scheduler', default='fixed', choices=LR_SCHEDULER_REGISTRY.keys(), help='Learning Rate Scheduler') - group.add_argument('--lr-shrink', default=0.1, type=float, metavar='LS', - help='learning rate shrink factor for annealing, lr_new = (lr * lr_shrink)') group.add_argument('--min-lr', default=-1, type=float, metavar='LR', - help='minimum learning rate') + help='stop training when the learning rate reaches this minimum') # fmt: on return group diff --git a/fairseq/tasks/language_modeling.py b/fairseq/tasks/language_modeling.py index 113e9d33c..222cb75b4 100644 --- a/fairseq/tasks/language_modeling.py +++ b/fairseq/tasks/language_modeling.py @@ -81,6 +81,8 @@ class LanguageModelingTask(FairseqTask): help='include future target') parser.add_argument('--past-target', action='store_true', help='include past target') + parser.add_argument('--add-bos-token', action='store_true', + help='prepend beginning of sentence token ()') # fmt: on def __init__(self, args, dictionary, output_dictionary, targets=None): @@ -185,7 +187,7 @@ class LanguageModelingTask(FairseqTask): self.datasets[split] = MonolingualDataset( dataset, sizes, self.dictionary, self.output_dictionary, add_eos_for_other_targets=add_eos_for_other_targets, shuffle=True, - targets=self.targets, + targets=self.targets, add_bos_token=self.args.add_bos_token, ) def build_dataset_for_inference(self, src_tokens, src_lengths): @@ -205,6 +207,7 @@ class LanguageModelingTask(FairseqTask): self.target_dictionary, add_eos_for_other_targets=False, shuffle=False, + add_bos_token=self.args.add_bos_token, ), eos=self.source_dictionary.eos(), # remove EOS since this will be used as a prefix for generation diff --git a/fairseq/tasks/translation.py b/fairseq/tasks/translation.py index d7a58eeaa..cab8b5369 100644 --- a/fairseq/tasks/translation.py +++ b/fairseq/tasks/translation.py @@ -8,7 +8,7 @@ import itertools import os -from fairseq import options, utils +from fairseq import options from fairseq.data import ( ConcatDataset, data_utils, @@ -69,24 +69,6 @@ class TranslationTask(FairseqTask): help='amount to upsample primary dataset') # fmt: on - @staticmethod - def load_pretrained_model(path, src_dict_path, tgt_dict_path, arg_overrides=None): - model = utils.load_checkpoint_to_cpu(path) - args = model['args'] - state_dict = model['model'] - args = utils.override_model_args(args, arg_overrides) - src_dict = Dictionary.load(src_dict_path) - tgt_dict = Dictionary.load(tgt_dict_path) - assert src_dict.pad() == tgt_dict.pad() - assert src_dict.eos() == tgt_dict.eos() - assert src_dict.unk() == tgt_dict.unk() - - task = TranslationTask(args, src_dict, tgt_dict) - model = task.build_model(args) - model.upgrade_state_dict(state_dict) - model.load_state_dict(state_dict, strict=True) - return model - def __init__(self, args, src_dict, tgt_dict): super().__init__(args) self.src_dict = src_dict @@ -102,6 +84,10 @@ class TranslationTask(FairseqTask): args.left_pad_source = options.eval_bool(args.left_pad_source) args.left_pad_target = options.eval_bool(args.left_pad_target) + # upgrade old checkpoints + if isinstance(args.data, str): + args.data = [args.data] + # find language pair automatically if args.source_lang is None or args.target_lang is None: args.source_lang, args.target_lang = data_utils.infer_language_pair(args.data[0]) @@ -147,9 +133,7 @@ class TranslationTask(FairseqTask): src_datasets = [] tgt_datasets = [] - data_paths = self.args.data - - for dk, data_path in enumerate(data_paths): + for dk, data_path in enumerate(self.args.data): for k in itertools.count(): split_k = split + (str(k) if k > 0 else '') diff --git a/fairseq/trainer.py b/fairseq/trainer.py index 5c50311cc..4b5ddb361 100644 --- a/fairseq/trainer.py +++ b/fairseq/trainer.py @@ -11,10 +11,11 @@ Train a network across multiple GPUs. from collections import OrderedDict from itertools import chain +import os import torch -from fairseq import distributed_utils, models, optim, utils +from fairseq import checkpoint_utils, distributed_utils, models, optim, utils from fairseq.meters import AverageMeter, StopwatchMeter, TimeMeter from fairseq.optim import lr_scheduler @@ -119,16 +120,31 @@ class Trainer(object): """Save all training state in a checkpoint file.""" if distributed_utils.is_master(self.args): # only save one checkpoint extra_state['train_meters'] = self.meters - utils.save_state( + checkpoint_utils.save_state( filename, self.args, self.get_model().state_dict(), self.criterion, self.optimizer, self.lr_scheduler, self._num_updates, self._optim_history, extra_state, ) def load_checkpoint(self, filename, reset_optimizer=False, reset_lr_scheduler=False, optimizer_overrides=None): """Load all training state from a checkpoint file.""" - extra_state, self._optim_history, last_optim_state = utils.load_model_state( - filename, self.get_model(), - ) + extra_state, self._optim_history, last_optim_state = None, [], None + + if os.path.exists(filename): + state = checkpoint_utils.load_checkpoint_to_cpu(filename) + + # load model parameters + try: + self.get_model().load_state_dict(state['model'], strict=True) + except Exception: + raise Exception( + 'Cannot load model parameters from checkpoint, ' + 'please ensure that the architectures match.' + ) + + extra_state = state['extra_state'] + self._optim_history = state['optimizer_history'] + last_optim_state = state['last_optimizer_state'] + if last_optim_state is not None and not reset_optimizer: # rebuild optimizer after loading model, since params may have changed self._build_optimizer() @@ -136,9 +152,9 @@ class Trainer(object): # only reload optimizer and lr_scheduler if they match last_optim = self._optim_history[-1] assert last_optim['criterion_name'] == self.criterion.__class__.__name__, \ - 'criterion does not match; please reset the optimizer (--reset-optimizer)' + 'Criterion does not match; please reset the optimizer (--reset-optimizer).' assert last_optim['optimizer_name'] == self.optimizer.__class__.__name__, \ - 'optimizer does not match; please reset the optimizer (--reset-optimizer)' + 'Optimizer does not match; please reset the optimizer (--reset-optimizer).' if not reset_lr_scheduler: self.lr_scheduler.load_state_dict(last_optim['lr_scheduler_state']) diff --git a/fairseq/utils.py b/fairseq/utils.py index 902d565f0..585f6d2eb 100644 --- a/fairseq/utils.py +++ b/fairseq/utils.py @@ -9,182 +9,25 @@ from collections import defaultdict, OrderedDict from typing import Callable import copy import importlib.util -import logging import os -import re import sys -import traceback import warnings import torch import torch.nn.functional as F -from torch.serialization import default_restore_location from fairseq.modules import gelu, gelu_fast -def torch_persistent_save(*args, **kwargs): - for i in range(3): - try: - return torch.save(*args, **kwargs) - except Exception: - if i == 2: - logging.error(traceback.format_exc()) - - -def convert_state_dict_type(state_dict, ttype=torch.FloatTensor): - if isinstance(state_dict, dict): - cpu_dict = OrderedDict() - for k, v in state_dict.items(): - cpu_dict[k] = convert_state_dict_type(v) - return cpu_dict - elif isinstance(state_dict, list): - return [convert_state_dict_type(v) for v in state_dict] - elif torch.is_tensor(state_dict): - return state_dict.type(ttype) - else: - return state_dict - - -def save_state(filename, args, model_state_dict, criterion, optimizer, lr_scheduler, - num_updates, optim_history=None, extra_state=None): - if optim_history is None: - optim_history = [] - if extra_state is None: - extra_state = {} - state_dict = { - 'args': args, - 'model': model_state_dict if model_state_dict else {}, - 'optimizer_history': optim_history + [ - { - 'criterion_name': criterion.__class__.__name__, - 'optimizer_name': optimizer.__class__.__name__, - 'lr_scheduler_state': lr_scheduler.state_dict(), - 'num_updates': num_updates, - } - ], - 'last_optimizer_state': convert_state_dict_type(optimizer.state_dict()), - 'extra_state': extra_state, - } - torch_persistent_save(state_dict, filename) - - -def load_model_state(filename, model): - if not os.path.exists(filename): - return None, [], None - state = torch.load(filename, map_location=lambda s, l: default_restore_location(s, 'cpu')) - state = _upgrade_state_dict(state) - model.upgrade_state_dict(state['model']) - - # load model parameters - try: - model.load_state_dict(state['model'], strict=True) - except Exception: - raise Exception('Cannot load model parameters from checkpoint, ' - 'please ensure that the architectures match') - - return state['extra_state'], state['optimizer_history'], state['last_optimizer_state'] - - -def _upgrade_state_dict(state): - """Helper for upgrading old model checkpoints.""" - # add optimizer_history - if 'optimizer_history' not in state: - state['optimizer_history'] = [ - { - 'criterion_name': 'CrossEntropyCriterion', - 'best_loss': state['best_loss'], - }, - ] - state['last_optimizer_state'] = state['optimizer'] - del state['optimizer'] - del state['best_loss'] - # move extra_state into sub-dictionary - if 'epoch' in state and 'extra_state' not in state: - state['extra_state'] = { - 'epoch': state['epoch'], - 'batch_offset': state['batch_offset'], - 'val_loss': state['val_loss'], - } - del state['epoch'] - del state['batch_offset'] - del state['val_loss'] - # reduce optimizer history's memory usage (only keep the last state) - if 'optimizer' in state['optimizer_history'][-1]: - state['last_optimizer_state'] = state['optimizer_history'][-1]['optimizer'] - for optim_hist in state['optimizer_history']: - del optim_hist['optimizer'] - # record the optimizer class name - if 'optimizer_name' not in state['optimizer_history'][-1]: - state['optimizer_history'][-1]['optimizer_name'] = 'FairseqNAG' - # move best_loss into lr_scheduler_state - if 'lr_scheduler_state' not in state['optimizer_history'][-1]: - state['optimizer_history'][-1]['lr_scheduler_state'] = { - 'best': state['optimizer_history'][-1]['best_loss'], - } - del state['optimizer_history'][-1]['best_loss'] - # keep track of number of updates - if 'num_updates' not in state['optimizer_history'][-1]: - state['optimizer_history'][-1]['num_updates'] = 0 - # old model checkpoints may not have separate source/target positions - if hasattr(state['args'], 'max_positions') and not hasattr(state['args'], 'max_source_positions'): - state['args'].max_source_positions = state['args'].max_positions - state['args'].max_target_positions = state['args'].max_positions - # use stateful training data iterator - if 'train_iterator' not in state['extra_state']: - state['extra_state']['train_iterator'] = { - 'epoch': state['extra_state']['epoch'], - 'iterations_in_epoch': state['extra_state'].get('batch_offset', 0), - } - return state - - -def load_checkpoint_to_cpu(path): - state = torch.load(path, map_location=lambda s, l: default_restore_location(s, 'cpu')) - state = _upgrade_state_dict(state) - return state - - def load_ensemble_for_inference(filenames, task, model_arg_overrides=None): - """Load an ensemble of models for inference. - - model_arg_overrides allows you to pass a dictionary model_arg_overrides -- - {'arg_name': arg} -- to override model args that were used during model - training - """ - # load model architectures and weights - states = [] - for filename in filenames: - if not os.path.exists(filename): - raise IOError('Model file not found: {}'.format(filename)) - state = load_checkpoint_to_cpu(filename) - states.append(state) - - ensemble = [] - for state in states: - args = state['args'] - - if model_arg_overrides is not None: - args = override_model_args(args, model_arg_overrides) - - # build model for ensemble - model = task.build_model(args) - model.upgrade_state_dict(state['model']) - model.load_state_dict(state['model'], strict=True) - ensemble.append(model) - - # some args (e.g., tokens_per_sample) might have been updated while building the model - if model_arg_overrides is not None: - args = override_model_args(args, model_arg_overrides) - - return ensemble, args - - -def override_model_args(args, model_arg_overrides): - # Uses model_arg_overrides {'arg_name': arg} to override model args - for arg_name, arg_val in model_arg_overrides.items(): - setattr(args, arg_name, arg_val) - return args + from fairseq import checkpoint_utils + deprecation_warning( + 'utils.load_ensemble_for_inference is deprecated. ' + 'Please use checkpoint_utils.load_model_ensemble instead.' + ) + return checkpoint_utils.load_model_ensemble( + filenames, arg_overrides=model_arg_overrides, task=task, + ) def move_to_cuda(sample): @@ -379,25 +222,6 @@ def fill_with_neg_inf(t): return t.float().fill_(float('-inf')).type_as(t) -def checkpoint_paths(path, pattern=r'checkpoint(\d+)\.pt'): - """Retrieves all checkpoints found in `path` directory. - - Checkpoints are identified by matching filename to the specified pattern. If - the pattern contains groups, the result will be sorted by the first group in - descending order. - """ - pt_regexp = re.compile(pattern) - files = os.listdir(path) - - entries = [] - for i, f in enumerate(files): - m = pt_regexp.fullmatch(f) - if m is not None: - idx = int(m.group(1)) if len(m.groups()) > 0 else i - entries.append((idx, m.group(0))) - return [os.path.join(path, x[1]) for x in sorted(entries, reverse=True)] - - def resolve_max_positions(*args): """Resolve max position constraints from multiple sources.""" diff --git a/generate.py b/generate.py index bbef4bd1d..10e85b5e9 100644 --- a/generate.py +++ b/generate.py @@ -11,9 +11,8 @@ Translate pre-processed data with a trained model. import torch -from fairseq import bleu, options, progress_bar, tasks, utils +from fairseq import bleu, checkpoint_utils, options, progress_bar, tasks, utils from fairseq.meters import StopwatchMeter, TimeMeter -from fairseq.utils import import_user_module def main(args): @@ -23,7 +22,7 @@ def main(args): assert args.replace_unk is None or args.raw_text, \ '--replace-unk requires a raw text dataset (--raw-text)' - import_user_module(args) + utils.import_user_module(args) if args.max_tokens is None and args.max_sentences is None: args.max_tokens = 12000 @@ -34,7 +33,6 @@ def main(args): # Load dataset splits task = tasks.setup_task(args) task.load_dataset(args.gen_subset) - print('| {} {} {} examples'.format(args.data, args.gen_subset, len(task.dataset(args.gen_subset)))) # Set dictionaries try: @@ -45,8 +43,10 @@ def main(args): # Load ensemble print('| loading model(s) from {}'.format(args.path)) - models, _model_args = utils.load_ensemble_for_inference( - args.path.split(':'), task, model_arg_overrides=eval(args.model_overrides), + models, _model_args = checkpoint_utils.load_model_ensemble( + args.path.split(':'), + arg_overrides=eval(args.model_overrides), + task=task, ) # Optimize ensemble for generation diff --git a/interactive.py b/interactive.py index 586f43a6b..f66a23000 100644 --- a/interactive.py +++ b/interactive.py @@ -15,9 +15,8 @@ import sys import torch -from fairseq import options, tasks, utils +from fairseq import checkpoint_utils, options, tasks, utils from fairseq.sequence_generator import SequenceGenerator -from fairseq.utils import import_user_module Batch = namedtuple('Batch', 'ids src_tokens src_lengths') Translation = namedtuple('Translation', 'src_str hypos pos_scores alignments') @@ -56,7 +55,7 @@ def make_batches(lines, args, task, max_positions): def main(args): - import_user_module(args) + utils.import_user_module(args) if args.buffer_size < 1: args.buffer_size = 1 @@ -77,8 +76,10 @@ def main(args): # Load ensemble print('| loading model(s) from {}'.format(args.path)) - models, _model_args = utils.load_ensemble_for_inference( - args.path.split(':'), task, model_arg_overrides=eval(args.model_overrides), + models, _model_args = checkpoint_utils.load_model_ensemble( + args.path.split(':'), + arg_overrides=eval(args.model_overrides), + task=task, ) # Set dictionaries diff --git a/preprocess.py b/preprocess.py index 163f88097..625f8bd77 100644 --- a/preprocess.py +++ b/preprocess.py @@ -12,10 +12,9 @@ Data pre-processing: build vocabularies and binarize training data. from collections import Counter from itertools import zip_longest -from fairseq import options, tasks +from fairseq import options, tasks, utils from fairseq.data import indexed_dataset from fairseq.binarizer import Binarizer -from fairseq.utils import import_user_module from multiprocessing import Pool import os @@ -23,7 +22,7 @@ import shutil def main(args): - import_user_module(args) + utils.import_user_module(args) print(args) diff --git a/tests/test_binaries.py b/tests/test_binaries.py index 7a6bc1a13..42078ba49 100644 --- a/tests/test_binaries.py +++ b/tests/test_binaries.py @@ -239,7 +239,20 @@ class TestLanguageModeling(unittest.TestCase): with tempfile.TemporaryDirectory('test_fconv_lm') as data_dir: create_dummy_data(data_dir) preprocess_lm_data(data_dir) - train_language_model(data_dir, 'fconv_lm') + train_language_model(data_dir, 'fconv_lm', [ + '--decoder-layers', '[(850, 3)] * 2 + [(1024,4)]', + '--decoder-embed-dim', '280', + '--optimizer', 'nag', + '--lr', '0.1', + ]) + eval_lm_main(data_dir) + + def test_transformer_lm(self): + with contextlib.redirect_stdout(StringIO()): + with tempfile.TemporaryDirectory('test_transformer_lm') as data_dir: + create_dummy_data(data_dir) + preprocess_lm_data(data_dir) + train_language_model(data_dir, 'transformer_lm', ['--add-bos-token']) eval_lm_main(data_dir) @@ -534,7 +547,7 @@ def preprocess_lm_data(data_dir): preprocess.main(preprocess_args) -def train_language_model(data_dir, arch): +def train_language_model(data_dir, arch, extra_flags=None): train_parser = options.get_training_parser() train_args = options.parse_args_and_arch( train_parser, @@ -542,12 +555,10 @@ def train_language_model(data_dir, arch): '--task', 'language_modeling', data_dir, '--arch', arch, - '--optimizer', 'nag', - '--lr', '0.1', + '--optimizer', 'adam', + '--lr', '0.0001', '--criterion', 'adaptive_loss', '--adaptive-softmax-cutoff', '5,10,15', - '--decoder-layers', '[(850, 3)] * 2 + [(1024,4)]', - '--decoder-embed-dim', '280', '--max-tokens', '500', '--tokens-per-sample', '500', '--save-dir', data_dir, @@ -555,7 +566,7 @@ def train_language_model(data_dir, arch): '--no-progress-bar', '--distributed-world-size', '1', '--ddp-backend', 'no_c10d', - ], + ] + (extra_flags or []), ) train.main(train_args) diff --git a/tests/test_reproducibility.py b/tests/test_reproducibility.py index 33aed39e4..d1891db3d 100644 --- a/tests/test_reproducibility.py +++ b/tests/test_reproducibility.py @@ -38,7 +38,7 @@ class TestReproducibility(unittest.TestCase): ] + extra_flags, ) stdout = stdout.getvalue() - train_log, valid_log = map(json.loads, stdout.split('\n')[-4:-2]) + train_log, valid_log = map(json.loads, stdout.split('\n')[-5:-3]) # train epoch 2, resuming from previous checkpoint 1 os.rename( @@ -56,7 +56,7 @@ class TestReproducibility(unittest.TestCase): ] + extra_flags, ) stdout = stdout.getvalue() - train_res_log, valid_res_log = map(json.loads, stdout.split('\n')[-4:-2]) + train_res_log, valid_res_log = map(json.loads, stdout.split('\n')[-5:-3]) def cast(s): return round(float(s), 3) diff --git a/train.py b/train.py index 57b4d4dbc..95c6ac1b8 100644 --- a/train.py +++ b/train.py @@ -17,15 +17,14 @@ import random import torch -from fairseq import distributed_utils, options, progress_bar, tasks, utils +from fairseq import checkpoint_utils, distributed_utils, options, progress_bar, tasks, utils from fairseq.data import iterators from fairseq.trainer import Trainer from fairseq.meters import AverageMeter, StopwatchMeter -from fairseq.utils import import_user_module def main(args, init_distributed=False): - import_user_module(args) + utils.import_user_module(args) if args.max_tokens is None: args.max_tokens = 6000 @@ -326,14 +325,18 @@ def save_checkpoint(args, trainer, epoch_itr, val_loss): if not end_of_epoch and args.keep_interval_updates > 0: # remove old checkpoints; checkpoints are sorted in descending order - checkpoints = utils.checkpoint_paths(args.save_dir, pattern=r'checkpoint_\d+_(\d+)\.pt') + checkpoints = checkpoint_utils.checkpoint_paths( + args.save_dir, pattern=r'checkpoint_\d+_(\d+)\.pt', + ) for old_chk in checkpoints[args.keep_interval_updates:]: if os.path.lexists(old_chk): os.remove(old_chk) if args.keep_last_epochs > 0: # remove old epoch checkpoints; checkpoints are sorted in descending order - checkpoints = utils.checkpoint_paths(args.save_dir, pattern=r'checkpoint(\d+)\.pt') + checkpoints = checkpoint_utils.checkpoint_paths( + args.save_dir, pattern=r'checkpoint(\d+)\.pt', + ) for old_chk in checkpoints[args.keep_last_epochs:]: if os.path.lexists(old_chk): os.remove(old_chk)