From 8feccf94412424a4683b01090de36fa77cb4951d Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Wed, 1 Sep 2021 11:43:25 -0700 Subject: [PATCH] EMA Summary: Adds Exponential moving average (EMA) model for Kaizen semi-supervised training https://arxiv.org/abs/2106.07759 1. Add `ema.store_ema` to enable storing EMA. EMA will be written to extra_state in the state dict while saving checkpoint. 2. `ema.ema_start_update` to control when the EMA starts accumulating 3. Tasks can use `uses_ema` property to decide if the EMA should be passed to the task. (Default is False) 4. `load_ema_from_checkpoint` can be used to load EMA model in place of the model to be used for evalutation. Pyspeech has eval-ema option for this. ``` This module has the EMA class used to store a copy of the exponentially decayed model params. Typical usage of EMA class involves initializing an object using an existing model (random or from a seed model) and setting the config like ema_decay, ema_start_update which determine how the EMA model is updated. After every update of the model i.e. at the end of the train_step, the EMA should be updated by passing the new model to the EMA.step function. The EMA model state dict can be stored in the extra state under the key of "ema" and dumped into a checkpoint and loaded. The EMA object can be passed to tasks by setting task.uses_ema property. EMA is a smoothed/ensemble model which might have better performance when used for inference or further fine-tuning. EMA class has a reverse function to load the EMA params into a model and use it like a regular model. ``` Reviewed By: cruvadom Differential Revision: D24238379 fbshipit-source-id: 879d3ba5070a614b7d365f9503af357001e875b2 --- fairseq/checkpoint_utils.py | 46 ++++++++ fairseq/dataclass/configs.py | 31 +++++ fairseq/models/ema/__init__.py | 20 ++++ fairseq/models/ema/ema.py | 189 +++++++++++++++++++++++++++++++ fairseq/options.py | 7 ++ fairseq/trainer.py | 78 ++++++++++++- tests/gpu/test_ema_gpu.py | 200 +++++++++++++++++++++++++++++++++ tests/test_ema.py | 199 ++++++++++++++++++++++++++++++++ 8 files changed, 769 insertions(+), 1 deletion(-) create mode 100644 fairseq/models/ema/__init__.py create mode 100644 fairseq/models/ema/ema.py create mode 100644 tests/gpu/test_ema_gpu.py create mode 100644 tests/test_ema.py diff --git a/fairseq/checkpoint_utils.py b/fairseq/checkpoint_utils.py index b8c46f825..ef5d4c902 100644 --- a/fairseq/checkpoint_utils.py +++ b/fairseq/checkpoint_utils.py @@ -810,3 +810,49 @@ def verify_checkpoint_directory(save_dir: str) -> None: raise e else: os.remove(temp_file_path) + + +def load_ema_from_checkpoint(fpath): + """Loads exponential moving averaged (EMA) checkpoint from input and + returns a model with ema weights. + + Args: + fpath: A string path of checkpoint to load from. + + Returns: + A dict of string keys mapping to various values. The 'model' key + from the returned dict should correspond to an OrderedDict mapping + string parameter names to torch Tensors. + """ + params_dict = collections.OrderedDict() + new_state = None + + with PathManager.open(fpath, 'rb') as f: + new_state = torch.load( + f, + map_location=( + lambda s, _: torch.serialization.default_restore_location(s, 'cpu') + ), + ) + + # EMA model is stored in a separate "extra state" + model_params = new_state['extra_state']['ema'] + + for key in list(model_params.keys()): + p = model_params[key] + if isinstance(p, torch.HalfTensor): + p = p.float() + if key not in params_dict: + params_dict[key] = p.clone() + # NOTE: clone() is needed in case of p is a shared parameter + else: + raise ValueError("Key {} is repeated in EMA model params.".format(key)) + + if len(params_dict) == 0: + raise ValueError( + f"Input checkpoint path '{fpath}' does not contain " + "ema model weights, is this model trained with EMA?" + ) + + new_state['model'] = params_dict + return new_state diff --git a/fairseq/dataclass/configs.py b/fairseq/dataclass/configs.py index 6a86ea019..952f1ec4d 100644 --- a/fairseq/dataclass/configs.py +++ b/fairseq/dataclass/configs.py @@ -984,6 +984,36 @@ class InteractiveConfig(FairseqDataclass): ) +@dataclass +class EMAConfig(FairseqDataclass): + store_ema: bool = field( + default=False, metadata={ + help: "store exponential moving average shadow model" + } + ) + ema_decay: float = field( + default=0.9999, metadata={ + "help": 'decay for exponential moving average model' + } + ) + ema_start_update : int = field( + default=0, metadata={"help": "start EMA update after this many model updates"} + ) + ema_seed_model : Optional[str] = field( + default=None, metadata={ + "help": "Seed to load EMA model from. " + "Used to load EMA model separately from the actual model." + } + ) + ema_update_freq : int = field( + default=1, metadata={"help": "Do EMA update every this many model updates"} + ) + ema_fp32: bool = field( + default=False, + metadata={"help": "If true, store EMA model in fp32 even if model is in fp16"}, + ) + + @dataclass class FairseqConfig(FairseqDataclass): common: CommonConfig = CommonConfig() @@ -1004,3 +1034,4 @@ class FairseqConfig(FairseqDataclass): scoring: Any = None bpe: Any = None tokenizer: Any = None + ema: EMAConfig = EMAConfig() diff --git a/fairseq/models/ema/__init__.py b/fairseq/models/ema/__init__.py new file mode 100644 index 000000000..503ceaa60 --- /dev/null +++ b/fairseq/models/ema/__init__.py @@ -0,0 +1,20 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import importlib +import os + +from .ema import EMA + + +def build_ema(model, cfg, device): + return EMA(model, cfg, device) + + +# automatically import any Python files in the models/ema/ directory +for file in sorted(os.listdir(os.path.dirname(__file__))): + if file.endswith(".py") and not file.startswith("_"): + file_name = file[: file.find(".py")] + importlib.import_module("fairseq.models.ema." + file_name) diff --git a/fairseq/models/ema/ema.py b/fairseq/models/ema/ema.py new file mode 100644 index 000000000..6c0af6932 --- /dev/null +++ b/fairseq/models/ema/ema.py @@ -0,0 +1,189 @@ +#!/usr/bin/env python3 + +""" +This module has the EMA class used to store a copy of the exponentially decayed +model params. + +Typical usage of EMA class involves initializing an object using an existing +model (random or from a seed model) and setting the config like ema_decay, +ema_start_update which determine how the EMA model is updated. After every +update of the model i.e. at the end of the train_step, the EMA should be updated +by passing the new model to the EMA.step function. The EMA model state dict +can be stored in the extra state under the key of "ema" and dumped +into a checkpoint and loaded. The EMA object can be passed to tasks +by setting task.uses_ema property. +EMA is a smoothed/ensemble model which might have better performance +when used for inference or further fine-tuning. EMA class has a +reverse function to load the EMA params into a model and use it +like a regular model. +""" + +import copy +import logging + +import torch +from fairseq import checkpoint_utils + + +class EMA(object): + """Exponential Moving Average of Fairseq Models + EMA keeps a copy of the exponentially decayed model params. + The set of params should include both gradient-descent and + non-gradient descent params, such as batch mean/var and buffers. + This is a modified implementation of + the open source code in https://github.com/zhawe01/fairseq-gec.git, + and internal source code in + fbcode/mobile-vision/projects/classification_pytorch/lib/utils/model_ema.py. + + Similar to TF EMA. + https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage. + EMA provides a averaged and smoothed set of model weights, and has been shown to + improve vision models. EMA class does all necessary functions to update, reload, + or init EMA methods. + + EMA object is initialized from an arbitrary model. By default, it is stored in + the same device (unless device specified at initialization) and with the + same precision as the model (unless ema_fp32 is True). ema_fp32 is recommended. + This stores the EMA parameters in fp32 only for the EMA update step, and + is used at the default precision otherwise. + EMA is usually enabled using EMAConfig with store_ema=True. Some important + parameters to configure EMA are + 1) ema_decay - The decay of EMA + 2) ema_update_freq - EMA is updated every this many model updates. + 3) ema_start_update - Start EMA update after this many model updates [default 0] + + Key methods: + 1) step - One update of EMA using new model + 2) restore - Update EMA from a state dict + 3) reverse - Load EMA into a model + 4) get_decay, _set_decay - Used to get or set the decay. Note _set_decay is + called from step. + 5) build_fp32_params - Used to initialize or update the fp32 copy of EMA params. + Note this is enabled only when ema_fp32=True + """ + + def __init__(self, model, config, device=None): + """ + @param model model to initialize the EMA with + @param config EMAConfig object with configuration like + ema_decay, ema_update_freq, ema_fp32 + @param device If provided, copy EMA to this device (e.g. gpu). + Otherwise EMA is in the same device as the model. + """ + + self.decay = config.ema_decay + self.model = copy.deepcopy(model) + self.model.requires_grad_(False) + self.config = config + self.fp32_params = {} + + if self.config.ema_seed_model is not None: + state = checkpoint_utils.load_ema_from_checkpoint(self.config.ema_seed_model) + self.model.load_state_dict(state["model"], strict=True) + + if device is not None: + logging.info(f"Copying EMA model to device {device}") + self.model = self.model.to(device=device) + + if self.config.ema_fp32: + self.build_fp32_params() + + self.update_freq_counter = 0 + + def get_model(self): + return self.model + + def build_fp32_params(self, state_dict=None): + """ + Store a copy of the EMA params in fp32. + If state dict is passed, the EMA params is copied from + the provided state dict. Otherwise, it is copied from the + current EMA model parameters. + """ + if not self.config.ema_fp32: + raise RuntimeError( + "build_fp32_params should not be called if ema_fp32=False. " + "Use ema_fp32=True if this is really intended." + ) + + if state_dict is None: + state_dict = self.model.state_dict() + + def _to_float(t): + return t.float() if torch.is_floating_point(t) else t + + for param_key in state_dict: + if param_key in self.fp32_params: + self.fp32_params[param_key].copy_(state_dict[param_key]) + else: + self.fp32_params[param_key] = _to_float(state_dict[param_key]) + + def restore(self, state_dict, build_fp32_params=False): + """ Load data from a model spec into EMA model """ + self.model.load_state_dict(state_dict, strict=False) + if build_fp32_params: + self.build_fp32_params(state_dict) + + def _set_decay(self, decay): + self.decay = decay + + def get_decay(self): + return self.decay + + def _step_internal(self, new_model, updates=None): + """ One update of the EMA model based on new model weights """ + decay = self.decay + + ema_state_dict = {} + ema_params = self.fp32_params if self.config.ema_fp32 else self.model.state_dict() + for key, param in new_model.state_dict().items(): + try: + ema_param = ema_params[key] + except KeyError: + ema_param = param.float().clone() if param.ndim == 1 else copy.deepcopy(param) + + if param.shape != ema_param.shape: + raise ValueError( + "incompatible tensor shapes between model param and ema param" + + "{} vs. {}".format(param.shape, ema_param.shape) + ) + if "version" in key: + # Do not decay a model.version pytorch param + continue + ema_param.mul_(decay) + ema_param.add_(param.to(dtype=ema_param.dtype), alpha=1-decay) + ema_state_dict[key] = ema_param + self.restore(ema_state_dict, build_fp32_params=False) + + def step(self, new_model, updates=None): + """ + One update of EMA which is done every self.config.ema_update_freq + updates of the model. + + @param updates The current number of model updates done. + Decay is set of 0 if model updates < ema_start_update, which means + the model will be simply copied over to the EMA. + When model updates >= ema_start_updates, then EMA is updated with + a decay of self.config.ema_decay. + """ + self._set_decay( + 0 + if updates is not None + and updates < self.config.ema_start_update + else self.config.ema_decay + ) + if updates is not None and self.config.ema_update_freq > 1: + self.update_freq_counter += 1 + if self.update_freq_counter >= self.config.ema_update_freq: + self._step_internal(new_model, updates) + self.update_freq_counter = 0 + else: + self._step_internal(new_model, updates) + + def reverse(self, model): + """ + Load the model parameters from EMA model. + Useful for inference or fine-tuning from the EMA model. + """ + model.load_state_dict(self.model.state_dict(), strict=False) + return model diff --git a/fairseq/options.py b/fairseq/options.py index 2d9f8381a..03883fc56 100644 --- a/fairseq/options.py +++ b/fairseq/options.py @@ -20,6 +20,7 @@ from fairseq.dataclass.configs import ( GenerationConfig, InteractiveConfig, OptimizationConfig, + EMAConfig, ) from fairseq.dataclass.utils import gen_parser_from_dataclass @@ -40,6 +41,7 @@ def get_training_parser(default_task="translation"): add_model_args(parser) add_optimization_args(parser) add_checkpoint_args(parser) + add_ema_args(parser) return parser @@ -379,3 +381,8 @@ def get_args( setattr(args, k, v) return args + + +def add_ema_args(parser): + group = parser.add_argument_group("EMA configuration") + gen_parser_from_dataclass(group, EMAConfig()) diff --git a/fairseq/trainer.py b/fairseq/trainer.py index c86b1a51e..e46ccfe0b 100644 --- a/fairseq/trainer.py +++ b/fairseq/trainer.py @@ -22,6 +22,7 @@ from fairseq.dataclass.utils import convert_namespace_to_omegaconf from fairseq.distributed import utils as distributed_utils from fairseq.file_io import PathManager from fairseq.logging import meters, metrics +from fairseq.models.ema import build_ema from fairseq.nan_detector import NanDetector from fairseq.optim import lr_scheduler from omegaconf import OmegaConf @@ -131,6 +132,7 @@ class Trainer(object): self._warn_once = set() self._wrapped_criterion = None self._wrapped_model = None + self._ema = None # TODO(myleott): support tpu if self.cuda and self.data_parallel_world_size > 1: @@ -256,6 +258,19 @@ class Trainer(object): self._wrapped_model = self._model return self._wrapped_model + @property + def ema(self): + if self._ema is None: + self._build_ema() + return self._ema + + def _build_ema(self): + if self.cfg.ema.store_ema: + self._ema = build_ema(self._model, self.cfg.ema, self.device) + logger.info( + "Exponential Moving Average Shadow Model is initialized." + ) + @property def optimizer(self): if self._optimizer is None: @@ -392,6 +407,12 @@ class Trainer(object): "previous_training_time": self.cumulative_training_time(), }, } + if self.cfg.ema.store_ema: + # Save EMA model state as extra state + state_dict["extra_state"]["ema"] = self.ema.get_model().state_dict() + if self.cfg.ema.ema_fp32: + # Save EMA params in fp32 + state_dict["extra_state"]["ema_fp32_params"] = self.ema.fp32_params if not self.cfg.checkpoint.no_save_optimizer_state: if self._gathered_optim_state is not None: state_dict["last_optimizer_state"] = self._gathered_optim_state @@ -552,6 +573,31 @@ class Trainer(object): if isinstance(meter, meters.TimeMeter): meter.reset() + if self.cfg.ema.store_ema: + if "ema" not in extra_state: + logger.warn( + "EMA not found in checkpoint. But store_ema is True. " + "EMA is re-initialized from checkpoint." + ) + self.ema.restore(state["model"], build_fp32_params=self.cfg.ema.ema_fp32) + else: + logger.info( + "Loading EMA from checkpoint" + ) + self.ema.restore(extra_state["ema"], build_fp32_params=False) + + if self.cfg.ema.ema_fp32: + if "ema_fp32_params" in extra_state: + logger.info( + "Loading EMA fp32 params from checkpoint" + ) + self.ema.build_fp32_params(extra_state["ema_fp32_params"]) + else: + logger.info( + "Building EMA fp32 params from EMA model in checkpoint" + ) + self.ema.build_fp32_params() + logger.info( "Loaded checkpoint {} (epoch {} @ {} updates)".format( filename, epoch, self.get_num_updates() @@ -670,6 +716,13 @@ class Trainer(object): metrics.log_start_time("train_wall", priority=800, round=0) + # If EMA is enabled through store_ema=True + # and task.uses_ema is True, pass the EMA model as a keyword + # argument to the task. + extra_kwargs = {} + if self.cfg.ema.store_ema and getattr(self.task, "uses_ema", False): + extra_kwargs["ema_model"] = self.ema.get_model() + # forward and backward pass logging_outputs, sample_size, ooms = [], 0, 0 for i, sample in enumerate(samples): # delayed update loop @@ -705,6 +758,7 @@ class Trainer(object): optimizer=self.optimizer, update_num=self.get_num_updates(), ignore_grad=is_dummy_batch, + **extra_kwargs, ) del loss @@ -840,6 +894,7 @@ class Trainer(object): self.optimizer, self.get_num_updates(), ignore_grad=False, + **extra_kwargs, ) raise except OverflowError as e: @@ -871,6 +926,20 @@ class Trainer(object): if not overflow or self.cfg.distributed_training.ddp_backend == "slow_mo": self.set_num_updates(self.get_num_updates() + 1) + if self.cfg.ema.store_ema: + # Step EMA forward with new model. + self.ema.step( + self.get_model(), + self.get_num_updates(), + ) + metrics.log_scalar( + "ema_decay", + self.ema.get_decay(), + priority=10000, + round=5, + weight=0, + ) + if self.tpu: import torch_xla.core.xla_model as xm @@ -953,6 +1022,13 @@ class Trainer(object): xm.rendezvous("valid_step") # wait for all workers + # If EMA is enabled through store_ema=True + # and task.uses_ema is True, pass the EMA model as a keyword + # argument to the task. + extra_kwargs = {} + if self.cfg.ema.store_ema and getattr(self.task, "uses_ema", False): + extra_kwargs["ema_model"] = self.ema.get_model() + with torch.no_grad(): self.model.eval() self.criterion.eval() @@ -961,7 +1037,7 @@ class Trainer(object): try: _loss, sample_size, logging_output = self.task.valid_step( - sample, self.model, self.criterion + sample, self.model, self.criterion, **extra_kwargs ) except RuntimeError as e: if "out of memory" in str(e): diff --git a/tests/gpu/test_ema_gpu.py b/tests/gpu/test_ema_gpu.py new file mode 100644 index 000000000..337107d69 --- /dev/null +++ b/tests/gpu/test_ema_gpu.py @@ -0,0 +1,200 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import unittest +from copy import deepcopy +from dataclasses import dataclass +from typing import Optional + +import torch +from fairseq.models.ema import EMA + + +class DummyModule(torch.nn.Module): + def __init__(self) -> None: + """LightningModule for testing purposes + + Args: + epoch_min_loss_override (int, optional): Pass in an epoch that will be set to the minimum + validation loss for testing purposes (zero based). If None this is ignored. Defaults to None. + """ + super().__init__() + self.layer = torch.nn.Linear(in_features=32, out_features=2) + self.another_layer = torch.nn.Linear(in_features=2, out_features=2) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.layer(x) + return self.another_layer(x) + + +@dataclass +class EMAConfig(object): + ema_decay: float = 0.99 + ema_start_update: int = 0 + ema_fp32: bool = False + ema_seed_model: Optional[str] = None + + +@unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU") +class TestEMAGPU(unittest.TestCase): + def assertTorchAllClose(self, x, y, atol=1e-8, rtol=1e-5, msg=None): + diff = x.float() - y.float() + diff_norm = torch.norm(diff) + other_norm = torch.norm(y.float()) + + if msg is None: + msg = "|input - other| > {} + {} * |other|".format( + atol, rtol + ) + + self.assertLessEqual( + diff_norm, + atol + rtol * other_norm, + msg=msg, + ) + + def test_ema(self): + model = DummyModule().cuda() + optimizer = torch.optim.SGD(model.parameters(), lr=0.01) + state = deepcopy(model.state_dict()) + config = EMAConfig() + ema = EMA(model, config) + + # set decay + ema._set_decay(config.ema_decay) + self.assertEqual(ema.get_decay(), config.ema_decay) + + # get model + self.assertEqual(ema.get_model(), ema.model) + + # Since fp32 params is not used, it should be of size 0 + self.assertEqual(len(ema.fp32_params), 0) + + # EMA step + x = torch.randn(32).cuda() + y = model(x) + loss = y.sum() + loss.backward() + optimizer.step() + + ema.step(model) + + ema_state_dict = ema.get_model().state_dict() + + for key, param in model.state_dict().items(): + prev_param = state[key] + ema_param = ema_state_dict[key] + + if "version" in key: + # Do not decay a model.version pytorch param + continue + self.assertTorchAllClose( + ema_param, + config.ema_decay * prev_param + (1 - config.ema_decay) * param, + ) + + # Since fp32 params is not used, it should be of size 0 + self.assertEqual(len(ema.fp32_params), 0) + + # Load EMA into model + model2 = DummyModule().cuda() + ema.reverse(model2) + + for key, param in model2.state_dict().items(): + ema_param = ema_state_dict[key] + self.assertTrue( + torch.allclose(ema_param, param) + ) + + def test_ema_fp32(self): + model = DummyModule().cuda().half() + optimizer = torch.optim.SGD(model.parameters(), lr=0.01) + state = deepcopy(model.state_dict()) + config = EMAConfig(ema_fp32=True) + ema = EMA(model, config) + + x = torch.randn(32).cuda() + y = model(x.half()) + loss = y.sum() + loss.backward() + optimizer.step() + + ema.step(model) + + for key, param in model.state_dict().items(): + prev_param = state[key] + ema_param = ema.get_model().state_dict()[key] + + if "version" in key: + # Do not decay a model.version pytorch param + continue + self.assertIn(key, ema.fp32_params) + + # EMA update is done in fp32, and hence the EMA param must be + # closer to the EMA update done in fp32 than in fp16. + self.assertLessEqual( + torch.norm( + ema_param.float() - + (config.ema_decay * prev_param.float() + (1 - config.ema_decay) * param.float()).half().float() + ), + torch.norm( + ema_param.float() - + (config.ema_decay * prev_param + (1 - config.ema_decay) * param).float() + ), + ) + self.assertTorchAllClose( + ema_param, + (config.ema_decay * prev_param.float() + (1 - config.ema_decay) * param.float()).half(), + ) + + def test_ema_fp16(self): + model = DummyModule().cuda().half() + optimizer = torch.optim.SGD(model.parameters(), lr=0.01) + state = deepcopy(model.state_dict()) + config = EMAConfig(ema_fp32=False) + ema = EMA(model, config) + + # Since fp32 params is not used, it should be of size 0 + self.assertEqual(len(ema.fp32_params), 0) + + x = torch.randn(32).cuda() + y = model(x.half()) + loss = y.sum() + loss.backward() + optimizer.step() + + ema.step(model) + + for key, param in model.state_dict().items(): + prev_param = state[key] + ema_param = ema.get_model().state_dict()[key] + + if "version" in key: + # Do not decay a model.version pytorch param + continue + + # EMA update is done in fp16, and hence the EMA param must be + # closer to the EMA update done in fp16 than in fp32. + self.assertLessEqual( + torch.norm( + ema_param.float() - + (config.ema_decay * prev_param + (1 - config.ema_decay) * param).float() + ), + torch.norm( + ema_param.float() - + (config.ema_decay * prev_param.float() + (1 - config.ema_decay) * param.float()).half().float() + ), + ) + self.assertTorchAllClose( + ema_param, + config.ema_decay * prev_param + (1 - config.ema_decay) * param, + ) + + # Since fp32 params is not used, it should be of size 0 + self.assertEqual(len(ema.fp32_params), 0) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_ema.py b/tests/test_ema.py new file mode 100644 index 000000000..88ea65a43 --- /dev/null +++ b/tests/test_ema.py @@ -0,0 +1,199 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import unittest +from copy import deepcopy +from dataclasses import dataclass +from typing import Optional + +import torch +from fairseq.models.ema import EMA + + +class DummyModule(torch.nn.Module): + def __init__(self) -> None: + """LightningModule for testing purposes + + Args: + epoch_min_loss_override (int, optional): Pass in an epoch that will be set to the minimum + validation loss for testing purposes (zero based). If None this is ignored. Defaults to None. + """ + super().__init__() + self.layer = torch.nn.Linear(in_features=32, out_features=2) + self.another_layer = torch.nn.Linear(in_features=2, out_features=2) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.layer(x) + return self.another_layer(x) + + +@dataclass +class EMAConfig(object): + ema_decay: float = 0.99 + ema_start_update: int = 0 + ema_fp32: bool = False + ema_seed_model: Optional[str] = None + + +class TestEMAGPU(unittest.TestCase): + def assertTorchAllClose(self, x, y, atol=1e-8, rtol=1e-5, msg=None): + diff = x.float() - y.float() + diff_norm = torch.norm(diff) + other_norm = torch.norm(y.float()) + + if msg is None: + msg = "|input - other| > {} + {} * |other|".format( + atol, rtol + ) + + self.assertLessEqual( + diff_norm, + atol + rtol * other_norm, + msg=msg, + ) + + def test_ema(self): + model = DummyModule() + optimizer = torch.optim.SGD(model.parameters(), lr=0.01) + state = deepcopy(model.state_dict()) + config = EMAConfig() + ema = EMA(model, config) + + # set decay + ema._set_decay(config.ema_decay) + self.assertEqual(ema.get_decay(), config.ema_decay) + + # get model + self.assertEqual(ema.get_model(), ema.model) + + # Since fp32 params is not used, it should be of size 0 + self.assertEqual(len(ema.fp32_params), 0) + + # EMA step + x = torch.randn(32) + y = model(x) + loss = y.sum() + loss.backward() + optimizer.step() + + ema.step(model) + + ema_state_dict = ema.get_model().state_dict() + + for key, param in model.state_dict().items(): + prev_param = state[key] + ema_param = ema_state_dict[key] + + if "version" in key: + # Do not decay a model.version pytorch param + continue + self.assertTorchAllClose( + ema_param, + config.ema_decay * prev_param + (1 - config.ema_decay) * param, + ) + + # Since fp32 params is not used, it should be of size 0 + self.assertEqual(len(ema.fp32_params), 0) + + # Load EMA into model + model2 = DummyModule() + ema.reverse(model2) + + for key, param in model2.state_dict().items(): + ema_param = ema_state_dict[key] + self.assertTrue( + torch.allclose(ema_param, param) + ) + + def test_ema_fp32(self): + model = DummyModule().half() + optimizer = torch.optim.SGD(model.parameters(), lr=0.01) + state = deepcopy(model.state_dict()) + config = EMAConfig(ema_fp32=True) + ema = EMA(model, config) + + x = torch.randn(32) + y = model(x.half()) + loss = y.sum() + loss.backward() + optimizer.step() + + ema.step(model) + + for key, param in model.state_dict().items(): + prev_param = state[key] + ema_param = ema.get_model().state_dict()[key] + + if "version" in key: + # Do not decay a model.version pytorch param + continue + self.assertIn(key, ema.fp32_params) + + # EMA update is done in fp32, and hence the EMA param must be + # closer to the EMA update done in fp32 than in fp16. + self.assertLessEqual( + torch.norm( + ema_param.float() - + (config.ema_decay * prev_param.float() + (1 - config.ema_decay) * param.float()).half().float() + ), + torch.norm( + ema_param.float() - + (config.ema_decay * prev_param + (1 - config.ema_decay) * param).float() + ), + ) + self.assertTorchAllClose( + ema_param, + (config.ema_decay * prev_param.float() + (1 - config.ema_decay) * param.float()).half(), + ) + + def test_ema_fp16(self): + model = DummyModule().half() + optimizer = torch.optim.SGD(model.parameters(), lr=0.01) + state = deepcopy(model.state_dict()) + config = EMAConfig(ema_fp32=False) + ema = EMA(model, config) + + # Since fp32 params is not used, it should be of size 0 + self.assertEqual(len(ema.fp32_params), 0) + + x = torch.randn(32) + y = model(x.half()) + loss = y.sum() + loss.backward() + optimizer.step() + + ema.step(model) + + for key, param in model.state_dict().items(): + prev_param = state[key] + ema_param = ema.get_model().state_dict()[key] + + if "version" in key: + # Do not decay a model.version pytorch param + continue + + # EMA update is done in fp16, and hence the EMA param must be + # closer to the EMA update done in fp16 than in fp32. + self.assertLessEqual( + torch.norm( + ema_param.float() - + (config.ema_decay * prev_param + (1 - config.ema_decay) * param).float() + ), + torch.norm( + ema_param.float() - + (config.ema_decay * prev_param.float() + (1 - config.ema_decay) * param.float()).half().float() + ), + ) + self.assertTorchAllClose( + ema_param, + config.ema_decay * prev_param + (1 - config.ema_decay) * param, + ) + + # Since fp32 params is not used, it should be of size 0 + self.assertEqual(len(ema.fp32_params), 0) + + +if __name__ == "__main__": + unittest.main()