Rust coin state (#12934)

* use rust types for CoinState and RespondToPhUpdates

* improve test_network_protocol_test to support native types
This commit is contained in:
Arvid Norberg 2022-08-18 19:20:09 +02:00 committed by GitHub
parent bd1a96b404
commit af255a0dbf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 57 additions and 43 deletions

View File

@ -1,6 +1,8 @@
from dataclasses import dataclass
from typing import List, Optional, Tuple
from chia_rs import CoinState, RespondToPhUpdates
from chia.types.blockchain_format.coin import Coin
from chia.types.blockchain_format.program import SerializedProgram
from chia.types.blockchain_format.sized_bytes import bytes32
@ -15,6 +17,9 @@ Note: When changing this file, also change protocol_message_types.py, and the pr
"""
__all__ = ["CoinState", "RespondToPhUpdates"]
@streamable
@dataclass(frozen=True)
class RequestPuzzleSolution(Streamable):
@ -178,12 +183,13 @@ class RespondHeaderBlocks(Streamable):
header_blocks: List[HeaderBlock]
@streamable
@dataclass(frozen=True)
class CoinState(Streamable):
coin: Coin
spent_height: Optional[uint32]
created_height: Optional[uint32]
# This class is implemented in Rust
# @streamable
# @dataclass(frozen=True)
# class CoinState(Streamable):
# coin: Coin
# spent_height: Optional[uint32]
# created_height: Optional[uint32]
@streamable
@ -193,12 +199,13 @@ class RegisterForPhUpdates(Streamable):
min_height: uint32
@streamable
@dataclass(frozen=True)
class RespondToPhUpdates(Streamable):
puzzle_hashes: List[bytes32]
min_height: uint32
coin_states: List[CoinState]
# This class is implemented in Rust
# @streamable
# @dataclass(frozen=True)
# class RespondToPhUpdates(Streamable):
# puzzle_hashes: List[bytes32]
# min_height: uint32
# coin_states: List[CoinState]
@streamable

View File

@ -1714,8 +1714,8 @@ class WalletRpcApi:
coin_state.coin,
None,
full_puzzle,
launcher_coin[0].spent_height,
coin_state.created_height if coin_state.created_height else uint32(0),
uint32(launcher_coin[0].spent_height),
uint32(coin_state.created_height) if coin_state.created_height else uint32(0),
)
)
except Exception as e:

View File

@ -377,7 +377,7 @@ class DIDWallet:
)[0]
assert parent_state.spent_height is not None
puzzle_solution_request = wallet_protocol.RequestPuzzleSolution(
coin.parent_coin_info, parent_state.spent_height
coin.parent_coin_info, uint32(parent_state.spent_height)
)
response = await peer.request_puzzle_solution(puzzle_solution_request)
req_puz_sol = response.response

View File

@ -218,7 +218,7 @@ class NFTWallet:
and len(launcher_coin_states) == 1
and launcher_coin_states[0].spent_height is not None
)
mint_height: uint32 = launcher_coin_states[0].spent_height
mint_height: uint32 = uint32(launcher_coin_states[0].spent_height)
self.log.info("Adding a new NFT to wallet: %s", child_coin)
# all is well, lets add NFT to our local db
@ -229,7 +229,7 @@ class NFTWallet:
)
if coin_states is not None:
parent_coin = coin_states[0].coin
confirmed_height = coin_states[0].spent_height
confirmed_height = None if coin_states[0].spent_height is None else uint32(coin_states[0].spent_height)
if parent_coin is None or confirmed_height is None:
raise ValueError("Error finding parent")

View File

@ -49,9 +49,9 @@ class PeerRequestCache:
def add_to_states_validated(self, coin_state: CoinState) -> None:
cs_height: Optional[uint32] = None
if coin_state.spent_height is not None:
cs_height = coin_state.spent_height
cs_height = uint32(coin_state.spent_height)
elif coin_state.created_height is not None:
cs_height = coin_state.created_height
cs_height = uint32(coin_state.created_height)
self._states_validated.put(coin_state.get_hash(), cs_height)
def get_height_timestamp(self, height: uint32) -> Optional[uint64]:

View File

@ -296,9 +296,9 @@ def get_block_challenge(
def last_change_height_cs(cs: CoinState) -> uint32:
if cs.spent_height is not None:
return cs.spent_height
return uint32(cs.spent_height)
if cs.created_height is not None:
return cs.created_height
return uint32(cs.created_height)
# Reorgs should be processed at the beginning
return uint32(0)

View File

@ -1241,8 +1241,10 @@ class WalletNode:
if await can_use_peer_request_cache(coin_state, peer_request_cache, fork_height):
return True
spent_height = coin_state.spent_height
confirmed_height = coin_state.created_height
spent_height: Optional[uint32] = None if coin_state.spent_height is None else uint32(coin_state.spent_height)
confirmed_height: Optional[uint32] = (
None if coin_state.created_height is None else uint32(coin_state.created_height)
)
current = await self.wallet_state_manager.coin_store.get_coin_record(coin_state.coin.name())
# if remote state is same as current local state we skip validation

View File

@ -603,8 +603,8 @@ class WalletStateManager:
self, peer: WSChiaConnection, coin_state: CoinState, fork_height: Optional[uint32]
) -> Tuple[Optional[uint32], Optional[WalletType]]:
if coin_state.created_height is not None and (
self.is_pool_reward(coin_state.created_height, coin_state.coin)
or self.is_farmer_reward(coin_state.created_height, coin_state.coin)
self.is_pool_reward(uint32(coin_state.created_height), coin_state.coin)
or self.is_farmer_reward(uint32(coin_state.created_height), coin_state.coin)
):
return None, None
@ -692,7 +692,7 @@ class WalletStateManager:
await self.interested_store.add_unacknowledged_token(
asset_id,
CATWallet.default_wallet_name_for_unknown_cat(asset_id.hex()),
parent_coin_state.spent_height,
None if parent_coin_state.spent_height is None else uint32(parent_coin_state.spent_height),
parent_coin_state.coin.puzzle_hash,
)
self.state_changed("added_stray_cat")
@ -834,7 +834,7 @@ class WalletStateManager:
)
nft_wallet: NFTWallet = self.wallets[wallet_info.id]
if parent_coin_state.spent_height is not None:
await nft_wallet.remove_coin(coin_spend.coin, parent_coin_state.spent_height)
await nft_wallet.remove_coin(coin_spend.coin, uint32(parent_coin_state.spent_height))
if nft_wallet_info.did_id == new_did_id:
self.log.info(
"Adding new NFT, NFT_ID:%s, DID_ID:%s",
@ -928,7 +928,7 @@ class WalletStateManager:
if local_record is None:
await self.coin_added(
coin_state.coin,
coin_state.created_height,
uint32(coin_state.created_height),
all_unconfirmed,
wallet_id,
wallet_type,
@ -948,18 +948,18 @@ class WalletStateManager:
farmer_reward = False
pool_reward = False
tx_type: int
if self.is_farmer_reward(coin_state.created_height, coin_state.coin):
if self.is_farmer_reward(uint32(coin_state.created_height), coin_state.coin):
farmer_reward = True
tx_type = TransactionType.FEE_REWARD.value
elif self.is_pool_reward(coin_state.created_height, coin_state.coin):
elif self.is_pool_reward(uint32(coin_state.created_height), coin_state.coin):
pool_reward = True
tx_type = TransactionType.COINBASE_REWARD.value
else:
tx_type = TransactionType.INCOMING_TX.value
record = WalletCoinRecord(
coin_state.coin,
coin_state.created_height,
coin_state.spent_height,
uint32(coin_state.created_height),
uint32(coin_state.spent_height),
True,
farmer_reward or pool_reward,
wallet_type,
@ -978,7 +978,7 @@ class WalletStateManager:
if not change:
created_timestamp = await self.wallet_node.get_timestamp_for_height(coin_state.created_height)
tx_record = TransactionRecord(
confirmed_at_height=coin_state.created_height,
confirmed_at_height=uint32(coin_state.created_height),
created_at_time=uint64(created_timestamp),
to_puzzle_hash=(await self.convert_puzzle_hash(wallet_id, coin_state.coin.puzzle_hash)),
amount=uint64(coin_state.coin.amount),
@ -1029,10 +1029,10 @@ class WalletStateManager:
if len(tx_records) > 0:
for tx_record in tx_records:
await self.tx_store.set_confirmed(tx_record.name, coin_state.spent_height)
await self.tx_store.set_confirmed(tx_record.name, uint32(coin_state.spent_height))
else:
tx_record = TransactionRecord(
confirmed_at_height=coin_state.spent_height,
confirmed_at_height=uint32(coin_state.spent_height),
created_at_time=uint64(spent_timestamp),
to_puzzle_hash=(await self.convert_puzzle_hash(wallet_id, to_puzzle_hash)),
amount=uint64(int(amount)),
@ -1052,7 +1052,7 @@ class WalletStateManager:
await self.tx_store.add_transaction_record(tx_record)
else:
await self.coin_store.set_spent(coin_name, coin_state.spent_height)
await self.coin_store.set_spent(coin_name, uint32(coin_state.spent_height))
rem_tx_records: List[TransactionRecord] = []
for out_tx_record in all_unconfirmed:
for rem_coin in out_tx_record.removals:
@ -1060,12 +1060,12 @@ class WalletStateManager:
rem_tx_records.append(out_tx_record)
for tx_record in rem_tx_records:
await self.tx_store.set_confirmed(tx_record.name, coin_state.spent_height)
await self.tx_store.set_confirmed(tx_record.name, uint32(coin_state.spent_height))
for unconfirmed_record in all_unconfirmed:
for rem_coin in unconfirmed_record.removals:
if rem_coin == coin_state.coin:
self.log.info(f"Setting tx_id: {unconfirmed_record.name} to confirmed")
await self.tx_store.set_confirmed(unconfirmed_record.name, coin_state.spent_height)
await self.tx_store.set_confirmed(unconfirmed_record.name, uint32(coin_state.spent_height))
if record.wallet_type == WalletType.POOLING_WALLET:
if coin_state.spent_height is not None and coin_state.coin.amount == uint64(1):
@ -1085,13 +1085,15 @@ class WalletStateManager:
break
await self.coin_added(
new_singleton_coin,
coin_state.spent_height,
uint32(coin_state.spent_height),
[],
uint32(record.wallet_id),
record.wallet_type,
peer,
)
await self.coin_store.set_spent(curr_coin_state.coin.name(), curr_coin_state.spent_height)
await self.coin_store.set_spent(
curr_coin_state.coin.name(), uint32(curr_coin_state.spent_height)
)
await self.add_interested_coin_ids([new_singleton_coin.name()])
new_coin_state: List[CoinState] = await self.wallet_node.get_coin_state(
[new_singleton_coin.name()], peer=peer, fork_height=fork_height
@ -1138,7 +1140,7 @@ class WalletStateManager:
self.main_wallet,
child.coin.name(),
[launcher_spend],
child.spent_height,
uint32(child.spent_height),
name="pool_wallet",
)
launcher_spend_additions = launcher_spend.additions()
@ -1146,7 +1148,7 @@ class WalletStateManager:
coin_added = launcher_spend_additions[0]
await self.coin_added(
coin_added,
coin_state.spent_height,
uint32(coin_state.spent_height),
[],
pool_wallet.id(),
WalletType(pool_wallet.type()),

View File

@ -8,7 +8,7 @@ dependencies = [
"chiapos==1.0.10", # proof of space
"clvm==0.9.7",
"clvm_tools==0.4.5", # Currying, Program.to, other conveniences
"chia_rs==0.1.7",
"chia_rs==0.1.8",
"clvm-tools-rs==0.1.19", # Rust implementation of clvm_tools' compiler
"aiohttp==3.8.1", # HTTP server for full node rpc
"aiosqlite==0.17.0", # asyncio wrapper for sqlite, to store blocks

View File

@ -23,6 +23,9 @@ def types_in_module(mod: Any) -> Set[str]:
obj = getattr(mod, sym)
if hasattr(obj, "__module__") and obj.__module__ == mod_name:
ret.append(sym)
if hasattr(mod, "__all__"):
ret += getattr(mod, "__all__")
return set(ret)