From aaba00c3549883450728db476424df1412cf86ca Mon Sep 17 00:00:00 2001 From: Adam Kelly <338792+aqk@users.noreply.github.com> Date: Tue, 29 Mar 2022 10:30:46 -0700 Subject: [PATCH] Add more type checks to CAT Wallet (#10934) --- chia/wallet/cat_wallet/cat_wallet.py | 33 ++++++++++++++++++---------- 1 file changed, 21 insertions(+), 12 deletions(-) diff --git a/chia/wallet/cat_wallet/cat_wallet.py b/chia/wallet/cat_wallet/cat_wallet.py index a9d7b6c3c8b9..4b9f2c19a17e 100644 --- a/chia/wallet/cat_wallet/cat_wallet.py +++ b/chia/wallet/cat_wallet/cat_wallet.py @@ -91,10 +91,12 @@ class CATWallet: if name is None: name = "CAT WALLET" - self.wallet_info = await wallet_state_manager.user_store.create_wallet(name, WalletType.CAT, info_as_string) - if self.wallet_info is None: + new_wallet_info = await wallet_state_manager.user_store.create_wallet(name, WalletType.CAT, info_as_string) + if new_wallet_info is None: raise ValueError("Internal Error") + self.wallet_info = new_wallet_info + self.lineage_store = await CATLineageStore.create(self.wallet_state_manager.db_wrapper, self.get_asset_id()) try: @@ -192,11 +194,12 @@ class CATWallet: limitations_program_hash = bytes32(hexstr_to_bytes(limitations_program_hash_hex)) self.cat_info = CATInfo(limitations_program_hash, None) info_as_string = bytes(self.cat_info).hex() - self.wallet_info = await wallet_state_manager.user_store.create_wallet( + new_wallet_info = await wallet_state_manager.user_store.create_wallet( name, WalletType.CAT, info_as_string, in_transaction=in_transaction ) - if self.wallet_info is None: + if new_wallet_info is None: raise Exception("wallet_info is None") + self.wallet_info = new_wallet_info self.lineage_store = await CATLineageStore.create( self.wallet_state_manager.db_wrapper, self.get_asset_id(), in_transaction=in_transaction @@ -477,7 +480,11 @@ class CATWallet: if matched: _, _, inner_puzzle = puzzle_args puzzle_hash = inner_puzzle.get_tree_hash() - pubkey, private = await self.wallet_state_manager.get_keys(puzzle_hash) + ret = await self.wallet_state_manager.get_keys(puzzle_hash) + if ret is None: + # Abort signing the entire SpendBundle - sign all or none + raise RuntimeError(f"Failed to get keys for puzzle_hash {puzzle_hash}") + pubkey, private = ret synthetic_secret_key = calculate_synthetic_secret_key(private, DEFAULT_HIDDEN_PUZZLE_HASH) error, conditions, cost = conditions_dict_for_solution( spend.puzzle_reveal.to_program(), @@ -499,18 +506,20 @@ class CATWallet: return SpendBundle.aggregate([spend_bundle, SpendBundle([], agg_sig)]) async def inner_puzzle_for_cat_puzhash(self, cat_hash: bytes32) -> Program: - record: DerivationRecord = await self.wallet_state_manager.puzzle_store.get_derivation_record_for_puzzle_hash( - cat_hash - ) + record: Optional[ + DerivationRecord + ] = await self.wallet_state_manager.puzzle_store.get_derivation_record_for_puzzle_hash(cat_hash) + if record is None: + raise RuntimeError(f"Missing Derivation Record for CAT puzzle_hash {cat_hash}") inner_puzzle: Program = self.standard_wallet.puzzle_for_pk(bytes(record.pubkey)) return inner_puzzle async def convert_puzzle_hash(self, puzzle_hash: bytes32) -> bytes32: - record: DerivationRecord = await self.wallet_state_manager.puzzle_store.get_derivation_record_for_puzzle_hash( - puzzle_hash - ) + record: Optional[ + DerivationRecord + ] = await self.wallet_state_manager.puzzle_store.get_derivation_record_for_puzzle_hash(puzzle_hash) if record is None: - return puzzle_hash + return puzzle_hash # TODO: check if we have a test for this case! else: return (await self.inner_puzzle_for_cat_puzhash(puzzle_hash)).get_tree_hash()