full_node: Move wallet updates into a separate task (#16238)

This commit is contained in:
dustinface 2023-09-13 20:26:47 +07:00 committed by GitHub
parent d7e9090122
commit 2a1e3ae2fc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 110 additions and 55 deletions

View File

@ -35,7 +35,7 @@ from chia.full_node.hint_store import HintStore
from chia.full_node.mempool_manager import MempoolManager
from chia.full_node.signage_point import SignagePoint
from chia.full_node.subscriptions import PeerSubscriptions
from chia.full_node.sync_store import SyncStore
from chia.full_node.sync_store import Peak, SyncStore
from chia.full_node.tx_processing_queue import TransactionQueue
from chia.full_node.weight_proof import WeightProofHandler
from chia.protocols import farmer_protocol, full_node_protocol, timelord_protocol, wallet_protocol
@ -89,6 +89,14 @@ class PeakPostProcessingResult:
lookup_coin_ids: List[bytes32] # The coin IDs that we need to look up to notify wallets of changes
@dataclasses.dataclass(frozen=True)
class WalletUpdate:
fork_height: uint32
peak: Peak
coin_records: List[CoinRecord]
hints: Dict[bytes32, bytes32]
class FullNode:
_segment_task: Optional[asyncio.Task[None]]
initialized: bool
@ -128,6 +136,8 @@ class FullNode:
_timelord_lock: Optional[asyncio.Lock]
weight_proof_handler: Optional[WeightProofHandler]
bad_peak_cache: Dict[bytes32, uint32] # hashes of peaks that failed long sync on chip13 Validation
wallet_sync_queue: asyncio.Queue[WalletUpdate]
wallet_sync_task: Optional[asyncio.Task[None]]
@property
def server(self) -> ChiaServer:
@ -191,6 +201,8 @@ class FullNode:
self._timelord_lock = None
self.weight_proof_handler = None
self.bad_peak_cache = {}
self.wallet_sync_queue = asyncio.Queue()
self.wallet_sync_task = None
@property
def block_store(self) -> BlockStore:
@ -415,6 +427,9 @@ class FullNode:
sanitize_weight_proof_only,
)
)
if self.wallet_sync_task is None or self.wallet_sync_task.done():
self.wallet_sync_task = asyncio.create_task(self._wallets_sync_task_handler())
self.initialized = True
if self.full_node_peers is not None:
asyncio.create_task(self.full_node_peers.start())
@ -872,6 +887,7 @@ class FullNode:
self.uncompact_task.cancel()
if self._transaction_queue_task is not None:
self._transaction_queue_task.cancel()
cancel_task_safe(task=self.wallet_sync_task, log=self.log)
cancel_task_safe(task=self._sync_task, log=self.log)
async def _await_closed(self) -> None:
@ -1122,52 +1138,52 @@ class FullNode:
return []
return [c for c in self.server.all_connections.values() if c.peer_node_id in peer_ids]
async def update_wallets(
self,
state_change_summary: StateChangeSummary,
hints: List[Tuple[bytes32, bytes]],
lookup_coin_ids: List[bytes32],
) -> None:
# Looks up coin records in DB for the coins that wallets are interested in
new_states: List[CoinRecord] = await self.coin_store.get_coin_records(lookup_coin_ids)
# Re-arrange to a map, and filter out any non-ph sized hint
coin_id_to_ph_hint: Dict[bytes32, bytes32] = {
coin_id: bytes32(hint) for coin_id, hint in hints if len(hint) == 32
}
async def _wallets_sync_task_handler(self) -> None:
while not self._shut_down:
try:
wallet_update = await self.wallet_sync_queue.get()
await self.update_wallets(wallet_update)
except Exception:
self.log.exception("Wallet sync task failure")
continue
async def update_wallets(self, wallet_update: WalletUpdate) -> None:
self.log.debug(
f"update_wallets - fork_height: {wallet_update.fork_height}, peak_height: {wallet_update.peak.height}"
)
changes_for_peer: Dict[bytes32, Set[CoinState]] = {}
for coin_record in state_change_summary.rolled_back_records + new_states:
cr_name: bytes32 = coin_record.name
for peer in self.subscriptions.peers_for_coin_id(cr_name):
if peer not in changes_for_peer:
changes_for_peer[peer] = set()
changes_for_peer[peer].add(coin_record.coin_state)
for peer in self.subscriptions.peers_for_puzzle_hash(coin_record.coin.puzzle_hash):
if peer not in changes_for_peer:
changes_for_peer[peer] = set()
changes_for_peer[peer].add(coin_record.coin_state)
if cr_name in coin_id_to_ph_hint:
for peer in self.subscriptions.peers_for_puzzle_hash(coin_id_to_ph_hint[cr_name]):
if peer not in changes_for_peer:
changes_for_peer[peer] = set()
changes_for_peer[peer].add(coin_record.coin_state)
for coin_record in wallet_update.coin_records:
coin_id = coin_record.name
subscribed_peers = self.subscriptions.peers_for_coin_id(coin_id)
subscribed_peers.update(self.subscriptions.peers_for_puzzle_hash(coin_record.coin.puzzle_hash))
hint = wallet_update.hints.get(coin_id)
if hint is not None:
subscribed_peers.update(self.subscriptions.peers_for_puzzle_hash(hint))
for peer in subscribed_peers:
changes_for_peer.setdefault(peer, set()).add(coin_record.coin_state)
for peer, changes in changes_for_peer.items():
if peer not in self.server.all_connections:
continue
ws_peer: WSChiaConnection = self.server.all_connections[peer]
state = CoinStateUpdate(
state_change_summary.peak.height,
state_change_summary.fork_height,
state_change_summary.peak.header_hash,
list(changes),
)
msg = make_msg(ProtocolMessageTypes.coin_state_update, state)
await ws_peer.send_message(msg)
connection = self.server.all_connections.get(peer)
if connection is not None:
state = CoinStateUpdate(
wallet_update.peak.height,
wallet_update.fork_height,
wallet_update.peak.header_hash,
list(changes),
)
await connection.send_message(make_msg(ProtocolMessageTypes.coin_state_update, state))
# Tell wallets about the new peak
new_peak_message = make_msg(
ProtocolMessageTypes.new_peak_wallet,
wallet_protocol.NewPeakWallet(
wallet_update.peak.header_hash,
wallet_update.peak.height,
wallet_update.peak.weight,
wallet_update.fork_height,
),
)
await self.server.send_to_all([new_peak_message], NodeType.WALLET)
async def add_block_batch(
self,
@ -1529,18 +1545,26 @@ class FullNode:
else:
await self.server.send_to_all([msg], NodeType.FULL_NODE)
# Tell wallets about the new peak
msg = make_msg(
ProtocolMessageTypes.new_peak_wallet,
wallet_protocol.NewPeakWallet(
record.header_hash,
record.height,
record.weight,
state_change_summary.fork_height,
),
coin_hints: Dict[bytes32, bytes32] = {
coin_id: bytes32(hint) for coin_id, hint in ppp_result.hints if len(hint) == 32
}
peak = Peak(
state_change_summary.peak.header_hash, state_change_summary.peak.height, state_change_summary.peak.weight
)
await self.update_wallets(state_change_summary, ppp_result.hints, ppp_result.lookup_coin_ids)
await self.server.send_to_all([msg], NodeType.WALLET)
# Looks up coin records in DB for the coins that wallets are interested in
new_states = await self.coin_store.get_coin_records(ppp_result.lookup_coin_ids)
await self.wallet_sync_queue.put(
WalletUpdate(
state_change_summary.fork_height,
peak,
state_change_summary.rolled_back_records + new_states,
coin_hints,
)
)
self._state_changed("new_peak")
async def add_block(

View File

@ -2,6 +2,7 @@ from __future__ import annotations
import asyncio
import dataclasses
import logging
import random
import time
from secrets import token_bytes
@ -13,8 +14,10 @@ from clvm.casts import int_to_bytes
from chia.consensus.pot_iterations import is_overflow_block
from chia.full_node.bundle_tools import detect_potential_template_generator
from chia.full_node.full_node import WalletUpdate
from chia.full_node.full_node_api import FullNodeAPI
from chia.full_node.signage_point import SignagePoint
from chia.full_node.sync_store import Peak
from chia.protocols import full_node_protocol
from chia.protocols import full_node_protocol as fnp
from chia.protocols import timelord_protocol, wallet_protocol
@ -28,6 +31,7 @@ from chia.server.server import ChiaServer
from chia.simulator.block_tools import BlockTools, create_block_tools_async, get_signage_point
from chia.simulator.full_node_simulator import FullNodeSimulator
from chia.simulator.keyring import TempKeyring
from chia.simulator.setup_nodes import SimulatorsAndWalletsServices
from chia.simulator.setup_services import setup_full_node
from chia.simulator.simulator_protocol import FarmNewBlockProtocol
from chia.simulator.time_out_assert import time_out_assert, time_out_assert_custom_interval, time_out_messages
@ -37,6 +41,7 @@ from chia.types.blockchain_format.program import Program
from chia.types.blockchain_format.proof_of_space import ProofOfSpace, calculate_plot_id_pk, calculate_pos_challenge
from chia.types.blockchain_format.reward_chain_block import RewardChainBlockUnfinished
from chia.types.blockchain_format.serialized_program import SerializedProgram
from chia.types.blockchain_format.sized_bytes import bytes32
from chia.types.blockchain_format.vdf import CompressibleVDFField, VDFProof
from chia.types.coin_spend import CoinSpend
from chia.types.condition_opcodes import ConditionOpcode
@ -48,7 +53,7 @@ from chia.types.spend_bundle import SpendBundle
from chia.types.unfinished_block import UnfinishedBlock
from chia.util.errors import ConsensusError, Err
from chia.util.hash import std_hash
from chia.util.ints import uint8, uint16, uint32, uint64
from chia.util.ints import uint8, uint16, uint32, uint64, uint128
from chia.util.limited_semaphore import LimitedSemaphore
from chia.util.recursive_replace import recursive_replace
from chia.util.vdf_prover import get_vdf_info_and_proof
@ -2080,3 +2085,29 @@ async def test_node_start_with_existing_blocks(db_version: int) -> None:
assert block_record is not None, f"block_record is None on cycle {cycle + 1}"
assert block_record.height == expected_height, f"wrong height on cycle {cycle + 1}"
@pytest.mark.asyncio
async def test_wallet_sync_task_failure(
one_node: SimulatorsAndWalletsServices, caplog: pytest.LogCaptureFixture
) -> None:
[full_node_service], _, _ = one_node
full_node = full_node_service._node
assert full_node.wallet_sync_task is not None
caplog.set_level(logging.DEBUG)
peak = Peak(bytes32(32 * b"0"), uint32(0), uint128(0))
# WalletUpdate with invalid args to force an exception in FullNode.update_wallets / FullNode.wallet_sync_task
bad_wallet_update = WalletUpdate(-10, peak, [], {}) # type: ignore[arg-type]
await full_node.wallet_sync_queue.put(bad_wallet_update)
await time_out_assert(30, full_node.wallet_sync_queue.empty)
assert "update_wallets - fork_height: -10, peak_height: 0" in caplog.text
assert "Wallet sync task failure" in caplog.text
assert not full_node.wallet_sync_task.done()
caplog.clear()
# WalletUpdate with valid args to test continued processing after failure
good_wallet_update = WalletUpdate(uint32(10), peak, [], {})
await full_node.wallet_sync_queue.put(good_wallet_update)
await time_out_assert(30, full_node.wallet_sync_queue.empty)
assert "update_wallets - fork_height: 10, peak_height: 0" in caplog.text
assert "Wallet sync task failure" not in caplog.text
assert not full_node.wallet_sync_task.done()