mirror of
https://github.com/facebookresearch/fairseq.git
synced 2024-07-14 18:50:22 +03:00
Summary: This is the same as https://github.com/fairinternal/fairseq-py/issues/3003 but for main instead of gshard. the lint test will run the latest version of black, which is 22.1.0 right now and seems to be incompatible with the 21.12b0 version that is setup in pre-commit. This means that some files were with valid format in the past, but are not anymore... This PR formats these files with 22.1.0 and autoupdates pre-commit config to use that black version too. (note: this is the second time it happens. a solution would be to pin the lint test to the same version as the one in the pre-commit hook and that was used to format everything clean so that we have a stable formating) Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/3004 Reviewed By: dianaml0 Differential Revision: D33917490 Pulled By: Mortimerp9 fbshipit-source-id: d55e800b976f94545cdab4132daa7c45cbd0e34c
128 lines
4.7 KiB
Python
128 lines
4.7 KiB
Python
import contextlib
|
|
import tempfile
|
|
import unittest
|
|
from io import StringIO
|
|
|
|
import numpy as np
|
|
|
|
from tests.utils import create_dummy_data, preprocess_lm_data, train_language_model
|
|
|
|
try:
|
|
from pyarrow import plasma
|
|
|
|
from fairseq.data.plasma_utils import PlasmaStore, PlasmaView
|
|
|
|
PYARROW_AVAILABLE = True
|
|
except ImportError:
|
|
PYARROW_AVAILABLE = False
|
|
|
|
dummy_path = "dummy"
|
|
|
|
|
|
@unittest.skipUnless(PYARROW_AVAILABLE, "")
|
|
class TestPlasmaView(unittest.TestCase):
|
|
def setUp(self) -> None:
|
|
self.tmp_file = tempfile.NamedTemporaryFile() # noqa: P201
|
|
self.path = self.tmp_file.name
|
|
self.server = PlasmaStore.start(path=self.path, nbytes=10000)
|
|
self.client = plasma.connect(self.path, num_retries=10)
|
|
|
|
def tearDown(self) -> None:
|
|
self.client.disconnect()
|
|
self.tmp_file.close()
|
|
self.server.kill()
|
|
|
|
def test_two_servers_do_not_share_object_id_space(self):
|
|
data_server_1 = np.array([0, 1])
|
|
data_server_2 = np.array([2, 3])
|
|
server_2_path = self.path
|
|
with tempfile.NamedTemporaryFile() as server_1_path:
|
|
server = PlasmaStore.start(path=server_1_path.name, nbytes=10000)
|
|
arr1 = PlasmaView(
|
|
data_server_1, dummy_path, 1, plasma_path=server_1_path.name
|
|
)
|
|
assert len(arr1.client.list()) == 1
|
|
assert (arr1.array == data_server_1).all()
|
|
arr2 = PlasmaView(data_server_2, dummy_path, 1, plasma_path=server_2_path)
|
|
assert (arr2.array == data_server_2).all()
|
|
assert (arr1.array == data_server_1).all()
|
|
server.kill()
|
|
|
|
def test_hash_collision(self):
|
|
data_server_1 = np.array([0, 1])
|
|
data_server_2 = np.array([2, 3])
|
|
arr1 = PlasmaView(data_server_1, dummy_path, 1, plasma_path=self.path)
|
|
assert len(arr1.client.list()) == 1
|
|
arr2 = PlasmaView(data_server_2, dummy_path, 1, plasma_path=self.path)
|
|
assert len(arr1.client.list()) == 1
|
|
assert len(arr2.client.list()) == 1
|
|
assert (arr2.array == data_server_1).all()
|
|
# New hash key based on tuples
|
|
arr3 = PlasmaView(
|
|
data_server_2, dummy_path, (1, 12312312312, None), plasma_path=self.path
|
|
)
|
|
assert (
|
|
len(arr2.client.list()) == 2
|
|
), "No new object was created by using a novel hash key"
|
|
assert (
|
|
arr3.object_id in arr2.client.list()
|
|
), "No new object was created by using a novel hash key"
|
|
assert (
|
|
arr3.object_id in arr3.client.list()
|
|
), "No new object was created by using a novel hash key"
|
|
del arr3, arr2, arr1
|
|
|
|
@staticmethod
|
|
def _assert_view_equal(pv1, pv2):
|
|
np.testing.assert_array_equal(pv1.array, pv2.array)
|
|
|
|
def test_putting_same_array_twice(self):
|
|
data = np.array([4, 4, 4])
|
|
arr1 = PlasmaView(data, dummy_path, 1, plasma_path=self.path)
|
|
assert len(self.client.list()) == 1
|
|
arr1b = PlasmaView(
|
|
data, dummy_path, 1, plasma_path=self.path
|
|
) # should not change contents of store
|
|
arr1c = PlasmaView(
|
|
None, dummy_path, 1, plasma_path=self.path
|
|
) # should not change contents of store
|
|
|
|
assert len(self.client.list()) == 1
|
|
self._assert_view_equal(arr1, arr1b)
|
|
self._assert_view_equal(arr1, arr1c)
|
|
PlasmaView(
|
|
data, dummy_path, 2, plasma_path=self.path
|
|
) # new object id, adds new entry
|
|
assert len(self.client.list()) == 2
|
|
|
|
new_client = plasma.connect(self.path)
|
|
assert len(new_client.list()) == 2 # new client can access same objects
|
|
assert isinstance(arr1.object_id, plasma.ObjectID)
|
|
del arr1b
|
|
del arr1c
|
|
|
|
def test_plasma_store_full_raises(self):
|
|
with tempfile.NamedTemporaryFile() as new_path:
|
|
server = PlasmaStore.start(path=new_path.name, nbytes=10000)
|
|
with self.assertRaises(plasma.PlasmaStoreFull):
|
|
# 2000 floats is more than 2000 bytes
|
|
PlasmaView(
|
|
np.random.rand(10000, 1), dummy_path, 1, plasma_path=new_path.name
|
|
)
|
|
server.kill()
|
|
|
|
def test_object_id_overflow(self):
|
|
PlasmaView.get_object_id("", 2**21)
|
|
|
|
def test_training_lm_plasma(self):
|
|
with contextlib.redirect_stdout(StringIO()):
|
|
with tempfile.TemporaryDirectory("test_transformer_lm") as data_dir:
|
|
create_dummy_data(data_dir)
|
|
preprocess_lm_data(data_dir)
|
|
train_language_model(
|
|
data_dir,
|
|
"transformer_lm",
|
|
["--use-plasma-view", "--plasma-path", self.path],
|
|
run_validation=True,
|
|
)
|