fairseq/fairseq_cli/validate.py
dianaml0 0dfd6b6240 Add linting with black (#2678)
Summary:
# Before submitting

- [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements)
- [ ] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/main/CONTRIBUTING.md)?
- [ ] Did you make sure to update the docs?
- [ ] Did you write any new necessary tests?

## What does this PR do?
Fixes # (issue).

## PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

## Did you have fun?
Make sure you had fun coding �

Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/2678

Reviewed By: Mortimerp9

Differential Revision: D32653381

Pulled By: dianaml0

fbshipit-source-id: 2810d14867cd7d64f4d340740e2b590b82de47fe
2021-11-29 12:32:59 -08:00

154 lines
5.1 KiB
Python

#!/usr/bin/env python3 -u
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
import os
import sys
from argparse import Namespace
from itertools import chain
import torch
from omegaconf import DictConfig
from fairseq import checkpoint_utils, distributed_utils, options, utils
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
from fairseq.logging import metrics, progress_bar
from fairseq.utils import reset_logging
logging.basicConfig(
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
level=os.environ.get("LOGLEVEL", "INFO").upper(),
stream=sys.stdout,
)
logger = logging.getLogger("fairseq_cli.validate")
def main(cfg: DictConfig, override_args=None):
if isinstance(cfg, Namespace):
cfg = convert_namespace_to_omegaconf(cfg)
utils.import_user_module(cfg.common)
reset_logging()
assert (
cfg.dataset.max_tokens is not None or cfg.dataset.batch_size is not None
), "Must specify batch size either with --max-tokens or --batch-size"
use_fp16 = cfg.common.fp16
use_cuda = torch.cuda.is_available() and not cfg.common.cpu
if use_cuda:
torch.cuda.set_device(cfg.distributed_training.device_id)
if cfg.distributed_training.distributed_world_size > 1:
data_parallel_world_size = distributed_utils.get_data_parallel_world_size()
data_parallel_rank = distributed_utils.get_data_parallel_rank()
else:
data_parallel_world_size = 1
data_parallel_rank = 0
if override_args is not None:
overrides = vars(override_args)
overrides.update(eval(getattr(override_args, "model_overrides", "{}")))
else:
overrides = None
# Load ensemble
logger.info("loading model(s) from {}".format(cfg.common_eval.path))
models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task(
[cfg.common_eval.path],
arg_overrides=overrides,
suffix=cfg.checkpoint.checkpoint_suffix,
)
model = models[0]
# Move models to GPU
for model in models:
model.eval()
if use_fp16:
model.half()
if use_cuda:
model.cuda()
# Print args
logger.info(saved_cfg)
# Build criterion
criterion = task.build_criterion(saved_cfg.criterion)
criterion.eval()
for subset in cfg.dataset.valid_subset.split(","):
try:
task.load_dataset(subset, combine=False, epoch=1, task_cfg=saved_cfg.task)
dataset = task.dataset(subset)
except KeyError:
raise Exception("Cannot find dataset: " + subset)
# Initialize data iterator
itr = task.get_batch_iterator(
dataset=dataset,
max_tokens=cfg.dataset.max_tokens,
max_sentences=cfg.dataset.batch_size,
max_positions=utils.resolve_max_positions(
task.max_positions(),
*[m.max_positions() for m in models],
),
ignore_invalid_inputs=cfg.dataset.skip_invalid_size_inputs_valid_test,
required_batch_size_multiple=cfg.dataset.required_batch_size_multiple,
seed=cfg.common.seed,
num_shards=data_parallel_world_size,
shard_id=data_parallel_rank,
num_workers=cfg.dataset.num_workers,
data_buffer_size=cfg.dataset.data_buffer_size,
).next_epoch_itr(shuffle=False)
progress = progress_bar.progress_bar(
itr,
log_format=cfg.common.log_format,
log_interval=cfg.common.log_interval,
prefix=f"valid on '{subset}' subset",
default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"),
)
log_outputs = []
for i, sample in enumerate(progress):
sample = utils.move_to_cuda(sample) if use_cuda else sample
_loss, _sample_size, log_output = task.valid_step(sample, model, criterion)
progress.log(log_output, step=i)
log_outputs.append(log_output)
if data_parallel_world_size > 1:
log_outputs = distributed_utils.all_gather_list(
log_outputs,
max_size=cfg.common.all_gather_list_size,
group=distributed_utils.get_data_parallel_group(),
)
log_outputs = list(chain.from_iterable(log_outputs))
with metrics.aggregate() as agg:
task.reduce_metrics(log_outputs, criterion)
log_output = agg.get_smoothed_values()
progress.print(log_output, tag=subset, step=i)
def cli_main():
parser = options.get_validation_parser()
args = options.parse_args_and_arch(parser)
# only override args that are explicitly given on the command line
override_parser = options.get_validation_parser()
override_args = options.parse_args_and_arch(override_parser, suppress_defaults=True)
distributed_utils.call_main(
convert_namespace_to_omegaconf(args), main, override_args=override_args
)
if __name__ == "__main__":
cli_main()