From 17659714268667ed38ba7a54f646ec9915451c07 Mon Sep 17 00:00:00 2001 From: dustinface <35775977+xdustinface@users.noreply.github.com> Date: Thu, 15 Jun 2023 16:46:31 +0200 Subject: [PATCH] rpc: Fix and test `WalletRpcApi.get_coin_records_by_names` (#15509) --- chia/rpc/wallet_rpc_api.py | 4 +-- chia/wallet/wallet_coin_store.py | 2 ++ tests/wallet/rpc/test_wallet_rpc.py | 51 +++++++++++++++++++++++++++++ 3 files changed, 55 insertions(+), 2 deletions(-) diff --git a/chia/rpc/wallet_rpc_api.py b/chia/rpc/wallet_rpc_api.py index 19213b2e8dd8..eb30293c87c8 100644 --- a/chia/rpc/wallet_rpc_api.py +++ b/chia/rpc/wallet_rpc_api.py @@ -82,7 +82,7 @@ from chia.wallet.vc_wallet.vc_store import VCProofs from chia.wallet.vc_wallet.vc_wallet import VCWallet from chia.wallet.wallet import CHIP_0002_SIGN_MESSAGE_PREFIX, Wallet from chia.wallet.wallet_coin_record import WalletCoinRecord -from chia.wallet.wallet_coin_store import CoinRecordOrder, GetCoinRecords +from chia.wallet.wallet_coin_store import CoinRecordOrder, GetCoinRecords, unspent_range from chia.wallet.wallet_info import WalletInfo from chia.wallet.wallet_node import WalletNode from chia.wallet.wallet_protocol import WalletProtocol @@ -1224,7 +1224,7 @@ class WalletRpcApi: kwargs["confirmed_range"] = confirmed_range if "include_spent_coins" in request and not str2bool(request["include_spent_coins"]): - kwargs["spent_range"] = UInt32Range(start=uint32(uint32.MAXIMUM_EXCLUSIVE - 1)) + kwargs["spent_range"] = unspent_range async with self.service.wallet_state_manager.lock: coin_records: List[CoinRecord] = await self.service.wallet_state_manager.get_coin_records_by_coin_ids( diff --git a/chia/wallet/wallet_coin_store.py b/chia/wallet/wallet_coin_store.py index 075f582ec9f5..4a6efdf34cc8 100644 --- a/chia/wallet/wallet_coin_store.py +++ b/chia/wallet/wallet_coin_store.py @@ -17,6 +17,8 @@ from chia.wallet.util.query_filter import AmountFilter, FilterMode, HashFilter from chia.wallet.util.wallet_types import CoinType, WalletType from chia.wallet.wallet_coin_record import WalletCoinRecord +unspent_range = UInt32Range(stop=uint32(0)) + class CoinRecordOrder(IntEnum): confirmed_height = 1 diff --git a/tests/wallet/rpc/test_wallet_rpc.py b/tests/wallet/rpc/test_wallet_rpc.py index cdbe057f49dd..309207c27ba4 100644 --- a/tests/wallet/rpc/test_wallet_rpc.py +++ b/tests/wallet/rpc/test_wallet_rpc.py @@ -1154,6 +1154,57 @@ async def test_offer_endpoints(wallet_rpc_environment: WalletRpcTestEnvironment) ### +@pytest.mark.asyncio +async def test_get_coin_records_by_names(wallet_rpc_environment: WalletRpcTestEnvironment) -> None: + env: WalletRpcTestEnvironment = wallet_rpc_environment + wallet_node: WalletNode = env.wallet_1.node + client: WalletRpcClient = env.wallet_1.rpc_client + store = wallet_node.wallet_state_manager.coin_store + full_node_api = env.full_node.api + # Generate some funds + generated_funds = await generate_funds(full_node_api, env.wallet_1, 5) + address = encode_puzzle_hash(await env.wallet_1.wallet.get_new_puzzlehash(), "txch") + # Spend half of it back to the same wallet get some spent coins in the wallet + tx = await client.send_transaction(1, uint64(generated_funds / 2), address) + assert tx.spend_bundle is not None + await time_out_assert(20, tx_in_mempool, True, client, tx.name) + await farm_transaction(full_node_api, wallet_node, tx.spend_bundle) + await full_node_api.wait_for_wallet_synced(wallet_node=wallet_node, timeout=5) + # Prepare some records and parameters first + result = await store.get_coin_records() + coins = {record.coin for record in result.records} + coins_unspent = {record.coin for record in result.records if not record.spent} + coin_ids = [coin.name() for coin in coins] + coin_ids_unspent = [coin.name() for coin in coins_unspent] + assert len(coin_ids) > 0 + assert len(coin_ids_unspent) > 0 + # Do some queries to trigger all parameters + # 1. Empty coin_ids + assert await client.get_coin_records_by_names([]) == [] + # 2. All coins + rpc_result = await client.get_coin_records_by_names(coin_ids + coin_ids_unspent) + assert set(record.coin for record in rpc_result) == {*coins, *coins_unspent} + # 3. All spent coins + rpc_result = await client.get_coin_records_by_names(coin_ids, include_spent_coins=True) + assert set(record.coin for record in rpc_result) == coins + # 4. All unspent coins + rpc_result = await client.get_coin_records_by_names(coin_ids_unspent, include_spent_coins=False) + assert set(record.coin for record in rpc_result) == coins_unspent + # 5. Filter start/end height + filter_records = result.records[:10] + assert len(filter_records) == 10 + filter_coin_ids = [record.name() for record in filter_records] + filter_coins = set(record.coin for record in filter_records) + min_height = min(record.confirmed_block_height for record in filter_records) + max_height = max(record.confirmed_block_height for record in filter_records) + assert min_height != max_height + rpc_result = await client.get_coin_records_by_names(filter_coin_ids, start_height=min_height, end_height=max_height) + assert set(record.coin for record in rpc_result) == filter_coins + # 8. Test the failure case + with pytest.raises(ValueError, match="not found"): + await client.get_coin_records_by_names(coin_ids, include_spent_coins=False) + + @pytest.mark.asyncio async def test_did_endpoints(wallet_rpc_environment: WalletRpcTestEnvironment): env: WalletRpcTestEnvironment = wallet_rpc_environment