2019-08-13 23:03:40 +03:00
|
|
|
#!/usr/bin/env python3 -u
|
2019-08-20 01:04:41 +03:00
|
|
|
#!/usr/bin/env python3 -u
|
|
|
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
2019-08-13 23:03:40 +03:00
|
|
|
#
|
2019-08-20 01:04:41 +03:00
|
|
|
# This source code is licensed under the MIT license found in the
|
|
|
|
# LICENSE file in the root directory of this source tree.
|
2019-08-13 23:03:40 +03:00
|
|
|
|
2020-01-17 03:12:45 +03:00
|
|
|
import logging
|
|
|
|
import sys
|
|
|
|
|
2019-08-13 23:03:40 +03:00
|
|
|
import torch
|
|
|
|
|
2020-02-27 19:19:48 +03:00
|
|
|
from fairseq import checkpoint_utils, options, utils
|
|
|
|
from fairseq.logging import metrics, progress_bar
|
2019-08-13 23:03:40 +03:00
|
|
|
|
2020-01-17 03:12:45 +03:00
|
|
|
logging.basicConfig(
|
|
|
|
format='%(asctime)s | %(levelname)s | %(name)s | %(message)s',
|
|
|
|
datefmt='%Y-%m-%d %H:%M:%S',
|
|
|
|
level=logging.INFO,
|
|
|
|
stream=sys.stdout,
|
|
|
|
)
|
2020-03-04 06:55:31 +03:00
|
|
|
logger = logging.getLogger('fairseq_cli.validate')
|
2020-01-17 03:12:45 +03:00
|
|
|
|
2019-08-13 23:03:40 +03:00
|
|
|
|
|
|
|
def main(args, override_args=None):
|
|
|
|
utils.import_user_module(args)
|
|
|
|
|
2020-01-10 00:32:11 +03:00
|
|
|
assert args.max_tokens is not None or args.max_sentences is not None, \
|
|
|
|
'Must specify batch size either with --max-tokens or --max-sentences'
|
|
|
|
|
2019-08-13 23:03:40 +03:00
|
|
|
use_fp16 = args.fp16
|
|
|
|
use_cuda = torch.cuda.is_available() and not args.cpu
|
|
|
|
|
|
|
|
if override_args is not None:
|
|
|
|
overrides = vars(override_args)
|
|
|
|
overrides.update(eval(getattr(override_args, 'model_overrides', '{}')))
|
|
|
|
else:
|
|
|
|
overrides = None
|
|
|
|
|
|
|
|
# Load ensemble
|
2020-01-17 03:12:45 +03:00
|
|
|
logger.info('loading model(s) from {}'.format(args.path))
|
2019-08-13 23:03:40 +03:00
|
|
|
models, model_args, task = checkpoint_utils.load_model_ensemble_and_task(
|
|
|
|
[args.path],
|
|
|
|
arg_overrides=overrides,
|
|
|
|
)
|
|
|
|
model = models[0]
|
|
|
|
|
|
|
|
# Move models to GPU
|
|
|
|
for model in models:
|
|
|
|
if use_fp16:
|
|
|
|
model.half()
|
|
|
|
if use_cuda:
|
|
|
|
model.cuda()
|
|
|
|
|
|
|
|
# Print args
|
2020-01-17 03:12:45 +03:00
|
|
|
logger.info(model_args)
|
2019-08-13 23:03:40 +03:00
|
|
|
|
|
|
|
# Build criterion
|
|
|
|
criterion = task.build_criterion(model_args)
|
|
|
|
criterion.eval()
|
|
|
|
|
|
|
|
for subset in args.valid_subset.split(','):
|
|
|
|
try:
|
2020-03-05 03:34:53 +03:00
|
|
|
task.load_dataset(subset, combine=False, epoch=1)
|
2019-08-13 23:03:40 +03:00
|
|
|
dataset = task.dataset(subset)
|
|
|
|
except KeyError:
|
|
|
|
raise Exception('Cannot find dataset: ' + subset)
|
|
|
|
|
|
|
|
# Initialize data iterator
|
|
|
|
itr = task.get_batch_iterator(
|
|
|
|
dataset=dataset,
|
|
|
|
max_tokens=args.max_tokens,
|
|
|
|
max_sentences=args.max_sentences,
|
|
|
|
max_positions=utils.resolve_max_positions(
|
|
|
|
task.max_positions(),
|
|
|
|
*[m.max_positions() for m in models],
|
|
|
|
),
|
|
|
|
ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
|
|
|
|
required_batch_size_multiple=args.required_batch_size_multiple,
|
|
|
|
seed=args.seed,
|
|
|
|
num_workers=args.num_workers,
|
|
|
|
).next_epoch_itr(shuffle=False)
|
2020-02-27 19:19:48 +03:00
|
|
|
progress = progress_bar.progress_bar(
|
|
|
|
itr,
|
|
|
|
log_format=args.log_format,
|
|
|
|
log_interval=args.log_interval,
|
|
|
|
prefix=f"valid on '{subset}' subset",
|
|
|
|
default_log_format=('tqdm' if not args.no_progress_bar else 'simple'),
|
2019-08-13 23:03:40 +03:00
|
|
|
)
|
|
|
|
|
|
|
|
log_outputs = []
|
|
|
|
for i, sample in enumerate(progress):
|
|
|
|
sample = utils.move_to_cuda(sample) if use_cuda else sample
|
|
|
|
_loss, _sample_size, log_output = task.valid_step(sample, model, criterion)
|
|
|
|
progress.log(log_output, step=i)
|
|
|
|
log_outputs.append(log_output)
|
|
|
|
|
2020-01-12 00:47:23 +03:00
|
|
|
with metrics.aggregate() as agg:
|
|
|
|
task.reduce_metrics(log_outputs, criterion)
|
|
|
|
log_output = agg.get_smoothed_values()
|
2019-08-13 23:03:40 +03:00
|
|
|
|
|
|
|
progress.print(log_output, tag=subset, step=i)
|
|
|
|
|
|
|
|
|
|
|
|
def cli_main():
|
|
|
|
parser = options.get_validation_parser()
|
|
|
|
args = options.parse_args_and_arch(parser)
|
|
|
|
|
|
|
|
# only override args that are explicitly given on the command line
|
|
|
|
override_parser = options.get_validation_parser()
|
|
|
|
override_args = options.parse_args_and_arch(override_parser, suppress_defaults=True)
|
|
|
|
|
|
|
|
main(args, override_args)
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
cli_main()
|