Merge commit '7c5a4b2199cfb69b1974098bca0ff3f11c78bdf2' into checkpoint/long_lived_atari_from_main_7c5a4b2199cfb69b1974098bca0ff3f11c78bdf2

This commit is contained in:
Amine Khaldi 2022-07-31 20:56:22 +01:00
commit 6fd10bbfef
No known key found for this signature in database
GPG Key ID: B1C074FFC904E2D9
4 changed files with 66 additions and 71 deletions

View File

@ -180,7 +180,7 @@ class PoolWallet:
raise ValueError(f"Invalid internal Pool State: {err}: {initial_target_state}")
async def get_spend_history(self) -> List[Tuple[uint32, CoinSpend]]:
return self.wallet_state_manager.pool_store.get_spends_for_wallet(self.wallet_id)
return await self.wallet_state_manager.pool_store.get_spends_for_wallet(self.wallet_id)
async def get_current_state(self) -> PoolWalletInfo:
history: List[Tuple[uint32, CoinSpend]] = await self.get_spend_history()
@ -228,7 +228,7 @@ class PoolWallet:
return await self.wallet_state_manager.tx_store.get_unconfirmed_for_wallet(self.wallet_id)
async def get_tip(self) -> Tuple[uint32, CoinSpend]:
return self.wallet_state_manager.pool_store.get_spends_for_wallet(self.wallet_id)[-1]
return (await self.wallet_state_manager.pool_store.get_spends_for_wallet(self.wallet_id))[-1]
async def update_pool_config(self) -> None:
current_state: PoolWalletInfo = await self.get_current_state()
@ -286,7 +286,9 @@ class PoolWallet:
self.log.info(f"New PoolWallet singleton tip_coin: {tip_spend} farmed at height {block_height}")
# If we have reached the target state, resets it to None. Loops back to get current state
for _, added_spend in reversed(self.wallet_state_manager.pool_store.get_spends_for_wallet(self.wallet_id)):
for _, added_spend in reversed(
await self.wallet_state_manager.pool_store.get_spends_for_wallet(self.wallet_id)
):
latest_state: Optional[PoolState] = solution_to_pool_state(added_spend)
if latest_state is not None:
if self.target_state == latest_state:
@ -303,9 +305,9 @@ class PoolWallet:
Returns True if the wallet should be removed.
"""
try:
history: List[Tuple[uint32, CoinSpend]] = self.wallet_state_manager.pool_store.get_spends_for_wallet(
history: List[Tuple[uint32, CoinSpend]] = await self.wallet_state_manager.pool_store.get_spends_for_wallet(
self.wallet_id
).copy()
)
prev_state: PoolWalletInfo = await self.get_current_state()
await self.wallet_state_manager.pool_store.rollback(block_height, self.wallet_id, in_transaction)

View File

@ -512,7 +512,6 @@ class WalletNode:
tb = traceback.format_exc()
self.log.error(f"Exception while perform_atomic_rollback: {e} {tb}")
await self.wallet_state_manager.db_wrapper.rollback_transaction()
await self.wallet_state_manager.pool_store.rebuild_cache()
raise
else:
await self.wallet_state_manager.blockchain.clean_block_records()
@ -713,7 +712,6 @@ class WalletNode:
tb = traceback.format_exc()
self.log.error(f"Exception while adding state: {e} {tb}")
await self.wallet_state_manager.db_wrapper.rollback_transaction()
await self.wallet_state_manager.pool_store.rebuild_cache()
else:
await self.wallet_state_manager.blockchain.clean_block_records()
@ -748,7 +746,6 @@ class WalletNode:
await self.wallet_state_manager.db_wrapper.commit_transaction()
except Exception as e:
await self.wallet_state_manager.db_wrapper.rollback_transaction()
await self.wallet_state_manager.pool_store.rebuild_cache()
tb = traceback.format_exc()
self.log.error(f"Error adding states.. {e} {tb}")
return False

View File

@ -1,5 +1,5 @@
import logging
from typing import List, Tuple, Dict, Optional
from typing import List, Tuple
import aiosqlite
@ -13,7 +13,6 @@ log = logging.getLogger(__name__)
class WalletPoolStore:
db_connection: aiosqlite.Connection
db_wrapper: DBWrapper
_state_transitions_cache: Dict[int, List[Tuple[uint32, CoinSpend]]]
@classmethod
async def create(cls, wrapper: DBWrapper):
@ -23,11 +22,15 @@ class WalletPoolStore:
self.db_wrapper = wrapper
await self.db_connection.execute(
"CREATE TABLE IF NOT EXISTS pool_state_transitions(transition_index integer, wallet_id integer, "
"height bigint, coin_spend blob, PRIMARY KEY(transition_index, wallet_id))"
"CREATE TABLE IF NOT EXISTS pool_state_transitions("
" transition_index integer,"
" wallet_id integer,"
" height bigint,"
" coin_spend blob,"
" PRIMARY KEY(transition_index, wallet_id))"
)
await self.db_connection.commit()
await self.rebuild_cache()
return self
async def _clear_database(self):
@ -51,28 +54,48 @@ class WalletPoolStore:
if not in_transaction:
await self.db_wrapper.lock.acquire()
try:
if wallet_id not in self._state_transitions_cache:
self._state_transitions_cache[wallet_id] = []
all_state_transitions: List[Tuple[uint32, CoinSpend]] = self.get_spends_for_wallet(wallet_id)
# find the most recent transition in wallet_id
rows = list(
await self.db_connection.execute_fetchall(
"SELECT transition_index, height, coin_spend "
"FROM pool_state_transitions "
"WHERE wallet_id=? "
"ORDER BY transition_index DESC "
"LIMIT 1",
(wallet_id,),
)
)
serialized_spend = bytes(spend)
if len(rows) == 0:
transition_index = 0
else:
existing = list(
await self.db_connection.execute_fetchall(
"SELECT COUNT(*) "
"FROM pool_state_transitions "
"WHERE wallet_id=? AND height=? AND coin_spend=?",
(wallet_id, height, serialized_spend),
)
)
if existing[0][0] != 0:
# we already have this transition in the DB
return
if (height, spend) in all_state_transitions:
return
if len(all_state_transitions) > 0:
if height < all_state_transitions[-1][0]:
row = rows[0]
if height < row[1]:
raise ValueError("Height cannot go down")
if spend.coin.parent_coin_info != all_state_transitions[-1][1].coin.name():
prev = CoinSpend.from_bytes(row[2])
if spend.coin.parent_coin_info != prev.coin.name():
raise ValueError("New spend does not extend")
all_state_transitions.append((height, spend))
transition_index = row[0]
cursor = await self.db_connection.execute(
"INSERT OR REPLACE INTO pool_state_transitions VALUES (?, ?, ?, ?)",
"INSERT OR IGNORE INTO pool_state_transitions VALUES (?, ?, ?, ?)",
(
len(all_state_transitions) - 1,
transition_index + 1,
wallet_id,
height,
bytes(spend),
serialized_spend,
),
)
await cursor.close()
@ -81,27 +104,16 @@ class WalletPoolStore:
await self.db_connection.commit()
self.db_wrapper.lock.release()
def get_spends_for_wallet(self, wallet_id: int) -> List[Tuple[uint32, CoinSpend]]:
async def get_spends_for_wallet(self, wallet_id: int) -> List[Tuple[uint32, CoinSpend]]:
"""
Retrieves all entries for a wallet ID from the cache, works even if commit is not called yet.
Retrieves all entries for a wallet ID.
"""
return self._state_transitions_cache.get(wallet_id, [])
async def rebuild_cache(self) -> None:
"""
This resets the cache, and loads all entries from the DB. Any entries in the cache that were not committed
are removed. This can happen if a state transition in wallet_blockchain fails.
"""
cursor = await self.db_connection.execute("SELECT * FROM pool_state_transitions ORDER BY transition_index")
rows = await cursor.fetchall()
await cursor.close()
self._state_transitions_cache = {}
for row in rows:
_, wallet_id, height, coin_spend_bytes = row
coin_spend: CoinSpend = CoinSpend.from_bytes(coin_spend_bytes)
if wallet_id not in self._state_transitions_cache:
self._state_transitions_cache[wallet_id] = []
self._state_transitions_cache[wallet_id].append((height, coin_spend))
rows = await self.db_connection.execute_fetchall(
"SELECT height, coin_spend FROM pool_state_transitions WHERE wallet_id=? ORDER BY transition_index",
(wallet_id,),
)
return [(uint32(row[0]), CoinSpend.from_bytes(row[1])) for row in rows]
async def rollback(self, height: int, wallet_id_arg: int, in_transaction: bool) -> None:
"""
@ -113,14 +125,6 @@ class WalletPoolStore:
if not in_transaction:
await self.db_wrapper.lock.acquire()
try:
for wallet_id, items in self._state_transitions_cache.items():
remove_index_start: Optional[int] = None
for i, (item_block_height, _) in enumerate(items):
if item_block_height > height and wallet_id == wallet_id_arg:
remove_index_start = i
break
if remove_index_start is not None:
del items[remove_index_start:]
cursor = await self.db_connection.execute(
"DELETE FROM pool_state_transitions WHERE height>? AND wallet_id=?", (height, wallet_id_arg)
)

View File

@ -51,38 +51,35 @@ class TestWalletPoolStore:
solution_0_alt: CoinSpend = make_child_solution(None, coin_0_alt)
solution_1: CoinSpend = make_child_solution(solution_0)
assert store.get_spends_for_wallet(0) == []
assert store.get_spends_for_wallet(1) == []
assert await store.get_spends_for_wallet(0) == []
assert await store.get_spends_for_wallet(1) == []
await store.add_spend(1, solution_1, 100, True)
assert store.get_spends_for_wallet(1) == [(100, solution_1)]
assert await store.get_spends_for_wallet(1) == [(100, solution_1)]
# Idempotent
await store.add_spend(1, solution_1, 100, True)
assert store.get_spends_for_wallet(1) == [(100, solution_1)]
assert await store.get_spends_for_wallet(1) == [(100, solution_1)]
with pytest.raises(ValueError):
await store.add_spend(1, solution_1, 101, True)
# Rebuild cache, no longer present
await db_wrapper.rollback_transaction()
await store.rebuild_cache()
assert store.get_spends_for_wallet(1) == []
assert await store.get_spends_for_wallet(1) == []
await store.rebuild_cache()
await store.add_spend(1, solution_1, 100, False)
assert store.get_spends_for_wallet(1) == [(100, solution_1)]
assert await store.get_spends_for_wallet(1) == [(100, solution_1)]
solution_1_alt: CoinSpend = make_child_solution(solution_0_alt)
with pytest.raises(ValueError):
await store.add_spend(1, solution_1_alt, 100, False)
assert store.get_spends_for_wallet(1) == [(100, solution_1)]
assert await store.get_spends_for_wallet(1) == [(100, solution_1)]
solution_2: CoinSpend = make_child_solution(solution_1)
await store.add_spend(1, solution_2, 100, False)
await store.rebuild_cache()
solution_3: CoinSpend = make_child_solution(solution_2)
await store.add_spend(1, solution_3, 100)
solution_4: CoinSpend = make_child_solution(solution_3)
@ -90,21 +87,16 @@ class TestWalletPoolStore:
with pytest.raises(ValueError):
await store.add_spend(1, solution_4, 99)
await store.rebuild_cache()
await store.add_spend(1, solution_4, 101)
await store.rebuild_cache()
await store.rollback(101, 1, False)
await store.rebuild_cache()
assert store.get_spends_for_wallet(1) == [
assert await store.get_spends_for_wallet(1) == [
(100, solution_1),
(100, solution_2),
(100, solution_3),
(101, solution_4),
]
await store.rebuild_cache()
await store.rollback(100, 1, False)
await store.rebuild_cache()
assert store.get_spends_for_wallet(1) == [
assert await store.get_spends_for_wallet(1) == [
(100, solution_1),
(100, solution_2),
(100, solution_3),
@ -116,7 +108,7 @@ class TestWalletPoolStore:
solution_5: CoinSpend = make_child_solution(solution_4)
await store.add_spend(1, solution_5, 105)
await store.rollback(99, 1, False)
assert store.get_spends_for_wallet(1) == []
assert await store.get_spends_for_wallet(1) == []
finally:
await db_connection.close()