mirror of
https://github.com/facebookresearch/fairseq.git
synced 2024-10-26 17:32:57 +03:00
Merge internal changes (#654)
Summary: - Add --add-bos-token option to LM task - Cleanup utils.py and options.py Pull Request resolved: https://github.com/pytorch/fairseq/pull/654 Differential Revision: D15041794 Pulled By: myleott fbshipit-source-id: 3ad00007769d5f48308052cfd40de39c5ffa1a6e
This commit is contained in:
parent
89a696161b
commit
d45db80431
@ -15,9 +15,15 @@ Optimizers update the Model parameters based on the gradients.
|
||||
:members:
|
||||
:undoc-members:
|
||||
|
||||
.. autoclass:: fairseq.optim.adadelta.Adadelta
|
||||
:members:
|
||||
:undoc-members:
|
||||
.. autoclass:: fairseq.optim.adagrad.Adagrad
|
||||
:members:
|
||||
:undoc-members:
|
||||
.. autoclass:: fairseq.optim.adafactor.FairseqAdafactor
|
||||
:members:
|
||||
:undoc-members:
|
||||
.. autoclass:: fairseq.optim.adam.FairseqAdam
|
||||
:members:
|
||||
:undoc-members:
|
||||
|
@ -28,11 +28,12 @@ fairseq implements the following high-level training flow::
|
||||
lr_scheduler.step_update(num_updates)
|
||||
lr_scheduler.step(epoch)
|
||||
|
||||
where the default implementation for ``train.train_step`` is roughly::
|
||||
where the default implementation for ``task.train_step`` is roughly::
|
||||
|
||||
def train_step(self, batch, model, criterion, optimizer):
|
||||
loss = criterion(model, batch)
|
||||
optimizer.backward(loss)
|
||||
return loss
|
||||
|
||||
**Registering new plug-ins**
|
||||
|
||||
|
@ -354,7 +354,7 @@ The model files should appear in the :file:`checkpoints/` directory.
|
||||
Finally we can write a short script to evaluate our model on new inputs. Create
|
||||
a new file named :file:`eval_classifier.py` with the following contents::
|
||||
|
||||
from fairseq import data, options, tasks, utils
|
||||
from fairseq import checkpoint_utils, data, options, tasks
|
||||
|
||||
# Parse command-line arguments for generation
|
||||
parser = options.get_generation_parser(default_task='simple_classification')
|
||||
@ -365,7 +365,7 @@ a new file named :file:`eval_classifier.py` with the following contents::
|
||||
|
||||
# Load model
|
||||
print('| loading model from {}'.format(args.path))
|
||||
models, _model_args = utils.load_ensemble_for_inference([args.path], task)
|
||||
models, _model_args = checkpoint_utils.load_model_ensemble([args.path], task=task)
|
||||
model = models[0]
|
||||
|
||||
while True:
|
||||
|
21
eval_lm.py
21
eval_lm.py
@ -13,11 +13,10 @@ Evaluate the perplexity of a trained language model.
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from fairseq import options, progress_bar, tasks, utils
|
||||
from fairseq import checkpoint_utils, options, progress_bar, tasks, utils
|
||||
from fairseq.data import LMContextWindowDataset
|
||||
from fairseq.meters import StopwatchMeter, TimeMeter
|
||||
from fairseq.sequence_scorer import SequenceScorer
|
||||
from fairseq.utils import import_user_module
|
||||
|
||||
|
||||
class WordStat(object):
|
||||
@ -49,7 +48,7 @@ class WordStat(object):
|
||||
def main(parsed_args):
|
||||
assert parsed_args.path is not None, '--path required for evaluation!'
|
||||
|
||||
import_user_module(parsed_args)
|
||||
utils.import_user_module(parsed_args)
|
||||
|
||||
print(parsed_args)
|
||||
|
||||
@ -59,12 +58,17 @@ def main(parsed_args):
|
||||
|
||||
# Load ensemble
|
||||
print('| loading model(s) from {}'.format(parsed_args.path))
|
||||
models, args = utils.load_ensemble_for_inference(
|
||||
parsed_args.path.split(':'), task, model_arg_overrides=eval(parsed_args.model_overrides),
|
||||
models, args = checkpoint_utils.load_model_ensemble(
|
||||
parsed_args.path.split(':'),
|
||||
arg_overrides=eval(parsed_args.model_overrides),
|
||||
task=task,
|
||||
)
|
||||
|
||||
for arg in vars(parsed_args).keys():
|
||||
if arg not in {'self_target', 'future_target', 'past_target', 'tokens_per_sample', 'output_size_dictionary'}:
|
||||
if arg not in {
|
||||
'self_target', 'future_target', 'past_target', 'tokens_per_sample',
|
||||
'output_size_dictionary', 'add_bos_token',
|
||||
}:
|
||||
setattr(args, arg, getattr(parsed_args, arg))
|
||||
|
||||
# reduce tokens per sample by the required context window size
|
||||
@ -151,6 +155,11 @@ def main(parsed_args):
|
||||
tgt_len = tokens.numel()
|
||||
pos_scores = hypo['positional_scores'].float()
|
||||
|
||||
if args.add_bos_token:
|
||||
assert hypo['tokens'][0].item() == task.target_dictionary.bos()
|
||||
tokens = tokens[1:]
|
||||
pos_scores = pos_scores[1:]
|
||||
|
||||
skipped_toks = 0
|
||||
if bpe_toks is not None:
|
||||
for i in range(tgt_len - 1):
|
||||
|
@ -39,7 +39,7 @@ $ fairseq-train --task language_modeling data-bin/wikitext-103 \
|
||||
--save-dir checkpoints/transformer_wikitext-103 --arch transformer_lm_wiki103 \
|
||||
--max-update 286000 --max-lr 1.0 --t-mult 2 --lr-period-updates 270000 --lr-scheduler cosine --lr-shrink 0.75 \
|
||||
--warmup-updates 16000 --warmup-init-lr 1e-07 --min-lr 1e-09 --optimizer nag --lr 0.0001 --clip-norm 0.1 \
|
||||
--criterion adaptive_loss --max-tokens 3072 --update-freq 4 --tokens-per-sample 3072 --seed 1 \
|
||||
--criterion adaptive_loss --max-tokens 3072 --update-freq 3 --tokens-per-sample 3072 --seed 1 \
|
||||
--sample-break-mode none --skip-invalid-size-inputs-valid-test --ddp-backend=no_c10d
|
||||
|
||||
# Evaluate:
|
||||
|
177
fairseq/checkpoint_utils.py
Normal file
177
fairseq/checkpoint_utils.py
Normal file
@ -0,0 +1,177 @@
|
||||
# Copyright (c) 2017-present, Facebook, Inc.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the LICENSE file in
|
||||
# the root directory of this source tree. An additional grant of patent rights
|
||||
# can be found in the PATENTS file in the same directory.
|
||||
|
||||
from collections import OrderedDict
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import traceback
|
||||
|
||||
import torch
|
||||
from torch.serialization import default_restore_location
|
||||
|
||||
from fairseq import tasks
|
||||
|
||||
|
||||
def load_checkpoint_to_cpu(path):
|
||||
"""Loads a checkpoint to CPU (with upgrading for backward compatibility)."""
|
||||
state = torch.load(
|
||||
path, map_location=lambda s, l: default_restore_location(s, 'cpu'),
|
||||
)
|
||||
state = _upgrade_state_dict(state)
|
||||
return state
|
||||
|
||||
|
||||
def load_model_ensemble(filenames, arg_overrides=None, task=None):
|
||||
"""Loads an ensemble of models.
|
||||
|
||||
Args:
|
||||
filenames (List[str]): checkpoint files to load
|
||||
arg_overrides (Dict[str,Any], optional): override model args that
|
||||
were used during model training
|
||||
task (fairseq.tasks.FairseqTask, optional): task to use for loading
|
||||
"""
|
||||
ensemble = []
|
||||
for filename in filenames:
|
||||
if not os.path.exists(filename):
|
||||
raise IOError('Model file not found: {}'.format(filename))
|
||||
state = load_checkpoint_to_cpu(filename)
|
||||
|
||||
args = state['args']
|
||||
if arg_overrides is not None:
|
||||
for arg_name, arg_val in arg_overrides.items():
|
||||
setattr(args, arg_name, arg_val)
|
||||
|
||||
if task is None:
|
||||
task = tasks.setup_task(args)
|
||||
|
||||
# build model for ensemble
|
||||
model = task.build_model(args)
|
||||
model.load_state_dict(state['model'], strict=True)
|
||||
ensemble.append(model)
|
||||
|
||||
return ensemble, args
|
||||
|
||||
|
||||
def checkpoint_paths(path, pattern=r'checkpoint(\d+)\.pt'):
|
||||
"""Retrieves all checkpoints found in `path` directory.
|
||||
|
||||
Checkpoints are identified by matching filename to the specified pattern. If
|
||||
the pattern contains groups, the result will be sorted by the first group in
|
||||
descending order.
|
||||
"""
|
||||
pt_regexp = re.compile(pattern)
|
||||
files = os.listdir(path)
|
||||
|
||||
entries = []
|
||||
for i, f in enumerate(files):
|
||||
m = pt_regexp.fullmatch(f)
|
||||
if m is not None:
|
||||
idx = int(m.group(1)) if len(m.groups()) > 0 else i
|
||||
entries.append((idx, m.group(0)))
|
||||
return [os.path.join(path, x[1]) for x in sorted(entries, reverse=True)]
|
||||
|
||||
|
||||
def torch_persistent_save(*args, **kwargs):
|
||||
for i in range(3):
|
||||
try:
|
||||
return torch.save(*args, **kwargs)
|
||||
except Exception:
|
||||
if i == 2:
|
||||
logging.error(traceback.format_exc())
|
||||
|
||||
|
||||
def convert_state_dict_type(state_dict, ttype=torch.FloatTensor):
|
||||
if isinstance(state_dict, dict):
|
||||
cpu_dict = OrderedDict()
|
||||
for k, v in state_dict.items():
|
||||
cpu_dict[k] = convert_state_dict_type(v)
|
||||
return cpu_dict
|
||||
elif isinstance(state_dict, list):
|
||||
return [convert_state_dict_type(v) for v in state_dict]
|
||||
elif torch.is_tensor(state_dict):
|
||||
return state_dict.type(ttype)
|
||||
else:
|
||||
return state_dict
|
||||
|
||||
|
||||
def save_state(
|
||||
filename, args, model_state_dict, criterion, optimizer, lr_scheduler,
|
||||
num_updates, optim_history=None, extra_state=None,
|
||||
):
|
||||
if optim_history is None:
|
||||
optim_history = []
|
||||
if extra_state is None:
|
||||
extra_state = {}
|
||||
state_dict = {
|
||||
'args': args,
|
||||
'model': model_state_dict if model_state_dict else {},
|
||||
'optimizer_history': optim_history + [
|
||||
{
|
||||
'criterion_name': criterion.__class__.__name__,
|
||||
'optimizer_name': optimizer.__class__.__name__,
|
||||
'lr_scheduler_state': lr_scheduler.state_dict(),
|
||||
'num_updates': num_updates,
|
||||
}
|
||||
],
|
||||
'last_optimizer_state': convert_state_dict_type(optimizer.state_dict()),
|
||||
'extra_state': extra_state,
|
||||
}
|
||||
torch_persistent_save(state_dict, filename)
|
||||
|
||||
|
||||
def _upgrade_state_dict(state):
|
||||
"""Helper for upgrading old model checkpoints."""
|
||||
# add optimizer_history
|
||||
if 'optimizer_history' not in state:
|
||||
state['optimizer_history'] = [
|
||||
{
|
||||
'criterion_name': 'CrossEntropyCriterion',
|
||||
'best_loss': state['best_loss'],
|
||||
},
|
||||
]
|
||||
state['last_optimizer_state'] = state['optimizer']
|
||||
del state['optimizer']
|
||||
del state['best_loss']
|
||||
# move extra_state into sub-dictionary
|
||||
if 'epoch' in state and 'extra_state' not in state:
|
||||
state['extra_state'] = {
|
||||
'epoch': state['epoch'],
|
||||
'batch_offset': state['batch_offset'],
|
||||
'val_loss': state['val_loss'],
|
||||
}
|
||||
del state['epoch']
|
||||
del state['batch_offset']
|
||||
del state['val_loss']
|
||||
# reduce optimizer history's memory usage (only keep the last state)
|
||||
if 'optimizer' in state['optimizer_history'][-1]:
|
||||
state['last_optimizer_state'] = state['optimizer_history'][-1]['optimizer']
|
||||
for optim_hist in state['optimizer_history']:
|
||||
del optim_hist['optimizer']
|
||||
# record the optimizer class name
|
||||
if 'optimizer_name' not in state['optimizer_history'][-1]:
|
||||
state['optimizer_history'][-1]['optimizer_name'] = 'FairseqNAG'
|
||||
# move best_loss into lr_scheduler_state
|
||||
if 'lr_scheduler_state' not in state['optimizer_history'][-1]:
|
||||
state['optimizer_history'][-1]['lr_scheduler_state'] = {
|
||||
'best': state['optimizer_history'][-1]['best_loss'],
|
||||
}
|
||||
del state['optimizer_history'][-1]['best_loss']
|
||||
# keep track of number of updates
|
||||
if 'num_updates' not in state['optimizer_history'][-1]:
|
||||
state['optimizer_history'][-1]['num_updates'] = 0
|
||||
# old model checkpoints may not have separate source/target positions
|
||||
if hasattr(state['args'], 'max_positions') and not hasattr(state['args'], 'max_source_positions'):
|
||||
state['args'].max_source_positions = state['args'].max_positions
|
||||
state['args'].max_target_positions = state['args'].max_positions
|
||||
# use stateful training data iterator
|
||||
if 'train_iterator' not in state['extra_state']:
|
||||
state['extra_state']['train_iterator'] = {
|
||||
'epoch': state['extra_state']['epoch'],
|
||||
'iterations_in_epoch': state['extra_state'].get('batch_offset', 0),
|
||||
}
|
||||
return state
|
@ -18,13 +18,12 @@ from fairseq.data import data_utils
|
||||
|
||||
class Dictionary(object):
|
||||
"""A mapping from symbols to consecutive integers"""
|
||||
def __init__(self, pad='<pad>', eos='</s>', unk='<unk>'):
|
||||
def __init__(self, pad='<pad>', eos='</s>', unk='<unk>', bos='<s>'):
|
||||
self.unk_word, self.pad_word, self.eos_word = unk, pad, eos
|
||||
self.symbols = []
|
||||
self.count = []
|
||||
self.indices = {}
|
||||
# dictionary indexing starts at 1 for consistency with Lua
|
||||
self.add_symbol('<Lua heritage>')
|
||||
self.bos_index = self.add_symbol(bos)
|
||||
self.pad_index = self.add_symbol(pad)
|
||||
self.eos_index = self.add_symbol(eos)
|
||||
self.unk_index = self.add_symbol(unk)
|
||||
@ -143,6 +142,10 @@ class Dictionary(object):
|
||||
self.symbols = list(new_symbols)
|
||||
self.indices = new_indices
|
||||
|
||||
def bos(self):
|
||||
"""Helper to get index of beginning-of-sentence symbol"""
|
||||
return self.bos_index
|
||||
|
||||
def pad(self):
|
||||
"""Helper to get index of pad symbol"""
|
||||
return self.pad_index
|
||||
|
@ -62,13 +62,14 @@ class MonolingualDataset(FairseqDataset):
|
||||
"""
|
||||
|
||||
def __init__(self, dataset, sizes, src_vocab, tgt_vocab, add_eos_for_other_targets, shuffle,
|
||||
targets=None):
|
||||
targets=None, add_bos_token=False):
|
||||
self.dataset = dataset
|
||||
self.sizes = np.array(sizes)
|
||||
self.vocab = src_vocab
|
||||
self.tgt_vocab = tgt_vocab
|
||||
self.add_eos_for_other_targets = add_eos_for_other_targets
|
||||
self.shuffle = shuffle
|
||||
self.add_bos_token = add_bos_token
|
||||
|
||||
assert targets is None or all(t in {'self', 'future', 'past'} for t in targets), \
|
||||
"targets must be none or one of 'self', 'future', 'past'"
|
||||
@ -91,6 +92,7 @@ class MonolingualDataset(FairseqDataset):
|
||||
else:
|
||||
source = self.dataset[index]
|
||||
target = None
|
||||
source, target = self._maybe_add_bos(source, target)
|
||||
return {'id': index, 'source': source, 'target': target}
|
||||
|
||||
def __len__(self):
|
||||
@ -129,6 +131,13 @@ class MonolingualDataset(FairseqDataset):
|
||||
|
||||
return source, self._filter_vocab(target)
|
||||
|
||||
def _maybe_add_bos(self, source, target):
|
||||
if self.add_bos_token:
|
||||
source = torch.cat([source.new([self.vocab.bos()]), source])
|
||||
if target is not None:
|
||||
target = torch.cat([target.new([self.tgt_vocab.bos()]), target])
|
||||
return source, target
|
||||
|
||||
def _filter_vocab(self, target):
|
||||
if len(self.tgt_vocab) != len(self.vocab):
|
||||
def _filter(target):
|
||||
@ -173,6 +182,7 @@ class MonolingualDataset(FairseqDataset):
|
||||
target = self.vocab.dummy_sentence(tgt_len + 2)
|
||||
source, past_target, future_target = target[1:-1], target[2:], target[:-2]
|
||||
source, target = self._make_source_target(source, past_target, future_target)
|
||||
source, target = self._maybe_add_bos(source, target)
|
||||
|
||||
return self.collater([
|
||||
{'id': i, 'source': source, 'target': target}
|
||||
|
@ -141,7 +141,7 @@ class FConvLanguageModel(FairseqLanguageModel):
|
||||
# make sure all arguments are present in older models
|
||||
base_lm_architecture(args)
|
||||
|
||||
if hasattr(args, 'max_target_positions'):
|
||||
if hasattr(args, 'max_target_positions') and not hasattr(args, 'tokens_per_sample'):
|
||||
args.tokens_per_sample = args.max_target_positions
|
||||
|
||||
decoder = FConvDecoder(
|
||||
|
@ -4,7 +4,6 @@
|
||||
# This source code is licensed under the license found in the LICENSE file in
|
||||
# the root directory of this source tree. An additional grant of patent rights
|
||||
# can be found in the PATENTS file in the same directory.
|
||||
#
|
||||
|
||||
import math
|
||||
|
||||
@ -12,11 +11,11 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from fairseq import checkpoint_utils
|
||||
from fairseq.modules import (
|
||||
DownsampledMultiHeadAttention, GradMultiply, LayerNorm,
|
||||
LearnedPositionalEmbedding, LinearizedConvolution,
|
||||
)
|
||||
from fairseq import utils
|
||||
|
||||
from . import (
|
||||
FairseqEncoder, CompositeEncoder, FairseqDecoder, FairseqModel,
|
||||
@ -84,8 +83,7 @@ class FConvModelSelfAtt(FairseqModel):
|
||||
pretrained = eval(args.pretrained)
|
||||
if pretrained:
|
||||
print("| loading pretrained model")
|
||||
trained_model = utils.load_ensemble_for_inference(
|
||||
# not actually for inference, but loads pretrained model parameters
|
||||
trained_model = checkpoint_utils.load_model_ensemble(
|
||||
filenames=[args.pretrained_checkpoint],
|
||||
task=task,
|
||||
)[0][0]
|
||||
|
@ -830,6 +830,7 @@ def base_lm_architecture(args):
|
||||
args.decoder_learned_pos = getattr(args, 'decoder_learned_pos', False)
|
||||
args.activation_fn = getattr(args, 'activation_fn', 'relu')
|
||||
|
||||
args.add_bos_token = getattr(args, 'add_bos_token', False)
|
||||
args.character_embeddings = getattr(args, 'character_embeddings', False)
|
||||
|
||||
args.decoder_output_dim = getattr(args, 'decoder_output_dim', args.decoder_embed_dim)
|
||||
@ -927,7 +928,7 @@ def transformer_wmt_en_de(args):
|
||||
base_architecture(args)
|
||||
|
||||
|
||||
# parameters used in the "Attention Is All You Need" paper (Vaswani, et al, 2017)
|
||||
# parameters used in the "Attention Is All You Need" paper (Vaswani et al., 2017)
|
||||
@register_model_architecture('transformer', 'transformer_vaswani_wmt_en_de_big')
|
||||
def transformer_vaswani_wmt_en_de_big(args):
|
||||
args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 1024)
|
||||
|
@ -8,7 +8,7 @@
|
||||
import os
|
||||
from typing import Any, Dict
|
||||
|
||||
from fairseq import utils
|
||||
from fairseq import checkpoint_utils
|
||||
from fairseq.data.masked_lm_dictionary import MaskedLMDictionary
|
||||
from fairseq.models.transformer import (
|
||||
TransformerDecoder,
|
||||
@ -92,7 +92,7 @@ def upgrade_state_dict_with_xlm_weights(
|
||||
if not os.path.exists(pretrained_xlm_checkpoint):
|
||||
raise IOError(f"Model file not found: {pretrained_xlm_checkpoint}")
|
||||
|
||||
state = utils.load_checkpoint_to_cpu(pretrained_xlm_checkpoint)
|
||||
state = checkpoint_utils.load_checkpoint_to_cpu(pretrained_xlm_checkpoint)
|
||||
xlm_state_dict = state["model"]
|
||||
for key in xlm_state_dict.keys():
|
||||
|
||||
|
@ -4,17 +4,15 @@
|
||||
# This source code is licensed under the license found in the LICENSE file in
|
||||
# the root directory of this source tree. An additional grant of patent rights
|
||||
# can be found in the PATENTS file in the same directory.
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
"""
|
||||
See "Gaussian Error Linear Units (GELUs)" by Dan Hendrycks and Kevin Gimpel with
|
||||
the corresponding GitHub repo: https://github.com/hendrycks/GELUs
|
||||
"""
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def gelu_fast(x):
|
||||
if not hasattr(gelu_fast, "_a"):
|
||||
|
@ -19,11 +19,15 @@ class Adadelta(FairseqOptimizer):
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
"""Add optimizer-specific arguments to the parser."""
|
||||
# fmt: off
|
||||
parser.add_argument('--adadelta-rho', type=float, default=0.9, metavar='RHO',
|
||||
help='coefficient used for computing a running average of squared gradients')
|
||||
parser.add_argument('--adadelta-eps', type=float, default=1e-6, metavar='EPS',
|
||||
help='term added to the denominator to improve numerical stability')
|
||||
parser.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD',
|
||||
help='weight decay')
|
||||
parser.add_argument('--anneal-eps', action='store_true', help='flag to anneal eps')
|
||||
# fmt: on
|
||||
|
||||
@property
|
||||
def optimizer_config(self):
|
||||
|
@ -21,6 +21,7 @@ class FairseqAdafactor(FairseqOptimizer):
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
"""Add optimizer-specific arguments to the parser."""
|
||||
# fmt: off
|
||||
parser.add_argument('--adafactor-eps', default='(1e-30, 1e-3)', metavar="E",
|
||||
help='epsilons for Adafactor optimizer')
|
||||
parser.add_argument('--clip-threshold', type=float, default=1.0, metavar="C",
|
||||
@ -31,11 +32,14 @@ class FairseqAdafactor(FairseqOptimizer):
|
||||
help='beta for first moment estimator. Optional')
|
||||
parser.add_argument('--scale-parameter', action='store_true',
|
||||
help='scale learning rate by root mean square of parameter.')
|
||||
parser.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD',
|
||||
help='weight decay')
|
||||
parser.add_argument('--warmup-init', action='store_true',
|
||||
help='use relative step for warm-up learning rate schedule')
|
||||
parser.add_argument('--relative-step', action='store_true',
|
||||
help='set learning rate to inverse square root of timestep.'
|
||||
'If false, external learning rate applied')
|
||||
# fmt: on
|
||||
|
||||
@property
|
||||
def optimizer_config(self):
|
||||
|
@ -16,6 +16,14 @@ class Adagrad(FairseqOptimizer):
|
||||
super().__init__(args, params)
|
||||
self._optimizer = torch.optim.Adagrad(params, **self.optimizer_config)
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
"""Add optimizer-specific arguments to the parser."""
|
||||
# fmt: off
|
||||
parser.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD',
|
||||
help='weight decay')
|
||||
# fmt: on
|
||||
|
||||
@property
|
||||
def optimizer_config(self):
|
||||
"""
|
||||
|
@ -30,6 +30,8 @@ class FairseqAdam(FairseqOptimizer):
|
||||
help='betas for Adam optimizer')
|
||||
parser.add_argument('--adam-eps', type=float, default=1e-8, metavar='D',
|
||||
help='epsilon for Adam optimizer')
|
||||
parser.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD',
|
||||
help='weight decay')
|
||||
# fmt: on
|
||||
|
||||
@property
|
||||
|
@ -85,6 +85,8 @@ class CosineSchedule(FairseqLRScheduler):
|
||||
help='factor to grow the length of each period')
|
||||
parser.add_argument('--lr-period-updates', default=-1, type=float, metavar='LR',
|
||||
help='initial number of updates per period')
|
||||
parser.add_argument('--lr-shrink', default=0.1, type=float, metavar='LS',
|
||||
help='shrink factor for annealing')
|
||||
# fmt: on
|
||||
|
||||
def step(self, epoch, val_loss=None):
|
||||
|
@ -30,6 +30,8 @@ class FixedSchedule(FairseqLRScheduler):
|
||||
# fmt: off
|
||||
parser.add_argument('--force-anneal', '--fa', type=int, metavar='N',
|
||||
help='force annealing at specified epoch')
|
||||
parser.add_argument('--lr-shrink', default=0.1, type=float, metavar='LS',
|
||||
help='shrink factor for annealing, lr_new = (lr * lr_shrink)')
|
||||
parser.add_argument('--warmup-updates', default=0, type=int, metavar='N',
|
||||
help='warmup the learning rate linearly for the first N updates')
|
||||
# fmt: on
|
||||
|
@ -24,6 +24,14 @@ class ReduceLROnPlateau(FairseqLRScheduler):
|
||||
self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
||||
self.optimizer.optimizer, patience=0, factor=args.lr_shrink)
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
"""Add arguments to the parser for this LR scheduler."""
|
||||
# fmt: off
|
||||
parser.add_argument('--lr-shrink', default=0.1, type=float, metavar='LS',
|
||||
help='shrink factor for annealing, lr_new = (lr * lr_shrink)')
|
||||
# fmt: on
|
||||
|
||||
def state_dict(self):
|
||||
"""Return the LR scheduler state dict."""
|
||||
return {
|
||||
|
@ -46,6 +46,8 @@ class TriangularSchedule(FairseqLRScheduler):
|
||||
help='max learning rate, must be more than args.lr')
|
||||
parser.add_argument('--lr-period-updates', default=5000, type=float, metavar='LR',
|
||||
help='initial number of updates per period (cycle length)')
|
||||
parser.add_argument('--lr-shrink', default=0.1, type=float, metavar='LS',
|
||||
help='shrink factor for annealing')
|
||||
parser.add_argument('--shrink-min', action='store_true',
|
||||
help='if set, also shrinks min lr')
|
||||
# fmt: on
|
||||
|
@ -16,6 +16,16 @@ class FairseqNAG(FairseqOptimizer):
|
||||
super().__init__(args, params)
|
||||
self._optimizer = NAG(params, **self.optimizer_config)
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
"""Add optimizer-specific arguments to the parser."""
|
||||
# fmt: off
|
||||
parser.add_argument('--momentum', default=0.99, type=float, metavar='M',
|
||||
help='momentum factor')
|
||||
parser.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD',
|
||||
help='weight decay')
|
||||
# fmt: on
|
||||
|
||||
@property
|
||||
def optimizer_config(self):
|
||||
"""
|
||||
|
@ -16,6 +16,16 @@ class SGD(FairseqOptimizer):
|
||||
super().__init__(args, params)
|
||||
self._optimizer = torch.optim.SGD(params, **self.optimizer_config)
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
"""Add optimizer-specific arguments to the parser."""
|
||||
# fmt: off
|
||||
parser.add_argument('--momentum', default=0.0, type=float, metavar='M',
|
||||
help='momentum factor')
|
||||
parser.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD',
|
||||
help='weight decay')
|
||||
# fmt: on
|
||||
|
||||
@property
|
||||
def optimizer_config(self):
|
||||
"""
|
||||
|
@ -303,19 +303,13 @@ def add_optimization_args(parser):
|
||||
metavar='LR_1,LR_2,...,LR_N',
|
||||
help='learning rate for the first N epochs; all epochs >N using LR_N'
|
||||
' (note: this may be interpreted differently depending on --lr-scheduler)')
|
||||
group.add_argument('--momentum', default=0.99, type=float, metavar='M',
|
||||
help='momentum factor')
|
||||
group.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD',
|
||||
help='weight decay')
|
||||
|
||||
# Learning rate schedulers can be found under fairseq/optim/lr_scheduler/
|
||||
group.add_argument('--lr-scheduler', default='reduce_lr_on_plateau',
|
||||
group.add_argument('--lr-scheduler', default='fixed',
|
||||
choices=LR_SCHEDULER_REGISTRY.keys(),
|
||||
help='Learning Rate Scheduler')
|
||||
group.add_argument('--lr-shrink', default=0.1, type=float, metavar='LS',
|
||||
help='learning rate shrink factor for annealing, lr_new = (lr * lr_shrink)')
|
||||
group.add_argument('--min-lr', default=-1, type=float, metavar='LR',
|
||||
help='minimum learning rate')
|
||||
help='stop training when the learning rate reaches this minimum')
|
||||
# fmt: on
|
||||
return group
|
||||
|
||||
|
@ -81,6 +81,8 @@ class LanguageModelingTask(FairseqTask):
|
||||
help='include future target')
|
||||
parser.add_argument('--past-target', action='store_true',
|
||||
help='include past target')
|
||||
parser.add_argument('--add-bos-token', action='store_true',
|
||||
help='prepend beginning of sentence token (<s>)')
|
||||
# fmt: on
|
||||
|
||||
def __init__(self, args, dictionary, output_dictionary, targets=None):
|
||||
@ -185,7 +187,7 @@ class LanguageModelingTask(FairseqTask):
|
||||
self.datasets[split] = MonolingualDataset(
|
||||
dataset, sizes, self.dictionary, self.output_dictionary,
|
||||
add_eos_for_other_targets=add_eos_for_other_targets, shuffle=True,
|
||||
targets=self.targets,
|
||||
targets=self.targets, add_bos_token=self.args.add_bos_token,
|
||||
)
|
||||
|
||||
def build_dataset_for_inference(self, src_tokens, src_lengths):
|
||||
@ -205,6 +207,7 @@ class LanguageModelingTask(FairseqTask):
|
||||
self.target_dictionary,
|
||||
add_eos_for_other_targets=False,
|
||||
shuffle=False,
|
||||
add_bos_token=self.args.add_bos_token,
|
||||
),
|
||||
eos=self.source_dictionary.eos(),
|
||||
# remove EOS since this will be used as a prefix for generation
|
||||
|
@ -8,7 +8,7 @@
|
||||
import itertools
|
||||
import os
|
||||
|
||||
from fairseq import options, utils
|
||||
from fairseq import options
|
||||
from fairseq.data import (
|
||||
ConcatDataset,
|
||||
data_utils,
|
||||
@ -69,24 +69,6 @@ class TranslationTask(FairseqTask):
|
||||
help='amount to upsample primary dataset')
|
||||
# fmt: on
|
||||
|
||||
@staticmethod
|
||||
def load_pretrained_model(path, src_dict_path, tgt_dict_path, arg_overrides=None):
|
||||
model = utils.load_checkpoint_to_cpu(path)
|
||||
args = model['args']
|
||||
state_dict = model['model']
|
||||
args = utils.override_model_args(args, arg_overrides)
|
||||
src_dict = Dictionary.load(src_dict_path)
|
||||
tgt_dict = Dictionary.load(tgt_dict_path)
|
||||
assert src_dict.pad() == tgt_dict.pad()
|
||||
assert src_dict.eos() == tgt_dict.eos()
|
||||
assert src_dict.unk() == tgt_dict.unk()
|
||||
|
||||
task = TranslationTask(args, src_dict, tgt_dict)
|
||||
model = task.build_model(args)
|
||||
model.upgrade_state_dict(state_dict)
|
||||
model.load_state_dict(state_dict, strict=True)
|
||||
return model
|
||||
|
||||
def __init__(self, args, src_dict, tgt_dict):
|
||||
super().__init__(args)
|
||||
self.src_dict = src_dict
|
||||
@ -102,6 +84,10 @@ class TranslationTask(FairseqTask):
|
||||
args.left_pad_source = options.eval_bool(args.left_pad_source)
|
||||
args.left_pad_target = options.eval_bool(args.left_pad_target)
|
||||
|
||||
# upgrade old checkpoints
|
||||
if isinstance(args.data, str):
|
||||
args.data = [args.data]
|
||||
|
||||
# find language pair automatically
|
||||
if args.source_lang is None or args.target_lang is None:
|
||||
args.source_lang, args.target_lang = data_utils.infer_language_pair(args.data[0])
|
||||
@ -147,9 +133,7 @@ class TranslationTask(FairseqTask):
|
||||
src_datasets = []
|
||||
tgt_datasets = []
|
||||
|
||||
data_paths = self.args.data
|
||||
|
||||
for dk, data_path in enumerate(data_paths):
|
||||
for dk, data_path in enumerate(self.args.data):
|
||||
for k in itertools.count():
|
||||
split_k = split + (str(k) if k > 0 else '')
|
||||
|
||||
|
@ -11,10 +11,11 @@ Train a network across multiple GPUs.
|
||||
|
||||
from collections import OrderedDict
|
||||
from itertools import chain
|
||||
import os
|
||||
|
||||
import torch
|
||||
|
||||
from fairseq import distributed_utils, models, optim, utils
|
||||
from fairseq import checkpoint_utils, distributed_utils, models, optim, utils
|
||||
from fairseq.meters import AverageMeter, StopwatchMeter, TimeMeter
|
||||
from fairseq.optim import lr_scheduler
|
||||
|
||||
@ -119,16 +120,31 @@ class Trainer(object):
|
||||
"""Save all training state in a checkpoint file."""
|
||||
if distributed_utils.is_master(self.args): # only save one checkpoint
|
||||
extra_state['train_meters'] = self.meters
|
||||
utils.save_state(
|
||||
checkpoint_utils.save_state(
|
||||
filename, self.args, self.get_model().state_dict(), self.criterion, self.optimizer,
|
||||
self.lr_scheduler, self._num_updates, self._optim_history, extra_state,
|
||||
)
|
||||
|
||||
def load_checkpoint(self, filename, reset_optimizer=False, reset_lr_scheduler=False, optimizer_overrides=None):
|
||||
"""Load all training state from a checkpoint file."""
|
||||
extra_state, self._optim_history, last_optim_state = utils.load_model_state(
|
||||
filename, self.get_model(),
|
||||
extra_state, self._optim_history, last_optim_state = None, [], None
|
||||
|
||||
if os.path.exists(filename):
|
||||
state = checkpoint_utils.load_checkpoint_to_cpu(filename)
|
||||
|
||||
# load model parameters
|
||||
try:
|
||||
self.get_model().load_state_dict(state['model'], strict=True)
|
||||
except Exception:
|
||||
raise Exception(
|
||||
'Cannot load model parameters from checkpoint, '
|
||||
'please ensure that the architectures match.'
|
||||
)
|
||||
|
||||
extra_state = state['extra_state']
|
||||
self._optim_history = state['optimizer_history']
|
||||
last_optim_state = state['last_optimizer_state']
|
||||
|
||||
if last_optim_state is not None and not reset_optimizer:
|
||||
# rebuild optimizer after loading model, since params may have changed
|
||||
self._build_optimizer()
|
||||
@ -136,9 +152,9 @@ class Trainer(object):
|
||||
# only reload optimizer and lr_scheduler if they match
|
||||
last_optim = self._optim_history[-1]
|
||||
assert last_optim['criterion_name'] == self.criterion.__class__.__name__, \
|
||||
'criterion does not match; please reset the optimizer (--reset-optimizer)'
|
||||
'Criterion does not match; please reset the optimizer (--reset-optimizer).'
|
||||
assert last_optim['optimizer_name'] == self.optimizer.__class__.__name__, \
|
||||
'optimizer does not match; please reset the optimizer (--reset-optimizer)'
|
||||
'Optimizer does not match; please reset the optimizer (--reset-optimizer).'
|
||||
|
||||
if not reset_lr_scheduler:
|
||||
self.lr_scheduler.load_state_dict(last_optim['lr_scheduler_state'])
|
||||
|
192
fairseq/utils.py
192
fairseq/utils.py
@ -9,182 +9,25 @@ from collections import defaultdict, OrderedDict
|
||||
from typing import Callable
|
||||
import copy
|
||||
import importlib.util
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import traceback
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.serialization import default_restore_location
|
||||
|
||||
from fairseq.modules import gelu, gelu_fast
|
||||
|
||||
|
||||
def torch_persistent_save(*args, **kwargs):
|
||||
for i in range(3):
|
||||
try:
|
||||
return torch.save(*args, **kwargs)
|
||||
except Exception:
|
||||
if i == 2:
|
||||
logging.error(traceback.format_exc())
|
||||
|
||||
|
||||
def convert_state_dict_type(state_dict, ttype=torch.FloatTensor):
|
||||
if isinstance(state_dict, dict):
|
||||
cpu_dict = OrderedDict()
|
||||
for k, v in state_dict.items():
|
||||
cpu_dict[k] = convert_state_dict_type(v)
|
||||
return cpu_dict
|
||||
elif isinstance(state_dict, list):
|
||||
return [convert_state_dict_type(v) for v in state_dict]
|
||||
elif torch.is_tensor(state_dict):
|
||||
return state_dict.type(ttype)
|
||||
else:
|
||||
return state_dict
|
||||
|
||||
|
||||
def save_state(filename, args, model_state_dict, criterion, optimizer, lr_scheduler,
|
||||
num_updates, optim_history=None, extra_state=None):
|
||||
if optim_history is None:
|
||||
optim_history = []
|
||||
if extra_state is None:
|
||||
extra_state = {}
|
||||
state_dict = {
|
||||
'args': args,
|
||||
'model': model_state_dict if model_state_dict else {},
|
||||
'optimizer_history': optim_history + [
|
||||
{
|
||||
'criterion_name': criterion.__class__.__name__,
|
||||
'optimizer_name': optimizer.__class__.__name__,
|
||||
'lr_scheduler_state': lr_scheduler.state_dict(),
|
||||
'num_updates': num_updates,
|
||||
}
|
||||
],
|
||||
'last_optimizer_state': convert_state_dict_type(optimizer.state_dict()),
|
||||
'extra_state': extra_state,
|
||||
}
|
||||
torch_persistent_save(state_dict, filename)
|
||||
|
||||
|
||||
def load_model_state(filename, model):
|
||||
if not os.path.exists(filename):
|
||||
return None, [], None
|
||||
state = torch.load(filename, map_location=lambda s, l: default_restore_location(s, 'cpu'))
|
||||
state = _upgrade_state_dict(state)
|
||||
model.upgrade_state_dict(state['model'])
|
||||
|
||||
# load model parameters
|
||||
try:
|
||||
model.load_state_dict(state['model'], strict=True)
|
||||
except Exception:
|
||||
raise Exception('Cannot load model parameters from checkpoint, '
|
||||
'please ensure that the architectures match')
|
||||
|
||||
return state['extra_state'], state['optimizer_history'], state['last_optimizer_state']
|
||||
|
||||
|
||||
def _upgrade_state_dict(state):
|
||||
"""Helper for upgrading old model checkpoints."""
|
||||
# add optimizer_history
|
||||
if 'optimizer_history' not in state:
|
||||
state['optimizer_history'] = [
|
||||
{
|
||||
'criterion_name': 'CrossEntropyCriterion',
|
||||
'best_loss': state['best_loss'],
|
||||
},
|
||||
]
|
||||
state['last_optimizer_state'] = state['optimizer']
|
||||
del state['optimizer']
|
||||
del state['best_loss']
|
||||
# move extra_state into sub-dictionary
|
||||
if 'epoch' in state and 'extra_state' not in state:
|
||||
state['extra_state'] = {
|
||||
'epoch': state['epoch'],
|
||||
'batch_offset': state['batch_offset'],
|
||||
'val_loss': state['val_loss'],
|
||||
}
|
||||
del state['epoch']
|
||||
del state['batch_offset']
|
||||
del state['val_loss']
|
||||
# reduce optimizer history's memory usage (only keep the last state)
|
||||
if 'optimizer' in state['optimizer_history'][-1]:
|
||||
state['last_optimizer_state'] = state['optimizer_history'][-1]['optimizer']
|
||||
for optim_hist in state['optimizer_history']:
|
||||
del optim_hist['optimizer']
|
||||
# record the optimizer class name
|
||||
if 'optimizer_name' not in state['optimizer_history'][-1]:
|
||||
state['optimizer_history'][-1]['optimizer_name'] = 'FairseqNAG'
|
||||
# move best_loss into lr_scheduler_state
|
||||
if 'lr_scheduler_state' not in state['optimizer_history'][-1]:
|
||||
state['optimizer_history'][-1]['lr_scheduler_state'] = {
|
||||
'best': state['optimizer_history'][-1]['best_loss'],
|
||||
}
|
||||
del state['optimizer_history'][-1]['best_loss']
|
||||
# keep track of number of updates
|
||||
if 'num_updates' not in state['optimizer_history'][-1]:
|
||||
state['optimizer_history'][-1]['num_updates'] = 0
|
||||
# old model checkpoints may not have separate source/target positions
|
||||
if hasattr(state['args'], 'max_positions') and not hasattr(state['args'], 'max_source_positions'):
|
||||
state['args'].max_source_positions = state['args'].max_positions
|
||||
state['args'].max_target_positions = state['args'].max_positions
|
||||
# use stateful training data iterator
|
||||
if 'train_iterator' not in state['extra_state']:
|
||||
state['extra_state']['train_iterator'] = {
|
||||
'epoch': state['extra_state']['epoch'],
|
||||
'iterations_in_epoch': state['extra_state'].get('batch_offset', 0),
|
||||
}
|
||||
return state
|
||||
|
||||
|
||||
def load_checkpoint_to_cpu(path):
|
||||
state = torch.load(path, map_location=lambda s, l: default_restore_location(s, 'cpu'))
|
||||
state = _upgrade_state_dict(state)
|
||||
return state
|
||||
|
||||
|
||||
def load_ensemble_for_inference(filenames, task, model_arg_overrides=None):
|
||||
"""Load an ensemble of models for inference.
|
||||
|
||||
model_arg_overrides allows you to pass a dictionary model_arg_overrides --
|
||||
{'arg_name': arg} -- to override model args that were used during model
|
||||
training
|
||||
"""
|
||||
# load model architectures and weights
|
||||
states = []
|
||||
for filename in filenames:
|
||||
if not os.path.exists(filename):
|
||||
raise IOError('Model file not found: {}'.format(filename))
|
||||
state = load_checkpoint_to_cpu(filename)
|
||||
states.append(state)
|
||||
|
||||
ensemble = []
|
||||
for state in states:
|
||||
args = state['args']
|
||||
|
||||
if model_arg_overrides is not None:
|
||||
args = override_model_args(args, model_arg_overrides)
|
||||
|
||||
# build model for ensemble
|
||||
model = task.build_model(args)
|
||||
model.upgrade_state_dict(state['model'])
|
||||
model.load_state_dict(state['model'], strict=True)
|
||||
ensemble.append(model)
|
||||
|
||||
# some args (e.g., tokens_per_sample) might have been updated while building the model
|
||||
if model_arg_overrides is not None:
|
||||
args = override_model_args(args, model_arg_overrides)
|
||||
|
||||
return ensemble, args
|
||||
|
||||
|
||||
def override_model_args(args, model_arg_overrides):
|
||||
# Uses model_arg_overrides {'arg_name': arg} to override model args
|
||||
for arg_name, arg_val in model_arg_overrides.items():
|
||||
setattr(args, arg_name, arg_val)
|
||||
return args
|
||||
from fairseq import checkpoint_utils
|
||||
deprecation_warning(
|
||||
'utils.load_ensemble_for_inference is deprecated. '
|
||||
'Please use checkpoint_utils.load_model_ensemble instead.'
|
||||
)
|
||||
return checkpoint_utils.load_model_ensemble(
|
||||
filenames, arg_overrides=model_arg_overrides, task=task,
|
||||
)
|
||||
|
||||
|
||||
def move_to_cuda(sample):
|
||||
@ -379,25 +222,6 @@ def fill_with_neg_inf(t):
|
||||
return t.float().fill_(float('-inf')).type_as(t)
|
||||
|
||||
|
||||
def checkpoint_paths(path, pattern=r'checkpoint(\d+)\.pt'):
|
||||
"""Retrieves all checkpoints found in `path` directory.
|
||||
|
||||
Checkpoints are identified by matching filename to the specified pattern. If
|
||||
the pattern contains groups, the result will be sorted by the first group in
|
||||
descending order.
|
||||
"""
|
||||
pt_regexp = re.compile(pattern)
|
||||
files = os.listdir(path)
|
||||
|
||||
entries = []
|
||||
for i, f in enumerate(files):
|
||||
m = pt_regexp.fullmatch(f)
|
||||
if m is not None:
|
||||
idx = int(m.group(1)) if len(m.groups()) > 0 else i
|
||||
entries.append((idx, m.group(0)))
|
||||
return [os.path.join(path, x[1]) for x in sorted(entries, reverse=True)]
|
||||
|
||||
|
||||
def resolve_max_positions(*args):
|
||||
"""Resolve max position constraints from multiple sources."""
|
||||
|
||||
|
12
generate.py
12
generate.py
@ -11,9 +11,8 @@ Translate pre-processed data with a trained model.
|
||||
|
||||
import torch
|
||||
|
||||
from fairseq import bleu, options, progress_bar, tasks, utils
|
||||
from fairseq import bleu, checkpoint_utils, options, progress_bar, tasks, utils
|
||||
from fairseq.meters import StopwatchMeter, TimeMeter
|
||||
from fairseq.utils import import_user_module
|
||||
|
||||
|
||||
def main(args):
|
||||
@ -23,7 +22,7 @@ def main(args):
|
||||
assert args.replace_unk is None or args.raw_text, \
|
||||
'--replace-unk requires a raw text dataset (--raw-text)'
|
||||
|
||||
import_user_module(args)
|
||||
utils.import_user_module(args)
|
||||
|
||||
if args.max_tokens is None and args.max_sentences is None:
|
||||
args.max_tokens = 12000
|
||||
@ -34,7 +33,6 @@ def main(args):
|
||||
# Load dataset splits
|
||||
task = tasks.setup_task(args)
|
||||
task.load_dataset(args.gen_subset)
|
||||
print('| {} {} {} examples'.format(args.data, args.gen_subset, len(task.dataset(args.gen_subset))))
|
||||
|
||||
# Set dictionaries
|
||||
try:
|
||||
@ -45,8 +43,10 @@ def main(args):
|
||||
|
||||
# Load ensemble
|
||||
print('| loading model(s) from {}'.format(args.path))
|
||||
models, _model_args = utils.load_ensemble_for_inference(
|
||||
args.path.split(':'), task, model_arg_overrides=eval(args.model_overrides),
|
||||
models, _model_args = checkpoint_utils.load_model_ensemble(
|
||||
args.path.split(':'),
|
||||
arg_overrides=eval(args.model_overrides),
|
||||
task=task,
|
||||
)
|
||||
|
||||
# Optimize ensemble for generation
|
||||
|
@ -15,9 +15,8 @@ import sys
|
||||
|
||||
import torch
|
||||
|
||||
from fairseq import options, tasks, utils
|
||||
from fairseq import checkpoint_utils, options, tasks, utils
|
||||
from fairseq.sequence_generator import SequenceGenerator
|
||||
from fairseq.utils import import_user_module
|
||||
|
||||
Batch = namedtuple('Batch', 'ids src_tokens src_lengths')
|
||||
Translation = namedtuple('Translation', 'src_str hypos pos_scores alignments')
|
||||
@ -56,7 +55,7 @@ def make_batches(lines, args, task, max_positions):
|
||||
|
||||
|
||||
def main(args):
|
||||
import_user_module(args)
|
||||
utils.import_user_module(args)
|
||||
|
||||
if args.buffer_size < 1:
|
||||
args.buffer_size = 1
|
||||
@ -77,8 +76,10 @@ def main(args):
|
||||
|
||||
# Load ensemble
|
||||
print('| loading model(s) from {}'.format(args.path))
|
||||
models, _model_args = utils.load_ensemble_for_inference(
|
||||
args.path.split(':'), task, model_arg_overrides=eval(args.model_overrides),
|
||||
models, _model_args = checkpoint_utils.load_model_ensemble(
|
||||
args.path.split(':'),
|
||||
arg_overrides=eval(args.model_overrides),
|
||||
task=task,
|
||||
)
|
||||
|
||||
# Set dictionaries
|
||||
|
@ -12,10 +12,9 @@ Data pre-processing: build vocabularies and binarize training data.
|
||||
from collections import Counter
|
||||
from itertools import zip_longest
|
||||
|
||||
from fairseq import options, tasks
|
||||
from fairseq import options, tasks, utils
|
||||
from fairseq.data import indexed_dataset
|
||||
from fairseq.binarizer import Binarizer
|
||||
from fairseq.utils import import_user_module
|
||||
from multiprocessing import Pool
|
||||
|
||||
import os
|
||||
@ -23,7 +22,7 @@ import shutil
|
||||
|
||||
|
||||
def main(args):
|
||||
import_user_module(args)
|
||||
utils.import_user_module(args)
|
||||
|
||||
print(args)
|
||||
|
||||
|
@ -239,7 +239,20 @@ class TestLanguageModeling(unittest.TestCase):
|
||||
with tempfile.TemporaryDirectory('test_fconv_lm') as data_dir:
|
||||
create_dummy_data(data_dir)
|
||||
preprocess_lm_data(data_dir)
|
||||
train_language_model(data_dir, 'fconv_lm')
|
||||
train_language_model(data_dir, 'fconv_lm', [
|
||||
'--decoder-layers', '[(850, 3)] * 2 + [(1024,4)]',
|
||||
'--decoder-embed-dim', '280',
|
||||
'--optimizer', 'nag',
|
||||
'--lr', '0.1',
|
||||
])
|
||||
eval_lm_main(data_dir)
|
||||
|
||||
def test_transformer_lm(self):
|
||||
with contextlib.redirect_stdout(StringIO()):
|
||||
with tempfile.TemporaryDirectory('test_transformer_lm') as data_dir:
|
||||
create_dummy_data(data_dir)
|
||||
preprocess_lm_data(data_dir)
|
||||
train_language_model(data_dir, 'transformer_lm', ['--add-bos-token'])
|
||||
eval_lm_main(data_dir)
|
||||
|
||||
|
||||
@ -534,7 +547,7 @@ def preprocess_lm_data(data_dir):
|
||||
preprocess.main(preprocess_args)
|
||||
|
||||
|
||||
def train_language_model(data_dir, arch):
|
||||
def train_language_model(data_dir, arch, extra_flags=None):
|
||||
train_parser = options.get_training_parser()
|
||||
train_args = options.parse_args_and_arch(
|
||||
train_parser,
|
||||
@ -542,12 +555,10 @@ def train_language_model(data_dir, arch):
|
||||
'--task', 'language_modeling',
|
||||
data_dir,
|
||||
'--arch', arch,
|
||||
'--optimizer', 'nag',
|
||||
'--lr', '0.1',
|
||||
'--optimizer', 'adam',
|
||||
'--lr', '0.0001',
|
||||
'--criterion', 'adaptive_loss',
|
||||
'--adaptive-softmax-cutoff', '5,10,15',
|
||||
'--decoder-layers', '[(850, 3)] * 2 + [(1024,4)]',
|
||||
'--decoder-embed-dim', '280',
|
||||
'--max-tokens', '500',
|
||||
'--tokens-per-sample', '500',
|
||||
'--save-dir', data_dir,
|
||||
@ -555,7 +566,7 @@ def train_language_model(data_dir, arch):
|
||||
'--no-progress-bar',
|
||||
'--distributed-world-size', '1',
|
||||
'--ddp-backend', 'no_c10d',
|
||||
],
|
||||
] + (extra_flags or []),
|
||||
)
|
||||
train.main(train_args)
|
||||
|
||||
|
@ -38,7 +38,7 @@ class TestReproducibility(unittest.TestCase):
|
||||
] + extra_flags,
|
||||
)
|
||||
stdout = stdout.getvalue()
|
||||
train_log, valid_log = map(json.loads, stdout.split('\n')[-4:-2])
|
||||
train_log, valid_log = map(json.loads, stdout.split('\n')[-5:-3])
|
||||
|
||||
# train epoch 2, resuming from previous checkpoint 1
|
||||
os.rename(
|
||||
@ -56,7 +56,7 @@ class TestReproducibility(unittest.TestCase):
|
||||
] + extra_flags,
|
||||
)
|
||||
stdout = stdout.getvalue()
|
||||
train_res_log, valid_res_log = map(json.loads, stdout.split('\n')[-4:-2])
|
||||
train_res_log, valid_res_log = map(json.loads, stdout.split('\n')[-5:-3])
|
||||
|
||||
def cast(s):
|
||||
return round(float(s), 3)
|
||||
|
13
train.py
13
train.py
@ -17,15 +17,14 @@ import random
|
||||
|
||||
import torch
|
||||
|
||||
from fairseq import distributed_utils, options, progress_bar, tasks, utils
|
||||
from fairseq import checkpoint_utils, distributed_utils, options, progress_bar, tasks, utils
|
||||
from fairseq.data import iterators
|
||||
from fairseq.trainer import Trainer
|
||||
from fairseq.meters import AverageMeter, StopwatchMeter
|
||||
from fairseq.utils import import_user_module
|
||||
|
||||
|
||||
def main(args, init_distributed=False):
|
||||
import_user_module(args)
|
||||
utils.import_user_module(args)
|
||||
|
||||
if args.max_tokens is None:
|
||||
args.max_tokens = 6000
|
||||
@ -326,14 +325,18 @@ def save_checkpoint(args, trainer, epoch_itr, val_loss):
|
||||
|
||||
if not end_of_epoch and args.keep_interval_updates > 0:
|
||||
# remove old checkpoints; checkpoints are sorted in descending order
|
||||
checkpoints = utils.checkpoint_paths(args.save_dir, pattern=r'checkpoint_\d+_(\d+)\.pt')
|
||||
checkpoints = checkpoint_utils.checkpoint_paths(
|
||||
args.save_dir, pattern=r'checkpoint_\d+_(\d+)\.pt',
|
||||
)
|
||||
for old_chk in checkpoints[args.keep_interval_updates:]:
|
||||
if os.path.lexists(old_chk):
|
||||
os.remove(old_chk)
|
||||
|
||||
if args.keep_last_epochs > 0:
|
||||
# remove old epoch checkpoints; checkpoints are sorted in descending order
|
||||
checkpoints = utils.checkpoint_paths(args.save_dir, pattern=r'checkpoint(\d+)\.pt')
|
||||
checkpoints = checkpoint_utils.checkpoint_paths(
|
||||
args.save_dir, pattern=r'checkpoint(\d+)\.pt',
|
||||
)
|
||||
for old_chk in checkpoints[args.keep_last_epochs:]:
|
||||
if os.path.lexists(old_chk):
|
||||
os.remove(old_chk)
|
||||
|
Loading…
Reference in New Issue
Block a user