enable mypy for a some wallet files (#13185)

* enable mypy for chia/wallet/trading/offer.py

* enable mypy for chia/wallet/trade_manager.py
This commit is contained in:
Arvid Norberg 2022-08-26 21:25:11 +02:00 committed by GitHub
parent 000df27dbf
commit 31b5340001
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 37 additions and 33 deletions

View File

@ -1,15 +1,16 @@
from typing import Any, List
from typing import List, Union
from chia_rs import Coin
from chia.types.blockchain_format.sized_bytes import bytes32
from chia.util.hash import std_hash
from chia.util.ints import uint64
__all__ = ["Coin", "coin_as_list", "hash_coin_ids"]
def coin_as_list(c: Coin) -> List[Any]:
return [c.parent_coin_info, c.puzzle_hash, c.amount]
def coin_as_list(c: Coin) -> List[Union[bytes32, uint64]]:
return [c.parent_coin_info, c.puzzle_hash, uint64(c.amount)]
def hash_coin_ids(coin_ids: List[bytes32]) -> bytes32:

View File

@ -1,7 +1,8 @@
from __future__ import annotations
import warnings
from dataclasses import dataclass
from typing import List
from typing import List, Dict, Any
from blspy import AugSchemeMPL, G2Element
@ -64,7 +65,7 @@ class SpendBundle(Streamable):
def debug(self, agg_sig_additional_data=DEFAULT_CONSTANTS.AGG_SIG_ME_ADDITIONAL_DATA):
debug_spend_bundle(self, agg_sig_additional_data)
def not_ephemeral_additions(self):
def not_ephemeral_additions(self) -> List[Coin]:
all_removals = self.removals()
all_additions = self.additions()
result: List[Coin] = []
@ -86,7 +87,7 @@ class SpendBundle(Streamable):
# 4. remove all code below this point
@classmethod
def from_json_dict(cls, json_dict):
def from_json_dict(cls, json_dict: Dict[str, Any]) -> SpendBundle:
if "coin_solutions" in json_dict:
if "coin_spends" not in json_dict:
json_dict = dict(
@ -97,7 +98,7 @@ class SpendBundle(Streamable):
raise ValueError("JSON contains both `coin_solutions` and `coin_spends`, just use `coin_spends`")
return streamable_from_dict(cls, json_dict)
def to_json_dict(self, include_legacy_keys: bool = True, exclude_modern_keys: bool = True):
def to_json_dict(self, include_legacy_keys: bool = True, exclude_modern_keys: bool = True) -> Dict[str, Any]:
if include_legacy_keys is False and exclude_modern_keys is True:
raise ValueError("`coin_spends` not included in legacy or modern outputs")
d = recurse_jsonify(self)

View File

@ -1,3 +1,4 @@
from __future__ import annotations
import dataclasses
import logging
import time
@ -78,8 +79,8 @@ class TradeManager:
async def create(
wallet_state_manager: Any,
db_wrapper: DBWrapper2,
name: str = None,
):
name: Optional[str] = None,
) -> TradeManager:
self = TradeManager()
if name:
self.log = logging.getLogger(name)
@ -127,7 +128,7 @@ class TradeManager:
async def coins_of_interest_farmed(
self, coin_state: CoinState, fork_height: Optional[uint32], peer: WSChiaConnection
):
) -> None:
"""
If both our coins and other coins in trade got removed that means that trade was successfully executed
If coins from other side of trade got farmed without ours, that means that trade failed because either someone
@ -186,7 +187,7 @@ class TradeManager:
await self.trade_store.set_status(trade.trade_id, TradeStatus.FAILED)
self.log.warning(f"Trade with id: {trade.trade_id} failed")
async def get_locked_coins(self, wallet_id: int = None) -> Dict[bytes32, WalletCoinRecord]:
async def get_locked_coins(self, wallet_id: Optional[int] = None) -> Dict[bytes32, WalletCoinRecord]:
"""Returns a dictionary of confirmed coins that are locked by a trade."""
all_pending = []
pending_accept = await self.get_offers_with_status(TradeStatus.PENDING_ACCEPT)
@ -208,7 +209,7 @@ class TradeManager:
return result
async def get_all_trades(self):
async def get_all_trades(self) -> List[TradeRecord]:
all: List[TradeRecord] = await self.trade_store.get_all_trades()
return all
@ -216,7 +217,7 @@ class TradeManager:
record = await self.trade_store.get_trade_record(trade_id)
return record
async def cancel_pending_offer(self, trade_id: bytes32):
async def cancel_pending_offer(self, trade_id: bytes32) -> None:
await self.trade_store.set_status(trade_id, TradeStatus.CANCELLED)
self.wallet_state_manager.state_changed("offer_cancelled")
@ -389,7 +390,7 @@ class TradeManager:
await self.trade_store.set_status(trade.trade_id, TradeStatus.CANCELLED)
return all_txs
async def save_trade(self, trade: TradeRecord):
async def save_trade(self, trade: TradeRecord) -> None:
await self.trade_store.add_trade_record(trade)
self.wallet_state_manager.state_changed("offer_added")
@ -578,7 +579,7 @@ class TradeManager:
self.log.error(f"Error with creating trade offer: {type(e)}{tb}")
return False, None, str(e)
async def maybe_create_wallets_for_offer(self, offer: Offer):
async def maybe_create_wallets_for_offer(self, offer: Offer) -> None:
for key in offer.arbitrage():
wsm = self.wallet_state_manager
@ -694,7 +695,7 @@ class TradeManager:
self,
offer: Offer,
peer: WSChiaConnection,
fee=uint64(0),
fee: uint64 = uint64(0),
min_coin_amount: Optional[uint64] = None,
) -> Union[Tuple[Literal[True], TradeRecord, None], Tuple[Literal[False], None, str]]:
take_offer_dict: Dict[Union[bytes32, int], int] = {}

View File

@ -1,5 +1,6 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Set, Tuple
from typing import Any, Dict, List, Optional, Set, Tuple, Union, BinaryIO
from blspy import G2Element
from clvm_tools.binutils import disassemble
@ -64,7 +65,7 @@ class Offer:
) -> Dict[Optional[bytes32], List[NotarizedPayment]]:
# This sort should be reproducible in CLVM with `>s`
sorted_coins: List[Coin] = sorted(coins, key=Coin.name)
sorted_coin_list: List[List] = [coin_as_list(c) for c in sorted_coins]
sorted_coin_list: List[List[Union[bytes32, uint64]]] = [coin_as_list(c) for c in sorted_coins]
nonce: bytes32 = Program.to(sorted_coin_list).get_tree_hash()
notarized_payments: Dict[Optional[bytes32], List[NotarizedPayment]] = {}
@ -95,9 +96,9 @@ class Offer:
return announcements
def __post_init__(self):
def __post_init__(self) -> None:
# Verify that there is at least something being offered
offered_coins: Dict[bytes32, List[Coin]] = self.get_offered_coins()
offered_coins: Dict[Optional[bytes32], List[Coin]] = self.get_offered_coins()
if offered_coins == {}:
raise ValueError("Bundle is not offering anything")
if self.get_requested_payments() == {}:
@ -289,7 +290,7 @@ class Offer:
return list(primary_coins)
@classmethod
def aggregate(cls, offers: List["Offer"]) -> "Offer":
def aggregate(cls, offers: List[Offer]) -> Offer:
total_requested_payments: Dict[Optional[bytes32], List[NotarizedPayment]] = {}
total_bundle = SpendBundle([], G2Element())
total_driver_dict: Dict[bytes32, PuzzleInfo] = {}
@ -363,7 +364,7 @@ class Offer:
sibling_spends: str = "("
sibling_puzzles: str = "("
sibling_solutions: str = "("
disassembled_offer_mod: str = disassemble(OFFER_MOD)
disassembled_offer_mod: str = disassemble(OFFER_MOD) # type: ignore
for sibling_coin in offered_coins:
if sibling_coin != coin:
siblings += (
@ -374,7 +375,7 @@ class Offer:
)
sibling_spends += "0x" + bytes(coin_to_spend_dict[sibling_coin]).hex() + ")"
sibling_puzzles += disassembled_offer_mod
sibling_solutions += disassemble(coin_to_solution_dict[sibling_coin])
sibling_solutions += disassemble(coin_to_solution_dict[sibling_coin]) # type: ignore
siblings += ")"
sibling_spends += ")"
sibling_puzzles += ")"
@ -442,7 +443,7 @@ class Offer:
)
@classmethod
def from_spend_bundle(cls, bundle: SpendBundle) -> "Offer":
def from_spend_bundle(cls, bundle: SpendBundle) -> Offer:
# Because of the `to_spend_bundle` method, we need to parse the dummy CoinSpends as `requested_payments`
requested_payments: Dict[Optional[bytes32], List[NotarizedPayment]] = {}
driver_dict: Dict[bytes32, PuzzleInfo] = {}
@ -473,7 +474,7 @@ class Offer:
def name(self) -> bytes32:
return self.to_spend_bundle().name()
def compress(self, version=None) -> bytes:
def compress(self, version: Optional[int] = None) -> bytes:
as_spend_bundle = self.to_spend_bundle()
if version is None:
mods: List[bytes] = [bytes(s.puzzle_reveal.to_program().uncurry()[0]) for s in as_spend_bundle.coin_spends]
@ -481,24 +482,24 @@ class Offer:
return compress_object_with_puzzles(bytes(as_spend_bundle), version)
@classmethod
def from_compressed(cls, compressed_bytes: bytes) -> "Offer":
def from_compressed(cls, compressed_bytes: bytes) -> Offer:
return Offer.from_bytes(decompress_object_with_puzzles(compressed_bytes))
@classmethod
def try_offer_decompression(cls, offer_bytes: bytes) -> "Offer":
def try_offer_decompression(cls, offer_bytes: bytes) -> Offer:
try:
return cls.from_compressed(offer_bytes)
except TypeError:
pass
return cls.from_bytes(offer_bytes)
def to_bech32(self, prefix: str = "offer", compression_version=None) -> str:
def to_bech32(self, prefix: str = "offer", compression_version: Optional[int] = None) -> str:
offer_bytes = self.compress(version=compression_version)
encoded = bech32_encode(prefix, convertbits(list(offer_bytes), 8, 5))
return encoded
@classmethod
def from_bech32(cls, offer_bech32: str) -> "Offer":
def from_bech32(cls, offer_bech32: str) -> Offer:
hrpgot, data = bech32_decode(offer_bech32, max_length=len(offer_bech32))
if data is None:
raise ValueError("Invalid Offer")
@ -509,11 +510,11 @@ class Offer:
# Methods to make this a valid Streamable member
# We basically hijack the SpendBundle versions for most of it
@classmethod
def parse(cls, f) -> "Offer":
def parse(cls, f: BinaryIO) -> Offer:
parsed_bundle = SpendBundle.parse(f)
return cls.from_bytes(bytes(parsed_bundle))
def stream(self, f):
def stream(self, f: BinaryIO) -> None:
as_spend_bundle = SpendBundle.from_bytes(bytes(self))
as_spend_bundle.stream(f)
@ -521,7 +522,7 @@ class Offer:
return bytes(self.to_spend_bundle())
@classmethod
def from_bytes(cls, as_bytes: bytes) -> "Offer":
def from_bytes(cls, as_bytes: bytes) -> Offer:
# Because of the __bytes__ method, we need to parse the dummy CoinSpends as `requested_payments`
bundle = SpendBundle.from_bytes(as_bytes)
return cls.from_spend_bundle(bundle)

File diff suppressed because one or more lines are too long