diff --git a/tests/test_checkpoint_utils.py b/tests/test_checkpoint_utils.py index 23ba034f3..1e58ddb11 100644 --- a/tests/test_checkpoint_utils.py +++ b/tests/test_checkpoint_utils.py @@ -19,6 +19,7 @@ from tests.utils import ( preprocess_translation_data, train_translation_model, ) +import torch class TestCheckpointUtils(unittest.TestCase): @@ -103,6 +104,27 @@ class TestCheckpointUtils(unittest.TestCase): mock_opena.assert_called_with(filename, "wb") mock_save.assert_called() + def test_load_ema_from_checkpoint(self): + dummy_state = {"a": torch.tensor([1]), "b": torch.tensor([0.1])} + with patch(f"{checkpoint_utils.__name__}.PathManager.open") as mock_open, patch( + f"{checkpoint_utils.__name__}.torch.load") as mock_load: + + mock_load.return_value = { + "extra_state": { + "ema": dummy_state + } + } + filename = "ema_checkpoint.pt" + state = checkpoint_utils.load_ema_from_checkpoint(filename) + + mock_open.assert_called_with(filename, "rb") + mock_load.assert_called() + + self.assertIn("a", state["model"]) + self.assertIn("b", state["model"]) + self.assertTrue(torch.allclose(dummy_state["a"], state["model"]["a"])) + self.assertTrue(torch.allclose(dummy_state["b"], state["model"]["b"])) + if __name__ == "__main__": unittest.main()