full_node: Move wallet updates into a separate task (#16238)

This commit is contained in:
dustinface 2023-09-13 20:26:47 +07:00 committed by GitHub
parent d7e9090122
commit 2a1e3ae2fc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 110 additions and 55 deletions

View File

@ -35,7 +35,7 @@ from chia.full_node.hint_store import HintStore
from chia.full_node.mempool_manager import MempoolManager from chia.full_node.mempool_manager import MempoolManager
from chia.full_node.signage_point import SignagePoint from chia.full_node.signage_point import SignagePoint
from chia.full_node.subscriptions import PeerSubscriptions from chia.full_node.subscriptions import PeerSubscriptions
from chia.full_node.sync_store import SyncStore from chia.full_node.sync_store import Peak, SyncStore
from chia.full_node.tx_processing_queue import TransactionQueue from chia.full_node.tx_processing_queue import TransactionQueue
from chia.full_node.weight_proof import WeightProofHandler from chia.full_node.weight_proof import WeightProofHandler
from chia.protocols import farmer_protocol, full_node_protocol, timelord_protocol, wallet_protocol from chia.protocols import farmer_protocol, full_node_protocol, timelord_protocol, wallet_protocol
@ -89,6 +89,14 @@ class PeakPostProcessingResult:
lookup_coin_ids: List[bytes32] # The coin IDs that we need to look up to notify wallets of changes lookup_coin_ids: List[bytes32] # The coin IDs that we need to look up to notify wallets of changes
@dataclasses.dataclass(frozen=True)
class WalletUpdate:
fork_height: uint32
peak: Peak
coin_records: List[CoinRecord]
hints: Dict[bytes32, bytes32]
class FullNode: class FullNode:
_segment_task: Optional[asyncio.Task[None]] _segment_task: Optional[asyncio.Task[None]]
initialized: bool initialized: bool
@ -128,6 +136,8 @@ class FullNode:
_timelord_lock: Optional[asyncio.Lock] _timelord_lock: Optional[asyncio.Lock]
weight_proof_handler: Optional[WeightProofHandler] weight_proof_handler: Optional[WeightProofHandler]
bad_peak_cache: Dict[bytes32, uint32] # hashes of peaks that failed long sync on chip13 Validation bad_peak_cache: Dict[bytes32, uint32] # hashes of peaks that failed long sync on chip13 Validation
wallet_sync_queue: asyncio.Queue[WalletUpdate]
wallet_sync_task: Optional[asyncio.Task[None]]
@property @property
def server(self) -> ChiaServer: def server(self) -> ChiaServer:
@ -191,6 +201,8 @@ class FullNode:
self._timelord_lock = None self._timelord_lock = None
self.weight_proof_handler = None self.weight_proof_handler = None
self.bad_peak_cache = {} self.bad_peak_cache = {}
self.wallet_sync_queue = asyncio.Queue()
self.wallet_sync_task = None
@property @property
def block_store(self) -> BlockStore: def block_store(self) -> BlockStore:
@ -415,6 +427,9 @@ class FullNode:
sanitize_weight_proof_only, sanitize_weight_proof_only,
) )
) )
if self.wallet_sync_task is None or self.wallet_sync_task.done():
self.wallet_sync_task = asyncio.create_task(self._wallets_sync_task_handler())
self.initialized = True self.initialized = True
if self.full_node_peers is not None: if self.full_node_peers is not None:
asyncio.create_task(self.full_node_peers.start()) asyncio.create_task(self.full_node_peers.start())
@ -872,6 +887,7 @@ class FullNode:
self.uncompact_task.cancel() self.uncompact_task.cancel()
if self._transaction_queue_task is not None: if self._transaction_queue_task is not None:
self._transaction_queue_task.cancel() self._transaction_queue_task.cancel()
cancel_task_safe(task=self.wallet_sync_task, log=self.log)
cancel_task_safe(task=self._sync_task, log=self.log) cancel_task_safe(task=self._sync_task, log=self.log)
async def _await_closed(self) -> None: async def _await_closed(self) -> None:
@ -1122,52 +1138,52 @@ class FullNode:
return [] return []
return [c for c in self.server.all_connections.values() if c.peer_node_id in peer_ids] return [c for c in self.server.all_connections.values() if c.peer_node_id in peer_ids]
async def update_wallets( async def _wallets_sync_task_handler(self) -> None:
self, while not self._shut_down:
state_change_summary: StateChangeSummary, try:
hints: List[Tuple[bytes32, bytes]], wallet_update = await self.wallet_sync_queue.get()
lookup_coin_ids: List[bytes32], await self.update_wallets(wallet_update)
) -> None: except Exception:
# Looks up coin records in DB for the coins that wallets are interested in self.log.exception("Wallet sync task failure")
new_states: List[CoinRecord] = await self.coin_store.get_coin_records(lookup_coin_ids) continue
# Re-arrange to a map, and filter out any non-ph sized hint
coin_id_to_ph_hint: Dict[bytes32, bytes32] = {
coin_id: bytes32(hint) for coin_id, hint in hints if len(hint) == 32
}
async def update_wallets(self, wallet_update: WalletUpdate) -> None:
self.log.debug(
f"update_wallets - fork_height: {wallet_update.fork_height}, peak_height: {wallet_update.peak.height}"
)
changes_for_peer: Dict[bytes32, Set[CoinState]] = {} changes_for_peer: Dict[bytes32, Set[CoinState]] = {}
for coin_record in state_change_summary.rolled_back_records + new_states: for coin_record in wallet_update.coin_records:
cr_name: bytes32 = coin_record.name coin_id = coin_record.name
subscribed_peers = self.subscriptions.peers_for_coin_id(coin_id)
for peer in self.subscriptions.peers_for_coin_id(cr_name): subscribed_peers.update(self.subscriptions.peers_for_puzzle_hash(coin_record.coin.puzzle_hash))
if peer not in changes_for_peer: hint = wallet_update.hints.get(coin_id)
changes_for_peer[peer] = set() if hint is not None:
changes_for_peer[peer].add(coin_record.coin_state) subscribed_peers.update(self.subscriptions.peers_for_puzzle_hash(hint))
for peer in subscribed_peers:
for peer in self.subscriptions.peers_for_puzzle_hash(coin_record.coin.puzzle_hash): changes_for_peer.setdefault(peer, set()).add(coin_record.coin_state)
if peer not in changes_for_peer:
changes_for_peer[peer] = set()
changes_for_peer[peer].add(coin_record.coin_state)
if cr_name in coin_id_to_ph_hint:
for peer in self.subscriptions.peers_for_puzzle_hash(coin_id_to_ph_hint[cr_name]):
if peer not in changes_for_peer:
changes_for_peer[peer] = set()
changes_for_peer[peer].add(coin_record.coin_state)
for peer, changes in changes_for_peer.items(): for peer, changes in changes_for_peer.items():
if peer not in self.server.all_connections: connection = self.server.all_connections.get(peer)
continue if connection is not None:
ws_peer: WSChiaConnection = self.server.all_connections[peer] state = CoinStateUpdate(
state = CoinStateUpdate( wallet_update.peak.height,
state_change_summary.peak.height, wallet_update.fork_height,
state_change_summary.fork_height, wallet_update.peak.header_hash,
state_change_summary.peak.header_hash, list(changes),
list(changes), )
) await connection.send_message(make_msg(ProtocolMessageTypes.coin_state_update, state))
msg = make_msg(ProtocolMessageTypes.coin_state_update, state)
await ws_peer.send_message(msg) # Tell wallets about the new peak
new_peak_message = make_msg(
ProtocolMessageTypes.new_peak_wallet,
wallet_protocol.NewPeakWallet(
wallet_update.peak.header_hash,
wallet_update.peak.height,
wallet_update.peak.weight,
wallet_update.fork_height,
),
)
await self.server.send_to_all([new_peak_message], NodeType.WALLET)
async def add_block_batch( async def add_block_batch(
self, self,
@ -1529,18 +1545,26 @@ class FullNode:
else: else:
await self.server.send_to_all([msg], NodeType.FULL_NODE) await self.server.send_to_all([msg], NodeType.FULL_NODE)
# Tell wallets about the new peak coin_hints: Dict[bytes32, bytes32] = {
msg = make_msg( coin_id: bytes32(hint) for coin_id, hint in ppp_result.hints if len(hint) == 32
ProtocolMessageTypes.new_peak_wallet, }
wallet_protocol.NewPeakWallet(
record.header_hash, peak = Peak(
record.height, state_change_summary.peak.header_hash, state_change_summary.peak.height, state_change_summary.peak.weight
record.weight,
state_change_summary.fork_height,
),
) )
await self.update_wallets(state_change_summary, ppp_result.hints, ppp_result.lookup_coin_ids)
await self.server.send_to_all([msg], NodeType.WALLET) # Looks up coin records in DB for the coins that wallets are interested in
new_states = await self.coin_store.get_coin_records(ppp_result.lookup_coin_ids)
await self.wallet_sync_queue.put(
WalletUpdate(
state_change_summary.fork_height,
peak,
state_change_summary.rolled_back_records + new_states,
coin_hints,
)
)
self._state_changed("new_peak") self._state_changed("new_peak")
async def add_block( async def add_block(

View File

@ -2,6 +2,7 @@ from __future__ import annotations
import asyncio import asyncio
import dataclasses import dataclasses
import logging
import random import random
import time import time
from secrets import token_bytes from secrets import token_bytes
@ -13,8 +14,10 @@ from clvm.casts import int_to_bytes
from chia.consensus.pot_iterations import is_overflow_block from chia.consensus.pot_iterations import is_overflow_block
from chia.full_node.bundle_tools import detect_potential_template_generator from chia.full_node.bundle_tools import detect_potential_template_generator
from chia.full_node.full_node import WalletUpdate
from chia.full_node.full_node_api import FullNodeAPI from chia.full_node.full_node_api import FullNodeAPI
from chia.full_node.signage_point import SignagePoint from chia.full_node.signage_point import SignagePoint
from chia.full_node.sync_store import Peak
from chia.protocols import full_node_protocol from chia.protocols import full_node_protocol
from chia.protocols import full_node_protocol as fnp from chia.protocols import full_node_protocol as fnp
from chia.protocols import timelord_protocol, wallet_protocol from chia.protocols import timelord_protocol, wallet_protocol
@ -28,6 +31,7 @@ from chia.server.server import ChiaServer
from chia.simulator.block_tools import BlockTools, create_block_tools_async, get_signage_point from chia.simulator.block_tools import BlockTools, create_block_tools_async, get_signage_point
from chia.simulator.full_node_simulator import FullNodeSimulator from chia.simulator.full_node_simulator import FullNodeSimulator
from chia.simulator.keyring import TempKeyring from chia.simulator.keyring import TempKeyring
from chia.simulator.setup_nodes import SimulatorsAndWalletsServices
from chia.simulator.setup_services import setup_full_node from chia.simulator.setup_services import setup_full_node
from chia.simulator.simulator_protocol import FarmNewBlockProtocol from chia.simulator.simulator_protocol import FarmNewBlockProtocol
from chia.simulator.time_out_assert import time_out_assert, time_out_assert_custom_interval, time_out_messages from chia.simulator.time_out_assert import time_out_assert, time_out_assert_custom_interval, time_out_messages
@ -37,6 +41,7 @@ from chia.types.blockchain_format.program import Program
from chia.types.blockchain_format.proof_of_space import ProofOfSpace, calculate_plot_id_pk, calculate_pos_challenge from chia.types.blockchain_format.proof_of_space import ProofOfSpace, calculate_plot_id_pk, calculate_pos_challenge
from chia.types.blockchain_format.reward_chain_block import RewardChainBlockUnfinished from chia.types.blockchain_format.reward_chain_block import RewardChainBlockUnfinished
from chia.types.blockchain_format.serialized_program import SerializedProgram from chia.types.blockchain_format.serialized_program import SerializedProgram
from chia.types.blockchain_format.sized_bytes import bytes32
from chia.types.blockchain_format.vdf import CompressibleVDFField, VDFProof from chia.types.blockchain_format.vdf import CompressibleVDFField, VDFProof
from chia.types.coin_spend import CoinSpend from chia.types.coin_spend import CoinSpend
from chia.types.condition_opcodes import ConditionOpcode from chia.types.condition_opcodes import ConditionOpcode
@ -48,7 +53,7 @@ from chia.types.spend_bundle import SpendBundle
from chia.types.unfinished_block import UnfinishedBlock from chia.types.unfinished_block import UnfinishedBlock
from chia.util.errors import ConsensusError, Err from chia.util.errors import ConsensusError, Err
from chia.util.hash import std_hash from chia.util.hash import std_hash
from chia.util.ints import uint8, uint16, uint32, uint64 from chia.util.ints import uint8, uint16, uint32, uint64, uint128
from chia.util.limited_semaphore import LimitedSemaphore from chia.util.limited_semaphore import LimitedSemaphore
from chia.util.recursive_replace import recursive_replace from chia.util.recursive_replace import recursive_replace
from chia.util.vdf_prover import get_vdf_info_and_proof from chia.util.vdf_prover import get_vdf_info_and_proof
@ -2080,3 +2085,29 @@ async def test_node_start_with_existing_blocks(db_version: int) -> None:
assert block_record is not None, f"block_record is None on cycle {cycle + 1}" assert block_record is not None, f"block_record is None on cycle {cycle + 1}"
assert block_record.height == expected_height, f"wrong height on cycle {cycle + 1}" assert block_record.height == expected_height, f"wrong height on cycle {cycle + 1}"
@pytest.mark.asyncio
async def test_wallet_sync_task_failure(
one_node: SimulatorsAndWalletsServices, caplog: pytest.LogCaptureFixture
) -> None:
[full_node_service], _, _ = one_node
full_node = full_node_service._node
assert full_node.wallet_sync_task is not None
caplog.set_level(logging.DEBUG)
peak = Peak(bytes32(32 * b"0"), uint32(0), uint128(0))
# WalletUpdate with invalid args to force an exception in FullNode.update_wallets / FullNode.wallet_sync_task
bad_wallet_update = WalletUpdate(-10, peak, [], {}) # type: ignore[arg-type]
await full_node.wallet_sync_queue.put(bad_wallet_update)
await time_out_assert(30, full_node.wallet_sync_queue.empty)
assert "update_wallets - fork_height: -10, peak_height: 0" in caplog.text
assert "Wallet sync task failure" in caplog.text
assert not full_node.wallet_sync_task.done()
caplog.clear()
# WalletUpdate with valid args to test continued processing after failure
good_wallet_update = WalletUpdate(uint32(10), peak, [], {})
await full_node.wallet_sync_queue.put(good_wallet_update)
await time_out_assert(30, full_node.wallet_sync_queue.empty)
assert "update_wallets - fork_height: 10, peak_height: 0" in caplog.text
assert "Wallet sync task failure" not in caplog.text
assert not full_node.wallet_sync_task.done()