don't use half precision in test_ema on CPU (#3408)

Summary:
X-link: https://github.com/fairinternal/fairseq-py/pull/3408

Pull Request resolved: https://github.com/facebookresearch/fairseq/pull/4443

To fix errors introduced in D35571505

Reviewed By: ngimel

Differential Revision: D36726254

fbshipit-source-id: dde8964c47426839b03c842574669ae9428031c6
This commit is contained in:
Jongsoo Park 2022-05-26 21:14:17 -07:00 committed by Facebook GitHub Bot
parent b5e7b25091
commit e0884db9a7

View File

@ -160,14 +160,17 @@ class TestEMA(unittest.TestCase):
self._test_ema_start_update(updates=1)
def test_ema_fp32(self):
model = DummyModule().half()
# CPU no longer supports Linear in half precision
dtype = torch.half if torch.cuda.is_available() else torch.float
model = DummyModule().to(dtype)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
state = deepcopy(model.state_dict())
config = EMAConfig(ema_fp32=True)
ema = EMA(model, config)
x = torch.randn(32)
y = model(x.half())
y = model(x.to(dtype))
loss = y.sum()
loss.backward()
optimizer.step()
@ -192,7 +195,7 @@ class TestEMA(unittest.TestCase):
config.ema_decay * prev_param.float()
+ (1 - config.ema_decay) * param.float()
)
.half()
.to(dtype)
.float()
),
torch.norm(
@ -207,10 +210,14 @@ class TestEMA(unittest.TestCase):
(
config.ema_decay * prev_param.float()
+ (1 - config.ema_decay) * param.float()
).half(),
).to(dtype),
)
def test_ema_fp16(self):
# CPU no longer supports Linear in half precision
if not torch.cuda.is_available():
return
model = DummyModule().half()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
state = deepcopy(model.state_dict())