diff --git a/chia/full_node/full_node.py b/chia/full_node/full_node.py index 626f297afa3f..66ca7c9037a1 100644 --- a/chia/full_node/full_node.py +++ b/chia/full_node/full_node.py @@ -35,7 +35,7 @@ from chia.full_node.hint_store import HintStore from chia.full_node.mempool_manager import MempoolManager from chia.full_node.signage_point import SignagePoint from chia.full_node.subscriptions import PeerSubscriptions -from chia.full_node.sync_store import SyncStore +from chia.full_node.sync_store import Peak, SyncStore from chia.full_node.tx_processing_queue import TransactionQueue from chia.full_node.weight_proof import WeightProofHandler from chia.protocols import farmer_protocol, full_node_protocol, timelord_protocol, wallet_protocol @@ -89,6 +89,14 @@ class PeakPostProcessingResult: lookup_coin_ids: List[bytes32] # The coin IDs that we need to look up to notify wallets of changes +@dataclasses.dataclass(frozen=True) +class WalletUpdate: + fork_height: uint32 + peak: Peak + coin_records: List[CoinRecord] + hints: Dict[bytes32, bytes32] + + class FullNode: _segment_task: Optional[asyncio.Task[None]] initialized: bool @@ -128,6 +136,8 @@ class FullNode: _timelord_lock: Optional[asyncio.Lock] weight_proof_handler: Optional[WeightProofHandler] bad_peak_cache: Dict[bytes32, uint32] # hashes of peaks that failed long sync on chip13 Validation + wallet_sync_queue: asyncio.Queue[WalletUpdate] + wallet_sync_task: Optional[asyncio.Task[None]] @property def server(self) -> ChiaServer: @@ -191,6 +201,8 @@ class FullNode: self._timelord_lock = None self.weight_proof_handler = None self.bad_peak_cache = {} + self.wallet_sync_queue = asyncio.Queue() + self.wallet_sync_task = None @property def block_store(self) -> BlockStore: @@ -415,6 +427,9 @@ class FullNode: sanitize_weight_proof_only, ) ) + if self.wallet_sync_task is None or self.wallet_sync_task.done(): + self.wallet_sync_task = asyncio.create_task(self._wallets_sync_task_handler()) + self.initialized = True if self.full_node_peers is not None: asyncio.create_task(self.full_node_peers.start()) @@ -872,6 +887,7 @@ class FullNode: self.uncompact_task.cancel() if self._transaction_queue_task is not None: self._transaction_queue_task.cancel() + cancel_task_safe(task=self.wallet_sync_task, log=self.log) cancel_task_safe(task=self._sync_task, log=self.log) async def _await_closed(self) -> None: @@ -1122,52 +1138,52 @@ class FullNode: return [] return [c for c in self.server.all_connections.values() if c.peer_node_id in peer_ids] - async def update_wallets( - self, - state_change_summary: StateChangeSummary, - hints: List[Tuple[bytes32, bytes]], - lookup_coin_ids: List[bytes32], - ) -> None: - # Looks up coin records in DB for the coins that wallets are interested in - new_states: List[CoinRecord] = await self.coin_store.get_coin_records(lookup_coin_ids) - - # Re-arrange to a map, and filter out any non-ph sized hint - coin_id_to_ph_hint: Dict[bytes32, bytes32] = { - coin_id: bytes32(hint) for coin_id, hint in hints if len(hint) == 32 - } + async def _wallets_sync_task_handler(self) -> None: + while not self._shut_down: + try: + wallet_update = await self.wallet_sync_queue.get() + await self.update_wallets(wallet_update) + except Exception: + self.log.exception("Wallet sync task failure") + continue + async def update_wallets(self, wallet_update: WalletUpdate) -> None: + self.log.debug( + f"update_wallets - fork_height: {wallet_update.fork_height}, peak_height: {wallet_update.peak.height}" + ) changes_for_peer: Dict[bytes32, Set[CoinState]] = {} - for coin_record in state_change_summary.rolled_back_records + new_states: - cr_name: bytes32 = coin_record.name - - for peer in self.subscriptions.peers_for_coin_id(cr_name): - if peer not in changes_for_peer: - changes_for_peer[peer] = set() - changes_for_peer[peer].add(coin_record.coin_state) - - for peer in self.subscriptions.peers_for_puzzle_hash(coin_record.coin.puzzle_hash): - if peer not in changes_for_peer: - changes_for_peer[peer] = set() - changes_for_peer[peer].add(coin_record.coin_state) - - if cr_name in coin_id_to_ph_hint: - for peer in self.subscriptions.peers_for_puzzle_hash(coin_id_to_ph_hint[cr_name]): - if peer not in changes_for_peer: - changes_for_peer[peer] = set() - changes_for_peer[peer].add(coin_record.coin_state) + for coin_record in wallet_update.coin_records: + coin_id = coin_record.name + subscribed_peers = self.subscriptions.peers_for_coin_id(coin_id) + subscribed_peers.update(self.subscriptions.peers_for_puzzle_hash(coin_record.coin.puzzle_hash)) + hint = wallet_update.hints.get(coin_id) + if hint is not None: + subscribed_peers.update(self.subscriptions.peers_for_puzzle_hash(hint)) + for peer in subscribed_peers: + changes_for_peer.setdefault(peer, set()).add(coin_record.coin_state) for peer, changes in changes_for_peer.items(): - if peer not in self.server.all_connections: - continue - ws_peer: WSChiaConnection = self.server.all_connections[peer] - state = CoinStateUpdate( - state_change_summary.peak.height, - state_change_summary.fork_height, - state_change_summary.peak.header_hash, - list(changes), - ) - msg = make_msg(ProtocolMessageTypes.coin_state_update, state) - await ws_peer.send_message(msg) + connection = self.server.all_connections.get(peer) + if connection is not None: + state = CoinStateUpdate( + wallet_update.peak.height, + wallet_update.fork_height, + wallet_update.peak.header_hash, + list(changes), + ) + await connection.send_message(make_msg(ProtocolMessageTypes.coin_state_update, state)) + + # Tell wallets about the new peak + new_peak_message = make_msg( + ProtocolMessageTypes.new_peak_wallet, + wallet_protocol.NewPeakWallet( + wallet_update.peak.header_hash, + wallet_update.peak.height, + wallet_update.peak.weight, + wallet_update.fork_height, + ), + ) + await self.server.send_to_all([new_peak_message], NodeType.WALLET) async def add_block_batch( self, @@ -1529,18 +1545,26 @@ class FullNode: else: await self.server.send_to_all([msg], NodeType.FULL_NODE) - # Tell wallets about the new peak - msg = make_msg( - ProtocolMessageTypes.new_peak_wallet, - wallet_protocol.NewPeakWallet( - record.header_hash, - record.height, - record.weight, - state_change_summary.fork_height, - ), + coin_hints: Dict[bytes32, bytes32] = { + coin_id: bytes32(hint) for coin_id, hint in ppp_result.hints if len(hint) == 32 + } + + peak = Peak( + state_change_summary.peak.header_hash, state_change_summary.peak.height, state_change_summary.peak.weight ) - await self.update_wallets(state_change_summary, ppp_result.hints, ppp_result.lookup_coin_ids) - await self.server.send_to_all([msg], NodeType.WALLET) + + # Looks up coin records in DB for the coins that wallets are interested in + new_states = await self.coin_store.get_coin_records(ppp_result.lookup_coin_ids) + + await self.wallet_sync_queue.put( + WalletUpdate( + state_change_summary.fork_height, + peak, + state_change_summary.rolled_back_records + new_states, + coin_hints, + ) + ) + self._state_changed("new_peak") async def add_block( diff --git a/tests/core/full_node/test_full_node.py b/tests/core/full_node/test_full_node.py index 67f410cd4c99..448443cdde9f 100644 --- a/tests/core/full_node/test_full_node.py +++ b/tests/core/full_node/test_full_node.py @@ -2,6 +2,7 @@ from __future__ import annotations import asyncio import dataclasses +import logging import random import time from secrets import token_bytes @@ -13,8 +14,10 @@ from clvm.casts import int_to_bytes from chia.consensus.pot_iterations import is_overflow_block from chia.full_node.bundle_tools import detect_potential_template_generator +from chia.full_node.full_node import WalletUpdate from chia.full_node.full_node_api import FullNodeAPI from chia.full_node.signage_point import SignagePoint +from chia.full_node.sync_store import Peak from chia.protocols import full_node_protocol from chia.protocols import full_node_protocol as fnp from chia.protocols import timelord_protocol, wallet_protocol @@ -28,6 +31,7 @@ from chia.server.server import ChiaServer from chia.simulator.block_tools import BlockTools, create_block_tools_async, get_signage_point from chia.simulator.full_node_simulator import FullNodeSimulator from chia.simulator.keyring import TempKeyring +from chia.simulator.setup_nodes import SimulatorsAndWalletsServices from chia.simulator.setup_services import setup_full_node from chia.simulator.simulator_protocol import FarmNewBlockProtocol from chia.simulator.time_out_assert import time_out_assert, time_out_assert_custom_interval, time_out_messages @@ -37,6 +41,7 @@ from chia.types.blockchain_format.program import Program from chia.types.blockchain_format.proof_of_space import ProofOfSpace, calculate_plot_id_pk, calculate_pos_challenge from chia.types.blockchain_format.reward_chain_block import RewardChainBlockUnfinished from chia.types.blockchain_format.serialized_program import SerializedProgram +from chia.types.blockchain_format.sized_bytes import bytes32 from chia.types.blockchain_format.vdf import CompressibleVDFField, VDFProof from chia.types.coin_spend import CoinSpend from chia.types.condition_opcodes import ConditionOpcode @@ -48,7 +53,7 @@ from chia.types.spend_bundle import SpendBundle from chia.types.unfinished_block import UnfinishedBlock from chia.util.errors import ConsensusError, Err from chia.util.hash import std_hash -from chia.util.ints import uint8, uint16, uint32, uint64 +from chia.util.ints import uint8, uint16, uint32, uint64, uint128 from chia.util.limited_semaphore import LimitedSemaphore from chia.util.recursive_replace import recursive_replace from chia.util.vdf_prover import get_vdf_info_and_proof @@ -2080,3 +2085,29 @@ async def test_node_start_with_existing_blocks(db_version: int) -> None: assert block_record is not None, f"block_record is None on cycle {cycle + 1}" assert block_record.height == expected_height, f"wrong height on cycle {cycle + 1}" + + +@pytest.mark.asyncio +async def test_wallet_sync_task_failure( + one_node: SimulatorsAndWalletsServices, caplog: pytest.LogCaptureFixture +) -> None: + [full_node_service], _, _ = one_node + full_node = full_node_service._node + assert full_node.wallet_sync_task is not None + caplog.set_level(logging.DEBUG) + peak = Peak(bytes32(32 * b"0"), uint32(0), uint128(0)) + # WalletUpdate with invalid args to force an exception in FullNode.update_wallets / FullNode.wallet_sync_task + bad_wallet_update = WalletUpdate(-10, peak, [], {}) # type: ignore[arg-type] + await full_node.wallet_sync_queue.put(bad_wallet_update) + await time_out_assert(30, full_node.wallet_sync_queue.empty) + assert "update_wallets - fork_height: -10, peak_height: 0" in caplog.text + assert "Wallet sync task failure" in caplog.text + assert not full_node.wallet_sync_task.done() + caplog.clear() + # WalletUpdate with valid args to test continued processing after failure + good_wallet_update = WalletUpdate(uint32(10), peak, [], {}) + await full_node.wallet_sync_queue.put(good_wallet_update) + await time_out_assert(30, full_node.wallet_sync_queue.empty) + assert "update_wallets - fork_height: 10, peak_height: 0" in caplog.text + assert "Wallet sync task failure" not in caplog.text + assert not full_node.wallet_sync_task.done()