Add unittests for jitting EMA model

Summary: As title

Reviewed By: nayansinghal

Differential Revision: D32005717

fbshipit-source-id: ebdf1ed0e4a2b9fccffd841d0fa7be0b50ec6b79
This commit is contained in:
Vimal Manohar 2022-01-13 01:52:50 -08:00 committed by Facebook GitHub Bot
parent fa7663c314
commit cf8ff8c3c5

View File

@ -19,6 +19,7 @@ from tests.utils import (
preprocess_translation_data,
train_translation_model,
)
import torch
class TestCheckpointUtils(unittest.TestCase):
@ -103,6 +104,27 @@ class TestCheckpointUtils(unittest.TestCase):
mock_opena.assert_called_with(filename, "wb")
mock_save.assert_called()
def test_load_ema_from_checkpoint(self):
dummy_state = {"a": torch.tensor([1]), "b": torch.tensor([0.1])}
with patch(f"{checkpoint_utils.__name__}.PathManager.open") as mock_open, patch(
f"{checkpoint_utils.__name__}.torch.load") as mock_load:
mock_load.return_value = {
"extra_state": {
"ema": dummy_state
}
}
filename = "ema_checkpoint.pt"
state = checkpoint_utils.load_ema_from_checkpoint(filename)
mock_open.assert_called_with(filename, "rb")
mock_load.assert_called()
self.assertIn("a", state["model"])
self.assertIn("b", state["model"])
self.assertTrue(torch.allclose(dummy_state["a"], state["model"]["a"]))
self.assertTrue(torch.allclose(dummy_state["b"], state["model"]["b"]))
if __name__ == "__main__":
unittest.main()