Summary:
Adds Exponential moving average (EMA) model for Kaizen semi-supervised training https://arxiv.org/abs/2106.07759

1. Add `ema.store_ema` to enable storing EMA. EMA will be written to extra_state in the state dict while saving checkpoint.
2. `ema.ema_start_update` to control when the EMA starts accumulating
3. Tasks can use `uses_ema` property to decide if the EMA should be passed to the task. (Default is False)
4. `load_ema_from_checkpoint` can be used to load EMA model in place of the model to be used for evalutation. Pyspeech has eval-ema option for this.

```
This module has the EMA class used to store a copy of the exponentially decayed
model params.

Typical usage of EMA class involves initializing an object using an existing
model (random or from a seed model) and setting the config like ema_decay,
ema_start_update which determine how the EMA model is updated. After every
update of the model i.e. at the end of the train_step, the EMA should be updated
by passing the new model to the EMA.step function. The EMA model state dict
can be stored in the extra state under the key of "ema" and dumped
into a checkpoint and loaded. The EMA object can be passed to tasks
by setting task.uses_ema property.
EMA is a smoothed/ensemble model which might have better performance
when used for inference or further fine-tuning. EMA class has a
reverse function to load the EMA params into a model and use it
like a regular model.
```

Reviewed By: cruvadom

Differential Revision: D24238379

fbshipit-source-id: 879d3ba5070a614b7d365f9503af357001e875b2
This commit is contained in:
Vimal Manohar 2021-09-01 11:43:25 -07:00 committed by Facebook GitHub Bot
parent 68a81202a3
commit 8feccf9441
8 changed files with 769 additions and 1 deletions

View File

@ -810,3 +810,49 @@ def verify_checkpoint_directory(save_dir: str) -> None:
raise e
else:
os.remove(temp_file_path)
def load_ema_from_checkpoint(fpath):
"""Loads exponential moving averaged (EMA) checkpoint from input and
returns a model with ema weights.
Args:
fpath: A string path of checkpoint to load from.
Returns:
A dict of string keys mapping to various values. The 'model' key
from the returned dict should correspond to an OrderedDict mapping
string parameter names to torch Tensors.
"""
params_dict = collections.OrderedDict()
new_state = None
with PathManager.open(fpath, 'rb') as f:
new_state = torch.load(
f,
map_location=(
lambda s, _: torch.serialization.default_restore_location(s, 'cpu')
),
)
# EMA model is stored in a separate "extra state"
model_params = new_state['extra_state']['ema']
for key in list(model_params.keys()):
p = model_params[key]
if isinstance(p, torch.HalfTensor):
p = p.float()
if key not in params_dict:
params_dict[key] = p.clone()
# NOTE: clone() is needed in case of p is a shared parameter
else:
raise ValueError("Key {} is repeated in EMA model params.".format(key))
if len(params_dict) == 0:
raise ValueError(
f"Input checkpoint path '{fpath}' does not contain "
"ema model weights, is this model trained with EMA?"
)
new_state['model'] = params_dict
return new_state

View File

@ -984,6 +984,36 @@ class InteractiveConfig(FairseqDataclass):
)
@dataclass
class EMAConfig(FairseqDataclass):
store_ema: bool = field(
default=False, metadata={
help: "store exponential moving average shadow model"
}
)
ema_decay: float = field(
default=0.9999, metadata={
"help": 'decay for exponential moving average model'
}
)
ema_start_update : int = field(
default=0, metadata={"help": "start EMA update after this many model updates"}
)
ema_seed_model : Optional[str] = field(
default=None, metadata={
"help": "Seed to load EMA model from. "
"Used to load EMA model separately from the actual model."
}
)
ema_update_freq : int = field(
default=1, metadata={"help": "Do EMA update every this many model updates"}
)
ema_fp32: bool = field(
default=False,
metadata={"help": "If true, store EMA model in fp32 even if model is in fp16"},
)
@dataclass
class FairseqConfig(FairseqDataclass):
common: CommonConfig = CommonConfig()
@ -1004,3 +1034,4 @@ class FairseqConfig(FairseqDataclass):
scoring: Any = None
bpe: Any = None
tokenizer: Any = None
ema: EMAConfig = EMAConfig()

View File

@ -0,0 +1,20 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import importlib
import os
from .ema import EMA
def build_ema(model, cfg, device):
return EMA(model, cfg, device)
# automatically import any Python files in the models/ema/ directory
for file in sorted(os.listdir(os.path.dirname(__file__))):
if file.endswith(".py") and not file.startswith("_"):
file_name = file[: file.find(".py")]
importlib.import_module("fairseq.models.ema." + file_name)

189
fairseq/models/ema/ema.py Normal file
View File

@ -0,0 +1,189 @@
#!/usr/bin/env python3
"""
This module has the EMA class used to store a copy of the exponentially decayed
model params.
Typical usage of EMA class involves initializing an object using an existing
model (random or from a seed model) and setting the config like ema_decay,
ema_start_update which determine how the EMA model is updated. After every
update of the model i.e. at the end of the train_step, the EMA should be updated
by passing the new model to the EMA.step function. The EMA model state dict
can be stored in the extra state under the key of "ema" and dumped
into a checkpoint and loaded. The EMA object can be passed to tasks
by setting task.uses_ema property.
EMA is a smoothed/ensemble model which might have better performance
when used for inference or further fine-tuning. EMA class has a
reverse function to load the EMA params into a model and use it
like a regular model.
"""
import copy
import logging
import torch
from fairseq import checkpoint_utils
class EMA(object):
"""Exponential Moving Average of Fairseq Models
EMA keeps a copy of the exponentially decayed model params.
The set of params should include both gradient-descent and
non-gradient descent params, such as batch mean/var and buffers.
This is a modified implementation of
the open source code in https://github.com/zhawe01/fairseq-gec.git,
and internal source code in
fbcode/mobile-vision/projects/classification_pytorch/lib/utils/model_ema.py.
Similar to TF EMA.
https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage.
EMA provides a averaged and smoothed set of model weights, and has been shown to
improve vision models. EMA class does all necessary functions to update, reload,
or init EMA methods.
EMA object is initialized from an arbitrary model. By default, it is stored in
the same device (unless device specified at initialization) and with the
same precision as the model (unless ema_fp32 is True). ema_fp32 is recommended.
This stores the EMA parameters in fp32 only for the EMA update step, and
is used at the default precision otherwise.
EMA is usually enabled using EMAConfig with store_ema=True. Some important
parameters to configure EMA are
1) ema_decay - The decay of EMA
2) ema_update_freq - EMA is updated every this many model updates.
3) ema_start_update - Start EMA update after this many model updates [default 0]
Key methods:
1) step - One update of EMA using new model
2) restore - Update EMA from a state dict
3) reverse - Load EMA into a model
4) get_decay, _set_decay - Used to get or set the decay. Note _set_decay is
called from step.
5) build_fp32_params - Used to initialize or update the fp32 copy of EMA params.
Note this is enabled only when ema_fp32=True
"""
def __init__(self, model, config, device=None):
"""
@param model model to initialize the EMA with
@param config EMAConfig object with configuration like
ema_decay, ema_update_freq, ema_fp32
@param device If provided, copy EMA to this device (e.g. gpu).
Otherwise EMA is in the same device as the model.
"""
self.decay = config.ema_decay
self.model = copy.deepcopy(model)
self.model.requires_grad_(False)
self.config = config
self.fp32_params = {}
if self.config.ema_seed_model is not None:
state = checkpoint_utils.load_ema_from_checkpoint(self.config.ema_seed_model)
self.model.load_state_dict(state["model"], strict=True)
if device is not None:
logging.info(f"Copying EMA model to device {device}")
self.model = self.model.to(device=device)
if self.config.ema_fp32:
self.build_fp32_params()
self.update_freq_counter = 0
def get_model(self):
return self.model
def build_fp32_params(self, state_dict=None):
"""
Store a copy of the EMA params in fp32.
If state dict is passed, the EMA params is copied from
the provided state dict. Otherwise, it is copied from the
current EMA model parameters.
"""
if not self.config.ema_fp32:
raise RuntimeError(
"build_fp32_params should not be called if ema_fp32=False. "
"Use ema_fp32=True if this is really intended."
)
if state_dict is None:
state_dict = self.model.state_dict()
def _to_float(t):
return t.float() if torch.is_floating_point(t) else t
for param_key in state_dict:
if param_key in self.fp32_params:
self.fp32_params[param_key].copy_(state_dict[param_key])
else:
self.fp32_params[param_key] = _to_float(state_dict[param_key])
def restore(self, state_dict, build_fp32_params=False):
""" Load data from a model spec into EMA model """
self.model.load_state_dict(state_dict, strict=False)
if build_fp32_params:
self.build_fp32_params(state_dict)
def _set_decay(self, decay):
self.decay = decay
def get_decay(self):
return self.decay
def _step_internal(self, new_model, updates=None):
""" One update of the EMA model based on new model weights """
decay = self.decay
ema_state_dict = {}
ema_params = self.fp32_params if self.config.ema_fp32 else self.model.state_dict()
for key, param in new_model.state_dict().items():
try:
ema_param = ema_params[key]
except KeyError:
ema_param = param.float().clone() if param.ndim == 1 else copy.deepcopy(param)
if param.shape != ema_param.shape:
raise ValueError(
"incompatible tensor shapes between model param and ema param"
+ "{} vs. {}".format(param.shape, ema_param.shape)
)
if "version" in key:
# Do not decay a model.version pytorch param
continue
ema_param.mul_(decay)
ema_param.add_(param.to(dtype=ema_param.dtype), alpha=1-decay)
ema_state_dict[key] = ema_param
self.restore(ema_state_dict, build_fp32_params=False)
def step(self, new_model, updates=None):
"""
One update of EMA which is done every self.config.ema_update_freq
updates of the model.
@param updates The current number of model updates done.
Decay is set of 0 if model updates < ema_start_update, which means
the model will be simply copied over to the EMA.
When model updates >= ema_start_updates, then EMA is updated with
a decay of self.config.ema_decay.
"""
self._set_decay(
0
if updates is not None
and updates < self.config.ema_start_update
else self.config.ema_decay
)
if updates is not None and self.config.ema_update_freq > 1:
self.update_freq_counter += 1
if self.update_freq_counter >= self.config.ema_update_freq:
self._step_internal(new_model, updates)
self.update_freq_counter = 0
else:
self._step_internal(new_model, updates)
def reverse(self, model):
"""
Load the model parameters from EMA model.
Useful for inference or fine-tuning from the EMA model.
"""
model.load_state_dict(self.model.state_dict(), strict=False)
return model

View File

@ -20,6 +20,7 @@ from fairseq.dataclass.configs import (
GenerationConfig,
InteractiveConfig,
OptimizationConfig,
EMAConfig,
)
from fairseq.dataclass.utils import gen_parser_from_dataclass
@ -40,6 +41,7 @@ def get_training_parser(default_task="translation"):
add_model_args(parser)
add_optimization_args(parser)
add_checkpoint_args(parser)
add_ema_args(parser)
return parser
@ -379,3 +381,8 @@ def get_args(
setattr(args, k, v)
return args
def add_ema_args(parser):
group = parser.add_argument_group("EMA configuration")
gen_parser_from_dataclass(group, EMAConfig())

View File

@ -22,6 +22,7 @@ from fairseq.dataclass.utils import convert_namespace_to_omegaconf
from fairseq.distributed import utils as distributed_utils
from fairseq.file_io import PathManager
from fairseq.logging import meters, metrics
from fairseq.models.ema import build_ema
from fairseq.nan_detector import NanDetector
from fairseq.optim import lr_scheduler
from omegaconf import OmegaConf
@ -131,6 +132,7 @@ class Trainer(object):
self._warn_once = set()
self._wrapped_criterion = None
self._wrapped_model = None
self._ema = None
# TODO(myleott): support tpu
if self.cuda and self.data_parallel_world_size > 1:
@ -256,6 +258,19 @@ class Trainer(object):
self._wrapped_model = self._model
return self._wrapped_model
@property
def ema(self):
if self._ema is None:
self._build_ema()
return self._ema
def _build_ema(self):
if self.cfg.ema.store_ema:
self._ema = build_ema(self._model, self.cfg.ema, self.device)
logger.info(
"Exponential Moving Average Shadow Model is initialized."
)
@property
def optimizer(self):
if self._optimizer is None:
@ -392,6 +407,12 @@ class Trainer(object):
"previous_training_time": self.cumulative_training_time(),
},
}
if self.cfg.ema.store_ema:
# Save EMA model state as extra state
state_dict["extra_state"]["ema"] = self.ema.get_model().state_dict()
if self.cfg.ema.ema_fp32:
# Save EMA params in fp32
state_dict["extra_state"]["ema_fp32_params"] = self.ema.fp32_params
if not self.cfg.checkpoint.no_save_optimizer_state:
if self._gathered_optim_state is not None:
state_dict["last_optimizer_state"] = self._gathered_optim_state
@ -552,6 +573,31 @@ class Trainer(object):
if isinstance(meter, meters.TimeMeter):
meter.reset()
if self.cfg.ema.store_ema:
if "ema" not in extra_state:
logger.warn(
"EMA not found in checkpoint. But store_ema is True. "
"EMA is re-initialized from checkpoint."
)
self.ema.restore(state["model"], build_fp32_params=self.cfg.ema.ema_fp32)
else:
logger.info(
"Loading EMA from checkpoint"
)
self.ema.restore(extra_state["ema"], build_fp32_params=False)
if self.cfg.ema.ema_fp32:
if "ema_fp32_params" in extra_state:
logger.info(
"Loading EMA fp32 params from checkpoint"
)
self.ema.build_fp32_params(extra_state["ema_fp32_params"])
else:
logger.info(
"Building EMA fp32 params from EMA model in checkpoint"
)
self.ema.build_fp32_params()
logger.info(
"Loaded checkpoint {} (epoch {} @ {} updates)".format(
filename, epoch, self.get_num_updates()
@ -670,6 +716,13 @@ class Trainer(object):
metrics.log_start_time("train_wall", priority=800, round=0)
# If EMA is enabled through store_ema=True
# and task.uses_ema is True, pass the EMA model as a keyword
# argument to the task.
extra_kwargs = {}
if self.cfg.ema.store_ema and getattr(self.task, "uses_ema", False):
extra_kwargs["ema_model"] = self.ema.get_model()
# forward and backward pass
logging_outputs, sample_size, ooms = [], 0, 0
for i, sample in enumerate(samples): # delayed update loop
@ -705,6 +758,7 @@ class Trainer(object):
optimizer=self.optimizer,
update_num=self.get_num_updates(),
ignore_grad=is_dummy_batch,
**extra_kwargs,
)
del loss
@ -840,6 +894,7 @@ class Trainer(object):
self.optimizer,
self.get_num_updates(),
ignore_grad=False,
**extra_kwargs,
)
raise
except OverflowError as e:
@ -871,6 +926,20 @@ class Trainer(object):
if not overflow or self.cfg.distributed_training.ddp_backend == "slow_mo":
self.set_num_updates(self.get_num_updates() + 1)
if self.cfg.ema.store_ema:
# Step EMA forward with new model.
self.ema.step(
self.get_model(),
self.get_num_updates(),
)
metrics.log_scalar(
"ema_decay",
self.ema.get_decay(),
priority=10000,
round=5,
weight=0,
)
if self.tpu:
import torch_xla.core.xla_model as xm
@ -953,6 +1022,13 @@ class Trainer(object):
xm.rendezvous("valid_step") # wait for all workers
# If EMA is enabled through store_ema=True
# and task.uses_ema is True, pass the EMA model as a keyword
# argument to the task.
extra_kwargs = {}
if self.cfg.ema.store_ema and getattr(self.task, "uses_ema", False):
extra_kwargs["ema_model"] = self.ema.get_model()
with torch.no_grad():
self.model.eval()
self.criterion.eval()
@ -961,7 +1037,7 @@ class Trainer(object):
try:
_loss, sample_size, logging_output = self.task.valid_step(
sample, self.model, self.criterion
sample, self.model, self.criterion, **extra_kwargs
)
except RuntimeError as e:
if "out of memory" in str(e):

200
tests/gpu/test_ema_gpu.py Normal file
View File

@ -0,0 +1,200 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import unittest
from copy import deepcopy
from dataclasses import dataclass
from typing import Optional
import torch
from fairseq.models.ema import EMA
class DummyModule(torch.nn.Module):
def __init__(self) -> None:
"""LightningModule for testing purposes
Args:
epoch_min_loss_override (int, optional): Pass in an epoch that will be set to the minimum
validation loss for testing purposes (zero based). If None this is ignored. Defaults to None.
"""
super().__init__()
self.layer = torch.nn.Linear(in_features=32, out_features=2)
self.another_layer = torch.nn.Linear(in_features=2, out_features=2)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.layer(x)
return self.another_layer(x)
@dataclass
class EMAConfig(object):
ema_decay: float = 0.99
ema_start_update: int = 0
ema_fp32: bool = False
ema_seed_model: Optional[str] = None
@unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU")
class TestEMAGPU(unittest.TestCase):
def assertTorchAllClose(self, x, y, atol=1e-8, rtol=1e-5, msg=None):
diff = x.float() - y.float()
diff_norm = torch.norm(diff)
other_norm = torch.norm(y.float())
if msg is None:
msg = "|input - other| > {} + {} * |other|".format(
atol, rtol
)
self.assertLessEqual(
diff_norm,
atol + rtol * other_norm,
msg=msg,
)
def test_ema(self):
model = DummyModule().cuda()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
state = deepcopy(model.state_dict())
config = EMAConfig()
ema = EMA(model, config)
# set decay
ema._set_decay(config.ema_decay)
self.assertEqual(ema.get_decay(), config.ema_decay)
# get model
self.assertEqual(ema.get_model(), ema.model)
# Since fp32 params is not used, it should be of size 0
self.assertEqual(len(ema.fp32_params), 0)
# EMA step
x = torch.randn(32).cuda()
y = model(x)
loss = y.sum()
loss.backward()
optimizer.step()
ema.step(model)
ema_state_dict = ema.get_model().state_dict()
for key, param in model.state_dict().items():
prev_param = state[key]
ema_param = ema_state_dict[key]
if "version" in key:
# Do not decay a model.version pytorch param
continue
self.assertTorchAllClose(
ema_param,
config.ema_decay * prev_param + (1 - config.ema_decay) * param,
)
# Since fp32 params is not used, it should be of size 0
self.assertEqual(len(ema.fp32_params), 0)
# Load EMA into model
model2 = DummyModule().cuda()
ema.reverse(model2)
for key, param in model2.state_dict().items():
ema_param = ema_state_dict[key]
self.assertTrue(
torch.allclose(ema_param, param)
)
def test_ema_fp32(self):
model = DummyModule().cuda().half()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
state = deepcopy(model.state_dict())
config = EMAConfig(ema_fp32=True)
ema = EMA(model, config)
x = torch.randn(32).cuda()
y = model(x.half())
loss = y.sum()
loss.backward()
optimizer.step()
ema.step(model)
for key, param in model.state_dict().items():
prev_param = state[key]
ema_param = ema.get_model().state_dict()[key]
if "version" in key:
# Do not decay a model.version pytorch param
continue
self.assertIn(key, ema.fp32_params)
# EMA update is done in fp32, and hence the EMA param must be
# closer to the EMA update done in fp32 than in fp16.
self.assertLessEqual(
torch.norm(
ema_param.float() -
(config.ema_decay * prev_param.float() + (1 - config.ema_decay) * param.float()).half().float()
),
torch.norm(
ema_param.float() -
(config.ema_decay * prev_param + (1 - config.ema_decay) * param).float()
),
)
self.assertTorchAllClose(
ema_param,
(config.ema_decay * prev_param.float() + (1 - config.ema_decay) * param.float()).half(),
)
def test_ema_fp16(self):
model = DummyModule().cuda().half()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
state = deepcopy(model.state_dict())
config = EMAConfig(ema_fp32=False)
ema = EMA(model, config)
# Since fp32 params is not used, it should be of size 0
self.assertEqual(len(ema.fp32_params), 0)
x = torch.randn(32).cuda()
y = model(x.half())
loss = y.sum()
loss.backward()
optimizer.step()
ema.step(model)
for key, param in model.state_dict().items():
prev_param = state[key]
ema_param = ema.get_model().state_dict()[key]
if "version" in key:
# Do not decay a model.version pytorch param
continue
# EMA update is done in fp16, and hence the EMA param must be
# closer to the EMA update done in fp16 than in fp32.
self.assertLessEqual(
torch.norm(
ema_param.float() -
(config.ema_decay * prev_param + (1 - config.ema_decay) * param).float()
),
torch.norm(
ema_param.float() -
(config.ema_decay * prev_param.float() + (1 - config.ema_decay) * param.float()).half().float()
),
)
self.assertTorchAllClose(
ema_param,
config.ema_decay * prev_param + (1 - config.ema_decay) * param,
)
# Since fp32 params is not used, it should be of size 0
self.assertEqual(len(ema.fp32_params), 0)
if __name__ == "__main__":
unittest.main()

199
tests/test_ema.py Normal file
View File

@ -0,0 +1,199 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import unittest
from copy import deepcopy
from dataclasses import dataclass
from typing import Optional
import torch
from fairseq.models.ema import EMA
class DummyModule(torch.nn.Module):
def __init__(self) -> None:
"""LightningModule for testing purposes
Args:
epoch_min_loss_override (int, optional): Pass in an epoch that will be set to the minimum
validation loss for testing purposes (zero based). If None this is ignored. Defaults to None.
"""
super().__init__()
self.layer = torch.nn.Linear(in_features=32, out_features=2)
self.another_layer = torch.nn.Linear(in_features=2, out_features=2)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.layer(x)
return self.another_layer(x)
@dataclass
class EMAConfig(object):
ema_decay: float = 0.99
ema_start_update: int = 0
ema_fp32: bool = False
ema_seed_model: Optional[str] = None
class TestEMAGPU(unittest.TestCase):
def assertTorchAllClose(self, x, y, atol=1e-8, rtol=1e-5, msg=None):
diff = x.float() - y.float()
diff_norm = torch.norm(diff)
other_norm = torch.norm(y.float())
if msg is None:
msg = "|input - other| > {} + {} * |other|".format(
atol, rtol
)
self.assertLessEqual(
diff_norm,
atol + rtol * other_norm,
msg=msg,
)
def test_ema(self):
model = DummyModule()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
state = deepcopy(model.state_dict())
config = EMAConfig()
ema = EMA(model, config)
# set decay
ema._set_decay(config.ema_decay)
self.assertEqual(ema.get_decay(), config.ema_decay)
# get model
self.assertEqual(ema.get_model(), ema.model)
# Since fp32 params is not used, it should be of size 0
self.assertEqual(len(ema.fp32_params), 0)
# EMA step
x = torch.randn(32)
y = model(x)
loss = y.sum()
loss.backward()
optimizer.step()
ema.step(model)
ema_state_dict = ema.get_model().state_dict()
for key, param in model.state_dict().items():
prev_param = state[key]
ema_param = ema_state_dict[key]
if "version" in key:
# Do not decay a model.version pytorch param
continue
self.assertTorchAllClose(
ema_param,
config.ema_decay * prev_param + (1 - config.ema_decay) * param,
)
# Since fp32 params is not used, it should be of size 0
self.assertEqual(len(ema.fp32_params), 0)
# Load EMA into model
model2 = DummyModule()
ema.reverse(model2)
for key, param in model2.state_dict().items():
ema_param = ema_state_dict[key]
self.assertTrue(
torch.allclose(ema_param, param)
)
def test_ema_fp32(self):
model = DummyModule().half()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
state = deepcopy(model.state_dict())
config = EMAConfig(ema_fp32=True)
ema = EMA(model, config)
x = torch.randn(32)
y = model(x.half())
loss = y.sum()
loss.backward()
optimizer.step()
ema.step(model)
for key, param in model.state_dict().items():
prev_param = state[key]
ema_param = ema.get_model().state_dict()[key]
if "version" in key:
# Do not decay a model.version pytorch param
continue
self.assertIn(key, ema.fp32_params)
# EMA update is done in fp32, and hence the EMA param must be
# closer to the EMA update done in fp32 than in fp16.
self.assertLessEqual(
torch.norm(
ema_param.float() -
(config.ema_decay * prev_param.float() + (1 - config.ema_decay) * param.float()).half().float()
),
torch.norm(
ema_param.float() -
(config.ema_decay * prev_param + (1 - config.ema_decay) * param).float()
),
)
self.assertTorchAllClose(
ema_param,
(config.ema_decay * prev_param.float() + (1 - config.ema_decay) * param.float()).half(),
)
def test_ema_fp16(self):
model = DummyModule().half()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
state = deepcopy(model.state_dict())
config = EMAConfig(ema_fp32=False)
ema = EMA(model, config)
# Since fp32 params is not used, it should be of size 0
self.assertEqual(len(ema.fp32_params), 0)
x = torch.randn(32)
y = model(x.half())
loss = y.sum()
loss.backward()
optimizer.step()
ema.step(model)
for key, param in model.state_dict().items():
prev_param = state[key]
ema_param = ema.get_model().state_dict()[key]
if "version" in key:
# Do not decay a model.version pytorch param
continue
# EMA update is done in fp16, and hence the EMA param must be
# closer to the EMA update done in fp16 than in fp32.
self.assertLessEqual(
torch.norm(
ema_param.float() -
(config.ema_decay * prev_param + (1 - config.ema_decay) * param).float()
),
torch.norm(
ema_param.float() -
(config.ema_decay * prev_param.float() + (1 - config.ema_decay) * param.float()).half().float()
),
)
self.assertTorchAllClose(
ema_param,
config.ema_decay * prev_param + (1 - config.ema_decay) * param,
)
# Since fp32 params is not used, it should be of size 0
self.assertEqual(len(ema.fp32_params), 0)
if __name__ == "__main__":
unittest.main()