wallet: Improve balance caching (#14631)

* wallet: Improve balance caching

* Fix missing entries in `get_balance` RPC and add a test for it

* Fix comment
This commit is contained in:
dustinface 2023-04-03 21:52:11 +07:00 committed by GitHub
parent e8e8e0d709
commit 5e3bf65abb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 169 additions and 61 deletions

View File

@ -90,7 +90,6 @@ class WalletRpcApi:
assert wallet_node is not None
self.service = wallet_node
self.service_name = "chia_wallet"
self.balance_cache: Dict[int, Any] = {}
def get_routes(self) -> Dict[str, Endpoint]:
return {
@ -318,7 +317,6 @@ class WalletRpcApi:
return {"fingerprint": fingerprint}
await self._stop_wallet()
self.balance_cache = {}
started = await self.service._start_with_fingerprint(fingerprint)
if started is True:
return {"fingerprint": fingerprint}
@ -792,61 +790,15 @@ class WalletRpcApi:
async def get_wallet_balance(self, request: Dict) -> EndpointResult:
wallet_id = uint32(int(request["wallet_id"]))
wallet = self.service.wallet_state_manager.wallets[wallet_id]
# If syncing return the last available info or 0s
syncing = self.service.wallet_state_manager.sync_mode
if syncing:
if wallet_id in self.balance_cache:
wallet_balance = self.balance_cache[wallet_id]
else:
wallet_balance = {
"wallet_id": wallet_id,
"confirmed_wallet_balance": 0,
"unconfirmed_wallet_balance": 0,
"spendable_balance": 0,
"pending_change": 0,
"max_send_amount": 0,
"unspent_coin_count": 0,
"pending_coin_removal_count": 0,
"wallet_type": wallet.type(),
}
if self.service.logged_in_fingerprint is not None:
wallet_balance["fingerprint"] = self.service.logged_in_fingerprint
if wallet.type() == WalletType.CAT:
assert isinstance(wallet, CATWallet)
wallet_balance["asset_id"] = wallet.get_asset_id()
else:
async with self.service.wallet_state_manager.lock:
unspent_records = await self.service.wallet_state_manager.coin_store.get_unspent_coins_for_wallet(
wallet_id
)
balance = await wallet.get_confirmed_balance(unspent_records)
pending_balance = await wallet.get_unconfirmed_balance(unspent_records)
spendable_balance = await wallet.get_spendable_balance(unspent_records)
pending_change = await wallet.get_pending_change_balance()
max_send_amount = await wallet.get_max_send_amount(unspent_records)
unconfirmed_removals: Dict[
bytes32, Coin
] = await wallet.wallet_state_manager.unconfirmed_removals_for_wallet(wallet_id)
wallet_balance = {
"wallet_id": wallet_id,
"confirmed_wallet_balance": balance,
"unconfirmed_wallet_balance": pending_balance,
"spendable_balance": spendable_balance,
"pending_change": pending_change,
"max_send_amount": max_send_amount,
"unspent_coin_count": len(unspent_records),
"pending_coin_removal_count": len(unconfirmed_removals),
"wallet_type": wallet.type(),
}
if self.service.logged_in_fingerprint is not None:
wallet_balance["fingerprint"] = self.service.logged_in_fingerprint
if wallet.type() == WalletType.CAT:
assert isinstance(wallet, CATWallet)
wallet_balance["asset_id"] = wallet.get_asset_id()
self.balance_cache[wallet_id] = wallet_balance
balance = await self.service.get_balance(wallet_id)
wallet_balance = balance.to_json_dict()
wallet_balance["wallet_id"] = wallet_id
wallet_balance["wallet_type"] = wallet.type()
if self.service.logged_in_fingerprint is not None:
wallet_balance["fingerprint"] = self.service.logged_in_fingerprint
if wallet.type() == WalletType.CAT:
assert isinstance(wallet, CATWallet)
wallet_balance["asset_id"] = wallet.get_asset_id()
return {"wallet_balance": wallet_balance}
async def get_transaction(self, request: Dict) -> EndpointResult:

View File

@ -46,10 +46,11 @@ from chia.util.config import (
)
from chia.util.db_wrapper import manage_connection
from chia.util.errors import KeychainIsEmpty, KeychainIsLocked, KeychainKeyNotFound, KeychainProxyConnectionFailure
from chia.util.ints import uint32, uint64
from chia.util.ints import uint32, uint64, uint128
from chia.util.keychain import Keychain
from chia.util.path import path_from_root
from chia.util.profiler import mem_profile_task, profile_task
from chia.util.streamable import Streamable, streamable
from chia.wallet.transaction_record import TransactionRecord
from chia.wallet.util.new_peak_queue import NewPeakItem, NewPeakQueue, NewPeakQueueTypes
from chia.wallet.util.peer_request_cache import PeerRequestCache, can_use_peer_request_cache
@ -85,6 +86,18 @@ def get_wallet_db_path(root_path: Path, config: Dict[str, Any], key_fingerprint:
return path
@streamable
@dataclasses.dataclass(frozen=True)
class Balance(Streamable):
confirmed_wallet_balance: uint128 = uint128(0)
unconfirmed_wallet_balance: uint128 = uint128(0)
spendable_balance: uint128 = uint128(0)
pending_change: uint64 = uint64(0)
max_send_amount: uint128 = uint128(0)
unspent_coin_count: uint32 = uint32(0)
pending_coin_removal_count: uint32 = uint32(0)
@dataclasses.dataclass
class WalletNode:
config: Dict
@ -103,6 +116,7 @@ class WalletNode:
logged_in_fingerprint: Optional[int] = None
logged_in: bool = False
_keychain_proxy: Optional[KeychainProxy] = None
_balance_cache: Dict[int, Balance] = dataclasses.field(default_factory=dict)
# Peers that we have long synced to
synced_peers: Set[bytes32] = dataclasses.field(default_factory=set)
wallet_peers: Optional[WalletPeers] = None
@ -373,6 +387,10 @@ class WalletNode:
self.log_in(private_key)
self.wallet_state_manager.state_changed("sync_changed")
# Populate the balance caches for all wallets
for wallet_id in self.wallet_state_manager.wallets:
await self.get_balance(wallet_id)
async with self.wallet_state_manager.puzzle_store.lock:
index = await self.wallet_state_manager.puzzle_store.get_last_derivation_path()
if index is None or index < self.wallet_state_manager.initial_num_public_keys - 1:
@ -407,6 +425,7 @@ class WalletNode:
await proxy.close()
await asyncio.sleep(0.5) # https://docs.aiohttp.org/en/stable/client_advanced.html#graceful-shutdown
self.wallet_peers = None
self._balance_cache = {}
def _set_state_changed_callback(self, callback: StateChangedProtocol) -> None:
self.state_changed_callback = callback
@ -1643,3 +1662,30 @@ class WalletNode:
full_nodes = self.server.get_connections(NodeType.FULL_NODE)
for peer in full_nodes:
await peer.send_message(msg)
async def get_balance(self, wallet_id: uint32) -> Balance:
self.log.debug(f"get_balance - wallet_id: {wallet_id}")
if not self.wallet_state_manager.sync_mode:
self.log.debug(f"get_balance - Updating cache for {wallet_id}")
async with self.wallet_state_manager.lock:
wallet = self.wallet_state_manager.wallets[wallet_id]
unspent_records = await self.wallet_state_manager.coin_store.get_unspent_coins_for_wallet(wallet_id)
balance = await wallet.get_confirmed_balance(unspent_records)
pending_balance = await wallet.get_unconfirmed_balance(unspent_records)
spendable_balance = await wallet.get_spendable_balance(unspent_records)
pending_change = await wallet.get_pending_change_balance()
max_send_amount = await wallet.get_max_send_amount(unspent_records)
unconfirmed_removals: Dict[
bytes32, Coin
] = await wallet.wallet_state_manager.unconfirmed_removals_for_wallet(wallet_id)
self._balance_cache[wallet_id] = Balance(
confirmed_wallet_balance=balance,
unconfirmed_wallet_balance=pending_balance,
spendable_balance=spendable_balance,
pending_change=pending_change,
max_send_amount=max_send_amount,
unspent_coin_count=uint32(len(unspent_records)),
pending_coin_removal_count=uint32(len(unconfirmed_removals)),
)
return self._balance_cache.get(wallet_id, Balance())

View File

@ -224,6 +224,18 @@ async def assert_push_tx_error(node_rpc: FullNodeRpcClient, tx: TransactionRecor
raise ValueError from error
async def assert_get_balance(rpc_client: WalletRpcClient, wallet_node: WalletNode, wallet: WalletProtocol) -> None:
expected_balance = await wallet_node.get_balance(wallet.id())
expected_balance_dict = expected_balance.to_json_dict()
expected_balance_dict["wallet_id"] = wallet.id()
expected_balance_dict["wallet_type"] = wallet.type()
expected_balance_dict["fingerprint"] = wallet_node.logged_in_fingerprint
if wallet.type() == WalletType.CAT:
assert isinstance(wallet, CATWallet)
expected_balance_dict["asset_id"] = wallet.get_asset_id()
assert await rpc_client.get_wallet_balance(wallet.id()) == expected_balance_dict
async def tx_in_mempool(client: WalletRpcClient, transaction_id: bytes32):
tx = await client.get_transaction(1, transaction_id)
return tx.is_in_mempool()
@ -310,6 +322,22 @@ async def test_push_transactions(wallet_rpc_environment: WalletRpcTestEnvironmen
assert tx.confirmed
@pytest.mark.asyncio
async def test_get_balance(wallet_rpc_environment: WalletRpcTestEnvironment):
env = wallet_rpc_environment
wallet: Wallet = env.wallet_1.wallet
wallet_node: WalletNode = env.wallet_1.node
full_node_api: FullNodeSimulator = env.full_node.api
wallet_rpc_client = env.wallet_1.rpc_client
await full_node_api.farm_blocks_to_wallet(2, wallet)
async with wallet_node.wallet_state_manager.lock:
cat_wallet: CATWallet = await CATWallet.create_new_cat_wallet(
wallet_node.wallet_state_manager, wallet, {"identifier": "genesis_by_id"}, uint64(100)
)
await assert_get_balance(wallet_rpc_client, wallet_node, wallet)
await assert_get_balance(wallet_rpc_client, wallet_node, cat_wallet)
@pytest.mark.asyncio
async def test_get_timestamp_for_height(wallet_rpc_environment: WalletRpcTestEnvironment):
env: WalletRpcTestEnvironment = wallet_rpc_environment

View File

@ -2,16 +2,20 @@ from __future__ import annotations
import sys
from pathlib import Path
from typing import Any, Dict, Optional
from typing import Any, Dict, List, Optional
import pytest
from blspy import PrivateKey
from chia.simulator.block_tools import test_constants
from chia.simulator.setup_nodes import SimulatorsAndWallets
from chia.simulator.time_out_assert import time_out_assert
from chia.types.full_block import FullBlock
from chia.types.peer_info import PeerInfo
from chia.util.config import load_config
from chia.util.keychain import Keychain, generate_mnemonic
from chia.wallet.wallet_node import WalletNode
from chia.util.ints import uint16, uint32, uint128
from chia.util.keychain import Keychain, KeyData, generate_mnemonic
from chia.wallet.wallet_node import Balance, WalletNode
@pytest.mark.asyncio
@ -302,3 +306,81 @@ async def test_unique_puzzle_hash_subscriptions(simulator_and_wallet: Simulators
puzzle_hashes = await node.get_puzzle_hashes_to_subscribe()
assert len(puzzle_hashes) > 1
assert len(set(puzzle_hashes)) == len(puzzle_hashes)
@pytest.mark.asyncio
async def test_get_balance(
simulator_and_wallet: SimulatorsAndWallets, self_hostname: str, default_400_blocks: List[FullBlock]
) -> None:
[full_node_api], [(wallet_node, wallet_server)], bt = simulator_and_wallet
full_node_server = full_node_api.full_node.server
def wallet_synced() -> bool:
return full_node_server.node_id in wallet_node.synced_peers
async def restart_with_fingerprint(fingerprint: Optional[int]) -> None:
wallet_node._close() # type: ignore[no-untyped-call] # WalletNode needs hinting here
await wallet_node._await_closed(shutting_down=False)
await wallet_node._start_with_fingerprint(fingerprint=fingerprint)
wallet_id = uint32(1)
initial_fingerprint = wallet_node.logged_in_fingerprint
# TODO, there is a bug in wallet_short_sync_backtrack which leads to a rollback to 0 (-1 which is another a bug) and
# with that to a KeyError when applying the race cache if there are less than WEIGHT_PROOF_RECENT_BLOCKS
# blocks but we still have a peak stored in the DB. So we need to add enough blocks for a weight proof here to
# be able to restart the wallet in this test.
for block in default_400_blocks:
await full_node_api.full_node.add_block(block)
# Initially there should be no sync and no balance
assert not wallet_synced()
assert await wallet_node.get_balance(wallet_id) == Balance()
# Generate some funds, get the balance and make sure it's as expected
await wallet_server.start_client(PeerInfo(self_hostname, uint16(full_node_server._port)), None)
await time_out_assert(30, wallet_synced)
generated_funds = await full_node_api.farm_blocks_to_wallet(5, wallet_node.wallet_state_manager.main_wallet)
expected_generated_balance = Balance(
confirmed_wallet_balance=uint128(generated_funds),
unconfirmed_wallet_balance=uint128(generated_funds),
spendable_balance=uint128(generated_funds),
max_send_amount=uint128(generated_funds),
unspent_coin_count=uint32(10),
)
generated_balance = await wallet_node.get_balance(wallet_id)
assert generated_balance == expected_generated_balance
# Load another key without funds, make sure the balance is empty.
other_key = KeyData.generate()
assert wallet_node.local_keychain is not None
wallet_node.local_keychain.add_private_key(other_key.mnemonic_str())
await restart_with_fingerprint(other_key.fingerprint)
assert await wallet_node.get_balance(wallet_id) == Balance()
# Load the initial fingerprint again and make sure the balance is still what we generated earlier
await restart_with_fingerprint(initial_fingerprint)
assert await wallet_node.get_balance(wallet_id) == generated_balance
# Connect and sync to the full node, generate more funds and test the balance caching
# TODO, there is a bug in untrusted sync if we try to sync to the same peak as stored in the DB after restart
# which leads to a rollback to 0 (-1 which is another a bug) and then to a validation error because the
# downloaded weight proof will not be added to the blockchain properly because we still have a peak with the
# same weight stored in the DB but without chain data. The 1 block generation below can be dropped if we just
# also store the chain data or maybe adjust the weight proof consideration logic in new_valid_weight_proof.
await full_node_api.farm_blocks_to_puzzlehash(1)
assert not wallet_synced()
await wallet_server.start_client(PeerInfo(self_hostname, uint16(full_node_server._port)), None)
await time_out_assert(30, wallet_synced)
generated_funds += await full_node_api.farm_blocks_to_wallet(5, wallet_node.wallet_state_manager.main_wallet)
expected_more_balance = Balance(
confirmed_wallet_balance=uint128(generated_funds),
unconfirmed_wallet_balance=uint128(generated_funds),
spendable_balance=uint128(generated_funds),
max_send_amount=uint128(generated_funds),
unspent_coin_count=uint32(20),
)
async with wallet_node.wallet_state_manager.set_sync_mode(uint32(100)):
# During sync the balance cache should not become updated, so it still should have the old balance here
assert await wallet_node.get_balance(wallet_id) == expected_generated_balance
# Now after the sync context the cache should become updated to the newly genertated balance
assert await wallet_node.get_balance(wallet_id) == expected_more_balance
# Restart one more time and make sure the balance is still correct after start
await restart_with_fingerprint(initial_fingerprint)
assert await wallet_node.get_balance(wallet_id) == expected_more_balance