Add Tensorboard support (#530)

Summary:
Enable with the `--tensorboard-logdir` option.
Pull Request resolved: https://github.com/pytorch/fairseq/pull/530

Differential Revision: D14218430

Pulled By: myleott

fbshipit-source-id: e7a54f66f928e3bb02ae03fda09b22fa4fa7d053
This commit is contained in:
Myle Ott 2019-02-25 18:34:10 -08:00 committed by Facebook Github Bot
parent 65c1903e4e
commit 44d27e645b
4 changed files with 115 additions and 55 deletions

View File

@ -138,6 +138,9 @@ def get_parser(desc, default_task='translation'):
help='log progress every N batches (when progress bar is disabled)') help='log progress every N batches (when progress bar is disabled)')
parser.add_argument('--log-format', default=None, help='log format to use', parser.add_argument('--log-format', default=None, help='log format to use',
choices=['json', 'none', 'simple', 'tqdm']) choices=['json', 'none', 'simple', 'tqdm'])
parser.add_argument('--tensorboard-logdir', metavar='DIR', default='',
help='path to save logs for tensorboard, should match --logdir '
'of running tensorboard (default: no tensorboard logging)')
parser.add_argument('--seed', default=1, type=int, metavar='N', parser.add_argument('--seed', default=1, type=int, metavar='N',
help='pseudo random number generator seed') help='pseudo random number generator seed')
parser.add_argument('--cpu', action='store_true', help='use CPU instead of CUDA') parser.add_argument('--cpu', action='store_true', help='use CPU instead of CUDA')

View File

@ -12,11 +12,14 @@ Wrapper around various loggers and progress bars (e.g., tqdm).
from collections import OrderedDict from collections import OrderedDict
import json import json
from numbers import Number from numbers import Number
import os
import re
import sys import sys
from tqdm import tqdm from tqdm import tqdm
from fairseq.meters import AverageMeter from fairseq import distributed_utils
from fairseq.meters import AverageMeter, StopwatchMeter, TimeMeter
def build_progress_bar(args, iterator, epoch=None, prefix=None, default='tqdm', no_progress_bar='none'): def build_progress_bar(args, iterator, epoch=None, prefix=None, default='tqdm', no_progress_bar='none'):
@ -36,9 +39,25 @@ def build_progress_bar(args, iterator, epoch=None, prefix=None, default='tqdm',
bar = tqdm_progress_bar(iterator, epoch, prefix) bar = tqdm_progress_bar(iterator, epoch, prefix)
else: else:
raise ValueError('Unknown log format: {}'.format(args.log_format)) raise ValueError('Unknown log format: {}'.format(args.log_format))
if args.tensorboard_logdir and distributed_utils.is_master(args):
bar = tensorboard_log_wrapper(bar, args.tensorboard_logdir)
return bar return bar
def format_stat(stat):
if isinstance(stat, Number):
stat = '{:g}'.format(stat)
elif isinstance(stat, AverageMeter):
stat = '{:.3f}'.format(stat.avg)
elif isinstance(stat, TimeMeter):
stat = '{:g}'.format(round(stat.avg))
elif isinstance(stat, StopwatchMeter):
stat = '{:g}'.format(round(stat.sum))
return stat
class progress_bar(object): class progress_bar(object):
"""Abstract class for progress bars.""" """Abstract class for progress bars."""
def __init__(self, iterable, epoch=None, prefix=None): def __init__(self, iterable, epoch=None, prefix=None):
@ -59,11 +78,11 @@ class progress_bar(object):
def __iter__(self): def __iter__(self):
raise NotImplementedError raise NotImplementedError
def log(self, stats): def log(self, stats, tag='', step=None):
"""Log intermediate stats according to log_interval.""" """Log intermediate stats according to log_interval."""
raise NotImplementedError raise NotImplementedError
def print(self, stats): def print(self, stats, tag='', step=None):
"""Print end-of-epoch stats.""" """Print end-of-epoch stats."""
raise NotImplementedError raise NotImplementedError
@ -79,17 +98,7 @@ class progress_bar(object):
postfix = OrderedDict(stats) postfix = OrderedDict(stats)
# Preprocess stats according to datatype # Preprocess stats according to datatype
for key in postfix.keys(): for key in postfix.keys():
# Number: limit the length of the string postfix[key] = str(format_stat(postfix[key]))
if isinstance(postfix[key], Number):
postfix[key] = '{:g}'.format(postfix[key])
# Meter: display both current and average value
elif isinstance(postfix[key], AverageMeter):
postfix[key] = '{:.2f} ({:.2f})'.format(
postfix[key].val, postfix[key].avg)
# Else for any other type, try to get the string conversion
elif not isinstance(postfix[key], str):
postfix[key] = str(postfix[key])
# Else if it's a string, don't need to preprocess anything
return postfix return postfix
@ -111,13 +120,15 @@ class json_progress_bar(progress_bar):
stats = self._format_stats(self.stats, epoch=self.epoch, update=update) stats = self._format_stats(self.stats, epoch=self.epoch, update=update)
print(json.dumps(stats), flush=True) print(json.dumps(stats), flush=True)
def log(self, stats): def log(self, stats, tag='', step=None):
"""Log intermediate stats according to log_interval.""" """Log intermediate stats according to log_interval."""
self.stats = stats self.stats = stats
def print(self, stats): def print(self, stats, tag='', step=None):
"""Print end-of-epoch stats.""" """Print end-of-epoch stats."""
self.stats = stats self.stats = stats
if tag != '':
self.stats = OrderedDict([(tag + '_' + k, v) for k, v in self.stats.items()])
stats = self._format_stats(self.stats, epoch=self.epoch) stats = self._format_stats(self.stats, epoch=self.epoch)
print(json.dumps(stats), flush=True) print(json.dumps(stats), flush=True)
@ -126,15 +137,10 @@ class json_progress_bar(progress_bar):
if epoch is not None: if epoch is not None:
postfix['epoch'] = epoch postfix['epoch'] = epoch
if update is not None: if update is not None:
postfix['update'] = update postfix['update'] = round(update, 3)
# Preprocess stats according to datatype # Preprocess stats according to datatype
for key in stats.keys(): for key in stats.keys():
# Meter: display both current and average value postfix[key] = format_stat(stats[key])
if isinstance(stats[key], AverageMeter):
postfix[key] = stats[key].val
postfix[key + '_avg'] = stats[key].avg
else:
postfix[key] = stats[key]
return postfix return postfix
@ -148,11 +154,11 @@ class noop_progress_bar(progress_bar):
for obj in self.iterable: for obj in self.iterable:
yield obj yield obj
def log(self, stats): def log(self, stats, tag='', step=None):
"""Log intermediate stats according to log_interval.""" """Log intermediate stats according to log_interval."""
pass pass
def print(self, stats): def print(self, stats, tag='', step=None):
"""Print end-of-epoch stats.""" """Print end-of-epoch stats."""
pass pass
@ -175,11 +181,11 @@ class simple_progress_bar(progress_bar):
print('{}: {:5d} / {:d} {}'.format(self.prefix, i, size, postfix), print('{}: {:5d} / {:d} {}'.format(self.prefix, i, size, postfix),
flush=True) flush=True)
def log(self, stats): def log(self, stats, tag='', step=None):
"""Log intermediate stats according to log_interval.""" """Log intermediate stats according to log_interval."""
self.stats = self._format_stats(stats) self.stats = self._format_stats(stats)
def print(self, stats): def print(self, stats, tag='', step=None):
"""Print end-of-epoch stats.""" """Print end-of-epoch stats."""
postfix = self._str_pipes(self._format_stats(stats)) postfix = self._str_pipes(self._format_stats(stats))
print('{} | {}'.format(self.prefix, postfix), flush=True) print('{} | {}'.format(self.prefix, postfix), flush=True)
@ -195,11 +201,62 @@ class tqdm_progress_bar(progress_bar):
def __iter__(self): def __iter__(self):
return iter(self.tqdm) return iter(self.tqdm)
def log(self, stats): def log(self, stats, tag='', step=None):
"""Log intermediate stats according to log_interval.""" """Log intermediate stats according to log_interval."""
self.tqdm.set_postfix(self._format_stats(stats), refresh=False) self.tqdm.set_postfix(self._format_stats(stats), refresh=False)
def print(self, stats): def print(self, stats, tag='', step=None):
"""Print end-of-epoch stats.""" """Print end-of-epoch stats."""
postfix = self._str_pipes(self._format_stats(stats)) postfix = self._str_pipes(self._format_stats(stats))
self.tqdm.write('{} | {}'.format(self.tqdm.desc, postfix)) self.tqdm.write('{} | {}'.format(self.tqdm.desc, postfix))
class tensorboard_log_wrapper(progress_bar):
"""Log to tensorboard."""
def __init__(self, wrapped_bar, tensorboard_logdir):
self.wrapped_bar = wrapped_bar
self.tensorboard_logdir = tensorboard_logdir
try:
from tensorboardX import SummaryWriter
self.SummaryWriter = SummaryWriter
self._writers = {}
except ImportError:
print("tensorboard or required dependencies not found, "
"please see README for using tensorboard.")
self.SummaryWriter = None
def _writer(self, key):
if self.SummaryWriter is None:
return None
if key not in self._writers:
self._writers[key] = self.SummaryWriter(
log_dir=os.path.join(self.tensorboard_logdir, key),
)
return self._writers[key]
def __iter__(self):
return iter(self.wrapped_bar)
def log(self, stats, tag='', step=None):
"""Log intermediate stats to tensorboard."""
self._log_to_tensorboard(stats, tag, step)
self.wrapped_bar.log(stats, tag=tag, step=step)
def print(self, stats, tag='', step=None):
"""Print end-of-epoch stats."""
self._log_to_tensorboard(stats, tag, step)
self.wrapped_bar.print(stats, tag=tag, step=step)
def _log_to_tensorboard(self, stats, tag='', step=None):
writer = self._writer(tag)
if writer is None:
return
if step is None:
step = stats['num_updates']
for key in stats.keys() - {'num_updates'}:
if isinstance(stats[key], AverageMeter):
writer.add_scalar(key, stats[key].val, step)
elif isinstance(stats[key], Number):
writer.add_scalar(key, stats[key], step)

View File

@ -61,9 +61,9 @@ class TestReproducibility(unittest.TestCase):
def cast(s): def cast(s):
return round(float(s), 3) return round(float(s), 3)
for k in ['loss', 'ppl', 'num_updates', 'gnorm']: for k in ['train_loss', 'train_ppl', 'train_num_updates', 'train_gnorm']:
self.assertEqual(cast(train_log[k]), cast(train_res_log[k])) self.assertEqual(cast(train_log[k]), cast(train_res_log[k]))
for k in ['valid_loss', 'valid_ppl', 'num_updates', 'best']: for k in ['valid_loss', 'valid_ppl', 'valid_num_updates', 'valid_best_loss']:
self.assertEqual(cast(valid_log[k]), cast(valid_res_log[k])) self.assertEqual(cast(valid_log[k]), cast(valid_res_log[k]))
def test_reproducibility(self): def test_reproducibility(self):

View File

@ -150,7 +150,7 @@ def train(args, trainer, task, epoch_itr):
else: else:
extra_meters[k].update(v) extra_meters[k].update(v)
stats[k] = extra_meters[k].avg stats[k] = extra_meters[k].avg
progress.log(stats) progress.log(stats, tag='train', step=stats['num_updates'])
# ignore the first mini-batch in words-per-second calculation # ignore the first mini-batch in words-per-second calculation
if i == 0: if i == 0:
@ -168,7 +168,7 @@ def train(args, trainer, task, epoch_itr):
stats = get_training_stats(trainer) stats = get_training_stats(trainer)
for k, meter in extra_meters.items(): for k, meter in extra_meters.items():
stats[k] = meter.avg stats[k] = meter.avg
progress.print(stats) progress.print(stats, tag='train', step=stats['num_updates'])
# reset training meters # reset training meters
for k in [ for k in [
@ -181,26 +181,26 @@ def train(args, trainer, task, epoch_itr):
def get_training_stats(trainer): def get_training_stats(trainer):
stats = collections.OrderedDict() stats = collections.OrderedDict()
stats['loss'] = '{:.3f}'.format(trainer.get_meter('train_loss').avg) stats['loss'] = trainer.get_meter('train_loss')
if trainer.get_meter('train_nll_loss').count > 0: if trainer.get_meter('train_nll_loss').count > 0:
nll_loss = trainer.get_meter('train_nll_loss').avg nll_loss = trainer.get_meter('train_nll_loss')
stats['nll_loss'] = '{:.3f}'.format(nll_loss) stats['nll_loss'] = nll_loss
else: else:
nll_loss = trainer.get_meter('train_loss').avg nll_loss = trainer.get_meter('train_loss')
stats['ppl'] = get_perplexity(nll_loss) stats['ppl'] = get_perplexity(nll_loss.avg)
stats['wps'] = round(trainer.get_meter('wps').avg) stats['wps'] = trainer.get_meter('wps')
stats['ups'] = '{:.1f}'.format(trainer.get_meter('ups').avg) stats['ups'] = trainer.get_meter('ups')
stats['wpb'] = round(trainer.get_meter('wpb').avg) stats['wpb'] = trainer.get_meter('wpb')
stats['bsz'] = round(trainer.get_meter('bsz').avg) stats['bsz'] = trainer.get_meter('bsz')
stats['num_updates'] = trainer.get_num_updates() stats['num_updates'] = trainer.get_num_updates()
stats['lr'] = trainer.get_lr() stats['lr'] = trainer.get_lr()
stats['gnorm'] = '{:.3f}'.format(trainer.get_meter('gnorm').avg) stats['gnorm'] = trainer.get_meter('gnorm')
stats['clip'] = '{:.0%}'.format(trainer.get_meter('clip').avg) stats['clip'] = trainer.get_meter('clip')
stats['oom'] = trainer.get_meter('oom').avg stats['oom'] = trainer.get_meter('oom')
if trainer.get_meter('loss_scale') is not None: if trainer.get_meter('loss_scale') is not None:
stats['loss_scale'] = '{:.3f}'.format(trainer.get_meter('loss_scale').avg) stats['loss_scale'] = trainer.get_meter('loss_scale')
stats['wall'] = round(trainer.get_meter('wall').elapsed_time) stats['wall'] = round(trainer.get_meter('wall').elapsed_time)
stats['train_wall'] = round(trainer.get_meter('train_wall').sum) stats['train_wall'] = trainer.get_meter('train_wall')
return stats return stats
@ -249,24 +249,24 @@ def validate(args, trainer, task, epoch_itr, subsets):
stats = get_valid_stats(trainer) stats = get_valid_stats(trainer)
for k, meter in extra_meters.items(): for k, meter in extra_meters.items():
stats[k] = meter.avg stats[k] = meter.avg
progress.print(stats) progress.print(stats, tag=subset, step=trainer.get_num_updates())
valid_losses.append(stats['valid_loss']) valid_losses.append(stats['loss'].avg)
return valid_losses return valid_losses
def get_valid_stats(trainer): def get_valid_stats(trainer):
stats = collections.OrderedDict() stats = collections.OrderedDict()
stats['valid_loss'] = trainer.get_meter('valid_loss').avg stats['loss'] = trainer.get_meter('valid_loss')
if trainer.get_meter('valid_nll_loss').count > 0: if trainer.get_meter('valid_nll_loss').count > 0:
nll_loss = trainer.get_meter('valid_nll_loss').avg nll_loss = trainer.get_meter('valid_nll_loss')
stats['valid_nll_loss'] = nll_loss stats['nll_loss'] = nll_loss
else: else:
nll_loss = trainer.get_meter('valid_loss').avg nll_loss = stats['loss']
stats['valid_ppl'] = get_perplexity(nll_loss) stats['ppl'] = get_perplexity(nll_loss.avg)
stats['num_updates'] = trainer.get_num_updates() stats['num_updates'] = trainer.get_num_updates()
if hasattr(save_checkpoint, 'best'): if hasattr(save_checkpoint, 'best'):
stats['best'] = min(save_checkpoint.best, stats['valid_loss']) stats['best_loss'] = min(save_checkpoint.best, stats['loss'].avg)
return stats return stats