diff --git a/chia/wallet/wallet_node.py b/chia/wallet/wallet_node.py index 0c9cefd18e7f..c0dc82d0cc89 100644 --- a/chia/wallet/wallet_node.py +++ b/chia/wallet/wallet_node.py @@ -364,7 +364,7 @@ class WalletNode: connect_to_unknown_peers = self.config.get("connect_to_unknown_peers", True) testing = self.config.get("testing", False) - if connect_to_unknown_peers and not testing: + if self.wallet_peers is None and connect_to_unknown_peers and not testing: self.wallet_peers = WalletPeers( self.server, self.config["target_peer_count"], @@ -388,6 +388,7 @@ class WalletNode: def on_disconnect(self, peer: WSChiaConnection): if self.is_trusted(peer): self.local_node_synced = False + self.initialize_wallet_peers() if peer.peer_node_id in self.untrusted_caches: self.untrusted_caches.pop(peer.peer_node_id) @@ -537,22 +538,32 @@ class WalletNode: if all_coins_state is not None and self.is_trusted(peer): await self.wallet_state_manager.new_coin_state(all_coins_state.coin_states, peer) - async def get_coin_state(self, coin_names: List[bytes32]) -> List[CoinState]: + async def get_coin_state( + self, coin_names: List[bytes32], peer: Optional[WSChiaConnection] = None + ) -> List[CoinState]: assert self.server is not None all_nodes = self.server.connection_by_type[NodeType.FULL_NODE] if len(all_nodes.keys()) == 0: raise ValueError("Not connected to the full node") - first_node = list(all_nodes.values())[0] + + # Use supplied if provided, prioritize trusted otherwise + if peer is None: + for node in list(all_nodes.values()): + if self.is_trusted(node): + peer = node + break + if peer is None: + peer = list(all_nodes.values())[0] + + assert peer is not None msg = wallet_protocol.RegisterForCoinUpdates(coin_names, uint32(0)) - coin_state: Optional[RespondToCoinUpdates] = await first_node.register_interest_in_coin(msg) + coin_state: Optional[RespondToCoinUpdates] = await peer.register_interest_in_coin(msg) assert coin_state is not None - if not self.is_trusted(first_node): + if not self.is_trusted(peer): valid_list = [] for coin in coin_state.coin_states: - valid = await self.validate_received_state_from_peer( - coin, first_node, self.get_cache_for_peer(first_node) - ) + valid = await self.validate_received_state_from_peer(coin, peer, self.get_cache_for_peer(peer)) if valid: valid_list.append(coin) return valid_list @@ -604,6 +615,23 @@ class WalletNode: else: return None + async def last_local_tx_block(self, header_hash: bytes32) -> Optional[BlockRecord]: + assert self.wallet_state_manager is not None + current_hash = header_hash + while True: + if self.wallet_state_manager.blockchain.contains_block(current_hash): + block = self.wallet_state_manager.blockchain.try_block_record(current_hash) + if block is None: + return None + if block.is_transaction_block: + return block + if block.prev_transaction_block_hash is None: + return None + current_hash = block.prev_transaction_block_hash + else: + break + return None + async def fetch_last_tx_from_peer(self, height: uint32, peer: WSChiaConnection) -> Optional[HeaderBlock]: request_height = height while True: @@ -619,6 +647,16 @@ class WalletNode: request_height = uint32(request_height - 1) return None + async def disconnect_and_stop_wpeers(self): + if len(self.server.get_full_node_connections()) > 1: + for peer in self.server.get_full_node_connections(): + if not self.is_trusted(peer): + asyncio.create_task(peer.close()) + + if self.wallet_peers is not None: + await self.wallet_peers.ensure_is_closed() + self.wallet_peers = None + async def get_timestamp_for_height(self, height: uint32) -> uint64: """ Returns the timestamp for transaction block at h=height, if not transaction block, backtracks until it finds @@ -654,37 +692,31 @@ class WalletNode: assert self.wallet_state_manager is not None assert self.server is not None request_time = int(time.time()) - async with self.new_peak_lock: - if self.wallet_state_manager is None: - # When logging out of wallet + + if self.wallet_state_manager is None: + # When logging out of wallet + return + if self.is_trusted(peer): + request = wallet_protocol.RequestBlockHeader(peak.height) + header_response: Optional[RespondBlockHeader] = await peer.request_block_header(request) + assert header_response is not None + + # Get last timestamp + last_tx: Optional[HeaderBlock] = await self.fetch_last_tx_from_peer(peak.height, peer) + latest_timestamp: Optional[uint64] = None + if last_tx is not None: + assert last_tx.foliage_transaction_block is not None + latest_timestamp = last_tx.foliage_transaction_block.timestamp + + # Ignore if not synced + if latest_timestamp is None or self.config["testing"] is False and latest_timestamp < request_time - 600: return - if self.is_trusted(peer): + + # Disconnect from all untrusted peers if our local node is trusted and synced + await self.disconnect_and_stop_wpeers() + + async with self.new_peak_lock: async with self.wallet_state_manager.lock: - request = wallet_protocol.RequestBlockHeader(peak.height) - header_response: Optional[RespondBlockHeader] = await peer.request_block_header(request) - assert header_response is not None - - # Get last timestamp - last_tx: Optional[HeaderBlock] = await self.fetch_last_tx_from_peer(peak.height, peer) - latest_timestamp: Optional[uint64] = None - if last_tx is not None: - assert last_tx.foliage_transaction_block is not None - latest_timestamp = last_tx.foliage_transaction_block.timestamp - - # Ignore if not synced - if ( - latest_timestamp is None - or self.config["testing"] is False - and latest_timestamp < request_time - 600 - ): - return - - # Disconnect from all untrusted peers if our local node is trusted and synced - if len(self.server.get_full_node_connections()) > 1: - for peer in self.server.get_full_node_connections(): - if not self.is_trusted(peer): - asyncio.create_task(peer.close()) - # Sync to trusted node self.local_node_synced = True if peer.peer_node_id not in self.synced_peers: @@ -696,7 +728,8 @@ class WalletNode: self.wallet_state_manager.state_changed("new_block") self.wallet_state_manager.set_sync_mode(False) - else: + else: + async with self.new_peak_lock: request = wallet_protocol.RequestBlockHeader(peak.height) response: Optional[RespondBlockHeader] = await peer.request_block_header(request) if response is None or not isinstance(response, RespondBlockHeader) or response.header_block is None: @@ -718,19 +751,27 @@ class WalletNode: # This block is after our peak, so we don't need to check if node is synced pass else: + tx_timestamp = None if not response.header_block.is_transaction_block: - last_tx_block = await self.fetch_last_tx_from_peer(response.header_block.height, peer) + last_tx_block = None + # Try local first + last_block_record = await self.last_local_tx_block(response.header_block.prev_header_hash) + if last_block_record is not None: + tx_timestamp = last_block_record.timestamp + else: + last_tx_block = await self.fetch_last_tx_from_peer(response.header_block.height, peer) + if last_tx_block is not None: + assert last_tx_block.foliage_transaction_block is not None + tx_timestamp = last_tx_block.foliage_transaction_block.timestamp else: last_tx_block = response.header_block + assert last_tx_block.foliage_transaction_block is not None + tx_timestamp = last_tx_block.foliage_transaction_block.timestamp - if last_tx_block is None: - return - assert last_tx_block is not None - assert last_tx_block.foliage_transaction_block is not None - if ( - self.config["testing"] is False - and last_tx_block.foliage_transaction_block.timestamp < request_time - 600 - ): + if tx_timestamp is None: + return None + + if self.config["testing"] is False and tx_timestamp < request_time - 600: # Full node not synced, don't sync to it self.log.info("Peer we connected to is not fully synced, dropping connection...") await peer.close() @@ -764,11 +805,17 @@ class WalletNode: return assert weight_proof is not None old_proof = self.wallet_state_manager.blockchain.synced_weight_proof + curr_peak = await self.wallet_state_manager.blockchain.get_peak_block() fork_point = 0 + if curr_peak is not None: + fork_point = max(0, curr_peak.height - 32) + if old_proof is not None: - fork_point = self.wallet_state_manager.weight_proof_handler.get_fork_point( + wp_fork_point = self.wallet_state_manager.weight_proof_handler.get_fork_point( old_proof, weight_proof ) + if wp_fork_point != 0: + fork_point = wp_fork_point await self.wallet_state_manager.blockchain.new_weight_proof(weight_proof, block_records) if syncing: @@ -786,6 +833,7 @@ class WalletNode: self.synced_peers.add(peer.peer_node_id) self.wallet_state_manager.state_changed("new_block") + self.wallet_state_manager.set_sync_mode(False) await self.update_ui() except Exception: if syncing: @@ -840,8 +888,6 @@ class WalletNode: fork_height = top.height - 1 blocks.reverse() - # Roll back coins and transactions - self.log.info(f"Rolling back to {fork_height}") await self.wallet_state_manager.reorg_rollback(fork_height) peak = await self.wallet_state_manager.blockchain.get_peak_block() self.rollback_request_caches(fork_height) @@ -1150,7 +1196,7 @@ class WalletNode: if stored_record.header_hash == block.header_hash: return True - weight_proof = self.wallet_state_manager.blockchain.synced_weight_proof + weight_proof: Optional[WeightProof] = self.wallet_state_manager.blockchain.synced_weight_proof if weight_proof is None: return False @@ -1171,26 +1217,30 @@ class WalletNode: compare_to_recent = True end = first_height_recent else: - request = RequestSESInfo(block.height, block.height + 32) - if block.height in peer_request_cache.ses_requests: - res_ses: RespondSESInfo = peer_request_cache.ses_requests[block.height] + if block.height < self.constants.SUB_EPOCH_BLOCKS: + inserted = weight_proof.sub_epochs[1] + end = self.constants.SUB_EPOCH_BLOCKS + inserted.num_blocks_overflow else: - res_ses = await peer.request_ses_hashes(request) - peer_request_cache.ses_requests[block.height] = res_ses + request = RequestSESInfo(block.height, block.height + 32) + if block.height in peer_request_cache.ses_requests: + res_ses: RespondSESInfo = peer_request_cache.ses_requests[block.height] + else: + res_ses = await peer.request_ses_hashes(request) + peer_request_cache.ses_requests[block.height] = res_ses - ses_0 = res_ses.reward_chain_hash[0] - last_height = res_ses.heights[0][-1] # Last height in sub epoch - end = last_height - for idx, ses in enumerate(weight_proof.sub_epochs): - if idx > len(weight_proof.sub_epochs) - 3: - break - if ses.reward_chain_hash == ses_0: - current_ses = ses - inserted = weight_proof.sub_epochs[idx + 2] - break - if current_ses is None: - self.log.error("Failed validation 2") - return False + ses_0 = res_ses.reward_chain_hash[0] + last_height = res_ses.heights[0][-1] # Last height in sub epoch + end = last_height + for idx, ses in enumerate(weight_proof.sub_epochs): + if idx > len(weight_proof.sub_epochs) - 3: + break + if ses.reward_chain_hash == ses_0: + current_ses = ses + inserted = weight_proof.sub_epochs[idx + 2] + break + if current_ses is None: + self.log.error("Failed validation 2") + return False blocks = [] diff --git a/chia/wallet/wallet_state_manager.py b/chia/wallet/wallet_state_manager.py index dadb6a74dca4..91cd6be97065 100644 --- a/chia/wallet/wallet_state_manager.py +++ b/chia/wallet/wallet_state_manager.py @@ -571,7 +571,7 @@ class WalletStateManager: ): return None, None - response: List[CoinState] = await self.wallet_node.get_coin_state([coin_state.coin.parent_coin_info]) + response: List[CoinState] = await self.wallet_node.get_coin_state([coin_state.coin.parent_coin_info], peer) if len(response) == 0: self.log.warning(f"Could not find a parent coin with ID: {coin_state.coin.parent_coin_info}") return None, None @@ -869,7 +869,6 @@ class WalletStateManager: except Exception as e: self.log.debug(f"Not a pool wallet launcher {e}") continue - # solution_to_pool_state may return None but this may not be an error if pool_state is None: self.log.debug("solution_to_pool_state returned None, ignore and continue")