Merge commit '371f76e377c197b6a53899c72650f9854c26fd0c' into checkpoint/long_lived_atari_from_main_371f76e377c197b6a53899c72650f9854c26fd0c

This commit is contained in:
Amine Khaldi 2022-08-04 18:32:31 +01:00
commit 5279f2b999
No known key found for this signature in database
GPG Key ID: B1C074FFC904E2D9
5 changed files with 117 additions and 70 deletions

View File

@ -578,7 +578,8 @@ class WeightProofHandler:
log.info("validate weight proof recent blocks")
if not _validate_recent_blocks(self.constants, wp_recent_chain_bytes, summary_bytes):
return False, uint32(0)
return True, self.get_fork_point(summaries)
fork_point, _ = self.get_fork_point(summaries)
return True, fork_point
def get_fork_point_no_validations(self, weight_proof: WeightProof) -> Tuple[bool, uint32]:
log.debug("get fork point skip validations")
@ -590,7 +591,8 @@ class WeightProofHandler:
if summaries is None:
log.warning("weight proof failed to validate sub epoch summaries")
return False, uint32(0)
return True, self.get_fork_point(summaries)
fork_height, _ = self.get_fork_point(summaries)
return True, fork_height
async def validate_weight_proof(self, weight_proof: WeightProof) -> Tuple[bool, uint32, List[SubEpochSummary]]:
assert self.blockchain is not None
@ -620,6 +622,7 @@ class WeightProofHandler:
log.error("failed weight proof sub epoch sample validation")
return False, uint32(0), []
fork_point, ses_fork_idx = self.get_fork_point(summaries)
# timing reference: 1 second
# TODO: Consider implementing an async polling closer for the executor.
with ProcessPoolExecutor(
@ -648,7 +651,7 @@ class WeightProofHandler:
# timing reference: 2 second
segments_validated, vdfs_to_validate = _validate_sub_epoch_segments(
self.constants, rng, wp_segment_bytes, summary_bytes
self.constants, rng, wp_segment_bytes, summary_bytes, ses_fork_idx
)
await asyncio.sleep(0) # break up otherwise multi-second sync code
if not segments_validated:
@ -686,9 +689,9 @@ class WeightProofHandler:
log.error("failed validating weight proof recent blocks")
return False, uint32(0), []
return True, self.get_fork_point(summaries), summaries
return True, fork_point, summaries
def get_fork_point(self, received_summaries: List[SubEpochSummary]) -> uint32:
def get_fork_point(self, received_summaries: List[SubEpochSummary]) -> Tuple[uint32, int]:
# iterate through sub epoch summaries to find fork point
fork_point_index = 0
ses_heights = self.blockchain.get_ses_heights()
@ -702,14 +705,12 @@ class WeightProofHandler:
break
fork_point_index = idx
if fork_point_index > 2:
if fork_point_index <= 2:
# Two summeries can have different blocks and still be identical
# This gets resolved after one full sub epoch
height = ses_heights[fork_point_index - 2]
else:
height = uint32(0)
return uint32(0), 0
return height
return ses_heights[fork_point_index - 2], fork_point_index
def _get_weights_for_sampling(
@ -994,6 +995,7 @@ def _validate_sub_epoch_segments(
rng: random.Random,
weight_proof_bytes: bytes,
summaries_bytes: List[bytes],
validate_from: int = 0,
):
summaries = summaries_from_bytes(summaries_bytes)
sub_epoch_segments: SubEpochSegments = SubEpochSegments.from_bytes(weight_proof_bytes)
@ -1018,6 +1020,11 @@ def _validate_sub_epoch_segments(
if not summaries[sub_epoch_n].reward_chain_hash == rc_sub_slot_hash:
log.error(f"failed reward_chain_hash validation sub_epoch {sub_epoch_n}")
return False
# skip validation up to fork height
if sub_epoch_n < validate_from:
continue
for idx, segment in enumerate(segments):
valid_segment, ip_iters, slot_iters, slots, vdf_list = _validate_segment(
constants, segment, curr_ssi, prev_ssi, curr_difficulty, prev_ses, idx == 0, sampled_seg_index == idx

View File

@ -77,7 +77,7 @@ class WalletBlockchain(BlockchainInterface):
latest_timestamp = self._latest_timestamp
if records is None:
success, _, _, records = await self._weight_proof_handler.validate_weight_proof(weight_proof, True)
success, _, records = await self._weight_proof_handler.validate_weight_proof(weight_proof, True)
assert success
assert records is not None and len(records) > 1

View File

@ -67,6 +67,7 @@ from chia.wallet.util.wallet_sync_utils import (
from chia.wallet.wallet_action import WalletAction
from chia.wallet.wallet_coin_record import WalletCoinRecord
from chia.wallet.wallet_state_manager import WalletStateManager
from chia.wallet.wallet_weight_proof_handler import get_wp_fork_point
@dataclasses.dataclass
@ -997,10 +998,7 @@ class WalletNode:
if old_proof is not None:
# If the weight proof fork point is in the past, rollback more to ensure we don't have duplicate
# state.
wp_fork_point = self.wallet_state_manager.weight_proof_handler.get_fork_point(
old_proof, weight_proof
)
fork_point = min(fork_point, wp_fork_point)
fork_point = min(fork_point, get_wp_fork_point(self.constants, old_proof, weight_proof))
await self.wallet_state_manager.blockchain.new_weight_proof(weight_proof, block_records)
if syncing:
@ -1160,13 +1158,16 @@ class WalletNode:
if weight_proof.get_hash() in self.valid_wp_cache:
valid, fork_point, summaries, block_records = self.valid_wp_cache[weight_proof.get_hash()]
else:
old_proof = self.wallet_state_manager.blockchain.synced_weight_proof
fork_point = get_wp_fork_point(self.constants, old_proof, weight_proof)
start_validation = time.time()
(
valid,
fork_point,
summaries,
block_records,
) = await self.wallet_state_manager.weight_proof_handler.validate_weight_proof(weight_proof)
) = await self.wallet_state_manager.weight_proof_handler.validate_weight_proof(
weight_proof, False, old_proof
)
if valid:
self.valid_wp_cache[weight_proof.get_hash()] = valid, fork_point, summaries, block_records

View File

@ -65,22 +65,23 @@ class WalletWeightProofHandler:
self._executor.shutdown(wait=True)
async def validate_weight_proof(
self, weight_proof: WeightProof, skip_segment_validation=False
) -> Tuple[bool, uint32, List[SubEpochSummary], List[BlockRecord]]:
self, weight_proof: WeightProof, skip_segment_validation: bool = False, old_proof: Optional[WeightProof] = None
) -> Tuple[bool, List[SubEpochSummary], List[BlockRecord]]:
validate_from = get_fork_ses_idx(old_proof, weight_proof)
task: asyncio.Task = asyncio.create_task(
self._validate_weight_proof_inner(weight_proof, skip_segment_validation)
self._validate_weight_proof_inner(weight_proof, skip_segment_validation, validate_from)
)
self._weight_proof_tasks.append(task)
valid, fork_point, summaries, block_records = await task
valid, summaries, block_records = await task
self._weight_proof_tasks.remove(task)
return valid, fork_point, summaries, block_records
return valid, summaries, block_records
async def _validate_weight_proof_inner(
self, weight_proof: WeightProof, skip_segment_validation: bool
) -> Tuple[bool, uint32, List[SubEpochSummary], List[BlockRecord]]:
self, weight_proof: WeightProof, skip_segment_validation: bool, validate_from: int
) -> Tuple[bool, List[SubEpochSummary], List[BlockRecord]]:
assert len(weight_proof.sub_epochs) > 0
if len(weight_proof.sub_epochs) == 0:
return False, uint32(0), [], []
return False, [], []
peak_height = weight_proof.recent_chain_data[-1].reward_chain_block.height
log.info(f"validate weight proof peak height {peak_height}")
@ -94,13 +95,13 @@ class WalletWeightProofHandler:
await asyncio.sleep(0) # break up otherwise multi-second sync code
if summaries is None:
log.error("weight proof failed sub epoch data validation")
return False, uint32(0), [], []
return False, [], []
seed = summaries[-2].get_hash()
rng = random.Random(seed)
if not validate_sub_epoch_sampling(rng, sub_epoch_weight_list, weight_proof):
log.error("failed weight proof sub epoch sample validation")
return False, uint32(0), [], []
return False, [], []
summary_bytes, wp_segment_bytes, wp_recent_chain_bytes = vars_to_bytes(summaries, weight_proof)
await asyncio.sleep(0) # break up otherwise multi-second sync code
@ -117,12 +118,12 @@ class WalletWeightProofHandler:
try:
if not skip_segment_validation:
segments_validated, vdfs_to_validate = _validate_sub_epoch_segments(
self._constants, rng, wp_segment_bytes, summary_bytes
self._constants, rng, wp_segment_bytes, summary_bytes, validate_from
)
await asyncio.sleep(0) # break up otherwise multi-second sync code
if not segments_validated:
return False, uint32(0), [], []
return False, [], []
vdf_chunks = chunks(vdfs_to_validate, self._num_processes)
for chunk in vdf_chunks:
@ -144,7 +145,7 @@ class WalletWeightProofHandler:
for vdf_task in vdf_tasks:
validated = await vdf_task
if not validated:
return False, uint32(0), [], []
return False, [], []
valid_recent_blocks, records_bytes = await recent_blocks_validation_task
finally:
@ -155,56 +156,65 @@ class WalletWeightProofHandler:
if not valid_recent_blocks:
log.error("failed validating weight proof recent blocks")
# Verify the data
return False, uint32(0), [], []
return False, [], []
records = [BlockRecord.from_bytes(b) for b in records_bytes]
return True, summaries, records
# TODO fix find fork point
return True, uint32(0), summaries, records
def get_fork_point(self, old_wp: Optional[WeightProof], new_wp: WeightProof) -> uint32:
"""
iterate through sub epoch summaries to find fork point. This method is conservative, it does not return the
actual fork point, it can return a height that is before the actual fork point.
"""
def get_wp_fork_point(constants: ConsensusConstants, old_wp: Optional[WeightProof], new_wp: WeightProof) -> uint32:
"""
iterate through sub epoch summaries to find fork point. This method is conservative, it does not return the
actual fork point, it can return a height that is before the actual fork point.
"""
if old_wp is None:
return uint32(0)
if old_wp is None:
return uint32(0)
old_ses = set()
overflow = 0
count = 0
for idx, new_ses in enumerate(new_wp.sub_epochs):
if idx == len(new_wp.sub_epochs) - 1 or idx == len(old_wp.sub_epochs):
break
if new_ses.reward_chain_hash != old_wp.sub_epochs[idx].reward_chain_hash:
break
for ses in old_wp.sub_epochs:
old_ses.add(ses.reward_chain_hash)
overflow = 0
count = 0
for new_ses in new_wp.sub_epochs:
if new_ses.reward_chain_hash in old_ses:
count += 1
overflow = new_ses.num_blocks_overflow
continue
else:
break
count = idx + 1
overflow = new_wp.sub_epochs[idx + 1].num_blocks_overflow
if new_wp.recent_chain_data[0].height < old_wp.recent_chain_data[-1].height:
# Try to find an exact fork point
if new_wp.recent_chain_data[0].height >= old_wp.recent_chain_data[0].height:
left_wp = old_wp
right_wp = new_wp
else:
left_wp = new_wp
right_wp = old_wp
r_index = 0
l_index = 0
while r_index < len(right_wp.recent_chain_data) and l_index < len(left_wp.recent_chain_data):
if right_wp.recent_chain_data[r_index].header_hash == left_wp.recent_chain_data[l_index].header_hash:
r_index += 1
new_wp_index = 0
old_wp_index = 0
while new_wp_index < len(new_wp.recent_chain_data) and old_wp_index < len(old_wp.recent_chain_data):
if new_wp.recent_chain_data[new_wp_index].header_hash == old_wp.recent_chain_data[old_wp_index].header_hash:
new_wp_index += 1
continue
# Keep incrementing left pointer until we find a match
l_index += 1
if r_index != 0:
old_wp_index += 1
if new_wp_index != 0:
# We found a matching block, this is the last matching block
return right_wp.recent_chain_data[r_index - 1].height
return new_wp.recent_chain_data[new_wp_index - 1].height
# Just return the matching sub epoch height
return uint32((self._constants.SUB_EPOCH_BLOCKS * count) - overflow)
# Just return the matching sub epoch height
return uint32((constants.SUB_EPOCH_BLOCKS * count) + overflow)
def get_fork_ses_idx(old_wp: Optional[WeightProof], new_wp: WeightProof) -> int:
"""
iterate through sub epoch summaries to find fork point. This method is conservative, it does not return the
actual fork point, it can return a height that is before the actual fork point.
"""
if old_wp is None:
return uint32(0)
ses_index = 0
for idx, new_ses in enumerate(new_wp.sub_epochs):
if new_ses.reward_chain_hash != old_wp.sub_epochs[idx].reward_chain_hash:
ses_index = idx
break
if idx == len(old_wp.sub_epochs) - 1:
ses_index = idx
break
return ses_index

View File

@ -6,6 +6,7 @@ from colorlog import getLogger
from chia.consensus.block_record import BlockRecord
from chia.consensus.block_rewards import calculate_base_farmer_reward, calculate_pool_reward
from chia.full_node.full_node_api import FullNodeAPI
from chia.full_node.weight_proof import WeightProofHandler
from chia.protocols import full_node_protocol, wallet_protocol
from chia.protocols.protocol_message_types import ProtocolMessageTypes
from chia.protocols.shared_protocol import Capability
@ -13,14 +14,17 @@ from chia.protocols.wallet_protocol import RequestAdditions, RespondAdditions, R
from chia.server.outbound_message import Message
from chia.simulator.simulator_protocol import FarmNewBlockProtocol
from chia.types.peer_info import PeerInfo
from chia.util.block_cache import BlockCache
from chia.util.hash import std_hash
from chia.util.ints import uint16, uint32, uint64
from chia.wallet.transaction_record import TransactionRecord
from chia.wallet.util.wallet_types import AmountWithPuzzlehash
from chia.wallet.wallet_weight_proof_handler import get_wp_fork_point
from tests.connection_utils import disconnect_all, disconnect_all_and_reconnect
from tests.pools.test_pool_rpc import wallet_is_synced
from tests.setup_nodes import test_constants
from tests.time_out_assert import time_out_assert
from tests.weight_proof.test_weight_proof import load_blocks_dont_validate
def wallet_height_at_least(wallet_node, h):
@ -559,3 +563,28 @@ class TestWalletSync:
response = RespondAdditions.from_bytes(res4.data)
assert response.proofs == []
assert len(response.coins) == 0
@pytest.mark.asyncio
async def test_get_wp_fork_point(self, default_10000_blocks):
blocks = default_10000_blocks
header_cache, height_to_hash, sub_blocks, summaries = await load_blocks_dont_validate(blocks)
wpf = WeightProofHandler(test_constants, BlockCache(sub_blocks, header_cache, height_to_hash, summaries))
wp1 = await wpf.get_proof_of_weight(header_cache[height_to_hash[uint32(9000)]].header_hash)
wp2 = await wpf.get_proof_of_weight(header_cache[height_to_hash[uint32(9030)]].header_hash)
wp3 = await wpf.get_proof_of_weight(header_cache[height_to_hash[uint32(7500)]].header_hash)
wp4 = await wpf.get_proof_of_weight(header_cache[height_to_hash[uint32(8700)]].header_hash)
wp5 = await wpf.get_proof_of_weight(header_cache[height_to_hash[uint32(9700)]].header_hash)
fork12 = get_wp_fork_point(test_constants, wp1, wp2)
fork13 = get_wp_fork_point(test_constants, wp3, wp1)
fork14 = get_wp_fork_point(test_constants, wp4, wp1)
fork23 = get_wp_fork_point(test_constants, wp3, wp2)
fork24 = get_wp_fork_point(test_constants, wp4, wp2)
fork34 = get_wp_fork_point(test_constants, wp3, wp4)
fork45 = get_wp_fork_point(test_constants, wp4, wp5)
assert fork14 == 8700
assert fork24 == 8700
assert fork12 == 9000
assert fork13 in summaries.keys()
assert fork23 in summaries.keys()
assert fork34 in summaries.keys()
assert fork45 in summaries.keys()