convert datalayer to DBWrapper2 (#13582)

* convert datalayer to DBWrapper2 (all write)

* more read, less write

* remove unneeded connection managers

* and...  close it

* data store now creates its own wrapper

* Drop unused hint DataLayer.batch_update_db_wrapper

* require named arguments for most of `DBWrapper2.create()`
This commit is contained in:
Kyle Altendorf 2022-10-03 18:50:12 -04:00 committed by GitHub
parent 305875d0d4
commit 5b39550f73
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 580 additions and 700 deletions

View File

@ -9,7 +9,6 @@ from pathlib import Path
from typing import Any, Awaitable, Callable, Dict, List, Optional, Set, Tuple, Union
import aiohttp
import aiosqlite
from chia.data_layer.data_layer_errors import KeyNotFoundError
from chia.data_layer.data_layer_util import (
@ -39,7 +38,6 @@ from chia.server.outbound_message import NodeType
from chia.server.server import ChiaServer
from chia.server.ws_connection import WSChiaConnection
from chia.types.blockchain_format.sized_bytes import bytes32
from chia.util.db_wrapper import DBWrapper
from chia.util.ints import uint32, uint64
from chia.util.path import path_from_root
from chia.wallet.trade_record import TradeRecord
@ -49,10 +47,7 @@ from chia.wallet.transaction_record import TransactionRecord
class DataLayer:
data_store: DataStore
db_wrapper: DBWrapper
batch_update_db_wrapper: DBWrapper
db_path: Path
connection: Optional[aiosqlite.Connection]
config: Dict[str, Any]
log: logging.Logger
wallet_rpc_init: Awaitable[WalletRpcClient]
@ -114,9 +109,7 @@ class DataLayer:
self._server = server
async def _start(self) -> None:
self.connection = await aiosqlite.connect(self.db_path)
self.db_wrapper = DBWrapper(self.connection)
self.data_store = await DataStore.create(self.db_wrapper)
self.data_store = await DataStore.create(database=self.db_path)
self.wallet_rpc = await self.wallet_rpc_init
self.subscription_lock: asyncio.Lock = asyncio.Lock()
@ -133,6 +126,7 @@ class DataLayer:
self.periodically_manage_data_task.cancel()
except asyncio.CancelledError:
pass
await self.data_store.close()
async def create_store(
self, fee: uint64, root: bytes32 = bytes32([0] * 32)
@ -157,13 +151,12 @@ class DataLayer:
self,
tree_id: bytes32,
changelist: List[Dict[str, Any]],
lock: bool = True,
) -> bytes32:
async with self.data_store.transaction(lock=lock):
async with self.data_store.transaction():
# Make sure we update based on the latest confirmed root.
async with self.lock:
await self._update_confirmation_status(tree_id=tree_id, lock=False)
pending_root: Optional[Root] = await self.data_store.get_pending_root(tree_id=tree_id, lock=False)
await self._update_confirmation_status(tree_id=tree_id)
pending_root: Optional[Root] = await self.data_store.get_pending_root(tree_id=tree_id)
if pending_root is not None:
raise Exception("Already have a pending root waiting for confirmation.")
@ -173,7 +166,7 @@ class DataLayer:
raise ValueError(f"Singleton with launcher ID {tree_id} is not owned by DL Wallet")
t1 = time.monotonic()
batch_hash = await self.data_store.insert_batch(tree_id, changelist, lock=False)
batch_hash = await self.data_store.insert_batch(tree_id, changelist)
t2 = time.monotonic()
self.log.info(f"Data store batch update process time: {t2 - t1}.")
# todo return empty node hash from get_tree_root
@ -210,21 +203,18 @@ class DataLayer:
store_id: bytes32,
key: bytes,
root_hash: Optional[bytes32] = None,
lock: bool = True,
) -> bytes32:
async with self.data_store.transaction(lock=lock):
async with self.data_store.transaction():
async with self.lock:
await self._update_confirmation_status(tree_id=store_id, lock=False)
node = await self.data_store.get_node_by_key(tree_id=store_id, key=key, root_hash=root_hash, lock=False)
await self._update_confirmation_status(tree_id=store_id)
node = await self.data_store.get_node_by_key(tree_id=store_id, key=key, root_hash=root_hash)
return node.hash
async def get_value(
self, store_id: bytes32, key: bytes, root_hash: Optional[bytes32] = None, lock: bool = True
) -> Optional[bytes]:
async with self.data_store.transaction(lock=lock):
async def get_value(self, store_id: bytes32, key: bytes, root_hash: Optional[bytes32] = None) -> Optional[bytes]:
async with self.data_store.transaction():
async with self.lock:
await self._update_confirmation_status(tree_id=store_id, lock=False)
res = await self.data_store.get_node_by_key(tree_id=store_id, key=key, root_hash=root_hash, lock=False)
await self._update_confirmation_status(tree_id=store_id)
res = await self.data_store.get_node_by_key(tree_id=store_id, key=key, root_hash=root_hash)
if res is None:
self.log.error("Failed to fetch key")
return None
@ -281,10 +271,10 @@ class DataLayer:
prev = record
return root_history
async def _update_confirmation_status(self, tree_id: bytes32, lock: bool = True) -> None:
async with self.data_store.transaction(lock=lock):
async def _update_confirmation_status(self, tree_id: bytes32) -> None:
async with self.data_store.transaction():
try:
root = await self.data_store.get_tree_root(tree_id=tree_id, lock=False)
root = await self.data_store.get_tree_root(tree_id=tree_id)
except asyncio.CancelledError:
raise
except Exception:
@ -293,11 +283,11 @@ class DataLayer:
if singleton_record is None:
return
if root is None:
pending_root = await self.data_store.get_pending_root(tree_id=tree_id, lock=False)
pending_root = await self.data_store.get_pending_root(tree_id=tree_id)
if pending_root is not None:
if pending_root.generation == 0 and pending_root.node_hash is None:
await self.data_store.change_root_status(pending_root, Status.COMMITTED, lock=False)
await self.data_store.clear_pending_roots(tree_id=tree_id, lock=False)
await self.data_store.change_root_status(pending_root, Status.COMMITTED)
await self.data_store.clear_pending_roots(tree_id=tree_id)
return
else:
root = None
@ -324,19 +314,19 @@ class DataLayer:
generation_shift += 1
new_hashes.pop(0)
if generation_shift > 0:
await self.data_store.clear_pending_roots(tree_id=tree_id, lock=False)
await self.data_store.shift_root_generations(tree_id=tree_id, shift_size=generation_shift, lock=False)
await self.data_store.clear_pending_roots(tree_id=tree_id)
await self.data_store.shift_root_generations(tree_id=tree_id, shift_size=generation_shift)
else:
expected_root_hash = None if new_hashes[0] == self.none_bytes else new_hashes[0]
pending_root = await self.data_store.get_pending_root(tree_id=tree_id, lock=False)
pending_root = await self.data_store.get_pending_root(tree_id=tree_id)
if (
pending_root is not None
and pending_root.generation == root.generation + 1
and pending_root.node_hash == expected_root_hash
):
await self.data_store.change_root_status(pending_root, Status.COMMITTED, lock=False)
await self.data_store.build_ancestor_table_for_latest_root(tree_id=tree_id, lock=False)
await self.data_store.clear_pending_roots(tree_id=tree_id, lock=False)
await self.data_store.change_root_status(pending_root, Status.COMMITTED)
await self.data_store.build_ancestor_table_for_latest_root(tree_id=tree_id)
await self.data_store.clear_pending_roots(tree_id=tree_id)
async def fetch_and_validate(self, tree_id: bytes32) -> None:
singleton_record: Optional[SingletonRecord] = await self.wallet_rpc.dl_latest_singleton(tree_id, True)
@ -557,13 +547,12 @@ class DataLayer:
self,
store_id: bytes32,
inclusions: Tuple[KeyValue, ...],
lock: bool = True,
) -> List[Dict[str, Any]]:
async with self.data_store.transaction(lock=lock):
async with self.data_store.transaction():
changelist: List[Dict[str, Any]] = []
for entry in inclusions:
try:
existing_value = await self.get_value(store_id=store_id, key=entry.key, lock=False)
existing_value = await self.get_value(store_id=store_id, key=entry.key)
except KeyNotFoundError:
existing_value = None
@ -590,26 +579,22 @@ class DataLayer:
return changelist
async def process_offered_stores(
self, offer_stores: Tuple[OfferStore, ...], lock: bool = True
) -> Dict[bytes32, StoreProofs]:
async with self.data_store.transaction(lock=lock):
async def process_offered_stores(self, offer_stores: Tuple[OfferStore, ...]) -> Dict[bytes32, StoreProofs]:
async with self.data_store.transaction():
our_store_proofs: Dict[bytes32, StoreProofs] = {}
for offer_store in offer_stores:
async with self.lock:
await self._update_confirmation_status(tree_id=offer_store.store_id, lock=False)
await self._update_confirmation_status(tree_id=offer_store.store_id)
changelist = await self.build_offer_changelist(
store_id=offer_store.store_id,
inclusions=offer_store.inclusions,
lock=False,
)
if len(changelist) > 0:
new_root_hash = await self.batch_insert(
tree_id=offer_store.store_id,
changelist=changelist,
lock=False,
)
else:
existing_root = await self.get_root(store_id=offer_store.store_id)
@ -626,13 +611,11 @@ class DataLayer:
store_id=offer_store.store_id,
key=entry.key,
root_hash=new_root_hash,
lock=False,
)
proof_of_inclusion = await self.data_store.get_proof_of_inclusion_by_hash(
node_hash=node_hash,
tree_id=offer_store.store_id,
root_hash=new_root_hash,
lock=False,
)
proof = Proof(
key=entry.key,
@ -659,7 +642,7 @@ class DataLayer:
fee: uint64,
) -> Offer:
async with self.data_store.transaction():
our_store_proofs = await self.process_offered_stores(offer_stores=maker, lock=False)
our_store_proofs = await self.process_offered_stores(offer_stores=maker)
offer_dict: Dict[Union[uint32, str], int] = {
**{offer_store.store_id.hex(): -1 for offer_store in maker},
@ -717,7 +700,7 @@ class DataLayer:
fee: uint64,
) -> TradeRecord:
async with self.data_store.transaction():
our_store_proofs = await self.process_offered_stores(offer_stores=taker, lock=False)
our_store_proofs = await self.process_offered_stores(offer_stores=taker)
offer = TradingOffer.from_bytes(offer_bytes)
summary = await DataLayerWallet.get_offer_summary(offer=offer)

View File

@ -12,6 +12,7 @@ from typing_extensions import final
from chia.types.blockchain_format.program import Program
from chia.types.blockchain_format.sized_bytes import bytes32
from chia.util.byte_types import hexstr_to_bytes
from chia.util.db_wrapper import DBWrapper2
from chia.util.ints import uint64
from chia.util.streamable import Streamable, streamable
@ -42,14 +43,15 @@ def leaf_hash(key: bytes, value: bytes) -> bytes32:
return Program.to((key, value)).get_tree_hash() # type: ignore[no-any-return]
async def _debug_dump(db: aiosqlite.Connection, description: str = "") -> None:
cursor = await db.execute("SELECT name FROM sqlite_master WHERE type='table';")
print("-" * 50, description, flush=True)
for [name] in await cursor.fetchall():
cursor = await db.execute(f"SELECT * FROM {name}")
print(f"\n -- {name} ------", flush=True)
async for row in cursor:
print(f" {dict(row)}")
async def _debug_dump(db: DBWrapper2, description: str = "") -> None:
async with db.reader() as reader:
cursor = await reader.execute("SELECT name FROM sqlite_master WHERE type='table';")
print("-" * 50, description, flush=True)
for [name] in await cursor.fetchall():
cursor = await reader.execute(f"SELECT * FROM {name}")
print(f"\n -- {name} ------", flush=True)
async for row in cursor:
print(f" {dict(row)}")
async def _dot_dump(data_store: DataStore, store_id: bytes32, root_hash: bytes32) -> str:

File diff suppressed because it is too large Load Diff

View File

@ -8,12 +8,9 @@ import time
from pathlib import Path
from typing import Dict, Optional
import aiosqlite
from chia.data_layer.data_layer_util import Side, TerminalNode, leaf_hash
from chia.data_layer.data_store import DataStore
from chia.types.blockchain_format.sized_bytes import bytes32
from chia.util.db_wrapper import DBWrapper
async def generate_datastore(num_nodes: int, slow_mode: bool) -> None:
@ -25,93 +22,93 @@ async def generate_datastore(num_nodes: int, slow_mode: bool) -> None:
if os.path.exists(db_path):
os.remove(db_path)
connection = await aiosqlite.connect(db_path)
db_wrapper = DBWrapper(connection)
data_store = await DataStore.create(db_wrapper=db_wrapper)
hint_keys_values: Dict[bytes, bytes] = {}
data_store = await DataStore.create(database=db_path)
try:
hint_keys_values: Dict[bytes, bytes] = {}
tree_id = bytes32(b"0" * 32)
await data_store.create_tree(tree_id)
tree_id = bytes32(b"0" * 32)
await data_store.create_tree(tree_id)
insert_time = 0.0
insert_count = 0
autoinsert_time = 0.0
autoinsert_count = 0
delete_time = 0.0
delete_count = 0
insert_time = 0.0
insert_count = 0
autoinsert_time = 0.0
autoinsert_count = 0
delete_time = 0.0
delete_count = 0
for i in range(num_nodes):
key = i.to_bytes(4, byteorder="big")
value = (2 * i).to_bytes(4, byteorder="big")
seed = leaf_hash(key=key, value=value)
reference_node_hash: Optional[bytes32] = await data_store.get_terminal_node_for_seed(tree_id, seed)
side: Optional[Side] = data_store.get_side_for_seed(seed)
for i in range(num_nodes):
key = i.to_bytes(4, byteorder="big")
value = (2 * i).to_bytes(4, byteorder="big")
seed = leaf_hash(key=key, value=value)
reference_node_hash: Optional[bytes32] = await data_store.get_terminal_node_for_seed(tree_id, seed)
side: Optional[Side] = data_store.get_side_for_seed(seed)
if i == 0:
reference_node_hash = None
side = None
if i % 3 == 0:
t1 = time.time()
if not slow_mode:
await data_store.insert(
key=key,
value=value,
tree_id=tree_id,
reference_node_hash=reference_node_hash,
side=side,
hint_keys_values=hint_keys_values,
)
if i == 0:
reference_node_hash = None
side = None
if i % 3 == 0:
t1 = time.time()
if not slow_mode:
await data_store.insert(
key=key,
value=value,
tree_id=tree_id,
reference_node_hash=reference_node_hash,
side=side,
hint_keys_values=hint_keys_values,
)
else:
await data_store.insert(
key=key,
value=value,
tree_id=tree_id,
reference_node_hash=reference_node_hash,
side=side,
use_optimized=False,
)
t2 = time.time()
insert_time += t2 - t1
insert_count += 1
elif i % 3 == 1:
t1 = time.time()
if not slow_mode:
await data_store.autoinsert(
key=key,
value=value,
tree_id=tree_id,
hint_keys_values=hint_keys_values,
)
else:
await data_store.autoinsert(
key=key,
value=value,
tree_id=tree_id,
use_optimized=False,
)
t2 = time.time()
autoinsert_time += t2 - t1
autoinsert_count += 1
else:
await data_store.insert(
key=key,
value=value,
tree_id=tree_id,
reference_node_hash=reference_node_hash,
side=side,
use_optimized=False,
)
t2 = time.time()
insert_time += t2 - t1
insert_count += 1
elif i % 3 == 1:
t1 = time.time()
if not slow_mode:
await data_store.autoinsert(
key=key,
value=value,
tree_id=tree_id,
hint_keys_values=hint_keys_values,
)
else:
await data_store.autoinsert(
key=key,
value=value,
tree_id=tree_id,
use_optimized=False,
)
t2 = time.time()
autoinsert_time += t2 - t1
autoinsert_count += 1
else:
t1 = time.time()
assert reference_node_hash is not None
node = await data_store.get_node(reference_node_hash)
assert isinstance(node, TerminalNode)
if not slow_mode:
await data_store.delete(key=node.key, tree_id=tree_id, hint_keys_values=hint_keys_values)
else:
await data_store.delete(key=node.key, tree_id=tree_id, use_optimized=False)
t2 = time.time()
delete_time += t2 - t1
delete_count += 1
t1 = time.time()
assert reference_node_hash is not None
node = await data_store.get_node(reference_node_hash)
assert isinstance(node, TerminalNode)
if not slow_mode:
await data_store.delete(key=node.key, tree_id=tree_id, hint_keys_values=hint_keys_values)
else:
await data_store.delete(key=node.key, tree_id=tree_id, use_optimized=False)
t2 = time.time()
delete_time += t2 - t1
delete_count += 1
print(f"Average insert time: {insert_time / insert_count}")
print(f"Average autoinsert time: {autoinsert_time / autoinsert_count}")
print(f"Average delete time: {delete_time / delete_count}")
print(f"Total time for {num_nodes} operations: {insert_time + autoinsert_time + delete_time}")
root = await data_store.get_tree_root(tree_id=tree_id)
print(f"Root hash: {root.node_hash}")
await connection.close()
print(f"Average insert time: {insert_time / insert_count}")
print(f"Average autoinsert time: {autoinsert_time / autoinsert_count}")
print(f"Average delete time: {delete_time / delete_count}")
print(f"Total time for {num_nodes} operations: {insert_time + autoinsert_time + delete_time}")
root = await data_store.get_tree_root(tree_id=tree_id)
print(f"Root hash: {root.node_hash}")
finally:
await data_store.close()
if __name__ == "__main__":

View File

@ -19,52 +19,6 @@ else:
SQLITE_MAX_VARIABLE_NUMBER = 32700
class DBWrapper:
"""
This object handles HeaderBlocks and Blocks stored in DB used by wallet.
"""
db: aiosqlite.Connection
lock: asyncio.Lock
def __init__(self, connection: aiosqlite.Connection):
self.db = connection
self.lock = asyncio.Lock()
async def begin_transaction(self):
cursor = await self.db.execute("BEGIN TRANSACTION")
await cursor.close()
async def rollback_transaction(self):
# Also rolls back the coin store, since both stores must be updated at once
if self.db.in_transaction:
cursor = await self.db.execute("ROLLBACK")
await cursor.close()
async def commit_transaction(self) -> None:
await self.db.commit()
@contextlib.asynccontextmanager
async def locked_transaction(self, *, lock=True):
# TODO: look into contextvars perhaps instead of this manual lock tracking
if not lock:
yield
return
# TODO: add a lock acquisition timeout
# maybe https://docs.python.org/3/library/asyncio-task.html#asyncio.wait_for
async with self.lock:
await self.begin_transaction()
try:
yield
except BaseException:
await self.rollback_transaction()
raise
else:
await self.commit_transaction()
async def execute_fetchone(
c: aiosqlite.Connection, sql: str, parameters: Iterable[Any] = None
) -> Optional[sqlite3.Row]:
@ -157,18 +111,25 @@ class DBWrapper2:
async def create(
cls,
database: Union[str, Path],
*,
db_version: int = 1,
uri: bool = False,
reader_count: int = 4,
log_path: Optional[Path] = None,
journal_mode: str = "WAL",
synchronous: Optional[str] = None,
foreign_keys: bool = False,
row_factory: Optional[Type[aiosqlite.Row]] = None,
) -> DBWrapper2:
write_connection = await create_connection(database=database, uri=uri, log_path=log_path, name="writer")
await (await write_connection.execute(f"pragma journal_mode={journal_mode}")).close()
if synchronous is not None:
await (await write_connection.execute(f"pragma synchronous={synchronous}")).close()
await (await write_connection.execute(f"pragma foreign_keys={'ON' if foreign_keys else 'OFF'}")).close()
write_connection.row_factory = row_factory
self = cls(connection=write_connection, db_version=db_version)
for index in range(reader_count):
@ -178,6 +139,7 @@ class DBWrapper2:
log_path=log_path,
name=f"reader-{index}",
)
read_connection.row_factory = row_factory
await self.add_connection(c=read_connection)
return self

View File

@ -3,13 +3,13 @@ from __future__ import annotations
import contextlib
import os
import pathlib
import random
import subprocess
import sys
import sysconfig
import time
from typing import Any, AsyncIterable, Awaitable, Callable, Dict, Iterator, List
import aiosqlite
import pytest
import pytest_asyncio
@ -19,7 +19,6 @@ from _pytest.fixtures import SubRequest
from chia.data_layer.data_layer_util import NodeType, Status
from chia.data_layer.data_store import DataStore
from chia.types.blockchain_format.tree_hash import bytes32
from chia.util.db_wrapper import DBWrapper
from tests.core.data_layer.util import (
ChiaRoot,
Example,
@ -90,17 +89,9 @@ def create_example_fixture(request: SubRequest) -> Callable[[DataStore, bytes32]
return request.param # type: ignore[no-any-return]
@pytest_asyncio.fixture(name="db_connection", scope="function")
async def db_connection_fixture() -> AsyncIterable[aiosqlite.Connection]:
async with aiosqlite.connect(":memory:") as connection:
# make sure this is on for tests even if we disable it at run time
await connection.execute("PRAGMA foreign_keys = ON")
yield connection
@pytest.fixture(name="db_wrapper", scope="function")
def db_wrapper_fixture(db_connection: aiosqlite.Connection) -> DBWrapper:
return DBWrapper(db_connection)
@pytest.fixture(name="database_uri")
def database_uri_fixture() -> str:
return f"file:db_{random.randint(0, 99999999)}?mode=memory&cache=shared"
@pytest.fixture(name="tree_id", scope="function")
@ -111,8 +102,10 @@ def tree_id_fixture() -> bytes32:
@pytest_asyncio.fixture(name="raw_data_store", scope="function")
async def raw_data_store_fixture(db_wrapper: DBWrapper) -> DataStore:
return await DataStore.create(db_wrapper=db_wrapper)
async def raw_data_store_fixture(database_uri: str) -> AsyncIterable[DataStore]:
store = await DataStore.create(database=database_uri, uri=True)
yield store
await store.close()
@pytest_asyncio.fixture(name="data_store", scope="function")

View File

@ -7,7 +7,6 @@ from pathlib import Path
from random import Random
from typing import Any, Awaitable, Callable, Dict, List, Set, Tuple
import aiosqlite
import pytest
from chia.data_layer.data_layer_errors import NodeHashError, TreeGenerationIncrementingError
@ -38,7 +37,7 @@ from chia.data_layer.download_data import (
from chia.types.blockchain_format.program import Program
from chia.types.blockchain_format.tree_hash import bytes32
from chia.util.byte_types import hexstr_to_bytes
from chia.util.db_wrapper import DBWrapper
from chia.util.db_wrapper import DBWrapper2
from tests.core.data_layer.util import Example, add_0123_example, add_01234567_example
log = logging.getLogger(__name__)
@ -59,8 +58,8 @@ table_columns: Dict[str, List[str]] = {
@pytest.mark.asyncio
async def test_valid_node_values_fixture_are_valid(data_store: DataStore, valid_node_values: Dict[str, Any]) -> None:
async with data_store.db_wrapper.locked_transaction():
await data_store.db.execute(
async with data_store.db_wrapper.writer() as writer:
await writer.execute(
"""
INSERT INTO node(hash, node_type, left, right, key, value)
VALUES(:hash, :node_type, :left, :right, :key, :value)
@ -72,20 +71,29 @@ async def test_valid_node_values_fixture_are_valid(data_store: DataStore, valid_
@pytest.mark.parametrize(argnames=["table_name", "expected_columns"], argvalues=table_columns.items())
@pytest.mark.asyncio
async def test_create_creates_tables_and_columns(
db_wrapper: DBWrapper, table_name: str, expected_columns: List[str]
database_uri: str, table_name: str, expected_columns: List[str]
) -> None:
# Never string-interpolate sql queries... Except maybe in tests when it does not
# allow you to parametrize the query.
query = f"pragma table_info({table_name});"
cursor = await db_wrapper.db.execute(query)
columns = await cursor.fetchall()
assert columns == []
db_wrapper = await DBWrapper2.create(database=database_uri, uri=True, reader_count=1)
try:
async with db_wrapper.reader() as reader:
cursor = await reader.execute(query)
columns = await cursor.fetchall()
assert columns == []
await DataStore.create(db_wrapper=db_wrapper)
cursor = await db_wrapper.db.execute(query)
columns = await cursor.fetchall()
assert [column[1] for column in columns] == expected_columns
store = await DataStore.create(database=database_uri, uri=True)
try:
async with db_wrapper.reader() as reader:
cursor = await reader.execute(query)
columns = await cursor.fetchall()
assert [column[1] for column in columns] == expected_columns
finally:
await store.close()
finally:
await db_wrapper.close()
@pytest.mark.asyncio
@ -195,15 +203,14 @@ async def test_insert_internal_node_does_nothing_if_matching(data_store: DataSto
ancestors = await data_store.get_ancestors(node_hash=kv_node.hash, tree_id=tree_id)
parent = ancestors[0]
async with data_store.db_wrapper.locked_transaction():
cursor = await data_store.db.execute("SELECT * FROM node")
async with data_store.db_wrapper.reader() as reader:
cursor = await reader.execute("SELECT * FROM node")
before = await cursor.fetchall()
async with data_store.db_wrapper.locked_transaction():
await data_store._insert_internal_node(left_hash=parent.left_hash, right_hash=parent.right_hash)
await data_store._insert_internal_node(left_hash=parent.left_hash, right_hash=parent.right_hash)
async with data_store.db_wrapper.locked_transaction():
cursor = await data_store.db.execute("SELECT * FROM node")
async with data_store.db_wrapper.reader() as reader:
cursor = await reader.execute("SELECT * FROM node")
after = await cursor.fetchall()
assert after == before
@ -215,15 +222,14 @@ async def test_insert_terminal_node_does_nothing_if_matching(data_store: DataSto
kv_node = await data_store.get_node_by_key(key=b"\x04", tree_id=tree_id)
async with data_store.db_wrapper.locked_transaction():
cursor = await data_store.db.execute("SELECT * FROM node")
async with data_store.db_wrapper.reader() as reader:
cursor = await reader.execute("SELECT * FROM node")
before = await cursor.fetchall()
async with data_store.db_wrapper.locked_transaction():
await data_store._insert_terminal_node(key=kv_node.key, value=kv_node.value)
await data_store._insert_terminal_node(key=kv_node.key, value=kv_node.value)
async with data_store.db_wrapper.locked_transaction():
cursor = await data_store.db.execute("SELECT * FROM node")
async with data_store.db_wrapper.reader() as reader:
cursor = await reader.execute("SELECT * FROM node")
after = await cursor.fetchall()
assert after == before
@ -237,7 +243,7 @@ async def test_build_a_tree(
) -> None:
example = await create_example(data_store, tree_id)
await _debug_dump(db=data_store.db, description="final")
await _debug_dump(db=data_store.db_wrapper, description="final")
actual = await data_store.get_tree_as_program(tree_id=tree_id)
# print("actual ", actual.as_python())
# print("expected", example.expected.as_python())
@ -361,54 +367,52 @@ async def test_batch_update(data_store: DataStore, tree_id: bytes32, use_optimiz
db_path = tmp_path.joinpath("dl_server_util.sqlite")
connection = await aiosqlite.connect(db_path)
db_wrapper = DBWrapper(connection)
single_op_data_store = await DataStore.create(db_wrapper=db_wrapper)
single_op_data_store = await DataStore.create(database=db_path)
try:
await single_op_data_store.create_tree(tree_id, status=Status.COMMITTED)
random = Random()
random.seed(100, version=2)
await single_op_data_store.create_tree(tree_id, status=Status.COMMITTED)
random = Random()
random.seed(100, version=2)
batch: List[Dict[str, Any]] = []
keys: List[bytes] = []
hint_keys_values: Dict[bytes, bytes] = {}
for operation in range(num_batches * num_ops_per_batch):
if random.randint(0, 4) > 0 or len(keys) == 0:
key = operation.to_bytes(4, byteorder="big")
value = (2 * operation).to_bytes(4, byteorder="big")
if use_optimized:
await single_op_data_store.autoinsert(
key=key,
value=value,
tree_id=tree_id,
hint_keys_values=hint_keys_values,
status=Status.COMMITTED,
)
batch: List[Dict[str, Any]] = []
keys: List[bytes] = []
hint_keys_values: Dict[bytes, bytes] = {}
for operation in range(num_batches * num_ops_per_batch):
if random.randint(0, 4) > 0 or len(keys) == 0:
key = operation.to_bytes(4, byteorder="big")
value = (2 * operation).to_bytes(4, byteorder="big")
if use_optimized:
await single_op_data_store.autoinsert(
key=key,
value=value,
tree_id=tree_id,
hint_keys_values=hint_keys_values,
status=Status.COMMITTED,
)
else:
await single_op_data_store.autoinsert(
key=key, value=value, tree_id=tree_id, use_optimized=False, status=Status.COMMITTED
)
batch.append({"action": "insert", "key": key, "value": value})
keys.append(key)
else:
await single_op_data_store.autoinsert(
key=key, value=value, tree_id=tree_id, use_optimized=False, status=Status.COMMITTED
)
batch.append({"action": "insert", "key": key, "value": value})
keys.append(key)
else:
key = random.choice(keys)
keys.remove(key)
if use_optimized:
await single_op_data_store.delete(
key=key, tree_id=tree_id, hint_keys_values=hint_keys_values, status=Status.COMMITTED
)
else:
await single_op_data_store.delete(
key=key, tree_id=tree_id, use_optimized=False, status=Status.COMMITTED
)
batch.append({"action": "delete", "key": key})
if (operation + 1) % num_ops_per_batch == 0:
saved_batches.append(batch)
batch = []
root = await single_op_data_store.get_tree_root(tree_id=tree_id)
saved_roots.append(root)
await connection.close()
key = random.choice(keys)
keys.remove(key)
if use_optimized:
await single_op_data_store.delete(
key=key, tree_id=tree_id, hint_keys_values=hint_keys_values, status=Status.COMMITTED
)
else:
await single_op_data_store.delete(
key=key, tree_id=tree_id, use_optimized=False, status=Status.COMMITTED
)
batch.append({"action": "delete", "key": key})
if (operation + 1) % num_ops_per_batch == 0:
saved_batches.append(batch)
batch = []
root = await single_op_data_store.get_tree_root(tree_id=tree_id)
saved_roots.append(root)
finally:
await single_op_data_store.close()
for batch_number, batch in enumerate(saved_batches):
assert len(batch) == num_ops_per_batch
@ -745,7 +749,7 @@ async def test_proof_of_inclusion_by_hash(data_store: DataStore, tree_id: bytes3
proof = await data_store.get_proof_of_inclusion_by_hash(node_hash=node.hash, tree_id=tree_id)
print(node)
await _debug_dump(db=data_store.db)
await _debug_dump(db=data_store.db_wrapper)
expected_layers = [
ProofOfInclusionLayer(
@ -862,9 +866,9 @@ invalid_program_hex = b"\xab\xcd".hex()
async def test_check_roots_are_incrementing_missing_zero(raw_data_store: DataStore) -> None:
tree_id = hexstr_to_bytes("c954ab71ffaf5b0f129b04b35fdc7c84541f4375167e730e2646bfcfdb7cf2cd")
async with raw_data_store.db_wrapper.locked_transaction():
async with raw_data_store.db_wrapper.writer() as writer:
for generation in range(1, 5):
await raw_data_store.db.execute(
await writer.execute(
"""
INSERT INTO root(tree_id, generation, node_hash, status)
VALUES(:tree_id, :generation, :node_hash, :status)
@ -888,9 +892,9 @@ async def test_check_roots_are_incrementing_missing_zero(raw_data_store: DataSto
async def test_check_roots_are_incrementing_gap(raw_data_store: DataStore) -> None:
tree_id = hexstr_to_bytes("c954ab71ffaf5b0f129b04b35fdc7c84541f4375167e730e2646bfcfdb7cf2cd")
async with raw_data_store.db_wrapper.locked_transaction():
async with raw_data_store.db_wrapper.writer() as writer:
for generation in [*range(5), *range(6, 10)]:
await raw_data_store.db.execute(
await writer.execute(
"""
INSERT INTO root(tree_id, generation, node_hash, status)
VALUES(:tree_id, :generation, :node_hash, :status)
@ -912,8 +916,8 @@ async def test_check_roots_are_incrementing_gap(raw_data_store: DataStore) -> No
@pytest.mark.asyncio
async def test_check_hashes_internal(raw_data_store: DataStore) -> None:
async with raw_data_store.db_wrapper.locked_transaction():
await raw_data_store.db.execute(
async with raw_data_store.db_wrapper.writer() as writer:
await writer.execute(
"INSERT INTO node(hash, node_type, left, right) VALUES(:hash, :node_type, :left, :right)",
{
"hash": a_bytes_32,
@ -932,8 +936,8 @@ async def test_check_hashes_internal(raw_data_store: DataStore) -> None:
@pytest.mark.asyncio
async def test_check_hashes_terminal(raw_data_store: DataStore) -> None:
async with raw_data_store.db_wrapper.locked_transaction():
await raw_data_store.db.execute(
async with raw_data_store.db_wrapper.writer() as writer:
await writer.execute(
"INSERT INTO node(hash, node_type, key, value) VALUES(:hash, :node_type, :key, :value)",
{
"hash": a_bytes_32,
@ -1178,34 +1182,34 @@ async def test_data_server_files(data_store: DataStore, tree_id: bytes32, test_d
db_path = tmp_path.joinpath("dl_server_util.sqlite")
connection = await aiosqlite.connect(db_path)
db_wrapper = DBWrapper(connection)
data_store_server = await DataStore.create(db_wrapper=db_wrapper)
await data_store_server.create_tree(tree_id, status=Status.COMMITTED)
random = Random()
random.seed(100, version=2)
data_store_server = await DataStore.create(database=db_path)
try:
await data_store_server.create_tree(tree_id, status=Status.COMMITTED)
random = Random()
random.seed(100, version=2)
keys: List[bytes] = []
counter = 0
keys: List[bytes] = []
counter = 0
for batch in range(num_batches):
changelist: List[Dict[str, Any]] = []
for operation in range(num_ops_per_batch):
if random.randint(0, 4) > 0 or len(keys) == 0:
key = counter.to_bytes(4, byteorder="big")
value = (2 * counter).to_bytes(4, byteorder="big")
keys.append(key)
changelist.append({"action": "insert", "key": key, "value": value})
else:
key = random.choice(keys)
keys.remove(key)
changelist.append({"action": "delete", "key": key})
counter += 1
await data_store_server.insert_batch(tree_id, changelist, status=Status.COMMITTED)
root = await data_store_server.get_tree_root(tree_id)
await write_files_for_root(data_store_server, tree_id, root, tmp_path)
roots.append(root)
await connection.close()
for batch in range(num_batches):
changelist: List[Dict[str, Any]] = []
for operation in range(num_ops_per_batch):
if random.randint(0, 4) > 0 or len(keys) == 0:
key = counter.to_bytes(4, byteorder="big")
value = (2 * counter).to_bytes(4, byteorder="big")
keys.append(key)
changelist.append({"action": "insert", "key": key, "value": value})
else:
key = random.choice(keys)
keys.remove(key)
changelist.append({"action": "delete", "key": key})
counter += 1
await data_store_server.insert_batch(tree_id, changelist, status=Status.COMMITTED)
root = await data_store_server.get_tree_root(tree_id)
await write_files_for_root(data_store_server, tree_id, root, tmp_path)
roots.append(root)
finally:
await data_store_server.close()
generation = 1
assert len(roots) == num_batches

View File

@ -18,9 +18,9 @@ async def test_node_update_fails(data_store: DataStore, tree_id: bytes32) -> Non
await add_01234567_example(data_store=data_store, tree_id=tree_id)
node = await data_store.get_node_by_key(key=b"\x04", tree_id=tree_id)
async with data_store.db_wrapper.locked_transaction():
async with data_store.db_wrapper.writer() as writer:
with pytest.raises(sqlite3.IntegrityError, match=r"^updates not allowed to the node table$"):
await data_store.db.execute(
await writer.execute(
"UPDATE node SET value = :value WHERE hash == :hash",
{
"hash": node.hash,
@ -39,9 +39,9 @@ async def test_node_hash_must_be_32(
) -> None:
valid_node_values["hash"] = bytes([0] * length)
async with data_store.db_wrapper.locked_transaction():
async with data_store.db_wrapper.writer() as writer:
with pytest.raises(sqlite3.IntegrityError, match=r"^CHECK constraint failed:"):
await data_store.db.execute(
await writer.execute(
"""
INSERT INTO node(hash, node_type, left, right, key, value)
VALUES(:hash, :node_type, :left, :right, :key, :value)
@ -58,9 +58,9 @@ async def test_node_hash_must_not_be_null(
) -> None:
valid_node_values["hash"] = None
async with data_store.db_wrapper.locked_transaction():
async with data_store.db_wrapper.writer() as writer:
with pytest.raises(sqlite3.IntegrityError, match=r"^NOT NULL constraint failed: node.hash$"):
await data_store.db.execute(
await writer.execute(
"""
INSERT INTO node(hash, node_type, left, right, key, value)
VALUES(:hash, :node_type, :left, :right, :key, :value)
@ -78,9 +78,9 @@ async def test_node_type_must_be_valid(
) -> None:
valid_node_values["node_type"] = bad_node_type
async with data_store.db_wrapper.locked_transaction():
async with data_store.db_wrapper.writer() as writer:
with pytest.raises(sqlite3.IntegrityError, match=r"^CHECK constraint failed:"):
await data_store.db.execute(
await writer.execute(
"""
INSERT INTO node(hash, node_type, left, right, key, value)
VALUES(:hash, :node_type, :left, :right, :key, :value)
@ -103,9 +103,9 @@ async def test_node_internal_child_not_null(data_store: DataStore, tree_id: byte
elif side == Side.RIGHT:
values["right"] = None
async with data_store.db_wrapper.locked_transaction():
async with data_store.db_wrapper.writer() as writer:
with pytest.raises(sqlite3.IntegrityError, match=r"^CHECK constraint failed:"):
await data_store.db.execute(
await writer.execute(
"""
INSERT INTO node(hash, node_type, left, right, key, value)
VALUES(:hash, :node_type, :left, :right, :key, :value)
@ -136,9 +136,9 @@ async def test_node_internal_must_be_valid_reference(
else:
assert False
async with data_store.db_wrapper.locked_transaction():
async with data_store.db_wrapper.writer() as writer:
with pytest.raises(sqlite3.IntegrityError, match=r"^FOREIGN KEY constraint failed$"):
await data_store.db.execute(
await writer.execute(
"""
INSERT INTO node(hash, node_type, left, right, key, value)
VALUES(:hash, :node_type, :left, :right, :key, :value)
@ -155,9 +155,9 @@ async def test_node_terminal_key_value_not_null(data_store: DataStore, tree_id:
values = create_valid_node_values(node_type=NodeType.TERMINAL)
values[key_or_value] = None
async with data_store.db_wrapper.locked_transaction():
async with data_store.db_wrapper.writer() as writer:
with pytest.raises(sqlite3.IntegrityError, match=r"^CHECK constraint failed:"):
await data_store.db.execute(
await writer.execute(
"""
INSERT INTO node(hash, node_type, left, right, key, value)
VALUES(:hash, :node_type, :left, :right, :key, :value)
@ -173,9 +173,9 @@ async def test_root_tree_id_must_be_32(data_store: DataStore, tree_id: bytes32,
bad_tree_id = bytes([0] * length)
values = {"tree_id": bad_tree_id, "generation": 0, "node_hash": example.terminal_nodes[0], "status": Status.PENDING}
async with data_store.db_wrapper.locked_transaction():
async with data_store.db_wrapper.writer() as writer:
with pytest.raises(sqlite3.IntegrityError, match=r"^CHECK constraint failed:"):
await data_store.db.execute(
await writer.execute(
"""
INSERT INTO root(tree_id, generation, node_hash, status)
VALUES(:tree_id, :generation, :node_hash, :status)
@ -189,9 +189,9 @@ async def test_root_tree_id_must_not_be_null(data_store: DataStore, tree_id: byt
example = await add_01234567_example(data_store=data_store, tree_id=tree_id)
values = {"tree_id": None, "generation": 0, "node_hash": example.terminal_nodes[0], "status": Status.PENDING}
async with data_store.db_wrapper.locked_transaction():
async with data_store.db_wrapper.writer() as writer:
with pytest.raises(sqlite3.IntegrityError, match=r"^NOT NULL constraint failed: root.tree_id$"):
await data_store.db.execute(
await writer.execute(
"""
INSERT INTO root(tree_id, generation, node_hash, status)
VALUES(:tree_id, :generation, :node_hash, :status)
@ -213,9 +213,9 @@ async def test_root_generation_must_not_be_less_than_zero(
"status": Status.PENDING,
}
async with data_store.db_wrapper.locked_transaction():
async with data_store.db_wrapper.writer() as writer:
with pytest.raises(sqlite3.IntegrityError, match=r"^CHECK constraint failed:"):
await data_store.db.execute(
await writer.execute(
"""
INSERT INTO root(tree_id, generation, node_hash, status)
VALUES(:tree_id, :generation, :node_hash, :status)
@ -234,9 +234,9 @@ async def test_root_generation_must_not_be_null(data_store: DataStore, tree_id:
"status": Status.PENDING,
}
async with data_store.db_wrapper.locked_transaction():
async with data_store.db_wrapper.writer() as writer:
with pytest.raises(sqlite3.IntegrityError, match=r"^NOT NULL constraint failed: root.generation$"):
await data_store.db.execute(
await writer.execute(
"""
INSERT INTO root(tree_id, generation, node_hash, status)
VALUES(:tree_id, :generation, :node_hash, :status)
@ -249,9 +249,9 @@ async def test_root_generation_must_not_be_null(data_store: DataStore, tree_id:
async def test_root_node_hash_must_reference(data_store: DataStore) -> None:
values = {"tree_id": bytes32([0] * 32), "generation": 0, "node_hash": bytes32([0] * 32), "status": Status.PENDING}
async with data_store.db_wrapper.locked_transaction():
async with data_store.db_wrapper.writer() as writer:
with pytest.raises(sqlite3.IntegrityError, match=r"^FOREIGN KEY constraint failed$"):
await data_store.db.execute(
await writer.execute(
"""
INSERT INTO root(tree_id, generation, node_hash, status)
VALUES(:tree_id, :generation, :node_hash, :status)
@ -271,9 +271,9 @@ async def test_root_status_must_be_valid(data_store: DataStore, tree_id: bytes32
"status": bad_status,
}
async with data_store.db_wrapper.locked_transaction():
async with data_store.db_wrapper.writer() as writer:
with pytest.raises(sqlite3.IntegrityError, match=r"^CHECK constraint failed:"):
await data_store.db.execute(
await writer.execute(
"""
INSERT INTO root(tree_id, generation, node_hash, status)
VALUES(:tree_id, :generation, :node_hash, :status)
@ -287,9 +287,9 @@ async def test_root_status_must_not_be_null(data_store: DataStore, tree_id: byte
example = await add_01234567_example(data_store=data_store, tree_id=tree_id)
values = {"tree_id": bytes32([0] * 32), "generation": 0, "node_hash": example.terminal_nodes[0], "status": None}
async with data_store.db_wrapper.locked_transaction():
async with data_store.db_wrapper.writer() as writer:
with pytest.raises(sqlite3.IntegrityError, match=r"^NOT NULL constraint failed: root.status$"):
await data_store.db.execute(
await writer.execute(
"""
INSERT INTO root(tree_id, generation, node_hash, status)
VALUES(:tree_id, :generation, :node_hash, :status)
@ -303,9 +303,9 @@ async def test_root_tree_id_generation_must_be_unique(data_store: DataStore, tre
example = await add_01234567_example(data_store=data_store, tree_id=tree_id)
values = {"tree_id": tree_id, "generation": 0, "node_hash": example.terminal_nodes[0], "status": Status.COMMITTED}
async with data_store.db_wrapper.locked_transaction():
async with data_store.db_wrapper.writer() as writer:
with pytest.raises(sqlite3.IntegrityError, match=r"^UNIQUE constraint failed: root.tree_id, root.generation$"):
await data_store.db.execute(
await writer.execute(
"""
INSERT INTO root(tree_id, generation, node_hash, status)
VALUES(:tree_id, :generation, :node_hash, :status)
@ -321,10 +321,10 @@ async def test_ancestors_ancestor_must_be_32(
tree_id: bytes32,
length: int,
) -> None:
async with data_store.db_wrapper.locked_transaction():
async with data_store.db_wrapper.writer() as writer:
node_hash = await data_store._insert_terminal_node(key=b"\x00", value=b"\x01")
with pytest.raises(sqlite3.IntegrityError, match=r"^CHECK constraint failed:"):
await data_store.db.execute(
await writer.execute(
"""
INSERT INTO ancestors(hash, ancestor, tree_id, generation)
VALUES(:hash, :ancestor, :tree_id, :generation)
@ -340,10 +340,10 @@ async def test_ancestors_tree_id_must_be_32(
tree_id: bytes32,
length: int,
) -> None:
async with data_store.db_wrapper.locked_transaction():
async with data_store.db_wrapper.writer() as writer:
node_hash = await data_store._insert_terminal_node(key=b"\x00", value=b"\x01")
with pytest.raises(sqlite3.IntegrityError, match=r"^CHECK constraint failed:"):
await data_store.db.execute(
await writer.execute(
"""
INSERT INTO ancestors(hash, ancestor, tree_id, generation)
VALUES(:hash, :ancestor, :tree_id, :generation)
@ -359,9 +359,9 @@ async def test_subscriptions_tree_id_must_be_32(
tree_id: bytes32,
length: int,
) -> None:
async with data_store.db_wrapper.locked_transaction():
async with data_store.db_wrapper.writer() as writer:
with pytest.raises(sqlite3.IntegrityError, match=r"^CHECK constraint failed:"):
await data_store.db.execute(
await writer.execute(
"""
INSERT INTO subscriptions(tree_id, url, ignore_till, num_consecutive_failures, from_wallet)
VALUES(:tree_id, :url, :ignore_till, :num_consecutive_failures, :from_wallet)

View File

@ -1,8 +1,6 @@
from pathlib import Path
from chia.util.db_wrapper import DBWrapper2
from chia.util.db_wrapper import DBWrapper
import tempfile
import aiosqlite
class DBConnection:
@ -20,18 +18,3 @@ class DBConnection:
async def __aexit__(self, exc_t, exc_v, exc_tb) -> None:
await self._db_wrapper.close()
self.db_path.unlink()
# This is just here until all DBWrappers have been upgraded to DBWrapper2
class DBConnection1:
async def __aenter__(self) -> DBWrapper:
self.db_path = Path(tempfile.NamedTemporaryFile().name)
if self.db_path.exists():
self.db_path.unlink()
self._db_connection = await aiosqlite.connect(self.db_path)
self._db_wrapper = DBWrapper(self._db_connection)
return self._db_wrapper
async def __aexit__(self, exc_t, exc_v, exc_tb) -> None:
await self._db_connection.close()
self.db_path.unlink()