wallet: Fix and improve untrusted race caching (#16239)

This commit is contained in:
dustinface 2023-09-07 01:00:25 +07:00 committed by GitHub
parent 4ee7fe27be
commit cd1d22d29b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 230 additions and 58 deletions

View File

@ -1,7 +1,7 @@
from __future__ import annotations
import asyncio
from typing import Any, Optional, Tuple
from typing import Any, Dict, List, Optional, Set, Tuple
from chia.protocols.wallet_protocol import CoinState
from chia.types.blockchain_format.sized_bytes import bytes32
@ -19,6 +19,9 @@ class PeerRequestCache:
_blocks_validated: LRUCache[bytes32, uint32] # header_hash -> height
_block_signatures_validated: LRUCache[bytes32, uint32] # sig_hash -> height
_additions_in_block: LRUCache[Tuple[bytes32, bytes32], uint32] # header_hash, puzzle_hash -> height
# The wallet gets the state update before receiving the block. In untrusted mode the block is required for the
# coin state validation, so we cache them before we apply them once we received the block.
_race_cache: Dict[uint32, Set[CoinState]]
def __init__(self) -> None:
self._blocks = LRUCache(100)
@ -28,6 +31,7 @@ class PeerRequestCache:
self._blocks_validated = LRUCache(1000)
self._block_signatures_validated = LRUCache(1000)
self._additions_in_block = LRUCache(200)
self._race_cache = {}
def get_block(self, height: uint32) -> Optional[HeaderBlock]:
return self._blocks.get(height)
@ -87,6 +91,27 @@ class PeerRequestCache:
def in_additions_in_block(self, header_hash: bytes32, addition_ph: bytes32) -> bool:
return self._additions_in_block.get((header_hash, addition_ph)) is not None
def add_states_to_race_cache(self, coin_states: List[CoinState]) -> None:
for coin_state in coin_states:
created_height = 0 if coin_state.created_height is None else coin_state.created_height
spent_height = 0 if coin_state.spent_height is None else coin_state.spent_height
max_height = uint32(max(created_height, spent_height))
race_cache = self._race_cache.setdefault(max_height, set())
race_cache.add(coin_state)
def get_race_cache(self, height: int) -> Set[CoinState]:
return self._race_cache[uint32(height)]
def rollback_race_cache(self, *, fork_height: int) -> None:
self._race_cache = {
height: coin_states for height, coin_states in self._race_cache.items() if height <= fork_height
}
def cleanup_race_cache(self, *, min_height: int) -> None:
self._race_cache = {
height: coin_states for height, coin_states in self._race_cache.items() if height >= min_height
}
def clear_after_height(self, height: int) -> None:
# Remove any cached item which relates to an event that happened at a height above height.
new_blocks = LRUCache[uint32, HeaderBlock](self._blocks.capacity)

View File

@ -108,12 +108,6 @@ class Balance(Streamable):
pending_coin_removal_count: uint32 = uint32(0)
@dataclasses.dataclass(frozen=True)
class PeerPeak:
height: uint32
hash: bytes32
@dataclasses.dataclass
class WalletNode:
config: Dict[str, Any]
@ -137,10 +131,6 @@ class WalletNode:
synced_peers: Set[bytes32] = dataclasses.field(default_factory=set)
wallet_peers: Optional[WalletPeers] = None
peer_caches: Dict[bytes32, PeerRequestCache] = dataclasses.field(default_factory=dict)
# in Untrusted mode wallet might get the state update before receiving the block
race_cache: Dict[bytes32, Set[CoinState]] = dataclasses.field(default_factory=dict)
race_cache_hashes: List[Tuple[uint32, bytes32]] = dataclasses.field(default_factory=list)
node_peaks: Dict[bytes32, PeerPeak] = dataclasses.field(default_factory=dict)
validation_semaphore: Optional[asyncio.Semaphore] = None
local_node_synced: bool = False
LONG_SYNC_THRESHOLD: int = 300
@ -458,8 +448,6 @@ class WalletNode:
await proxy.close()
await asyncio.sleep(0.5) # https://docs.aiohttp.org/en/stable/client_advanced.html#graceful-shutdown
self.wallet_peers = None
self.race_cache = {}
self.race_cache_hashes = []
self._balance_cache = {}
def _set_state_changed_callback(self, callback: StateChangedProtocol) -> None:
@ -682,8 +670,6 @@ class WalletNode:
self.peer_caches.pop(peer.peer_node_id)
if peer.peer_node_id in self.synced_peers:
self.synced_peers.remove(peer.peer_node_id)
if peer.peer_node_id in self.node_peaks:
self.node_peaks.pop(peer.peer_node_id)
self.wallet_state_manager.state_changed("close_connection")
@ -830,7 +816,6 @@ class WalletNode:
peer: WSChiaConnection,
fork_height: Optional[uint32] = None,
height: Optional[uint32] = None,
header_hash: Optional[bytes32] = None,
) -> bool:
# Adds the state to the wallet state manager. If the peer is trusted, we do not validate. If the peer is
# untrusted we do, but we might not add the state, since we need to receive the new_peak message as well.
@ -863,6 +848,9 @@ class WalletNode:
# only one peer told us to rollback so only clear for that peer
cache.clear_after_height(fork_height)
self.log.info(f"clear_after_height {fork_height} for peer {peer}")
if not trusted:
# Rollback race_cache not in clear_after_height to avoid applying rollbacks from new peak processing
cache.rollback_race_cache(fork_height=fork_height)
all_tasks: List[asyncio.Task[None]] = []
target_concurrent_tasks: int = 30
@ -879,11 +867,6 @@ class WalletNode:
try:
assert self.validation_semaphore is not None
async with self.validation_semaphore:
if header_hash is not None:
assert height is not None
for inner_state in inner_states:
self.add_state_to_race_cache(header_hash, height, inner_state)
self.log.info(f"Added to race cache: {height}, {inner_state}")
valid_states = [
inner_state
for inner_state in inner_states
@ -893,7 +876,7 @@ class WalletNode:
async with self.wallet_state_manager.db_wrapper.writer():
self.log.info(
f"new coin state received ({inner_idx_start}-"
f"{inner_idx_start + len(inner_states) - 1}/ {len(items)})"
f"{inner_idx_start + len(inner_states) - 1}/ {len(updated_coin_states)})"
)
await self.wallet_state_manager.add_coin_states(valid_states, peer, fork_height)
except Exception as e:
@ -901,11 +884,28 @@ class WalletNode:
log_level = logging.DEBUG if peer.closed or self._shut_down else logging.ERROR
self.log.log(log_level, f"validate_and_add failed - exception: {e}, traceback: {tb}")
idx = 1
# Keep chunk size below 1000 just in case, windows has sqlite limits of 999 per query
# Untrusted has a smaller batch size since validation has to happen which takes a while
chunk_size: int = 900 if trusted else 10
for batch in to_batches(items, chunk_size):
reorged_coin_states = []
updated_coin_states = []
for coin_state in items:
if coin_state.created_height is None:
reorged_coin_states.append(coin_state)
else:
updated_coin_states.append(coin_state)
# Reorged coin states don't require any validation in untrusted mode, so we can just always apply them upfront
# instead of adding them to the race cache in untrusted mode.
for batch in to_batches(reorged_coin_states, chunk_size):
self.log.info(f"Process reorged states: ({len(batch.entries)} / {len(reorged_coin_states)})")
if not await self.wallet_state_manager.add_coin_states(batch.entries, peer, fork_height):
self.log.debug("Processing reorged states failed")
return False
idx = 1
for batch in to_batches(updated_coin_states, chunk_size):
if self._server is None:
self.log.error("No server")
await asyncio.gather(*all_tasks)
@ -916,18 +916,23 @@ class WalletNode:
return False
if trusted:
async with self.wallet_state_manager.db_wrapper.writer():
self.log.info(f"new coin state received ({idx}-{idx + len(batch.entries) - 1}/ {len(items)})")
self.log.info(
f"new coin state received ({idx}-{idx + len(batch.entries) - 1}/ {len(updated_coin_states)})"
)
if not await self.wallet_state_manager.add_coin_states(batch.entries, peer, fork_height):
return False
else:
while len(all_tasks) >= target_concurrent_tasks:
all_tasks = [task for task in all_tasks if not task.done()]
await asyncio.sleep(0.1)
if self._shut_down:
self.log.info("Terminating receipt and validation due to shut down request")
await asyncio.gather(*all_tasks)
return False
all_tasks.append(asyncio.create_task(validate_and_add(batch.entries, idx)))
if fork_height is not None:
cache.add_states_to_race_cache(batch.entries)
else:
while len(all_tasks) >= target_concurrent_tasks:
all_tasks = [task for task in all_tasks if not task.done()]
await asyncio.sleep(0.1)
if self._shut_down:
self.log.info("Terminating receipt and validation due to shut down request")
await asyncio.gather(*all_tasks)
return False
all_tasks.append(asyncio.create_task(validate_and_add(batch.entries, idx)))
idx += len(batch.entries)
still_connected = self._server is not None and peer.peer_node_id in self.server.all_connections
@ -955,20 +960,6 @@ class WalletNode:
def is_trusted(self, peer: WSChiaConnection) -> bool:
return self.server.is_trusted_peer(peer, self.config.get("trusted_peers", {}))
def add_state_to_race_cache(self, header_hash: bytes32, height: uint32, coin_state: CoinState) -> None:
# Clears old state that is no longer relevant
delete_threshold = 100
for rc_height, rc_hh in self.race_cache_hashes:
if height - delete_threshold >= rc_height:
self.race_cache.pop(rc_hh)
self.race_cache_hashes = [
(rc_height, rc_hh) for rc_height, rc_hh in self.race_cache_hashes if height - delete_threshold < rc_height
]
if header_hash not in self.race_cache:
self.race_cache[header_hash] = set()
self.race_cache[header_hash].add(coin_state)
async def state_update_received(self, request: CoinStateUpdate, peer: WSChiaConnection) -> None:
# This gets called every time there is a new coin or puzzle hash change in the DB
# that is of interest to this wallet. It is not guaranteed to come for every height. This message is guaranteed
@ -983,7 +974,6 @@ class WalletNode:
peer,
request.fork_height,
request.height,
request.peak_hash,
)
def get_full_node_peer(self) -> WSChiaConnection:
@ -1174,6 +1164,7 @@ class WalletNode:
else:
backtrack_fork_height = new_peak_hb.height - 1
cache = self.get_cache_for_peer(peer)
if peer.peer_node_id not in self.synced_peers:
# Edge case, this happens when the peak < WEIGHT_PROOF_RECENT_BLOCKS
# we still want to subscribe for all phs and coins.
@ -1182,12 +1173,10 @@ class WalletNode:
phs: List[bytes32] = await self.get_puzzle_hashes_to_subscribe()
ph_updates: List[CoinState] = await subscribe_to_phs(phs, peer, uint32(0))
coin_updates: List[CoinState] = await subscribe_to_coin_updates(all_coin_ids, peer, uint32(0))
peer_new_peak = self.node_peaks[peer.peer_node_id]
success = await self.add_states_from_peer(
ph_updates + coin_updates,
peer,
height=peer_new_peak.height,
header_hash=peer_new_peak.hash,
fork_height=uint32(max(backtrack_fork_height, 0)),
)
if success:
self.synced_peers.add(peer.peer_node_id)
@ -1198,10 +1187,16 @@ class WalletNode:
# For every block, we need to apply the cache from race_cache
for potential_height in range(backtrack_fork_height + 1, new_peak_hb.height + 1):
header_hash = self.wallet_state_manager.blockchain.height_to_hash(uint32(potential_height))
if header_hash in self.race_cache:
self.log.info(f"Receiving race state: {self.race_cache[header_hash]}")
await self.add_states_from_peer(list(self.race_cache[header_hash]), peer)
try:
race_cache = cache.get_race_cache(potential_height)
except KeyError:
continue
self.log.info(f"Apply race cache - height: {potential_height}, coin_states: {race_cache}")
await self.add_states_from_peer(list(race_cache), peer)
# Clear old entries that are no longer relevant
cache.cleanup_race_cache(min_height=backtrack_fork_height)
self.wallet_state_manager.state_changed("new_block")
self.log.info(f"Finished processing new peak of {new_peak_hb.height}")

View File

@ -8,7 +8,7 @@ from chia.server.ws_connection import WSChiaConnection
from chia.types.mempool_inclusion_status import MempoolInclusionStatus
from chia.util.api_decorators import api_request
from chia.util.errors import Err
from chia.wallet.wallet_node import PeerPeak, WalletNode
from chia.wallet.wallet_node import WalletNode
class WalletNodeAPI:
@ -44,7 +44,6 @@ class WalletNodeAPI:
"""
The full node sent as a new peak
"""
self.wallet_node.node_peaks[peer.peer_node_id] = PeerPeak(peak.height, peak.header_hash)
# For trusted peers check if there are untrusted peers, if so make sure to disconnect them if the trusted node
# is synced.
if self.wallet_node.is_trusted(peer):

View File

@ -1,5 +1,6 @@
from __future__ import annotations
import logging
import sys
from pathlib import Path
from typing import Any, Dict, List, Optional
@ -7,6 +8,7 @@ from typing import Any, Dict, List, Optional
import pytest
from blspy import PrivateKey
from chia.protocols.wallet_protocol import CoinState
from chia.simulator.block_tools import test_constants
from chia.simulator.setup_nodes import SimulatorsAndWallets
from chia.simulator.time_out_assert import time_out_assert
@ -16,6 +18,7 @@ from chia.util.config import load_config
from chia.util.ints import uint16, uint32, uint128
from chia.util.keychain import Keychain, KeyData, generate_mnemonic
from chia.wallet.wallet_node import Balance, WalletNode
from tests.util.misc import CoinGenerator
@pytest.mark.asyncio
@ -384,3 +387,39 @@ async def test_get_balance(
# Restart one more time and make sure the balance is still correct after start
await restart_with_fingerprint(initial_fingerprint)
assert await wallet_node.get_balance(wallet_id) == expected_more_balance
@pytest.mark.asyncio
async def test_add_states_from_peer_reorg_failure(
simulator_and_wallet: SimulatorsAndWallets, self_hostname: str, caplog: pytest.LogCaptureFixture
) -> None:
[full_node_api], [(wallet_node, wallet_server)], _ = simulator_and_wallet
await wallet_server.start_client(PeerInfo(self_hostname, uint16(full_node_api.server._port)), None)
wallet = wallet_node.wallet_state_manager.main_wallet
await full_node_api.farm_rewards_to_wallet(1, wallet)
coin_generator = CoinGenerator()
coin_states = [CoinState(coin_generator.get().coin, None, None)]
with caplog.at_level(logging.DEBUG):
full_node_peer = list(wallet_server.all_connections.values())[0]
# Close the connection to trigger a state processing failure during reorged coin processing.
await full_node_peer.close()
assert not await wallet_node.add_states_from_peer(coin_states, full_node_peer)
assert "Processing reorged states failed" in caplog.text
@pytest.mark.asyncio
async def test_add_states_from_peer_untrusted_shutdown(
simulator_and_wallet: SimulatorsAndWallets, self_hostname: str, caplog: pytest.LogCaptureFixture
) -> None:
[full_node_api], [(wallet_node, wallet_server)], _ = simulator_and_wallet
await wallet_server.start_client(PeerInfo(self_hostname, uint16(full_node_api.server._port)), None)
wallet = wallet_node.wallet_state_manager.main_wallet
await full_node_api.farm_rewards_to_wallet(1, wallet)
# Close to trigger the shutdown
wallet_node._close()
coin_generator = CoinGenerator()
# Generate enough coin states to fill up the max number validation/add tasks.
coin_states = [CoinState(coin_generator.get().coin, i, i) for i in range(3000)]
with caplog.at_level(logging.INFO):
assert not await wallet_node.add_states_from_peer(coin_states, list(wallet_server.all_connections.values())[0])
assert "Terminating receipt and validation due to shut down request" in caplog.text

View File

@ -1,11 +1,35 @@
from __future__ import annotations
from typing import Collection, List, Optional, Tuple
from typing import Collection, Dict, List, Optional, Set, Tuple
import pytest
from chia_rs import Coin, CoinState
from chia.types.blockchain_format.sized_bytes import bytes32
from chia.util.ints import uint64
from chia.wallet.util.peer_request_cache import PeerRequestCache
from chia.wallet.util.wallet_sync_utils import sort_coin_states
coin_states = [
CoinState(Coin(bytes32(b"\00" * 32), bytes32(b"\00" * 32), uint64(1)), None, None),
CoinState(Coin(bytes32(b"\00" * 32), bytes32(b"\11" * 32), uint64(1)), None, 1),
CoinState(Coin(bytes32(b"\00" * 32), bytes32(b"\22" * 32), uint64(1)), 1, 1),
CoinState(Coin(bytes32(b"\00" * 32), bytes32(b"\33" * 32), uint64(1)), 1, 1),
CoinState(Coin(bytes32(b"\00" * 32), bytes32(b"\44" * 32), uint64(1)), 2, 1),
CoinState(Coin(bytes32(b"\00" * 32), bytes32(b"\55" * 32), uint64(1)), 2, 2),
CoinState(Coin(bytes32(b"\00" * 32), bytes32(b"\66" * 32), uint64(1)), 20, 10),
CoinState(Coin(bytes32(b"\00" * 32), bytes32(b"\77" * 32), uint64(1)), None, 20),
]
def assert_race_cache(cache: PeerRequestCache, expected_entries: Dict[int, Set[CoinState]]) -> None:
for i in range(100):
if i in expected_entries:
assert cache.get_race_cache(i) == expected_entries[i], f"failed for {i}"
else:
with pytest.raises(KeyError):
cache.get_race_cache(i)
def dummy_coin_state(*, created_height: Optional[int], spent_height: Optional[int]) -> CoinState:
return CoinState(Coin(bytes(b"0" * 32), bytes(b"0" * 32), 0), spent_height, created_height)
@ -35,3 +59,93 @@ def test_sort_coin_states() -> None:
unsorted_coin_states = set(sorted_coin_states.copy())
assert heights(unsorted_coin_states) != heights(sorted_coin_states)
assert heights(sort_coin_states(unsorted_coin_states)) == heights(sorted_coin_states)
def test_add_states_to_race_cache() -> None:
cache = PeerRequestCache()
expected_entries: Dict[int, Set[CoinState]] = {}
assert_race_cache(cache, expected_entries)
# Repeated adding of the same coin state should not have any impact
expected_entries[0] = {coin_states[0]}
for i in range(3):
cache.add_states_to_race_cache(coin_states[0:1])
assert_race_cache(cache, expected_entries)
# Add a coin state with max height 1
cache.add_states_to_race_cache(coin_states[1:2])
expected_entries[1] = {coin_states[1]}
assert_race_cache(cache, expected_entries)
# Add two more with max height 1
cache.add_states_to_race_cache(coin_states[2:4])
expected_entries[1] = {*coin_states[1:4]}
assert_race_cache(cache, expected_entries)
# Add one with max height 2
cache.add_states_to_race_cache(coin_states[4:5])
expected_entries[2] = {coin_states[4]}
assert_race_cache(cache, expected_entries)
# Adding all again should add all the remaining states
cache.add_states_to_race_cache(coin_states)
expected_entries[0] = {coin_states[0]}
expected_entries[2] = {*coin_states[4:6]}
expected_entries[20] = {*coin_states[6:8]}
assert_race_cache(cache, expected_entries)
def test_cleanup_race_cache() -> None:
cache = PeerRequestCache()
cache.add_states_to_race_cache(coin_states)
expected_race_cache = {
0: {coin_states[0]},
1: {*coin_states[1:4]},
2: {*coin_states[4:6]},
20: {*coin_states[6:8]},
}
assert_race_cache(cache, expected_race_cache)
# Should not have an impact because 0 is the min height
cache.cleanup_race_cache(min_height=0)
assert_race_cache(cache, expected_race_cache)
# Drop all below 19
cache.cleanup_race_cache(min_height=1)
expected_race_cache.pop(0)
assert_race_cache(cache, expected_race_cache)
# Drop all below 19
cache.cleanup_race_cache(min_height=19)
expected_race_cache.pop(1)
expected_race_cache.pop(2)
assert_race_cache(cache, expected_race_cache)
# This should clear the cache
cache.cleanup_race_cache(min_height=100)
expected_race_cache.clear()
assert_race_cache(cache, expected_race_cache)
def test_rollback_race_cache() -> None:
cache = PeerRequestCache()
cache.add_states_to_race_cache(coin_states)
expected_race_cache = {
0: {coin_states[0]},
1: {*coin_states[1:4]},
2: {*coin_states[4:6]},
20: {*coin_states[6:8]},
}
assert_race_cache(cache, expected_race_cache)
# Should not have an impact because 20 is the max height
cache.rollback_race_cache(fork_height=20)
assert_race_cache(cache, expected_race_cache)
# Drop all above 19
cache.rollback_race_cache(fork_height=19)
expected_race_cache.pop(20)
assert_race_cache(cache, expected_race_cache)
# Drop all above 0
cache.rollback_race_cache(fork_height=0)
expected_race_cache.pop(1)
expected_race_cache.pop(2)
assert_race_cache(cache, expected_race_cache)
# This should clear the cache
cache.rollback_race_cache(fork_height=-1)
expected_race_cache.clear()
assert_race_cache(cache, expected_race_cache)