add DataStore.managed() (#16890)

This commit is contained in:
Kyle Altendorf 2023-12-05 14:31:30 -05:00 committed by GitHub
parent 4f55ffb1eb
commit 99ee3e4c34
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 113 additions and 129 deletions

View File

@ -188,27 +188,25 @@ class DataLayer:
sql_log_path = path_from_root(self.root_path, "log/data_sql.log")
self.log.info(f"logging SQL commands to {sql_log_path}")
self._data_store = await DataStore.create(database=self.db_path, sql_log_path=sql_log_path)
self._wallet_rpc = await self.wallet_rpc_init
async with DataStore.managed(database=self.db_path, sql_log_path=sql_log_path) as self._data_store:
self._wallet_rpc = await self.wallet_rpc_init
self.periodically_manage_data_task = asyncio.create_task(self.periodically_manage_data())
try:
yield
finally:
# TODO: review for anything else we need to do here
self._shut_down = True
if self._wallet_rpc is not None:
self.wallet_rpc.close()
self.periodically_manage_data_task = asyncio.create_task(self.periodically_manage_data())
try:
yield
finally:
# TODO: review for anything else we need to do here
self._shut_down = True
if self._wallet_rpc is not None:
self.wallet_rpc.close()
if self.periodically_manage_data_task is not None:
try:
self.periodically_manage_data_task.cancel()
except asyncio.CancelledError:
pass
if self._data_store is not None:
await self.data_store.close()
if self._wallet_rpc is not None:
await self.wallet_rpc.await_closed()
if self.periodically_manage_data_task is not None:
try:
self.periodically_manage_data_task.cancel()
except asyncio.CancelledError:
pass
if self._wallet_rpc is not None:
await self.wallet_rpc.await_closed()
def _set_state_changed_callback(self, callback: StateChangedProtocol) -> None:
self.state_changed_callback = callback

View File

@ -1,5 +1,6 @@
from __future__ import annotations
import contextlib
import logging
from collections import defaultdict
from contextlib import asynccontextmanager
@ -48,10 +49,11 @@ class DataStore:
db_wrapper: DBWrapper2
@classmethod
async def create(
@contextlib.asynccontextmanager
async def managed(
cls, database: Union[str, Path], uri: bool = False, sql_log_path: Optional[Path] = None
) -> DataStore:
db_wrapper = await DBWrapper2.create(
) -> AsyncIterator[DataStore]:
async with DBWrapper2.managed(
database=database,
uri=uri,
journal_mode="WAL",
@ -63,100 +65,97 @@ class DataStore:
foreign_keys=True,
row_factory=aiosqlite.Row,
log_path=sql_log_path,
)
self = cls(db_wrapper=db_wrapper)
) as db_wrapper:
self = cls(db_wrapper=db_wrapper)
async with db_wrapper.writer() as writer:
await writer.execute(
f"""
CREATE TABLE IF NOT EXISTS node(
hash BLOB PRIMARY KEY NOT NULL CHECK(length(hash) == 32),
node_type INTEGER NOT NULL CHECK(
(
node_type == {int(NodeType.INTERNAL)}
AND left IS NOT NULL
AND right IS NOT NULL
AND key IS NULL
AND value IS NULL
)
OR
(
node_type == {int(NodeType.TERMINAL)}
AND left IS NULL
AND right IS NULL
AND key IS NOT NULL
AND value IS NOT NULL
)
),
left BLOB REFERENCES node,
right BLOB REFERENCES node,
key BLOB,
value BLOB
async with db_wrapper.writer() as writer:
await writer.execute(
f"""
CREATE TABLE IF NOT EXISTS node(
hash BLOB PRIMARY KEY NOT NULL CHECK(length(hash) == 32),
node_type INTEGER NOT NULL CHECK(
(
node_type == {int(NodeType.INTERNAL)}
AND left IS NOT NULL
AND right IS NOT NULL
AND key IS NULL
AND value IS NULL
)
OR
(
node_type == {int(NodeType.TERMINAL)}
AND left IS NULL
AND right IS NULL
AND key IS NOT NULL
AND value IS NOT NULL
)
),
left BLOB REFERENCES node,
right BLOB REFERENCES node,
key BLOB,
value BLOB
)
"""
)
"""
)
await writer.execute(
"""
CREATE TRIGGER IF NOT EXISTS no_node_updates
BEFORE UPDATE ON node
BEGIN
SELECT RAISE(FAIL, 'updates not allowed to the node table');
END
"""
)
await writer.execute(
f"""
CREATE TABLE IF NOT EXISTS root(
tree_id BLOB NOT NULL CHECK(length(tree_id) == 32),
generation INTEGER NOT NULL CHECK(generation >= 0),
node_hash BLOB,
status INTEGER NOT NULL CHECK(
{" OR ".join(f"status == {status}" for status in Status)}
),
PRIMARY KEY(tree_id, generation),
FOREIGN KEY(node_hash) REFERENCES node(hash)
await writer.execute(
"""
CREATE TRIGGER IF NOT EXISTS no_node_updates
BEFORE UPDATE ON node
BEGIN
SELECT RAISE(FAIL, 'updates not allowed to the node table');
END
"""
)
"""
)
# TODO: Add ancestor -> hash relationship, this might involve temporarily
# deferring the foreign key enforcement due to the insertion order
# and the node table also enforcing a similar relationship in the
# other direction.
# FOREIGN KEY(ancestor) REFERENCES ancestors(ancestor)
await writer.execute(
"""
CREATE TABLE IF NOT EXISTS ancestors(
hash BLOB NOT NULL REFERENCES node,
ancestor BLOB CHECK(length(ancestor) == 32),
tree_id BLOB NOT NULL CHECK(length(tree_id) == 32),
generation INTEGER NOT NULL,
PRIMARY KEY(hash, tree_id, generation),
FOREIGN KEY(ancestor) REFERENCES node(hash)
await writer.execute(
f"""
CREATE TABLE IF NOT EXISTS root(
tree_id BLOB NOT NULL CHECK(length(tree_id) == 32),
generation INTEGER NOT NULL CHECK(generation >= 0),
node_hash BLOB,
status INTEGER NOT NULL CHECK(
{" OR ".join(f"status == {status}" for status in Status)}
),
PRIMARY KEY(tree_id, generation),
FOREIGN KEY(node_hash) REFERENCES node(hash)
)
"""
)
"""
)
await writer.execute(
"""
CREATE TABLE IF NOT EXISTS subscriptions(
tree_id BLOB NOT NULL CHECK(length(tree_id) == 32),
url TEXT,
ignore_till INTEGER,
num_consecutive_failures INTEGER,
from_wallet tinyint CHECK(from_wallet == 0 OR from_wallet == 1),
PRIMARY KEY(tree_id, url)
# TODO: Add ancestor -> hash relationship, this might involve temporarily
# deferring the foreign key enforcement due to the insertion order
# and the node table also enforcing a similar relationship in the
# other direction.
# FOREIGN KEY(ancestor) REFERENCES ancestors(ancestor)
await writer.execute(
"""
CREATE TABLE IF NOT EXISTS ancestors(
hash BLOB NOT NULL REFERENCES node,
ancestor BLOB CHECK(length(ancestor) == 32),
tree_id BLOB NOT NULL CHECK(length(tree_id) == 32),
generation INTEGER NOT NULL,
PRIMARY KEY(hash, tree_id, generation),
FOREIGN KEY(ancestor) REFERENCES node(hash)
)
"""
)
await writer.execute(
"""
CREATE TABLE IF NOT EXISTS subscriptions(
tree_id BLOB NOT NULL CHECK(length(tree_id) == 32),
url TEXT,
ignore_till INTEGER,
num_consecutive_failures INTEGER,
from_wallet tinyint CHECK(from_wallet == 0 OR from_wallet == 1),
PRIMARY KEY(tree_id, url)
)
"""
)
await writer.execute(
"""
CREATE INDEX IF NOT EXISTS node_hash ON root(node_hash)
"""
)
"""
)
await writer.execute(
"""
CREATE INDEX IF NOT EXISTS node_hash ON root(node_hash)
"""
)
return self
async def close(self) -> None:
await self.db_wrapper.close()
yield self
@asynccontextmanager
async def transaction(self) -> AsyncIterator[None]:

View File

@ -22,8 +22,7 @@ async def generate_datastore(num_nodes: int, slow_mode: bool) -> None:
if os.path.exists(db_path):
os.remove(db_path)
data_store = await DataStore.create(database=db_path)
try:
async with DataStore.managed(database=db_path) as data_store:
hint_keys_values: Dict[bytes, bytes] = {}
tree_id = bytes32(b"0" * 32)
@ -107,8 +106,6 @@ async def generate_datastore(num_nodes: int, slow_mode: bool) -> None:
print(f"Total time for {num_nodes} operations: {insert_time + autoinsert_time + delete_time}")
root = await data_store.get_tree_root(tree_id=tree_id)
print(f"Root hash: {root.node_hash}")
finally:
await data_store.close()
if __name__ == "__main__":

View File

@ -61,9 +61,8 @@ def tree_id_fixture() -> bytes32:
@pytest.fixture(name="raw_data_store", scope="function")
async def raw_data_store_fixture(database_uri: str) -> AsyncIterable[DataStore]:
store = await DataStore.create(database=database_uri, uri=True)
yield store
await store.close()
async with DataStore.managed(database=database_uri, uri=True) as store:
yield store
@pytest.fixture(name="data_store", scope="function")

View File

@ -88,14 +88,11 @@ async def test_create_creates_tables_and_columns(
columns = await cursor.fetchall()
assert columns == []
store = await DataStore.create(database=database_uri, uri=True)
try:
async with DataStore.managed(database=database_uri, uri=True):
async with db_wrapper.reader() as reader:
cursor = await reader.execute(query)
columns = await cursor.fetchall()
assert [column[1] for column in columns] == expected_columns
finally:
await store.close()
@pytest.mark.anyio
@ -379,8 +376,7 @@ async def test_batch_update(data_store: DataStore, tree_id: bytes32, use_optimiz
saved_batches: List[List[Dict[str, Any]]] = []
db_uri = generate_in_memory_db_uri()
single_op_data_store = await DataStore.create(database=db_uri, uri=True)
try:
async with DataStore.managed(database=db_uri, uri=True) as single_op_data_store:
await single_op_data_store.create_tree(tree_id, status=Status.COMMITTED)
random = Random()
random.seed(100, version=2)
@ -423,8 +419,6 @@ async def test_batch_update(data_store: DataStore, tree_id: bytes32, use_optimiz
batch = []
root = await single_op_data_store.get_tree_root(tree_id=tree_id)
saved_roots.append(root)
finally:
await single_op_data_store.close()
for batch_number, batch in enumerate(saved_batches):
assert len(batch) == num_ops_per_batch
@ -1265,8 +1259,7 @@ async def test_data_server_files(data_store: DataStore, tree_id: bytes32, test_d
num_ops_per_batch = 100
db_uri = generate_in_memory_db_uri()
data_store_server = await DataStore.create(database=db_uri, uri=True)
try:
async with DataStore.managed(database=db_uri, uri=True) as data_store_server:
await data_store_server.create_tree(tree_id, status=Status.COMMITTED)
random = Random()
random.seed(100, version=2)
@ -1291,8 +1284,6 @@ async def test_data_server_files(data_store: DataStore, tree_id: bytes32, test_d
root = await data_store_server.get_tree_root(tree_id)
await write_files_for_root(data_store_server, tree_id, root, tmp_path, 0)
roots.append(root)
finally:
await data_store_server.close()
generation = 1
assert len(roots) == num_batches