mirror of
https://github.com/Chia-Network/chia-blockchain.git
synced 2024-09-19 23:21:46 +03:00
wallet: Improve balance caching (#14631)
* wallet: Improve balance caching * Fix missing entries in `get_balance` RPC and add a test for it * Fix comment
This commit is contained in:
parent
e8e8e0d709
commit
5e3bf65abb
@ -90,7 +90,6 @@ class WalletRpcApi:
|
|||||||
assert wallet_node is not None
|
assert wallet_node is not None
|
||||||
self.service = wallet_node
|
self.service = wallet_node
|
||||||
self.service_name = "chia_wallet"
|
self.service_name = "chia_wallet"
|
||||||
self.balance_cache: Dict[int, Any] = {}
|
|
||||||
|
|
||||||
def get_routes(self) -> Dict[str, Endpoint]:
|
def get_routes(self) -> Dict[str, Endpoint]:
|
||||||
return {
|
return {
|
||||||
@ -318,7 +317,6 @@ class WalletRpcApi:
|
|||||||
return {"fingerprint": fingerprint}
|
return {"fingerprint": fingerprint}
|
||||||
|
|
||||||
await self._stop_wallet()
|
await self._stop_wallet()
|
||||||
self.balance_cache = {}
|
|
||||||
started = await self.service._start_with_fingerprint(fingerprint)
|
started = await self.service._start_with_fingerprint(fingerprint)
|
||||||
if started is True:
|
if started is True:
|
||||||
return {"fingerprint": fingerprint}
|
return {"fingerprint": fingerprint}
|
||||||
@ -792,61 +790,15 @@ class WalletRpcApi:
|
|||||||
async def get_wallet_balance(self, request: Dict) -> EndpointResult:
|
async def get_wallet_balance(self, request: Dict) -> EndpointResult:
|
||||||
wallet_id = uint32(int(request["wallet_id"]))
|
wallet_id = uint32(int(request["wallet_id"]))
|
||||||
wallet = self.service.wallet_state_manager.wallets[wallet_id]
|
wallet = self.service.wallet_state_manager.wallets[wallet_id]
|
||||||
|
balance = await self.service.get_balance(wallet_id)
|
||||||
# If syncing return the last available info or 0s
|
wallet_balance = balance.to_json_dict()
|
||||||
syncing = self.service.wallet_state_manager.sync_mode
|
wallet_balance["wallet_id"] = wallet_id
|
||||||
if syncing:
|
wallet_balance["wallet_type"] = wallet.type()
|
||||||
if wallet_id in self.balance_cache:
|
if self.service.logged_in_fingerprint is not None:
|
||||||
wallet_balance = self.balance_cache[wallet_id]
|
wallet_balance["fingerprint"] = self.service.logged_in_fingerprint
|
||||||
else:
|
if wallet.type() == WalletType.CAT:
|
||||||
wallet_balance = {
|
assert isinstance(wallet, CATWallet)
|
||||||
"wallet_id": wallet_id,
|
wallet_balance["asset_id"] = wallet.get_asset_id()
|
||||||
"confirmed_wallet_balance": 0,
|
|
||||||
"unconfirmed_wallet_balance": 0,
|
|
||||||
"spendable_balance": 0,
|
|
||||||
"pending_change": 0,
|
|
||||||
"max_send_amount": 0,
|
|
||||||
"unspent_coin_count": 0,
|
|
||||||
"pending_coin_removal_count": 0,
|
|
||||||
"wallet_type": wallet.type(),
|
|
||||||
}
|
|
||||||
if self.service.logged_in_fingerprint is not None:
|
|
||||||
wallet_balance["fingerprint"] = self.service.logged_in_fingerprint
|
|
||||||
if wallet.type() == WalletType.CAT:
|
|
||||||
assert isinstance(wallet, CATWallet)
|
|
||||||
wallet_balance["asset_id"] = wallet.get_asset_id()
|
|
||||||
else:
|
|
||||||
async with self.service.wallet_state_manager.lock:
|
|
||||||
unspent_records = await self.service.wallet_state_manager.coin_store.get_unspent_coins_for_wallet(
|
|
||||||
wallet_id
|
|
||||||
)
|
|
||||||
balance = await wallet.get_confirmed_balance(unspent_records)
|
|
||||||
pending_balance = await wallet.get_unconfirmed_balance(unspent_records)
|
|
||||||
spendable_balance = await wallet.get_spendable_balance(unspent_records)
|
|
||||||
pending_change = await wallet.get_pending_change_balance()
|
|
||||||
max_send_amount = await wallet.get_max_send_amount(unspent_records)
|
|
||||||
|
|
||||||
unconfirmed_removals: Dict[
|
|
||||||
bytes32, Coin
|
|
||||||
] = await wallet.wallet_state_manager.unconfirmed_removals_for_wallet(wallet_id)
|
|
||||||
wallet_balance = {
|
|
||||||
"wallet_id": wallet_id,
|
|
||||||
"confirmed_wallet_balance": balance,
|
|
||||||
"unconfirmed_wallet_balance": pending_balance,
|
|
||||||
"spendable_balance": spendable_balance,
|
|
||||||
"pending_change": pending_change,
|
|
||||||
"max_send_amount": max_send_amount,
|
|
||||||
"unspent_coin_count": len(unspent_records),
|
|
||||||
"pending_coin_removal_count": len(unconfirmed_removals),
|
|
||||||
"wallet_type": wallet.type(),
|
|
||||||
}
|
|
||||||
if self.service.logged_in_fingerprint is not None:
|
|
||||||
wallet_balance["fingerprint"] = self.service.logged_in_fingerprint
|
|
||||||
if wallet.type() == WalletType.CAT:
|
|
||||||
assert isinstance(wallet, CATWallet)
|
|
||||||
wallet_balance["asset_id"] = wallet.get_asset_id()
|
|
||||||
self.balance_cache[wallet_id] = wallet_balance
|
|
||||||
|
|
||||||
return {"wallet_balance": wallet_balance}
|
return {"wallet_balance": wallet_balance}
|
||||||
|
|
||||||
async def get_transaction(self, request: Dict) -> EndpointResult:
|
async def get_transaction(self, request: Dict) -> EndpointResult:
|
||||||
|
@ -46,10 +46,11 @@ from chia.util.config import (
|
|||||||
)
|
)
|
||||||
from chia.util.db_wrapper import manage_connection
|
from chia.util.db_wrapper import manage_connection
|
||||||
from chia.util.errors import KeychainIsEmpty, KeychainIsLocked, KeychainKeyNotFound, KeychainProxyConnectionFailure
|
from chia.util.errors import KeychainIsEmpty, KeychainIsLocked, KeychainKeyNotFound, KeychainProxyConnectionFailure
|
||||||
from chia.util.ints import uint32, uint64
|
from chia.util.ints import uint32, uint64, uint128
|
||||||
from chia.util.keychain import Keychain
|
from chia.util.keychain import Keychain
|
||||||
from chia.util.path import path_from_root
|
from chia.util.path import path_from_root
|
||||||
from chia.util.profiler import mem_profile_task, profile_task
|
from chia.util.profiler import mem_profile_task, profile_task
|
||||||
|
from chia.util.streamable import Streamable, streamable
|
||||||
from chia.wallet.transaction_record import TransactionRecord
|
from chia.wallet.transaction_record import TransactionRecord
|
||||||
from chia.wallet.util.new_peak_queue import NewPeakItem, NewPeakQueue, NewPeakQueueTypes
|
from chia.wallet.util.new_peak_queue import NewPeakItem, NewPeakQueue, NewPeakQueueTypes
|
||||||
from chia.wallet.util.peer_request_cache import PeerRequestCache, can_use_peer_request_cache
|
from chia.wallet.util.peer_request_cache import PeerRequestCache, can_use_peer_request_cache
|
||||||
@ -85,6 +86,18 @@ def get_wallet_db_path(root_path: Path, config: Dict[str, Any], key_fingerprint:
|
|||||||
return path
|
return path
|
||||||
|
|
||||||
|
|
||||||
|
@streamable
|
||||||
|
@dataclasses.dataclass(frozen=True)
|
||||||
|
class Balance(Streamable):
|
||||||
|
confirmed_wallet_balance: uint128 = uint128(0)
|
||||||
|
unconfirmed_wallet_balance: uint128 = uint128(0)
|
||||||
|
spendable_balance: uint128 = uint128(0)
|
||||||
|
pending_change: uint64 = uint64(0)
|
||||||
|
max_send_amount: uint128 = uint128(0)
|
||||||
|
unspent_coin_count: uint32 = uint32(0)
|
||||||
|
pending_coin_removal_count: uint32 = uint32(0)
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class WalletNode:
|
class WalletNode:
|
||||||
config: Dict
|
config: Dict
|
||||||
@ -103,6 +116,7 @@ class WalletNode:
|
|||||||
logged_in_fingerprint: Optional[int] = None
|
logged_in_fingerprint: Optional[int] = None
|
||||||
logged_in: bool = False
|
logged_in: bool = False
|
||||||
_keychain_proxy: Optional[KeychainProxy] = None
|
_keychain_proxy: Optional[KeychainProxy] = None
|
||||||
|
_balance_cache: Dict[int, Balance] = dataclasses.field(default_factory=dict)
|
||||||
# Peers that we have long synced to
|
# Peers that we have long synced to
|
||||||
synced_peers: Set[bytes32] = dataclasses.field(default_factory=set)
|
synced_peers: Set[bytes32] = dataclasses.field(default_factory=set)
|
||||||
wallet_peers: Optional[WalletPeers] = None
|
wallet_peers: Optional[WalletPeers] = None
|
||||||
@ -373,6 +387,10 @@ class WalletNode:
|
|||||||
self.log_in(private_key)
|
self.log_in(private_key)
|
||||||
self.wallet_state_manager.state_changed("sync_changed")
|
self.wallet_state_manager.state_changed("sync_changed")
|
||||||
|
|
||||||
|
# Populate the balance caches for all wallets
|
||||||
|
for wallet_id in self.wallet_state_manager.wallets:
|
||||||
|
await self.get_balance(wallet_id)
|
||||||
|
|
||||||
async with self.wallet_state_manager.puzzle_store.lock:
|
async with self.wallet_state_manager.puzzle_store.lock:
|
||||||
index = await self.wallet_state_manager.puzzle_store.get_last_derivation_path()
|
index = await self.wallet_state_manager.puzzle_store.get_last_derivation_path()
|
||||||
if index is None or index < self.wallet_state_manager.initial_num_public_keys - 1:
|
if index is None or index < self.wallet_state_manager.initial_num_public_keys - 1:
|
||||||
@ -407,6 +425,7 @@ class WalletNode:
|
|||||||
await proxy.close()
|
await proxy.close()
|
||||||
await asyncio.sleep(0.5) # https://docs.aiohttp.org/en/stable/client_advanced.html#graceful-shutdown
|
await asyncio.sleep(0.5) # https://docs.aiohttp.org/en/stable/client_advanced.html#graceful-shutdown
|
||||||
self.wallet_peers = None
|
self.wallet_peers = None
|
||||||
|
self._balance_cache = {}
|
||||||
|
|
||||||
def _set_state_changed_callback(self, callback: StateChangedProtocol) -> None:
|
def _set_state_changed_callback(self, callback: StateChangedProtocol) -> None:
|
||||||
self.state_changed_callback = callback
|
self.state_changed_callback = callback
|
||||||
@ -1643,3 +1662,30 @@ class WalletNode:
|
|||||||
full_nodes = self.server.get_connections(NodeType.FULL_NODE)
|
full_nodes = self.server.get_connections(NodeType.FULL_NODE)
|
||||||
for peer in full_nodes:
|
for peer in full_nodes:
|
||||||
await peer.send_message(msg)
|
await peer.send_message(msg)
|
||||||
|
|
||||||
|
async def get_balance(self, wallet_id: uint32) -> Balance:
|
||||||
|
self.log.debug(f"get_balance - wallet_id: {wallet_id}")
|
||||||
|
if not self.wallet_state_manager.sync_mode:
|
||||||
|
self.log.debug(f"get_balance - Updating cache for {wallet_id}")
|
||||||
|
async with self.wallet_state_manager.lock:
|
||||||
|
wallet = self.wallet_state_manager.wallets[wallet_id]
|
||||||
|
unspent_records = await self.wallet_state_manager.coin_store.get_unspent_coins_for_wallet(wallet_id)
|
||||||
|
balance = await wallet.get_confirmed_balance(unspent_records)
|
||||||
|
pending_balance = await wallet.get_unconfirmed_balance(unspent_records)
|
||||||
|
spendable_balance = await wallet.get_spendable_balance(unspent_records)
|
||||||
|
pending_change = await wallet.get_pending_change_balance()
|
||||||
|
max_send_amount = await wallet.get_max_send_amount(unspent_records)
|
||||||
|
|
||||||
|
unconfirmed_removals: Dict[
|
||||||
|
bytes32, Coin
|
||||||
|
] = await wallet.wallet_state_manager.unconfirmed_removals_for_wallet(wallet_id)
|
||||||
|
self._balance_cache[wallet_id] = Balance(
|
||||||
|
confirmed_wallet_balance=balance,
|
||||||
|
unconfirmed_wallet_balance=pending_balance,
|
||||||
|
spendable_balance=spendable_balance,
|
||||||
|
pending_change=pending_change,
|
||||||
|
max_send_amount=max_send_amount,
|
||||||
|
unspent_coin_count=uint32(len(unspent_records)),
|
||||||
|
pending_coin_removal_count=uint32(len(unconfirmed_removals)),
|
||||||
|
)
|
||||||
|
return self._balance_cache.get(wallet_id, Balance())
|
||||||
|
@ -224,6 +224,18 @@ async def assert_push_tx_error(node_rpc: FullNodeRpcClient, tx: TransactionRecor
|
|||||||
raise ValueError from error
|
raise ValueError from error
|
||||||
|
|
||||||
|
|
||||||
|
async def assert_get_balance(rpc_client: WalletRpcClient, wallet_node: WalletNode, wallet: WalletProtocol) -> None:
|
||||||
|
expected_balance = await wallet_node.get_balance(wallet.id())
|
||||||
|
expected_balance_dict = expected_balance.to_json_dict()
|
||||||
|
expected_balance_dict["wallet_id"] = wallet.id()
|
||||||
|
expected_balance_dict["wallet_type"] = wallet.type()
|
||||||
|
expected_balance_dict["fingerprint"] = wallet_node.logged_in_fingerprint
|
||||||
|
if wallet.type() == WalletType.CAT:
|
||||||
|
assert isinstance(wallet, CATWallet)
|
||||||
|
expected_balance_dict["asset_id"] = wallet.get_asset_id()
|
||||||
|
assert await rpc_client.get_wallet_balance(wallet.id()) == expected_balance_dict
|
||||||
|
|
||||||
|
|
||||||
async def tx_in_mempool(client: WalletRpcClient, transaction_id: bytes32):
|
async def tx_in_mempool(client: WalletRpcClient, transaction_id: bytes32):
|
||||||
tx = await client.get_transaction(1, transaction_id)
|
tx = await client.get_transaction(1, transaction_id)
|
||||||
return tx.is_in_mempool()
|
return tx.is_in_mempool()
|
||||||
@ -310,6 +322,22 @@ async def test_push_transactions(wallet_rpc_environment: WalletRpcTestEnvironmen
|
|||||||
assert tx.confirmed
|
assert tx.confirmed
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_balance(wallet_rpc_environment: WalletRpcTestEnvironment):
|
||||||
|
env = wallet_rpc_environment
|
||||||
|
wallet: Wallet = env.wallet_1.wallet
|
||||||
|
wallet_node: WalletNode = env.wallet_1.node
|
||||||
|
full_node_api: FullNodeSimulator = env.full_node.api
|
||||||
|
wallet_rpc_client = env.wallet_1.rpc_client
|
||||||
|
await full_node_api.farm_blocks_to_wallet(2, wallet)
|
||||||
|
async with wallet_node.wallet_state_manager.lock:
|
||||||
|
cat_wallet: CATWallet = await CATWallet.create_new_cat_wallet(
|
||||||
|
wallet_node.wallet_state_manager, wallet, {"identifier": "genesis_by_id"}, uint64(100)
|
||||||
|
)
|
||||||
|
await assert_get_balance(wallet_rpc_client, wallet_node, wallet)
|
||||||
|
await assert_get_balance(wallet_rpc_client, wallet_node, cat_wallet)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_get_timestamp_for_height(wallet_rpc_environment: WalletRpcTestEnvironment):
|
async def test_get_timestamp_for_height(wallet_rpc_environment: WalletRpcTestEnvironment):
|
||||||
env: WalletRpcTestEnvironment = wallet_rpc_environment
|
env: WalletRpcTestEnvironment = wallet_rpc_environment
|
||||||
|
@ -2,16 +2,20 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from blspy import PrivateKey
|
from blspy import PrivateKey
|
||||||
|
|
||||||
from chia.simulator.block_tools import test_constants
|
from chia.simulator.block_tools import test_constants
|
||||||
from chia.simulator.setup_nodes import SimulatorsAndWallets
|
from chia.simulator.setup_nodes import SimulatorsAndWallets
|
||||||
|
from chia.simulator.time_out_assert import time_out_assert
|
||||||
|
from chia.types.full_block import FullBlock
|
||||||
|
from chia.types.peer_info import PeerInfo
|
||||||
from chia.util.config import load_config
|
from chia.util.config import load_config
|
||||||
from chia.util.keychain import Keychain, generate_mnemonic
|
from chia.util.ints import uint16, uint32, uint128
|
||||||
from chia.wallet.wallet_node import WalletNode
|
from chia.util.keychain import Keychain, KeyData, generate_mnemonic
|
||||||
|
from chia.wallet.wallet_node import Balance, WalletNode
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@ -302,3 +306,81 @@ async def test_unique_puzzle_hash_subscriptions(simulator_and_wallet: Simulators
|
|||||||
puzzle_hashes = await node.get_puzzle_hashes_to_subscribe()
|
puzzle_hashes = await node.get_puzzle_hashes_to_subscribe()
|
||||||
assert len(puzzle_hashes) > 1
|
assert len(puzzle_hashes) > 1
|
||||||
assert len(set(puzzle_hashes)) == len(puzzle_hashes)
|
assert len(set(puzzle_hashes)) == len(puzzle_hashes)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_balance(
|
||||||
|
simulator_and_wallet: SimulatorsAndWallets, self_hostname: str, default_400_blocks: List[FullBlock]
|
||||||
|
) -> None:
|
||||||
|
[full_node_api], [(wallet_node, wallet_server)], bt = simulator_and_wallet
|
||||||
|
full_node_server = full_node_api.full_node.server
|
||||||
|
|
||||||
|
def wallet_synced() -> bool:
|
||||||
|
return full_node_server.node_id in wallet_node.synced_peers
|
||||||
|
|
||||||
|
async def restart_with_fingerprint(fingerprint: Optional[int]) -> None:
|
||||||
|
wallet_node._close() # type: ignore[no-untyped-call] # WalletNode needs hinting here
|
||||||
|
await wallet_node._await_closed(shutting_down=False)
|
||||||
|
await wallet_node._start_with_fingerprint(fingerprint=fingerprint)
|
||||||
|
|
||||||
|
wallet_id = uint32(1)
|
||||||
|
initial_fingerprint = wallet_node.logged_in_fingerprint
|
||||||
|
|
||||||
|
# TODO, there is a bug in wallet_short_sync_backtrack which leads to a rollback to 0 (-1 which is another a bug) and
|
||||||
|
# with that to a KeyError when applying the race cache if there are less than WEIGHT_PROOF_RECENT_BLOCKS
|
||||||
|
# blocks but we still have a peak stored in the DB. So we need to add enough blocks for a weight proof here to
|
||||||
|
# be able to restart the wallet in this test.
|
||||||
|
for block in default_400_blocks:
|
||||||
|
await full_node_api.full_node.add_block(block)
|
||||||
|
|
||||||
|
# Initially there should be no sync and no balance
|
||||||
|
assert not wallet_synced()
|
||||||
|
assert await wallet_node.get_balance(wallet_id) == Balance()
|
||||||
|
# Generate some funds, get the balance and make sure it's as expected
|
||||||
|
await wallet_server.start_client(PeerInfo(self_hostname, uint16(full_node_server._port)), None)
|
||||||
|
await time_out_assert(30, wallet_synced)
|
||||||
|
generated_funds = await full_node_api.farm_blocks_to_wallet(5, wallet_node.wallet_state_manager.main_wallet)
|
||||||
|
expected_generated_balance = Balance(
|
||||||
|
confirmed_wallet_balance=uint128(generated_funds),
|
||||||
|
unconfirmed_wallet_balance=uint128(generated_funds),
|
||||||
|
spendable_balance=uint128(generated_funds),
|
||||||
|
max_send_amount=uint128(generated_funds),
|
||||||
|
unspent_coin_count=uint32(10),
|
||||||
|
)
|
||||||
|
generated_balance = await wallet_node.get_balance(wallet_id)
|
||||||
|
assert generated_balance == expected_generated_balance
|
||||||
|
# Load another key without funds, make sure the balance is empty.
|
||||||
|
other_key = KeyData.generate()
|
||||||
|
assert wallet_node.local_keychain is not None
|
||||||
|
wallet_node.local_keychain.add_private_key(other_key.mnemonic_str())
|
||||||
|
await restart_with_fingerprint(other_key.fingerprint)
|
||||||
|
assert await wallet_node.get_balance(wallet_id) == Balance()
|
||||||
|
# Load the initial fingerprint again and make sure the balance is still what we generated earlier
|
||||||
|
await restart_with_fingerprint(initial_fingerprint)
|
||||||
|
assert await wallet_node.get_balance(wallet_id) == generated_balance
|
||||||
|
# Connect and sync to the full node, generate more funds and test the balance caching
|
||||||
|
# TODO, there is a bug in untrusted sync if we try to sync to the same peak as stored in the DB after restart
|
||||||
|
# which leads to a rollback to 0 (-1 which is another a bug) and then to a validation error because the
|
||||||
|
# downloaded weight proof will not be added to the blockchain properly because we still have a peak with the
|
||||||
|
# same weight stored in the DB but without chain data. The 1 block generation below can be dropped if we just
|
||||||
|
# also store the chain data or maybe adjust the weight proof consideration logic in new_valid_weight_proof.
|
||||||
|
await full_node_api.farm_blocks_to_puzzlehash(1)
|
||||||
|
assert not wallet_synced()
|
||||||
|
await wallet_server.start_client(PeerInfo(self_hostname, uint16(full_node_server._port)), None)
|
||||||
|
await time_out_assert(30, wallet_synced)
|
||||||
|
generated_funds += await full_node_api.farm_blocks_to_wallet(5, wallet_node.wallet_state_manager.main_wallet)
|
||||||
|
expected_more_balance = Balance(
|
||||||
|
confirmed_wallet_balance=uint128(generated_funds),
|
||||||
|
unconfirmed_wallet_balance=uint128(generated_funds),
|
||||||
|
spendable_balance=uint128(generated_funds),
|
||||||
|
max_send_amount=uint128(generated_funds),
|
||||||
|
unspent_coin_count=uint32(20),
|
||||||
|
)
|
||||||
|
async with wallet_node.wallet_state_manager.set_sync_mode(uint32(100)):
|
||||||
|
# During sync the balance cache should not become updated, so it still should have the old balance here
|
||||||
|
assert await wallet_node.get_balance(wallet_id) == expected_generated_balance
|
||||||
|
# Now after the sync context the cache should become updated to the newly genertated balance
|
||||||
|
assert await wallet_node.get_balance(wallet_id) == expected_more_balance
|
||||||
|
# Restart one more time and make sure the balance is still correct after start
|
||||||
|
await restart_with_fingerprint(initial_fingerprint)
|
||||||
|
assert await wallet_node.get_balance(wallet_id) == expected_more_balance
|
||||||
|
Loading…
Reference in New Issue
Block a user