diff --git a/chia/data_layer/data_layer.py b/chia/data_layer/data_layer.py index d8b72b9723e4..b94a37d5e34c 100644 --- a/chia/data_layer/data_layer.py +++ b/chia/data_layer/data_layer.py @@ -9,7 +9,6 @@ from pathlib import Path from typing import Any, Awaitable, Callable, Dict, List, Optional, Set, Tuple, Union import aiohttp -import aiosqlite from chia.data_layer.data_layer_errors import KeyNotFoundError from chia.data_layer.data_layer_util import ( @@ -39,7 +38,6 @@ from chia.server.outbound_message import NodeType from chia.server.server import ChiaServer from chia.server.ws_connection import WSChiaConnection from chia.types.blockchain_format.sized_bytes import bytes32 -from chia.util.db_wrapper import DBWrapper from chia.util.ints import uint32, uint64 from chia.util.path import path_from_root from chia.wallet.trade_record import TradeRecord @@ -49,10 +47,7 @@ from chia.wallet.transaction_record import TransactionRecord class DataLayer: data_store: DataStore - db_wrapper: DBWrapper - batch_update_db_wrapper: DBWrapper db_path: Path - connection: Optional[aiosqlite.Connection] config: Dict[str, Any] log: logging.Logger wallet_rpc_init: Awaitable[WalletRpcClient] @@ -114,9 +109,7 @@ class DataLayer: self._server = server async def _start(self) -> None: - self.connection = await aiosqlite.connect(self.db_path) - self.db_wrapper = DBWrapper(self.connection) - self.data_store = await DataStore.create(self.db_wrapper) + self.data_store = await DataStore.create(database=self.db_path) self.wallet_rpc = await self.wallet_rpc_init self.subscription_lock: asyncio.Lock = asyncio.Lock() @@ -133,6 +126,7 @@ class DataLayer: self.periodically_manage_data_task.cancel() except asyncio.CancelledError: pass + await self.data_store.close() async def create_store( self, fee: uint64, root: bytes32 = bytes32([0] * 32) @@ -157,13 +151,12 @@ class DataLayer: self, tree_id: bytes32, changelist: List[Dict[str, Any]], - lock: bool = True, ) -> bytes32: - async with self.data_store.transaction(lock=lock): + async with self.data_store.transaction(): # Make sure we update based on the latest confirmed root. async with self.lock: - await self._update_confirmation_status(tree_id=tree_id, lock=False) - pending_root: Optional[Root] = await self.data_store.get_pending_root(tree_id=tree_id, lock=False) + await self._update_confirmation_status(tree_id=tree_id) + pending_root: Optional[Root] = await self.data_store.get_pending_root(tree_id=tree_id) if pending_root is not None: raise Exception("Already have a pending root waiting for confirmation.") @@ -173,7 +166,7 @@ class DataLayer: raise ValueError(f"Singleton with launcher ID {tree_id} is not owned by DL Wallet") t1 = time.monotonic() - batch_hash = await self.data_store.insert_batch(tree_id, changelist, lock=False) + batch_hash = await self.data_store.insert_batch(tree_id, changelist) t2 = time.monotonic() self.log.info(f"Data store batch update process time: {t2 - t1}.") # todo return empty node hash from get_tree_root @@ -210,21 +203,18 @@ class DataLayer: store_id: bytes32, key: bytes, root_hash: Optional[bytes32] = None, - lock: bool = True, ) -> bytes32: - async with self.data_store.transaction(lock=lock): + async with self.data_store.transaction(): async with self.lock: - await self._update_confirmation_status(tree_id=store_id, lock=False) - node = await self.data_store.get_node_by_key(tree_id=store_id, key=key, root_hash=root_hash, lock=False) + await self._update_confirmation_status(tree_id=store_id) + node = await self.data_store.get_node_by_key(tree_id=store_id, key=key, root_hash=root_hash) return node.hash - async def get_value( - self, store_id: bytes32, key: bytes, root_hash: Optional[bytes32] = None, lock: bool = True - ) -> Optional[bytes]: - async with self.data_store.transaction(lock=lock): + async def get_value(self, store_id: bytes32, key: bytes, root_hash: Optional[bytes32] = None) -> Optional[bytes]: + async with self.data_store.transaction(): async with self.lock: - await self._update_confirmation_status(tree_id=store_id, lock=False) - res = await self.data_store.get_node_by_key(tree_id=store_id, key=key, root_hash=root_hash, lock=False) + await self._update_confirmation_status(tree_id=store_id) + res = await self.data_store.get_node_by_key(tree_id=store_id, key=key, root_hash=root_hash) if res is None: self.log.error("Failed to fetch key") return None @@ -281,10 +271,10 @@ class DataLayer: prev = record return root_history - async def _update_confirmation_status(self, tree_id: bytes32, lock: bool = True) -> None: - async with self.data_store.transaction(lock=lock): + async def _update_confirmation_status(self, tree_id: bytes32) -> None: + async with self.data_store.transaction(): try: - root = await self.data_store.get_tree_root(tree_id=tree_id, lock=False) + root = await self.data_store.get_tree_root(tree_id=tree_id) except asyncio.CancelledError: raise except Exception: @@ -293,11 +283,11 @@ class DataLayer: if singleton_record is None: return if root is None: - pending_root = await self.data_store.get_pending_root(tree_id=tree_id, lock=False) + pending_root = await self.data_store.get_pending_root(tree_id=tree_id) if pending_root is not None: if pending_root.generation == 0 and pending_root.node_hash is None: - await self.data_store.change_root_status(pending_root, Status.COMMITTED, lock=False) - await self.data_store.clear_pending_roots(tree_id=tree_id, lock=False) + await self.data_store.change_root_status(pending_root, Status.COMMITTED) + await self.data_store.clear_pending_roots(tree_id=tree_id) return else: root = None @@ -324,19 +314,19 @@ class DataLayer: generation_shift += 1 new_hashes.pop(0) if generation_shift > 0: - await self.data_store.clear_pending_roots(tree_id=tree_id, lock=False) - await self.data_store.shift_root_generations(tree_id=tree_id, shift_size=generation_shift, lock=False) + await self.data_store.clear_pending_roots(tree_id=tree_id) + await self.data_store.shift_root_generations(tree_id=tree_id, shift_size=generation_shift) else: expected_root_hash = None if new_hashes[0] == self.none_bytes else new_hashes[0] - pending_root = await self.data_store.get_pending_root(tree_id=tree_id, lock=False) + pending_root = await self.data_store.get_pending_root(tree_id=tree_id) if ( pending_root is not None and pending_root.generation == root.generation + 1 and pending_root.node_hash == expected_root_hash ): - await self.data_store.change_root_status(pending_root, Status.COMMITTED, lock=False) - await self.data_store.build_ancestor_table_for_latest_root(tree_id=tree_id, lock=False) - await self.data_store.clear_pending_roots(tree_id=tree_id, lock=False) + await self.data_store.change_root_status(pending_root, Status.COMMITTED) + await self.data_store.build_ancestor_table_for_latest_root(tree_id=tree_id) + await self.data_store.clear_pending_roots(tree_id=tree_id) async def fetch_and_validate(self, tree_id: bytes32) -> None: singleton_record: Optional[SingletonRecord] = await self.wallet_rpc.dl_latest_singleton(tree_id, True) @@ -557,13 +547,12 @@ class DataLayer: self, store_id: bytes32, inclusions: Tuple[KeyValue, ...], - lock: bool = True, ) -> List[Dict[str, Any]]: - async with self.data_store.transaction(lock=lock): + async with self.data_store.transaction(): changelist: List[Dict[str, Any]] = [] for entry in inclusions: try: - existing_value = await self.get_value(store_id=store_id, key=entry.key, lock=False) + existing_value = await self.get_value(store_id=store_id, key=entry.key) except KeyNotFoundError: existing_value = None @@ -590,26 +579,22 @@ class DataLayer: return changelist - async def process_offered_stores( - self, offer_stores: Tuple[OfferStore, ...], lock: bool = True - ) -> Dict[bytes32, StoreProofs]: - async with self.data_store.transaction(lock=lock): + async def process_offered_stores(self, offer_stores: Tuple[OfferStore, ...]) -> Dict[bytes32, StoreProofs]: + async with self.data_store.transaction(): our_store_proofs: Dict[bytes32, StoreProofs] = {} for offer_store in offer_stores: async with self.lock: - await self._update_confirmation_status(tree_id=offer_store.store_id, lock=False) + await self._update_confirmation_status(tree_id=offer_store.store_id) changelist = await self.build_offer_changelist( store_id=offer_store.store_id, inclusions=offer_store.inclusions, - lock=False, ) if len(changelist) > 0: new_root_hash = await self.batch_insert( tree_id=offer_store.store_id, changelist=changelist, - lock=False, ) else: existing_root = await self.get_root(store_id=offer_store.store_id) @@ -626,13 +611,11 @@ class DataLayer: store_id=offer_store.store_id, key=entry.key, root_hash=new_root_hash, - lock=False, ) proof_of_inclusion = await self.data_store.get_proof_of_inclusion_by_hash( node_hash=node_hash, tree_id=offer_store.store_id, root_hash=new_root_hash, - lock=False, ) proof = Proof( key=entry.key, @@ -659,7 +642,7 @@ class DataLayer: fee: uint64, ) -> Offer: async with self.data_store.transaction(): - our_store_proofs = await self.process_offered_stores(offer_stores=maker, lock=False) + our_store_proofs = await self.process_offered_stores(offer_stores=maker) offer_dict: Dict[Union[uint32, str], int] = { **{offer_store.store_id.hex(): -1 for offer_store in maker}, @@ -717,7 +700,7 @@ class DataLayer: fee: uint64, ) -> TradeRecord: async with self.data_store.transaction(): - our_store_proofs = await self.process_offered_stores(offer_stores=taker, lock=False) + our_store_proofs = await self.process_offered_stores(offer_stores=taker) offer = TradingOffer.from_bytes(offer_bytes) summary = await DataLayerWallet.get_offer_summary(offer=offer) diff --git a/chia/data_layer/data_layer_util.py b/chia/data_layer/data_layer_util.py index 8e48525dc19a..c1ecbb9d9a28 100644 --- a/chia/data_layer/data_layer_util.py +++ b/chia/data_layer/data_layer_util.py @@ -12,6 +12,7 @@ from typing_extensions import final from chia.types.blockchain_format.program import Program from chia.types.blockchain_format.sized_bytes import bytes32 from chia.util.byte_types import hexstr_to_bytes +from chia.util.db_wrapper import DBWrapper2 from chia.util.ints import uint64 from chia.util.streamable import Streamable, streamable @@ -42,14 +43,15 @@ def leaf_hash(key: bytes, value: bytes) -> bytes32: return Program.to((key, value)).get_tree_hash() # type: ignore[no-any-return] -async def _debug_dump(db: aiosqlite.Connection, description: str = "") -> None: - cursor = await db.execute("SELECT name FROM sqlite_master WHERE type='table';") - print("-" * 50, description, flush=True) - for [name] in await cursor.fetchall(): - cursor = await db.execute(f"SELECT * FROM {name}") - print(f"\n -- {name} ------", flush=True) - async for row in cursor: - print(f" {dict(row)}") +async def _debug_dump(db: DBWrapper2, description: str = "") -> None: + async with db.reader() as reader: + cursor = await reader.execute("SELECT name FROM sqlite_master WHERE type='table';") + print("-" * 50, description, flush=True) + for [name] in await cursor.fetchall(): + cursor = await reader.execute(f"SELECT * FROM {name}") + print(f"\n -- {name} ------", flush=True) + async for row in cursor: + print(f" {dict(row)}") async def _dot_dump(data_store: DataStore, store_id: bytes32, root_hash: bytes32) -> str: diff --git a/chia/data_layer/data_store.py b/chia/data_layer/data_store.py index 5b801d39e8ee..4c9335f689f5 100644 --- a/chia/data_layer/data_store.py +++ b/chia/data_layer/data_store.py @@ -4,7 +4,8 @@ import logging from collections import defaultdict from contextlib import asynccontextmanager from dataclasses import dataclass, replace -from typing import Any, AsyncIterator, Awaitable, BinaryIO, Callable, Dict, List, Optional, Set, Tuple +from pathlib import Path +from typing import Any, AsyncIterator, Awaitable, BinaryIO, Callable, Dict, List, Optional, Set, Tuple, Union import aiosqlite @@ -30,7 +31,7 @@ from chia.data_layer.data_layer_util import ( ) from chia.types.blockchain_format.program import Program from chia.types.blockchain_format.sized_bytes import bytes32 -from chia.util.db_wrapper import DBWrapper +from chia.util.db_wrapper import DBWrapper2 log = logging.getLogger(__name__) @@ -43,24 +44,26 @@ log = logging.getLogger(__name__) class DataStore: """A key/value store with the pairs being terminal nodes in a CLVM object tree.""" - db: aiosqlite.Connection - db_wrapper: DBWrapper + db_wrapper: DBWrapper2 @classmethod - async def create(cls, db_wrapper: DBWrapper) -> "DataStore": - self = cls(db=db_wrapper.db, db_wrapper=db_wrapper) - self.db.row_factory = aiosqlite.Row + async def create(cls, database: Union[str, Path], uri: bool = False) -> "DataStore": + db_wrapper = await DBWrapper2.create( + database=database, + uri=uri, + journal_mode="WAL", + # Setting to FULL despite other locations being configurable. If there are + # performance issues we can consider other the implications of other options. + synchronous="FULL", + # If foreign key checking gets turned off, please add corresponding check + # methods and enable foreign key checking in the tests. + foreign_keys=True, + row_factory=aiosqlite.Row, + ) + self = cls(db_wrapper=db_wrapper) - await self.db.execute("pragma journal_mode=wal") - # Setting to FULL despite other locations being configurable. If there are - # performance issues we can consider other the implications of other options. - await self.db.execute("pragma synchronous=FULL") - # If foreign key checking gets turned off, please add corresponding check - # methods. - await self.db.execute("PRAGMA foreign_keys=ON") - - async with self.db_wrapper.locked_transaction(): - await self.db.execute( + async with db_wrapper.writer() as writer: + await writer.execute( f""" CREATE TABLE IF NOT EXISTS node( hash BLOB PRIMARY KEY NOT NULL CHECK(length(hash) == 32), @@ -88,7 +91,7 @@ class DataStore: ) """ ) - await self.db.execute( + await writer.execute( """ CREATE TRIGGER IF NOT EXISTS no_node_updates BEFORE UPDATE ON node @@ -97,7 +100,7 @@ class DataStore: END """ ) - await self.db.execute( + await writer.execute( f""" CREATE TABLE IF NOT EXISTS root( tree_id BLOB NOT NULL CHECK(length(tree_id) == 32), @@ -116,7 +119,7 @@ class DataStore: # and the node table also enforcing a similar relationship in the # other direction. # FOREIGN KEY(ancestor) REFERENCES ancestors(ancestor) - await self.db.execute( + await writer.execute( """ CREATE TABLE IF NOT EXISTS ancestors( hash BLOB NOT NULL REFERENCES node, @@ -128,7 +131,7 @@ class DataStore: ) """ ) - await self.db.execute( + await writer.execute( """ CREATE TABLE IF NOT EXISTS subscriptions( tree_id BLOB NOT NULL CHECK(length(tree_id) == 32), @@ -140,7 +143,7 @@ class DataStore: ) """ ) - await self.db.execute( + await writer.execute( """ CREATE INDEX IF NOT EXISTS node_hash ON root(node_hash) """ @@ -148,9 +151,12 @@ class DataStore: return self + async def close(self) -> None: + await self.db_wrapper.close() + @asynccontextmanager - async def transaction(self, lock: bool = True) -> AsyncIterator[None]: - async with self.db_wrapper.locked_transaction(lock=lock): + async def transaction(self) -> AsyncIterator[None]: + async with self.db_wrapper.writer(): yield async def _insert_root( @@ -164,42 +170,44 @@ class DataStore: # https://github.com/Chia-Network/chia-blockchain/pull/9284 tree_id = bytes32(tree_id) - if generation is None: - existing_generation = await self.get_tree_generation(tree_id=tree_id, lock=False) + async with self.db_wrapper.writer() as writer: + if generation is None: + existing_generation = await self.get_tree_generation(tree_id=tree_id) - if existing_generation is None: - generation = 0 - else: - generation = existing_generation + 1 + if existing_generation is None: + generation = 0 + else: + generation = existing_generation + 1 - await self.db.execute( - """ - INSERT INTO root(tree_id, generation, node_hash, status) VALUES(:tree_id, :generation, :node_hash, :status) - """, - { - "tree_id": tree_id, - "generation": generation, - "node_hash": None if node_hash is None else node_hash, - "status": status.value, - }, - ) - - # `node_hash` is now a root, so it has no ancestor. - # Don't change the ancestor table unless the root is committed. - if node_hash is not None and status == Status.COMMITTED: - values = { - "hash": node_hash, - "tree_id": tree_id, - "generation": generation, - } - await self.db.execute( + await writer.execute( """ - INSERT INTO ancestors(hash, ancestor, tree_id, generation) - VALUES (:hash, NULL, :tree_id, :generation) + INSERT INTO root(tree_id, generation, node_hash, status) + VALUES(:tree_id, :generation, :node_hash, :status) """, - values, + { + "tree_id": tree_id, + "generation": generation, + "node_hash": None if node_hash is None else node_hash, + "status": status.value, + }, ) + # `node_hash` is now a root, so it has no ancestor. + # Don't change the ancestor table unless the root is committed. + if node_hash is not None and status == Status.COMMITTED: + values = { + "hash": node_hash, + "tree_id": tree_id, + "generation": generation, + } + await writer.execute( + """ + INSERT INTO ancestors(hash, ancestor, tree_id, generation) + VALUES (:hash, NULL, :tree_id, :generation) + """, + values, + ) + async def _insert_node( self, node_hash: bytes32, @@ -219,32 +227,34 @@ class DataStore: "value": value, } - cursor = await self.db.execute("SELECT * FROM node WHERE hash == :hash", {"hash": node_hash}) - result = await cursor.fetchone() + async with self.db_wrapper.writer() as writer: + cursor = await writer.execute("SELECT * FROM node WHERE hash == :hash", {"hash": node_hash}) + result = await cursor.fetchone() - if result is None: - await self.db.execute( - """ - INSERT INTO node(hash, node_type, left, right, key, value) - VALUES(:hash, :node_type, :left, :right, :key, :value) - """, - values, - ) - else: - result_dict = dict(result) - if result_dict != values: - raise Exception(f"Requested insertion of node with matching hash but other values differ: {node_hash}") + if result is None: + await writer.execute( + """ + INSERT INTO node(hash, node_type, left, right, key, value) + VALUES(:hash, :node_type, :left, :right, :key, :value) + """, + values, + ) + else: + result_dict = dict(result) + if result_dict != values: + raise Exception( + f"Requested insertion of node with matching hash but other values differ: {node_hash}" + ) async def insert_node(self, node_type: NodeType, value1: bytes, value2: bytes) -> None: - async with self.db_wrapper.locked_transaction(lock=True): - if node_type == NodeType.INTERNAL: - left_hash = bytes32(value1) - right_hash = bytes32(value2) - node_hash = internal_hash(left_hash, right_hash) - await self._insert_node(node_hash, node_type, bytes32(value1), bytes32(value2), None, None) - else: - node_hash = leaf_hash(key=value1, value=value2) - await self._insert_node(node_hash, node_type, None, None, value1, value2) + if node_type == NodeType.INTERNAL: + left_hash = bytes32(value1) + right_hash = bytes32(value2) + node_hash = internal_hash(left_hash, right_hash) + await self._insert_node(node_hash, node_type, bytes32(value1), bytes32(value2), None, None) + else: + node_hash = leaf_hash(key=value1, value=value2) + await self._insert_node(node_hash, node_type, None, None, value1, value2) async def _insert_internal_node(self, left_hash: bytes32, right_hash: bytes32) -> bytes32: node_hash: bytes32 = internal_hash(left_hash=left_hash, right_hash=right_hash) @@ -269,33 +279,34 @@ class DataStore: ) -> None: node_hash = internal_hash(left_hash=left_hash, right_hash=right_hash) - for hash in (left_hash, right_hash): - values = { - "hash": hash, - "ancestor": node_hash, - "tree_id": tree_id, - "generation": generation, - } - cursor = await self.db.execute( - "SELECT * FROM ancestors WHERE hash == :hash AND generation == :generation AND tree_id == :tree_id", - {"hash": hash, "generation": generation, "tree_id": tree_id}, - ) - result = await cursor.fetchone() - if result is None: - await self.db.execute( - """ - INSERT INTO ancestors(hash, ancestor, tree_id, generation) - VALUES (:hash, :ancestor, :tree_id, :generation) - """, - values, + async with self.db_wrapper.writer() as writer: + for hash in (left_hash, right_hash): + values = { + "hash": hash, + "ancestor": node_hash, + "tree_id": tree_id, + "generation": generation, + } + cursor = await writer.execute( + "SELECT * FROM ancestors WHERE hash == :hash AND generation == :generation AND tree_id == :tree_id", + {"hash": hash, "generation": generation, "tree_id": tree_id}, ) - else: - result_dict = dict(result) - if result_dict != values: - raise Exception( - "Requested insertion of ancestor, where ancestor differ, but other values are identical: " - f"{hash} {generation} {tree_id}" + result = await cursor.fetchone() + if result is None: + await writer.execute( + """ + INSERT INTO ancestors(hash, ancestor, tree_id, generation) + VALUES (:hash, :ancestor, :tree_id, :generation) + """, + values, ) + else: + result_dict = dict(result) + if result_dict != values: + raise Exception( + "Requested insertion of ancestor, where ancestor differ, but other values are identical: " + f"{hash} {generation} {tree_id}" + ) async def _insert_terminal_node(self, key: bytes, value: bytes) -> bytes32: # forcing type hint here for: @@ -314,9 +325,9 @@ class DataStore: return node_hash - async def get_pending_root(self, tree_id: bytes32, *, lock: bool = True) -> Optional[Root]: - async with self.db_wrapper.locked_transaction(lock=lock): - cursor = await self.db.execute( + async def get_pending_root(self, tree_id: bytes32) -> Optional[Root]: + async with self.db_wrapper.reader() as reader: + cursor = await reader.execute( "SELECT * FROM root WHERE tree_id == :tree_id AND status == :status", {"tree_id": tree_id, "status": Status.PENDING.value}, ) @@ -332,22 +343,22 @@ class DataStore: return Root.from_row(row=row) - async def clear_pending_roots(self, tree_id: bytes32, *, lock: bool = True) -> None: - async with self.db_wrapper.locked_transaction(lock=lock): - await self.db.execute( + async def clear_pending_roots(self, tree_id: bytes32) -> None: + async with self.db_wrapper.writer() as writer: + await writer.execute( "DELETE FROM root WHERE tree_id == :tree_id AND status == :status", {"tree_id": tree_id, "status": Status.PENDING.value}, ) - async def shift_root_generations(self, tree_id: bytes32, shift_size: int, *, lock: bool = True) -> None: - async with self.db_wrapper.locked_transaction(lock=lock): - root = await self.get_tree_root(tree_id=tree_id, lock=False) + async def shift_root_generations(self, tree_id: bytes32, shift_size: int) -> None: + async with self.db_wrapper.writer(): + root = await self.get_tree_root(tree_id=tree_id) for _ in range(shift_size): await self._insert_root(tree_id=tree_id, node_hash=root.node_hash, status=Status.COMMITTED) - async def change_root_status(self, root: Root, status: Status = Status.PENDING, lock: bool = True) -> None: - async with self.db_wrapper.locked_transaction(lock=lock): - await self.db.execute( + async def change_root_status(self, root: Root, status: Status = Status.PENDING) -> None: + async with self.db_wrapper.writer() as writer: + await writer.execute( "UPDATE root SET status = ? WHERE tree_id=? and generation = ?", ( status.value, @@ -363,7 +374,7 @@ class DataStore: "tree_id": root.tree_id, "generation": root.generation, } - await self.db.execute( + await writer.execute( """ INSERT INTO ancestors(hash, ancestor, tree_id, generation) VALUES (:hash, NULL, :tree_id, :generation) @@ -376,9 +387,9 @@ class DataStore: # pylint seems to think these are bound methods not unbound methods. await check(self) # pylint: disable=too-many-function-args - async def _check_roots_are_incrementing(self, *, lock: bool = True) -> None: - async with self.db_wrapper.locked_transaction(lock=lock): - cursor = await self.db.execute("SELECT * FROM root ORDER BY tree_id, generation") + async def _check_roots_are_incrementing(self) -> None: + async with self.db_wrapper.reader() as reader: + cursor = await reader.execute("SELECT * FROM root ORDER BY tree_id, generation") roots = [Root.from_row(row=row) async for row in cursor] roots_by_tree: Dict[bytes32, List[Root]] = defaultdict(list) @@ -396,9 +407,9 @@ class DataStore: if len(bad_trees) > 0: raise TreeGenerationIncrementingError(tree_ids=bad_trees) - async def _check_hashes(self, *, lock: bool = True) -> None: - async with self.db_wrapper.locked_transaction(lock=lock): - cursor = await self.db.execute("SELECT * FROM node") + async def _check_hashes(self) -> None: + async with self.db_wrapper.reader() as reader: + cursor = await reader.execute("SELECT * FROM node") bad_node_hashes: List[bytes32] = [] async for row in cursor: @@ -419,29 +430,27 @@ class DataStore: _check_hashes, ) - async def create_tree(self, tree_id: bytes32, *, lock: bool = True, status: Status = Status.PENDING) -> bool: - async with self.db_wrapper.locked_transaction(lock=lock): - await self._insert_root(tree_id=tree_id, node_hash=None, status=status) + async def create_tree(self, tree_id: bytes32, status: Status = Status.PENDING) -> bool: + await self._insert_root(tree_id=tree_id, node_hash=None, status=status) return True - async def table_is_empty(self, tree_id: bytes32, *, lock: bool = True) -> bool: - async with self.db_wrapper.locked_transaction(lock=lock): - tree_root = await self.get_tree_root(tree_id=tree_id, lock=False) + async def table_is_empty(self, tree_id: bytes32) -> bool: + tree_root = await self.get_tree_root(tree_id=tree_id) return tree_root.node_hash is None - async def get_tree_ids(self, *, lock: bool = True) -> Set[bytes32]: - async with self.db_wrapper.locked_transaction(lock=lock): - cursor = await self.db.execute("SELECT DISTINCT tree_id FROM root") + async def get_tree_ids(self) -> Set[bytes32]: + async with self.db_wrapper.reader() as reader: + cursor = await reader.execute("SELECT DISTINCT tree_id FROM root") - tree_ids = {bytes32(row["tree_id"]) async for row in cursor} + tree_ids = {bytes32(row["tree_id"]) async for row in cursor} return tree_ids - async def get_tree_generation(self, tree_id: bytes32, *, lock: bool = True) -> int: - async with self.db_wrapper.locked_transaction(lock=lock): - cursor = await self.db.execute( + async def get_tree_generation(self, tree_id: bytes32) -> int: + async with self.db_wrapper.reader() as reader: + cursor = await reader.execute( "SELECT MAX(generation) FROM root WHERE tree_id == :tree_id AND status == :status", {"tree_id": tree_id, "status": Status.COMMITTED.value}, ) @@ -452,11 +461,11 @@ class DataStore: generation: int = row["MAX(generation)"] return generation - async def get_tree_root(self, tree_id: bytes32, generation: Optional[int] = None, *, lock: bool = True) -> Root: - async with self.db_wrapper.locked_transaction(lock=lock): + async def get_tree_root(self, tree_id: bytes32, generation: Optional[int] = None) -> Root: + async with self.db_wrapper.reader() as reader: if generation is None: - generation = await self.get_tree_generation(tree_id=tree_id, lock=False) - cursor = await self.db.execute( + generation = await self.get_tree_generation(tree_id=tree_id) + cursor = await reader.execute( "SELECT * FROM root WHERE tree_id == :tree_id AND generation == :generation AND status == :status", {"tree_id": tree_id, "generation": generation, "status": Status.COMMITTED.value}, ) @@ -471,9 +480,9 @@ class DataStore: return Root.from_row(row=row) - async def tree_id_exists(self, tree_id: bytes32, *, lock: bool = True) -> bool: - async with self.db_wrapper.locked_transaction(lock=lock): - cursor = await self.db.execute( + async def tree_id_exists(self, tree_id: bytes32) -> bool: + async with self.db_wrapper.reader() as reader: + cursor = await reader.execute( "SELECT 1 FROM root WHERE tree_id == :tree_id AND status == :status", {"tree_id": tree_id, "status": Status.COMMITTED.value}, ) @@ -483,11 +492,9 @@ class DataStore: return False return True - async def get_roots_between( - self, tree_id: bytes32, generation_begin: int, generation_end: int, *, lock: bool = True - ) -> List[Root]: - async with self.db_wrapper.locked_transaction(lock=lock): - cursor = await self.db.execute( + async def get_roots_between(self, tree_id: bytes32, generation_begin: int, generation_end: int) -> List[Root]: + async with self.db_wrapper.reader() as reader: + cursor = await reader.execute( "SELECT * FROM root WHERE tree_id == :tree_id " "AND generation >= :generation_begin AND generation < :generation_end ORDER BY generation ASC", {"tree_id": tree_id, "generation_begin": generation_begin, "generation_end": generation_end}, @@ -497,12 +504,12 @@ class DataStore: return roots async def get_last_tree_root_by_hash( - self, tree_id: bytes32, hash: Optional[bytes32], max_generation: Optional[int] = None, *, lock: bool = True + self, tree_id: bytes32, hash: Optional[bytes32], max_generation: Optional[int] = None ) -> Optional[Root]: - async with self.db_wrapper.locked_transaction(lock=lock): + async with self.db_wrapper.reader() as reader: max_generation_str = f"AND generation < {max_generation} " if max_generation is not None else "" node_hash_str = "AND node_hash == :node_hash " if hash is not None else "AND node_hash is NULL " - cursor = await self.db.execute( + cursor = await reader.execute( "SELECT * FROM root WHERE tree_id == :tree_id " f"{max_generation_str}" f"{node_hash_str}" @@ -520,16 +527,14 @@ class DataStore: node_hash: bytes32, tree_id: bytes32, root_hash: Optional[bytes32] = None, - *, - lock: bool = True, ) -> List[InternalNode]: - async with self.db_wrapper.locked_transaction(lock=lock): + async with self.db_wrapper.reader() as reader: if root_hash is None: - root = await self.get_tree_root(tree_id=tree_id, lock=False) + root = await self.get_tree_root(tree_id=tree_id) root_hash = root.node_hash if root_hash is None: raise Exception(f"Root hash is unspecified for tree ID: {tree_id.hex()}") - cursor = await self.db.execute( + cursor = await reader.execute( """ WITH RECURSIVE tree_from_root_hash(hash, node_type, left, right, key, value, depth) AS ( @@ -561,11 +566,11 @@ class DataStore: return ancestors async def get_ancestors_optimized( - self, node_hash: bytes32, tree_id: bytes32, generation: Optional[int] = None, lock: bool = True + self, node_hash: bytes32, tree_id: bytes32, generation: Optional[int] = None ) -> List[InternalNode]: - async with self.db_wrapper.locked_transaction(lock=lock): + async with self.db_wrapper.reader(): nodes = [] - root = await self.get_tree_root(tree_id=tree_id, generation=generation, lock=False) + root = await self.get_tree_root(tree_id=tree_id, generation=generation) if root.node_hash is None: return [] @@ -582,14 +587,12 @@ class DataStore: return nodes - async def get_internal_nodes( - self, tree_id: bytes32, root_hash: Optional[bytes32] = None, *, lock: bool = True - ) -> List[InternalNode]: - async with self.db_wrapper.locked_transaction(lock=lock): + async def get_internal_nodes(self, tree_id: bytes32, root_hash: Optional[bytes32] = None) -> List[InternalNode]: + async with self.db_wrapper.reader() as reader: if root_hash is None: - root = await self.get_tree_root(tree_id=tree_id, lock=False) + root = await self.get_tree_root(tree_id=tree_id) root_hash = root.node_hash - cursor = await self.db.execute( + cursor = await reader.execute( """ WITH RECURSIVE tree_from_root_hash(hash, node_type, left, right, key, value) AS ( @@ -613,14 +616,12 @@ class DataStore: return internal_nodes - async def get_keys_values( - self, tree_id: bytes32, root_hash: Optional[bytes32] = None, *, lock: bool = True - ) -> List[TerminalNode]: - async with self.db_wrapper.locked_transaction(lock=lock): + async def get_keys_values(self, tree_id: bytes32, root_hash: Optional[bytes32] = None) -> List[TerminalNode]: + async with self.db_wrapper.reader() as reader: if root_hash is None: - root = await self.get_tree_root(tree_id=tree_id, lock=False) + root = await self.get_tree_root(tree_id=tree_id) root_hash = root.node_hash - cursor = await self.db.execute( + cursor = await reader.execute( """ WITH RECURSIVE tree_from_root_hash(hash, node_type, left, right, key, value, depth, rights) AS ( @@ -665,9 +666,9 @@ class DataStore: return terminal_nodes - async def get_node_type(self, node_hash: bytes32, *, lock: bool = True) -> NodeType: - async with self.db_wrapper.locked_transaction(lock=lock): - cursor = await self.db.execute("SELECT node_type FROM node WHERE hash == :hash", {"hash": node_hash}) + async def get_node_type(self, node_hash: bytes32) -> NodeType: + async with self.db_wrapper.reader() as reader: + cursor = await reader.execute("SELECT node_type FROM node WHERE hash == :hash", {"hash": node_hash}) raw_node_type = await cursor.fetchone() if raw_node_type is None: @@ -675,17 +676,15 @@ class DataStore: return NodeType(raw_node_type["node_type"]) - async def get_terminal_node_for_seed( - self, tree_id: bytes32, seed: bytes32, *, lock: bool = True - ) -> Optional[bytes32]: + async def get_terminal_node_for_seed(self, tree_id: bytes32, seed: bytes32) -> Optional[bytes32]: path = int.from_bytes(seed, byteorder="big") - async with self.db_wrapper.locked_transaction(lock=lock): - root = await self.get_tree_root(tree_id, lock=False) + async with self.db_wrapper.reader(): + root = await self.get_tree_root(tree_id) if root is None or root.node_hash is None: return None node_hash = root.node_hash while True: - node = await self.get_node(node_hash, lock=False) + node = await self.get_node(node_hash) assert node is not None if isinstance(node, TerminalNode): break @@ -709,17 +708,15 @@ class DataStore: hint_keys_values: Optional[Dict[bytes, bytes]] = None, use_optimized: bool = True, status: Status = Status.PENDING, - *, - lock: bool = True, ) -> bytes32: - async with self.db_wrapper.locked_transaction(lock=lock): - was_empty = await self.table_is_empty(tree_id=tree_id, lock=False) + async with self.db_wrapper.writer(): + was_empty = await self.table_is_empty(tree_id=tree_id) if was_empty: reference_node_hash = None side = None else: seed = leaf_hash(key=key, value=value) - reference_node_hash = await self.get_terminal_node_for_seed(tree_id, seed, lock=False) + reference_node_hash = await self.get_terminal_node_for_seed(tree_id, seed) side = self.get_side_for_seed(seed) return await self.insert( @@ -729,24 +726,20 @@ class DataStore: reference_node_hash=reference_node_hash, side=side, hint_keys_values=hint_keys_values, - lock=False, use_optimized=use_optimized, status=status, ) - async def get_keys_values_dict(self, tree_id: bytes32, *, lock: bool = True) -> Dict[bytes, bytes]: - async with self.db_wrapper.locked_transaction(lock=lock): - pairs = await self.get_keys_values(tree_id=tree_id, lock=False) - return {node.key: node.value for node in pairs} + async def get_keys_values_dict(self, tree_id: bytes32) -> Dict[bytes, bytes]: + pairs = await self.get_keys_values(tree_id=tree_id) + return {node.key: node.value for node in pairs} - async def get_keys( - self, tree_id: bytes32, root_hash: Optional[bytes32] = None, *, lock: bool = True - ) -> List[bytes]: - async with self.db_wrapper.locked_transaction(lock=lock): + async def get_keys(self, tree_id: bytes32, root_hash: Optional[bytes32] = None) -> List[bytes]: + async with self.db_wrapper.reader() as reader: if root_hash is None: - root = await self.get_tree_root(tree_id=tree_id, lock=False) + root = await self.get_tree_root(tree_id=tree_id) root_hash = root.node_hash - cursor = await self.db.execute( + cursor = await reader.execute( """ WITH RECURSIVE tree_from_root_hash(hash, node_type, left, right, key) AS ( @@ -776,17 +769,15 @@ class DataStore: hint_keys_values: Optional[Dict[bytes, bytes]] = None, use_optimized: bool = True, status: Status = Status.PENDING, - *, - lock: bool = True, ) -> bytes32: - async with self.db_wrapper.locked_transaction(lock=lock): - was_empty = await self.table_is_empty(tree_id=tree_id, lock=False) - root = await self.get_tree_root(tree_id=tree_id, lock=False) + async with self.db_wrapper.writer(): + was_empty = await self.table_is_empty(tree_id=tree_id) + root = await self.get_tree_root(tree_id=tree_id) if not was_empty: if hint_keys_values is None: # TODO: is there any way the db can enforce this? - pairs = await self.get_keys_values(tree_id=tree_id, lock=False) + pairs = await self.get_keys_values(tree_id=tree_id) if any(key == node.key for node in pairs): raise Exception(f"Key already present: {key.hex()}") else: @@ -797,7 +788,7 @@ class DataStore: if not was_empty: raise Exception(f"Reference node hash must be specified for non-empty tree: {tree_id.hex()}") else: - reference_node_type = await self.get_node_type(node_hash=reference_node_hash, lock=False) + reference_node_type = await self.get_node_type(node_hash=reference_node_hash) if reference_node_type == NodeType.INTERNAL: raise Exception("can not insert a new key/value on an internal node") @@ -823,14 +814,12 @@ class DataStore: if use_optimized: ancestors: List[InternalNode] = await self.get_ancestors_optimized( - node_hash=reference_node_hash, tree_id=tree_id, lock=False + node_hash=reference_node_hash, tree_id=tree_id ) else: - ancestors = await self.get_ancestors_optimized( - node_hash=reference_node_hash, tree_id=tree_id, lock=False - ) + ancestors = await self.get_ancestors_optimized(node_hash=reference_node_hash, tree_id=tree_id) ancestors_2: List[InternalNode] = await self.get_ancestors( - node_hash=reference_node_hash, tree_id=tree_id, lock=False + node_hash=reference_node_hash, tree_id=tree_id ) if ancestors != ancestors_2: raise RuntimeError("Ancestors optimized didn't produce the expected result.") @@ -890,12 +879,10 @@ class DataStore: hint_keys_values: Optional[Dict[bytes, bytes]] = None, use_optimized: bool = True, status: Status = Status.PENDING, - *, - lock: bool = True, ) -> None: - async with self.db_wrapper.locked_transaction(lock=lock): + async with self.db_wrapper.writer(): if hint_keys_values is None: - node = await self.get_node_by_key(key=key, tree_id=tree_id, lock=False) + node = await self.get_node_by_key(key=key, tree_id=tree_id) else: if bytes(key) not in hint_keys_values: log.debug(f"Request to delete an unknown key ignored: {key.hex()}") @@ -905,14 +892,10 @@ class DataStore: node = TerminalNode(node_hash, key, value) del hint_keys_values[bytes(key)] if use_optimized: - ancestors: List[InternalNode] = await self.get_ancestors_optimized( - node_hash=node.hash, tree_id=tree_id, lock=False - ) + ancestors: List[InternalNode] = await self.get_ancestors_optimized(node_hash=node.hash, tree_id=tree_id) else: - ancestors = await self.get_ancestors_optimized(node_hash=node.hash, tree_id=tree_id, lock=False) - ancestors_2: List[InternalNode] = await self.get_ancestors( - node_hash=node.hash, tree_id=tree_id, lock=False - ) + ancestors = await self.get_ancestors_optimized(node_hash=node.hash, tree_id=tree_id) + ancestors_2: List[InternalNode] = await self.get_ancestors(node_hash=node.hash, tree_id=tree_id) if ancestors != ancestors_2: raise RuntimeError("Ancestors optimized didn't produce the expected result.") @@ -943,7 +926,7 @@ class DataStore: old_child_hash = parent.hash new_child_hash = other_hash - new_generation = await self.get_tree_generation(tree_id, lock=False) + 1 + new_generation = await self.get_tree_generation(tree_id) + 1 # update ancestors after inserting root, to keep table constraints. insert_ancestors_cache: List[Tuple[bytes32, bytes32, bytes32]] = [] # more parents to handle so let's traverse them @@ -977,12 +960,10 @@ class DataStore: tree_id: bytes32, changelist: List[Dict[str, Any]], status: Status = Status.PENDING, - *, - lock: bool = True, ) -> Optional[bytes32]: - async with self.db_wrapper.locked_transaction(lock=lock): - hint_keys_values = await self.get_keys_values_dict(tree_id, lock=False) - old_root = await self.get_tree_root(tree_id, lock=False) + async with self.db_wrapper.writer(): + hint_keys_values = await self.get_keys_values_dict(tree_id) + old_root = await self.get_tree_root(tree_id) for change in changelist: if change["action"] == "insert": key = change["key"] @@ -990,7 +971,7 @@ class DataStore: reference_node_hash = change.get("reference_node_hash", None) side = change.get("side", None) if reference_node_hash is None and side is None: - await self.autoinsert(key, value, tree_id, hint_keys_values, True, Status.COMMITTED, lock=False) + await self.autoinsert(key, value, tree_id, hint_keys_values, True, Status.COMMITTED) else: if reference_node_hash is None or side is None: raise Exception("Provide both reference_node_hash and side or neither.") @@ -1003,27 +984,24 @@ class DataStore: hint_keys_values, True, Status.COMMITTED, - lock=False, ) elif change["action"] == "delete": key = change["key"] - await self.delete(key, tree_id, hint_keys_values, True, Status.COMMITTED, lock=False) + await self.delete(key, tree_id, hint_keys_values, True, Status.COMMITTED) else: raise Exception(f"Operation in batch is not insert or delete: {change}") - root = await self.get_tree_root(tree_id=tree_id, lock=False) + root = await self.get_tree_root(tree_id=tree_id) # We delete all "temporary" records stored in root and ancestor tables and store only the final result. - await self.rollback_to_generation(tree_id, old_root.generation, lock=False) + await self.rollback_to_generation(tree_id, old_root.generation) if root.node_hash == old_root.node_hash: raise ValueError("Changelist resulted in no change to tree data") - await self.insert_root_with_ancestor_table( - tree_id=tree_id, node_hash=root.node_hash, status=status, lock=False - ) + await self.insert_root_with_ancestor_table(tree_id=tree_id, node_hash=root.node_hash, status=status) if status == Status.PENDING: - new_root = await self.get_pending_root(tree_id=tree_id, lock=False) + new_root = await self.get_pending_root(tree_id=tree_id) assert new_root is not None elif status == Status.COMMITTED: - new_root = await self.get_tree_root(tree_id=tree_id, lock=False) + new_root = await self.get_tree_root(tree_id=tree_id) else: raise Exception(f"No known status: {status}") if new_root.node_hash != root.node_hash: @@ -1043,42 +1021,41 @@ class DataStore: tree_id: bytes32, generation: Optional[int] = None, ) -> Optional[InternalNode]: - if generation is None: - generation = await self.get_tree_generation(tree_id=tree_id, lock=False) - cursor = await self.db.execute( - """ - SELECT * from node INNER JOIN ( - SELECT ancestors.ancestor AS hash, MAX(ancestors.generation) AS generation - FROM ancestors - WHERE ancestors.hash == :hash - AND ancestors.tree_id == :tree_id - AND ancestors.generation <= :generation - GROUP BY hash - ) asc on asc.hash == node.hash - """, - {"hash": node_hash, "tree_id": tree_id, "generation": generation}, - ) - row = await cursor.fetchone() - if row is None: - return None - return InternalNode.from_row(row=row) + async with self.db_wrapper.reader() as reader: + if generation is None: + generation = await self.get_tree_generation(tree_id=tree_id) + cursor = await reader.execute( + """ + SELECT * from node INNER JOIN ( + SELECT ancestors.ancestor AS hash, MAX(ancestors.generation) AS generation + FROM ancestors + WHERE ancestors.hash == :hash + AND ancestors.tree_id == :tree_id + AND ancestors.generation <= :generation + GROUP BY hash + ) asc on asc.hash == node.hash + """, + {"hash": node_hash, "tree_id": tree_id, "generation": generation}, + ) + row = await cursor.fetchone() + if row is None: + return None + return InternalNode.from_row(row=row) - async def build_ancestor_table_for_latest_root(self, tree_id: bytes32, *, lock: bool = True) -> None: - async with self.db_wrapper.locked_transaction(lock=lock): - root = await self.get_tree_root(tree_id=tree_id, lock=False) + async def build_ancestor_table_for_latest_root(self, tree_id: bytes32) -> None: + async with self.db_wrapper.writer(): + root = await self.get_tree_root(tree_id=tree_id) if root.node_hash is None: return previous_root = await self.get_tree_root( tree_id=tree_id, generation=max(root.generation - 1, 0), - lock=False, ) if previous_root.node_hash is not None: previous_internal_nodes: List[InternalNode] = await self.get_internal_nodes( tree_id=tree_id, root_hash=previous_root.node_hash, - lock=False, ) known_hashes: Set[bytes32] = set(node.hash for node in previous_internal_nodes) else: @@ -1086,7 +1063,6 @@ class DataStore: internal_nodes: List[InternalNode] = await self.get_internal_nodes( tree_id=tree_id, root_hash=root.node_hash, - lock=False, ) for node in internal_nodes: # We already have the same values in ancestor tables, if we have the same internal node. @@ -1095,24 +1071,21 @@ class DataStore: await self._insert_ancestor_table(node.left_hash, node.right_hash, tree_id, root.generation) async def insert_root_with_ancestor_table( - self, tree_id: bytes32, node_hash: Optional[bytes32], status: Status = Status.PENDING, *, lock: bool = True + self, tree_id: bytes32, node_hash: Optional[bytes32], status: Status = Status.PENDING ) -> None: - async with self.db_wrapper.locked_transaction(lock=lock): + async with self.db_wrapper.writer(): await self._insert_root(tree_id=tree_id, node_hash=node_hash, status=status) # Don't update the ancestor table for non-committed status. if status == Status.COMMITTED: - await self.build_ancestor_table_for_latest_root(tree_id=tree_id, lock=False) + await self.build_ancestor_table_for_latest_root(tree_id=tree_id) async def get_node_by_key( self, key: bytes, tree_id: bytes32, root_hash: Optional[bytes32] = None, - *, - lock: bool = True, ) -> TerminalNode: - async with self.db_wrapper.locked_transaction(lock=lock): - nodes = await self.get_keys_values(tree_id=tree_id, root_hash=root_hash, lock=False) + nodes = await self.get_keys_values(tree_id=tree_id, root_hash=root_hash) for node in nodes: if node.key == key: @@ -1120,9 +1093,9 @@ class DataStore: raise KeyNotFoundError(key=key) - async def get_node(self, node_hash: bytes32, *, lock: bool = True) -> Node: - async with self.db_wrapper.locked_transaction(lock=lock): - cursor = await self.db.execute("SELECT * FROM node WHERE hash == :hash", {"hash": node_hash}) + async def get_node(self, node_hash: bytes32) -> Node: + async with self.db_wrapper.reader() as reader: + cursor = await reader.execute("SELECT * FROM node WHERE hash == :hash", {"hash": node_hash}) row = await cursor.fetchone() if row is None: @@ -1131,14 +1104,14 @@ class DataStore: node = row_to_node(row=row) return node - async def get_tree_as_program(self, tree_id: bytes32, *, lock: bool = True) -> Program: - async with self.db_wrapper.locked_transaction(lock=lock): - root = await self.get_tree_root(tree_id=tree_id, lock=False) + async def get_tree_as_program(self, tree_id: bytes32) -> Program: + async with self.db_wrapper.reader() as reader: + root = await self.get_tree_root(tree_id=tree_id) # TODO: consider actual proper behavior assert root.node_hash is not None - root_node = await self.get_node(node_hash=root.node_hash, lock=False) + root_node = await self.get_node(node_hash=root.node_hash) - cursor = await self.db.execute( + cursor = await reader.execute( """ WITH RECURSIVE tree_from_root_hash(hash, node_type, left, right, key, value) AS ( @@ -1171,14 +1144,11 @@ class DataStore: node_hash: bytes32, tree_id: bytes32, root_hash: Optional[bytes32] = None, - *, - lock: bool = True, ) -> ProofOfInclusion: """Collect the information for a proof of inclusion of a hash in the Merkle tree. """ - async with self.db_wrapper.locked_transaction(lock=lock): - ancestors = await self.get_ancestors(node_hash=node_hash, tree_id=tree_id, root_hash=root_hash, lock=False) + ancestors = await self.get_ancestors(node_hash=node_hash, tree_id=tree_id, root_hash=root_hash) layers: List[ProofOfInclusionLayer] = [] child_hash = node_hash @@ -1206,19 +1176,17 @@ class DataStore: self, key: bytes, tree_id: bytes32, - *, - lock: bool = True, ) -> ProofOfInclusion: """Collect the information for a proof of inclusion of a key and its value in the Merkle tree. """ - async with self.db_wrapper.locked_transaction(lock=lock): - node = await self.get_node_by_key(key=key, tree_id=tree_id, lock=False) - return await self.get_proof_of_inclusion_by_hash(node_hash=node.hash, tree_id=tree_id, lock=False) + async with self.db_wrapper.reader(): + node = await self.get_node_by_key(key=key, tree_id=tree_id) + return await self.get_proof_of_inclusion_by_hash(node_hash=node.hash, tree_id=tree_id) - async def get_first_generation(self, node_hash: bytes32, tree_id: bytes32, *, lock: bool = True) -> int: - async with self.db_wrapper.locked_transaction(lock=lock): - cursor = await self.db.execute( + async def get_first_generation(self, node_hash: bytes32, tree_id: bytes32) -> int: + async with self.db_wrapper.reader() as reader: + cursor = await reader.execute( "SELECT MIN(generation) AS generation FROM ancestors WHERE hash == :hash AND tree_id == :tree_id", {"hash": node_hash, "tree_id": tree_id}, ) @@ -1236,22 +1204,20 @@ class DataStore: tree_id: bytes32, deltas_only: bool, writer: BinaryIO, - *, - lock: bool = True, ) -> None: if node_hash == bytes32([0] * 32): return if deltas_only: - generation = await self.get_first_generation(node_hash, tree_id, lock=lock) + generation = await self.get_first_generation(node_hash, tree_id) # Root's generation is not the first time we see this hash, so it's not a new delta. if root.generation != generation: return - node = await self.get_node(node_hash, lock=lock) + node = await self.get_node(node_hash) to_write = b"" if isinstance(node, InternalNode): - await self.write_tree_to_file(root, node.left_hash, tree_id, deltas_only, writer, lock=lock) - await self.write_tree_to_file(root, node.right_hash, tree_id, deltas_only, writer, lock=lock) + await self.write_tree_to_file(root, node.left_hash, tree_id, deltas_only, writer) + await self.write_tree_to_file(root, node.right_hash, tree_id, deltas_only, writer) to_write = bytes(SerializedNode(False, bytes(node.left_hash), bytes(node.right_hash))) elif isinstance(node, TerminalNode): to_write = bytes(SerializedNode(True, node.key, node.value)) @@ -1261,18 +1227,16 @@ class DataStore: writer.write(len(to_write).to_bytes(4, byteorder="big")) writer.write(to_write) - async def update_subscriptions_from_wallet( - self, tree_id: bytes32, new_urls: List[str], *, lock: bool = True - ) -> None: - async with self.db_wrapper.locked_transaction(lock=lock): - cursor = await self.db.execute( + async def update_subscriptions_from_wallet(self, tree_id: bytes32, new_urls: List[str]) -> None: + async with self.db_wrapper.writer() as writer: + cursor = await writer.execute( "SELECT * FROM subscriptions WHERE from_wallet == 1 AND tree_id == :tree_id", { "tree_id": tree_id, }, ) old_urls = [row["url"] async for row in cursor] - cursor = await self.db.execute( + cursor = await writer.execute( "SELECT * FROM subscriptions WHERE from_wallet == 0 AND tree_id == :tree_id", { "tree_id": tree_id, @@ -1282,7 +1246,7 @@ class DataStore: additions = {url for url in new_urls if url not in old_urls} removals = [url for url in old_urls if url not in new_urls] for url in removals: - await self.db.execute( + await writer.execute( "DELETE FROM subscriptions WHERE url == :url AND tree_id == :tree_id", { "url": url, @@ -1291,7 +1255,7 @@ class DataStore: ) for url in additions: if url not in from_subscriptions_urls: - await self.db.execute( + await writer.execute( "INSERT INTO subscriptions(tree_id, url, ignore_till, num_consecutive_failures, from_wallet) " "VALUES (:tree_id, :url, 0, 0, 1)", { @@ -1300,17 +1264,17 @@ class DataStore: }, ) - async def subscribe(self, subscription: Subscription, *, lock: bool = True) -> None: - async with self.db_wrapper.locked_transaction(lock=lock): + async def subscribe(self, subscription: Subscription) -> None: + async with self.db_wrapper.writer() as writer: # Add a fake subscription, so we always have the tree_id, even with no URLs. - await self.db.execute( + await writer.execute( "INSERT INTO subscriptions(tree_id, url, ignore_till, num_consecutive_failures, from_wallet) " "VALUES (:tree_id, NULL, NULL, NULL, 0)", { "tree_id": subscription.tree_id, }, ) - all_subscriptions = await self.get_subscriptions(lock=False) + all_subscriptions = await self.get_subscriptions() old_subscription = next( ( old_subscription @@ -1324,7 +1288,7 @@ class DataStore: old_urls = set(server_info.url for server_info in old_subscription.servers_info) new_servers = [server_info for server_info in subscription.servers_info if server_info.url not in old_urls] for server_info in new_servers: - await self.db.execute( + await writer.execute( "INSERT INTO subscriptions(tree_id, url, ignore_till, num_consecutive_failures, from_wallet) " "VALUES (:tree_id, :url, :ignore_till, :num_consecutive_failures, 0)", { @@ -1335,10 +1299,10 @@ class DataStore: }, ) - async def remove_subscriptions(self, tree_id: bytes32, urls: List[str], *, lock: bool = True) -> None: - async with self.db_wrapper.locked_transaction(lock=lock): + async def remove_subscriptions(self, tree_id: bytes32, urls: List[str]) -> None: + async with self.db_wrapper.writer() as writer: for url in urls: - await self.db.execute( + await writer.execute( "DELETE FROM subscriptions WHERE tree_id == :tree_id AND url == :url", { "tree_id": tree_id, @@ -1346,27 +1310,27 @@ class DataStore: }, ) - async def unsubscribe(self, tree_id: bytes32, *, lock: bool = True) -> None: - async with self.db_wrapper.locked_transaction(lock=lock): - await self.db.execute( + async def unsubscribe(self, tree_id: bytes32) -> None: + async with self.db_wrapper.writer() as writer: + await writer.execute( "DELETE FROM subscriptions WHERE tree_id == :tree_id", {"tree_id": tree_id}, ) - async def rollback_to_generation(self, tree_id: bytes32, target_generation: int, *, lock: bool = True) -> None: - async with self.db_wrapper.locked_transaction(lock=lock): - await self.db.execute( + async def rollback_to_generation(self, tree_id: bytes32, target_generation: int) -> None: + async with self.db_wrapper.writer() as writer: + await writer.execute( "DELETE FROM ancestors WHERE tree_id == :tree_id AND generation > :target_generation", {"tree_id": tree_id, "target_generation": target_generation}, ) - await self.db.execute( + await writer.execute( "DELETE FROM root WHERE tree_id == :tree_id AND generation > :target_generation", {"tree_id": tree_id, "target_generation": target_generation}, ) - async def update_server_info(self, tree_id: bytes32, server_info: ServerInfo, *, lock: bool = True) -> None: - async with self.db_wrapper.locked_transaction(lock=lock): - await self.db.execute( + async def update_server_info(self, tree_id: bytes32, server_info: ServerInfo) -> None: + async with self.db_wrapper.writer() as writer: + await writer.execute( "UPDATE subscriptions SET ignore_till = :ignore_till, " "num_consecutive_failures = :num_consecutive_failures WHERE tree_id = :tree_id AND url = :url", { @@ -1377,27 +1341,23 @@ class DataStore: }, ) - async def received_incorrect_file( - self, tree_id: bytes32, server_info: ServerInfo, timestamp: int, *, lock: bool = True - ) -> None: + async def received_incorrect_file(self, tree_id: bytes32, server_info: ServerInfo, timestamp: int) -> None: SEVEN_DAYS_BAN = 7 * 24 * 60 * 60 new_server_info = replace( server_info, num_consecutive_failures=server_info.num_consecutive_failures + 1, ignore_till=max(server_info.ignore_till, timestamp + SEVEN_DAYS_BAN), ) - await self.update_server_info(tree_id, new_server_info, lock=lock) + await self.update_server_info(tree_id, new_server_info) - async def received_correct_file(self, tree_id: bytes32, server_info: ServerInfo, *, lock: bool = True) -> None: + async def received_correct_file(self, tree_id: bytes32, server_info: ServerInfo) -> None: new_server_info = replace( server_info, num_consecutive_failures=0, ) - await self.update_server_info(tree_id, new_server_info, lock=lock) + await self.update_server_info(tree_id, new_server_info) - async def server_misses_file( - self, tree_id: bytes32, server_info: ServerInfo, timestamp: int, *, lock: bool = True - ) -> None: + async def server_misses_file(self, tree_id: bytes32, server_info: ServerInfo, timestamp: int) -> None: BAN_TIME_BY_MISSING_COUNT = [5 * 60] * 3 + [15 * 60] * 3 + [60 * 60] * 2 + [240 * 60] index = min(server_info.num_consecutive_failures, len(BAN_TIME_BY_MISSING_COUNT) - 1) new_server_info = replace( @@ -1405,12 +1365,10 @@ class DataStore: num_consecutive_failures=server_info.num_consecutive_failures + 1, ignore_till=max(server_info.ignore_till, timestamp + BAN_TIME_BY_MISSING_COUNT[index]), ) - await self.update_server_info(tree_id, new_server_info, lock=lock) + await self.update_server_info(tree_id, new_server_info) - async def get_available_servers_for_store( - self, tree_id: bytes32, timestamp: int, *, lock: bool = True - ) -> List[ServerInfo]: - subscriptions = await self.get_subscriptions(lock=lock) + async def get_available_servers_for_store(self, tree_id: bytes32, timestamp: int) -> List[ServerInfo]: + subscriptions = await self.get_subscriptions() subscription = next((subscription for subscription in subscriptions if subscription.tree_id == tree_id), None) if subscription is None: return [] @@ -1420,11 +1378,11 @@ class DataStore: servers_info.append(server_info) return servers_info - async def get_subscriptions(self, *, lock: bool = True) -> List[Subscription]: + async def get_subscriptions(self) -> List[Subscription]: subscriptions: List[Subscription] = [] - async with self.db_wrapper.locked_transaction(lock=lock): - cursor = await self.db.execute( + async with self.db_wrapper.reader() as reader: + cursor = await reader.execute( "SELECT * from subscriptions", ) async for row in cursor: @@ -1457,12 +1415,10 @@ class DataStore: tree_id: bytes32, hash_1: bytes32, hash_2: bytes32, - *, - lock: bool = True, ) -> Set[DiffData]: - async with self.db_wrapper.locked_transaction(lock=lock): - old_pairs = set(await self.get_keys_values(tree_id, hash_1, lock=False)) - new_pairs = set(await self.get_keys_values(tree_id, hash_2, lock=False)) + async with self.db_wrapper.reader(): + old_pairs = set(await self.get_keys_values(tree_id, hash_1)) + new_pairs = set(await self.get_keys_values(tree_id, hash_2)) if len(old_pairs) == 0 and hash_1 != bytes32([0] * 32): return set() if len(new_pairs) == 0 and hash_2 != bytes32([0] * 32): diff --git a/chia/data_layer/util/benchmark.py b/chia/data_layer/util/benchmark.py index 67664ddc7a09..c4051df17818 100644 --- a/chia/data_layer/util/benchmark.py +++ b/chia/data_layer/util/benchmark.py @@ -8,12 +8,9 @@ import time from pathlib import Path from typing import Dict, Optional -import aiosqlite - from chia.data_layer.data_layer_util import Side, TerminalNode, leaf_hash from chia.data_layer.data_store import DataStore from chia.types.blockchain_format.sized_bytes import bytes32 -from chia.util.db_wrapper import DBWrapper async def generate_datastore(num_nodes: int, slow_mode: bool) -> None: @@ -25,93 +22,93 @@ async def generate_datastore(num_nodes: int, slow_mode: bool) -> None: if os.path.exists(db_path): os.remove(db_path) - connection = await aiosqlite.connect(db_path) - db_wrapper = DBWrapper(connection) - data_store = await DataStore.create(db_wrapper=db_wrapper) - hint_keys_values: Dict[bytes, bytes] = {} + data_store = await DataStore.create(database=db_path) + try: + hint_keys_values: Dict[bytes, bytes] = {} - tree_id = bytes32(b"0" * 32) - await data_store.create_tree(tree_id) + tree_id = bytes32(b"0" * 32) + await data_store.create_tree(tree_id) - insert_time = 0.0 - insert_count = 0 - autoinsert_time = 0.0 - autoinsert_count = 0 - delete_time = 0.0 - delete_count = 0 + insert_time = 0.0 + insert_count = 0 + autoinsert_time = 0.0 + autoinsert_count = 0 + delete_time = 0.0 + delete_count = 0 - for i in range(num_nodes): - key = i.to_bytes(4, byteorder="big") - value = (2 * i).to_bytes(4, byteorder="big") - seed = leaf_hash(key=key, value=value) - reference_node_hash: Optional[bytes32] = await data_store.get_terminal_node_for_seed(tree_id, seed) - side: Optional[Side] = data_store.get_side_for_seed(seed) + for i in range(num_nodes): + key = i.to_bytes(4, byteorder="big") + value = (2 * i).to_bytes(4, byteorder="big") + seed = leaf_hash(key=key, value=value) + reference_node_hash: Optional[bytes32] = await data_store.get_terminal_node_for_seed(tree_id, seed) + side: Optional[Side] = data_store.get_side_for_seed(seed) - if i == 0: - reference_node_hash = None - side = None - if i % 3 == 0: - t1 = time.time() - if not slow_mode: - await data_store.insert( - key=key, - value=value, - tree_id=tree_id, - reference_node_hash=reference_node_hash, - side=side, - hint_keys_values=hint_keys_values, - ) + if i == 0: + reference_node_hash = None + side = None + if i % 3 == 0: + t1 = time.time() + if not slow_mode: + await data_store.insert( + key=key, + value=value, + tree_id=tree_id, + reference_node_hash=reference_node_hash, + side=side, + hint_keys_values=hint_keys_values, + ) + else: + await data_store.insert( + key=key, + value=value, + tree_id=tree_id, + reference_node_hash=reference_node_hash, + side=side, + use_optimized=False, + ) + t2 = time.time() + insert_time += t2 - t1 + insert_count += 1 + elif i % 3 == 1: + t1 = time.time() + if not slow_mode: + await data_store.autoinsert( + key=key, + value=value, + tree_id=tree_id, + hint_keys_values=hint_keys_values, + ) + else: + await data_store.autoinsert( + key=key, + value=value, + tree_id=tree_id, + use_optimized=False, + ) + t2 = time.time() + autoinsert_time += t2 - t1 + autoinsert_count += 1 else: - await data_store.insert( - key=key, - value=value, - tree_id=tree_id, - reference_node_hash=reference_node_hash, - side=side, - use_optimized=False, - ) - t2 = time.time() - insert_time += t2 - t1 - insert_count += 1 - elif i % 3 == 1: - t1 = time.time() - if not slow_mode: - await data_store.autoinsert( - key=key, - value=value, - tree_id=tree_id, - hint_keys_values=hint_keys_values, - ) - else: - await data_store.autoinsert( - key=key, - value=value, - tree_id=tree_id, - use_optimized=False, - ) - t2 = time.time() - autoinsert_time += t2 - t1 - autoinsert_count += 1 - else: - t1 = time.time() - assert reference_node_hash is not None - node = await data_store.get_node(reference_node_hash) - assert isinstance(node, TerminalNode) - if not slow_mode: - await data_store.delete(key=node.key, tree_id=tree_id, hint_keys_values=hint_keys_values) - else: - await data_store.delete(key=node.key, tree_id=tree_id, use_optimized=False) - t2 = time.time() - delete_time += t2 - t1 - delete_count += 1 + t1 = time.time() + assert reference_node_hash is not None + node = await data_store.get_node(reference_node_hash) + assert isinstance(node, TerminalNode) + if not slow_mode: + await data_store.delete(key=node.key, tree_id=tree_id, hint_keys_values=hint_keys_values) + else: + await data_store.delete(key=node.key, tree_id=tree_id, use_optimized=False) + t2 = time.time() + delete_time += t2 - t1 + delete_count += 1 - print(f"Average insert time: {insert_time / insert_count}") - print(f"Average autoinsert time: {autoinsert_time / autoinsert_count}") - print(f"Average delete time: {delete_time / delete_count}") - print(f"Total time for {num_nodes} operations: {insert_time + autoinsert_time + delete_time}") - root = await data_store.get_tree_root(tree_id=tree_id) - print(f"Root hash: {root.node_hash}") - await connection.close() + print(f"Average insert time: {insert_time / insert_count}") + print(f"Average autoinsert time: {autoinsert_time / autoinsert_count}") + print(f"Average delete time: {delete_time / delete_count}") + print(f"Total time for {num_nodes} operations: {insert_time + autoinsert_time + delete_time}") + root = await data_store.get_tree_root(tree_id=tree_id) + print(f"Root hash: {root.node_hash}") + finally: + await data_store.close() if __name__ == "__main__": diff --git a/chia/util/db_wrapper.py b/chia/util/db_wrapper.py index ae3282cb250b..0f77744952b4 100644 --- a/chia/util/db_wrapper.py +++ b/chia/util/db_wrapper.py @@ -19,52 +19,6 @@ else: SQLITE_MAX_VARIABLE_NUMBER = 32700 -class DBWrapper: - """ - This object handles HeaderBlocks and Blocks stored in DB used by wallet. - """ - - db: aiosqlite.Connection - lock: asyncio.Lock - - def __init__(self, connection: aiosqlite.Connection): - self.db = connection - self.lock = asyncio.Lock() - - async def begin_transaction(self): - cursor = await self.db.execute("BEGIN TRANSACTION") - await cursor.close() - - async def rollback_transaction(self): - # Also rolls back the coin store, since both stores must be updated at once - if self.db.in_transaction: - cursor = await self.db.execute("ROLLBACK") - await cursor.close() - - async def commit_transaction(self) -> None: - await self.db.commit() - - @contextlib.asynccontextmanager - async def locked_transaction(self, *, lock=True): - # TODO: look into contextvars perhaps instead of this manual lock tracking - if not lock: - yield - return - - # TODO: add a lock acquisition timeout - # maybe https://docs.python.org/3/library/asyncio-task.html#asyncio.wait_for - - async with self.lock: - await self.begin_transaction() - try: - yield - except BaseException: - await self.rollback_transaction() - raise - else: - await self.commit_transaction() - - async def execute_fetchone( c: aiosqlite.Connection, sql: str, parameters: Iterable[Any] = None ) -> Optional[sqlite3.Row]: @@ -157,18 +111,25 @@ class DBWrapper2: async def create( cls, database: Union[str, Path], + *, db_version: int = 1, uri: bool = False, reader_count: int = 4, log_path: Optional[Path] = None, journal_mode: str = "WAL", synchronous: Optional[str] = None, + foreign_keys: bool = False, + row_factory: Optional[Type[aiosqlite.Row]] = None, ) -> DBWrapper2: write_connection = await create_connection(database=database, uri=uri, log_path=log_path, name="writer") await (await write_connection.execute(f"pragma journal_mode={journal_mode}")).close() if synchronous is not None: await (await write_connection.execute(f"pragma synchronous={synchronous}")).close() + await (await write_connection.execute(f"pragma foreign_keys={'ON' if foreign_keys else 'OFF'}")).close() + + write_connection.row_factory = row_factory + self = cls(connection=write_connection, db_version=db_version) for index in range(reader_count): @@ -178,6 +139,7 @@ class DBWrapper2: log_path=log_path, name=f"reader-{index}", ) + read_connection.row_factory = row_factory await self.add_connection(c=read_connection) return self diff --git a/tests/core/data_layer/conftest.py b/tests/core/data_layer/conftest.py index cc466b885e58..d1b60614a00a 100644 --- a/tests/core/data_layer/conftest.py +++ b/tests/core/data_layer/conftest.py @@ -3,13 +3,13 @@ from __future__ import annotations import contextlib import os import pathlib +import random import subprocess import sys import sysconfig import time from typing import Any, AsyncIterable, Awaitable, Callable, Dict, Iterator, List -import aiosqlite import pytest import pytest_asyncio @@ -19,7 +19,6 @@ from _pytest.fixtures import SubRequest from chia.data_layer.data_layer_util import NodeType, Status from chia.data_layer.data_store import DataStore from chia.types.blockchain_format.tree_hash import bytes32 -from chia.util.db_wrapper import DBWrapper from tests.core.data_layer.util import ( ChiaRoot, Example, @@ -90,17 +89,9 @@ def create_example_fixture(request: SubRequest) -> Callable[[DataStore, bytes32] return request.param # type: ignore[no-any-return] -@pytest_asyncio.fixture(name="db_connection", scope="function") -async def db_connection_fixture() -> AsyncIterable[aiosqlite.Connection]: - async with aiosqlite.connect(":memory:") as connection: - # make sure this is on for tests even if we disable it at run time - await connection.execute("PRAGMA foreign_keys = ON") - yield connection - - -@pytest.fixture(name="db_wrapper", scope="function") -def db_wrapper_fixture(db_connection: aiosqlite.Connection) -> DBWrapper: - return DBWrapper(db_connection) +@pytest.fixture(name="database_uri") +def database_uri_fixture() -> str: + return f"file:db_{random.randint(0, 99999999)}?mode=memory&cache=shared" @pytest.fixture(name="tree_id", scope="function") @@ -111,8 +102,10 @@ def tree_id_fixture() -> bytes32: @pytest_asyncio.fixture(name="raw_data_store", scope="function") -async def raw_data_store_fixture(db_wrapper: DBWrapper) -> DataStore: - return await DataStore.create(db_wrapper=db_wrapper) +async def raw_data_store_fixture(database_uri: str) -> AsyncIterable[DataStore]: + store = await DataStore.create(database=database_uri, uri=True) + yield store + await store.close() @pytest_asyncio.fixture(name="data_store", scope="function") diff --git a/tests/core/data_layer/test_data_store.py b/tests/core/data_layer/test_data_store.py index c08ed1254e84..6d534232268e 100644 --- a/tests/core/data_layer/test_data_store.py +++ b/tests/core/data_layer/test_data_store.py @@ -7,7 +7,6 @@ from pathlib import Path from random import Random from typing import Any, Awaitable, Callable, Dict, List, Set, Tuple -import aiosqlite import pytest from chia.data_layer.data_layer_errors import NodeHashError, TreeGenerationIncrementingError @@ -38,7 +37,7 @@ from chia.data_layer.download_data import ( from chia.types.blockchain_format.program import Program from chia.types.blockchain_format.tree_hash import bytes32 from chia.util.byte_types import hexstr_to_bytes -from chia.util.db_wrapper import DBWrapper +from chia.util.db_wrapper import DBWrapper2 from tests.core.data_layer.util import Example, add_0123_example, add_01234567_example log = logging.getLogger(__name__) @@ -59,8 +58,8 @@ table_columns: Dict[str, List[str]] = { @pytest.mark.asyncio async def test_valid_node_values_fixture_are_valid(data_store: DataStore, valid_node_values: Dict[str, Any]) -> None: - async with data_store.db_wrapper.locked_transaction(): - await data_store.db.execute( + async with data_store.db_wrapper.writer() as writer: + await writer.execute( """ INSERT INTO node(hash, node_type, left, right, key, value) VALUES(:hash, :node_type, :left, :right, :key, :value) @@ -72,20 +71,29 @@ async def test_valid_node_values_fixture_are_valid(data_store: DataStore, valid_ @pytest.mark.parametrize(argnames=["table_name", "expected_columns"], argvalues=table_columns.items()) @pytest.mark.asyncio async def test_create_creates_tables_and_columns( - db_wrapper: DBWrapper, table_name: str, expected_columns: List[str] + database_uri: str, table_name: str, expected_columns: List[str] ) -> None: # Never string-interpolate sql queries... Except maybe in tests when it does not # allow you to parametrize the query. query = f"pragma table_info({table_name});" - cursor = await db_wrapper.db.execute(query) - columns = await cursor.fetchall() - assert columns == [] + db_wrapper = await DBWrapper2.create(database=database_uri, uri=True, reader_count=1) + try: + async with db_wrapper.reader() as reader: + cursor = await reader.execute(query) + columns = await cursor.fetchall() + assert columns == [] - await DataStore.create(db_wrapper=db_wrapper) - cursor = await db_wrapper.db.execute(query) - columns = await cursor.fetchall() - assert [column[1] for column in columns] == expected_columns + store = await DataStore.create(database=database_uri, uri=True) + try: + async with db_wrapper.reader() as reader: + cursor = await reader.execute(query) + columns = await cursor.fetchall() + assert [column[1] for column in columns] == expected_columns + finally: + await store.close() + finally: + await db_wrapper.close() @pytest.mark.asyncio @@ -195,15 +203,14 @@ async def test_insert_internal_node_does_nothing_if_matching(data_store: DataSto ancestors = await data_store.get_ancestors(node_hash=kv_node.hash, tree_id=tree_id) parent = ancestors[0] - async with data_store.db_wrapper.locked_transaction(): - cursor = await data_store.db.execute("SELECT * FROM node") + async with data_store.db_wrapper.reader() as reader: + cursor = await reader.execute("SELECT * FROM node") before = await cursor.fetchall() - async with data_store.db_wrapper.locked_transaction(): - await data_store._insert_internal_node(left_hash=parent.left_hash, right_hash=parent.right_hash) + await data_store._insert_internal_node(left_hash=parent.left_hash, right_hash=parent.right_hash) - async with data_store.db_wrapper.locked_transaction(): - cursor = await data_store.db.execute("SELECT * FROM node") + async with data_store.db_wrapper.reader() as reader: + cursor = await reader.execute("SELECT * FROM node") after = await cursor.fetchall() assert after == before @@ -215,15 +222,14 @@ async def test_insert_terminal_node_does_nothing_if_matching(data_store: DataSto kv_node = await data_store.get_node_by_key(key=b"\x04", tree_id=tree_id) - async with data_store.db_wrapper.locked_transaction(): - cursor = await data_store.db.execute("SELECT * FROM node") + async with data_store.db_wrapper.reader() as reader: + cursor = await reader.execute("SELECT * FROM node") before = await cursor.fetchall() - async with data_store.db_wrapper.locked_transaction(): - await data_store._insert_terminal_node(key=kv_node.key, value=kv_node.value) + await data_store._insert_terminal_node(key=kv_node.key, value=kv_node.value) - async with data_store.db_wrapper.locked_transaction(): - cursor = await data_store.db.execute("SELECT * FROM node") + async with data_store.db_wrapper.reader() as reader: + cursor = await reader.execute("SELECT * FROM node") after = await cursor.fetchall() assert after == before @@ -237,7 +243,7 @@ async def test_build_a_tree( ) -> None: example = await create_example(data_store, tree_id) - await _debug_dump(db=data_store.db, description="final") + await _debug_dump(db=data_store.db_wrapper, description="final") actual = await data_store.get_tree_as_program(tree_id=tree_id) # print("actual ", actual.as_python()) # print("expected", example.expected.as_python()) @@ -361,54 +367,52 @@ async def test_batch_update(data_store: DataStore, tree_id: bytes32, use_optimiz db_path = tmp_path.joinpath("dl_server_util.sqlite") - connection = await aiosqlite.connect(db_path) - db_wrapper = DBWrapper(connection) - single_op_data_store = await DataStore.create(db_wrapper=db_wrapper) + single_op_data_store = await DataStore.create(database=db_path) + try: + await single_op_data_store.create_tree(tree_id, status=Status.COMMITTED) + random = Random() + random.seed(100, version=2) - await single_op_data_store.create_tree(tree_id, status=Status.COMMITTED) - random = Random() - random.seed(100, version=2) - - batch: List[Dict[str, Any]] = [] - keys: List[bytes] = [] - hint_keys_values: Dict[bytes, bytes] = {} - for operation in range(num_batches * num_ops_per_batch): - if random.randint(0, 4) > 0 or len(keys) == 0: - key = operation.to_bytes(4, byteorder="big") - value = (2 * operation).to_bytes(4, byteorder="big") - if use_optimized: - await single_op_data_store.autoinsert( - key=key, - value=value, - tree_id=tree_id, - hint_keys_values=hint_keys_values, - status=Status.COMMITTED, - ) + batch: List[Dict[str, Any]] = [] + keys: List[bytes] = [] + hint_keys_values: Dict[bytes, bytes] = {} + for operation in range(num_batches * num_ops_per_batch): + if random.randint(0, 4) > 0 or len(keys) == 0: + key = operation.to_bytes(4, byteorder="big") + value = (2 * operation).to_bytes(4, byteorder="big") + if use_optimized: + await single_op_data_store.autoinsert( + key=key, + value=value, + tree_id=tree_id, + hint_keys_values=hint_keys_values, + status=Status.COMMITTED, + ) + else: + await single_op_data_store.autoinsert( + key=key, value=value, tree_id=tree_id, use_optimized=False, status=Status.COMMITTED + ) + batch.append({"action": "insert", "key": key, "value": value}) + keys.append(key) else: - await single_op_data_store.autoinsert( - key=key, value=value, tree_id=tree_id, use_optimized=False, status=Status.COMMITTED - ) - batch.append({"action": "insert", "key": key, "value": value}) - keys.append(key) - else: - key = random.choice(keys) - keys.remove(key) - if use_optimized: - await single_op_data_store.delete( - key=key, tree_id=tree_id, hint_keys_values=hint_keys_values, status=Status.COMMITTED - ) - else: - await single_op_data_store.delete( - key=key, tree_id=tree_id, use_optimized=False, status=Status.COMMITTED - ) - batch.append({"action": "delete", "key": key}) - if (operation + 1) % num_ops_per_batch == 0: - saved_batches.append(batch) - batch = [] - root = await single_op_data_store.get_tree_root(tree_id=tree_id) - saved_roots.append(root) - - await connection.close() + key = random.choice(keys) + keys.remove(key) + if use_optimized: + await single_op_data_store.delete( + key=key, tree_id=tree_id, hint_keys_values=hint_keys_values, status=Status.COMMITTED + ) + else: + await single_op_data_store.delete( + key=key, tree_id=tree_id, use_optimized=False, status=Status.COMMITTED + ) + batch.append({"action": "delete", "key": key}) + if (operation + 1) % num_ops_per_batch == 0: + saved_batches.append(batch) + batch = [] + root = await single_op_data_store.get_tree_root(tree_id=tree_id) + saved_roots.append(root) + finally: + await single_op_data_store.close() for batch_number, batch in enumerate(saved_batches): assert len(batch) == num_ops_per_batch @@ -745,7 +749,7 @@ async def test_proof_of_inclusion_by_hash(data_store: DataStore, tree_id: bytes3 proof = await data_store.get_proof_of_inclusion_by_hash(node_hash=node.hash, tree_id=tree_id) print(node) - await _debug_dump(db=data_store.db) + await _debug_dump(db=data_store.db_wrapper) expected_layers = [ ProofOfInclusionLayer( @@ -862,9 +866,9 @@ invalid_program_hex = b"\xab\xcd".hex() async def test_check_roots_are_incrementing_missing_zero(raw_data_store: DataStore) -> None: tree_id = hexstr_to_bytes("c954ab71ffaf5b0f129b04b35fdc7c84541f4375167e730e2646bfcfdb7cf2cd") - async with raw_data_store.db_wrapper.locked_transaction(): + async with raw_data_store.db_wrapper.writer() as writer: for generation in range(1, 5): - await raw_data_store.db.execute( + await writer.execute( """ INSERT INTO root(tree_id, generation, node_hash, status) VALUES(:tree_id, :generation, :node_hash, :status) @@ -888,9 +892,9 @@ async def test_check_roots_are_incrementing_missing_zero(raw_data_store: DataSto async def test_check_roots_are_incrementing_gap(raw_data_store: DataStore) -> None: tree_id = hexstr_to_bytes("c954ab71ffaf5b0f129b04b35fdc7c84541f4375167e730e2646bfcfdb7cf2cd") - async with raw_data_store.db_wrapper.locked_transaction(): + async with raw_data_store.db_wrapper.writer() as writer: for generation in [*range(5), *range(6, 10)]: - await raw_data_store.db.execute( + await writer.execute( """ INSERT INTO root(tree_id, generation, node_hash, status) VALUES(:tree_id, :generation, :node_hash, :status) @@ -912,8 +916,8 @@ async def test_check_roots_are_incrementing_gap(raw_data_store: DataStore) -> No @pytest.mark.asyncio async def test_check_hashes_internal(raw_data_store: DataStore) -> None: - async with raw_data_store.db_wrapper.locked_transaction(): - await raw_data_store.db.execute( + async with raw_data_store.db_wrapper.writer() as writer: + await writer.execute( "INSERT INTO node(hash, node_type, left, right) VALUES(:hash, :node_type, :left, :right)", { "hash": a_bytes_32, @@ -932,8 +936,8 @@ async def test_check_hashes_internal(raw_data_store: DataStore) -> None: @pytest.mark.asyncio async def test_check_hashes_terminal(raw_data_store: DataStore) -> None: - async with raw_data_store.db_wrapper.locked_transaction(): - await raw_data_store.db.execute( + async with raw_data_store.db_wrapper.writer() as writer: + await writer.execute( "INSERT INTO node(hash, node_type, key, value) VALUES(:hash, :node_type, :key, :value)", { "hash": a_bytes_32, @@ -1178,34 +1182,34 @@ async def test_data_server_files(data_store: DataStore, tree_id: bytes32, test_d db_path = tmp_path.joinpath("dl_server_util.sqlite") - connection = await aiosqlite.connect(db_path) - db_wrapper = DBWrapper(connection) - data_store_server = await DataStore.create(db_wrapper=db_wrapper) - await data_store_server.create_tree(tree_id, status=Status.COMMITTED) - random = Random() - random.seed(100, version=2) + data_store_server = await DataStore.create(database=db_path) + try: + await data_store_server.create_tree(tree_id, status=Status.COMMITTED) + random = Random() + random.seed(100, version=2) - keys: List[bytes] = [] - counter = 0 + keys: List[bytes] = [] + counter = 0 - for batch in range(num_batches): - changelist: List[Dict[str, Any]] = [] - for operation in range(num_ops_per_batch): - if random.randint(0, 4) > 0 or len(keys) == 0: - key = counter.to_bytes(4, byteorder="big") - value = (2 * counter).to_bytes(4, byteorder="big") - keys.append(key) - changelist.append({"action": "insert", "key": key, "value": value}) - else: - key = random.choice(keys) - keys.remove(key) - changelist.append({"action": "delete", "key": key}) - counter += 1 - await data_store_server.insert_batch(tree_id, changelist, status=Status.COMMITTED) - root = await data_store_server.get_tree_root(tree_id) - await write_files_for_root(data_store_server, tree_id, root, tmp_path) - roots.append(root) - await connection.close() + for batch in range(num_batches): + changelist: List[Dict[str, Any]] = [] + for operation in range(num_ops_per_batch): + if random.randint(0, 4) > 0 or len(keys) == 0: + key = counter.to_bytes(4, byteorder="big") + value = (2 * counter).to_bytes(4, byteorder="big") + keys.append(key) + changelist.append({"action": "insert", "key": key, "value": value}) + else: + key = random.choice(keys) + keys.remove(key) + changelist.append({"action": "delete", "key": key}) + counter += 1 + await data_store_server.insert_batch(tree_id, changelist, status=Status.COMMITTED) + root = await data_store_server.get_tree_root(tree_id) + await write_files_for_root(data_store_server, tree_id, root, tmp_path) + roots.append(root) + finally: + await data_store_server.close() generation = 1 assert len(roots) == num_batches diff --git a/tests/core/data_layer/test_data_store_schema.py b/tests/core/data_layer/test_data_store_schema.py index 7e37aacd1f82..705c28745e1a 100644 --- a/tests/core/data_layer/test_data_store_schema.py +++ b/tests/core/data_layer/test_data_store_schema.py @@ -18,9 +18,9 @@ async def test_node_update_fails(data_store: DataStore, tree_id: bytes32) -> Non await add_01234567_example(data_store=data_store, tree_id=tree_id) node = await data_store.get_node_by_key(key=b"\x04", tree_id=tree_id) - async with data_store.db_wrapper.locked_transaction(): + async with data_store.db_wrapper.writer() as writer: with pytest.raises(sqlite3.IntegrityError, match=r"^updates not allowed to the node table$"): - await data_store.db.execute( + await writer.execute( "UPDATE node SET value = :value WHERE hash == :hash", { "hash": node.hash, @@ -39,9 +39,9 @@ async def test_node_hash_must_be_32( ) -> None: valid_node_values["hash"] = bytes([0] * length) - async with data_store.db_wrapper.locked_transaction(): + async with data_store.db_wrapper.writer() as writer: with pytest.raises(sqlite3.IntegrityError, match=r"^CHECK constraint failed:"): - await data_store.db.execute( + await writer.execute( """ INSERT INTO node(hash, node_type, left, right, key, value) VALUES(:hash, :node_type, :left, :right, :key, :value) @@ -58,9 +58,9 @@ async def test_node_hash_must_not_be_null( ) -> None: valid_node_values["hash"] = None - async with data_store.db_wrapper.locked_transaction(): + async with data_store.db_wrapper.writer() as writer: with pytest.raises(sqlite3.IntegrityError, match=r"^NOT NULL constraint failed: node.hash$"): - await data_store.db.execute( + await writer.execute( """ INSERT INTO node(hash, node_type, left, right, key, value) VALUES(:hash, :node_type, :left, :right, :key, :value) @@ -78,9 +78,9 @@ async def test_node_type_must_be_valid( ) -> None: valid_node_values["node_type"] = bad_node_type - async with data_store.db_wrapper.locked_transaction(): + async with data_store.db_wrapper.writer() as writer: with pytest.raises(sqlite3.IntegrityError, match=r"^CHECK constraint failed:"): - await data_store.db.execute( + await writer.execute( """ INSERT INTO node(hash, node_type, left, right, key, value) VALUES(:hash, :node_type, :left, :right, :key, :value) @@ -103,9 +103,9 @@ async def test_node_internal_child_not_null(data_store: DataStore, tree_id: byte elif side == Side.RIGHT: values["right"] = None - async with data_store.db_wrapper.locked_transaction(): + async with data_store.db_wrapper.writer() as writer: with pytest.raises(sqlite3.IntegrityError, match=r"^CHECK constraint failed:"): - await data_store.db.execute( + await writer.execute( """ INSERT INTO node(hash, node_type, left, right, key, value) VALUES(:hash, :node_type, :left, :right, :key, :value) @@ -136,9 +136,9 @@ async def test_node_internal_must_be_valid_reference( else: assert False - async with data_store.db_wrapper.locked_transaction(): + async with data_store.db_wrapper.writer() as writer: with pytest.raises(sqlite3.IntegrityError, match=r"^FOREIGN KEY constraint failed$"): - await data_store.db.execute( + await writer.execute( """ INSERT INTO node(hash, node_type, left, right, key, value) VALUES(:hash, :node_type, :left, :right, :key, :value) @@ -155,9 +155,9 @@ async def test_node_terminal_key_value_not_null(data_store: DataStore, tree_id: values = create_valid_node_values(node_type=NodeType.TERMINAL) values[key_or_value] = None - async with data_store.db_wrapper.locked_transaction(): + async with data_store.db_wrapper.writer() as writer: with pytest.raises(sqlite3.IntegrityError, match=r"^CHECK constraint failed:"): - await data_store.db.execute( + await writer.execute( """ INSERT INTO node(hash, node_type, left, right, key, value) VALUES(:hash, :node_type, :left, :right, :key, :value) @@ -173,9 +173,9 @@ async def test_root_tree_id_must_be_32(data_store: DataStore, tree_id: bytes32, bad_tree_id = bytes([0] * length) values = {"tree_id": bad_tree_id, "generation": 0, "node_hash": example.terminal_nodes[0], "status": Status.PENDING} - async with data_store.db_wrapper.locked_transaction(): + async with data_store.db_wrapper.writer() as writer: with pytest.raises(sqlite3.IntegrityError, match=r"^CHECK constraint failed:"): - await data_store.db.execute( + await writer.execute( """ INSERT INTO root(tree_id, generation, node_hash, status) VALUES(:tree_id, :generation, :node_hash, :status) @@ -189,9 +189,9 @@ async def test_root_tree_id_must_not_be_null(data_store: DataStore, tree_id: byt example = await add_01234567_example(data_store=data_store, tree_id=tree_id) values = {"tree_id": None, "generation": 0, "node_hash": example.terminal_nodes[0], "status": Status.PENDING} - async with data_store.db_wrapper.locked_transaction(): + async with data_store.db_wrapper.writer() as writer: with pytest.raises(sqlite3.IntegrityError, match=r"^NOT NULL constraint failed: root.tree_id$"): - await data_store.db.execute( + await writer.execute( """ INSERT INTO root(tree_id, generation, node_hash, status) VALUES(:tree_id, :generation, :node_hash, :status) @@ -213,9 +213,9 @@ async def test_root_generation_must_not_be_less_than_zero( "status": Status.PENDING, } - async with data_store.db_wrapper.locked_transaction(): + async with data_store.db_wrapper.writer() as writer: with pytest.raises(sqlite3.IntegrityError, match=r"^CHECK constraint failed:"): - await data_store.db.execute( + await writer.execute( """ INSERT INTO root(tree_id, generation, node_hash, status) VALUES(:tree_id, :generation, :node_hash, :status) @@ -234,9 +234,9 @@ async def test_root_generation_must_not_be_null(data_store: DataStore, tree_id: "status": Status.PENDING, } - async with data_store.db_wrapper.locked_transaction(): + async with data_store.db_wrapper.writer() as writer: with pytest.raises(sqlite3.IntegrityError, match=r"^NOT NULL constraint failed: root.generation$"): - await data_store.db.execute( + await writer.execute( """ INSERT INTO root(tree_id, generation, node_hash, status) VALUES(:tree_id, :generation, :node_hash, :status) @@ -249,9 +249,9 @@ async def test_root_generation_must_not_be_null(data_store: DataStore, tree_id: async def test_root_node_hash_must_reference(data_store: DataStore) -> None: values = {"tree_id": bytes32([0] * 32), "generation": 0, "node_hash": bytes32([0] * 32), "status": Status.PENDING} - async with data_store.db_wrapper.locked_transaction(): + async with data_store.db_wrapper.writer() as writer: with pytest.raises(sqlite3.IntegrityError, match=r"^FOREIGN KEY constraint failed$"): - await data_store.db.execute( + await writer.execute( """ INSERT INTO root(tree_id, generation, node_hash, status) VALUES(:tree_id, :generation, :node_hash, :status) @@ -271,9 +271,9 @@ async def test_root_status_must_be_valid(data_store: DataStore, tree_id: bytes32 "status": bad_status, } - async with data_store.db_wrapper.locked_transaction(): + async with data_store.db_wrapper.writer() as writer: with pytest.raises(sqlite3.IntegrityError, match=r"^CHECK constraint failed:"): - await data_store.db.execute( + await writer.execute( """ INSERT INTO root(tree_id, generation, node_hash, status) VALUES(:tree_id, :generation, :node_hash, :status) @@ -287,9 +287,9 @@ async def test_root_status_must_not_be_null(data_store: DataStore, tree_id: byte example = await add_01234567_example(data_store=data_store, tree_id=tree_id) values = {"tree_id": bytes32([0] * 32), "generation": 0, "node_hash": example.terminal_nodes[0], "status": None} - async with data_store.db_wrapper.locked_transaction(): + async with data_store.db_wrapper.writer() as writer: with pytest.raises(sqlite3.IntegrityError, match=r"^NOT NULL constraint failed: root.status$"): - await data_store.db.execute( + await writer.execute( """ INSERT INTO root(tree_id, generation, node_hash, status) VALUES(:tree_id, :generation, :node_hash, :status) @@ -303,9 +303,9 @@ async def test_root_tree_id_generation_must_be_unique(data_store: DataStore, tre example = await add_01234567_example(data_store=data_store, tree_id=tree_id) values = {"tree_id": tree_id, "generation": 0, "node_hash": example.terminal_nodes[0], "status": Status.COMMITTED} - async with data_store.db_wrapper.locked_transaction(): + async with data_store.db_wrapper.writer() as writer: with pytest.raises(sqlite3.IntegrityError, match=r"^UNIQUE constraint failed: root.tree_id, root.generation$"): - await data_store.db.execute( + await writer.execute( """ INSERT INTO root(tree_id, generation, node_hash, status) VALUES(:tree_id, :generation, :node_hash, :status) @@ -321,10 +321,10 @@ async def test_ancestors_ancestor_must_be_32( tree_id: bytes32, length: int, ) -> None: - async with data_store.db_wrapper.locked_transaction(): + async with data_store.db_wrapper.writer() as writer: node_hash = await data_store._insert_terminal_node(key=b"\x00", value=b"\x01") with pytest.raises(sqlite3.IntegrityError, match=r"^CHECK constraint failed:"): - await data_store.db.execute( + await writer.execute( """ INSERT INTO ancestors(hash, ancestor, tree_id, generation) VALUES(:hash, :ancestor, :tree_id, :generation) @@ -340,10 +340,10 @@ async def test_ancestors_tree_id_must_be_32( tree_id: bytes32, length: int, ) -> None: - async with data_store.db_wrapper.locked_transaction(): + async with data_store.db_wrapper.writer() as writer: node_hash = await data_store._insert_terminal_node(key=b"\x00", value=b"\x01") with pytest.raises(sqlite3.IntegrityError, match=r"^CHECK constraint failed:"): - await data_store.db.execute( + await writer.execute( """ INSERT INTO ancestors(hash, ancestor, tree_id, generation) VALUES(:hash, :ancestor, :tree_id, :generation) @@ -359,9 +359,9 @@ async def test_subscriptions_tree_id_must_be_32( tree_id: bytes32, length: int, ) -> None: - async with data_store.db_wrapper.locked_transaction(): + async with data_store.db_wrapper.writer() as writer: with pytest.raises(sqlite3.IntegrityError, match=r"^CHECK constraint failed:"): - await data_store.db.execute( + await writer.execute( """ INSERT INTO subscriptions(tree_id, url, ignore_till, num_consecutive_failures, from_wallet) VALUES(:tree_id, :url, :ignore_till, :num_consecutive_failures, :from_wallet) diff --git a/tests/util/db_connection.py b/tests/util/db_connection.py index b7cbc175c159..5173ee2dcbf7 100644 --- a/tests/util/db_connection.py +++ b/tests/util/db_connection.py @@ -1,8 +1,6 @@ from pathlib import Path from chia.util.db_wrapper import DBWrapper2 -from chia.util.db_wrapper import DBWrapper import tempfile -import aiosqlite class DBConnection: @@ -20,18 +18,3 @@ class DBConnection: async def __aexit__(self, exc_t, exc_v, exc_tb) -> None: await self._db_wrapper.close() self.db_path.unlink() - - -# This is just here until all DBWrappers have been upgraded to DBWrapper2 -class DBConnection1: - async def __aenter__(self) -> DBWrapper: - self.db_path = Path(tempfile.NamedTemporaryFile().name) - if self.db_path.exists(): - self.db_path.unlink() - self._db_connection = await aiosqlite.connect(self.db_path) - self._db_wrapper = DBWrapper(self._db_connection) - return self._db_wrapper - - async def __aexit__(self, exc_t, exc_v, exc_tb) -> None: - await self._db_connection.close() - self.db_path.unlink()