mirror of
https://github.com/Chia-Network/chia-blockchain.git
synced 2024-09-20 08:05:33 +03:00
Add more types (#3441)
This commit is contained in:
parent
4395f97260
commit
5034fcc85d
@ -4,7 +4,7 @@ import pathlib
|
||||
import signal
|
||||
import socket
|
||||
import time
|
||||
from typing import List
|
||||
from typing import Dict, List
|
||||
|
||||
import pkg_resources
|
||||
|
||||
@ -32,14 +32,14 @@ async def kill_processes():
|
||||
pass
|
||||
|
||||
|
||||
def find_vdf_client():
|
||||
def find_vdf_client() -> pathlib.Path:
|
||||
p = pathlib.Path(pkg_resources.get_distribution("chiavdf").location) / "vdf_client"
|
||||
if p.is_file():
|
||||
return p
|
||||
raise FileNotFoundError("can't find vdf_client binary")
|
||||
|
||||
|
||||
async def spawn_process(host, port, counter):
|
||||
async def spawn_process(host: str, port: int, counter: int):
|
||||
global stopped
|
||||
global active_processes
|
||||
path_to_vdf_client = find_vdf_client()
|
||||
@ -77,7 +77,7 @@ async def spawn_process(host, port, counter):
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
|
||||
async def spawn_all_processes(config, net_config):
|
||||
async def spawn_all_processes(config: Dict, net_config: Dict):
|
||||
await asyncio.sleep(5)
|
||||
port = config["port"]
|
||||
process_count = config["process_count"]
|
||||
|
@ -159,7 +159,7 @@ class BlockTools:
|
||||
updated_constants = updated_constants.replace(**const_dict)
|
||||
self.constants = updated_constants
|
||||
|
||||
def change_config(self, new_config):
|
||||
def change_config(self, new_config: Dict):
|
||||
self._config = new_config
|
||||
overrides = self._config["network_overrides"]["constants"][self._config["selected_network"]]
|
||||
updated_constants = self.constants.replace_str_to_bytes(**overrides)
|
||||
@ -1162,11 +1162,11 @@ def get_challenges(
|
||||
blocks: Dict[uint32, BlockRecord],
|
||||
finished_sub_slots: List[EndOfSubSlotBundle],
|
||||
prev_header_hash: Optional[bytes32],
|
||||
):
|
||||
) -> Tuple[bytes32, bytes32]:
|
||||
if len(finished_sub_slots) == 0:
|
||||
if prev_header_hash is None:
|
||||
return constants.GENESIS_CHALLENGE, constants.GENESIS_CHALLENGE
|
||||
curr = blocks[prev_header_hash]
|
||||
curr: BlockRecord = blocks[prev_header_hash]
|
||||
while not curr.first_in_sub_slot:
|
||||
curr = blocks[curr.prev_hash]
|
||||
assert curr.finished_challenge_slot_hashes is not None
|
||||
|
@ -11,7 +11,7 @@ def hexstr_to_bytes(input_str: str) -> bytes:
|
||||
return bytes.fromhex(input_str)
|
||||
|
||||
|
||||
def make_sized_bytes(size):
|
||||
def make_sized_bytes(size: int):
|
||||
"""
|
||||
Create a streamable type that subclasses "bytes" but requires instances
|
||||
to be a certain, fixed size.
|
||||
|
@ -83,7 +83,7 @@ def load_config_cli(root_path: Path, filename: str, sub_config: Optional[str] =
|
||||
return unflatten_properties(flattened_props)
|
||||
|
||||
|
||||
def flatten_properties(config: Dict):
|
||||
def flatten_properties(config: Dict) -> Dict:
|
||||
properties = {}
|
||||
for key, value in config.items():
|
||||
if type(value) is dict:
|
||||
@ -94,7 +94,7 @@ def flatten_properties(config: Dict):
|
||||
return properties
|
||||
|
||||
|
||||
def unflatten_properties(config: Dict):
|
||||
def unflatten_properties(config: Dict) -> Dict:
|
||||
properties: Dict = {}
|
||||
for key, value in config.items():
|
||||
if "." in key:
|
||||
@ -114,7 +114,7 @@ def add_property(d: Dict, partial_key: str, value: Any):
|
||||
d[key_1][key_2] = value
|
||||
|
||||
|
||||
def str2bool(v: Any) -> bool:
|
||||
def str2bool(v: Union[str, bool]) -> bool:
|
||||
# Source from https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse
|
||||
if isinstance(v, bool):
|
||||
return v
|
||||
|
@ -39,7 +39,7 @@ def generate_mnemonic() -> str:
|
||||
return mnemonic
|
||||
|
||||
|
||||
def bytes_to_mnemonic(mnemonic_bytes: bytes):
|
||||
def bytes_to_mnemonic(mnemonic_bytes: bytes) -> str:
|
||||
if len(mnemonic_bytes) not in [16, 20, 24, 28, 32]:
|
||||
raise ValueError(
|
||||
f"Data length should be one of the following: [16, 20, 24, 28, 32], but it is {len(mnemonic_bytes)}."
|
||||
@ -64,7 +64,7 @@ def bytes_to_mnemonic(mnemonic_bytes: bytes):
|
||||
return " ".join(mnemonics)
|
||||
|
||||
|
||||
def bytes_from_mnemonic(mnemonic_str: str):
|
||||
def bytes_from_mnemonic(mnemonic_str: str) -> bytes:
|
||||
mnemonic: List[str] = mnemonic_str.split(" ")
|
||||
if len(mnemonic) not in [12, 15, 18, 21, 24]:
|
||||
raise ValueError("Invalid mnemonic length")
|
||||
@ -93,7 +93,7 @@ def bytes_from_mnemonic(mnemonic_str: str):
|
||||
return entropy_bytes
|
||||
|
||||
|
||||
def mnemonic_to_seed(mnemonic: str, passphrase):
|
||||
def mnemonic_to_seed(mnemonic: str, passphrase: str) -> bytes:
|
||||
"""
|
||||
Uses BIP39 standard to derive a seed from entropy bytes.
|
||||
"""
|
||||
@ -124,7 +124,7 @@ class Keychain:
|
||||
self.testing = testing
|
||||
self.user = user
|
||||
|
||||
def _get_service(self):
|
||||
def _get_service(self) -> str:
|
||||
"""
|
||||
The keychain stores keys under a different name for tests.
|
||||
"""
|
||||
@ -148,7 +148,7 @@ class Keychain:
|
||||
str_bytes[G1Element.SIZE :], # flake8: noqa
|
||||
)
|
||||
|
||||
def _get_private_key_user(self, index: int):
|
||||
def _get_private_key_user(self, index: int) -> str:
|
||||
"""
|
||||
Returns the keychain user string for a key index.
|
||||
"""
|
||||
|
@ -1,7 +1,7 @@
|
||||
from typing import Dict
|
||||
|
||||
from chia.consensus.default_constants import DEFAULT_CONSTANTS
|
||||
from chia.consensus.default_constants import DEFAULT_CONSTANTS, ConsensusConstants
|
||||
|
||||
|
||||
def make_test_constants(test_constants_overrides: Dict):
|
||||
def make_test_constants(test_constants_overrides: Dict) -> ConsensusConstants:
|
||||
return DEFAULT_CONSTANTS.replace(**test_constants_overrides)
|
||||
|
@ -42,7 +42,7 @@ TRUNCATED = bytes([3])
|
||||
|
||||
BLANK = bytes([0] * 32)
|
||||
|
||||
prehashed: Dict = {}
|
||||
prehashed: Dict[bytes, Any] = {}
|
||||
|
||||
|
||||
def init_prehashed():
|
||||
@ -54,14 +54,14 @@ def init_prehashed():
|
||||
init_prehashed()
|
||||
|
||||
|
||||
def hashdown(mystr: bytes):
|
||||
def hashdown(mystr: bytes) -> bytes:
|
||||
assert len(mystr) == 66
|
||||
h = prehashed[bytes(mystr[0:1] + mystr[33:34])].copy()
|
||||
h.update(mystr[1:33] + mystr[34:])
|
||||
return h.digest()[:32]
|
||||
|
||||
|
||||
def compress_root(mystr: bytes):
|
||||
def compress_root(mystr: bytes) -> bytes:
|
||||
assert len(mystr) == 33
|
||||
if mystr[0:1] == MIDDLE:
|
||||
return mystr[1:]
|
||||
@ -71,7 +71,7 @@ def compress_root(mystr: bytes):
|
||||
return sha256(mystr).digest()[:32]
|
||||
|
||||
|
||||
def get_bit(mybytes: bytes, pos: int):
|
||||
def get_bit(mybytes: bytes, pos: int) -> int:
|
||||
assert len(mybytes) == 32
|
||||
return (mybytes[pos // 8] >> (7 - (pos % 8))) & 1
|
||||
|
||||
@ -104,7 +104,7 @@ class Node(metaclass=ABCMeta):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def is_included(self, tocheck: bytes, depth: int, p: List[bytes]):
|
||||
def is_included(self, tocheck: bytes, depth: int, p: List[bytes]) -> bool:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@ -125,7 +125,7 @@ class MerkleSet:
|
||||
else:
|
||||
self.root = root
|
||||
|
||||
def get_root(self) -> Node:
|
||||
def get_root(self) -> bytes:
|
||||
return compress_root(self.root.get_hash())
|
||||
|
||||
def add_already_hashed(self, toadd: bytes):
|
||||
|
@ -1,4 +1,4 @@
|
||||
def format_minutes(minutes):
|
||||
def format_minutes(minutes: int) -> str:
|
||||
|
||||
if not isinstance(minutes, int):
|
||||
return "Invalid"
|
||||
@ -18,10 +18,10 @@ def format_minutes(minutes):
|
||||
days = int(minutes / day_minutes)
|
||||
hours = int(minutes / hour_minutes)
|
||||
|
||||
def format_unit_string(str_unit, count):
|
||||
def format_unit_string(str_unit: str, count: int) -> str:
|
||||
return f"{count} {str_unit}{('s' if count > 1 else '')}"
|
||||
|
||||
def format_unit(unit, count, unit_minutes, next_unit, next_unit_minutes):
|
||||
def format_unit(unit: str, count: int, unit_minutes: int, next_unit: str, next_unit_minutes: int) -> str:
|
||||
formatted = format_unit_string(unit, count)
|
||||
minutes_left = minutes % unit_minutes
|
||||
if minutes_left >= next_unit_minutes:
|
||||
|
@ -1,11 +1,12 @@
|
||||
from typing import Any
|
||||
from chia.server.outbound_message import NodeType
|
||||
|
||||
|
||||
def is_localhost(peer_host: str):
|
||||
def is_localhost(peer_host: str) -> bool:
|
||||
return peer_host == "127.0.0.1" or peer_host == "localhost" or peer_host == "::1" or peer_host == "0:0:0:0:0:0:0:1"
|
||||
|
||||
|
||||
def class_for_type(type: NodeType):
|
||||
def class_for_type(type: NodeType) -> Any:
|
||||
if type is NodeType.FULL_NODE:
|
||||
from chia.full_node.full_node_api import FullNodeAPI
|
||||
|
||||
|
@ -2,7 +2,7 @@ from dataclasses import replace
|
||||
from typing import Any
|
||||
|
||||
|
||||
def recursive_replace(root_obj: Any, replace_str: str, replace_with: Any):
|
||||
def recursive_replace(root_obj: Any, replace_str: str, replace_with: Any) -> Any:
|
||||
split_str = replace_str.split(".")
|
||||
if len(split_str) == 1:
|
||||
return replace(root_obj, **{split_str[0]: replace_with})
|
||||
|
@ -1,8 +1,9 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def cancel_task_safe(task: Optional[asyncio.Task], log=None):
|
||||
def cancel_task_safe(task: Optional[asyncio.Task], log: Optional[logging.Logger] = None):
|
||||
if task is not None:
|
||||
try:
|
||||
task.cancel()
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import KeysView
|
||||
from typing import KeysView, Generator
|
||||
|
||||
SERVICES_FOR_GROUP = {
|
||||
"all": "chia_harvester chia_timelord_launcher chia_timelord chia_farmer chia_full_node chia_wallet".split(),
|
||||
@ -21,11 +21,11 @@ def all_groups() -> KeysView[str]:
|
||||
return SERVICES_FOR_GROUP.keys()
|
||||
|
||||
|
||||
def services_for_groups(groups):
|
||||
def services_for_groups(groups) -> Generator[str, None, None]:
|
||||
for group in groups:
|
||||
for service in SERVICES_FOR_GROUP[group]:
|
||||
yield service
|
||||
|
||||
|
||||
def validate_service(service) -> bool:
|
||||
def validate_service(service: str) -> bool:
|
||||
return any(service in _ for _ in SERVICES_FOR_GROUP.values())
|
||||
|
@ -6,6 +6,6 @@ except Exception:
|
||||
no_setproctitle = True
|
||||
|
||||
|
||||
def setproctitle(ps_name) -> None:
|
||||
def setproctitle(ps_name: str) -> None:
|
||||
if no_setproctitle is False:
|
||||
pysetproctitle.setproctitle(ps_name)
|
||||
|
Loading…
Reference in New Issue
Block a user