mirror of
https://github.com/facebookresearch/fairseq.git
synced 2024-09-20 22:01:21 +03:00
Fix broken EMA in fairseq
Summary: EMA broken since D33649708 (995c204337
) due to indentation error.
Reviewed By: cruvadom
Differential Revision: D33809223
fbshipit-source-id: c6c4d0d327443bfea787817040e1832eef0f50e4
This commit is contained in:
parent
4a7835b794
commit
1b61bbad32
@ -185,11 +185,11 @@ class EMA(object):
|
||||
self._set_decay(
|
||||
0 if updates < self.config.ema_start_update else self.config.ema_decay
|
||||
)
|
||||
if updates is not None and self.config.ema_update_freq > 1:
|
||||
self.update_freq_counter += 1
|
||||
if self.update_freq_counter >= self.config.ema_update_freq:
|
||||
self._step_internal(new_model, updates)
|
||||
self.update_freq_counter = 0
|
||||
if self.config.ema_update_freq > 1:
|
||||
self.update_freq_counter += 1
|
||||
if self.update_freq_counter >= self.config.ema_update_freq:
|
||||
self._step_internal(new_model, updates)
|
||||
self.update_freq_counter = 0
|
||||
else:
|
||||
self._step_internal(new_model, updates)
|
||||
|
||||
|
@ -4,6 +4,7 @@
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
@ -36,9 +37,10 @@ class EMAConfig(object):
|
||||
ema_start_update: int = 0
|
||||
ema_fp32: bool = False
|
||||
ema_seed_model: Optional[str] = None
|
||||
ema_update_freq: int = 1
|
||||
|
||||
|
||||
class TestEMAGPU(unittest.TestCase):
|
||||
class TestEMA(unittest.TestCase):
|
||||
def assertTorchAllClose(self, x, y, atol=1e-8, rtol=1e-5, msg=None):
|
||||
diff = x.float() - y.float()
|
||||
diff_norm = torch.norm(diff)
|
||||
@ -104,6 +106,63 @@ class TestEMAGPU(unittest.TestCase):
|
||||
ema_param = ema_state_dict[key]
|
||||
self.assertTrue(torch.allclose(ema_param, param))
|
||||
|
||||
# Check that step_internal is called once
|
||||
with patch.object(
|
||||
ema, "_step_internal", return_value=None
|
||||
) as mock_method:
|
||||
ema.step(model)
|
||||
mock_method.assert_called_once_with(model, None)
|
||||
|
||||
def _test_ema_start_update(self, updates):
|
||||
model = DummyModule()
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
|
||||
state = deepcopy(model.state_dict())
|
||||
config = EMAConfig(ema_start_update=1)
|
||||
ema = EMA(model, config)
|
||||
|
||||
# EMA step
|
||||
x = torch.randn(32)
|
||||
y = model(x)
|
||||
loss = y.sum()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
ema.step(model, updates=updates)
|
||||
ema_state_dict = ema.get_model().state_dict()
|
||||
|
||||
self.assertEqual(ema.get_decay(), 0 if updates == 0 else config.ema_decay)
|
||||
|
||||
for key, param in model.state_dict().items():
|
||||
ema_param = ema_state_dict[key]
|
||||
prev_param = state[key]
|
||||
|
||||
if "version" in key:
|
||||
# Do not decay a model.version pytorch param
|
||||
continue
|
||||
if updates == 0:
|
||||
self.assertTorchAllClose(
|
||||
ema_param,
|
||||
param,
|
||||
)
|
||||
else:
|
||||
self.assertTorchAllClose(
|
||||
ema_param,
|
||||
config.ema_decay * prev_param + (1 - config.ema_decay) * param,
|
||||
)
|
||||
|
||||
# Check that step_internal is called once
|
||||
with patch.object(
|
||||
ema, "_step_internal", return_value=None
|
||||
) as mock_method:
|
||||
ema.step(model, updates=updates)
|
||||
mock_method.assert_called_once_with(model, updates)
|
||||
|
||||
def test_ema_before_start_update(self):
|
||||
self._test_ema_start_update(updates=0)
|
||||
|
||||
def test_ema_after_start_update(self):
|
||||
self._test_ema_start_update(updates=1)
|
||||
|
||||
def test_ema_fp32(self):
|
||||
model = DummyModule().half()
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
|
||||
|
Loading…
Reference in New Issue
Block a user