From 371f76e377c197b6a53899c72650f9854c26fd0c Mon Sep 17 00:00:00 2001 From: Almog De Paz Date: Mon, 11 Jul 2022 19:43:12 +0300 Subject: [PATCH] Wp validate from fork (#11826) * only validate segments from fork point * lint * isort * isort --- chia/full_node/weight_proof.py | 27 +++-- chia/wallet/wallet_blockchain.py | 2 +- chia/wallet/wallet_node.py | 13 +-- chia/wallet/wallet_weight_proof_handler.py | 116 +++++++++++---------- tests/wallet/sync/test_wallet_sync.py | 29 ++++++ 5 files changed, 117 insertions(+), 70 deletions(-) diff --git a/chia/full_node/weight_proof.py b/chia/full_node/weight_proof.py index 3ac6fd656715..86ab1de71afd 100644 --- a/chia/full_node/weight_proof.py +++ b/chia/full_node/weight_proof.py @@ -578,7 +578,8 @@ class WeightProofHandler: log.info("validate weight proof recent blocks") if not _validate_recent_blocks(self.constants, wp_recent_chain_bytes, summary_bytes): return False, uint32(0) - return True, self.get_fork_point(summaries) + fork_point, _ = self.get_fork_point(summaries) + return True, fork_point def get_fork_point_no_validations(self, weight_proof: WeightProof) -> Tuple[bool, uint32]: log.debug("get fork point skip validations") @@ -590,7 +591,8 @@ class WeightProofHandler: if summaries is None: log.warning("weight proof failed to validate sub epoch summaries") return False, uint32(0) - return True, self.get_fork_point(summaries) + fork_height, _ = self.get_fork_point(summaries) + return True, fork_height async def validate_weight_proof(self, weight_proof: WeightProof) -> Tuple[bool, uint32, List[SubEpochSummary]]: assert self.blockchain is not None @@ -620,6 +622,7 @@ class WeightProofHandler: log.error("failed weight proof sub epoch sample validation") return False, uint32(0), [] + fork_point, ses_fork_idx = self.get_fork_point(summaries) # timing reference: 1 second # TODO: Consider implementing an async polling closer for the executor. with ProcessPoolExecutor( @@ -648,7 +651,7 @@ class WeightProofHandler: # timing reference: 2 second segments_validated, vdfs_to_validate = _validate_sub_epoch_segments( - self.constants, rng, wp_segment_bytes, summary_bytes + self.constants, rng, wp_segment_bytes, summary_bytes, ses_fork_idx ) await asyncio.sleep(0) # break up otherwise multi-second sync code if not segments_validated: @@ -686,9 +689,9 @@ class WeightProofHandler: log.error("failed validating weight proof recent blocks") return False, uint32(0), [] - return True, self.get_fork_point(summaries), summaries + return True, fork_point, summaries - def get_fork_point(self, received_summaries: List[SubEpochSummary]) -> uint32: + def get_fork_point(self, received_summaries: List[SubEpochSummary]) -> Tuple[uint32, int]: # iterate through sub epoch summaries to find fork point fork_point_index = 0 ses_heights = self.blockchain.get_ses_heights() @@ -702,14 +705,12 @@ class WeightProofHandler: break fork_point_index = idx - if fork_point_index > 2: + if fork_point_index <= 2: # Two summeries can have different blocks and still be identical # This gets resolved after one full sub epoch - height = ses_heights[fork_point_index - 2] - else: - height = uint32(0) + return uint32(0), 0 - return height + return ses_heights[fork_point_index - 2], fork_point_index def _get_weights_for_sampling( @@ -994,6 +995,7 @@ def _validate_sub_epoch_segments( rng: random.Random, weight_proof_bytes: bytes, summaries_bytes: List[bytes], + validate_from: int = 0, ): summaries = summaries_from_bytes(summaries_bytes) sub_epoch_segments: SubEpochSegments = SubEpochSegments.from_bytes(weight_proof_bytes) @@ -1018,6 +1020,11 @@ def _validate_sub_epoch_segments( if not summaries[sub_epoch_n].reward_chain_hash == rc_sub_slot_hash: log.error(f"failed reward_chain_hash validation sub_epoch {sub_epoch_n}") return False + + # skip validation up to fork height + if sub_epoch_n < validate_from: + continue + for idx, segment in enumerate(segments): valid_segment, ip_iters, slot_iters, slots, vdf_list = _validate_segment( constants, segment, curr_ssi, prev_ssi, curr_difficulty, prev_ses, idx == 0, sampled_seg_index == idx diff --git a/chia/wallet/wallet_blockchain.py b/chia/wallet/wallet_blockchain.py index 64ffb7206247..03ee3566c20e 100644 --- a/chia/wallet/wallet_blockchain.py +++ b/chia/wallet/wallet_blockchain.py @@ -77,7 +77,7 @@ class WalletBlockchain(BlockchainInterface): latest_timestamp = self._latest_timestamp if records is None: - success, _, _, records = await self._weight_proof_handler.validate_weight_proof(weight_proof, True) + success, _, records = await self._weight_proof_handler.validate_weight_proof(weight_proof, True) assert success assert records is not None and len(records) > 1 diff --git a/chia/wallet/wallet_node.py b/chia/wallet/wallet_node.py index 4419a2a4c5aa..c7c8213667c8 100644 --- a/chia/wallet/wallet_node.py +++ b/chia/wallet/wallet_node.py @@ -67,6 +67,7 @@ from chia.wallet.util.wallet_sync_utils import ( from chia.wallet.wallet_action import WalletAction from chia.wallet.wallet_coin_record import WalletCoinRecord from chia.wallet.wallet_state_manager import WalletStateManager +from chia.wallet.wallet_weight_proof_handler import get_wp_fork_point @dataclasses.dataclass @@ -990,10 +991,7 @@ class WalletNode: if old_proof is not None: # If the weight proof fork point is in the past, rollback more to ensure we don't have duplicate # state. - wp_fork_point = self.wallet_state_manager.weight_proof_handler.get_fork_point( - old_proof, weight_proof - ) - fork_point = min(fork_point, wp_fork_point) + fork_point = min(fork_point, get_wp_fork_point(self.constants, old_proof, weight_proof)) await self.wallet_state_manager.blockchain.new_weight_proof(weight_proof, block_records) if syncing: @@ -1153,13 +1151,16 @@ class WalletNode: if weight_proof.get_hash() in self.valid_wp_cache: valid, fork_point, summaries, block_records = self.valid_wp_cache[weight_proof.get_hash()] else: + old_proof = self.wallet_state_manager.blockchain.synced_weight_proof + fork_point = get_wp_fork_point(self.constants, old_proof, weight_proof) start_validation = time.time() ( valid, - fork_point, summaries, block_records, - ) = await self.wallet_state_manager.weight_proof_handler.validate_weight_proof(weight_proof) + ) = await self.wallet_state_manager.weight_proof_handler.validate_weight_proof( + weight_proof, False, old_proof + ) if valid: self.valid_wp_cache[weight_proof.get_hash()] = valid, fork_point, summaries, block_records diff --git a/chia/wallet/wallet_weight_proof_handler.py b/chia/wallet/wallet_weight_proof_handler.py index fcd5fe982fd1..6afe5b0d482d 100644 --- a/chia/wallet/wallet_weight_proof_handler.py +++ b/chia/wallet/wallet_weight_proof_handler.py @@ -65,22 +65,23 @@ class WalletWeightProofHandler: self._executor.shutdown(wait=True) async def validate_weight_proof( - self, weight_proof: WeightProof, skip_segment_validation=False - ) -> Tuple[bool, uint32, List[SubEpochSummary], List[BlockRecord]]: + self, weight_proof: WeightProof, skip_segment_validation: bool = False, old_proof: Optional[WeightProof] = None + ) -> Tuple[bool, List[SubEpochSummary], List[BlockRecord]]: + validate_from = get_fork_ses_idx(old_proof, weight_proof) task: asyncio.Task = asyncio.create_task( - self._validate_weight_proof_inner(weight_proof, skip_segment_validation) + self._validate_weight_proof_inner(weight_proof, skip_segment_validation, validate_from) ) self._weight_proof_tasks.append(task) - valid, fork_point, summaries, block_records = await task + valid, summaries, block_records = await task self._weight_proof_tasks.remove(task) - return valid, fork_point, summaries, block_records + return valid, summaries, block_records async def _validate_weight_proof_inner( - self, weight_proof: WeightProof, skip_segment_validation: bool - ) -> Tuple[bool, uint32, List[SubEpochSummary], List[BlockRecord]]: + self, weight_proof: WeightProof, skip_segment_validation: bool, validate_from: int + ) -> Tuple[bool, List[SubEpochSummary], List[BlockRecord]]: assert len(weight_proof.sub_epochs) > 0 if len(weight_proof.sub_epochs) == 0: - return False, uint32(0), [], [] + return False, [], [] peak_height = weight_proof.recent_chain_data[-1].reward_chain_block.height log.info(f"validate weight proof peak height {peak_height}") @@ -94,13 +95,13 @@ class WalletWeightProofHandler: await asyncio.sleep(0) # break up otherwise multi-second sync code if summaries is None: log.error("weight proof failed sub epoch data validation") - return False, uint32(0), [], [] + return False, [], [] seed = summaries[-2].get_hash() rng = random.Random(seed) if not validate_sub_epoch_sampling(rng, sub_epoch_weight_list, weight_proof): log.error("failed weight proof sub epoch sample validation") - return False, uint32(0), [], [] + return False, [], [] summary_bytes, wp_segment_bytes, wp_recent_chain_bytes = vars_to_bytes(summaries, weight_proof) await asyncio.sleep(0) # break up otherwise multi-second sync code @@ -117,12 +118,12 @@ class WalletWeightProofHandler: try: if not skip_segment_validation: segments_validated, vdfs_to_validate = _validate_sub_epoch_segments( - self._constants, rng, wp_segment_bytes, summary_bytes + self._constants, rng, wp_segment_bytes, summary_bytes, validate_from ) await asyncio.sleep(0) # break up otherwise multi-second sync code if not segments_validated: - return False, uint32(0), [], [] + return False, [], [] vdf_chunks = chunks(vdfs_to_validate, self._num_processes) for chunk in vdf_chunks: @@ -144,7 +145,7 @@ class WalletWeightProofHandler: for vdf_task in vdf_tasks: validated = await vdf_task if not validated: - return False, uint32(0), [], [] + return False, [], [] valid_recent_blocks, records_bytes = await recent_blocks_validation_task finally: @@ -155,56 +156,65 @@ class WalletWeightProofHandler: if not valid_recent_blocks: log.error("failed validating weight proof recent blocks") # Verify the data - return False, uint32(0), [], [] + return False, [], [] records = [BlockRecord.from_bytes(b) for b in records_bytes] + return True, summaries, records - # TODO fix find fork point - return True, uint32(0), summaries, records - def get_fork_point(self, old_wp: Optional[WeightProof], new_wp: WeightProof) -> uint32: - """ - iterate through sub epoch summaries to find fork point. This method is conservative, it does not return the - actual fork point, it can return a height that is before the actual fork point. - """ +def get_wp_fork_point(constants: ConsensusConstants, old_wp: Optional[WeightProof], new_wp: WeightProof) -> uint32: + """ + iterate through sub epoch summaries to find fork point. This method is conservative, it does not return the + actual fork point, it can return a height that is before the actual fork point. + """ - if old_wp is None: - return uint32(0) + if old_wp is None: + return uint32(0) - old_ses = set() + overflow = 0 + count = 0 + for idx, new_ses in enumerate(new_wp.sub_epochs): + if idx == len(new_wp.sub_epochs) - 1 or idx == len(old_wp.sub_epochs): + break + if new_ses.reward_chain_hash != old_wp.sub_epochs[idx].reward_chain_hash: + break - for ses in old_wp.sub_epochs: - old_ses.add(ses.reward_chain_hash) - - overflow = 0 - count = 0 - for new_ses in new_wp.sub_epochs: - if new_ses.reward_chain_hash in old_ses: - count += 1 - overflow = new_ses.num_blocks_overflow - continue - else: - break + count = idx + 1 + overflow = new_wp.sub_epochs[idx + 1].num_blocks_overflow + if new_wp.recent_chain_data[0].height < old_wp.recent_chain_data[-1].height: # Try to find an exact fork point - if new_wp.recent_chain_data[0].height >= old_wp.recent_chain_data[0].height: - left_wp = old_wp - right_wp = new_wp - else: - left_wp = new_wp - right_wp = old_wp - - r_index = 0 - l_index = 0 - while r_index < len(right_wp.recent_chain_data) and l_index < len(left_wp.recent_chain_data): - if right_wp.recent_chain_data[r_index].header_hash == left_wp.recent_chain_data[l_index].header_hash: - r_index += 1 + new_wp_index = 0 + old_wp_index = 0 + while new_wp_index < len(new_wp.recent_chain_data) and old_wp_index < len(old_wp.recent_chain_data): + if new_wp.recent_chain_data[new_wp_index].header_hash == old_wp.recent_chain_data[old_wp_index].header_hash: + new_wp_index += 1 continue # Keep incrementing left pointer until we find a match - l_index += 1 - if r_index != 0: + old_wp_index += 1 + if new_wp_index != 0: # We found a matching block, this is the last matching block - return right_wp.recent_chain_data[r_index - 1].height + return new_wp.recent_chain_data[new_wp_index - 1].height - # Just return the matching sub epoch height - return uint32((self._constants.SUB_EPOCH_BLOCKS * count) - overflow) + # Just return the matching sub epoch height + return uint32((constants.SUB_EPOCH_BLOCKS * count) + overflow) + + +def get_fork_ses_idx(old_wp: Optional[WeightProof], new_wp: WeightProof) -> int: + """ + iterate through sub epoch summaries to find fork point. This method is conservative, it does not return the + actual fork point, it can return a height that is before the actual fork point. + """ + + if old_wp is None: + return uint32(0) + ses_index = 0 + for idx, new_ses in enumerate(new_wp.sub_epochs): + if new_ses.reward_chain_hash != old_wp.sub_epochs[idx].reward_chain_hash: + ses_index = idx + break + + if idx == len(old_wp.sub_epochs) - 1: + ses_index = idx + break + return ses_index diff --git a/tests/wallet/sync/test_wallet_sync.py b/tests/wallet/sync/test_wallet_sync.py index 469d3646f2ad..06a5baf3e50a 100644 --- a/tests/wallet/sync/test_wallet_sync.py +++ b/tests/wallet/sync/test_wallet_sync.py @@ -6,6 +6,7 @@ from colorlog import getLogger from chia.consensus.block_record import BlockRecord from chia.consensus.block_rewards import calculate_base_farmer_reward, calculate_pool_reward from chia.full_node.full_node_api import FullNodeAPI +from chia.full_node.weight_proof import WeightProofHandler from chia.protocols import full_node_protocol, wallet_protocol from chia.protocols.protocol_message_types import ProtocolMessageTypes from chia.protocols.shared_protocol import Capability @@ -13,14 +14,17 @@ from chia.protocols.wallet_protocol import RequestAdditions, RespondAdditions, R from chia.server.outbound_message import Message from chia.simulator.simulator_protocol import FarmNewBlockProtocol from chia.types.peer_info import PeerInfo +from chia.util.block_cache import BlockCache from chia.util.hash import std_hash from chia.util.ints import uint16, uint32, uint64 from chia.wallet.transaction_record import TransactionRecord from chia.wallet.util.wallet_types import AmountWithPuzzlehash +from chia.wallet.wallet_weight_proof_handler import get_wp_fork_point from tests.connection_utils import disconnect_all, disconnect_all_and_reconnect from tests.pools.test_pool_rpc import wallet_is_synced from tests.setup_nodes import test_constants from tests.time_out_assert import time_out_assert +from tests.weight_proof.test_weight_proof import load_blocks_dont_validate def wallet_height_at_least(wallet_node, h): @@ -559,3 +563,28 @@ class TestWalletSync: response = RespondAdditions.from_bytes(res4.data) assert response.proofs == [] assert len(response.coins) == 0 + + @pytest.mark.asyncio + async def test_get_wp_fork_point(self, default_10000_blocks): + blocks = default_10000_blocks + header_cache, height_to_hash, sub_blocks, summaries = await load_blocks_dont_validate(blocks) + wpf = WeightProofHandler(test_constants, BlockCache(sub_blocks, header_cache, height_to_hash, summaries)) + wp1 = await wpf.get_proof_of_weight(header_cache[height_to_hash[uint32(9000)]].header_hash) + wp2 = await wpf.get_proof_of_weight(header_cache[height_to_hash[uint32(9030)]].header_hash) + wp3 = await wpf.get_proof_of_weight(header_cache[height_to_hash[uint32(7500)]].header_hash) + wp4 = await wpf.get_proof_of_weight(header_cache[height_to_hash[uint32(8700)]].header_hash) + wp5 = await wpf.get_proof_of_weight(header_cache[height_to_hash[uint32(9700)]].header_hash) + fork12 = get_wp_fork_point(test_constants, wp1, wp2) + fork13 = get_wp_fork_point(test_constants, wp3, wp1) + fork14 = get_wp_fork_point(test_constants, wp4, wp1) + fork23 = get_wp_fork_point(test_constants, wp3, wp2) + fork24 = get_wp_fork_point(test_constants, wp4, wp2) + fork34 = get_wp_fork_point(test_constants, wp3, wp4) + fork45 = get_wp_fork_point(test_constants, wp4, wp5) + assert fork14 == 8700 + assert fork24 == 8700 + assert fork12 == 9000 + assert fork13 in summaries.keys() + assert fork23 in summaries.keys() + assert fork34 in summaries.keys() + assert fork45 in summaries.keys()