Add fairseq-validate

Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/765

Differential Revision: D16763357

Pulled By: myleott

fbshipit-source-id: 758b03158e486ee82786e2d5bf4e46073b50c503
This commit is contained in:
Myle Ott 2019-08-13 13:03:40 -07:00 committed by Facebook Github Bot
parent a33ac060de
commit d015d23a1f
4 changed files with 141 additions and 1 deletions

View File

@ -48,6 +48,14 @@ def get_eval_lm_parser(default_task='language_modeling'):
return parser
def get_validation_parser(default_task=None):
parser = get_parser('Validation', default_task)
add_dataset_args(parser, train=True)
group = parser.add_argument_group('Evaluation')
add_common_eval_args(group)
return parser
def eval_str_list(x, type=float):
if x is None:
return None

View File

@ -60,8 +60,9 @@ setup(
'fairseq-generate = fairseq_cli.generate:cli_main',
'fairseq-interactive = fairseq_cli.interactive:cli_main',
'fairseq-preprocess = fairseq_cli.preprocess:cli_main',
'fairseq-train = fairseq_cli.train:cli_main',
'fairseq-score = fairseq_cli.score:main',
'fairseq-train = fairseq_cli.train:cli_main',
'fairseq-validate = fairseq_cli.validate:cli_main',
],
},
)

View File

@ -20,6 +20,7 @@ import train
import generate
import interactive
import eval_lm
import validate
class TestTranslation(unittest.TestCase):
@ -476,6 +477,21 @@ def train_translation_model(data_dir, arch, extra_flags=None, task='translation'
)
train.main(train_args)
# test validation
validate_parser = options.get_validation_parser()
validate_args = options.parse_args_and_arch(
validate_parser,
[
'--task', task,
data_dir,
'--path', os.path.join(data_dir, 'checkpoint_last.pt'),
'--valid-subset', 'valid',
'--max-tokens', '500',
'--no-progress-bar',
]
)
validate.main(validate_args)
def generate_main(data_dir, extra_flags=None):
generate_parser = options.get_generation_parser()
@ -541,6 +557,21 @@ def train_language_model(data_dir, arch, extra_flags=None):
)
train.main(train_args)
# test validation
validate_parser = options.get_validation_parser()
validate_args = options.parse_args_and_arch(
validate_parser,
[
'--task', 'language_modeling',
data_dir,
'--path', os.path.join(data_dir, 'checkpoint_last.pt'),
'--valid-subset', 'valid',
'--max-tokens', '500',
'--no-progress-bar',
]
)
validate.main(validate_args)
def eval_lm_main(data_dir):
eval_lm_parser = options.get_eval_lm_parser()

100
validate.py Normal file
View File

@ -0,0 +1,100 @@
#!/usr/bin/env python3 -u
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import torch
from fairseq import checkpoint_utils, options, progress_bar, utils
def main(args, override_args=None):
utils.import_user_module(args)
use_fp16 = args.fp16
use_cuda = torch.cuda.is_available() and not args.cpu
if override_args is not None:
overrides = vars(override_args)
overrides.update(eval(getattr(override_args, 'model_overrides', '{}')))
else:
overrides = None
# Load ensemble
print('| loading model(s) from {}'.format(args.path))
models, model_args, task = checkpoint_utils.load_model_ensemble_and_task(
[args.path],
arg_overrides=overrides,
)
model = models[0]
# Move models to GPU
for model in models:
if use_fp16:
model.half()
if use_cuda:
model.cuda()
# Print args
print(model_args)
# Build criterion
criterion = task.build_criterion(model_args)
criterion.eval()
# Load valid dataset (we load training data below, based on the latest checkpoint)
for subset in args.valid_subset.split(','):
try:
task.load_dataset(subset, combine=False, epoch=0)
dataset = task.dataset(subset)
except KeyError:
raise Exception('Cannot find dataset: ' + subset)
# Initialize data iterator
itr = task.get_batch_iterator(
dataset=dataset,
max_tokens=args.max_tokens,
max_sentences=args.max_sentences,
max_positions=utils.resolve_max_positions(
task.max_positions(),
*[m.max_positions() for m in models],
),
ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
required_batch_size_multiple=args.required_batch_size_multiple,
seed=args.seed,
num_workers=args.num_workers,
).next_epoch_itr(shuffle=False)
progress = progress_bar.build_progress_bar(
args, itr,
prefix='valid on \'{}\' subset'.format(subset),
no_progress_bar='simple'
)
log_outputs = []
for i, sample in enumerate(progress):
sample = utils.move_to_cuda(sample) if use_cuda else sample
_loss, _sample_size, log_output = task.valid_step(sample, model, criterion)
progress.log(log_output, step=i)
log_outputs.append(log_output)
log_output = task.aggregate_logging_outputs(log_outputs, criterion)
progress.print(log_output, tag=subset, step=i)
def cli_main():
parser = options.get_validation_parser()
args = options.parse_args_and_arch(parser)
# only override args that are explicitly given on the command line
override_parser = options.get_validation_parser()
override_args = options.parse_args_and_arch(override_parser, suppress_defaults=True)
main(args, override_args)
if __name__ == '__main__':
cli_main()