Merge internal changes (#654)

Summary:
- Add --add-bos-token option to LM task
- Cleanup utils.py and options.py
Pull Request resolved: https://github.com/pytorch/fairseq/pull/654

Differential Revision: D15041794

Pulled By: myleott

fbshipit-source-id: 3ad00007769d5f48308052cfd40de39c5ffa1a6e
This commit is contained in:
Myle Ott 2019-04-29 19:45:01 -07:00 committed by Facebook Github Bot
parent 89a696161b
commit d45db80431
34 changed files with 368 additions and 278 deletions

View File

@ -15,9 +15,15 @@ Optimizers update the Model parameters based on the gradients.
:members:
:undoc-members:
.. autoclass:: fairseq.optim.adadelta.Adadelta
:members:
:undoc-members:
.. autoclass:: fairseq.optim.adagrad.Adagrad
:members:
:undoc-members:
.. autoclass:: fairseq.optim.adafactor.FairseqAdafactor
:members:
:undoc-members:
.. autoclass:: fairseq.optim.adam.FairseqAdam
:members:
:undoc-members:

View File

@ -28,11 +28,12 @@ fairseq implements the following high-level training flow::
lr_scheduler.step_update(num_updates)
lr_scheduler.step(epoch)
where the default implementation for ``train.train_step`` is roughly::
where the default implementation for ``task.train_step`` is roughly::
def train_step(self, batch, model, criterion, optimizer):
loss = criterion(model, batch)
optimizer.backward(loss)
return loss
**Registering new plug-ins**

View File

@ -354,7 +354,7 @@ The model files should appear in the :file:`checkpoints/` directory.
Finally we can write a short script to evaluate our model on new inputs. Create
a new file named :file:`eval_classifier.py` with the following contents::
from fairseq import data, options, tasks, utils
from fairseq import checkpoint_utils, data, options, tasks
# Parse command-line arguments for generation
parser = options.get_generation_parser(default_task='simple_classification')
@ -365,7 +365,7 @@ a new file named :file:`eval_classifier.py` with the following contents::
# Load model
print('| loading model from {}'.format(args.path))
models, _model_args = utils.load_ensemble_for_inference([args.path], task)
models, _model_args = checkpoint_utils.load_model_ensemble([args.path], task=task)
model = models[0]
while True:

View File

@ -13,11 +13,10 @@ Evaluate the perplexity of a trained language model.
import numpy as np
import torch
from fairseq import options, progress_bar, tasks, utils
from fairseq import checkpoint_utils, options, progress_bar, tasks, utils
from fairseq.data import LMContextWindowDataset
from fairseq.meters import StopwatchMeter, TimeMeter
from fairseq.sequence_scorer import SequenceScorer
from fairseq.utils import import_user_module
class WordStat(object):
@ -49,7 +48,7 @@ class WordStat(object):
def main(parsed_args):
assert parsed_args.path is not None, '--path required for evaluation!'
import_user_module(parsed_args)
utils.import_user_module(parsed_args)
print(parsed_args)
@ -59,12 +58,17 @@ def main(parsed_args):
# Load ensemble
print('| loading model(s) from {}'.format(parsed_args.path))
models, args = utils.load_ensemble_for_inference(
parsed_args.path.split(':'), task, model_arg_overrides=eval(parsed_args.model_overrides),
models, args = checkpoint_utils.load_model_ensemble(
parsed_args.path.split(':'),
arg_overrides=eval(parsed_args.model_overrides),
task=task,
)
for arg in vars(parsed_args).keys():
if arg not in {'self_target', 'future_target', 'past_target', 'tokens_per_sample', 'output_size_dictionary'}:
if arg not in {
'self_target', 'future_target', 'past_target', 'tokens_per_sample',
'output_size_dictionary', 'add_bos_token',
}:
setattr(args, arg, getattr(parsed_args, arg))
# reduce tokens per sample by the required context window size
@ -151,6 +155,11 @@ def main(parsed_args):
tgt_len = tokens.numel()
pos_scores = hypo['positional_scores'].float()
if args.add_bos_token:
assert hypo['tokens'][0].item() == task.target_dictionary.bos()
tokens = tokens[1:]
pos_scores = pos_scores[1:]
skipped_toks = 0
if bpe_toks is not None:
for i in range(tgt_len - 1):

View File

@ -39,7 +39,7 @@ $ fairseq-train --task language_modeling data-bin/wikitext-103 \
--save-dir checkpoints/transformer_wikitext-103 --arch transformer_lm_wiki103 \
--max-update 286000 --max-lr 1.0 --t-mult 2 --lr-period-updates 270000 --lr-scheduler cosine --lr-shrink 0.75 \
--warmup-updates 16000 --warmup-init-lr 1e-07 --min-lr 1e-09 --optimizer nag --lr 0.0001 --clip-norm 0.1 \
--criterion adaptive_loss --max-tokens 3072 --update-freq 4 --tokens-per-sample 3072 --seed 1 \
--criterion adaptive_loss --max-tokens 3072 --update-freq 3 --tokens-per-sample 3072 --seed 1 \
--sample-break-mode none --skip-invalid-size-inputs-valid-test --ddp-backend=no_c10d
# Evaluate:

177
fairseq/checkpoint_utils.py Normal file
View File

@ -0,0 +1,177 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from collections import OrderedDict
import logging
import os
import re
import traceback
import torch
from torch.serialization import default_restore_location
from fairseq import tasks
def load_checkpoint_to_cpu(path):
"""Loads a checkpoint to CPU (with upgrading for backward compatibility)."""
state = torch.load(
path, map_location=lambda s, l: default_restore_location(s, 'cpu'),
)
state = _upgrade_state_dict(state)
return state
def load_model_ensemble(filenames, arg_overrides=None, task=None):
"""Loads an ensemble of models.
Args:
filenames (List[str]): checkpoint files to load
arg_overrides (Dict[str,Any], optional): override model args that
were used during model training
task (fairseq.tasks.FairseqTask, optional): task to use for loading
"""
ensemble = []
for filename in filenames:
if not os.path.exists(filename):
raise IOError('Model file not found: {}'.format(filename))
state = load_checkpoint_to_cpu(filename)
args = state['args']
if arg_overrides is not None:
for arg_name, arg_val in arg_overrides.items():
setattr(args, arg_name, arg_val)
if task is None:
task = tasks.setup_task(args)
# build model for ensemble
model = task.build_model(args)
model.load_state_dict(state['model'], strict=True)
ensemble.append(model)
return ensemble, args
def checkpoint_paths(path, pattern=r'checkpoint(\d+)\.pt'):
"""Retrieves all checkpoints found in `path` directory.
Checkpoints are identified by matching filename to the specified pattern. If
the pattern contains groups, the result will be sorted by the first group in
descending order.
"""
pt_regexp = re.compile(pattern)
files = os.listdir(path)
entries = []
for i, f in enumerate(files):
m = pt_regexp.fullmatch(f)
if m is not None:
idx = int(m.group(1)) if len(m.groups()) > 0 else i
entries.append((idx, m.group(0)))
return [os.path.join(path, x[1]) for x in sorted(entries, reverse=True)]
def torch_persistent_save(*args, **kwargs):
for i in range(3):
try:
return torch.save(*args, **kwargs)
except Exception:
if i == 2:
logging.error(traceback.format_exc())
def convert_state_dict_type(state_dict, ttype=torch.FloatTensor):
if isinstance(state_dict, dict):
cpu_dict = OrderedDict()
for k, v in state_dict.items():
cpu_dict[k] = convert_state_dict_type(v)
return cpu_dict
elif isinstance(state_dict, list):
return [convert_state_dict_type(v) for v in state_dict]
elif torch.is_tensor(state_dict):
return state_dict.type(ttype)
else:
return state_dict
def save_state(
filename, args, model_state_dict, criterion, optimizer, lr_scheduler,
num_updates, optim_history=None, extra_state=None,
):
if optim_history is None:
optim_history = []
if extra_state is None:
extra_state = {}
state_dict = {
'args': args,
'model': model_state_dict if model_state_dict else {},
'optimizer_history': optim_history + [
{
'criterion_name': criterion.__class__.__name__,
'optimizer_name': optimizer.__class__.__name__,
'lr_scheduler_state': lr_scheduler.state_dict(),
'num_updates': num_updates,
}
],
'last_optimizer_state': convert_state_dict_type(optimizer.state_dict()),
'extra_state': extra_state,
}
torch_persistent_save(state_dict, filename)
def _upgrade_state_dict(state):
"""Helper for upgrading old model checkpoints."""
# add optimizer_history
if 'optimizer_history' not in state:
state['optimizer_history'] = [
{
'criterion_name': 'CrossEntropyCriterion',
'best_loss': state['best_loss'],
},
]
state['last_optimizer_state'] = state['optimizer']
del state['optimizer']
del state['best_loss']
# move extra_state into sub-dictionary
if 'epoch' in state and 'extra_state' not in state:
state['extra_state'] = {
'epoch': state['epoch'],
'batch_offset': state['batch_offset'],
'val_loss': state['val_loss'],
}
del state['epoch']
del state['batch_offset']
del state['val_loss']
# reduce optimizer history's memory usage (only keep the last state)
if 'optimizer' in state['optimizer_history'][-1]:
state['last_optimizer_state'] = state['optimizer_history'][-1]['optimizer']
for optim_hist in state['optimizer_history']:
del optim_hist['optimizer']
# record the optimizer class name
if 'optimizer_name' not in state['optimizer_history'][-1]:
state['optimizer_history'][-1]['optimizer_name'] = 'FairseqNAG'
# move best_loss into lr_scheduler_state
if 'lr_scheduler_state' not in state['optimizer_history'][-1]:
state['optimizer_history'][-1]['lr_scheduler_state'] = {
'best': state['optimizer_history'][-1]['best_loss'],
}
del state['optimizer_history'][-1]['best_loss']
# keep track of number of updates
if 'num_updates' not in state['optimizer_history'][-1]:
state['optimizer_history'][-1]['num_updates'] = 0
# old model checkpoints may not have separate source/target positions
if hasattr(state['args'], 'max_positions') and not hasattr(state['args'], 'max_source_positions'):
state['args'].max_source_positions = state['args'].max_positions
state['args'].max_target_positions = state['args'].max_positions
# use stateful training data iterator
if 'train_iterator' not in state['extra_state']:
state['extra_state']['train_iterator'] = {
'epoch': state['extra_state']['epoch'],
'iterations_in_epoch': state['extra_state'].get('batch_offset', 0),
}
return state

View File

@ -18,13 +18,12 @@ from fairseq.data import data_utils
class Dictionary(object):
"""A mapping from symbols to consecutive integers"""
def __init__(self, pad='<pad>', eos='</s>', unk='<unk>'):
def __init__(self, pad='<pad>', eos='</s>', unk='<unk>', bos='<s>'):
self.unk_word, self.pad_word, self.eos_word = unk, pad, eos
self.symbols = []
self.count = []
self.indices = {}
# dictionary indexing starts at 1 for consistency with Lua
self.add_symbol('<Lua heritage>')
self.bos_index = self.add_symbol(bos)
self.pad_index = self.add_symbol(pad)
self.eos_index = self.add_symbol(eos)
self.unk_index = self.add_symbol(unk)
@ -143,6 +142,10 @@ class Dictionary(object):
self.symbols = list(new_symbols)
self.indices = new_indices
def bos(self):
"""Helper to get index of beginning-of-sentence symbol"""
return self.bos_index
def pad(self):
"""Helper to get index of pad symbol"""
return self.pad_index

View File

@ -62,13 +62,14 @@ class MonolingualDataset(FairseqDataset):
"""
def __init__(self, dataset, sizes, src_vocab, tgt_vocab, add_eos_for_other_targets, shuffle,
targets=None):
targets=None, add_bos_token=False):
self.dataset = dataset
self.sizes = np.array(sizes)
self.vocab = src_vocab
self.tgt_vocab = tgt_vocab
self.add_eos_for_other_targets = add_eos_for_other_targets
self.shuffle = shuffle
self.add_bos_token = add_bos_token
assert targets is None or all(t in {'self', 'future', 'past'} for t in targets), \
"targets must be none or one of 'self', 'future', 'past'"
@ -91,6 +92,7 @@ class MonolingualDataset(FairseqDataset):
else:
source = self.dataset[index]
target = None
source, target = self._maybe_add_bos(source, target)
return {'id': index, 'source': source, 'target': target}
def __len__(self):
@ -129,6 +131,13 @@ class MonolingualDataset(FairseqDataset):
return source, self._filter_vocab(target)
def _maybe_add_bos(self, source, target):
if self.add_bos_token:
source = torch.cat([source.new([self.vocab.bos()]), source])
if target is not None:
target = torch.cat([target.new([self.tgt_vocab.bos()]), target])
return source, target
def _filter_vocab(self, target):
if len(self.tgt_vocab) != len(self.vocab):
def _filter(target):
@ -173,6 +182,7 @@ class MonolingualDataset(FairseqDataset):
target = self.vocab.dummy_sentence(tgt_len + 2)
source, past_target, future_target = target[1:-1], target[2:], target[:-2]
source, target = self._make_source_target(source, past_target, future_target)
source, target = self._maybe_add_bos(source, target)
return self.collater([
{'id': i, 'source': source, 'target': target}

View File

@ -141,7 +141,7 @@ class FConvLanguageModel(FairseqLanguageModel):
# make sure all arguments are present in older models
base_lm_architecture(args)
if hasattr(args, 'max_target_positions'):
if hasattr(args, 'max_target_positions') and not hasattr(args, 'tokens_per_sample'):
args.tokens_per_sample = args.max_target_positions
decoder = FConvDecoder(

View File

@ -4,7 +4,6 @@
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
import math
@ -12,11 +11,11 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from fairseq import checkpoint_utils
from fairseq.modules import (
DownsampledMultiHeadAttention, GradMultiply, LayerNorm,
LearnedPositionalEmbedding, LinearizedConvolution,
)
from fairseq import utils
from . import (
FairseqEncoder, CompositeEncoder, FairseqDecoder, FairseqModel,
@ -84,8 +83,7 @@ class FConvModelSelfAtt(FairseqModel):
pretrained = eval(args.pretrained)
if pretrained:
print("| loading pretrained model")
trained_model = utils.load_ensemble_for_inference(
# not actually for inference, but loads pretrained model parameters
trained_model = checkpoint_utils.load_model_ensemble(
filenames=[args.pretrained_checkpoint],
task=task,
)[0][0]

View File

@ -830,6 +830,7 @@ def base_lm_architecture(args):
args.decoder_learned_pos = getattr(args, 'decoder_learned_pos', False)
args.activation_fn = getattr(args, 'activation_fn', 'relu')
args.add_bos_token = getattr(args, 'add_bos_token', False)
args.character_embeddings = getattr(args, 'character_embeddings', False)
args.decoder_output_dim = getattr(args, 'decoder_output_dim', args.decoder_embed_dim)
@ -927,7 +928,7 @@ def transformer_wmt_en_de(args):
base_architecture(args)
# parameters used in the "Attention Is All You Need" paper (Vaswani, et al, 2017)
# parameters used in the "Attention Is All You Need" paper (Vaswani et al., 2017)
@register_model_architecture('transformer', 'transformer_vaswani_wmt_en_de_big')
def transformer_vaswani_wmt_en_de_big(args):
args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 1024)

View File

@ -8,7 +8,7 @@
import os
from typing import Any, Dict
from fairseq import utils
from fairseq import checkpoint_utils
from fairseq.data.masked_lm_dictionary import MaskedLMDictionary
from fairseq.models.transformer import (
TransformerDecoder,
@ -92,7 +92,7 @@ def upgrade_state_dict_with_xlm_weights(
if not os.path.exists(pretrained_xlm_checkpoint):
raise IOError(f"Model file not found: {pretrained_xlm_checkpoint}")
state = utils.load_checkpoint_to_cpu(pretrained_xlm_checkpoint)
state = checkpoint_utils.load_checkpoint_to_cpu(pretrained_xlm_checkpoint)
xlm_state_dict = state["model"]
for key in xlm_state_dict.keys():

View File

@ -4,17 +4,15 @@
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import math
import torch
"""
See "Gaussian Error Linear Units (GELUs)" by Dan Hendrycks and Kevin Gimpel with
the corresponding GitHub repo: https://github.com/hendrycks/GELUs
"""
import math
import torch
def gelu_fast(x):
if not hasattr(gelu_fast, "_a"):

View File

@ -19,11 +19,15 @@ class Adadelta(FairseqOptimizer):
@staticmethod
def add_args(parser):
"""Add optimizer-specific arguments to the parser."""
# fmt: off
parser.add_argument('--adadelta-rho', type=float, default=0.9, metavar='RHO',
help='coefficient used for computing a running average of squared gradients')
parser.add_argument('--adadelta-eps', type=float, default=1e-6, metavar='EPS',
help='term added to the denominator to improve numerical stability')
parser.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD',
help='weight decay')
parser.add_argument('--anneal-eps', action='store_true', help='flag to anneal eps')
# fmt: on
@property
def optimizer_config(self):

View File

@ -21,6 +21,7 @@ class FairseqAdafactor(FairseqOptimizer):
@staticmethod
def add_args(parser):
"""Add optimizer-specific arguments to the parser."""
# fmt: off
parser.add_argument('--adafactor-eps', default='(1e-30, 1e-3)', metavar="E",
help='epsilons for Adafactor optimizer')
parser.add_argument('--clip-threshold', type=float, default=1.0, metavar="C",
@ -31,11 +32,14 @@ class FairseqAdafactor(FairseqOptimizer):
help='beta for first moment estimator. Optional')
parser.add_argument('--scale-parameter', action='store_true',
help='scale learning rate by root mean square of parameter.')
parser.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD',
help='weight decay')
parser.add_argument('--warmup-init', action='store_true',
help='use relative step for warm-up learning rate schedule')
parser.add_argument('--relative-step', action='store_true',
help='set learning rate to inverse square root of timestep.'
'If false, external learning rate applied')
# fmt: on
@property
def optimizer_config(self):

View File

@ -16,6 +16,14 @@ class Adagrad(FairseqOptimizer):
super().__init__(args, params)
self._optimizer = torch.optim.Adagrad(params, **self.optimizer_config)
@staticmethod
def add_args(parser):
"""Add optimizer-specific arguments to the parser."""
# fmt: off
parser.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD',
help='weight decay')
# fmt: on
@property
def optimizer_config(self):
"""

View File

@ -30,6 +30,8 @@ class FairseqAdam(FairseqOptimizer):
help='betas for Adam optimizer')
parser.add_argument('--adam-eps', type=float, default=1e-8, metavar='D',
help='epsilon for Adam optimizer')
parser.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD',
help='weight decay')
# fmt: on
@property

View File

@ -85,6 +85,8 @@ class CosineSchedule(FairseqLRScheduler):
help='factor to grow the length of each period')
parser.add_argument('--lr-period-updates', default=-1, type=float, metavar='LR',
help='initial number of updates per period')
parser.add_argument('--lr-shrink', default=0.1, type=float, metavar='LS',
help='shrink factor for annealing')
# fmt: on
def step(self, epoch, val_loss=None):

View File

@ -30,6 +30,8 @@ class FixedSchedule(FairseqLRScheduler):
# fmt: off
parser.add_argument('--force-anneal', '--fa', type=int, metavar='N',
help='force annealing at specified epoch')
parser.add_argument('--lr-shrink', default=0.1, type=float, metavar='LS',
help='shrink factor for annealing, lr_new = (lr * lr_shrink)')
parser.add_argument('--warmup-updates', default=0, type=int, metavar='N',
help='warmup the learning rate linearly for the first N updates')
# fmt: on

View File

@ -24,6 +24,14 @@ class ReduceLROnPlateau(FairseqLRScheduler):
self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
self.optimizer.optimizer, patience=0, factor=args.lr_shrink)
@staticmethod
def add_args(parser):
"""Add arguments to the parser for this LR scheduler."""
# fmt: off
parser.add_argument('--lr-shrink', default=0.1, type=float, metavar='LS',
help='shrink factor for annealing, lr_new = (lr * lr_shrink)')
# fmt: on
def state_dict(self):
"""Return the LR scheduler state dict."""
return {

View File

@ -46,6 +46,8 @@ class TriangularSchedule(FairseqLRScheduler):
help='max learning rate, must be more than args.lr')
parser.add_argument('--lr-period-updates', default=5000, type=float, metavar='LR',
help='initial number of updates per period (cycle length)')
parser.add_argument('--lr-shrink', default=0.1, type=float, metavar='LS',
help='shrink factor for annealing')
parser.add_argument('--shrink-min', action='store_true',
help='if set, also shrinks min lr')
# fmt: on

View File

@ -16,6 +16,16 @@ class FairseqNAG(FairseqOptimizer):
super().__init__(args, params)
self._optimizer = NAG(params, **self.optimizer_config)
@staticmethod
def add_args(parser):
"""Add optimizer-specific arguments to the parser."""
# fmt: off
parser.add_argument('--momentum', default=0.99, type=float, metavar='M',
help='momentum factor')
parser.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD',
help='weight decay')
# fmt: on
@property
def optimizer_config(self):
"""

View File

@ -16,6 +16,16 @@ class SGD(FairseqOptimizer):
super().__init__(args, params)
self._optimizer = torch.optim.SGD(params, **self.optimizer_config)
@staticmethod
def add_args(parser):
"""Add optimizer-specific arguments to the parser."""
# fmt: off
parser.add_argument('--momentum', default=0.0, type=float, metavar='M',
help='momentum factor')
parser.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD',
help='weight decay')
# fmt: on
@property
def optimizer_config(self):
"""

View File

@ -303,19 +303,13 @@ def add_optimization_args(parser):
metavar='LR_1,LR_2,...,LR_N',
help='learning rate for the first N epochs; all epochs >N using LR_N'
' (note: this may be interpreted differently depending on --lr-scheduler)')
group.add_argument('--momentum', default=0.99, type=float, metavar='M',
help='momentum factor')
group.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD',
help='weight decay')
# Learning rate schedulers can be found under fairseq/optim/lr_scheduler/
group.add_argument('--lr-scheduler', default='reduce_lr_on_plateau',
group.add_argument('--lr-scheduler', default='fixed',
choices=LR_SCHEDULER_REGISTRY.keys(),
help='Learning Rate Scheduler')
group.add_argument('--lr-shrink', default=0.1, type=float, metavar='LS',
help='learning rate shrink factor for annealing, lr_new = (lr * lr_shrink)')
group.add_argument('--min-lr', default=-1, type=float, metavar='LR',
help='minimum learning rate')
help='stop training when the learning rate reaches this minimum')
# fmt: on
return group

View File

@ -81,6 +81,8 @@ class LanguageModelingTask(FairseqTask):
help='include future target')
parser.add_argument('--past-target', action='store_true',
help='include past target')
parser.add_argument('--add-bos-token', action='store_true',
help='prepend beginning of sentence token (<s>)')
# fmt: on
def __init__(self, args, dictionary, output_dictionary, targets=None):
@ -185,7 +187,7 @@ class LanguageModelingTask(FairseqTask):
self.datasets[split] = MonolingualDataset(
dataset, sizes, self.dictionary, self.output_dictionary,
add_eos_for_other_targets=add_eos_for_other_targets, shuffle=True,
targets=self.targets,
targets=self.targets, add_bos_token=self.args.add_bos_token,
)
def build_dataset_for_inference(self, src_tokens, src_lengths):
@ -205,6 +207,7 @@ class LanguageModelingTask(FairseqTask):
self.target_dictionary,
add_eos_for_other_targets=False,
shuffle=False,
add_bos_token=self.args.add_bos_token,
),
eos=self.source_dictionary.eos(),
# remove EOS since this will be used as a prefix for generation

View File

@ -8,7 +8,7 @@
import itertools
import os
from fairseq import options, utils
from fairseq import options
from fairseq.data import (
ConcatDataset,
data_utils,
@ -69,24 +69,6 @@ class TranslationTask(FairseqTask):
help='amount to upsample primary dataset')
# fmt: on
@staticmethod
def load_pretrained_model(path, src_dict_path, tgt_dict_path, arg_overrides=None):
model = utils.load_checkpoint_to_cpu(path)
args = model['args']
state_dict = model['model']
args = utils.override_model_args(args, arg_overrides)
src_dict = Dictionary.load(src_dict_path)
tgt_dict = Dictionary.load(tgt_dict_path)
assert src_dict.pad() == tgt_dict.pad()
assert src_dict.eos() == tgt_dict.eos()
assert src_dict.unk() == tgt_dict.unk()
task = TranslationTask(args, src_dict, tgt_dict)
model = task.build_model(args)
model.upgrade_state_dict(state_dict)
model.load_state_dict(state_dict, strict=True)
return model
def __init__(self, args, src_dict, tgt_dict):
super().__init__(args)
self.src_dict = src_dict
@ -102,6 +84,10 @@ class TranslationTask(FairseqTask):
args.left_pad_source = options.eval_bool(args.left_pad_source)
args.left_pad_target = options.eval_bool(args.left_pad_target)
# upgrade old checkpoints
if isinstance(args.data, str):
args.data = [args.data]
# find language pair automatically
if args.source_lang is None or args.target_lang is None:
args.source_lang, args.target_lang = data_utils.infer_language_pair(args.data[0])
@ -147,9 +133,7 @@ class TranslationTask(FairseqTask):
src_datasets = []
tgt_datasets = []
data_paths = self.args.data
for dk, data_path in enumerate(data_paths):
for dk, data_path in enumerate(self.args.data):
for k in itertools.count():
split_k = split + (str(k) if k > 0 else '')

View File

@ -11,10 +11,11 @@ Train a network across multiple GPUs.
from collections import OrderedDict
from itertools import chain
import os
import torch
from fairseq import distributed_utils, models, optim, utils
from fairseq import checkpoint_utils, distributed_utils, models, optim, utils
from fairseq.meters import AverageMeter, StopwatchMeter, TimeMeter
from fairseq.optim import lr_scheduler
@ -119,16 +120,31 @@ class Trainer(object):
"""Save all training state in a checkpoint file."""
if distributed_utils.is_master(self.args): # only save one checkpoint
extra_state['train_meters'] = self.meters
utils.save_state(
checkpoint_utils.save_state(
filename, self.args, self.get_model().state_dict(), self.criterion, self.optimizer,
self.lr_scheduler, self._num_updates, self._optim_history, extra_state,
)
def load_checkpoint(self, filename, reset_optimizer=False, reset_lr_scheduler=False, optimizer_overrides=None):
"""Load all training state from a checkpoint file."""
extra_state, self._optim_history, last_optim_state = utils.load_model_state(
filename, self.get_model(),
)
extra_state, self._optim_history, last_optim_state = None, [], None
if os.path.exists(filename):
state = checkpoint_utils.load_checkpoint_to_cpu(filename)
# load model parameters
try:
self.get_model().load_state_dict(state['model'], strict=True)
except Exception:
raise Exception(
'Cannot load model parameters from checkpoint, '
'please ensure that the architectures match.'
)
extra_state = state['extra_state']
self._optim_history = state['optimizer_history']
last_optim_state = state['last_optimizer_state']
if last_optim_state is not None and not reset_optimizer:
# rebuild optimizer after loading model, since params may have changed
self._build_optimizer()
@ -136,9 +152,9 @@ class Trainer(object):
# only reload optimizer and lr_scheduler if they match
last_optim = self._optim_history[-1]
assert last_optim['criterion_name'] == self.criterion.__class__.__name__, \
'criterion does not match; please reset the optimizer (--reset-optimizer)'
'Criterion does not match; please reset the optimizer (--reset-optimizer).'
assert last_optim['optimizer_name'] == self.optimizer.__class__.__name__, \
'optimizer does not match; please reset the optimizer (--reset-optimizer)'
'Optimizer does not match; please reset the optimizer (--reset-optimizer).'
if not reset_lr_scheduler:
self.lr_scheduler.load_state_dict(last_optim['lr_scheduler_state'])

View File

@ -9,182 +9,25 @@ from collections import defaultdict, OrderedDict
from typing import Callable
import copy
import importlib.util
import logging
import os
import re
import sys
import traceback
import warnings
import torch
import torch.nn.functional as F
from torch.serialization import default_restore_location
from fairseq.modules import gelu, gelu_fast
def torch_persistent_save(*args, **kwargs):
for i in range(3):
try:
return torch.save(*args, **kwargs)
except Exception:
if i == 2:
logging.error(traceback.format_exc())
def convert_state_dict_type(state_dict, ttype=torch.FloatTensor):
if isinstance(state_dict, dict):
cpu_dict = OrderedDict()
for k, v in state_dict.items():
cpu_dict[k] = convert_state_dict_type(v)
return cpu_dict
elif isinstance(state_dict, list):
return [convert_state_dict_type(v) for v in state_dict]
elif torch.is_tensor(state_dict):
return state_dict.type(ttype)
else:
return state_dict
def save_state(filename, args, model_state_dict, criterion, optimizer, lr_scheduler,
num_updates, optim_history=None, extra_state=None):
if optim_history is None:
optim_history = []
if extra_state is None:
extra_state = {}
state_dict = {
'args': args,
'model': model_state_dict if model_state_dict else {},
'optimizer_history': optim_history + [
{
'criterion_name': criterion.__class__.__name__,
'optimizer_name': optimizer.__class__.__name__,
'lr_scheduler_state': lr_scheduler.state_dict(),
'num_updates': num_updates,
}
],
'last_optimizer_state': convert_state_dict_type(optimizer.state_dict()),
'extra_state': extra_state,
}
torch_persistent_save(state_dict, filename)
def load_model_state(filename, model):
if not os.path.exists(filename):
return None, [], None
state = torch.load(filename, map_location=lambda s, l: default_restore_location(s, 'cpu'))
state = _upgrade_state_dict(state)
model.upgrade_state_dict(state['model'])
# load model parameters
try:
model.load_state_dict(state['model'], strict=True)
except Exception:
raise Exception('Cannot load model parameters from checkpoint, '
'please ensure that the architectures match')
return state['extra_state'], state['optimizer_history'], state['last_optimizer_state']
def _upgrade_state_dict(state):
"""Helper for upgrading old model checkpoints."""
# add optimizer_history
if 'optimizer_history' not in state:
state['optimizer_history'] = [
{
'criterion_name': 'CrossEntropyCriterion',
'best_loss': state['best_loss'],
},
]
state['last_optimizer_state'] = state['optimizer']
del state['optimizer']
del state['best_loss']
# move extra_state into sub-dictionary
if 'epoch' in state and 'extra_state' not in state:
state['extra_state'] = {
'epoch': state['epoch'],
'batch_offset': state['batch_offset'],
'val_loss': state['val_loss'],
}
del state['epoch']
del state['batch_offset']
del state['val_loss']
# reduce optimizer history's memory usage (only keep the last state)
if 'optimizer' in state['optimizer_history'][-1]:
state['last_optimizer_state'] = state['optimizer_history'][-1]['optimizer']
for optim_hist in state['optimizer_history']:
del optim_hist['optimizer']
# record the optimizer class name
if 'optimizer_name' not in state['optimizer_history'][-1]:
state['optimizer_history'][-1]['optimizer_name'] = 'FairseqNAG'
# move best_loss into lr_scheduler_state
if 'lr_scheduler_state' not in state['optimizer_history'][-1]:
state['optimizer_history'][-1]['lr_scheduler_state'] = {
'best': state['optimizer_history'][-1]['best_loss'],
}
del state['optimizer_history'][-1]['best_loss']
# keep track of number of updates
if 'num_updates' not in state['optimizer_history'][-1]:
state['optimizer_history'][-1]['num_updates'] = 0
# old model checkpoints may not have separate source/target positions
if hasattr(state['args'], 'max_positions') and not hasattr(state['args'], 'max_source_positions'):
state['args'].max_source_positions = state['args'].max_positions
state['args'].max_target_positions = state['args'].max_positions
# use stateful training data iterator
if 'train_iterator' not in state['extra_state']:
state['extra_state']['train_iterator'] = {
'epoch': state['extra_state']['epoch'],
'iterations_in_epoch': state['extra_state'].get('batch_offset', 0),
}
return state
def load_checkpoint_to_cpu(path):
state = torch.load(path, map_location=lambda s, l: default_restore_location(s, 'cpu'))
state = _upgrade_state_dict(state)
return state
def load_ensemble_for_inference(filenames, task, model_arg_overrides=None):
"""Load an ensemble of models for inference.
model_arg_overrides allows you to pass a dictionary model_arg_overrides --
{'arg_name': arg} -- to override model args that were used during model
training
"""
# load model architectures and weights
states = []
for filename in filenames:
if not os.path.exists(filename):
raise IOError('Model file not found: {}'.format(filename))
state = load_checkpoint_to_cpu(filename)
states.append(state)
ensemble = []
for state in states:
args = state['args']
if model_arg_overrides is not None:
args = override_model_args(args, model_arg_overrides)
# build model for ensemble
model = task.build_model(args)
model.upgrade_state_dict(state['model'])
model.load_state_dict(state['model'], strict=True)
ensemble.append(model)
# some args (e.g., tokens_per_sample) might have been updated while building the model
if model_arg_overrides is not None:
args = override_model_args(args, model_arg_overrides)
return ensemble, args
def override_model_args(args, model_arg_overrides):
# Uses model_arg_overrides {'arg_name': arg} to override model args
for arg_name, arg_val in model_arg_overrides.items():
setattr(args, arg_name, arg_val)
return args
from fairseq import checkpoint_utils
deprecation_warning(
'utils.load_ensemble_for_inference is deprecated. '
'Please use checkpoint_utils.load_model_ensemble instead.'
)
return checkpoint_utils.load_model_ensemble(
filenames, arg_overrides=model_arg_overrides, task=task,
)
def move_to_cuda(sample):
@ -379,25 +222,6 @@ def fill_with_neg_inf(t):
return t.float().fill_(float('-inf')).type_as(t)
def checkpoint_paths(path, pattern=r'checkpoint(\d+)\.pt'):
"""Retrieves all checkpoints found in `path` directory.
Checkpoints are identified by matching filename to the specified pattern. If
the pattern contains groups, the result will be sorted by the first group in
descending order.
"""
pt_regexp = re.compile(pattern)
files = os.listdir(path)
entries = []
for i, f in enumerate(files):
m = pt_regexp.fullmatch(f)
if m is not None:
idx = int(m.group(1)) if len(m.groups()) > 0 else i
entries.append((idx, m.group(0)))
return [os.path.join(path, x[1]) for x in sorted(entries, reverse=True)]
def resolve_max_positions(*args):
"""Resolve max position constraints from multiple sources."""

View File

@ -11,9 +11,8 @@ Translate pre-processed data with a trained model.
import torch
from fairseq import bleu, options, progress_bar, tasks, utils
from fairseq import bleu, checkpoint_utils, options, progress_bar, tasks, utils
from fairseq.meters import StopwatchMeter, TimeMeter
from fairseq.utils import import_user_module
def main(args):
@ -23,7 +22,7 @@ def main(args):
assert args.replace_unk is None or args.raw_text, \
'--replace-unk requires a raw text dataset (--raw-text)'
import_user_module(args)
utils.import_user_module(args)
if args.max_tokens is None and args.max_sentences is None:
args.max_tokens = 12000
@ -34,7 +33,6 @@ def main(args):
# Load dataset splits
task = tasks.setup_task(args)
task.load_dataset(args.gen_subset)
print('| {} {} {} examples'.format(args.data, args.gen_subset, len(task.dataset(args.gen_subset))))
# Set dictionaries
try:
@ -45,8 +43,10 @@ def main(args):
# Load ensemble
print('| loading model(s) from {}'.format(args.path))
models, _model_args = utils.load_ensemble_for_inference(
args.path.split(':'), task, model_arg_overrides=eval(args.model_overrides),
models, _model_args = checkpoint_utils.load_model_ensemble(
args.path.split(':'),
arg_overrides=eval(args.model_overrides),
task=task,
)
# Optimize ensemble for generation

View File

@ -15,9 +15,8 @@ import sys
import torch
from fairseq import options, tasks, utils
from fairseq import checkpoint_utils, options, tasks, utils
from fairseq.sequence_generator import SequenceGenerator
from fairseq.utils import import_user_module
Batch = namedtuple('Batch', 'ids src_tokens src_lengths')
Translation = namedtuple('Translation', 'src_str hypos pos_scores alignments')
@ -56,7 +55,7 @@ def make_batches(lines, args, task, max_positions):
def main(args):
import_user_module(args)
utils.import_user_module(args)
if args.buffer_size < 1:
args.buffer_size = 1
@ -77,8 +76,10 @@ def main(args):
# Load ensemble
print('| loading model(s) from {}'.format(args.path))
models, _model_args = utils.load_ensemble_for_inference(
args.path.split(':'), task, model_arg_overrides=eval(args.model_overrides),
models, _model_args = checkpoint_utils.load_model_ensemble(
args.path.split(':'),
arg_overrides=eval(args.model_overrides),
task=task,
)
# Set dictionaries

View File

@ -12,10 +12,9 @@ Data pre-processing: build vocabularies and binarize training data.
from collections import Counter
from itertools import zip_longest
from fairseq import options, tasks
from fairseq import options, tasks, utils
from fairseq.data import indexed_dataset
from fairseq.binarizer import Binarizer
from fairseq.utils import import_user_module
from multiprocessing import Pool
import os
@ -23,7 +22,7 @@ import shutil
def main(args):
import_user_module(args)
utils.import_user_module(args)
print(args)

View File

@ -239,7 +239,20 @@ class TestLanguageModeling(unittest.TestCase):
with tempfile.TemporaryDirectory('test_fconv_lm') as data_dir:
create_dummy_data(data_dir)
preprocess_lm_data(data_dir)
train_language_model(data_dir, 'fconv_lm')
train_language_model(data_dir, 'fconv_lm', [
'--decoder-layers', '[(850, 3)] * 2 + [(1024,4)]',
'--decoder-embed-dim', '280',
'--optimizer', 'nag',
'--lr', '0.1',
])
eval_lm_main(data_dir)
def test_transformer_lm(self):
with contextlib.redirect_stdout(StringIO()):
with tempfile.TemporaryDirectory('test_transformer_lm') as data_dir:
create_dummy_data(data_dir)
preprocess_lm_data(data_dir)
train_language_model(data_dir, 'transformer_lm', ['--add-bos-token'])
eval_lm_main(data_dir)
@ -534,7 +547,7 @@ def preprocess_lm_data(data_dir):
preprocess.main(preprocess_args)
def train_language_model(data_dir, arch):
def train_language_model(data_dir, arch, extra_flags=None):
train_parser = options.get_training_parser()
train_args = options.parse_args_and_arch(
train_parser,
@ -542,12 +555,10 @@ def train_language_model(data_dir, arch):
'--task', 'language_modeling',
data_dir,
'--arch', arch,
'--optimizer', 'nag',
'--lr', '0.1',
'--optimizer', 'adam',
'--lr', '0.0001',
'--criterion', 'adaptive_loss',
'--adaptive-softmax-cutoff', '5,10,15',
'--decoder-layers', '[(850, 3)] * 2 + [(1024,4)]',
'--decoder-embed-dim', '280',
'--max-tokens', '500',
'--tokens-per-sample', '500',
'--save-dir', data_dir,
@ -555,7 +566,7 @@ def train_language_model(data_dir, arch):
'--no-progress-bar',
'--distributed-world-size', '1',
'--ddp-backend', 'no_c10d',
],
] + (extra_flags or []),
)
train.main(train_args)

View File

@ -38,7 +38,7 @@ class TestReproducibility(unittest.TestCase):
] + extra_flags,
)
stdout = stdout.getvalue()
train_log, valid_log = map(json.loads, stdout.split('\n')[-4:-2])
train_log, valid_log = map(json.loads, stdout.split('\n')[-5:-3])
# train epoch 2, resuming from previous checkpoint 1
os.rename(
@ -56,7 +56,7 @@ class TestReproducibility(unittest.TestCase):
] + extra_flags,
)
stdout = stdout.getvalue()
train_res_log, valid_res_log = map(json.loads, stdout.split('\n')[-4:-2])
train_res_log, valid_res_log = map(json.loads, stdout.split('\n')[-5:-3])
def cast(s):
return round(float(s), 3)

View File

@ -17,15 +17,14 @@ import random
import torch
from fairseq import distributed_utils, options, progress_bar, tasks, utils
from fairseq import checkpoint_utils, distributed_utils, options, progress_bar, tasks, utils
from fairseq.data import iterators
from fairseq.trainer import Trainer
from fairseq.meters import AverageMeter, StopwatchMeter
from fairseq.utils import import_user_module
def main(args, init_distributed=False):
import_user_module(args)
utils.import_user_module(args)
if args.max_tokens is None:
args.max_tokens = 6000
@ -326,14 +325,18 @@ def save_checkpoint(args, trainer, epoch_itr, val_loss):
if not end_of_epoch and args.keep_interval_updates > 0:
# remove old checkpoints; checkpoints are sorted in descending order
checkpoints = utils.checkpoint_paths(args.save_dir, pattern=r'checkpoint_\d+_(\d+)\.pt')
checkpoints = checkpoint_utils.checkpoint_paths(
args.save_dir, pattern=r'checkpoint_\d+_(\d+)\.pt',
)
for old_chk in checkpoints[args.keep_interval_updates:]:
if os.path.lexists(old_chk):
os.remove(old_chk)
if args.keep_last_epochs > 0:
# remove old epoch checkpoints; checkpoints are sorted in descending order
checkpoints = utils.checkpoint_paths(args.save_dir, pattern=r'checkpoint(\d+)\.pt')
checkpoints = checkpoint_utils.checkpoint_paths(
args.save_dir, pattern=r'checkpoint(\d+)\.pt',
)
for old_chk in checkpoints[args.keep_last_epochs:]:
if os.path.lexists(old_chk):
os.remove(old_chk)