Add more types (#3441)

This commit is contained in:
An Long 2021-05-13 06:32:21 +08:00 committed by GitHub
parent 4395f97260
commit 5034fcc85d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 37 additions and 35 deletions

View File

@ -4,7 +4,7 @@ import pathlib
import signal
import socket
import time
from typing import List
from typing import Dict, List
import pkg_resources
@ -32,14 +32,14 @@ async def kill_processes():
pass
def find_vdf_client():
def find_vdf_client() -> pathlib.Path:
p = pathlib.Path(pkg_resources.get_distribution("chiavdf").location) / "vdf_client"
if p.is_file():
return p
raise FileNotFoundError("can't find vdf_client binary")
async def spawn_process(host, port, counter):
async def spawn_process(host: str, port: int, counter: int):
global stopped
global active_processes
path_to_vdf_client = find_vdf_client()
@ -77,7 +77,7 @@ async def spawn_process(host, port, counter):
await asyncio.sleep(0.1)
async def spawn_all_processes(config, net_config):
async def spawn_all_processes(config: Dict, net_config: Dict):
await asyncio.sleep(5)
port = config["port"]
process_count = config["process_count"]

View File

@ -159,7 +159,7 @@ class BlockTools:
updated_constants = updated_constants.replace(**const_dict)
self.constants = updated_constants
def change_config(self, new_config):
def change_config(self, new_config: Dict):
self._config = new_config
overrides = self._config["network_overrides"]["constants"][self._config["selected_network"]]
updated_constants = self.constants.replace_str_to_bytes(**overrides)
@ -1162,11 +1162,11 @@ def get_challenges(
blocks: Dict[uint32, BlockRecord],
finished_sub_slots: List[EndOfSubSlotBundle],
prev_header_hash: Optional[bytes32],
):
) -> Tuple[bytes32, bytes32]:
if len(finished_sub_slots) == 0:
if prev_header_hash is None:
return constants.GENESIS_CHALLENGE, constants.GENESIS_CHALLENGE
curr = blocks[prev_header_hash]
curr: BlockRecord = blocks[prev_header_hash]
while not curr.first_in_sub_slot:
curr = blocks[curr.prev_hash]
assert curr.finished_challenge_slot_hashes is not None

View File

@ -11,7 +11,7 @@ def hexstr_to_bytes(input_str: str) -> bytes:
return bytes.fromhex(input_str)
def make_sized_bytes(size):
def make_sized_bytes(size: int):
"""
Create a streamable type that subclasses "bytes" but requires instances
to be a certain, fixed size.

View File

@ -83,7 +83,7 @@ def load_config_cli(root_path: Path, filename: str, sub_config: Optional[str] =
return unflatten_properties(flattened_props)
def flatten_properties(config: Dict):
def flatten_properties(config: Dict) -> Dict:
properties = {}
for key, value in config.items():
if type(value) is dict:
@ -94,7 +94,7 @@ def flatten_properties(config: Dict):
return properties
def unflatten_properties(config: Dict):
def unflatten_properties(config: Dict) -> Dict:
properties: Dict = {}
for key, value in config.items():
if "." in key:
@ -114,7 +114,7 @@ def add_property(d: Dict, partial_key: str, value: Any):
d[key_1][key_2] = value
def str2bool(v: Any) -> bool:
def str2bool(v: Union[str, bool]) -> bool:
# Source from https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse
if isinstance(v, bool):
return v

View File

@ -39,7 +39,7 @@ def generate_mnemonic() -> str:
return mnemonic
def bytes_to_mnemonic(mnemonic_bytes: bytes):
def bytes_to_mnemonic(mnemonic_bytes: bytes) -> str:
if len(mnemonic_bytes) not in [16, 20, 24, 28, 32]:
raise ValueError(
f"Data length should be one of the following: [16, 20, 24, 28, 32], but it is {len(mnemonic_bytes)}."
@ -64,7 +64,7 @@ def bytes_to_mnemonic(mnemonic_bytes: bytes):
return " ".join(mnemonics)
def bytes_from_mnemonic(mnemonic_str: str):
def bytes_from_mnemonic(mnemonic_str: str) -> bytes:
mnemonic: List[str] = mnemonic_str.split(" ")
if len(mnemonic) not in [12, 15, 18, 21, 24]:
raise ValueError("Invalid mnemonic length")
@ -93,7 +93,7 @@ def bytes_from_mnemonic(mnemonic_str: str):
return entropy_bytes
def mnemonic_to_seed(mnemonic: str, passphrase):
def mnemonic_to_seed(mnemonic: str, passphrase: str) -> bytes:
"""
Uses BIP39 standard to derive a seed from entropy bytes.
"""
@ -124,7 +124,7 @@ class Keychain:
self.testing = testing
self.user = user
def _get_service(self):
def _get_service(self) -> str:
"""
The keychain stores keys under a different name for tests.
"""
@ -148,7 +148,7 @@ class Keychain:
str_bytes[G1Element.SIZE :], # flake8: noqa
)
def _get_private_key_user(self, index: int):
def _get_private_key_user(self, index: int) -> str:
"""
Returns the keychain user string for a key index.
"""

View File

@ -1,7 +1,7 @@
from typing import Dict
from chia.consensus.default_constants import DEFAULT_CONSTANTS
from chia.consensus.default_constants import DEFAULT_CONSTANTS, ConsensusConstants
def make_test_constants(test_constants_overrides: Dict):
def make_test_constants(test_constants_overrides: Dict) -> ConsensusConstants:
return DEFAULT_CONSTANTS.replace(**test_constants_overrides)

View File

@ -42,7 +42,7 @@ TRUNCATED = bytes([3])
BLANK = bytes([0] * 32)
prehashed: Dict = {}
prehashed: Dict[bytes, Any] = {}
def init_prehashed():
@ -54,14 +54,14 @@ def init_prehashed():
init_prehashed()
def hashdown(mystr: bytes):
def hashdown(mystr: bytes) -> bytes:
assert len(mystr) == 66
h = prehashed[bytes(mystr[0:1] + mystr[33:34])].copy()
h.update(mystr[1:33] + mystr[34:])
return h.digest()[:32]
def compress_root(mystr: bytes):
def compress_root(mystr: bytes) -> bytes:
assert len(mystr) == 33
if mystr[0:1] == MIDDLE:
return mystr[1:]
@ -71,7 +71,7 @@ def compress_root(mystr: bytes):
return sha256(mystr).digest()[:32]
def get_bit(mybytes: bytes, pos: int):
def get_bit(mybytes: bytes, pos: int) -> int:
assert len(mybytes) == 32
return (mybytes[pos // 8] >> (7 - (pos % 8))) & 1
@ -104,7 +104,7 @@ class Node(metaclass=ABCMeta):
pass
@abstractmethod
def is_included(self, tocheck: bytes, depth: int, p: List[bytes]):
def is_included(self, tocheck: bytes, depth: int, p: List[bytes]) -> bool:
pass
@abstractmethod
@ -125,7 +125,7 @@ class MerkleSet:
else:
self.root = root
def get_root(self) -> Node:
def get_root(self) -> bytes:
return compress_root(self.root.get_hash())
def add_already_hashed(self, toadd: bytes):

View File

@ -1,4 +1,4 @@
def format_minutes(minutes):
def format_minutes(minutes: int) -> str:
if not isinstance(minutes, int):
return "Invalid"
@ -18,10 +18,10 @@ def format_minutes(minutes):
days = int(minutes / day_minutes)
hours = int(minutes / hour_minutes)
def format_unit_string(str_unit, count):
def format_unit_string(str_unit: str, count: int) -> str:
return f"{count} {str_unit}{('s' if count > 1 else '')}"
def format_unit(unit, count, unit_minutes, next_unit, next_unit_minutes):
def format_unit(unit: str, count: int, unit_minutes: int, next_unit: str, next_unit_minutes: int) -> str:
formatted = format_unit_string(unit, count)
minutes_left = minutes % unit_minutes
if minutes_left >= next_unit_minutes:

View File

@ -1,11 +1,12 @@
from typing import Any
from chia.server.outbound_message import NodeType
def is_localhost(peer_host: str):
def is_localhost(peer_host: str) -> bool:
return peer_host == "127.0.0.1" or peer_host == "localhost" or peer_host == "::1" or peer_host == "0:0:0:0:0:0:0:1"
def class_for_type(type: NodeType):
def class_for_type(type: NodeType) -> Any:
if type is NodeType.FULL_NODE:
from chia.full_node.full_node_api import FullNodeAPI

View File

@ -2,7 +2,7 @@ from dataclasses import replace
from typing import Any
def recursive_replace(root_obj: Any, replace_str: str, replace_with: Any):
def recursive_replace(root_obj: Any, replace_str: str, replace_with: Any) -> Any:
split_str = replace_str.split(".")
if len(split_str) == 1:
return replace(root_obj, **{split_str[0]: replace_with})

View File

@ -1,8 +1,9 @@
import asyncio
import logging
from typing import Optional
def cancel_task_safe(task: Optional[asyncio.Task], log=None):
def cancel_task_safe(task: Optional[asyncio.Task], log: Optional[logging.Logger] = None):
if task is not None:
try:
task.cancel()

View File

@ -1,4 +1,4 @@
from typing import KeysView
from typing import KeysView, Generator
SERVICES_FOR_GROUP = {
"all": "chia_harvester chia_timelord_launcher chia_timelord chia_farmer chia_full_node chia_wallet".split(),
@ -21,11 +21,11 @@ def all_groups() -> KeysView[str]:
return SERVICES_FOR_GROUP.keys()
def services_for_groups(groups):
def services_for_groups(groups) -> Generator[str, None, None]:
for group in groups:
for service in SERVICES_FOR_GROUP[group]:
yield service
def validate_service(service) -> bool:
def validate_service(service: str) -> bool:
return any(service in _ for _ in SERVICES_FOR_GROUP.values())

View File

@ -6,6 +6,6 @@ except Exception:
no_setproctitle = True
def setproctitle(ps_name) -> None:
def setproctitle(ps_name: str) -> None:
if no_setproctitle is False:
pysetproctitle.setproctitle(ps_name)