diff --git a/chia/rpc/wallet_rpc_api.py b/chia/rpc/wallet_rpc_api.py index ee46a546987e..3981c5f76f7a 100644 --- a/chia/rpc/wallet_rpc_api.py +++ b/chia/rpc/wallet_rpc_api.py @@ -90,7 +90,6 @@ class WalletRpcApi: assert wallet_node is not None self.service = wallet_node self.service_name = "chia_wallet" - self.balance_cache: Dict[int, Any] = {} def get_routes(self) -> Dict[str, Endpoint]: return { @@ -318,7 +317,6 @@ class WalletRpcApi: return {"fingerprint": fingerprint} await self._stop_wallet() - self.balance_cache = {} started = await self.service._start_with_fingerprint(fingerprint) if started is True: return {"fingerprint": fingerprint} @@ -792,61 +790,15 @@ class WalletRpcApi: async def get_wallet_balance(self, request: Dict) -> EndpointResult: wallet_id = uint32(int(request["wallet_id"])) wallet = self.service.wallet_state_manager.wallets[wallet_id] - - # If syncing return the last available info or 0s - syncing = self.service.wallet_state_manager.sync_mode - if syncing: - if wallet_id in self.balance_cache: - wallet_balance = self.balance_cache[wallet_id] - else: - wallet_balance = { - "wallet_id": wallet_id, - "confirmed_wallet_balance": 0, - "unconfirmed_wallet_balance": 0, - "spendable_balance": 0, - "pending_change": 0, - "max_send_amount": 0, - "unspent_coin_count": 0, - "pending_coin_removal_count": 0, - "wallet_type": wallet.type(), - } - if self.service.logged_in_fingerprint is not None: - wallet_balance["fingerprint"] = self.service.logged_in_fingerprint - if wallet.type() == WalletType.CAT: - assert isinstance(wallet, CATWallet) - wallet_balance["asset_id"] = wallet.get_asset_id() - else: - async with self.service.wallet_state_manager.lock: - unspent_records = await self.service.wallet_state_manager.coin_store.get_unspent_coins_for_wallet( - wallet_id - ) - balance = await wallet.get_confirmed_balance(unspent_records) - pending_balance = await wallet.get_unconfirmed_balance(unspent_records) - spendable_balance = await wallet.get_spendable_balance(unspent_records) - pending_change = await wallet.get_pending_change_balance() - max_send_amount = await wallet.get_max_send_amount(unspent_records) - - unconfirmed_removals: Dict[ - bytes32, Coin - ] = await wallet.wallet_state_manager.unconfirmed_removals_for_wallet(wallet_id) - wallet_balance = { - "wallet_id": wallet_id, - "confirmed_wallet_balance": balance, - "unconfirmed_wallet_balance": pending_balance, - "spendable_balance": spendable_balance, - "pending_change": pending_change, - "max_send_amount": max_send_amount, - "unspent_coin_count": len(unspent_records), - "pending_coin_removal_count": len(unconfirmed_removals), - "wallet_type": wallet.type(), - } - if self.service.logged_in_fingerprint is not None: - wallet_balance["fingerprint"] = self.service.logged_in_fingerprint - if wallet.type() == WalletType.CAT: - assert isinstance(wallet, CATWallet) - wallet_balance["asset_id"] = wallet.get_asset_id() - self.balance_cache[wallet_id] = wallet_balance - + balance = await self.service.get_balance(wallet_id) + wallet_balance = balance.to_json_dict() + wallet_balance["wallet_id"] = wallet_id + wallet_balance["wallet_type"] = wallet.type() + if self.service.logged_in_fingerprint is not None: + wallet_balance["fingerprint"] = self.service.logged_in_fingerprint + if wallet.type() == WalletType.CAT: + assert isinstance(wallet, CATWallet) + wallet_balance["asset_id"] = wallet.get_asset_id() return {"wallet_balance": wallet_balance} async def get_transaction(self, request: Dict) -> EndpointResult: diff --git a/chia/wallet/wallet_node.py b/chia/wallet/wallet_node.py index 60ceac07a5eb..0bcdc296bc39 100644 --- a/chia/wallet/wallet_node.py +++ b/chia/wallet/wallet_node.py @@ -46,10 +46,11 @@ from chia.util.config import ( ) from chia.util.db_wrapper import manage_connection from chia.util.errors import KeychainIsEmpty, KeychainIsLocked, KeychainKeyNotFound, KeychainProxyConnectionFailure -from chia.util.ints import uint32, uint64 +from chia.util.ints import uint32, uint64, uint128 from chia.util.keychain import Keychain from chia.util.path import path_from_root from chia.util.profiler import mem_profile_task, profile_task +from chia.util.streamable import Streamable, streamable from chia.wallet.transaction_record import TransactionRecord from chia.wallet.util.new_peak_queue import NewPeakItem, NewPeakQueue, NewPeakQueueTypes from chia.wallet.util.peer_request_cache import PeerRequestCache, can_use_peer_request_cache @@ -85,6 +86,18 @@ def get_wallet_db_path(root_path: Path, config: Dict[str, Any], key_fingerprint: return path +@streamable +@dataclasses.dataclass(frozen=True) +class Balance(Streamable): + confirmed_wallet_balance: uint128 = uint128(0) + unconfirmed_wallet_balance: uint128 = uint128(0) + spendable_balance: uint128 = uint128(0) + pending_change: uint64 = uint64(0) + max_send_amount: uint128 = uint128(0) + unspent_coin_count: uint32 = uint32(0) + pending_coin_removal_count: uint32 = uint32(0) + + @dataclasses.dataclass class WalletNode: config: Dict @@ -103,6 +116,7 @@ class WalletNode: logged_in_fingerprint: Optional[int] = None logged_in: bool = False _keychain_proxy: Optional[KeychainProxy] = None + _balance_cache: Dict[int, Balance] = dataclasses.field(default_factory=dict) # Peers that we have long synced to synced_peers: Set[bytes32] = dataclasses.field(default_factory=set) wallet_peers: Optional[WalletPeers] = None @@ -373,6 +387,10 @@ class WalletNode: self.log_in(private_key) self.wallet_state_manager.state_changed("sync_changed") + # Populate the balance caches for all wallets + for wallet_id in self.wallet_state_manager.wallets: + await self.get_balance(wallet_id) + async with self.wallet_state_manager.puzzle_store.lock: index = await self.wallet_state_manager.puzzle_store.get_last_derivation_path() if index is None or index < self.wallet_state_manager.initial_num_public_keys - 1: @@ -407,6 +425,7 @@ class WalletNode: await proxy.close() await asyncio.sleep(0.5) # https://docs.aiohttp.org/en/stable/client_advanced.html#graceful-shutdown self.wallet_peers = None + self._balance_cache = {} def _set_state_changed_callback(self, callback: StateChangedProtocol) -> None: self.state_changed_callback = callback @@ -1643,3 +1662,30 @@ class WalletNode: full_nodes = self.server.get_connections(NodeType.FULL_NODE) for peer in full_nodes: await peer.send_message(msg) + + async def get_balance(self, wallet_id: uint32) -> Balance: + self.log.debug(f"get_balance - wallet_id: {wallet_id}") + if not self.wallet_state_manager.sync_mode: + self.log.debug(f"get_balance - Updating cache for {wallet_id}") + async with self.wallet_state_manager.lock: + wallet = self.wallet_state_manager.wallets[wallet_id] + unspent_records = await self.wallet_state_manager.coin_store.get_unspent_coins_for_wallet(wallet_id) + balance = await wallet.get_confirmed_balance(unspent_records) + pending_balance = await wallet.get_unconfirmed_balance(unspent_records) + spendable_balance = await wallet.get_spendable_balance(unspent_records) + pending_change = await wallet.get_pending_change_balance() + max_send_amount = await wallet.get_max_send_amount(unspent_records) + + unconfirmed_removals: Dict[ + bytes32, Coin + ] = await wallet.wallet_state_manager.unconfirmed_removals_for_wallet(wallet_id) + self._balance_cache[wallet_id] = Balance( + confirmed_wallet_balance=balance, + unconfirmed_wallet_balance=pending_balance, + spendable_balance=spendable_balance, + pending_change=pending_change, + max_send_amount=max_send_amount, + unspent_coin_count=uint32(len(unspent_records)), + pending_coin_removal_count=uint32(len(unconfirmed_removals)), + ) + return self._balance_cache.get(wallet_id, Balance()) diff --git a/tests/wallet/rpc/test_wallet_rpc.py b/tests/wallet/rpc/test_wallet_rpc.py index 738b8e343abd..255753d958ce 100644 --- a/tests/wallet/rpc/test_wallet_rpc.py +++ b/tests/wallet/rpc/test_wallet_rpc.py @@ -224,6 +224,18 @@ async def assert_push_tx_error(node_rpc: FullNodeRpcClient, tx: TransactionRecor raise ValueError from error +async def assert_get_balance(rpc_client: WalletRpcClient, wallet_node: WalletNode, wallet: WalletProtocol) -> None: + expected_balance = await wallet_node.get_balance(wallet.id()) + expected_balance_dict = expected_balance.to_json_dict() + expected_balance_dict["wallet_id"] = wallet.id() + expected_balance_dict["wallet_type"] = wallet.type() + expected_balance_dict["fingerprint"] = wallet_node.logged_in_fingerprint + if wallet.type() == WalletType.CAT: + assert isinstance(wallet, CATWallet) + expected_balance_dict["asset_id"] = wallet.get_asset_id() + assert await rpc_client.get_wallet_balance(wallet.id()) == expected_balance_dict + + async def tx_in_mempool(client: WalletRpcClient, transaction_id: bytes32): tx = await client.get_transaction(1, transaction_id) return tx.is_in_mempool() @@ -310,6 +322,22 @@ async def test_push_transactions(wallet_rpc_environment: WalletRpcTestEnvironmen assert tx.confirmed +@pytest.mark.asyncio +async def test_get_balance(wallet_rpc_environment: WalletRpcTestEnvironment): + env = wallet_rpc_environment + wallet: Wallet = env.wallet_1.wallet + wallet_node: WalletNode = env.wallet_1.node + full_node_api: FullNodeSimulator = env.full_node.api + wallet_rpc_client = env.wallet_1.rpc_client + await full_node_api.farm_blocks_to_wallet(2, wallet) + async with wallet_node.wallet_state_manager.lock: + cat_wallet: CATWallet = await CATWallet.create_new_cat_wallet( + wallet_node.wallet_state_manager, wallet, {"identifier": "genesis_by_id"}, uint64(100) + ) + await assert_get_balance(wallet_rpc_client, wallet_node, wallet) + await assert_get_balance(wallet_rpc_client, wallet_node, cat_wallet) + + @pytest.mark.asyncio async def test_get_timestamp_for_height(wallet_rpc_environment: WalletRpcTestEnvironment): env: WalletRpcTestEnvironment = wallet_rpc_environment diff --git a/tests/wallet/test_wallet_node.py b/tests/wallet/test_wallet_node.py index e41e3ceb0258..9254f6640f9a 100644 --- a/tests/wallet/test_wallet_node.py +++ b/tests/wallet/test_wallet_node.py @@ -2,16 +2,20 @@ from __future__ import annotations import sys from pathlib import Path -from typing import Any, Dict, Optional +from typing import Any, Dict, List, Optional import pytest from blspy import PrivateKey from chia.simulator.block_tools import test_constants from chia.simulator.setup_nodes import SimulatorsAndWallets +from chia.simulator.time_out_assert import time_out_assert +from chia.types.full_block import FullBlock +from chia.types.peer_info import PeerInfo from chia.util.config import load_config -from chia.util.keychain import Keychain, generate_mnemonic -from chia.wallet.wallet_node import WalletNode +from chia.util.ints import uint16, uint32, uint128 +from chia.util.keychain import Keychain, KeyData, generate_mnemonic +from chia.wallet.wallet_node import Balance, WalletNode @pytest.mark.asyncio @@ -302,3 +306,81 @@ async def test_unique_puzzle_hash_subscriptions(simulator_and_wallet: Simulators puzzle_hashes = await node.get_puzzle_hashes_to_subscribe() assert len(puzzle_hashes) > 1 assert len(set(puzzle_hashes)) == len(puzzle_hashes) + + +@pytest.mark.asyncio +async def test_get_balance( + simulator_and_wallet: SimulatorsAndWallets, self_hostname: str, default_400_blocks: List[FullBlock] +) -> None: + [full_node_api], [(wallet_node, wallet_server)], bt = simulator_and_wallet + full_node_server = full_node_api.full_node.server + + def wallet_synced() -> bool: + return full_node_server.node_id in wallet_node.synced_peers + + async def restart_with_fingerprint(fingerprint: Optional[int]) -> None: + wallet_node._close() # type: ignore[no-untyped-call] # WalletNode needs hinting here + await wallet_node._await_closed(shutting_down=False) + await wallet_node._start_with_fingerprint(fingerprint=fingerprint) + + wallet_id = uint32(1) + initial_fingerprint = wallet_node.logged_in_fingerprint + + # TODO, there is a bug in wallet_short_sync_backtrack which leads to a rollback to 0 (-1 which is another a bug) and + # with that to a KeyError when applying the race cache if there are less than WEIGHT_PROOF_RECENT_BLOCKS + # blocks but we still have a peak stored in the DB. So we need to add enough blocks for a weight proof here to + # be able to restart the wallet in this test. + for block in default_400_blocks: + await full_node_api.full_node.add_block(block) + + # Initially there should be no sync and no balance + assert not wallet_synced() + assert await wallet_node.get_balance(wallet_id) == Balance() + # Generate some funds, get the balance and make sure it's as expected + await wallet_server.start_client(PeerInfo(self_hostname, uint16(full_node_server._port)), None) + await time_out_assert(30, wallet_synced) + generated_funds = await full_node_api.farm_blocks_to_wallet(5, wallet_node.wallet_state_manager.main_wallet) + expected_generated_balance = Balance( + confirmed_wallet_balance=uint128(generated_funds), + unconfirmed_wallet_balance=uint128(generated_funds), + spendable_balance=uint128(generated_funds), + max_send_amount=uint128(generated_funds), + unspent_coin_count=uint32(10), + ) + generated_balance = await wallet_node.get_balance(wallet_id) + assert generated_balance == expected_generated_balance + # Load another key without funds, make sure the balance is empty. + other_key = KeyData.generate() + assert wallet_node.local_keychain is not None + wallet_node.local_keychain.add_private_key(other_key.mnemonic_str()) + await restart_with_fingerprint(other_key.fingerprint) + assert await wallet_node.get_balance(wallet_id) == Balance() + # Load the initial fingerprint again and make sure the balance is still what we generated earlier + await restart_with_fingerprint(initial_fingerprint) + assert await wallet_node.get_balance(wallet_id) == generated_balance + # Connect and sync to the full node, generate more funds and test the balance caching + # TODO, there is a bug in untrusted sync if we try to sync to the same peak as stored in the DB after restart + # which leads to a rollback to 0 (-1 which is another a bug) and then to a validation error because the + # downloaded weight proof will not be added to the blockchain properly because we still have a peak with the + # same weight stored in the DB but without chain data. The 1 block generation below can be dropped if we just + # also store the chain data or maybe adjust the weight proof consideration logic in new_valid_weight_proof. + await full_node_api.farm_blocks_to_puzzlehash(1) + assert not wallet_synced() + await wallet_server.start_client(PeerInfo(self_hostname, uint16(full_node_server._port)), None) + await time_out_assert(30, wallet_synced) + generated_funds += await full_node_api.farm_blocks_to_wallet(5, wallet_node.wallet_state_manager.main_wallet) + expected_more_balance = Balance( + confirmed_wallet_balance=uint128(generated_funds), + unconfirmed_wallet_balance=uint128(generated_funds), + spendable_balance=uint128(generated_funds), + max_send_amount=uint128(generated_funds), + unspent_coin_count=uint32(20), + ) + async with wallet_node.wallet_state_manager.set_sync_mode(uint32(100)): + # During sync the balance cache should not become updated, so it still should have the old balance here + assert await wallet_node.get_balance(wallet_id) == expected_generated_balance + # Now after the sync context the cache should become updated to the newly genertated balance + assert await wallet_node.get_balance(wallet_id) == expected_more_balance + # Restart one more time and make sure the balance is still correct after start + await restart_with_fingerprint(initial_fingerprint) + assert await wallet_node.get_balance(wallet_id) == expected_more_balance