fairseq/tests/test_plasma_utils.py
Pierre Andrews f591cc94ca upgrade black for lints (#3004)
Summary:
This is the same as https://github.com/fairinternal/fairseq-py/issues/3003 but for main instead of gshard.

the lint test will run the latest version of black, which is 22.1.0 right now and seems to be incompatible with the 21.12b0 version that is setup in pre-commit. This means that some files were with valid format in the past, but are not anymore...

This PR formats these files with 22.1.0 and autoupdates pre-commit config to use that black version too.

(note: this is the second time it happens. a solution would be to pin the lint test to the same version as the one in the pre-commit hook and that was used to format everything clean so that we have a stable formating)

Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/3004

Reviewed By: dianaml0

Differential Revision: D33917490

Pulled By: Mortimerp9

fbshipit-source-id: d55e800b976f94545cdab4132daa7c45cbd0e34c
2022-02-02 04:31:33 -08:00

128 lines
4.7 KiB
Python

import contextlib
import tempfile
import unittest
from io import StringIO
import numpy as np
from tests.utils import create_dummy_data, preprocess_lm_data, train_language_model
try:
from pyarrow import plasma
from fairseq.data.plasma_utils import PlasmaStore, PlasmaView
PYARROW_AVAILABLE = True
except ImportError:
PYARROW_AVAILABLE = False
dummy_path = "dummy"
@unittest.skipUnless(PYARROW_AVAILABLE, "")
class TestPlasmaView(unittest.TestCase):
def setUp(self) -> None:
self.tmp_file = tempfile.NamedTemporaryFile() # noqa: P201
self.path = self.tmp_file.name
self.server = PlasmaStore.start(path=self.path, nbytes=10000)
self.client = plasma.connect(self.path, num_retries=10)
def tearDown(self) -> None:
self.client.disconnect()
self.tmp_file.close()
self.server.kill()
def test_two_servers_do_not_share_object_id_space(self):
data_server_1 = np.array([0, 1])
data_server_2 = np.array([2, 3])
server_2_path = self.path
with tempfile.NamedTemporaryFile() as server_1_path:
server = PlasmaStore.start(path=server_1_path.name, nbytes=10000)
arr1 = PlasmaView(
data_server_1, dummy_path, 1, plasma_path=server_1_path.name
)
assert len(arr1.client.list()) == 1
assert (arr1.array == data_server_1).all()
arr2 = PlasmaView(data_server_2, dummy_path, 1, plasma_path=server_2_path)
assert (arr2.array == data_server_2).all()
assert (arr1.array == data_server_1).all()
server.kill()
def test_hash_collision(self):
data_server_1 = np.array([0, 1])
data_server_2 = np.array([2, 3])
arr1 = PlasmaView(data_server_1, dummy_path, 1, plasma_path=self.path)
assert len(arr1.client.list()) == 1
arr2 = PlasmaView(data_server_2, dummy_path, 1, plasma_path=self.path)
assert len(arr1.client.list()) == 1
assert len(arr2.client.list()) == 1
assert (arr2.array == data_server_1).all()
# New hash key based on tuples
arr3 = PlasmaView(
data_server_2, dummy_path, (1, 12312312312, None), plasma_path=self.path
)
assert (
len(arr2.client.list()) == 2
), "No new object was created by using a novel hash key"
assert (
arr3.object_id in arr2.client.list()
), "No new object was created by using a novel hash key"
assert (
arr3.object_id in arr3.client.list()
), "No new object was created by using a novel hash key"
del arr3, arr2, arr1
@staticmethod
def _assert_view_equal(pv1, pv2):
np.testing.assert_array_equal(pv1.array, pv2.array)
def test_putting_same_array_twice(self):
data = np.array([4, 4, 4])
arr1 = PlasmaView(data, dummy_path, 1, plasma_path=self.path)
assert len(self.client.list()) == 1
arr1b = PlasmaView(
data, dummy_path, 1, plasma_path=self.path
) # should not change contents of store
arr1c = PlasmaView(
None, dummy_path, 1, plasma_path=self.path
) # should not change contents of store
assert len(self.client.list()) == 1
self._assert_view_equal(arr1, arr1b)
self._assert_view_equal(arr1, arr1c)
PlasmaView(
data, dummy_path, 2, plasma_path=self.path
) # new object id, adds new entry
assert len(self.client.list()) == 2
new_client = plasma.connect(self.path)
assert len(new_client.list()) == 2 # new client can access same objects
assert isinstance(arr1.object_id, plasma.ObjectID)
del arr1b
del arr1c
def test_plasma_store_full_raises(self):
with tempfile.NamedTemporaryFile() as new_path:
server = PlasmaStore.start(path=new_path.name, nbytes=10000)
with self.assertRaises(plasma.PlasmaStoreFull):
# 2000 floats is more than 2000 bytes
PlasmaView(
np.random.rand(10000, 1), dummy_path, 1, plasma_path=new_path.name
)
server.kill()
def test_object_id_overflow(self):
PlasmaView.get_object_id("", 2**21)
def test_training_lm_plasma(self):
with contextlib.redirect_stdout(StringIO()):
with tempfile.TemporaryDirectory("test_transformer_lm") as data_dir:
create_dummy_data(data_dir)
preprocess_lm_data(data_dir)
train_language_model(
data_dir,
"transformer_lm",
["--use-plasma-view", "--plasma-path", self.path],
run_validation=True,
)