Add Tensorboard support (#530)

Summary:
Enable with the `--tensorboard-logdir` option.
Pull Request resolved: https://github.com/pytorch/fairseq/pull/530

Differential Revision: D14218430

Pulled By: myleott

fbshipit-source-id: e7a54f66f928e3bb02ae03fda09b22fa4fa7d053
This commit is contained in:
Myle Ott 2019-02-25 18:34:10 -08:00 committed by Facebook Github Bot
parent 65c1903e4e
commit 44d27e645b
4 changed files with 115 additions and 55 deletions

View File

@ -138,6 +138,9 @@ def get_parser(desc, default_task='translation'):
help='log progress every N batches (when progress bar is disabled)')
parser.add_argument('--log-format', default=None, help='log format to use',
choices=['json', 'none', 'simple', 'tqdm'])
parser.add_argument('--tensorboard-logdir', metavar='DIR', default='',
help='path to save logs for tensorboard, should match --logdir '
'of running tensorboard (default: no tensorboard logging)')
parser.add_argument('--seed', default=1, type=int, metavar='N',
help='pseudo random number generator seed')
parser.add_argument('--cpu', action='store_true', help='use CPU instead of CUDA')

View File

@ -12,11 +12,14 @@ Wrapper around various loggers and progress bars (e.g., tqdm).
from collections import OrderedDict
import json
from numbers import Number
import os
import re
import sys
from tqdm import tqdm
from fairseq.meters import AverageMeter
from fairseq import distributed_utils
from fairseq.meters import AverageMeter, StopwatchMeter, TimeMeter
def build_progress_bar(args, iterator, epoch=None, prefix=None, default='tqdm', no_progress_bar='none'):
@ -36,9 +39,25 @@ def build_progress_bar(args, iterator, epoch=None, prefix=None, default='tqdm',
bar = tqdm_progress_bar(iterator, epoch, prefix)
else:
raise ValueError('Unknown log format: {}'.format(args.log_format))
if args.tensorboard_logdir and distributed_utils.is_master(args):
bar = tensorboard_log_wrapper(bar, args.tensorboard_logdir)
return bar
def format_stat(stat):
if isinstance(stat, Number):
stat = '{:g}'.format(stat)
elif isinstance(stat, AverageMeter):
stat = '{:.3f}'.format(stat.avg)
elif isinstance(stat, TimeMeter):
stat = '{:g}'.format(round(stat.avg))
elif isinstance(stat, StopwatchMeter):
stat = '{:g}'.format(round(stat.sum))
return stat
class progress_bar(object):
"""Abstract class for progress bars."""
def __init__(self, iterable, epoch=None, prefix=None):
@ -59,11 +78,11 @@ class progress_bar(object):
def __iter__(self):
raise NotImplementedError
def log(self, stats):
def log(self, stats, tag='', step=None):
"""Log intermediate stats according to log_interval."""
raise NotImplementedError
def print(self, stats):
def print(self, stats, tag='', step=None):
"""Print end-of-epoch stats."""
raise NotImplementedError
@ -79,17 +98,7 @@ class progress_bar(object):
postfix = OrderedDict(stats)
# Preprocess stats according to datatype
for key in postfix.keys():
# Number: limit the length of the string
if isinstance(postfix[key], Number):
postfix[key] = '{:g}'.format(postfix[key])
# Meter: display both current and average value
elif isinstance(postfix[key], AverageMeter):
postfix[key] = '{:.2f} ({:.2f})'.format(
postfix[key].val, postfix[key].avg)
# Else for any other type, try to get the string conversion
elif not isinstance(postfix[key], str):
postfix[key] = str(postfix[key])
# Else if it's a string, don't need to preprocess anything
postfix[key] = str(format_stat(postfix[key]))
return postfix
@ -111,13 +120,15 @@ class json_progress_bar(progress_bar):
stats = self._format_stats(self.stats, epoch=self.epoch, update=update)
print(json.dumps(stats), flush=True)
def log(self, stats):
def log(self, stats, tag='', step=None):
"""Log intermediate stats according to log_interval."""
self.stats = stats
def print(self, stats):
def print(self, stats, tag='', step=None):
"""Print end-of-epoch stats."""
self.stats = stats
if tag != '':
self.stats = OrderedDict([(tag + '_' + k, v) for k, v in self.stats.items()])
stats = self._format_stats(self.stats, epoch=self.epoch)
print(json.dumps(stats), flush=True)
@ -126,15 +137,10 @@ class json_progress_bar(progress_bar):
if epoch is not None:
postfix['epoch'] = epoch
if update is not None:
postfix['update'] = update
postfix['update'] = round(update, 3)
# Preprocess stats according to datatype
for key in stats.keys():
# Meter: display both current and average value
if isinstance(stats[key], AverageMeter):
postfix[key] = stats[key].val
postfix[key + '_avg'] = stats[key].avg
else:
postfix[key] = stats[key]
postfix[key] = format_stat(stats[key])
return postfix
@ -148,11 +154,11 @@ class noop_progress_bar(progress_bar):
for obj in self.iterable:
yield obj
def log(self, stats):
def log(self, stats, tag='', step=None):
"""Log intermediate stats according to log_interval."""
pass
def print(self, stats):
def print(self, stats, tag='', step=None):
"""Print end-of-epoch stats."""
pass
@ -175,11 +181,11 @@ class simple_progress_bar(progress_bar):
print('{}: {:5d} / {:d} {}'.format(self.prefix, i, size, postfix),
flush=True)
def log(self, stats):
def log(self, stats, tag='', step=None):
"""Log intermediate stats according to log_interval."""
self.stats = self._format_stats(stats)
def print(self, stats):
def print(self, stats, tag='', step=None):
"""Print end-of-epoch stats."""
postfix = self._str_pipes(self._format_stats(stats))
print('{} | {}'.format(self.prefix, postfix), flush=True)
@ -195,11 +201,62 @@ class tqdm_progress_bar(progress_bar):
def __iter__(self):
return iter(self.tqdm)
def log(self, stats):
def log(self, stats, tag='', step=None):
"""Log intermediate stats according to log_interval."""
self.tqdm.set_postfix(self._format_stats(stats), refresh=False)
def print(self, stats):
def print(self, stats, tag='', step=None):
"""Print end-of-epoch stats."""
postfix = self._str_pipes(self._format_stats(stats))
self.tqdm.write('{} | {}'.format(self.tqdm.desc, postfix))
class tensorboard_log_wrapper(progress_bar):
"""Log to tensorboard."""
def __init__(self, wrapped_bar, tensorboard_logdir):
self.wrapped_bar = wrapped_bar
self.tensorboard_logdir = tensorboard_logdir
try:
from tensorboardX import SummaryWriter
self.SummaryWriter = SummaryWriter
self._writers = {}
except ImportError:
print("tensorboard or required dependencies not found, "
"please see README for using tensorboard.")
self.SummaryWriter = None
def _writer(self, key):
if self.SummaryWriter is None:
return None
if key not in self._writers:
self._writers[key] = self.SummaryWriter(
log_dir=os.path.join(self.tensorboard_logdir, key),
)
return self._writers[key]
def __iter__(self):
return iter(self.wrapped_bar)
def log(self, stats, tag='', step=None):
"""Log intermediate stats to tensorboard."""
self._log_to_tensorboard(stats, tag, step)
self.wrapped_bar.log(stats, tag=tag, step=step)
def print(self, stats, tag='', step=None):
"""Print end-of-epoch stats."""
self._log_to_tensorboard(stats, tag, step)
self.wrapped_bar.print(stats, tag=tag, step=step)
def _log_to_tensorboard(self, stats, tag='', step=None):
writer = self._writer(tag)
if writer is None:
return
if step is None:
step = stats['num_updates']
for key in stats.keys() - {'num_updates'}:
if isinstance(stats[key], AverageMeter):
writer.add_scalar(key, stats[key].val, step)
elif isinstance(stats[key], Number):
writer.add_scalar(key, stats[key], step)

View File

@ -61,9 +61,9 @@ class TestReproducibility(unittest.TestCase):
def cast(s):
return round(float(s), 3)
for k in ['loss', 'ppl', 'num_updates', 'gnorm']:
for k in ['train_loss', 'train_ppl', 'train_num_updates', 'train_gnorm']:
self.assertEqual(cast(train_log[k]), cast(train_res_log[k]))
for k in ['valid_loss', 'valid_ppl', 'num_updates', 'best']:
for k in ['valid_loss', 'valid_ppl', 'valid_num_updates', 'valid_best_loss']:
self.assertEqual(cast(valid_log[k]), cast(valid_res_log[k]))
def test_reproducibility(self):

View File

@ -150,7 +150,7 @@ def train(args, trainer, task, epoch_itr):
else:
extra_meters[k].update(v)
stats[k] = extra_meters[k].avg
progress.log(stats)
progress.log(stats, tag='train', step=stats['num_updates'])
# ignore the first mini-batch in words-per-second calculation
if i == 0:
@ -168,7 +168,7 @@ def train(args, trainer, task, epoch_itr):
stats = get_training_stats(trainer)
for k, meter in extra_meters.items():
stats[k] = meter.avg
progress.print(stats)
progress.print(stats, tag='train', step=stats['num_updates'])
# reset training meters
for k in [
@ -181,26 +181,26 @@ def train(args, trainer, task, epoch_itr):
def get_training_stats(trainer):
stats = collections.OrderedDict()
stats['loss'] = '{:.3f}'.format(trainer.get_meter('train_loss').avg)
stats['loss'] = trainer.get_meter('train_loss')
if trainer.get_meter('train_nll_loss').count > 0:
nll_loss = trainer.get_meter('train_nll_loss').avg
stats['nll_loss'] = '{:.3f}'.format(nll_loss)
nll_loss = trainer.get_meter('train_nll_loss')
stats['nll_loss'] = nll_loss
else:
nll_loss = trainer.get_meter('train_loss').avg
stats['ppl'] = get_perplexity(nll_loss)
stats['wps'] = round(trainer.get_meter('wps').avg)
stats['ups'] = '{:.1f}'.format(trainer.get_meter('ups').avg)
stats['wpb'] = round(trainer.get_meter('wpb').avg)
stats['bsz'] = round(trainer.get_meter('bsz').avg)
nll_loss = trainer.get_meter('train_loss')
stats['ppl'] = get_perplexity(nll_loss.avg)
stats['wps'] = trainer.get_meter('wps')
stats['ups'] = trainer.get_meter('ups')
stats['wpb'] = trainer.get_meter('wpb')
stats['bsz'] = trainer.get_meter('bsz')
stats['num_updates'] = trainer.get_num_updates()
stats['lr'] = trainer.get_lr()
stats['gnorm'] = '{:.3f}'.format(trainer.get_meter('gnorm').avg)
stats['clip'] = '{:.0%}'.format(trainer.get_meter('clip').avg)
stats['oom'] = trainer.get_meter('oom').avg
stats['gnorm'] = trainer.get_meter('gnorm')
stats['clip'] = trainer.get_meter('clip')
stats['oom'] = trainer.get_meter('oom')
if trainer.get_meter('loss_scale') is not None:
stats['loss_scale'] = '{:.3f}'.format(trainer.get_meter('loss_scale').avg)
stats['loss_scale'] = trainer.get_meter('loss_scale')
stats['wall'] = round(trainer.get_meter('wall').elapsed_time)
stats['train_wall'] = round(trainer.get_meter('train_wall').sum)
stats['train_wall'] = trainer.get_meter('train_wall')
return stats
@ -249,24 +249,24 @@ def validate(args, trainer, task, epoch_itr, subsets):
stats = get_valid_stats(trainer)
for k, meter in extra_meters.items():
stats[k] = meter.avg
progress.print(stats)
progress.print(stats, tag=subset, step=trainer.get_num_updates())
valid_losses.append(stats['valid_loss'])
valid_losses.append(stats['loss'].avg)
return valid_losses
def get_valid_stats(trainer):
stats = collections.OrderedDict()
stats['valid_loss'] = trainer.get_meter('valid_loss').avg
stats['loss'] = trainer.get_meter('valid_loss')
if trainer.get_meter('valid_nll_loss').count > 0:
nll_loss = trainer.get_meter('valid_nll_loss').avg
stats['valid_nll_loss'] = nll_loss
nll_loss = trainer.get_meter('valid_nll_loss')
stats['nll_loss'] = nll_loss
else:
nll_loss = trainer.get_meter('valid_loss').avg
stats['valid_ppl'] = get_perplexity(nll_loss)
nll_loss = stats['loss']
stats['ppl'] = get_perplexity(nll_loss.avg)
stats['num_updates'] = trainer.get_num_updates()
if hasattr(save_checkpoint, 'best'):
stats['best'] = min(save_checkpoint.best, stats['valid_loss'])
stats['best_loss'] = min(save_checkpoint.best, stats['loss'].avg)
return stats