Move fb_pathmgr registration out of train.py

Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/903

Reviewed By: sujitoc

Differential Revision: D18327653

fbshipit-source-id: 739ddbaf54862acdf7b4f1bc3ad538bde5ae00fd
This commit is contained in:
Myle Ott 2019-11-08 04:36:43 -08:00 committed by Facebook Github Bot
parent e9171ce19b
commit e98bf7e64f
4 changed files with 9 additions and 27 deletions

View File

@ -173,8 +173,6 @@ def get_parser(desc, default_task='translation'):
parser.add_argument('--tensorboard-logdir', metavar='DIR', default='',
help='path to save logs for tensorboard, should match --logdir '
'of running tensorboard (default: no tensorboard logging)')
parser.add_argument("--tbmf-wrapper", action="store_true",
help="[FB only] ")
parser.add_argument('--seed', default=1, type=int, metavar='N',
help='pseudo random number generator seed')
parser.add_argument('--cpu', action='store_true', help='use CPU instead of CUDA')

View File

@ -16,8 +16,6 @@ import sys
from fairseq import distributed_utils
from fairseq.meters import AverageMeter, StopwatchMeter, TimeMeter
g_tbmf_wrapper = None
def build_progress_bar(args, iterator, epoch=None, prefix=None, default='tqdm', no_progress_bar='none'):
if args.log_format is None:
@ -37,17 +35,14 @@ def build_progress_bar(args, iterator, epoch=None, prefix=None, default='tqdm',
else:
raise ValueError('Unknown log format: {}'.format(args.log_format))
if args.tbmf_wrapper and distributed_utils.is_master(args):
global g_tbmf_wrapper
if g_tbmf_wrapper is None:
try:
from fairseq.fb_tbmf_wrapper import fb_tbmf_wrapper
except Exception:
raise ImportError("fb_tbmf_wrapper package not found.")
g_tbmf_wrapper = fb_tbmf_wrapper
bar = g_tbmf_wrapper(bar, args, args.log_interval)
elif args.tensorboard_logdir and distributed_utils.is_master(args):
bar = tensorboard_log_wrapper(bar, args.tensorboard_logdir, args)
if args.tensorboard_logdir and distributed_utils.is_master(args):
try:
# [FB only] custom wrapper for TensorBoard
import palaas # noqa
from fairseq.fb_tbmf_wrapper import fb_tbmf_wrapper
bar = fb_tbmf_wrapper(bar, args, args.log_interval)
except ImportError:
bar = tensorboard_log_wrapper(bar, args.tensorboard_logdir, args)
return bar

View File

@ -173,7 +173,7 @@ class Trainer(object):
try:
from fairseq.fb_pathmgr import fb_pathmgr
bexists = fb_pathmgr.isfile(filename)
except Exception:
except (ModuleNotFoundError, ImportError):
bexists = os.path.exists(filename)
if bexists:

View File

@ -19,21 +19,10 @@ from fairseq.data import iterators
from fairseq.trainer import Trainer
from fairseq.meters import AverageMeter, StopwatchMeter
fb_pathmgr_registerd = False
def main(args, init_distributed=False):
utils.import_user_module(args)
try:
from fairseq.fb_pathmgr import fb_pathmgr
global fb_pathmgr_registerd
if not fb_pathmgr_registerd:
fb_pathmgr.register()
fb_pathmgr_registerd = True
except (ModuleNotFoundError, ImportError):
pass
assert args.max_tokens is not None or args.max_sentences is not None, \
'Must specify batch size either with --max-tokens or --max-sentences'