ioPath async - Fairseq unittests (#1669)

Summary:
Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1669

Unit tests for async writes integration done in D26467815 (3100d0b8e5).

Ongoing performance tests: https://fb.quip.com/kjM7Atb1kKbO

Reviewed By: myleott

Differential Revision: D26732660

fbshipit-source-id: faf8cac67b9167af4195358c1a2592804c13562c
This commit is contained in:
Eric Lou 2021-03-03 10:48:42 -08:00 committed by Facebook GitHub Bot
parent 0c32e251e2
commit 7d2394b56f
3 changed files with 31 additions and 1 deletions

View File

@ -170,7 +170,7 @@ class PathManager:
if not IOPathPathManager:
logging.info("ioPath is initializing PathManager.")
try:
from iopath import PathManager
from iopath.common.file_io import PathManager
IOPathPathManager = PathManager()
except Exception:
logging.exception("Failed to initialize ioPath PathManager object.")

View File

@ -9,8 +9,10 @@ import os
import tempfile
import unittest
from io import StringIO
from unittest.mock import patch
from fairseq import checkpoint_utils
from omegaconf import OmegaConf
from tests.utils import (
create_dummy_data,
@ -87,6 +89,19 @@ class TestCheckpointUtils(unittest.TestCase):
self.assertEqual(len(ensemble[0].encoder.layers), 2)
self.assertEqual(len(ensemble[0].decoder.layers), 1)
def test_torch_persistent_save_async(self):
cfg = OmegaConf.create()
cfg.dataset = OmegaConf.create()
cfg.dataset.write_checkpoints_asynchronously = True
state_dict = {}
filename = "async_checkpoint.pt"
with patch(f"{checkpoint_utils.__name__}.PathManager.opena") as mock_opena:
with patch(f"{checkpoint_utils.__name__}._torch_persistent_save") as mock_save:
checkpoint_utils.torch_persistent_save(cfg.dataset, state_dict, filename)
mock_opena.assert_called_with(filename, "wb")
mock_save.assert_called()
if __name__ == "__main__":
unittest.main()

View File

@ -45,3 +45,18 @@ class TestFileIO(unittest.TestCase):
with PathManager.open(os.path.join(self._tmpdir, "test.txt"), "r") as f:
s = f.read()
self.assertEqual(s, self._tmpfile_contents)
def test_file_io_async(self):
# ioPath `PathManager` is initialized after the first `opena` call.
try:
from fairseq.file_io import IOPathPathManager, PathManager
self.assertIsNone(IOPathPathManager)
_asyncfile = os.path.join(self._tmpdir, "async.txt")
f = PathManager.opena(_asyncfile, "wb")
f.close()
from fairseq.file_io import IOPathPathManager
self.assertIsNotNone(IOPathPathManager)
finally:
self.assertTrue(PathManager.async_close())