From 0fef0b933eb15eef18421c513410e870d6f8e218 Mon Sep 17 00:00:00 2001 From: Mariano Sorgente Date: Tue, 22 Oct 2019 16:44:01 +0900 Subject: [PATCH] Changed to mypy from pyright, fix tests (full node still broken) --- README.md | 8 ++ mypy.ini | 5 + pyrightconfig.json | 15 --- setup.py | 2 +- src/blockchain.py | 77 +++++++++------ src/consensus/constants.py | 4 +- src/consensus/pot_iterations.py | 6 +- src/farmer.py | 15 ++- src/full_node.py | 4 +- src/plotter.py | 9 +- src/protocols/farmer_protocol.py | 8 ++ src/protocols/peer_protocol.py | 94 ++++++++++-------- src/protocols/plotter_protocol.py | 10 ++ src/protocols/pool_protocol.py | 6 ++ src/protocols/shared_protocol.py | 3 + src/protocols/timelord_protocol.py | 5 + src/protocols/wallet_protocol.py | 7 ++ src/server/connection.py | 8 +- src/server/server.py | 14 +-- src/store/full_node_store.py | 4 +- src/timelord.py | 13 +-- src/types/block_body.py | 7 +- src/types/block_header.py | 9 +- src/types/challenge.py | 6 +- src/types/classgroup.py | 6 +- src/types/coinbase.py | 18 +--- src/types/fees_target.py | 2 + src/types/full_block.py | 6 +- src/types/peer_info.py | 6 +- src/types/proof_of_space.py | 13 +-- src/types/proof_of_time.py | 9 +- src/types/transaction.py | 6 +- src/types/trunk_block.py | 6 +- src/util/bin_methods.py | 18 ---- src/util/byte_types.py | 22 +++-- src/util/ints.py | 24 +++-- src/util/streamable.py | 147 ++++++++++++++++------------- src/util/struct_stream.py | 25 ++++- src/util/type_checking.py | 20 ++-- test1.py | 25 ----- tests/block_tools.py | 51 ++++++---- tests/test_blockchain.py | 139 +++++++++++++++++---------- tests/util/test_streamable.py | 16 ++-- tests/util/test_type_checking.py | 18 ++-- 44 files changed, 538 insertions(+), 378 deletions(-) create mode 100644 mypy.ini delete mode 100644 pyrightconfig.json delete mode 100644 src/util/bin_methods.py delete mode 100644 test1.py diff --git a/README.md b/README.md index 14113522d0e7..f18a310b9895 100644 --- a/README.md +++ b/README.md @@ -54,3 +54,11 @@ py.test tests -s -v flake8 src pyright ``` + +### Configure VS code +1. Install Python extension +2. Set the environment to ./.venv/bin/python +3. Install mypy plugin +4. Preferences > Settings > Python > Linting > flake8 enabled +5. Preferences > Settings > Python > Linting > mypy enabled +6. Preferences > Settings > mypy > Targets: set to ./src and ./tests diff --git a/mypy.ini b/mypy.ini new file mode 100644 index 000000000000..cbfbfc82fd90 --- /dev/null +++ b/mypy.ini @@ -0,0 +1,5 @@ +[mypy] +ignore_missing_imports = True + +[mypy-lib] +ignore_errors = True diff --git a/pyrightconfig.json b/pyrightconfig.json deleted file mode 100644 index f78018dd7b14..000000000000 --- a/pyrightconfig.json +++ /dev/null @@ -1,15 +0,0 @@ -{ - "ignore": ["./lib", - "./src/server/server.py", - "./src/types/block_header.py", - "./src/util/streamable.py", - "./src/util/type_checking.py", - "./src/util/cbor_message.py", - "./src/util/cbor.py", - "./src/util/byte_types.py"], - "include": ["./src"], - "pythonVersion": "3.7", - "venvPath": ".", - "venv": "./.venv", - "reportMissingImports": false - } \ No newline at end of file diff --git a/setup.py b/setup.py index ec3ccd6a03d5..db3b9e948186 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ from setuptools import setup dependencies = ['blspy', 'cbor2', 'pyyaml'] -dev_dependencies = ['pytest', 'flake8', 'ipython'] +dev_dependencies = ['pytest', 'flake8', 'ipython', 'mypy', 'pytest-asyncio'] setup( name='chiablockchain', diff --git a/src/blockchain.py b/src/blockchain.py index 15cdc39502c5..0c2e4f154715 100644 --- a/src/blockchain.py +++ b/src/blockchain.py @@ -42,12 +42,15 @@ class Blockchain: self.store = store self.heads: List[FullBlock] = [] - self.lca_block: FullBlock = None + self.lca_block: FullBlock self.height_to_hash: Dict[uint64, bytes32] = {} + + async def initialize(self): self.genesis = FullBlock.from_bytes(self.constants["GENESIS_BLOCK"]) - result = self.receive_block(self.genesis) + result = await self.receive_block(self.genesis) if result != ReceiveBlockResult.ADDED_TO_HEAD: raise InvalidGenesisBlock() + assert self.lca_block def get_current_heads(self) -> List[TrunkBlock]: """ @@ -89,9 +92,9 @@ class Blockchain: curr_block = curr_full_block.trunk_block trunks: List[Tuple[int, TrunkBlock]] = [] for height, index in sorted_heights: - if height > curr_block.challenge.height: + if height > curr_block.height: raise ValueError("Height is not valid for tip {tip_header_hash}") - while height < curr_block.challenge.height: + while height < curr_block.height: curr_full_block = (await self.store.get_block(curr_block.header.data.prev_header_hash)).trunk_block trunks.append((index, curr_block)) return [b for index, b in sorted(trunks)] @@ -102,9 +105,9 @@ class Blockchain: where both blockchains are equal. """ lca: TrunkBlock = self.lca_block.trunk_block - assert lca.challenge.height < alternate_chain[-1].challenge.height + assert lca.height < alternate_chain[-1].height low = 0 - high = lca.challenge.height + high = lca.height while low + 1 < high: mid = (low + high) // 2 if self.height_to_hash[uint64(mid)] != alternate_chain[mid].header.get_hash(): @@ -185,7 +188,7 @@ class Blockchain: block2 = await self.store.get_block(self.height_to_hash[height2]) if not block3: block3 = await self.store.get_block(self.height_to_hash[height3]) - assert block1 is not None and block2 is not None and block3 is not None + assert block2 is not None and block3 is not None # Current difficulty parameter (diff of block h = i - 1) Tc = await self.get_next_difficulty(block.prev_header_hash) @@ -278,7 +281,7 @@ class Blockchain: block1 = await self.store.get_block(self.height_to_hash[height1]) if not block2: block2 = await self.store.get_block(self.height_to_hash[height2]) - assert block1 is not None and block2 is not None + assert block2 is not None if block1: timestamp1 = block1.trunk_block.header.data.timestamp @@ -314,7 +317,7 @@ class Blockchain: # Block is valid and connected, so it can be added to the blockchain. await self.store.save_block(block) - if await self._reconsider_heads(block): + if await self._reconsider_heads(block, genesis): return ReceiveBlockResult.ADDED_TO_HEAD else: return ReceiveBlockResult.ADDED_AS_ORPHAN @@ -330,25 +333,27 @@ class Blockchain: return False # 2. Check Now+2hrs > timestamp > avg timestamp of last 11 blocks + prev_block: Optional[FullBlock] = None if not genesis: + # TODO: do something about first 11 blocks last_timestamps: List[uint64] = [] - prev_block: Optional[FullBlock] = await self.store.get_block(block.prev_header_hash) + prev_block = await self.store.get_block(block.prev_header_hash) + if not prev_block or not prev_block.trunk_block: + return False curr = prev_block while len(last_timestamps) < self.constants["NUMBER_OF_TIMESTAMPS"]: last_timestamps.append(curr.trunk_block.header.data.timestamp) - try: - curr = await self.store.get_block(curr.prev_header_hash) - except KeyError: + fetched = await self.store.get_block(curr.prev_header_hash) + if not fetched: break + curr = fetched if len(last_timestamps) != self.constants["NUMBER_OF_TIMESTAMPS"] and curr.body.coinbase.height != 0: return False - prev_time: uint64 = uint64(sum(last_timestamps) / len(last_timestamps)) + prev_time: uint64 = uint64(int(sum(last_timestamps) / len(last_timestamps))) if block.trunk_block.header.data.timestamp < prev_time: return False if block.trunk_block.header.data.timestamp > time.time() + self.constants["MAX_FUTURE_TIME"]: return False - else: - prev_block: Optional[FullBlock] = None # 3. Check filter hash is correct TODO @@ -359,10 +364,14 @@ class Blockchain: # 5. Check extension data, if any is added # 6. Compute challenge of parent + challenge_hash: bytes32 if not genesis: - challenge_hash: bytes32 = prev_block.trunk_block.challenge.get_hash() + assert prev_block + assert prev_block.trunk_block.challenge + challenge_hash = prev_block.trunk_block.challenge.get_hash() else: - challenge_hash: bytes32 = block.trunk_block.proof_of_time.output.challenge_hash + assert block.trunk_block.proof_of_time + challenge_hash = block.trunk_block.proof_of_time.output.challenge_hash # 7. Check plotter signature of header data is valid based on plotter key if not block.trunk_block.header.plotter_signature.verify( @@ -377,6 +386,7 @@ class Blockchain: # 9. Check coinbase height = parent coinbase height + 1 if not genesis: + assert prev_block if block.body.coinbase.height != prev_block.body.coinbase.height + 1: return False else: @@ -406,17 +416,21 @@ class Blockchain: and extends something in the blockchain. """ # 1. Validate unfinished block (check the rest of the conditions) - if not self.validate_unfinished_block(block, genesis): + if not (await self.validate_unfinished_block(block, genesis)): return False + difficulty: uint64 + ips: uint64 if not genesis: - difficulty: uint64 = await self.get_next_difficulty(block.prev_header_hash) - ips: uint64 = await self.get_next_ips(block.prev_header_hash) + difficulty = await self.get_next_difficulty(block.prev_header_hash) + ips = await self.get_next_ips(block.prev_header_hash) else: - difficulty: uint64 = uint64(self.constants["DIFFICULTY_STARTING"]) - ips: uint64 = uint64(self.constants["VDF_IPS_STARTING"]) + difficulty = uint64(self.constants["DIFFICULTY_STARTING"]) + ips = uint64(self.constants["VDF_IPS_STARTING"]) # 2. Check proof of space hash + if not block.trunk_block.challenge or not block.trunk_block.proof_of_time: + return False if block.trunk_block.proof_of_space.get_hash() != block.trunk_block.challenge.proof_of_space_hash: return False @@ -439,7 +453,7 @@ class Blockchain: if not genesis: prev_block: FullBlock = await self.store.get_block(block.prev_header_hash) - if not prev_block: + if not prev_block or not prev_block.trunk_block.challenge: return False # 5. and check if PoT.output.challenge_hash matches @@ -475,11 +489,11 @@ class Blockchain: return True - async def _reconsider_heights(self, old_lca: FullBlock, new_lca: FullBlock): + async def _reconsider_heights(self, old_lca: Optional[FullBlock], new_lca: FullBlock): """ Update the mapping from height to block hash, when the lca changes. """ - curr_old: TrunkBlock = old_lca.trunk_block if old_lca else None + curr_old: Optional[TrunkBlock] = old_lca.trunk_block if old_lca else None curr_new: TrunkBlock = new_lca.trunk_block while True: if not curr_old or curr_old.height < curr_new.height: @@ -497,7 +511,7 @@ class Blockchain: curr_new = (await self.store.get_block(curr_new.prev_header_hash)).trunk_block curr_old = (await self.store.get_block(curr_old.prev_header_hash)).trunk_block - async def _reconsider_lca(self): + async def _reconsider_lca(self, genesis: bool): """ Update the least common ancestor of the heads. This is useful, since we can just assume there is one block per height before the LCA (and use the height_to_hash dict). @@ -508,10 +522,13 @@ class Blockchain: i = heights.index(max(heights)) cur[i] = await self.store.get_block(cur[i].prev_header_hash) heights[i] = cur[i].height - self._reconsider_heights(self.lca_block, cur[0]) + if genesis: + await self._reconsider_heights(None, cur[0]) + else: + await self._reconsider_heights(self.lca_block, cur[0]) self.lca_block = cur[0] - async def _reconsider_heads(self, block: FullBlock) -> bool: + async def _reconsider_heads(self, block: FullBlock, genesis: bool) -> bool: """ When a new block is added, this is called, to check if the new block is heavier than one of the heads. @@ -521,6 +538,6 @@ class Blockchain: while len(self.heads) > self.constants["NUMBER_OF_HEADS"]: self.heads.sort(key=lambda b: b.weight, reverse=True) self.heads.pop() - self._reconsider_lca() + await self._reconsider_lca(genesis) return True return False diff --git a/src/consensus/constants.py b/src/consensus/constants.py index da845e7415a6..441a1adbd122 100644 --- a/src/consensus/constants.py +++ b/src/consensus/constants.py @@ -1,4 +1,6 @@ -constants = { +from typing import Dict, Any + +constants: Dict[str, Any] = { "NUMBER_OF_HEADS": 3, # The number of tips each full node keeps track of and propagates "DIFFICULTY_STARTING": 500, # These are in units of 2^32 "DIFFICULTY_FACTOR": 3, # The next difficulty is truncated to range [prev / FACTOR, prev * FACTOR] diff --git a/src/consensus/pot_iterations.py b/src/consensus/pot_iterations.py index d03a05ecc6d9..073d985dcde7 100644 --- a/src/consensus/pot_iterations.py +++ b/src/consensus/pot_iterations.py @@ -45,7 +45,7 @@ def calculate_iterations_quality(quality: bytes32, size: uint8, difficulty: uint min_iterations = min_block_time * vdf_ips dec_iters = (Decimal(int(difficulty) << 32) * (_quality_to_decimal(quality) / _expected_plot_size(size))) - iters_final = uint64(min_iterations + dec_iters.to_integral_exact(rounding=ROUND_UP)) + iters_final = uint64(int(min_iterations + dec_iters.to_integral_exact(rounding=ROUND_UP))) assert iters_final >= 1 return iters_final @@ -74,5 +74,5 @@ def calculate_ips_from_iterations(proof_of_space: ProofOfSpace, challenge_hash: min_iterations = uint64(iterations - iters_rounded) ips = min_iterations / min_block_time assert ips >= 1 - assert uint64(ips) == ips - return uint64(ips) + assert uint64(int(ips)) == ips + return uint64(int(ips)) diff --git a/src/farmer.py b/src/farmer.py index 0ae3d4c89ac3..05b24c61b055 100644 --- a/src/farmer.py +++ b/src/farmer.py @@ -114,9 +114,9 @@ async def respond_proof_of_space(response: plotter_protocol.RespondProofOfSpace) async with state.lock: estimate_secs: float = number_iters / state.proof_of_time_estimate_ips if estimate_secs < config['pool_share_threshold']: - request = plotter_protocol.RequestPartialProof(response.quality, - sha256(bytes.fromhex(config['farmer_target'])).digest()) - yield OutboundMessage(NodeType.PLOTTER, Message("request_partial_proof", request), Delivery.RESPOND) + request1 = plotter_protocol.RequestPartialProof(response.quality, + sha256(bytes.fromhex(config['farmer_target'])).digest()) + yield OutboundMessage(NodeType.PLOTTER, Message("request_partial_proof", request1), Delivery.RESPOND) if estimate_secs < config['propagate_threshold']: async with state.lock: if new_proof_height not in state.coinbase_rewards: @@ -124,10 +124,10 @@ async def respond_proof_of_space(response: plotter_protocol.RespondProofOfSpace) return coinbase, signature = state.coinbase_rewards[new_proof_height] - request = farmer_protocol.RequestHeaderHash(challenge_hash, coinbase, signature, - bytes.fromhex(config['farmer_target']), response.proof) + request2 = farmer_protocol.RequestHeaderHash(challenge_hash, coinbase, signature, + bytes.fromhex(config['farmer_target']), response.proof) - yield OutboundMessage(NodeType.FULL_NODE, Message("request_header_hash", request), Delivery.BROADCAST) + yield OutboundMessage(NodeType.FULL_NODE, Message("request_header_hash", request2), Delivery.BROADCAST) @api_request @@ -235,8 +235,7 @@ async def proof_of_space_arrived(proof_of_space_arrived: farmer_protocol.ProofOf if proof_of_space_arrived.height not in state.unfinished_challenges: state.unfinished_challenges[proof_of_space_arrived.height] = [] else: - state.unfinished_challenges[proof_of_space_arrived.height].append( - proof_of_space_arrived.quality_string) + state.unfinished_challenges[proof_of_space_arrived.height].append(proof_of_space_arrived.quality) @api_request diff --git a/src/full_node.py b/src/full_node.py index a4d8f276f319..0d476f6cc8f6 100644 --- a/src/full_node.py +++ b/src/full_node.py @@ -46,6 +46,7 @@ async def send_heads_to_farmers() -> AsyncGenerator[OutboundMessage, None]: requests: List[farmer_protocol.ProofOfSpaceFinalized] = [] async with (await store.get_lock()): for head in blockchain.get_current_heads(): + assert head.proof_of_time and head.challenge prev_challenge_hash = head.proof_of_time.output.challenge_hash challenge_hash = head.challenge.get_hash() height = head.challenge.height @@ -70,6 +71,7 @@ async def send_challenges_to_timelords() -> AsyncGenerator[OutboundMessage, None requests: List[timelord_protocol.ChallengeStart] = [] async with (await store.get_lock()): for head in blockchain.get_current_heads(): + assert head.challenge challenge_hash = head.challenge.get_hash() requests.append(timelord_protocol.ChallengeStart(challenge_hash, head.challenge.height)) for request in requests: @@ -84,7 +86,7 @@ async def on_connect() -> AsyncGenerator[OutboundMessage, None]: async with (await store.get_lock()): heads: List[TrunkBlock] = blockchain.get_current_heads() for h in heads: - blocks.append(blockchain.get_block(h.header.get_hash())) + blocks.append(await blockchain.get_block(h.header.get_hash())) for block in blocks: request = peer_protocol.Block(block) yield OutboundMessage(NodeType.FULL_NODE, Message("block", request), Delivery.RESPOND) diff --git a/src/plotter.py b/src/plotter.py index 0dc1a3e57580..fee98da6c433 100644 --- a/src/plotter.py +++ b/src/plotter.py @@ -16,7 +16,7 @@ from src.server.outbound_message import OutboundMessage, Delivery, Message, Node class PlotterState: # From filename to prover - provers = {} + provers: Dict[str, DiskProver] = {} lock: Lock = Lock() # From quality to (challenge_hash, filename, index) challenge_hashes: Dict[bytes32, Tuple[bytes32, str, uint8]] = {} @@ -102,18 +102,19 @@ async def request_proof_of_space(request: plotter_protocol.RequestProofOfSpace): log.warning(f"Quality {request.quality} not found") return if index is not None: + proof_xs: bytes try: - proof_xs: bytes = state.provers[filename].get_full_proof(challenge_hash, index) + proof_xs = state.provers[filename].get_full_proof(challenge_hash, index) except RuntimeError: state.provers[filename] = DiskProver(filename) - proof_xs: bytes = state.provers[filename].get_full_proof(challenge_hash, index) + proof_xs = state.provers[filename].get_full_proof(challenge_hash, index) pool_pubkey = PublicKey.from_bytes(bytes.fromhex(config['plots'][filename]['pool_pk'])) plot_pubkey = PrivateKey.from_bytes(bytes.fromhex(config['plots'][filename]['sk'])).get_public_key() proof_of_space: ProofOfSpace = ProofOfSpace(pool_pubkey, plot_pubkey, uint8(config['plots'][filename]['k']), - list(proof_xs)) + [uint8(b) for b in proof_xs]) response = plotter_protocol.RespondProofOfSpace( request.quality, diff --git a/src/protocols/farmer_protocol.py b/src/protocols/farmer_protocol.py index 597446c74948..88e82c91cb44 100644 --- a/src/protocols/farmer_protocol.py +++ b/src/protocols/farmer_protocol.py @@ -4,12 +4,14 @@ from src.types.sized_bytes import bytes32 from src.util.ints import uint64, uint32 from src.types.proof_of_space import ProofOfSpace from src.types.coinbase import CoinbaseInfo +from dataclasses import dataclass """ Protocol between farmer and full node. """ +@dataclass(frozen=True) @cbor_message(tag=2000) class ProofOfSpaceFinalized: challenge_hash: bytes32 @@ -18,17 +20,20 @@ class ProofOfSpaceFinalized: difficulty: uint64 +@dataclass(frozen=True) @cbor_message(tag=2001) class ProofOfSpaceArrived: height: uint32 quality: bytes32 +@dataclass(frozen=True) @cbor_message(tag=2002) class DeepReorgNotification: pass +@dataclass(frozen=True) @cbor_message(tag=2003) class RequestHeaderHash: challenge_hash: bytes32 @@ -38,12 +43,14 @@ class RequestHeaderHash: proof_of_space: ProofOfSpace +@dataclass(frozen=True) @cbor_message(tag=2004) class HeaderHash: pos_hash: bytes32 header_hash: bytes32 +@dataclass(frozen=True) @cbor_message(tag=2005) class HeaderSignature: pos_hash: bytes32 @@ -51,6 +58,7 @@ class HeaderSignature: header_signature: PrependSignature +@dataclass(frozen=True) @cbor_message(tag=2006) class ProofOfTimeRate: pot_estimate_ips: uint64 diff --git a/src/protocols/peer_protocol.py b/src/protocols/peer_protocol.py index 24056f1b6e38..948a337fdd30 100644 --- a/src/protocols/peer_protocol.py +++ b/src/protocols/peer_protocol.py @@ -7,116 +7,130 @@ from src.types.proof_of_time import ProofOfTime from src.types.trunk_block import TrunkBlock from src.types.full_block import FullBlock from src.types.peer_info import PeerInfo +from dataclasses import dataclass """ Protocol between full nodes. """ -""" -Receive a transaction id from a peer. -""" +@dataclass(frozen=True) @cbor_message(tag=4000) class TransactionId: + """ + Receive a transaction id from a peer. + """ transaction_id: bytes32 -""" -Request a transaction from a peer. -""" +@dataclass(frozen=True) @cbor_message(tag=4001) class RequestTransaction: + """ + Request a transaction from a peer. + """ transaction_id: bytes32 -""" -Receive a transaction from a peer. -""" +@dataclass(frozen=True) @cbor_message(tag=4002) -class Transaction: +class NewTransaction: + """ + Receive a transaction from a peer. + """ transaction: Transaction -""" -Receive a new proof of time from a peer. -""" +@dataclass(frozen=True) @cbor_message(tag=4003) class NewProofOfTime: + """ + Receive a new proof of time from a peer. + """ proof: ProofOfTime -""" -Receive an unfinished block from a peer. -""" +@dataclass(frozen=True) @cbor_message(tag=4004) class UnfinishedBlock: + """ + Receive an unfinished block from a peer. + """ # Block that does not have ProofOfTime and Challenge block: FullBlock -""" -Requests a block from a peer. -""" +@dataclass(frozen=True) @cbor_message(tag=4005) class RequestBlock: + """ + Requests a block from a peer. + """ header_hash: bytes32 -""" -Receive a block from a peer. -""" +@dataclass(frozen=True) @cbor_message(tag=4006) class Block: + """ + Receive a block from a peer. + """ block: FullBlock -""" -Return full list of peers -""" +@dataclass(frozen=True) @cbor_message(tag=4007) class RequestPeers: + """ + Return full list of peers + """ pass -""" -Update list of peers -""" +@dataclass(frozen=True) @cbor_message(tag=4008) class Peers: + """ + Update list of peers + """ peer_list: List[PeerInfo] -""" -Request trunks of blocks that are ancestors of the specified tip. -""" +@dataclass(frozen=True) @cbor_message(tag=4009) class RequestTrunkBlocks: + """ + Request trunks of blocks that are ancestors of the specified tip. + """ tip_header_hash: bytes32 heights: List[uint64] -""" -Sends trunk blocks that are ancestors of the specified tip, at the specified heights. -""" +@dataclass(frozen=True) @cbor_message(tag=4010) class TrunkBlocks: + """ + Sends trunk blocks that are ancestors of the specified tip, at the specified heights. + """ tip_header_hash: bytes32 trunk_blocks: List[TrunkBlock] -""" -Request download of blocks, in the blockchain that has 'tip_header_hash' as the tip -""" +@dataclass(frozen=True) @cbor_message(tag=4011) class RequestSyncBlocks: + """ + Request download of blocks, in the blockchain that has 'tip_header_hash' as the tip + """ tip_header_hash: bytes32 heights: List[uint64] -""" -Send blocks to peer. -""" +@dataclass(frozen=True) @cbor_message(tag=4012) class SyncBlocks: + """ + Send blocks to peer. + """ tip_header_hash: bytes32 blocks: List[FullBlock] diff --git a/src/protocols/plotter_protocol.py b/src/protocols/plotter_protocol.py index 65489461b823..20b718e6d7f2 100644 --- a/src/protocols/plotter_protocol.py +++ b/src/protocols/plotter_protocol.py @@ -4,22 +4,26 @@ from src.util.streamable import List from src.types.sized_bytes import bytes32 from src.types.proof_of_space import ProofOfSpace from src.util.ints import uint8 +from dataclasses import dataclass """ Protocol between plotter and farmer. """ +@dataclass(frozen=True) @cbor_message(tag=1000) class PlotterHandshake: pool_pubkeys: List[PublicKey] +@dataclass(frozen=True) @cbor_message(tag=1001) class NewChallenge: challenge_hash: bytes32 +@dataclass(frozen=True) @cbor_message(tag=1002) class ChallengeResponse: challenge_hash: bytes32 @@ -27,35 +31,41 @@ class ChallengeResponse: plot_size: uint8 +@dataclass(frozen=True) @cbor_message(tag=1003) class RequestProofOfSpace: quality: bytes32 +@dataclass(frozen=True) @cbor_message(tag=1004) class RespondProofOfSpace: quality: bytes32 proof: ProofOfSpace +@dataclass(frozen=True) @cbor_message(tag=1005) class RequestHeaderSignature: quality: bytes32 header_hash: bytes32 +@dataclass(frozen=True) @cbor_message(tag=1006) class RespondHeaderSignature: quality: bytes32 header_hash_signature: PrependSignature +@dataclass(frozen=True) @cbor_message(tag=1007) class RequestPartialProof: quality: bytes32 farmer_target_hash: bytes32 +@dataclass(frozen=True) @cbor_message(tag=1008) class RespondPartialProof: quality: bytes32 diff --git a/src/protocols/pool_protocol.py b/src/protocols/pool_protocol.py index 1b598c55dabe..0df1e06e2ef9 100644 --- a/src/protocols/pool_protocol.py +++ b/src/protocols/pool_protocol.py @@ -6,24 +6,28 @@ from src.util.ints import uint32, uint64 from src.types.coinbase import CoinbaseInfo from src.types.challenge import Challenge from src.types.proof_of_space import ProofOfSpace +from dataclasses import dataclass """ Protocol between farmer and pool. """ +@dataclass(frozen=True) @streamable class SignedCoinbase: coinbase: CoinbaseInfo coinbase_signature: PrependSignature +@dataclass(frozen=True) @cbor_message(tag=5000) class RequestData: min_height: Optional[uint32] farmer_id: Optional[str] +@dataclass(frozen=True) @cbor_message(tag=5001) class RespondData: posting_url: str @@ -32,6 +36,7 @@ class RespondData: coinbase_info: List[SignedCoinbase] +@dataclass(frozen=True) @cbor_message(tag=5002) class Partial: challenge: Challenge @@ -41,6 +46,7 @@ class Partial: signature: PrependSignature +@dataclass(frozen=True) @cbor_message(tag=5003) class PartialAck: pass diff --git a/src/protocols/shared_protocol.py b/src/protocols/shared_protocol.py index fde106e33ef8..e52706835839 100644 --- a/src/protocols/shared_protocol.py +++ b/src/protocols/shared_protocol.py @@ -1,17 +1,20 @@ from src.util.cbor_message import cbor_message from src.types.sized_bytes import bytes32 +from dataclasses import dataclass protocol_version = "0.0.1" """ Handshake when establishing a connection between two servers. """ +@dataclass(frozen=True) @cbor_message(tag=7000) class Handshake: version: str node_id: bytes32 +@dataclass(frozen=True) @cbor_message(tag=7001) class HandshakeAck: pass diff --git a/src/protocols/timelord_protocol.py b/src/protocols/timelord_protocol.py index ceba9b882220..e6e70e6bfcba 100644 --- a/src/protocols/timelord_protocol.py +++ b/src/protocols/timelord_protocol.py @@ -2,6 +2,7 @@ from src.util.cbor_message import cbor_message from src.types.sized_bytes import bytes32 from src.util.ints import uint32, uint64 from src.types.proof_of_time import ProofOfTime +from dataclasses import dataclass """ Protocol between timelord and full node. @@ -12,22 +13,26 @@ If don't have the unfinished block, ignore Validate PoT Call self.Block """ +@dataclass(frozen=True) @cbor_message(tag=3000) class ProofOfTimeFinished: proof: ProofOfTime +@dataclass(frozen=True) @cbor_message(tag=3001) class ChallengeStart: challenge_hash: bytes32 height: uint32 +@dataclass(frozen=True) @cbor_message(tag=3002) class ChallengeEnd: challenge_hash: bytes32 +@dataclass(frozen=True) @cbor_message(tag=3003) class ProofOfSpaceInfo: challenge_hash: bytes32 diff --git a/src/protocols/wallet_protocol.py b/src/protocols/wallet_protocol.py index 57d390e62bfd..677cfec8ea17 100644 --- a/src/protocols/wallet_protocol.py +++ b/src/protocols/wallet_protocol.py @@ -7,29 +7,34 @@ from src.types.challenge import Challenge from src.types.proof_of_space import ProofOfSpace from src.types.proof_of_time import ProofOfTime from src.types.transaction import Transaction +from dataclasses import dataclass """ Protocol between wallet (SPV node) and full node. """ +@dataclass(frozen=True) @cbor_message(tag=6000) class SendTransaction: transaction: Transaction +@dataclass(frozen=True) @cbor_message(tag=6001) class NewHead: header_hash: bytes32 height: uint32 +@dataclass(frozen=True) @cbor_message(tag=6002) class RequestHeaders: header_hash: bytes32 previous_heights_desired: List[uint32] +@dataclass(frozen=True) @cbor_message(tag=6003) class Headers: proof_of_time: ProofOfTime @@ -38,11 +43,13 @@ class Headers: bip158_filter: bytes +@dataclass(frozen=True) @cbor_message(tag=6004) class RequestBody: body_hash: bytes32 +@dataclass(frozen=True) @cbor_message(tag=6005) class RespondBody: body: BlockBody diff --git a/src/server/connection.py b/src/server/connection.py index 517f71da08da..658c700f5cae 100644 --- a/src/server/connection.py +++ b/src/server/connection.py @@ -1,6 +1,6 @@ from asyncio import StreamReader, StreamWriter from asyncio import Lock -from typing import List +from typing import List, Any from src.util import cbor from src.server.outbound_message import Message, NodeType from src.types.sized_bytes import bytes32 @@ -35,9 +35,9 @@ class Connection: async def read_one_message(self) -> Message: size = await self.reader.readexactly(LENGTH_BYTES) full_message_length = int.from_bytes(size, "big") - full_message = await self.reader.readexactly(full_message_length) - full_message = cbor.loads(full_message) - return Message(full_message["function"], full_message["data"]) + full_message: bytes = await self.reader.readexactly(full_message_length) + full_message_loaded: Any = cbor.loads(full_message) + return Message(full_message_loaded["function"], full_message_loaded["data"]) def close(self): self.writer.close() diff --git a/src/server/server.py b/src/server/server.py index aca62189b90e..e7119bb28ab8 100644 --- a/src/server/server.py +++ b/src/server/server.py @@ -1,7 +1,7 @@ import logging import asyncio import random -from typing import Tuple, AsyncGenerator, Callable, Optional +from typing import Tuple, AsyncGenerator, Callable, Optional, List from types import ModuleType from lib.aiter.aiter.server import start_server_aiter from lib.aiter.aiter.map_aiter import map_aiter @@ -155,18 +155,18 @@ async def expand_outbound_messages(pair: Tuple[Connection, OutboundMessage]) -> yield connection, outbound_message.message elif outbound_message.delivery_method == Delivery.RANDOM: # Select a random peer. - to_yield = None + to_yield_single: Tuple[Connection, Message] async with global_connections.get_lock(): - typed_peers = [peer for peer in await global_connections.get_connections() - if peer.connection_type == outbound_message.peer_type] + typed_peers: List[Connection] = [peer for peer in await global_connections.get_connections() + if peer.connection_type == outbound_message.peer_type] if len(typed_peers) == 0: return - to_yield = (random.choice(typed_peers), outbound_message.message) - yield to_yield + to_yield_single = (random.choice(typed_peers), outbound_message.message) + yield to_yield_single elif (outbound_message.delivery_method == Delivery.BROADCAST or outbound_message.delivery_method == Delivery.BROADCAST_TO_OTHERS): # Broadcast to all peers. - to_yield = [] + to_yield: List[Tuple[Connection, Message]] = [] async with global_connections.get_lock(): for peer in await global_connections.get_connections(): if peer.connection_type == outbound_message.peer_type: diff --git a/src/store/full_node_store.py b/src/store/full_node_store.py index e245af07c639..a5a666a9ab4b 100644 --- a/src/store/full_node_store.py +++ b/src/store/full_node_store.py @@ -14,7 +14,7 @@ class FullNodeStore: def __init__(self): self.lock = Lock() - def initialize(self): + async def initialize(self): self.full_blocks: Dict[str, FullBlock] = {} self.sync_mode: bool = True @@ -48,7 +48,7 @@ class FullNodeStore: self.full_blocks[block.header_hash] = block async def get_block(self, header_hash: str) -> Optional[FullBlock]: - self.full_blocks.get(header_hash) + return self.full_blocks.get(header_hash) async def set_sync_mode(self, sync_mode: bool): self.sync_mode = sync_mode diff --git a/src/timelord.py b/src/timelord.py index ed65b9420941..2af6b431396d 100644 --- a/src/timelord.py +++ b/src/timelord.py @@ -13,7 +13,8 @@ from src.util.api_decorators import api_request from src.protocols import timelord_protocol from src.types.proof_of_time import ProofOfTimeOutput, ProofOfTime from src.types.classgroup import ClassgroupElement -from src.util.ints import uint8 +from src.types.sized_bytes import bytes32 +from src.util.ints import uint8, uint32, uint64 from src.consensus.constants import constants from src.server.outbound_message import OutboundMessage, Delivery, Message, NodeType @@ -24,9 +25,9 @@ class TimelordState: active_discriminants: Dict = {} active_discriminants_start_time: Dict = {} pending_iters: Dict = {} - done_discriminants = [] - seen_discriminants = [] - active_heights = [] + done_discriminants: List[bytes32] = [] + seen_discriminants: List[bytes32] = [] + active_heights: List[uint32] = [] log = logging.getLogger(__name__) @@ -83,7 +84,7 @@ async def challenge_start(challenge_start: timelord_protocol.ChallengeStart): e_to_str = str(e) log.error(f"Connection to VDF server error message: {e_to_str}") await asyncio.sleep(5) - if not writer: + if not writer or not reader: raise Exception("Unable to connect to VDF server") writer.write((str(len(str(disc))) + str(disc)).encode()) @@ -134,7 +135,7 @@ async def challenge_start(challenge_start: timelord_protocol.ChallengeStart): e_to_str = str(e) log.error(f"Socket error: {e_to_str}") - iterations_needed = int.from_bytes(stdout_bytes_io.read(8), "big", signed=True) + iterations_needed = uint64(int.from_bytes(stdout_bytes_io.read(8), "big", signed=True)) y = ClassgroupElement.parse(stdout_bytes_io) proof_bytes: bytes = stdout_bytes_io.read() diff --git a/src/types/block_body.py b/src/types/block_body.py index b3af15968bf1..43d1f5f59fb7 100644 --- a/src/types/block_body.py +++ b/src/types/block_body.py @@ -1,12 +1,15 @@ from typing import Optional from blspy import PrependSignature, Signature -from src.util.streamable import streamable +from src.util.streamable import streamable, Streamable from src.types.coinbase import CoinbaseInfo from src.types.fees_target import FeesTarget from src.types.sized_bytes import bytes32 +from dataclasses import dataclass + +@dataclass(frozen=True) @streamable -class BlockBody: +class BlockBody(Streamable): coinbase: CoinbaseInfo coinbase_signature: PrependSignature fees_target_info: FeesTarget diff --git a/src/types/block_header.py b/src/types/block_header.py index ddbc76848c1d..abcaf55c5a23 100644 --- a/src/types/block_header.py +++ b/src/types/block_header.py @@ -1,11 +1,13 @@ from blspy import PrependSignature -from src.util.streamable import streamable +from src.util.streamable import streamable, Streamable from src.util.ints import uint64 from src.types.sized_bytes import bytes32 +from dataclasses import dataclass +@dataclass(frozen=True) @streamable -class BlockHeaderData: +class BlockHeaderData(Streamable): prev_header_hash: bytes32 timestamp: uint64 filter_hash: bytes32 @@ -14,8 +16,9 @@ class BlockHeaderData: extension_data: bytes32 +@dataclass(frozen=True) @streamable -class BlockHeader: +class BlockHeader(Streamable): data: BlockHeaderData plotter_signature: PrependSignature diff --git a/src/types/challenge.py b/src/types/challenge.py index 446dc102c5b1..96119ee698b2 100644 --- a/src/types/challenge.py +++ b/src/types/challenge.py @@ -1,10 +1,12 @@ -from src.util.streamable import streamable +from src.util.streamable import streamable, Streamable from src.types.sized_bytes import bytes32 from src.util.ints import uint32, uint64 +from dataclasses import dataclass +@dataclass(frozen=True) @streamable -class Challenge: +class Challenge(Streamable): proof_of_space_hash: bytes32 proof_of_time_output_hash: bytes32 height: uint32 diff --git a/src/types/classgroup.py b/src/types/classgroup.py index 143cd52be069..02398b60c11a 100644 --- a/src/types/classgroup.py +++ b/src/types/classgroup.py @@ -1,8 +1,10 @@ -from src.util.streamable import streamable +from src.util.streamable import streamable, Streamable from src.util.ints import int1024 +from dataclasses import dataclass +@dataclass(frozen=True) @streamable -class ClassgroupElement: +class ClassgroupElement(Streamable): a: int1024 b: int1024 diff --git a/src/types/coinbase.py b/src/types/coinbase.py index 76e42482e367..aa77a6461296 100644 --- a/src/types/coinbase.py +++ b/src/types/coinbase.py @@ -1,22 +1,12 @@ -from src.util.streamable import streamable +from src.util.streamable import streamable, Streamable from src.types.sized_bytes import bytes32 from src.util.ints import uint64, uint32 from dataclasses import dataclass -# @streamable - -@dataclass -class CoinbaseInfo: +@dataclass(frozen=True) +@streamable +class CoinbaseInfo(Streamable): height: uint32 amount: uint64 puzzle_hash: bytes32 - - -def f(c: CoinbaseInfo) -> CoinbaseInfo: - return c - - -a: int = f(124) - -b = CoinbaseInfo \ No newline at end of file diff --git a/src/types/fees_target.py b/src/types/fees_target.py index 3d40c20249f0..6555d6403a76 100644 --- a/src/types/fees_target.py +++ b/src/types/fees_target.py @@ -1,8 +1,10 @@ from src.util.streamable import streamable from src.types.sized_bytes import bytes32 from src.util.ints import uint64 +from dataclasses import dataclass +@dataclass(frozen=True) @streamable class FeesTarget: puzzle_hash: bytes32 diff --git a/src/types/full_block.py b/src/types/full_block.py index 447040f07543..d1409d667fb7 100644 --- a/src/types/full_block.py +++ b/src/types/full_block.py @@ -1,12 +1,14 @@ from src.util.ints import uint32, uint64 from src.types.sized_bytes import bytes32 -from src.util.streamable import streamable +from src.util.streamable import streamable, Streamable from src.types.block_body import BlockBody from src.types.trunk_block import TrunkBlock +from dataclasses import dataclass +@dataclass(frozen=True) @streamable -class FullBlock: +class FullBlock(Streamable): trunk_block: TrunkBlock body: BlockBody diff --git a/src/types/peer_info.py b/src/types/peer_info.py index 784939e22384..3a923cd8a4b0 100644 --- a/src/types/peer_info.py +++ b/src/types/peer_info.py @@ -1,10 +1,12 @@ -from src.util.streamable import streamable +from src.util.streamable import streamable, Streamable from src.types.sized_bytes import bytes32 from src.util.ints import uint32 +from dataclasses import dataclass +@dataclass(frozen=True) @streamable -class PeerInfo: +class PeerInfo(Streamable): host: str port: uint32 node_id: bytes32 diff --git a/src/types/proof_of_space.py b/src/types/proof_of_space.py index a0d8e4c269df..d34184890c15 100644 --- a/src/types/proof_of_space.py +++ b/src/types/proof_of_space.py @@ -2,34 +2,31 @@ from typing import List, Optional from hashlib import sha256 from chiapos import Verifier from blspy import PublicKey -from src.util.streamable import streamable +from src.util.streamable import streamable, Streamable from src.util.ints import uint8 from src.types.sized_bytes import bytes32 +from dataclasses import dataclass +@dataclass(frozen=True) @streamable -class ProofOfSpace: +class ProofOfSpace(Streamable): pool_pubkey: PublicKey plot_pubkey: PublicKey size: uint8 proof: List[uint8] - _cached_quality = None - def get_plot_seed(self) -> bytes32: return self.calculate_plot_seed(self.pool_pubkey, self.plot_pubkey) def verify_and_get_quality(self, challenge_hash: bytes32) -> Optional[bytes32]: - if self._cached_quality: - return self._cached_quality v: Verifier = Verifier() plot_seed: bytes32 = self.get_plot_seed() quality_str = v.validate_proof(plot_seed, self.size, challenge_hash, bytes(self.proof)) if not quality_str: return None - self._cached_quality: bytes32 = self.quality_str_to_quality(challenge_hash, quality_str) - return self._cached_quality + return self.quality_str_to_quality(challenge_hash, quality_str) @staticmethod def calculate_plot_seed(pool_pubkey: PublicKey, plot_pubkey: PublicKey) -> bytes32: diff --git a/src/types/proof_of_time.py b/src/types/proof_of_time.py index f3c73e752092..b671ad02cc14 100644 --- a/src/types/proof_of_time.py +++ b/src/types/proof_of_time.py @@ -1,6 +1,7 @@ from typing import List -from src.util.streamable import streamable +from dataclasses import dataclass from src.types.sized_bytes import bytes32 +from src.util.streamable import streamable, Streamable from src.types.classgroup import ClassgroupElement from src.util.ints import uint8, uint64 from lib.chiavdf.inkfish.proof_of_time import check_proof_of_time_nwesolowski @@ -8,15 +9,17 @@ from lib.chiavdf.inkfish.create_discriminant import create_discriminant from lib.chiavdf.inkfish.classgroup import ClassGroup +@dataclass(frozen=True) @streamable -class ProofOfTimeOutput: +class ProofOfTimeOutput(Streamable): challenge_hash: bytes32 number_of_iterations: uint64 output: ClassgroupElement +@dataclass(frozen=True) @streamable -class ProofOfTime: +class ProofOfTime(Streamable): output: ProofOfTimeOutput witness_type: uint8 witness: List[uint8] diff --git a/src/types/transaction.py b/src/types/transaction.py index d16c39c8893e..1630fd31ee45 100644 --- a/src/types/transaction.py +++ b/src/types/transaction.py @@ -1,6 +1,8 @@ -from src.util.streamable import streamable +from src.util.streamable import streamable, Streamable +from dataclasses import dataclass +@dataclass(frozen=True) @streamable -class Transaction: +class Transaction(Streamable): pass diff --git a/src/types/trunk_block.py b/src/types/trunk_block.py index 9294c7cf94e2..0d4227aabfda 100644 --- a/src/types/trunk_block.py +++ b/src/types/trunk_block.py @@ -1,13 +1,15 @@ from typing import Optional -from src.util.streamable import streamable +from dataclasses import dataclass +from src.util.streamable import streamable, Streamable from src.types.block_header import BlockHeader from src.types.challenge import Challenge from src.types.proof_of_space import ProofOfSpace from src.types.proof_of_time import ProofOfTime +@dataclass(frozen=True) @streamable -class TrunkBlock: +class TrunkBlock(Streamable): proof_of_space: ProofOfSpace proof_of_time: Optional[ProofOfTime] challenge: Optional[Challenge] diff --git a/src/util/bin_methods.py b/src/util/bin_methods.py deleted file mode 100644 index 82293ede7a10..000000000000 --- a/src/util/bin_methods.py +++ /dev/null @@ -1,18 +0,0 @@ -import io - -from typing import Any - - -class BinMethods: - """ - Create "from_bytes" and "serialize" methods in terms of "parse" and "stream" methods. - """ - @classmethod - def from_bytes(cls: Any, blob: bytes) -> Any: - f = io.BytesIO(blob) - return cls.parse(f) - - def serialize(self: Any) -> bytes: - f = io.BytesIO() - self.stream(f) - return bytes(f.getvalue()) diff --git a/src/util/byte_types.py b/src/util/byte_types.py index d9f5512cc326..1acba3ffc10f 100644 --- a/src/util/byte_types.py +++ b/src/util/byte_types.py @@ -1,6 +1,5 @@ from typing import Any, BinaryIO - -from .bin_methods import BinMethods +import io def make_sized_bytes(size): @@ -14,9 +13,9 @@ def make_sized_bytes(size): v = bytes(v) if not isinstance(v, bytes) or len(v) != size: raise ValueError("bad %s initializer %s" % (name, v)) - return bytes.__new__(self, v) + return bytes.__new__(self, v) # type: ignore - @classmethod + @classmethod # type: ignore def parse(cls, f: BinaryIO) -> Any: b = f.read(size) assert len(b) == size @@ -25,12 +24,23 @@ def make_sized_bytes(size): def stream(self, f): f.write(self) + @classmethod # type: ignore + def from_bytes(cls: Any, blob: bytes) -> Any: + f = io.BytesIO(blob) + return cls.parse(f) + + def serialize(self: Any) -> bytes: + f = io.BytesIO() + self.stream(f) + return bytes(f.getvalue()) + def __str__(self): return self.hex() def __repr__(self): return "<%s: %s>" % (self.__class__.__name__, str(self)) - namespace = dict(__new__=__new__, parse=parse, stream=stream, __str__=__str__, __repr__=__repr__) + namespace = dict(__new__=__new__, parse=parse, stream=stream, from_bytes=from_bytes, + serialize=serialize, __str__=__str__, __repr__=__repr__) - return type(name, (bytes, BinMethods), namespace) + return type(name, (bytes,), namespace) diff --git a/src/util/ints.py b/src/util/ints.py index 1652ab7399d4..a544d8e15eac 100644 --- a/src/util/ints.py +++ b/src/util/ints.py @@ -2,36 +2,44 @@ from src.util.struct_stream import StructStream from typing import Any, BinaryIO -class int8(int, StructStream): +class int8(StructStream): PACK = "!b" + bits = 8 -class uint8(int, StructStream): +class uint8(StructStream): PACK = "!B" + bits = 8 -class int16(int, StructStream): +class int16(StructStream): PACK = "!h" + bits = 16 -class uint16(int, StructStream): +class uint16(StructStream): PACK = "!H" + bits = 16 -class int32(int, StructStream): +class int32(StructStream): PACK = "!l" + bits = 32 -class uint32(int, StructStream): +class uint32(StructStream): PACK = "!L" + bits = 32 -class int64(int, StructStream): +class int64(StructStream): PACK = "!q" + bits = 64 -class uint64(int, StructStream): +class uint64(StructStream): PACK = "!Q" + bits = 64 class int1024(int): diff --git a/src/util/streamable.py b/src/util/streamable.py index 6a29a5992630..05574635547c 100644 --- a/src/util/streamable.py +++ b/src/util/streamable.py @@ -1,10 +1,11 @@ +# flake8: noqa from __future__ import annotations +import io from typing import Type, BinaryIO, get_type_hints, Any, List from hashlib import sha256 from blspy import PublicKey, Signature, PrependSignature from src.util.type_checking import strictdataclass, is_type_List, is_type_SpecificOptional from src.types.sized_bytes import bytes32 -from src.util.bin_methods import BinMethods from src.util.ints import uint32 @@ -35,68 +36,88 @@ def streamable(cls: Any): This class is used for deterministic serialization and hashing, for consensus critical objects such as the block header. + + Make sure to use the Streamable class as a parent class when using the streamable decorator, + as it will allow linters to recognize the methods that are added by the decorator. Also, + use the @dataclass(frozen=True) decorator as well, for linters to recognize constructor + arguments. """ - class _Local: - @classmethod - def parse_one_item(cls: Type[cls.__name__], f_type: Type, f: BinaryIO): - if is_type_List(f_type): - inner_type: Type = f_type.__args__[0] - full_list: List[inner_type] = [] - assert inner_type != List.__args__[0] - list_size: uint32 = int.from_bytes(f.read(4), "big") - for list_index in range(list_size): - full_list.append(cls.parse_one_item(inner_type, f)) - return full_list - if is_type_SpecificOptional(f_type): - inner_type: Type = f_type.__args__[0] - is_present: bool = f.read(1) == bytes([1]) - if is_present: - return cls.parse_one_item(inner_type, f) - else: - return None - if hasattr(f_type, "parse"): - return f_type.parse(f) - if hasattr(f_type, "from_bytes") and size_hints[f_type.__name__]: - return f_type.from_bytes(f.read(size_hints[f_type.__name__])) - else: - raise RuntimeError(f"Type {f_type} does not have parse") - - @classmethod - def parse(cls: Type[cls.__name__], f: BinaryIO) -> cls.__name__: - values = [] - for _, f_type in get_type_hints(cls).items(): - values.append(cls.parse_one_item(f_type, f)) - return cls(*values) - - def stream_one_item(self, f_type: Type, item, f: BinaryIO) -> None: - if is_type_List(f_type): - assert is_type_List(type(item)) - f.write(uint32(len(item)).to_bytes(4, "big")) - inner_type: Type = f_type.__args__[0] - assert inner_type != List.__args__[0] - for element in item: - self.stream_one_item(inner_type, element, f) - elif is_type_SpecificOptional(f_type): - inner_type: Type = f_type.__args__[0] - if item is None: - f.write(bytes([0])) - else: - f.write(bytes([1])) - self.stream_one_item(inner_type, item, f) - elif hasattr(f_type, "stream"): - item.stream(f) - elif hasattr(f_type, "serialize"): - f.write(item.serialize()) - else: - raise NotImplementedError(f"can't stream {item}, {f_type}") - - def stream(self, f: BinaryIO) -> None: - for f_name, f_type in get_type_hints(self).items(): - self.stream_one_item(f_type, getattr(self, f_name), f) - - def get_hash(self) -> bytes32: - return bytes32(sha256(self.serialize()).digest()) - cls1 = strictdataclass(cls) - return type(cls.__name__, (cls1, BinMethods, _Local), {}) + return type(cls.__name__, (cls1, Streamable), {}) + + +class Streamable: + @classmethod + def parse_one_item(cls: Type[cls.__name__], f_type: Type, f: BinaryIO): # type: ignore + inner_type: Type + if is_type_List(f_type): + inner_type = f_type.__args__[0] + full_list: List[inner_type] = [] # type: ignore + assert inner_type != List.__args__[0] # type: ignore + list_size: uint32 = uint32(int.from_bytes(f.read(4), "big")) + for list_index in range(list_size): + full_list.append(cls.parse_one_item(inner_type, f)) # type: ignore + return full_list + if is_type_SpecificOptional(f_type): + inner_type = f_type.__args__[0] + is_present: bool = f.read(1) == bytes([1]) + if is_present: + return cls.parse_one_item(inner_type, f) # type: ignore + else: + return None + if hasattr(f_type, "parse"): + return f_type.parse(f) + if hasattr(f_type, "from_bytes") and size_hints[f_type.__name__]: + return f_type.from_bytes(f.read(size_hints[f_type.__name__])) + else: + raise RuntimeError(f"Type {f_type} does not have parse") + + @classmethod + def parse(cls: Type[cls.__name__], f: BinaryIO) -> cls.__name__: # type: ignore + values = [] + for _, f_type in get_type_hints(cls).items(): + values.append(cls.parse_one_item(f_type, f)) # type: ignore + return cls(*values) + + def stream_one_item(self, f_type: Type, item, f: BinaryIO) -> None: + inner_type: Type + if is_type_List(f_type): + assert is_type_List(type(item)) + f.write(uint32(len(item)).to_bytes(4, "big")) + inner_type = f_type.__args__[0] + assert inner_type != List.__args__[0] # type: ignore + for element in item: + self.stream_one_item(inner_type, element, f) + elif is_type_SpecificOptional(f_type): + inner_type = f_type.__args__[0] + if item is None: + f.write(bytes([0])) + else: + f.write(bytes([1])) + self.stream_one_item(inner_type, item, f) + elif hasattr(f_type, "stream"): + item.stream(f) + elif hasattr(f_type, "serialize"): + f.write(item.serialize()) + else: + raise NotImplementedError(f"can't stream {item}, {f_type}") + + def stream(self, f: BinaryIO) -> None: + for f_name, f_type in get_type_hints(self).items(): # type: ignore + self.stream_one_item(f_type, getattr(self, f_name), f) + + def get_hash(self) -> bytes32: + return bytes32(sha256(self.serialize()).digest()) + + @classmethod + def from_bytes(cls: Any, blob: bytes) -> Any: + f = io.BytesIO(blob) + return cls.parse(f) + + def serialize(self: Any) -> bytes: + f = io.BytesIO() + self.stream(f) + return bytes(f.getvalue()) + + diff --git a/src/util/struct_stream.py b/src/util/struct_stream.py index 5ac3b165343a..87ab06709624 100644 --- a/src/util/struct_stream.py +++ b/src/util/struct_stream.py @@ -1,18 +1,37 @@ + import struct +import io from typing import Any, BinaryIO -from src.util.bin_methods import BinMethods - -class StructStream(BinMethods): +class StructStream(int): PACK = "" + bits = 1 + """ Create a class that can parse and stream itself based on a struct.pack template string. """ + def __new__(cls: Any, value: int): + if value.bit_length() > cls.bits: + raise ValueError(f"Value {value} of size {value.bit_length()} does not fit into " + f"{cls.__name__} of size {cls.bits}") + + return int.__new__(cls, value) # type: ignore + @classmethod def parse(cls: Any, f: BinaryIO) -> Any: return cls(*struct.unpack(cls.PACK, f.read(struct.calcsize(cls.PACK)))) def stream(self, f): f.write(struct.pack(self.PACK, self)) + + @classmethod + def from_bytes(cls: Any, blob: bytes) -> Any: # type: ignore + f = io.BytesIO(blob) + return cls.parse(f) + + def serialize(self: Any) -> bytes: + f = io.BytesIO() + self.stream(f) + return bytes(f.getvalue()) diff --git a/src/util/type_checking.py b/src/util/type_checking.py index a1ae1ddc280c..03f8de01c4ba 100644 --- a/src/util/type_checking.py +++ b/src/util/type_checking.py @@ -24,9 +24,9 @@ def strictdataclass(cls: Any): """ def parse_item(self, item: Any, f_name: str, f_type: Type) -> Any: if is_type_List(f_type): - collected_list: f_type = [] + collected_list: List = [] inner_type: Type = f_type.__args__[0] - assert inner_type != List.__args__[0] + assert inner_type != List.__args__[0] # type: ignore if not is_type_List(type(item)): raise ValueError(f"Wrong type for {f_name}, need a list.") for el in item: @@ -36,7 +36,7 @@ def strictdataclass(cls: Any): if item is None: return None else: - inner_type: Type = f_type.__args__[0] + inner_type: Type = f_type.__args__[0] # type: ignore return self.parse_item(item, f_name, inner_type) if not isinstance(item, f_type): try: @@ -47,18 +47,18 @@ def strictdataclass(cls: Any): raise ValueError(f"Wrong type for {f_name}") return item - def __init__(self, *args): + def __post_init__(self): fields = get_type_hints(self) - la, lf = len(args), len(fields) - if la != lf: - raise ValueError("got %d and expected %d args" % (la, lf)) - for a, (f_name, f_type) in zip(args, fields.items()): - object.__setattr__(self, f_name, self.parse_item(a, f_name, f_type)) + data = self.__dict__ + for (f_name, f_type) in fields.items(): + if f_name not in data: + raise ValueError(f"Field {f_name} not present") + object.__setattr__(self, f_name, self.parse_item(data[f_name], f_name, f_type)) class NoTypeChecking: __no_type_check__ = True - cls1 = dataclasses.dataclass(_cls=cls, init=False, frozen=True) + cls1 = dataclasses.dataclass(_cls=cls, init=False, frozen=True) # type: ignore if dataclasses.fields(cls1) == (): return type(cls.__name__, (cls1, _Local, NoTypeChecking), {}) return type(cls.__name__, (cls1, _Local), {}) diff --git a/test1.py b/test1.py deleted file mode 100644 index 2193345fb61e..000000000000 --- a/test1.py +++ /dev/null @@ -1,25 +0,0 @@ -import asyncio - - -def cb(r, w): - print("Connected", w.get_extra_info("peername")) - - -async def main(): - server = await asyncio.start_server(cb, "127.0.0.1", 8000) - server2 = await asyncio.start_server(cb, "127.0.0.1", 8001) - - _, _ = await asyncio.open_connection("127.0.0.1", 8001) - _, _ = await asyncio.open_connection("127.0.0.1", 8001) - _, _ = await asyncio.open_connection("127.0.0.1", 8001) - await asyncio.sleep(2) - server2_socket = server2.sockets[0] - print("Socket", server2.sockets) - - r, w = await asyncio.open_connection(sock=server2_socket) - print("Opened connection", w.get_extra_info("peername"), w.transport) - await server.serve_forever() - - - -asyncio.run(main()) \ No newline at end of file diff --git a/tests/block_tools.py b/tests/block_tools.py index 094decc4406d..de4c71d8a6f2 100644 --- a/tests/block_tools.py +++ b/tests/block_tools.py @@ -3,7 +3,7 @@ import os import sys from hashlib import sha256 from chiapos import DiskPlotter, DiskProver -from typing import List, Dict +from typing import List, Dict, Any from blspy import PublicKey, PrivateKey, PrependSignature from src.types.sized_bytes import bytes32 from src.types.full_block import FullBlock @@ -27,7 +27,7 @@ from src.consensus.pot_iterations import calculate_ips_from_iterations # Can't go much lower than 19, since plots start having no solutions -k = 19 +k: uint8 = uint8(19) # Uses many plots for testing, in order to guarantee proofs of space at every height num_plots = 80 # Use the empty string as the seed for the private key @@ -39,7 +39,7 @@ plot_pks: List[PublicKey] = [sk.get_public_key() for sk in plot_sks] farmer_sk: PrivateKey = PrivateKey.from_seed(b'coinbase') coinbase_target = sha256(farmer_sk.get_public_key().serialize()).digest() fee_target = sha256(farmer_sk.get_public_key().serialize()).digest() -n_wesolowski = 3 +n_wesolowski = uint8(3) class BlockTools: @@ -70,7 +70,7 @@ class BlockTools: block_list: List[FullBlock] = [], seconds_per_block=constants["BLOCK_TIME_TARGET"], seed: uint64 = uint64(0)) -> List[FullBlock]: - test_constants = constants.copy() + test_constants: Dict[str, Any] = constants.copy() for key, value in input_constants.items(): test_constants[key] = value @@ -92,6 +92,7 @@ class BlockTools: curr_difficulty = block_list[-1].weight - block_list[-2].weight prev_difficulty = (block_list[-1 - test_constants["DIFFICULTY_EPOCH"]].weight - block_list[-2 - test_constants["DIFFICULTY_EPOCH"]].weight) + assert block_list[-1].trunk_block.proof_of_time curr_ips = calculate_ips_from_iterations(block_list[-1].trunk_block.proof_of_space, block_list[-1].trunk_block.proof_of_time.output.challenge_hash, curr_difficulty, @@ -110,15 +111,22 @@ class BlockTools: height2 = uint64(next_height - (test_constants["DIFFICULTY_EPOCH"]) - 1) height3 = uint64(next_height - (test_constants["DIFFICULTY_DELAY"]) - 1) if height1 >= 0: - timestamp1 = block_list[height1].trunk_block.header.data.timestamp - iters1 = block_list[height1].trunk_block.challenge.total_iters + block1 = block_list[height1] + assert block1.trunk_block.challenge + iters1 = block1.trunk_block.challenge.total_iters + timestamp1 = block1.trunk_block.header.data.timestamp else: - timestamp1 = (block_list[0].trunk_block.header.data.timestamp - + block1 = block_list[0] + assert block1.trunk_block.challenge + timestamp1 = (block1.trunk_block.header.data.timestamp - test_constants["BLOCK_TIME_TARGET"]) - iters1 = block_list[0].trunk_block.challenge.total_iters + iters1 = block1.trunk_block.challenge.total_iters timestamp2 = block_list[height2].trunk_block.header.data.timestamp timestamp3 = block_list[height3].trunk_block.header.data.timestamp - iters3 = block_list[height3].trunk_block.challenge.total_iters + + block3 = block_list[height3] + assert block3.trunk_block.challenge + iters3 = block3.trunk_block.challenge.total_iters term1 = (test_constants["DIFFICULTY_DELAY"] * prev_difficulty * (timestamp3 - timestamp2) * test_constants["BLOCK_TIME_TARGET"]) @@ -153,7 +161,7 @@ class BlockTools: """ Creates the genesis block with the specified details. """ - test_constants = constants.copy() + test_constants: Dict[str, Any] = constants.copy() for key, value in input_constants.items(): test_constants[key] = value @@ -164,7 +172,7 @@ class BlockTools: bytes([0]*32), uint64(0), uint64(0), - uint64(time.time()), + uint64(int(time.time())), uint64(test_constants["DIFFICULTY_STARTING"]), uint64(test_constants["VDF_IPS_STARTING"]), seed @@ -176,14 +184,16 @@ class BlockTools: """ Creates the next block with the specified details. """ - test_constants = constants.copy() + test_constants: Dict[str, Any] = constants.copy() for key, value in input_constants.items(): test_constants[key] = value + assert prev_block.trunk_block.challenge + return self._create_block( test_constants, prev_block.trunk_block.challenge.get_hash(), - prev_block.height + 1, + uint32(prev_block.height + 1), prev_block.header_hash, prev_block.trunk_block.challenge.total_iters, prev_block.weight, @@ -203,7 +213,7 @@ class BlockTools: prover = None plot_pk = None plot_sk = None - qualities = [] + qualities: List[bytes] = [] for pn in range(num_plots): seeded_pn = (pn + 17 * seed) % num_plots # Allow passing in seed, to create reorgs and different chains filename = self.filenames[seeded_pn] @@ -214,11 +224,14 @@ class BlockTools: if len(qualities) > 0: break + assert prover + assert plot_pk + assert plot_sk if len(qualities) == 0: raise NoProofsOfSpaceFound("No proofs for this challenge") proof_xs: bytes = prover.get_full_proof(challenge_hash, 0) - proof_of_space: ProofOfSpace = ProofOfSpace(pool_pk, plot_pk, k, list(proof_xs)) + proof_of_space: ProofOfSpace = ProofOfSpace(pool_pk, plot_pk, k, [uint8(b) for b in proof_xs]) number_iters: uint64 = pot_iterations.calculate_iterations(proof_of_space, challenge_hash, difficulty, ips, test_constants["MIN_BLOCK_TIME"]) @@ -237,7 +250,7 @@ class BlockTools: coinbase_target) coinbase_sig: PrependSignature = pool_sk.sign_prepend(coinbase.serialize()) - fees_target: FeesTarget = FeesTarget(fee_target, 0) + fees_target: FeesTarget = FeesTarget(fee_target, uint64(0)) body: BlockBody = BlockBody(coinbase, coinbase_sig, fees_target, None, bytes([0]*32)) @@ -250,7 +263,7 @@ class BlockTools: header: BlockHeader = BlockHeader(header_data, header_hash_sig) challenge = Challenge(proof_of_space.get_hash(), proof_of_time.get_hash(), height, - prev_weight + difficulty, prev_iters + number_iters) + uint64(prev_weight + difficulty), uint64(prev_iters + number_iters)) trunk_block = TrunkBlock(proof_of_space, proof_of_time, challenge, header) full_block: FullBlock = FullBlock(trunk_block, body) @@ -261,5 +274,5 @@ class BlockTools: # This code generates a genesis block, uncomment to output genesis block to terminal # This might take a while, using the python VDF implementation. # Run by doing python -m tests.block_tools -bt = BlockTools() -print(bt.create_genesis_block({}, bytes([1]*32), uint64(0)).serialize()) +# bt = BlockTools() +# print(bt.create_genesis_block({}, bytes([1]*32), uint64(0)).serialize()) diff --git a/tests/test_blockchain.py b/tests/test_blockchain.py index e0a8b53935cb..c5090e9c1201 100644 --- a/tests/test_blockchain.py +++ b/tests/test_blockchain.py @@ -1,5 +1,7 @@ import time import pytest +import asyncio +from typing import Dict, Any from blspy import PrivateKey from src.consensus.constants import constants from src.types.coinbase import CoinbaseInfo @@ -10,17 +12,19 @@ from src.types.trunk_block import TrunkBlock from src.types.full_block import FullBlock from src.types.block_header import BlockHeaderData from src.blockchain import Blockchain, ReceiveBlockResult -from src.util.ints import uint64 +from src.store.full_node_store import FullNodeStore +from src.util.ints import uint64, uint32 from tests.block_tools import BlockTools bt = BlockTools() -test_constants = { +test_constants: Dict[str, Any] = { "DIFFICULTY_STARTING": 5, "DISCRIMINANT_SIZE_BITS": 16, "BLOCK_TIME_TARGET": 10, "MIN_BLOCK_TIME": 2, + "DIFFICULTY_FACTOR": 3, "DIFFICULTY_EPOCH": 12, # The number of blocks per epoch "DIFFICULTY_WARP_FACTOR": 4, # DELAY divides EPOCH in order to warp efficiently. "DIFFICULTY_DELAY": 3 # EPOCH / WARP_FACTOR @@ -28,30 +32,46 @@ test_constants = { test_constants["GENESIS_BLOCK"] = bt.create_genesis_block(test_constants, bytes([0]*32), uint64(0)).serialize() +@pytest.fixture(scope="module") +def event_loop(): + loop = asyncio.get_event_loop() + yield loop + loop.close() + + class TestGenesisBlock(): - def test_basic_blockchain(self): - bc1: Blockchain = Blockchain() + @pytest.mark.asyncio + async def test_basic_blockchain(self): + store = FullNodeStore() + await store.initialize() + bc1: Blockchain = Blockchain(store) + await bc1.initialize() assert len(bc1.get_current_heads()) == 1 genesis_block = bc1.get_current_heads()[0] assert genesis_block.height == 0 - assert bc1.get_trunk_blocks_by_height([uint64(0)], genesis_block.header_hash)[0] == genesis_block - assert bc1.get_next_difficulty(genesis_block.header_hash) == genesis_block.challenge.total_weight - assert bc1.get_next_ips(genesis_block.header_hash) > 0 + assert genesis_block.challenge + assert (await bc1.get_trunk_blocks_by_height([uint64(0)], genesis_block.header_hash))[0] == genesis_block + assert (await bc1.get_next_difficulty(genesis_block.header_hash)) == genesis_block.challenge.total_weight + assert await bc1.get_next_ips(genesis_block.header_hash) > 0 class TestBlockValidation(): @pytest.fixture(scope="module") - def initial_blockchain(self): + async def initial_blockchain(self): """ Provides a list of 10 valid blocks, as well as a blockchain with 9 blocks added to it. """ + store = FullNodeStore() + await store.initialize() blocks = bt.get_consecutive_blocks(test_constants, 10, [], 10) - b: Blockchain = Blockchain(test_constants) + b: Blockchain = Blockchain(store, test_constants) + await b.initialize() for i in range(1, 9): - assert b.receive_block(blocks[i]) == ReceiveBlockResult.ADDED_TO_HEAD + assert (await b.receive_block(blocks[i])) == ReceiveBlockResult.ADDED_TO_HEAD return (blocks, b) - def test_prev_pointer(self, initial_blockchain): + @pytest.mark.asyncio + async def test_prev_pointer(self, initial_blockchain): blocks, b = initial_blockchain block_bad = FullBlock(TrunkBlock( blocks[9].trunk_block.proof_of_space, @@ -66,9 +86,10 @@ class TestBlockValidation(): blocks[9].trunk_block.header.data.extension_data ), blocks[9].trunk_block.header.plotter_signature) ), blocks[9].body) - assert b.receive_block(block_bad) == ReceiveBlockResult.INVALID_BLOCK + assert (await b.receive_block(block_bad)) == ReceiveBlockResult.INVALID_BLOCK - def test_timestamp(self, initial_blockchain): + @pytest.mark.asyncio + async def test_timestamp(self, initial_blockchain): blocks, b = initial_blockchain # Time too far in the past block_bad = FullBlock(TrunkBlock( @@ -84,7 +105,7 @@ class TestBlockValidation(): blocks[9].trunk_block.header.data.extension_data ), blocks[9].trunk_block.header.plotter_signature) ), blocks[9].body) - assert b.receive_block(block_bad) == ReceiveBlockResult.INVALID_BLOCK + assert (await b.receive_block(block_bad)) == ReceiveBlockResult.INVALID_BLOCK # Time too far in the future block_bad = FullBlock(TrunkBlock( @@ -93,7 +114,7 @@ class TestBlockValidation(): blocks[9].trunk_block.challenge, BlockHeader(BlockHeaderData( blocks[9].trunk_block.header.data.prev_header_hash, - time.time() + 3600 * 3, + uint64(int(time.time() + 3600 * 3)), blocks[9].trunk_block.header.data.filter_hash, blocks[9].trunk_block.header.data.proof_of_space_hash, blocks[9].trunk_block.header.data.body_hash, @@ -101,9 +122,10 @@ class TestBlockValidation(): ), blocks[9].trunk_block.header.plotter_signature) ), blocks[9].body) - assert b.receive_block(block_bad) == ReceiveBlockResult.INVALID_BLOCK + assert (await b.receive_block(block_bad)) == ReceiveBlockResult.INVALID_BLOCK - def test_body_hash(self, initial_blockchain): + @pytest.mark.asyncio + async def test_body_hash(self, initial_blockchain): blocks, b = initial_blockchain # Time too far in the past block_bad = FullBlock(TrunkBlock( @@ -119,9 +141,10 @@ class TestBlockValidation(): blocks[9].trunk_block.header.data.extension_data ), blocks[9].trunk_block.header.plotter_signature) ), blocks[9].body) - assert b.receive_block(block_bad) == ReceiveBlockResult.INVALID_BLOCK + assert (await b.receive_block(block_bad)) == ReceiveBlockResult.INVALID_BLOCK - def test_plotter_signature(self, initial_blockchain): + @pytest.mark.asyncio + async def test_plotter_signature(self, initial_blockchain): blocks, b = initial_blockchain # Time too far in the past block_bad = FullBlock(TrunkBlock( @@ -132,9 +155,10 @@ class TestBlockValidation(): blocks[9].trunk_block.header.data, PrivateKey.from_seed(b'0').sign_prepend(b"random junk")) ), blocks[9].body) - assert b.receive_block(block_bad) == ReceiveBlockResult.INVALID_BLOCK + assert (await b.receive_block(block_bad)) == ReceiveBlockResult.INVALID_BLOCK - def test_invalid_pos(self, initial_blockchain): + @pytest.mark.asyncio + async def test_invalid_pos(self, initial_blockchain): blocks, b = initial_blockchain bad_pos = blocks[9].trunk_block.proof_of_space.proof @@ -151,15 +175,16 @@ class TestBlockValidation(): blocks[9].trunk_block.challenge, blocks[9].trunk_block.header ), blocks[9].body) - assert b.receive_block(block_bad) == ReceiveBlockResult.INVALID_BLOCK + assert (await b.receive_block(block_bad)) == ReceiveBlockResult.INVALID_BLOCK - def test_invalid_coinbase_height(self, initial_blockchain): + @pytest.mark.asyncio + async def test_invalid_coinbase_height(self, initial_blockchain): blocks, b = initial_blockchain # Coinbase height invalid block_bad = FullBlock(blocks[9].trunk_block, BlockBody( CoinbaseInfo( - 3, + uint32(3), blocks[9].body.coinbase.amount, blocks[9].body.coinbase.puzzle_hash ), @@ -168,41 +193,52 @@ class TestBlockValidation(): blocks[9].body.aggregated_signature, blocks[9].body.solutions_generator )) - assert b.receive_block(block_bad) == ReceiveBlockResult.INVALID_BLOCK + assert (await b.receive_block(block_bad)) == ReceiveBlockResult.INVALID_BLOCK - def test_difficulty_change(self): + @pytest.mark.asyncio + async def test_difficulty_change(self): num_blocks = 20 # Make it 5x faster than target time blocks = bt.get_consecutive_blocks(test_constants, num_blocks, [], 2) - b: Blockchain = Blockchain(test_constants) + store = FullNodeStore() + await store.initialize() + b: Blockchain = Blockchain(store, test_constants) + await b.initialize() for i in range(1, num_blocks): - assert b.receive_block(blocks[i]) == ReceiveBlockResult.ADDED_TO_HEAD - assert b.get_next_difficulty(blocks[13].header_hash) == b.get_next_difficulty(blocks[12].header_hash) - assert b.get_next_difficulty(blocks[14].header_hash) > b.get_next_difficulty(blocks[13].header_hash) - assert ((b.get_next_difficulty(blocks[14].header_hash) / b.get_next_difficulty(blocks[13].header_hash) - <= constants["DIFFICULTY_FACTOR"])) - assert blocks[-1].trunk_block.challenge.total_iters == 142911 + assert (await b.receive_block(blocks[i])) == ReceiveBlockResult.ADDED_TO_HEAD - assert b.get_next_ips(blocks[1].header_hash) == constants["VDF_IPS_STARTING"] - assert b.get_next_ips(blocks[12].header_hash) == b.get_next_ips(blocks[11].header_hash) - assert b.get_next_ips(blocks[13].header_hash) == b.get_next_ips(blocks[12].header_hash) - assert b.get_next_ips(blocks[14].header_hash) > b.get_next_ips(blocks[13].header_hash) - assert b.get_next_ips(blocks[15].header_hash) == b.get_next_ips(blocks[14].header_hash) + diff_13 = await b.get_next_difficulty(blocks[12].header_hash) + diff_14 = await b.get_next_difficulty(blocks[13].header_hash) + diff_15 = await b.get_next_difficulty(blocks[14].header_hash) + + assert diff_14 == diff_13 + assert diff_15 > diff_14 + assert (diff_15 / diff_14) <= test_constants["DIFFICULTY_FACTOR"] + + assert (await b.get_next_ips(blocks[1].header_hash)) == constants["VDF_IPS_STARTING"] + assert (await b.get_next_ips(blocks[12].header_hash)) == (await b.get_next_ips(blocks[11].header_hash)) + assert (await b.get_next_ips(blocks[13].header_hash)) == (await b.get_next_ips(blocks[12].header_hash)) + assert (await b.get_next_ips(blocks[14].header_hash)) > (await b.get_next_ips(blocks[13].header_hash)) + assert (await b.get_next_ips(blocks[15].header_hash)) == (await b.get_next_ips(blocks[14].header_hash)) class TestReorgs(): - def test_basic_reorg(self): + @pytest.mark.asyncio + async def test_basic_reorg(self): blocks = bt.get_consecutive_blocks(test_constants, 100, [], 9) - b: Blockchain = Blockchain(test_constants) + store = FullNodeStore() + await store.initialize() + b: Blockchain = Blockchain(store, test_constants) + await b.initialize() for block in blocks: - b.receive_block(block) + await b.receive_block(block) assert b.get_current_heads()[0].height == 100 blocks_reorg_chain = bt.get_consecutive_blocks(test_constants, 30, blocks[:90], 9, uint64(1)) for reorg_block in blocks_reorg_chain: - result = b.receive_block(reorg_block) + result = await b.receive_block(reorg_block) if reorg_block.height < 90: assert result == ReceiveBlockResult.ALREADY_HAVE_BLOCK elif reorg_block.height < 99: @@ -211,18 +247,21 @@ class TestReorgs(): assert result == ReceiveBlockResult.ADDED_TO_HEAD assert b.get_current_heads()[0].height == 119 - def test_reorg_from_genesis(self): + @pytest.mark.asyncio + async def test_reorg_from_genesis(self): blocks = bt.get_consecutive_blocks(test_constants, 20, [], 9, uint64(0)) - - b: Blockchain = Blockchain(test_constants) + store = FullNodeStore() + await store.initialize() + b: Blockchain = Blockchain(store, test_constants) + await b.initialize() for block in blocks: - b.receive_block(block) + await b.receive_block(block) assert b.get_current_heads()[0].height == 20 # Reorg from genesis blocks_reorg_chain = bt.get_consecutive_blocks(test_constants, 21, [blocks[0]], 9, uint64(1)) for reorg_block in blocks_reorg_chain: - result = b.receive_block(reorg_block) + result = await b.receive_block(reorg_block) if reorg_block.height == 0: assert result == ReceiveBlockResult.ALREADY_HAVE_BLOCK elif reorg_block.height < 19: @@ -233,6 +272,6 @@ class TestReorgs(): # Reorg back to original branch blocks_reorg_chain_2 = bt.get_consecutive_blocks(test_constants, 3, blocks, 9, uint64(3)) - b.receive_block(blocks_reorg_chain_2[20]) == ReceiveBlockResult.ADDED_AS_ORPHAN - assert b.receive_block(blocks_reorg_chain_2[21]) == ReceiveBlockResult.ADDED_TO_HEAD - assert b.receive_block(blocks_reorg_chain_2[22]) == ReceiveBlockResult.ADDED_TO_HEAD + await b.receive_block(blocks_reorg_chain_2[20]) == ReceiveBlockResult.ADDED_AS_ORPHAN + assert (await b.receive_block(blocks_reorg_chain_2[21])) == ReceiveBlockResult.ADDED_TO_HEAD + assert (await b.receive_block(blocks_reorg_chain_2[22])) == ReceiveBlockResult.ADDED_TO_HEAD diff --git a/tests/util/test_streamable.py b/tests/util/test_streamable.py index b4d4821a343e..2645e504da5a 100644 --- a/tests/util/test_streamable.py +++ b/tests/util/test_streamable.py @@ -1,13 +1,15 @@ import unittest +from dataclasses import dataclass from typing import List, Optional -from src.util.streamable import streamable +from src.util.streamable import streamable, Streamable from src.util.ints import uint32 class TestStreamable(unittest.TestCase): def test_basic(self): + @dataclass(frozen=True) @streamable - class TestClass: + class TestClass(Streamable): a: uint32 b: uint32 c: List[uint32] @@ -15,27 +17,29 @@ class TestStreamable(unittest.TestCase): e: Optional[uint32] f: Optional[uint32] - a = TestClass(24, 352, [1, 2, 4], [[1, 2, 3], [3, 4]], 728, None) + a = TestClass(24, 352, [1, 2, 4], [[1, 2, 3], [3, 4]], 728, None) # type: ignore b: bytes = a.serialize() assert a == TestClass.from_bytes(b) def test_variablesize(self): + @dataclass(frozen=True) @streamable - class TestClass2: + class TestClass2(Streamable): a: uint32 b: uint32 c: str - a = TestClass2(1, 2, "3") + a = TestClass2(uint32(1), uint32(2), "3") try: a.serialize() assert False except NotImplementedError: pass + @dataclass(frozen=True) @streamable - class TestClass3: + class TestClass3(Streamable): a: int b = TestClass3(1) diff --git a/tests/util/test_type_checking.py b/tests/util/test_type_checking.py index 3216d2a52fab..bfe97f6e5e06 100644 --- a/tests/util/test_type_checking.py +++ b/tests/util/test_type_checking.py @@ -1,4 +1,5 @@ import unittest +from dataclasses import dataclass from src.util.type_checking import is_type_List, is_type_SpecificOptional, strictdataclass from src.util.ints import uint8 from typing import List, Dict, Tuple, Optional @@ -12,7 +13,7 @@ class TestIsTypeList(unittest.TestCase): assert is_type_List(List[int]) assert is_type_List(List[uint8]) assert is_type_List(list) - assert not is_type_List(Tuple) + assert not is_type_List(Tuple) # type: ignore assert not is_type_List(tuple) assert not is_type_List(dict) @@ -29,6 +30,7 @@ class TestIsTypeSpecificOptional(unittest.TestCase): class TestStrictClass(unittest.TestCase): def test_StrictDataClass(self): + @dataclass(frozen=True) @strictdataclass class TestClass1: a: int @@ -39,10 +41,11 @@ class TestStrictClass(unittest.TestCase): assert good assert good.a == 24 assert good.b == "!@12" - good2 = TestClass1(52, bytes([1, 2, 3])) + good2 = TestClass1(52, bytes([1, 2, 3])) # type: ignore assert good2.b == str(bytes([1, 2, 3])) def test_StrictDataClassBad(self): + @dataclass(frozen=True) @strictdataclass class TestClass2: a: int @@ -50,12 +53,13 @@ class TestStrictClass(unittest.TestCase): assert TestClass2(25) try: - TestClass2(1, 2) + TestClass2(1, 2) # type: ignore assert False - except ValueError: + except TypeError: pass def test_StrictDataClassLists(self): + @dataclass(frozen=True) @strictdataclass class TestClass: a: List[int] @@ -63,17 +67,18 @@ class TestStrictClass(unittest.TestCase): assert TestClass([1, 2, 3], [[uint8(200), uint8(25)], [uint8(25)]]) try: - TestClass([1, 2, 3], [[200, uint8(25)], [uint8(25)]]) + TestClass([1, 2, 3], [[uint8(200), uint8(25)], [uint8(25)]]) assert False except AssertionError: pass try: - TestClass([1, 2, 3], [uint8(200), uint8(25)]) + TestClass([1, 2, 3], [uint8(200), uint8(25)]) # type: ignore assert False except ValueError: pass def test_StrictDataClassOptional(self): + @dataclass(frozen=True) @strictdataclass class TestClass: a: Optional[int] @@ -85,6 +90,7 @@ class TestStrictClass(unittest.TestCase): assert good def test_StrictDataClassEmpty(self): + @dataclass(frozen=True) @strictdataclass class A: pass