mirror of
https://github.com/Chia-Network/chia-blockchain.git
synced 2024-09-19 23:21:46 +03:00
wallet: Improve balance caching (#14631)
* wallet: Improve balance caching * Fix missing entries in `get_balance` RPC and add a test for it * Fix comment
This commit is contained in:
parent
e8e8e0d709
commit
5e3bf65abb
@ -90,7 +90,6 @@ class WalletRpcApi:
|
||||
assert wallet_node is not None
|
||||
self.service = wallet_node
|
||||
self.service_name = "chia_wallet"
|
||||
self.balance_cache: Dict[int, Any] = {}
|
||||
|
||||
def get_routes(self) -> Dict[str, Endpoint]:
|
||||
return {
|
||||
@ -318,7 +317,6 @@ class WalletRpcApi:
|
||||
return {"fingerprint": fingerprint}
|
||||
|
||||
await self._stop_wallet()
|
||||
self.balance_cache = {}
|
||||
started = await self.service._start_with_fingerprint(fingerprint)
|
||||
if started is True:
|
||||
return {"fingerprint": fingerprint}
|
||||
@ -792,61 +790,15 @@ class WalletRpcApi:
|
||||
async def get_wallet_balance(self, request: Dict) -> EndpointResult:
|
||||
wallet_id = uint32(int(request["wallet_id"]))
|
||||
wallet = self.service.wallet_state_manager.wallets[wallet_id]
|
||||
|
||||
# If syncing return the last available info or 0s
|
||||
syncing = self.service.wallet_state_manager.sync_mode
|
||||
if syncing:
|
||||
if wallet_id in self.balance_cache:
|
||||
wallet_balance = self.balance_cache[wallet_id]
|
||||
else:
|
||||
wallet_balance = {
|
||||
"wallet_id": wallet_id,
|
||||
"confirmed_wallet_balance": 0,
|
||||
"unconfirmed_wallet_balance": 0,
|
||||
"spendable_balance": 0,
|
||||
"pending_change": 0,
|
||||
"max_send_amount": 0,
|
||||
"unspent_coin_count": 0,
|
||||
"pending_coin_removal_count": 0,
|
||||
"wallet_type": wallet.type(),
|
||||
}
|
||||
if self.service.logged_in_fingerprint is not None:
|
||||
wallet_balance["fingerprint"] = self.service.logged_in_fingerprint
|
||||
if wallet.type() == WalletType.CAT:
|
||||
assert isinstance(wallet, CATWallet)
|
||||
wallet_balance["asset_id"] = wallet.get_asset_id()
|
||||
else:
|
||||
async with self.service.wallet_state_manager.lock:
|
||||
unspent_records = await self.service.wallet_state_manager.coin_store.get_unspent_coins_for_wallet(
|
||||
wallet_id
|
||||
)
|
||||
balance = await wallet.get_confirmed_balance(unspent_records)
|
||||
pending_balance = await wallet.get_unconfirmed_balance(unspent_records)
|
||||
spendable_balance = await wallet.get_spendable_balance(unspent_records)
|
||||
pending_change = await wallet.get_pending_change_balance()
|
||||
max_send_amount = await wallet.get_max_send_amount(unspent_records)
|
||||
|
||||
unconfirmed_removals: Dict[
|
||||
bytes32, Coin
|
||||
] = await wallet.wallet_state_manager.unconfirmed_removals_for_wallet(wallet_id)
|
||||
wallet_balance = {
|
||||
"wallet_id": wallet_id,
|
||||
"confirmed_wallet_balance": balance,
|
||||
"unconfirmed_wallet_balance": pending_balance,
|
||||
"spendable_balance": spendable_balance,
|
||||
"pending_change": pending_change,
|
||||
"max_send_amount": max_send_amount,
|
||||
"unspent_coin_count": len(unspent_records),
|
||||
"pending_coin_removal_count": len(unconfirmed_removals),
|
||||
"wallet_type": wallet.type(),
|
||||
}
|
||||
if self.service.logged_in_fingerprint is not None:
|
||||
wallet_balance["fingerprint"] = self.service.logged_in_fingerprint
|
||||
if wallet.type() == WalletType.CAT:
|
||||
assert isinstance(wallet, CATWallet)
|
||||
wallet_balance["asset_id"] = wallet.get_asset_id()
|
||||
self.balance_cache[wallet_id] = wallet_balance
|
||||
|
||||
balance = await self.service.get_balance(wallet_id)
|
||||
wallet_balance = balance.to_json_dict()
|
||||
wallet_balance["wallet_id"] = wallet_id
|
||||
wallet_balance["wallet_type"] = wallet.type()
|
||||
if self.service.logged_in_fingerprint is not None:
|
||||
wallet_balance["fingerprint"] = self.service.logged_in_fingerprint
|
||||
if wallet.type() == WalletType.CAT:
|
||||
assert isinstance(wallet, CATWallet)
|
||||
wallet_balance["asset_id"] = wallet.get_asset_id()
|
||||
return {"wallet_balance": wallet_balance}
|
||||
|
||||
async def get_transaction(self, request: Dict) -> EndpointResult:
|
||||
|
@ -46,10 +46,11 @@ from chia.util.config import (
|
||||
)
|
||||
from chia.util.db_wrapper import manage_connection
|
||||
from chia.util.errors import KeychainIsEmpty, KeychainIsLocked, KeychainKeyNotFound, KeychainProxyConnectionFailure
|
||||
from chia.util.ints import uint32, uint64
|
||||
from chia.util.ints import uint32, uint64, uint128
|
||||
from chia.util.keychain import Keychain
|
||||
from chia.util.path import path_from_root
|
||||
from chia.util.profiler import mem_profile_task, profile_task
|
||||
from chia.util.streamable import Streamable, streamable
|
||||
from chia.wallet.transaction_record import TransactionRecord
|
||||
from chia.wallet.util.new_peak_queue import NewPeakItem, NewPeakQueue, NewPeakQueueTypes
|
||||
from chia.wallet.util.peer_request_cache import PeerRequestCache, can_use_peer_request_cache
|
||||
@ -85,6 +86,18 @@ def get_wallet_db_path(root_path: Path, config: Dict[str, Any], key_fingerprint:
|
||||
return path
|
||||
|
||||
|
||||
@streamable
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class Balance(Streamable):
|
||||
confirmed_wallet_balance: uint128 = uint128(0)
|
||||
unconfirmed_wallet_balance: uint128 = uint128(0)
|
||||
spendable_balance: uint128 = uint128(0)
|
||||
pending_change: uint64 = uint64(0)
|
||||
max_send_amount: uint128 = uint128(0)
|
||||
unspent_coin_count: uint32 = uint32(0)
|
||||
pending_coin_removal_count: uint32 = uint32(0)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class WalletNode:
|
||||
config: Dict
|
||||
@ -103,6 +116,7 @@ class WalletNode:
|
||||
logged_in_fingerprint: Optional[int] = None
|
||||
logged_in: bool = False
|
||||
_keychain_proxy: Optional[KeychainProxy] = None
|
||||
_balance_cache: Dict[int, Balance] = dataclasses.field(default_factory=dict)
|
||||
# Peers that we have long synced to
|
||||
synced_peers: Set[bytes32] = dataclasses.field(default_factory=set)
|
||||
wallet_peers: Optional[WalletPeers] = None
|
||||
@ -373,6 +387,10 @@ class WalletNode:
|
||||
self.log_in(private_key)
|
||||
self.wallet_state_manager.state_changed("sync_changed")
|
||||
|
||||
# Populate the balance caches for all wallets
|
||||
for wallet_id in self.wallet_state_manager.wallets:
|
||||
await self.get_balance(wallet_id)
|
||||
|
||||
async with self.wallet_state_manager.puzzle_store.lock:
|
||||
index = await self.wallet_state_manager.puzzle_store.get_last_derivation_path()
|
||||
if index is None or index < self.wallet_state_manager.initial_num_public_keys - 1:
|
||||
@ -407,6 +425,7 @@ class WalletNode:
|
||||
await proxy.close()
|
||||
await asyncio.sleep(0.5) # https://docs.aiohttp.org/en/stable/client_advanced.html#graceful-shutdown
|
||||
self.wallet_peers = None
|
||||
self._balance_cache = {}
|
||||
|
||||
def _set_state_changed_callback(self, callback: StateChangedProtocol) -> None:
|
||||
self.state_changed_callback = callback
|
||||
@ -1643,3 +1662,30 @@ class WalletNode:
|
||||
full_nodes = self.server.get_connections(NodeType.FULL_NODE)
|
||||
for peer in full_nodes:
|
||||
await peer.send_message(msg)
|
||||
|
||||
async def get_balance(self, wallet_id: uint32) -> Balance:
|
||||
self.log.debug(f"get_balance - wallet_id: {wallet_id}")
|
||||
if not self.wallet_state_manager.sync_mode:
|
||||
self.log.debug(f"get_balance - Updating cache for {wallet_id}")
|
||||
async with self.wallet_state_manager.lock:
|
||||
wallet = self.wallet_state_manager.wallets[wallet_id]
|
||||
unspent_records = await self.wallet_state_manager.coin_store.get_unspent_coins_for_wallet(wallet_id)
|
||||
balance = await wallet.get_confirmed_balance(unspent_records)
|
||||
pending_balance = await wallet.get_unconfirmed_balance(unspent_records)
|
||||
spendable_balance = await wallet.get_spendable_balance(unspent_records)
|
||||
pending_change = await wallet.get_pending_change_balance()
|
||||
max_send_amount = await wallet.get_max_send_amount(unspent_records)
|
||||
|
||||
unconfirmed_removals: Dict[
|
||||
bytes32, Coin
|
||||
] = await wallet.wallet_state_manager.unconfirmed_removals_for_wallet(wallet_id)
|
||||
self._balance_cache[wallet_id] = Balance(
|
||||
confirmed_wallet_balance=balance,
|
||||
unconfirmed_wallet_balance=pending_balance,
|
||||
spendable_balance=spendable_balance,
|
||||
pending_change=pending_change,
|
||||
max_send_amount=max_send_amount,
|
||||
unspent_coin_count=uint32(len(unspent_records)),
|
||||
pending_coin_removal_count=uint32(len(unconfirmed_removals)),
|
||||
)
|
||||
return self._balance_cache.get(wallet_id, Balance())
|
||||
|
@ -224,6 +224,18 @@ async def assert_push_tx_error(node_rpc: FullNodeRpcClient, tx: TransactionRecor
|
||||
raise ValueError from error
|
||||
|
||||
|
||||
async def assert_get_balance(rpc_client: WalletRpcClient, wallet_node: WalletNode, wallet: WalletProtocol) -> None:
|
||||
expected_balance = await wallet_node.get_balance(wallet.id())
|
||||
expected_balance_dict = expected_balance.to_json_dict()
|
||||
expected_balance_dict["wallet_id"] = wallet.id()
|
||||
expected_balance_dict["wallet_type"] = wallet.type()
|
||||
expected_balance_dict["fingerprint"] = wallet_node.logged_in_fingerprint
|
||||
if wallet.type() == WalletType.CAT:
|
||||
assert isinstance(wallet, CATWallet)
|
||||
expected_balance_dict["asset_id"] = wallet.get_asset_id()
|
||||
assert await rpc_client.get_wallet_balance(wallet.id()) == expected_balance_dict
|
||||
|
||||
|
||||
async def tx_in_mempool(client: WalletRpcClient, transaction_id: bytes32):
|
||||
tx = await client.get_transaction(1, transaction_id)
|
||||
return tx.is_in_mempool()
|
||||
@ -310,6 +322,22 @@ async def test_push_transactions(wallet_rpc_environment: WalletRpcTestEnvironmen
|
||||
assert tx.confirmed
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_balance(wallet_rpc_environment: WalletRpcTestEnvironment):
|
||||
env = wallet_rpc_environment
|
||||
wallet: Wallet = env.wallet_1.wallet
|
||||
wallet_node: WalletNode = env.wallet_1.node
|
||||
full_node_api: FullNodeSimulator = env.full_node.api
|
||||
wallet_rpc_client = env.wallet_1.rpc_client
|
||||
await full_node_api.farm_blocks_to_wallet(2, wallet)
|
||||
async with wallet_node.wallet_state_manager.lock:
|
||||
cat_wallet: CATWallet = await CATWallet.create_new_cat_wallet(
|
||||
wallet_node.wallet_state_manager, wallet, {"identifier": "genesis_by_id"}, uint64(100)
|
||||
)
|
||||
await assert_get_balance(wallet_rpc_client, wallet_node, wallet)
|
||||
await assert_get_balance(wallet_rpc_client, wallet_node, cat_wallet)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_timestamp_for_height(wallet_rpc_environment: WalletRpcTestEnvironment):
|
||||
env: WalletRpcTestEnvironment = wallet_rpc_environment
|
||||
|
@ -2,16 +2,20 @@ from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import pytest
|
||||
from blspy import PrivateKey
|
||||
|
||||
from chia.simulator.block_tools import test_constants
|
||||
from chia.simulator.setup_nodes import SimulatorsAndWallets
|
||||
from chia.simulator.time_out_assert import time_out_assert
|
||||
from chia.types.full_block import FullBlock
|
||||
from chia.types.peer_info import PeerInfo
|
||||
from chia.util.config import load_config
|
||||
from chia.util.keychain import Keychain, generate_mnemonic
|
||||
from chia.wallet.wallet_node import WalletNode
|
||||
from chia.util.ints import uint16, uint32, uint128
|
||||
from chia.util.keychain import Keychain, KeyData, generate_mnemonic
|
||||
from chia.wallet.wallet_node import Balance, WalletNode
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -302,3 +306,81 @@ async def test_unique_puzzle_hash_subscriptions(simulator_and_wallet: Simulators
|
||||
puzzle_hashes = await node.get_puzzle_hashes_to_subscribe()
|
||||
assert len(puzzle_hashes) > 1
|
||||
assert len(set(puzzle_hashes)) == len(puzzle_hashes)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_balance(
|
||||
simulator_and_wallet: SimulatorsAndWallets, self_hostname: str, default_400_blocks: List[FullBlock]
|
||||
) -> None:
|
||||
[full_node_api], [(wallet_node, wallet_server)], bt = simulator_and_wallet
|
||||
full_node_server = full_node_api.full_node.server
|
||||
|
||||
def wallet_synced() -> bool:
|
||||
return full_node_server.node_id in wallet_node.synced_peers
|
||||
|
||||
async def restart_with_fingerprint(fingerprint: Optional[int]) -> None:
|
||||
wallet_node._close() # type: ignore[no-untyped-call] # WalletNode needs hinting here
|
||||
await wallet_node._await_closed(shutting_down=False)
|
||||
await wallet_node._start_with_fingerprint(fingerprint=fingerprint)
|
||||
|
||||
wallet_id = uint32(1)
|
||||
initial_fingerprint = wallet_node.logged_in_fingerprint
|
||||
|
||||
# TODO, there is a bug in wallet_short_sync_backtrack which leads to a rollback to 0 (-1 which is another a bug) and
|
||||
# with that to a KeyError when applying the race cache if there are less than WEIGHT_PROOF_RECENT_BLOCKS
|
||||
# blocks but we still have a peak stored in the DB. So we need to add enough blocks for a weight proof here to
|
||||
# be able to restart the wallet in this test.
|
||||
for block in default_400_blocks:
|
||||
await full_node_api.full_node.add_block(block)
|
||||
|
||||
# Initially there should be no sync and no balance
|
||||
assert not wallet_synced()
|
||||
assert await wallet_node.get_balance(wallet_id) == Balance()
|
||||
# Generate some funds, get the balance and make sure it's as expected
|
||||
await wallet_server.start_client(PeerInfo(self_hostname, uint16(full_node_server._port)), None)
|
||||
await time_out_assert(30, wallet_synced)
|
||||
generated_funds = await full_node_api.farm_blocks_to_wallet(5, wallet_node.wallet_state_manager.main_wallet)
|
||||
expected_generated_balance = Balance(
|
||||
confirmed_wallet_balance=uint128(generated_funds),
|
||||
unconfirmed_wallet_balance=uint128(generated_funds),
|
||||
spendable_balance=uint128(generated_funds),
|
||||
max_send_amount=uint128(generated_funds),
|
||||
unspent_coin_count=uint32(10),
|
||||
)
|
||||
generated_balance = await wallet_node.get_balance(wallet_id)
|
||||
assert generated_balance == expected_generated_balance
|
||||
# Load another key without funds, make sure the balance is empty.
|
||||
other_key = KeyData.generate()
|
||||
assert wallet_node.local_keychain is not None
|
||||
wallet_node.local_keychain.add_private_key(other_key.mnemonic_str())
|
||||
await restart_with_fingerprint(other_key.fingerprint)
|
||||
assert await wallet_node.get_balance(wallet_id) == Balance()
|
||||
# Load the initial fingerprint again and make sure the balance is still what we generated earlier
|
||||
await restart_with_fingerprint(initial_fingerprint)
|
||||
assert await wallet_node.get_balance(wallet_id) == generated_balance
|
||||
# Connect and sync to the full node, generate more funds and test the balance caching
|
||||
# TODO, there is a bug in untrusted sync if we try to sync to the same peak as stored in the DB after restart
|
||||
# which leads to a rollback to 0 (-1 which is another a bug) and then to a validation error because the
|
||||
# downloaded weight proof will not be added to the blockchain properly because we still have a peak with the
|
||||
# same weight stored in the DB but without chain data. The 1 block generation below can be dropped if we just
|
||||
# also store the chain data or maybe adjust the weight proof consideration logic in new_valid_weight_proof.
|
||||
await full_node_api.farm_blocks_to_puzzlehash(1)
|
||||
assert not wallet_synced()
|
||||
await wallet_server.start_client(PeerInfo(self_hostname, uint16(full_node_server._port)), None)
|
||||
await time_out_assert(30, wallet_synced)
|
||||
generated_funds += await full_node_api.farm_blocks_to_wallet(5, wallet_node.wallet_state_manager.main_wallet)
|
||||
expected_more_balance = Balance(
|
||||
confirmed_wallet_balance=uint128(generated_funds),
|
||||
unconfirmed_wallet_balance=uint128(generated_funds),
|
||||
spendable_balance=uint128(generated_funds),
|
||||
max_send_amount=uint128(generated_funds),
|
||||
unspent_coin_count=uint32(20),
|
||||
)
|
||||
async with wallet_node.wallet_state_manager.set_sync_mode(uint32(100)):
|
||||
# During sync the balance cache should not become updated, so it still should have the old balance here
|
||||
assert await wallet_node.get_balance(wallet_id) == expected_generated_balance
|
||||
# Now after the sync context the cache should become updated to the newly genertated balance
|
||||
assert await wallet_node.get_balance(wallet_id) == expected_more_balance
|
||||
# Restart one more time and make sure the balance is still correct after start
|
||||
await restart_with_fingerprint(initial_fingerprint)
|
||||
assert await wallet_node.get_balance(wallet_id) == expected_more_balance
|
||||
|
Loading…
Reference in New Issue
Block a user