Fix broken EMA in fairseq

Summary: EMA broken since D33649708 (995c204337) due to indentation error.

Reviewed By: cruvadom

Differential Revision: D33809223

fbshipit-source-id: c6c4d0d327443bfea787817040e1832eef0f50e4
This commit is contained in:
Vimal Manohar 2022-01-27 13:01:49 -08:00 committed by Facebook GitHub Bot
parent 4a7835b794
commit 1b61bbad32
2 changed files with 65 additions and 6 deletions

View File

@ -185,11 +185,11 @@ class EMA(object):
self._set_decay(
0 if updates < self.config.ema_start_update else self.config.ema_decay
)
if updates is not None and self.config.ema_update_freq > 1:
self.update_freq_counter += 1
if self.update_freq_counter >= self.config.ema_update_freq:
self._step_internal(new_model, updates)
self.update_freq_counter = 0
if self.config.ema_update_freq > 1:
self.update_freq_counter += 1
if self.update_freq_counter >= self.config.ema_update_freq:
self._step_internal(new_model, updates)
self.update_freq_counter = 0
else:
self._step_internal(new_model, updates)

View File

@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree.
import unittest
from unittest.mock import patch
from copy import deepcopy
from dataclasses import dataclass
from typing import Optional
@ -36,9 +37,10 @@ class EMAConfig(object):
ema_start_update: int = 0
ema_fp32: bool = False
ema_seed_model: Optional[str] = None
ema_update_freq: int = 1
class TestEMAGPU(unittest.TestCase):
class TestEMA(unittest.TestCase):
def assertTorchAllClose(self, x, y, atol=1e-8, rtol=1e-5, msg=None):
diff = x.float() - y.float()
diff_norm = torch.norm(diff)
@ -104,6 +106,63 @@ class TestEMAGPU(unittest.TestCase):
ema_param = ema_state_dict[key]
self.assertTrue(torch.allclose(ema_param, param))
# Check that step_internal is called once
with patch.object(
ema, "_step_internal", return_value=None
) as mock_method:
ema.step(model)
mock_method.assert_called_once_with(model, None)
def _test_ema_start_update(self, updates):
model = DummyModule()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
state = deepcopy(model.state_dict())
config = EMAConfig(ema_start_update=1)
ema = EMA(model, config)
# EMA step
x = torch.randn(32)
y = model(x)
loss = y.sum()
loss.backward()
optimizer.step()
ema.step(model, updates=updates)
ema_state_dict = ema.get_model().state_dict()
self.assertEqual(ema.get_decay(), 0 if updates == 0 else config.ema_decay)
for key, param in model.state_dict().items():
ema_param = ema_state_dict[key]
prev_param = state[key]
if "version" in key:
# Do not decay a model.version pytorch param
continue
if updates == 0:
self.assertTorchAllClose(
ema_param,
param,
)
else:
self.assertTorchAllClose(
ema_param,
config.ema_decay * prev_param + (1 - config.ema_decay) * param,
)
# Check that step_internal is called once
with patch.object(
ema, "_step_internal", return_value=None
) as mock_method:
ema.step(model, updates=updates)
mock_method.assert_called_once_with(model, updates)
def test_ema_before_start_update(self):
self._test_ema_start_update(updates=0)
def test_ema_after_start_update(self):
self._test_ema_start_update(updates=1)
def test_ema_fp32(self):
model = DummyModule().half()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)