wallet: Improve balance caching (#14631)

* wallet: Improve balance caching

* Fix missing entries in `get_balance` RPC and add a test for it

* Fix comment
This commit is contained in:
dustinface 2023-04-03 21:52:11 +07:00 committed by GitHub
parent e8e8e0d709
commit 5e3bf65abb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 169 additions and 61 deletions

View File

@ -90,7 +90,6 @@ class WalletRpcApi:
assert wallet_node is not None assert wallet_node is not None
self.service = wallet_node self.service = wallet_node
self.service_name = "chia_wallet" self.service_name = "chia_wallet"
self.balance_cache: Dict[int, Any] = {}
def get_routes(self) -> Dict[str, Endpoint]: def get_routes(self) -> Dict[str, Endpoint]:
return { return {
@ -318,7 +317,6 @@ class WalletRpcApi:
return {"fingerprint": fingerprint} return {"fingerprint": fingerprint}
await self._stop_wallet() await self._stop_wallet()
self.balance_cache = {}
started = await self.service._start_with_fingerprint(fingerprint) started = await self.service._start_with_fingerprint(fingerprint)
if started is True: if started is True:
return {"fingerprint": fingerprint} return {"fingerprint": fingerprint}
@ -792,61 +790,15 @@ class WalletRpcApi:
async def get_wallet_balance(self, request: Dict) -> EndpointResult: async def get_wallet_balance(self, request: Dict) -> EndpointResult:
wallet_id = uint32(int(request["wallet_id"])) wallet_id = uint32(int(request["wallet_id"]))
wallet = self.service.wallet_state_manager.wallets[wallet_id] wallet = self.service.wallet_state_manager.wallets[wallet_id]
balance = await self.service.get_balance(wallet_id)
# If syncing return the last available info or 0s wallet_balance = balance.to_json_dict()
syncing = self.service.wallet_state_manager.sync_mode wallet_balance["wallet_id"] = wallet_id
if syncing: wallet_balance["wallet_type"] = wallet.type()
if wallet_id in self.balance_cache: if self.service.logged_in_fingerprint is not None:
wallet_balance = self.balance_cache[wallet_id] wallet_balance["fingerprint"] = self.service.logged_in_fingerprint
else: if wallet.type() == WalletType.CAT:
wallet_balance = { assert isinstance(wallet, CATWallet)
"wallet_id": wallet_id, wallet_balance["asset_id"] = wallet.get_asset_id()
"confirmed_wallet_balance": 0,
"unconfirmed_wallet_balance": 0,
"spendable_balance": 0,
"pending_change": 0,
"max_send_amount": 0,
"unspent_coin_count": 0,
"pending_coin_removal_count": 0,
"wallet_type": wallet.type(),
}
if self.service.logged_in_fingerprint is not None:
wallet_balance["fingerprint"] = self.service.logged_in_fingerprint
if wallet.type() == WalletType.CAT:
assert isinstance(wallet, CATWallet)
wallet_balance["asset_id"] = wallet.get_asset_id()
else:
async with self.service.wallet_state_manager.lock:
unspent_records = await self.service.wallet_state_manager.coin_store.get_unspent_coins_for_wallet(
wallet_id
)
balance = await wallet.get_confirmed_balance(unspent_records)
pending_balance = await wallet.get_unconfirmed_balance(unspent_records)
spendable_balance = await wallet.get_spendable_balance(unspent_records)
pending_change = await wallet.get_pending_change_balance()
max_send_amount = await wallet.get_max_send_amount(unspent_records)
unconfirmed_removals: Dict[
bytes32, Coin
] = await wallet.wallet_state_manager.unconfirmed_removals_for_wallet(wallet_id)
wallet_balance = {
"wallet_id": wallet_id,
"confirmed_wallet_balance": balance,
"unconfirmed_wallet_balance": pending_balance,
"spendable_balance": spendable_balance,
"pending_change": pending_change,
"max_send_amount": max_send_amount,
"unspent_coin_count": len(unspent_records),
"pending_coin_removal_count": len(unconfirmed_removals),
"wallet_type": wallet.type(),
}
if self.service.logged_in_fingerprint is not None:
wallet_balance["fingerprint"] = self.service.logged_in_fingerprint
if wallet.type() == WalletType.CAT:
assert isinstance(wallet, CATWallet)
wallet_balance["asset_id"] = wallet.get_asset_id()
self.balance_cache[wallet_id] = wallet_balance
return {"wallet_balance": wallet_balance} return {"wallet_balance": wallet_balance}
async def get_transaction(self, request: Dict) -> EndpointResult: async def get_transaction(self, request: Dict) -> EndpointResult:

View File

@ -46,10 +46,11 @@ from chia.util.config import (
) )
from chia.util.db_wrapper import manage_connection from chia.util.db_wrapper import manage_connection
from chia.util.errors import KeychainIsEmpty, KeychainIsLocked, KeychainKeyNotFound, KeychainProxyConnectionFailure from chia.util.errors import KeychainIsEmpty, KeychainIsLocked, KeychainKeyNotFound, KeychainProxyConnectionFailure
from chia.util.ints import uint32, uint64 from chia.util.ints import uint32, uint64, uint128
from chia.util.keychain import Keychain from chia.util.keychain import Keychain
from chia.util.path import path_from_root from chia.util.path import path_from_root
from chia.util.profiler import mem_profile_task, profile_task from chia.util.profiler import mem_profile_task, profile_task
from chia.util.streamable import Streamable, streamable
from chia.wallet.transaction_record import TransactionRecord from chia.wallet.transaction_record import TransactionRecord
from chia.wallet.util.new_peak_queue import NewPeakItem, NewPeakQueue, NewPeakQueueTypes from chia.wallet.util.new_peak_queue import NewPeakItem, NewPeakQueue, NewPeakQueueTypes
from chia.wallet.util.peer_request_cache import PeerRequestCache, can_use_peer_request_cache from chia.wallet.util.peer_request_cache import PeerRequestCache, can_use_peer_request_cache
@ -85,6 +86,18 @@ def get_wallet_db_path(root_path: Path, config: Dict[str, Any], key_fingerprint:
return path return path
@streamable
@dataclasses.dataclass(frozen=True)
class Balance(Streamable):
confirmed_wallet_balance: uint128 = uint128(0)
unconfirmed_wallet_balance: uint128 = uint128(0)
spendable_balance: uint128 = uint128(0)
pending_change: uint64 = uint64(0)
max_send_amount: uint128 = uint128(0)
unspent_coin_count: uint32 = uint32(0)
pending_coin_removal_count: uint32 = uint32(0)
@dataclasses.dataclass @dataclasses.dataclass
class WalletNode: class WalletNode:
config: Dict config: Dict
@ -103,6 +116,7 @@ class WalletNode:
logged_in_fingerprint: Optional[int] = None logged_in_fingerprint: Optional[int] = None
logged_in: bool = False logged_in: bool = False
_keychain_proxy: Optional[KeychainProxy] = None _keychain_proxy: Optional[KeychainProxy] = None
_balance_cache: Dict[int, Balance] = dataclasses.field(default_factory=dict)
# Peers that we have long synced to # Peers that we have long synced to
synced_peers: Set[bytes32] = dataclasses.field(default_factory=set) synced_peers: Set[bytes32] = dataclasses.field(default_factory=set)
wallet_peers: Optional[WalletPeers] = None wallet_peers: Optional[WalletPeers] = None
@ -373,6 +387,10 @@ class WalletNode:
self.log_in(private_key) self.log_in(private_key)
self.wallet_state_manager.state_changed("sync_changed") self.wallet_state_manager.state_changed("sync_changed")
# Populate the balance caches for all wallets
for wallet_id in self.wallet_state_manager.wallets:
await self.get_balance(wallet_id)
async with self.wallet_state_manager.puzzle_store.lock: async with self.wallet_state_manager.puzzle_store.lock:
index = await self.wallet_state_manager.puzzle_store.get_last_derivation_path() index = await self.wallet_state_manager.puzzle_store.get_last_derivation_path()
if index is None or index < self.wallet_state_manager.initial_num_public_keys - 1: if index is None or index < self.wallet_state_manager.initial_num_public_keys - 1:
@ -407,6 +425,7 @@ class WalletNode:
await proxy.close() await proxy.close()
await asyncio.sleep(0.5) # https://docs.aiohttp.org/en/stable/client_advanced.html#graceful-shutdown await asyncio.sleep(0.5) # https://docs.aiohttp.org/en/stable/client_advanced.html#graceful-shutdown
self.wallet_peers = None self.wallet_peers = None
self._balance_cache = {}
def _set_state_changed_callback(self, callback: StateChangedProtocol) -> None: def _set_state_changed_callback(self, callback: StateChangedProtocol) -> None:
self.state_changed_callback = callback self.state_changed_callback = callback
@ -1643,3 +1662,30 @@ class WalletNode:
full_nodes = self.server.get_connections(NodeType.FULL_NODE) full_nodes = self.server.get_connections(NodeType.FULL_NODE)
for peer in full_nodes: for peer in full_nodes:
await peer.send_message(msg) await peer.send_message(msg)
async def get_balance(self, wallet_id: uint32) -> Balance:
self.log.debug(f"get_balance - wallet_id: {wallet_id}")
if not self.wallet_state_manager.sync_mode:
self.log.debug(f"get_balance - Updating cache for {wallet_id}")
async with self.wallet_state_manager.lock:
wallet = self.wallet_state_manager.wallets[wallet_id]
unspent_records = await self.wallet_state_manager.coin_store.get_unspent_coins_for_wallet(wallet_id)
balance = await wallet.get_confirmed_balance(unspent_records)
pending_balance = await wallet.get_unconfirmed_balance(unspent_records)
spendable_balance = await wallet.get_spendable_balance(unspent_records)
pending_change = await wallet.get_pending_change_balance()
max_send_amount = await wallet.get_max_send_amount(unspent_records)
unconfirmed_removals: Dict[
bytes32, Coin
] = await wallet.wallet_state_manager.unconfirmed_removals_for_wallet(wallet_id)
self._balance_cache[wallet_id] = Balance(
confirmed_wallet_balance=balance,
unconfirmed_wallet_balance=pending_balance,
spendable_balance=spendable_balance,
pending_change=pending_change,
max_send_amount=max_send_amount,
unspent_coin_count=uint32(len(unspent_records)),
pending_coin_removal_count=uint32(len(unconfirmed_removals)),
)
return self._balance_cache.get(wallet_id, Balance())

View File

@ -224,6 +224,18 @@ async def assert_push_tx_error(node_rpc: FullNodeRpcClient, tx: TransactionRecor
raise ValueError from error raise ValueError from error
async def assert_get_balance(rpc_client: WalletRpcClient, wallet_node: WalletNode, wallet: WalletProtocol) -> None:
expected_balance = await wallet_node.get_balance(wallet.id())
expected_balance_dict = expected_balance.to_json_dict()
expected_balance_dict["wallet_id"] = wallet.id()
expected_balance_dict["wallet_type"] = wallet.type()
expected_balance_dict["fingerprint"] = wallet_node.logged_in_fingerprint
if wallet.type() == WalletType.CAT:
assert isinstance(wallet, CATWallet)
expected_balance_dict["asset_id"] = wallet.get_asset_id()
assert await rpc_client.get_wallet_balance(wallet.id()) == expected_balance_dict
async def tx_in_mempool(client: WalletRpcClient, transaction_id: bytes32): async def tx_in_mempool(client: WalletRpcClient, transaction_id: bytes32):
tx = await client.get_transaction(1, transaction_id) tx = await client.get_transaction(1, transaction_id)
return tx.is_in_mempool() return tx.is_in_mempool()
@ -310,6 +322,22 @@ async def test_push_transactions(wallet_rpc_environment: WalletRpcTestEnvironmen
assert tx.confirmed assert tx.confirmed
@pytest.mark.asyncio
async def test_get_balance(wallet_rpc_environment: WalletRpcTestEnvironment):
env = wallet_rpc_environment
wallet: Wallet = env.wallet_1.wallet
wallet_node: WalletNode = env.wallet_1.node
full_node_api: FullNodeSimulator = env.full_node.api
wallet_rpc_client = env.wallet_1.rpc_client
await full_node_api.farm_blocks_to_wallet(2, wallet)
async with wallet_node.wallet_state_manager.lock:
cat_wallet: CATWallet = await CATWallet.create_new_cat_wallet(
wallet_node.wallet_state_manager, wallet, {"identifier": "genesis_by_id"}, uint64(100)
)
await assert_get_balance(wallet_rpc_client, wallet_node, wallet)
await assert_get_balance(wallet_rpc_client, wallet_node, cat_wallet)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_timestamp_for_height(wallet_rpc_environment: WalletRpcTestEnvironment): async def test_get_timestamp_for_height(wallet_rpc_environment: WalletRpcTestEnvironment):
env: WalletRpcTestEnvironment = wallet_rpc_environment env: WalletRpcTestEnvironment = wallet_rpc_environment

View File

@ -2,16 +2,20 @@ from __future__ import annotations
import sys import sys
from pathlib import Path from pathlib import Path
from typing import Any, Dict, Optional from typing import Any, Dict, List, Optional
import pytest import pytest
from blspy import PrivateKey from blspy import PrivateKey
from chia.simulator.block_tools import test_constants from chia.simulator.block_tools import test_constants
from chia.simulator.setup_nodes import SimulatorsAndWallets from chia.simulator.setup_nodes import SimulatorsAndWallets
from chia.simulator.time_out_assert import time_out_assert
from chia.types.full_block import FullBlock
from chia.types.peer_info import PeerInfo
from chia.util.config import load_config from chia.util.config import load_config
from chia.util.keychain import Keychain, generate_mnemonic from chia.util.ints import uint16, uint32, uint128
from chia.wallet.wallet_node import WalletNode from chia.util.keychain import Keychain, KeyData, generate_mnemonic
from chia.wallet.wallet_node import Balance, WalletNode
@pytest.mark.asyncio @pytest.mark.asyncio
@ -302,3 +306,81 @@ async def test_unique_puzzle_hash_subscriptions(simulator_and_wallet: Simulators
puzzle_hashes = await node.get_puzzle_hashes_to_subscribe() puzzle_hashes = await node.get_puzzle_hashes_to_subscribe()
assert len(puzzle_hashes) > 1 assert len(puzzle_hashes) > 1
assert len(set(puzzle_hashes)) == len(puzzle_hashes) assert len(set(puzzle_hashes)) == len(puzzle_hashes)
@pytest.mark.asyncio
async def test_get_balance(
simulator_and_wallet: SimulatorsAndWallets, self_hostname: str, default_400_blocks: List[FullBlock]
) -> None:
[full_node_api], [(wallet_node, wallet_server)], bt = simulator_and_wallet
full_node_server = full_node_api.full_node.server
def wallet_synced() -> bool:
return full_node_server.node_id in wallet_node.synced_peers
async def restart_with_fingerprint(fingerprint: Optional[int]) -> None:
wallet_node._close() # type: ignore[no-untyped-call] # WalletNode needs hinting here
await wallet_node._await_closed(shutting_down=False)
await wallet_node._start_with_fingerprint(fingerprint=fingerprint)
wallet_id = uint32(1)
initial_fingerprint = wallet_node.logged_in_fingerprint
# TODO, there is a bug in wallet_short_sync_backtrack which leads to a rollback to 0 (-1 which is another a bug) and
# with that to a KeyError when applying the race cache if there are less than WEIGHT_PROOF_RECENT_BLOCKS
# blocks but we still have a peak stored in the DB. So we need to add enough blocks for a weight proof here to
# be able to restart the wallet in this test.
for block in default_400_blocks:
await full_node_api.full_node.add_block(block)
# Initially there should be no sync and no balance
assert not wallet_synced()
assert await wallet_node.get_balance(wallet_id) == Balance()
# Generate some funds, get the balance and make sure it's as expected
await wallet_server.start_client(PeerInfo(self_hostname, uint16(full_node_server._port)), None)
await time_out_assert(30, wallet_synced)
generated_funds = await full_node_api.farm_blocks_to_wallet(5, wallet_node.wallet_state_manager.main_wallet)
expected_generated_balance = Balance(
confirmed_wallet_balance=uint128(generated_funds),
unconfirmed_wallet_balance=uint128(generated_funds),
spendable_balance=uint128(generated_funds),
max_send_amount=uint128(generated_funds),
unspent_coin_count=uint32(10),
)
generated_balance = await wallet_node.get_balance(wallet_id)
assert generated_balance == expected_generated_balance
# Load another key without funds, make sure the balance is empty.
other_key = KeyData.generate()
assert wallet_node.local_keychain is not None
wallet_node.local_keychain.add_private_key(other_key.mnemonic_str())
await restart_with_fingerprint(other_key.fingerprint)
assert await wallet_node.get_balance(wallet_id) == Balance()
# Load the initial fingerprint again and make sure the balance is still what we generated earlier
await restart_with_fingerprint(initial_fingerprint)
assert await wallet_node.get_balance(wallet_id) == generated_balance
# Connect and sync to the full node, generate more funds and test the balance caching
# TODO, there is a bug in untrusted sync if we try to sync to the same peak as stored in the DB after restart
# which leads to a rollback to 0 (-1 which is another a bug) and then to a validation error because the
# downloaded weight proof will not be added to the blockchain properly because we still have a peak with the
# same weight stored in the DB but without chain data. The 1 block generation below can be dropped if we just
# also store the chain data or maybe adjust the weight proof consideration logic in new_valid_weight_proof.
await full_node_api.farm_blocks_to_puzzlehash(1)
assert not wallet_synced()
await wallet_server.start_client(PeerInfo(self_hostname, uint16(full_node_server._port)), None)
await time_out_assert(30, wallet_synced)
generated_funds += await full_node_api.farm_blocks_to_wallet(5, wallet_node.wallet_state_manager.main_wallet)
expected_more_balance = Balance(
confirmed_wallet_balance=uint128(generated_funds),
unconfirmed_wallet_balance=uint128(generated_funds),
spendable_balance=uint128(generated_funds),
max_send_amount=uint128(generated_funds),
unspent_coin_count=uint32(20),
)
async with wallet_node.wallet_state_manager.set_sync_mode(uint32(100)):
# During sync the balance cache should not become updated, so it still should have the old balance here
assert await wallet_node.get_balance(wallet_id) == expected_generated_balance
# Now after the sync context the cache should become updated to the newly genertated balance
assert await wallet_node.get_balance(wallet_id) == expected_more_balance
# Restart one more time and make sure the balance is still correct after start
await restart_with_fingerprint(initial_fingerprint)
assert await wallet_node.get_balance(wallet_id) == expected_more_balance