diff --git a/chia/pools/pool_wallet.py b/chia/pools/pool_wallet.py index b41a2585cb069..f2ca301036362 100644 --- a/chia/pools/pool_wallet.py +++ b/chia/pools/pool_wallet.py @@ -180,7 +180,7 @@ class PoolWallet: raise ValueError(f"Invalid internal Pool State: {err}: {initial_target_state}") async def get_spend_history(self) -> List[Tuple[uint32, CoinSpend]]: - return self.wallet_state_manager.pool_store.get_spends_for_wallet(self.wallet_id) + return await self.wallet_state_manager.pool_store.get_spends_for_wallet(self.wallet_id) async def get_current_state(self) -> PoolWalletInfo: history: List[Tuple[uint32, CoinSpend]] = await self.get_spend_history() @@ -228,7 +228,7 @@ class PoolWallet: return await self.wallet_state_manager.tx_store.get_unconfirmed_for_wallet(self.wallet_id) async def get_tip(self) -> Tuple[uint32, CoinSpend]: - return self.wallet_state_manager.pool_store.get_spends_for_wallet(self.wallet_id)[-1] + return (await self.wallet_state_manager.pool_store.get_spends_for_wallet(self.wallet_id))[-1] async def update_pool_config(self) -> None: current_state: PoolWalletInfo = await self.get_current_state() @@ -286,7 +286,9 @@ class PoolWallet: self.log.info(f"New PoolWallet singleton tip_coin: {tip_spend} farmed at height {block_height}") # If we have reached the target state, resets it to None. Loops back to get current state - for _, added_spend in reversed(self.wallet_state_manager.pool_store.get_spends_for_wallet(self.wallet_id)): + for _, added_spend in reversed( + await self.wallet_state_manager.pool_store.get_spends_for_wallet(self.wallet_id) + ): latest_state: Optional[PoolState] = solution_to_pool_state(added_spend) if latest_state is not None: if self.target_state == latest_state: @@ -303,9 +305,9 @@ class PoolWallet: Returns True if the wallet should be removed. """ try: - history: List[Tuple[uint32, CoinSpend]] = self.wallet_state_manager.pool_store.get_spends_for_wallet( + history: List[Tuple[uint32, CoinSpend]] = await self.wallet_state_manager.pool_store.get_spends_for_wallet( self.wallet_id - ).copy() + ) prev_state: PoolWalletInfo = await self.get_current_state() await self.wallet_state_manager.pool_store.rollback(block_height, self.wallet_id, in_transaction) diff --git a/chia/wallet/wallet_node.py b/chia/wallet/wallet_node.py index 32dcb670dc5ec..e5fc2fafee480 100644 --- a/chia/wallet/wallet_node.py +++ b/chia/wallet/wallet_node.py @@ -512,7 +512,6 @@ class WalletNode: tb = traceback.format_exc() self.log.error(f"Exception while perform_atomic_rollback: {e} {tb}") await self.wallet_state_manager.db_wrapper.rollback_transaction() - await self.wallet_state_manager.pool_store.rebuild_cache() raise else: await self.wallet_state_manager.blockchain.clean_block_records() @@ -713,7 +712,6 @@ class WalletNode: tb = traceback.format_exc() self.log.error(f"Exception while adding state: {e} {tb}") await self.wallet_state_manager.db_wrapper.rollback_transaction() - await self.wallet_state_manager.pool_store.rebuild_cache() else: await self.wallet_state_manager.blockchain.clean_block_records() @@ -748,7 +746,6 @@ class WalletNode: await self.wallet_state_manager.db_wrapper.commit_transaction() except Exception as e: await self.wallet_state_manager.db_wrapper.rollback_transaction() - await self.wallet_state_manager.pool_store.rebuild_cache() tb = traceback.format_exc() self.log.error(f"Error adding states.. {e} {tb}") return False diff --git a/chia/wallet/wallet_pool_store.py b/chia/wallet/wallet_pool_store.py index 2523dd61c0ff4..2505c60e0cb33 100644 --- a/chia/wallet/wallet_pool_store.py +++ b/chia/wallet/wallet_pool_store.py @@ -1,5 +1,5 @@ import logging -from typing import List, Tuple, Dict, Optional +from typing import List, Tuple import aiosqlite @@ -13,7 +13,6 @@ log = logging.getLogger(__name__) class WalletPoolStore: db_connection: aiosqlite.Connection db_wrapper: DBWrapper - _state_transitions_cache: Dict[int, List[Tuple[uint32, CoinSpend]]] @classmethod async def create(cls, wrapper: DBWrapper): @@ -23,11 +22,15 @@ class WalletPoolStore: self.db_wrapper = wrapper await self.db_connection.execute( - "CREATE TABLE IF NOT EXISTS pool_state_transitions(transition_index integer, wallet_id integer, " - "height bigint, coin_spend blob, PRIMARY KEY(transition_index, wallet_id))" + "CREATE TABLE IF NOT EXISTS pool_state_transitions(" + " transition_index integer," + " wallet_id integer," + " height bigint," + " coin_spend blob," + " PRIMARY KEY(transition_index, wallet_id))" ) + await self.db_connection.commit() - await self.rebuild_cache() return self async def _clear_database(self): @@ -51,28 +54,48 @@ class WalletPoolStore: if not in_transaction: await self.db_wrapper.lock.acquire() try: - if wallet_id not in self._state_transitions_cache: - self._state_transitions_cache[wallet_id] = [] - all_state_transitions: List[Tuple[uint32, CoinSpend]] = self.get_spends_for_wallet(wallet_id) + # find the most recent transition in wallet_id + rows = list( + await self.db_connection.execute_fetchall( + "SELECT transition_index, height, coin_spend " + "FROM pool_state_transitions " + "WHERE wallet_id=? " + "ORDER BY transition_index DESC " + "LIMIT 1", + (wallet_id,), + ) + ) + serialized_spend = bytes(spend) + if len(rows) == 0: + transition_index = 0 + else: + existing = list( + await self.db_connection.execute_fetchall( + "SELECT COUNT(*) " + "FROM pool_state_transitions " + "WHERE wallet_id=? AND height=? AND coin_spend=?", + (wallet_id, height, serialized_spend), + ) + ) + if existing[0][0] != 0: + # we already have this transition in the DB + return - if (height, spend) in all_state_transitions: - return - - if len(all_state_transitions) > 0: - if height < all_state_transitions[-1][0]: + row = rows[0] + if height < row[1]: raise ValueError("Height cannot go down") - if spend.coin.parent_coin_info != all_state_transitions[-1][1].coin.name(): + prev = CoinSpend.from_bytes(row[2]) + if spend.coin.parent_coin_info != prev.coin.name(): raise ValueError("New spend does not extend") - - all_state_transitions.append((height, spend)) + transition_index = row[0] cursor = await self.db_connection.execute( - "INSERT OR REPLACE INTO pool_state_transitions VALUES (?, ?, ?, ?)", + "INSERT OR IGNORE INTO pool_state_transitions VALUES (?, ?, ?, ?)", ( - len(all_state_transitions) - 1, + transition_index + 1, wallet_id, height, - bytes(spend), + serialized_spend, ), ) await cursor.close() @@ -81,27 +104,16 @@ class WalletPoolStore: await self.db_connection.commit() self.db_wrapper.lock.release() - def get_spends_for_wallet(self, wallet_id: int) -> List[Tuple[uint32, CoinSpend]]: + async def get_spends_for_wallet(self, wallet_id: int) -> List[Tuple[uint32, CoinSpend]]: """ - Retrieves all entries for a wallet ID from the cache, works even if commit is not called yet. + Retrieves all entries for a wallet ID. """ - return self._state_transitions_cache.get(wallet_id, []) - async def rebuild_cache(self) -> None: - """ - This resets the cache, and loads all entries from the DB. Any entries in the cache that were not committed - are removed. This can happen if a state transition in wallet_blockchain fails. - """ - cursor = await self.db_connection.execute("SELECT * FROM pool_state_transitions ORDER BY transition_index") - rows = await cursor.fetchall() - await cursor.close() - self._state_transitions_cache = {} - for row in rows: - _, wallet_id, height, coin_spend_bytes = row - coin_spend: CoinSpend = CoinSpend.from_bytes(coin_spend_bytes) - if wallet_id not in self._state_transitions_cache: - self._state_transitions_cache[wallet_id] = [] - self._state_transitions_cache[wallet_id].append((height, coin_spend)) + rows = await self.db_connection.execute_fetchall( + "SELECT height, coin_spend FROM pool_state_transitions WHERE wallet_id=? ORDER BY transition_index", + (wallet_id,), + ) + return [(uint32(row[0]), CoinSpend.from_bytes(row[1])) for row in rows] async def rollback(self, height: int, wallet_id_arg: int, in_transaction: bool) -> None: """ @@ -113,14 +125,6 @@ class WalletPoolStore: if not in_transaction: await self.db_wrapper.lock.acquire() try: - for wallet_id, items in self._state_transitions_cache.items(): - remove_index_start: Optional[int] = None - for i, (item_block_height, _) in enumerate(items): - if item_block_height > height and wallet_id == wallet_id_arg: - remove_index_start = i - break - if remove_index_start is not None: - del items[remove_index_start:] cursor = await self.db_connection.execute( "DELETE FROM pool_state_transitions WHERE height>? AND wallet_id=?", (height, wallet_id_arg) ) diff --git a/tests/pools/test_wallet_pool_store.py b/tests/pools/test_wallet_pool_store.py index 27c3152a2c299..f88397af2caa2 100644 --- a/tests/pools/test_wallet_pool_store.py +++ b/tests/pools/test_wallet_pool_store.py @@ -51,38 +51,35 @@ class TestWalletPoolStore: solution_0_alt: CoinSpend = make_child_solution(None, coin_0_alt) solution_1: CoinSpend = make_child_solution(solution_0) - assert store.get_spends_for_wallet(0) == [] - assert store.get_spends_for_wallet(1) == [] + assert await store.get_spends_for_wallet(0) == [] + assert await store.get_spends_for_wallet(1) == [] await store.add_spend(1, solution_1, 100, True) - assert store.get_spends_for_wallet(1) == [(100, solution_1)] + assert await store.get_spends_for_wallet(1) == [(100, solution_1)] # Idempotent await store.add_spend(1, solution_1, 100, True) - assert store.get_spends_for_wallet(1) == [(100, solution_1)] + assert await store.get_spends_for_wallet(1) == [(100, solution_1)] with pytest.raises(ValueError): await store.add_spend(1, solution_1, 101, True) # Rebuild cache, no longer present await db_wrapper.rollback_transaction() - await store.rebuild_cache() - assert store.get_spends_for_wallet(1) == [] + assert await store.get_spends_for_wallet(1) == [] - await store.rebuild_cache() await store.add_spend(1, solution_1, 100, False) - assert store.get_spends_for_wallet(1) == [(100, solution_1)] + assert await store.get_spends_for_wallet(1) == [(100, solution_1)] solution_1_alt: CoinSpend = make_child_solution(solution_0_alt) with pytest.raises(ValueError): await store.add_spend(1, solution_1_alt, 100, False) - assert store.get_spends_for_wallet(1) == [(100, solution_1)] + assert await store.get_spends_for_wallet(1) == [(100, solution_1)] solution_2: CoinSpend = make_child_solution(solution_1) await store.add_spend(1, solution_2, 100, False) - await store.rebuild_cache() solution_3: CoinSpend = make_child_solution(solution_2) await store.add_spend(1, solution_3, 100) solution_4: CoinSpend = make_child_solution(solution_3) @@ -90,21 +87,16 @@ class TestWalletPoolStore: with pytest.raises(ValueError): await store.add_spend(1, solution_4, 99) - await store.rebuild_cache() await store.add_spend(1, solution_4, 101) - await store.rebuild_cache() await store.rollback(101, 1, False) - await store.rebuild_cache() - assert store.get_spends_for_wallet(1) == [ + assert await store.get_spends_for_wallet(1) == [ (100, solution_1), (100, solution_2), (100, solution_3), (101, solution_4), ] - await store.rebuild_cache() await store.rollback(100, 1, False) - await store.rebuild_cache() - assert store.get_spends_for_wallet(1) == [ + assert await store.get_spends_for_wallet(1) == [ (100, solution_1), (100, solution_2), (100, solution_3), @@ -116,7 +108,7 @@ class TestWalletPoolStore: solution_5: CoinSpend = make_child_solution(solution_4) await store.add_spend(1, solution_5, 105) await store.rollback(99, 1, False) - assert store.get_spends_for_wallet(1) == [] + assert await store.get_spends_for_wallet(1) == [] finally: await db_connection.close()