Prevent redundant peer calls in coin_added (#15735)

* Refine coin_added

* Resolve comments

* Fix assert

* Fix pre-commit

* Resolve comments

* Resolve comments

* Resolve comments

* Resolve comment

* Add todo

* Add generic

* resolve comment

* Fix cat

* Resolve comments

* Fix test
This commit is contained in:
Kronus91 2023-08-31 10:42:30 -07:00 committed by GitHub
parent 1463f0c84e
commit 3a82aacbbe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 179 additions and 98 deletions

View File

@ -108,7 +108,8 @@ class Mirror:
@final
class DataLayerWallet:
if TYPE_CHECKING:
_protocol_check: ClassVar[WalletProtocol] = cast("DataLayerWallet", None)
# TODO Create DataLayer coin data model if necessary
_protocol_check: ClassVar[WalletProtocol[object]] = cast("DataLayerWallet", None)
wallet_state_manager: WalletStateManager
log: logging.Logger
@ -847,7 +848,7 @@ class DataLayerWallet:
# SYNCING #
###########
async def coin_added(self, coin: Coin, height: uint32, peer: WSChiaConnection) -> None:
async def coin_added(self, coin: Coin, height: uint32, peer: WSChiaConnection, coin_data: Optional[object]) -> None:
if coin.puzzle_hash == create_mirror_puzzle().get_tree_hash():
parent_state: CoinState = (
await self.wallet_state_manager.wallet_node.get_coin_state([coin.parent_coin_info], peer=peer)

View File

@ -65,7 +65,7 @@ class PoolWallet:
if TYPE_CHECKING:
from chia.wallet.wallet_protocol import WalletProtocol
_protocol_check: ClassVar[WalletProtocol] = cast("PoolWallet", None)
_protocol_check: ClassVar[WalletProtocol[object]] = cast("PoolWallet", None)
MINIMUM_INITIAL_BALANCE = 1
MINIMUM_RELATIVE_LOCK_HEIGHT = 5
@ -991,7 +991,7 @@ class PoolWallet:
async def get_max_send_amount(self, records: Optional[Set[WalletCoinRecord]] = None) -> uint128:
return uint128(0)
async def coin_added(self, coin: Coin, height: uint32, peer: WSChiaConnection) -> None:
async def coin_added(self, coin: Coin, height: uint32, peer: WSChiaConnection, coin_data: Optional[object]) -> None:
pass
async def select_coins(self, amount: uint64, coin_selection_config: CoinSelectionConfig) -> Set[Coin]:

View File

@ -52,7 +52,7 @@ from chia.wallet.derive_keys import (
match_address_to_sk,
)
from chia.wallet.did_wallet import did_wallet_puzzles
from chia.wallet.did_wallet.did_info import DIDInfo
from chia.wallet.did_wallet.did_info import DIDCoinData, DIDInfo
from chia.wallet.did_wallet.did_wallet import DIDWallet
from chia.wallet.did_wallet.did_wallet_puzzles import (
DID_INNERPUZ_MOD,
@ -1466,7 +1466,7 @@ class WalletRpcApi:
:return:
"""
entity_id: bytes32 = decode_puzzle_hash(request["id"])
selected_wallet: Optional[WalletProtocol] = None
selected_wallet: Optional[WalletProtocol[Any]] = None
is_hex = request.get("is_hex", False)
if isinstance(is_hex, str):
is_hex = bool(is_hex)
@ -2100,7 +2100,13 @@ class WalletRpcApi:
if curried_args is None:
return {"success": False, "error": "The coin is not a DID."}
p2_puzzle, recovery_list_hash, num_verification, singleton_struct, metadata = curried_args
did_data: DIDCoinData = DIDCoinData(
p2_puzzle,
bytes32(recovery_list_hash.atom),
uint16(num_verification.as_int()),
singleton_struct,
metadata,
)
hinted_coins = compute_spend_hints_and_additions(coin_spend)
# Hint is required, if it doesn't have any hint then it should be invalid
hint: Optional[bytes32] = None
@ -2258,7 +2264,12 @@ class WalletRpcApi:
coin_state.coin, uint32(coin_state.created_height), uint32(0), False, False, wallet_type, wallet_id
)
await self.service.wallet_state_manager.coin_store.add_coin_record(coin_record, coin_state.coin.name())
await did_wallet.coin_added(coin_state.coin, uint32(coin_state.created_height), peer)
await did_wallet.coin_added(
coin_state.coin,
uint32(coin_state.created_height),
peer,
did_data,
)
return {"success": True, "latest_coin_id": coin_state.coin.name().hex()}
@tx_endpoint

View File

@ -17,6 +17,14 @@ class CATInfo(Streamable):
my_tail: Optional[Program] # this is the program
@streamable
@dataclass(frozen=True)
class CATCoinData(Streamable):
mod_hash: bytes32
tail_program_hash: bytes32
inner_puzzle: Program
# We used to store all of the lineage proofs here but it was very slow to serialize for a lot of transactions
# so we moved it to CATLineageStore. We keep this around for migration purposes.
@streamable

View File

@ -23,7 +23,7 @@ from chia.util.condition_tools import conditions_dict_for_solution, pkm_pairs_fo
from chia.util.hash import std_hash
from chia.util.ints import uint32, uint64, uint128
from chia.wallet.cat_wallet.cat_constants import DEFAULT_CATS
from chia.wallet.cat_wallet.cat_info import CATInfo, LegacyCATInfo
from chia.wallet.cat_wallet.cat_info import CATCoinData, CATInfo, LegacyCATInfo
from chia.wallet.cat_wallet.cat_utils import (
CAT_MOD,
SpendableCAT,
@ -69,7 +69,7 @@ QUOTED_MOD_HASH = calculate_hash_of_quoted_mod_hash(CAT_MOD_HASH)
class CATWallet:
if TYPE_CHECKING:
_protocol_check: ClassVar[WalletProtocol] = cast("CATWallet", None)
_protocol_check: ClassVar[WalletProtocol[CATCoinData]] = cast("CATWallet", None)
wallet_state_manager: WalletStateManager
log: logging.Logger
@ -349,7 +349,10 @@ class CATWallet:
)
)
async def coin_added(self, coin: Coin, height: uint32, peer: WSChiaConnection) -> None:
async def coin_added(
self, coin: Coin, height: uint32, peer: WSChiaConnection, coin_data: Optional[CATCoinData]
) -> None:
# TODO Use coin_data instead of calling peer API
"""Notification from wallet state manager that wallet has been received."""
self.log.info(f"CAT wallet has been notified that {coin.name().hex()} was added")

View File

@ -6,7 +6,7 @@ from typing import List, Optional, Tuple
from chia.types.blockchain_format.coin import Coin
from chia.types.blockchain_format.program import Program
from chia.types.blockchain_format.sized_bytes import bytes32
from chia.util.ints import uint64
from chia.util.ints import uint16, uint64
from chia.util.streamable import Streamable, streamable
from chia.wallet.lineage_proof import LineageProof
@ -24,3 +24,13 @@ class DIDInfo(Streamable):
temp_pubkey: Optional[bytes]
sent_recovery_transaction: bool
metadata: str # JSON of the user defined metadata
@streamable
@dataclass(frozen=True)
class DIDCoinData(Streamable):
p2_puzzle: Program
recovery_list_hash: bytes32
num_verification: uint16
singleton_struct: Program
metadata: Program

View File

@ -24,7 +24,7 @@ from chia.wallet.conditions import Condition
from chia.wallet.derivation_record import DerivationRecord
from chia.wallet.derive_keys import master_sk_to_wallet_sk_unhardened
from chia.wallet.did_wallet import did_wallet_puzzles
from chia.wallet.did_wallet.did_info import DIDInfo
from chia.wallet.did_wallet.did_info import DIDCoinData, DIDInfo
from chia.wallet.did_wallet.did_wallet_puzzles import uncurry_innerpuz
from chia.wallet.lineage_proof import LineageProof
from chia.wallet.payment import Payment
@ -49,13 +49,13 @@ from chia.wallet.util.wallet_types import WalletType
from chia.wallet.wallet import CHIP_0002_SIGN_MESSAGE_PREFIX, Wallet
from chia.wallet.wallet_coin_record import WalletCoinRecord
from chia.wallet.wallet_info import WalletInfo
from chia.wallet.wallet_protocol import WalletProtocol
class DIDWallet:
if TYPE_CHECKING:
from chia.wallet.wallet_protocol import WalletProtocol
_protocol_check: ClassVar[WalletProtocol] = cast("DIDWallet", None)
if TYPE_CHECKING:
_protocol_check: ClassVar[WalletProtocol[DIDCoinData]] = cast("DIDWallet", None)
wallet_state_manager: Any
log: logging.Logger
@ -343,9 +343,9 @@ class DIDWallet:
# We can improve this interface by passing in the CoinSpend, as well
# We need to change DID Wallet coin_added to expect p2 spends as well as recovery spends,
# or only call it in the recovery spend case
async def coin_added(self, coin: Coin, _: uint32, peer: WSChiaConnection):
async def coin_added(self, coin: Coin, _: uint32, peer: WSChiaConnection, coin_data: Optional[DIDCoinData]):
"""Notification from wallet state manager that wallet has been received."""
# TODO Use coin_data instead of calling peer API
parent = self.get_parent_for_coin(coin)
if parent is None:
# this is the first time we received it, check it's a DID coin

View File

@ -59,7 +59,7 @@ _T_NFTWallet = TypeVar("_T_NFTWallet", bound="NFTWallet")
class NFTWallet:
if TYPE_CHECKING:
_protocol_check: ClassVar[WalletProtocol] = cast("NFTWallet", None)
_protocol_check: ClassVar[WalletProtocol[UncurriedNFT]] = cast("NFTWallet", None)
wallet_state_manager: Any
log: logging.Logger
@ -157,8 +157,11 @@ class NFTWallet:
raise KeyError(f"Couldn't find coin with id: {nft_coin_id}")
return nft_coin
async def coin_added(self, coin: Coin, height: uint32, peer: WSChiaConnection) -> None:
async def coin_added(
self, coin: Coin, height: uint32, peer: WSChiaConnection, coin_data: Optional[UncurriedNFT]
) -> None:
"""Notification from wallet state manager that wallet has been received."""
# TODO Use coin_data instead of calling peer API
self.log.info(f"NFT wallet %s has been notified that {coin} was added", self.get_name())
if await self.nft_store.exists(coin.name()):
# already added

View File

@ -7,6 +7,7 @@ from typing import Optional, Type, TypeVar
from chia.types.blockchain_format.program import Program
from chia.types.blockchain_format.sized_bytes import bytes32
from chia.util.ints import uint16
from chia.util.streamable import Streamable, streamable
from chia.wallet.puzzles.load_clvm import load_clvm_maybe_recompile
log = logging.getLogger(__name__)
@ -19,8 +20,9 @@ NFT_OWNERSHIP_LAYER = load_clvm_maybe_recompile(
_T_UncurriedNFT = TypeVar("_T_UncurriedNFT", bound="UncurriedNFT")
@streamable
@dataclass(frozen=True)
class UncurriedNFT:
class UncurriedNFT(Streamable):
"""
A simple solution for uncurry NFT puzzle.
Initial the class with a full NFT puzzle, it will do a deep uncurry.
@ -168,7 +170,7 @@ class UncurriedNFT:
return None
return cls(
nft_mod_hash=nft_mod_hash,
nft_mod_hash=bytes32(nft_mod_hash.atom),
nft_state_layer=nft_state_layer,
singleton_struct=singleton_struct,
singleton_mod_hash=singleton_mod_hash,

View File

@ -2,7 +2,7 @@ from __future__ import annotations
from dataclasses import dataclass
from enum import IntEnum
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, TypeVar
from chia.util.ints import uint8, uint32
from chia.util.streamable import Streamable, streamable
@ -42,13 +42,16 @@ class RemarkDataType(IntEnum):
CLAWBACK = 2
T = TypeVar("T", contravariant=True)
@dataclass(frozen=True)
class WalletIdentifier:
id: uint32
type: WalletType
@classmethod
def create(cls, wallet: WalletProtocol) -> WalletIdentifier:
def create(cls, wallet: WalletProtocol[T]) -> WalletIdentifier:
return cls(wallet.id(), wallet.type())

View File

@ -20,7 +20,7 @@ from chia.util.byte_types import hexstr_to_bytes
from chia.util.hash import std_hash
from chia.util.ints import uint8, uint32, uint64, uint128
from chia.util.misc import VersionedBlob
from chia.wallet.cat_wallet.cat_info import CRCATInfo
from chia.wallet.cat_wallet.cat_info import CATCoinData, CRCATInfo
from chia.wallet.cat_wallet.cat_utils import CAT_MOD, construct_cat_puzzle
from chia.wallet.cat_wallet.cat_wallet import CATWallet
from chia.wallet.coin_selection import select_coins
@ -191,7 +191,9 @@ class CRCATWallet(CATWallet):
async def set_tail_program(self, tail_program: str) -> None: # pragma: no cover
raise NotImplementedError("set_tail_program is a legacy method and is not available on CR-CAT wallets")
async def coin_added(self, coin: Coin, height: uint32, peer: WSChiaConnection) -> None:
async def coin_added(
self, coin: Coin, height: uint32, peer: WSChiaConnection, coin_data: Optional[CATCoinData]
) -> None:
"""Notification from wallet state manager that wallet has been received."""
self.log.info(f"CR-CAT wallet has been notified that {coin.name().hex()} was added")
try:
@ -884,4 +886,4 @@ class CRCATWallet(CATWallet):
if TYPE_CHECKING:
_dummy: WalletProtocol = CRCATWallet()
_dummy: WalletProtocol[CATCoinData] = CRCATWallet()

View File

@ -65,6 +65,7 @@ VIRAL_BACKDOOR: Program = load_clvm_maybe_recompile(
# (a (i 7 (q . 7) (q 4 2 (q () ()))) 1)
ACS_TRANSFER_PROGRAM: Program = Program.to([2, [3, 7, (1, 7), [1, 4, 2, [1, None, None]]], 1])
# Hashes
EXTIGENT_METADATA_LAYER_HASH = EXTIGENT_METADATA_LAYER.get_tree_hash()
P2_ANNOUNCED_DELEGATED_PUZZLE_HASH: bytes32 = P2_ANNOUNCED_DELEGATED_PUZZLE.get_tree_hash()

View File

@ -20,6 +20,7 @@ from chia.types.coin_spend import CoinSpend
from chia.types.spend_bundle import SpendBundle
from chia.util.hash import std_hash
from chia.util.ints import uint32, uint64, uint128
from chia.util.streamable import Streamable
from chia.wallet.conditions import Condition, UnknownCondition
from chia.wallet.did_wallet.did_wallet import DIDWallet
from chia.wallet.payment import Payment
@ -94,11 +95,14 @@ class VCWallet:
def id(self) -> uint32:
return self.wallet_info.id
async def coin_added(self, coin: Coin, height: uint32, peer: WSChiaConnection) -> None:
async def coin_added(
self, coin: Coin, height: uint32, peer: WSChiaConnection, coin_data: Optional[Streamable]
) -> None:
"""
An unspent coin has arrived to our wallet. Get the parent spend to construct the current VerifiedCredential
representation of the coin and add it to the DB if it's the newest version of the singleton.
"""
# TODO Use coin_data instead of calling peer API
wallet_node = self.wallet_state_manager.wallet_node
coin_states: Optional[List[CoinState]] = await wallet_node.get_coin_state([coin.parent_coin_info], peer=peer)
if coin_states is None:
@ -664,4 +668,4 @@ class VCWallet:
if TYPE_CHECKING:
_dummy: WalletProtocol = VCWallet() # pragma: no cover
_dummy: WalletProtocol[VerifiedCredential] = VCWallet() # pragma: no cover

View File

@ -16,10 +16,12 @@ from chia.types.coin_spend import CoinSpend
from chia.types.spend_bundle import SpendBundle
from chia.util.hash import std_hash
from chia.util.ints import uint32, uint64, uint128
from chia.util.streamable import Streamable
from chia.wallet.coin_selection import select_coins
from chia.wallet.conditions import Condition
from chia.wallet.derivation_record import DerivationRecord
from chia.wallet.payment import Payment
from chia.wallet.puzzles.clawback.metadata import ClawbackMetadata
from chia.wallet.puzzles.p2_delegated_puzzle_or_hidden_puzzle import (
DEFAULT_HIDDEN_PUZZLE_HASH,
calculate_synthetic_secret_key,
@ -55,7 +57,7 @@ CHIP_0002_SIGN_MESSAGE_PREFIX = "Chia Signed Message"
class Wallet:
if TYPE_CHECKING:
_protocol_check: ClassVar[WalletProtocol] = cast("Wallet", None)
_protocol_check: ClassVar[WalletProtocol[ClawbackMetadata]] = cast("Wallet", None)
wallet_info: WalletInfo
wallet_state_manager: WalletStateManager
@ -533,7 +535,7 @@ class Wallet:
# WSChiaConnection is only imported for type checking
async def coin_added(
self, coin: Coin, height: uint32, peer: WSChiaConnection
self, coin: Coin, height: uint32, peer: WSChiaConnection, coin_data: Optional[Streamable]
) -> None: # pylint: disable=used-before-assignment
pass

View File

@ -1,6 +1,6 @@
from __future__ import annotations
from typing import TYPE_CHECKING, List, Optional, Set, Tuple
from typing import TYPE_CHECKING, List, Optional, Set, Tuple, TypeVar
from blspy import G1Element
from typing_extensions import NotRequired, Protocol, TypedDict
@ -20,8 +20,10 @@ from chia.wallet.wallet_info import WalletInfo
if TYPE_CHECKING:
from chia.wallet.wallet_state_manager import WalletStateManager
T = TypeVar("T", contravariant=True)
class WalletProtocol(Protocol):
class WalletProtocol(Protocol[T]):
@classmethod
def type(cls) -> WalletType:
...
@ -29,7 +31,7 @@ class WalletProtocol(Protocol):
def id(self) -> uint32:
...
async def coin_added(self, coin: Coin, height: uint32, peer: WSChiaConnection) -> None:
async def coin_added(self, coin: Coin, height: uint32, peer: WSChiaConnection, coin_data: Optional[T]) -> None:
...
async def select_coins(

View File

@ -8,21 +8,7 @@ import traceback
from contextlib import asynccontextmanager
from pathlib import Path
from secrets import token_bytes
from typing import (
TYPE_CHECKING,
Any,
AsyncIterator,
Callable,
Dict,
Iterator,
List,
Optional,
Set,
Tuple,
Type,
TypeVar,
Union,
)
from typing import TYPE_CHECKING, Any, AsyncIterator, Callable, Dict, List, Optional, Set, Tuple, Type, TypeVar, Union
import aiosqlite
from blspy import G1Element, G2Element, PrivateKey
@ -56,12 +42,13 @@ from chia.util.db_synchronous import db_synchronous_on
from chia.util.db_wrapper import DBWrapper2
from chia.util.errors import Err
from chia.util.hash import std_hash
from chia.util.ints import uint32, uint64, uint128
from chia.util.ints import uint16, uint32, uint64, uint128
from chia.util.lru_cache import LRUCache
from chia.util.misc import UInt32Range, UInt64Range, VersionedBlob
from chia.util.path import path_from_root
from chia.util.streamable import Streamable
from chia.wallet.cat_wallet.cat_constants import DEFAULT_CATS
from chia.wallet.cat_wallet.cat_info import CATInfo, CRCATInfo
from chia.wallet.cat_wallet.cat_info import CATCoinData, CATInfo, CRCATInfo
from chia.wallet.cat_wallet.cat_utils import CAT_MOD, CAT_MOD_HASH, construct_cat_puzzle, match_cat_puzzle
from chia.wallet.cat_wallet.cat_wallet import CATWallet
from chia.wallet.conditions import Condition
@ -75,6 +62,7 @@ from chia.wallet.derive_keys import (
master_sk_to_wallet_sk_unhardened,
master_sk_to_wallet_sk_unhardened_intermediate,
)
from chia.wallet.did_wallet.did_info import DIDCoinData
from chia.wallet.did_wallet.did_wallet import DIDWallet
from chia.wallet.did_wallet.did_wallet_puzzles import DID_INNERPUZ_MOD, match_did_puzzle
from chia.wallet.key_val_store import KeyValStore
@ -130,7 +118,7 @@ from chia.wallet.wallet_retry_store import WalletRetryStore
from chia.wallet.wallet_transaction_store import WalletTransactionStore
from chia.wallet.wallet_user_store import WalletUserStore
TWalletType = TypeVar("TWalletType", bound=WalletProtocol)
TWalletType = TypeVar("TWalletType", bound=WalletProtocol[Any])
if TYPE_CHECKING:
from chia.wallet.wallet_node import WalletNode
@ -163,7 +151,7 @@ class WalletStateManager:
db_wrapper: DBWrapper2
main_wallet: Wallet
wallets: Dict[uint32, WalletProtocol]
wallets: Dict[uint32, WalletProtocol[Any]]
private_key: PrivateKey
trade_manager: TradeManager
@ -255,7 +243,7 @@ class WalletStateManager:
AssetType.CAT: CATWallet,
}
wallet: Optional[WalletProtocol] = None
wallet: Optional[WalletProtocol[Any]] = None
for wallet_info in await self.get_all_wallet_info_entries():
wallet_type = WalletType(wallet_info.type)
if wallet_type == WalletType.STANDARD_WALLET:
@ -706,40 +694,44 @@ class WalletStateManager:
async def determine_coin_type(
self, peer: WSChiaConnection, coin_state: CoinState, fork_height: Optional[uint32]
) -> Optional[WalletIdentifier]:
) -> Tuple[Optional[WalletIdentifier], Optional[Streamable]]:
if coin_state.created_height is not None and (
self.is_pool_reward(uint32(coin_state.created_height), coin_state.coin)
or self.is_farmer_reward(uint32(coin_state.created_height), coin_state.coin)
):
return None
return None, None
response: List[CoinState] = await self.wallet_node.get_coin_state(
[coin_state.coin.parent_coin_info], peer=peer, fork_height=fork_height
)
if len(response) == 0:
self.log.warning(f"Could not find a parent coin with ID: {coin_state.coin.parent_coin_info}")
return None
return None, None
parent_coin_state = response[0]
assert parent_coin_state.spent_height == coin_state.created_height
coin_spend = await fetch_coin_spend_for_coin_state(parent_coin_state, peer)
if coin_spend is None:
return None
puzzle = Program.from_bytes(bytes(coin_spend.puzzle_reveal))
uncurried = uncurry_puzzle(puzzle)
# Check if the coin is a CAT
cat_curried_args = match_cat_puzzle(uncurried)
if cat_curried_args is not None:
return await self.handle_cat(
cat_curried_args,
parent_coin_state,
coin_state,
coin_spend,
peer,
fork_height,
cat_mod_hash, tail_program_hash, cat_inner_puzzle = cat_curried_args
cat_data: CATCoinData = CATCoinData(
bytes32(cat_mod_hash.atom), bytes32(tail_program_hash.atom), cat_inner_puzzle
)
return (
await self.handle_cat(
cat_data,
parent_coin_state,
coin_state,
coin_spend,
peer,
fork_height,
),
cat_data,
)
# Check if the coin is a NFT
@ -747,27 +739,36 @@ class WalletStateManager:
# First spend where 1 mojo coin -> Singleton launcher -> NFT -> NFT
uncurried_nft = UncurriedNFT.uncurry(uncurried.mod, uncurried.args)
if uncurried_nft is not None and coin_state.coin.amount % 2 == 1:
return await self.handle_nft(coin_spend, uncurried_nft, parent_coin_state, coin_state)
return await self.handle_nft(coin_spend, uncurried_nft, parent_coin_state, coin_state), uncurried_nft
# Check if the coin is a DID
did_curried_args = match_did_puzzle(uncurried.mod, uncurried.args)
if did_curried_args is not None and coin_state.coin.amount % 2 == 1:
return await self.handle_did(did_curried_args, parent_coin_state, coin_state, coin_spend, peer)
p2_puzzle, recovery_list_hash, num_verification, singleton_struct, metadata = did_curried_args
did_data: DIDCoinData = DIDCoinData(
p2_puzzle,
bytes32(recovery_list_hash.atom),
uint16(num_verification.as_int()),
singleton_struct,
metadata,
)
return await self.handle_did(did_data, parent_coin_state, coin_state, coin_spend, peer), did_data
# Check if the coin is clawback
solution = coin_spend.solution.to_program()
clawback_metadata = match_clawback_puzzle(uncurried, puzzle, solution)
if clawback_metadata is not None:
return await self.handle_clawback(clawback_metadata, coin_state, coin_spend, peer)
clawback_coin_data = match_clawback_puzzle(uncurried, puzzle, solution)
if clawback_coin_data is not None:
return await self.handle_clawback(clawback_coin_data, coin_state, coin_spend, peer), clawback_coin_data
# Check if the coin is a VC
is_vc, err_msg = VerifiedCredential.is_vc(uncurried)
if is_vc:
return await self.handle_vc(coin_spend)
vc: VerifiedCredential = VerifiedCredential.get_next_from_coin_spend(coin_spend)
return await self.handle_vc(vc), vc
await self.notification_manager.potentially_add_new_notification(coin_state, coin_spend)
return None
return None, None
async def auto_claim_coins(self) -> None:
# Get unspent clawback coin
@ -933,7 +934,7 @@ class WalletStateManager:
async def handle_cat(
self,
curried_args: Iterator[Program],
parent_data: CATCoinData,
parent_coin_state: CoinState,
coin_state: CoinState,
coin_spend: CoinSpend,
@ -942,14 +943,13 @@ class WalletStateManager:
) -> Optional[WalletIdentifier]:
"""
Handle the new coin when it is a CAT
:param curried_args: Curried arg of the CAT mod
:param parent_data: Parent CAT coin uncurried metadata
:param parent_coin_state: Parent coin state
:param coin_state: Current coin state
:param coin_spend: New coin spend
:param fork_height: Current block height
:return: Wallet ID & Wallet Type
"""
mod_hash, tail_hash, inner_puzzle = curried_args
hinted_coin = compute_spend_hints_and_additions(coin_spend)[coin_state.coin.name()]
assert hinted_coin.hint is not None, f"hint missing for coin {hinted_coin.coin}"
derivation_record = await self.puzzle_store.get_derivation_record_for_puzzle_hash(hinted_coin.hint)
@ -959,7 +959,7 @@ class WalletStateManager:
return None
else:
our_inner_puzzle: Program = self.main_wallet.puzzle_for_pk(derivation_record.pubkey)
asset_id: bytes32 = bytes32(bytes(tail_hash)[1:])
asset_id: bytes32 = parent_data.tail_program_hash
cat_puzzle = construct_cat_puzzle(CAT_MOD, asset_id, our_inner_puzzle, CAT_MOD_HASH)
is_crcat: bool = False
if cat_puzzle.get_tree_hash() != coin_state.coin.puzzle_hash:
@ -1005,7 +1005,7 @@ class WalletStateManager:
)
self.state_changed("converted cat wallet to cr", wallet_info.id)
return WalletIdentifier(wallet_info.id, WalletType(WalletType.CRCAT))
if bytes(tail_hash).hex()[2:] in self.default_cats or self.config.get(
if parent_data.tail_program_hash.hex() in self.default_cats or self.config.get(
"automatically_add_unknown_cats", False
):
if is_crcat:
@ -1018,7 +1018,7 @@ class WalletStateManager:
)
else:
cat_wallet = await CATWallet.get_or_create_wallet_for_cat(
self, self.main_wallet, bytes(tail_hash).hex()[2:]
self, self.main_wallet, parent_data.tail_program_hash.hex()
)
return WalletIdentifier.create(cat_wallet)
else:
@ -1039,7 +1039,7 @@ class WalletStateManager:
async def handle_did(
self,
curried_args: Iterator[Program],
parent_data: DIDCoinData,
parent_coin_state: CoinState,
coin_state: CoinState,
coin_spend: CoinSpend,
@ -1047,21 +1047,21 @@ class WalletStateManager:
) -> Optional[WalletIdentifier]:
"""
Handle the new coin when it is a DID
:param curried_args: Curried arg of the DID mod
:param parent_data: Curried data of the DID coin
:param parent_coin_state: Parent coin state
:param coin_state: Current coin state
:param coin_spend: New coin spend
:return: Wallet ID & Wallet Type
"""
p2_puzzle, recovery_list_hash, num_verification, singleton_struct, metadata = curried_args
inner_puzzle_hash = p2_puzzle.get_tree_hash()
inner_puzzle_hash = parent_data.p2_puzzle.get_tree_hash()
self.log.info(f"parent: {parent_coin_state.coin.name()} inner_puzzle_hash for parent is {inner_puzzle_hash}")
hinted_coin = compute_spend_hints_and_additions(coin_spend)[coin_state.coin.name()]
assert hinted_coin.hint is not None, f"hint missing for coin {hinted_coin.coin}"
derivation_record = await self.puzzle_store.get_derivation_record_for_puzzle_hash(hinted_coin.hint)
launch_id: bytes32 = bytes32(bytes(singleton_struct.rest().first())[1:])
launch_id: bytes32 = bytes32(parent_data.singleton_struct.rest().first().atom)
if derivation_record is None:
self.log.info(f"Received state for the coin that doesn't belong to us {coin_state}")
# Check if it was owned by us
@ -1086,11 +1086,19 @@ class WalletStateManager:
self.log.info(f"Found DID, launch_id {launch_id}.")
did_puzzle = DID_INNERPUZ_MOD.curry(
our_inner_puzzle, recovery_list_hash, num_verification, singleton_struct, metadata
our_inner_puzzle,
parent_data.recovery_list_hash,
parent_data.num_verification,
parent_data.singleton_struct,
parent_data.metadata,
)
full_puzzle = create_singleton_puzzle(did_puzzle, launch_id)
did_puzzle_empty_recovery = DID_INNERPUZ_MOD.curry(
our_inner_puzzle, Program.to([]).get_tree_hash(), uint64(0), singleton_struct, metadata
our_inner_puzzle,
Program.to([]).get_tree_hash(),
uint64(0),
parent_data.singleton_struct,
parent_data.metadata,
)
full_puzzle_empty_recovery = create_singleton_puzzle(did_puzzle_empty_recovery, launch_id)
if full_puzzle.get_tree_hash() != coin_state.coin.puzzle_hash:
@ -1366,9 +1374,8 @@ class WalletStateManager:
await self.tx_store.add_transaction_record(tx_record)
return None
async def handle_vc(self, parent_coin_spend: CoinSpend) -> Optional[WalletIdentifier]:
async def handle_vc(self, vc: VerifiedCredential) -> Optional[WalletIdentifier]:
# Check the ownership
vc: VerifiedCredential = VerifiedCredential.get_next_from_coin_spend(parent_coin_spend)
derivation_record: Optional[DerivationRecord] = await self.puzzle_store.get_derivation_record_for_puzzle_hash(
vc.inner_puzzle_hash
)
@ -1421,7 +1428,7 @@ class WalletStateManager:
await self.retry_store.remove_state(coin_state)
wallet_identifier = await self.get_wallet_identifier_for_puzzle_hash(coin_state.coin.puzzle_hash)
coin_data: Optional[Streamable] = None
# If we already have this coin, & it was spent & confirmed at the same heights, then return (done)
if local_record is not None:
local_spent = None
@ -1440,7 +1447,7 @@ class WalletStateManager:
elif local_record is not None:
wallet_identifier = WalletIdentifier(uint32(local_record.wallet_id), local_record.wallet_type)
elif coin_state.created_height is not None:
wallet_identifier = await self.determine_coin_type(peer, coin_state, fork_height)
wallet_identifier, coin_data = await self.determine_coin_type(peer, coin_state, fork_height)
try:
dl_wallet = self.get_dl_wallet()
except ValueError:
@ -1481,6 +1488,7 @@ class WalletStateManager:
wallet_identifier.type,
peer,
coin_name,
coin_data,
)
# if the coin has been spent
@ -1682,6 +1690,7 @@ class WalletStateManager:
record.wallet_type,
peer,
coin_name,
coin_data,
)
await self.coin_store.set_spent(
curr_coin_state.coin.name(), uint32(curr_coin_state.spent_height)
@ -1773,6 +1782,7 @@ class WalletStateManager:
pool_wallet.type(),
peer,
coin_name,
coin_data,
)
await self.add_interested_coin_ids([coin_name])
@ -1888,6 +1898,7 @@ class WalletStateManager:
wallet_type: WalletType,
peer: WSChiaConnection,
coin_name: bytes32,
coin_data: Optional[Streamable],
) -> None:
"""
Adding coin to DB
@ -1951,7 +1962,7 @@ class WalletStateManager:
)
await self.coin_store.add_coin_record(coin_record, coin_name)
await self.wallets[wallet_id].coin_added(coin, height, peer)
await self.wallets[wallet_id].coin_added(coin, height, peer, coin_data)
await self.create_more_puzzle_hashes()
@ -2055,7 +2066,7 @@ class WalletStateManager:
result = await self.coin_store.get_coin_records(**kwargs)
return [await self.get_coin_record_by_wallet_record(record) for record in result.records]
async def get_wallet_for_coin(self, coin_id: bytes32) -> Optional[WalletProtocol]:
async def get_wallet_for_coin(self, coin_id: bytes32) -> Optional[WalletProtocol[Any]]:
coin_record = await self.coin_store.get_coin_record(coin_id)
if coin_record is None:
return None
@ -2108,7 +2119,7 @@ class WalletStateManager:
async def get_all_wallet_info_entries(self, wallet_type: Optional[WalletType] = None) -> List[WalletInfo]:
return await self.user_store.get_all_wallet_info_entries(wallet_type)
async def get_wallet_for_asset_id(self, asset_id: str) -> Optional[WalletProtocol]:
async def get_wallet_for_asset_id(self, asset_id: str) -> Optional[WalletProtocol[Any]]:
for wallet_id, wallet in self.wallets.items():
if wallet.type() in (WalletType.CAT, WalletType.CRCAT):
assert isinstance(wallet, CATWallet)
@ -2125,7 +2136,7 @@ class WalletStateManager:
return wallet
return None
async def get_wallet_for_puzzle_info(self, puzzle_driver: PuzzleInfo) -> Optional[WalletProtocol]:
async def get_wallet_for_puzzle_info(self, puzzle_driver: PuzzleInfo) -> Optional[WalletProtocol[Any]]:
for wallet in self.wallets.values():
match_function = getattr(wallet, "match_puzzle_info", None)
if match_function is not None and callable(match_function):
@ -2145,7 +2156,7 @@ class WalletStateManager:
},
)
async def add_new_wallet(self, wallet: WalletProtocol) -> None:
async def add_new_wallet(self, wallet: WalletProtocol[Any]) -> None:
self.wallets[wallet.id()] = wallet
await self.create_more_puzzle_hashes()
self.state_changed("wallet_created")

View File

@ -5,9 +5,13 @@ from typing import AsyncIterator
import pytest
from chia.protocols.wallet_protocol import CoinState
from chia.server.outbound_message import NodeType
from chia.simulator.setup_nodes import SimulatorsAndWallets
from chia.types.blockchain_format.coin import Coin
from chia.types.blockchain_format.sized_bytes import bytes32
from chia.util.ints import uint32
from chia.types.peer_info import PeerInfo
from chia.util.ints import uint16, uint32
from chia.wallet.derivation_record import DerivationRecord
from chia.wallet.derive_keys import master_sk_to_wallet_sk, master_sk_to_wallet_sk_unhardened
from chia.wallet.util.wallet_types import WalletType
@ -77,3 +81,17 @@ async def test_get_private_key_failure(simulator_and_wallet: SimulatorsAndWallet
invalid_puzzle_hash = bytes32(b"1" * 32)
with pytest.raises(ValueError, match=f"No key for puzzle hash: {invalid_puzzle_hash.hex()}"):
await wallet_state_manager.get_private_key(bytes32(b"1" * 32))
@pytest.mark.asyncio
async def test_determine_coin_type(simulator_and_wallet: SimulatorsAndWallets, self_hostname: str) -> None:
full_nodes, wallets, _ = simulator_and_wallet
full_node_api = full_nodes[0]
full_node_server = full_node_api.full_node.server
wallet_node, wallet_server = wallets[0]
await wallet_server.start_client(PeerInfo(self_hostname, uint16(full_node_server._port)), None)
wallet_state_manager: WalletStateManager = wallet_node.wallet_state_manager
peer = wallet_node.server.get_connections(NodeType.FULL_NODE)[0]
assert (None, None) == await wallet_state_manager.determine_coin_type(
peer, CoinState(Coin(bytes32(b"1" * 32), bytes32(b"1" * 32), 0), uint32(0), uint32(0)), None
)