Reproduce #1781. Add Weights and Biases support

Summary:

Fixes https://github.com/pytorch/fairseq/issues/1790.

Reviewed By: alexeib

Differential Revision: D24579153

fbshipit-source-id: 74a30effa164db9d6376554376e36b1f47618899

Co-authored-by: Nikolay Korolev <korolevns98@gmail.com>
Co-authored-by: Vlad Lyalin <Guitaricet@gmail.com>
This commit is contained in:
Myle Ott 2020-11-03 20:46:45 -08:00 committed by Facebook GitHub Bot
parent dd52ed0f38
commit 1a709b2a40
5 changed files with 69 additions and 2 deletions

3
.gitignore vendored
View File

@ -131,3 +131,6 @@ data-bin/
# Experimental Folder
experimental/*
# Weights and Biases logs
wandb/

View File

@ -4,6 +4,7 @@ common:
log_interval: 100
log_format: null
tensorboard_logdir: null
wandb_project: null
seed: 1
cpu: false
tpu: false

View File

@ -102,6 +102,12 @@ class CommonConfig(FairseqDataclass):
"of running tensorboard (default: no tensorboard logging)"
},
)
wandb_project: Optional[str] = field(
default=None,
metadata={
"help": "Weights and Biases project name to use for logging"
},
)
seed: int = field(
default=1, metadata={"help": "pseudo random number generator seed"}
)

View File

@ -33,6 +33,7 @@ def progress_bar(
prefix: Optional[str] = None,
tensorboard_logdir: Optional[str] = None,
default_log_format: str = "tqdm",
wandb_project: Optional[str] = None,
):
if log_format is None:
log_format = default_log_format
@ -60,6 +61,9 @@ def progress_bar(
except ImportError:
bar = TensorboardProgressBarWrapper(bar, tensorboard_logdir)
if wandb_project:
bar = WandBProgressBarWrapper(bar, wandb_project)
return bar
@ -353,3 +357,50 @@ class TensorboardProgressBarWrapper(BaseProgressBar):
elif isinstance(stats[key], Number):
writer.add_scalar(key, stats[key], step)
writer.flush()
try:
import wandb
except ImportError:
wandb = None
class WandBProgressBarWrapper(BaseProgressBar):
"""Log to Weights & Biases."""
def __init__(self, wrapped_bar, wandb_project):
self.wrapped_bar = wrapped_bar
if wandb is None:
logger.warning('wandb not found, pip install wandb')
return
# reinit=False to ensure if wandb.init() is called multiple times
# within one process it still references the same run
wandb.init(project=wandb_project, reinit=False)
def __iter__(self):
return iter(self.wrapped_bar)
def log(self, stats, tag=None, step=None):
"""Log intermediate stats to tensorboard."""
self._log_to_wandb(stats, tag, step)
self.wrapped_bar.log(stats, tag=tag, step=step)
def print(self, stats, tag=None, step=None):
"""Print end-of-epoch stats."""
self._log_to_wandb(stats, tag, step)
self.wrapped_bar.print(stats, tag=tag, step=step)
def _log_to_wandb(self, stats, tag=None, step=None):
if wandb is None:
return
if step is None:
step = stats['num_updates']
prefix = '' if tag is None else tag + '/'
for key in stats.keys() - {'num_updates'}:
if isinstance(stats[key], AverageMeter):
wandb.log({prefix + key: stats[key].val}, step=step)
elif isinstance(stats[key], Number):
wandb.log({prefix + key: stats[key]}, step=step)

View File

@ -187,7 +187,10 @@ def train(cfg: DictConfig, trainer: Trainer, task: tasks.FairseqTask, epoch_itr)
tensorboard_logdir=(
cfg.common.tensorboard_logdir if distributed_utils.is_master(cfg.distributed_training) else None
),
default_log_format=('tqdm' if not cfg.common.no_progress_bar else 'simple'),
default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"),
wandb_project=(
cfg.common.wandb_project if distributed_utils.is_master(cfg.distributed_training) else None
),
)
trainer.begin_epoch(epoch_itr.epoch)
@ -307,7 +310,10 @@ def validate(cfg: DictConfig, trainer: Trainer, task: tasks.FairseqTask, epoch_i
tensorboard_logdir=(
cfg.common.tensorboard_logdir if distributed_utils.is_master(cfg.distributed_training) else None
),
default_log_format=('tqdm' if not cfg.common.no_progress_bar else 'simple'),
default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"),
wandb_project=(
cfg.common.wandb_project if distributed_utils.is_master(cfg.distributed_training) else None
),
)
# create a new root metrics aggregator so validation metrics