chia-blockchain/chia/wallet/wallet_node.py
Arvid Norberg 95e5b97557
Use bls from chia rs (#16715)
* replace blspy imports with chia_rs imports for BLS types

* remove blspy-stubs, since we're dropping the blspy dependency. chia_rs has type stubs already
2023-11-07 09:06:52 -08:00

1699 lines
81 KiB
Python

from __future__ import annotations
import asyncio
import dataclasses
import logging
import multiprocessing
import random
import sys
import time
import traceback
from pathlib import Path
from typing import TYPE_CHECKING, Any, ClassVar, Dict, List, Optional, Set, Tuple, Union, cast
import aiosqlite
from chia_rs import AugSchemeMPL, G1Element, G2Element, PrivateKey
from packaging.version import Version
from chia.consensus.blockchain import AddBlockResult
from chia.consensus.constants import ConsensusConstants
from chia.daemon.keychain_proxy import KeychainProxy, connect_to_keychain_and_validate, wrap_local_keychain
from chia.full_node.full_node_api import FullNodeAPI
from chia.protocols.full_node_protocol import RequestProofOfWeight, RespondProofOfWeight
from chia.protocols.protocol_message_types import ProtocolMessageTypes
from chia.protocols.wallet_protocol import (
CoinState,
CoinStateUpdate,
NewPeakWallet,
RegisterForCoinUpdates,
RequestBlockHeader,
RequestChildren,
RespondBlockHeader,
RespondChildren,
RespondToCoinUpdates,
SendTransaction,
)
from chia.rpc.rpc_server import StateChangedProtocol, default_get_connections
from chia.server.node_discovery import WalletPeers
from chia.server.outbound_message import Message, NodeType, make_msg
from chia.server.peer_store_resolver import PeerStoreResolver
from chia.server.server import ChiaServer
from chia.server.ws_connection import WSChiaConnection
from chia.types.blockchain_format.coin import Coin
from chia.types.blockchain_format.sized_bytes import bytes32
from chia.types.header_block import HeaderBlock
from chia.types.mempool_inclusion_status import MempoolInclusionStatus
from chia.types.spend_bundle import SpendBundle
from chia.types.weight_proof import WeightProof
from chia.util.config import (
WALLET_PEERS_PATH_KEY_DEPRECATED,
lock_and_load_config,
process_config_start_method,
save_config,
)
from chia.util.db_wrapper import manage_connection
from chia.util.errors import KeychainIsEmpty, KeychainIsLocked, KeychainKeyNotFound, KeychainProxyConnectionFailure
from chia.util.hash import std_hash
from chia.util.ints import uint16, uint32, uint64, uint128
from chia.util.keychain import Keychain
from chia.util.misc import to_batches
from chia.util.path import path_from_root
from chia.util.profiler import mem_profile_task, profile_task
from chia.util.streamable import Streamable, streamable
from chia.wallet.puzzles.clawback.metadata import AutoClaimSettings
from chia.wallet.transaction_record import TransactionRecord
from chia.wallet.util.new_peak_queue import NewPeakItem, NewPeakQueue, NewPeakQueueTypes
from chia.wallet.util.peer_request_cache import PeerRequestCache, can_use_peer_request_cache
from chia.wallet.util.wallet_sync_utils import (
PeerRequestException,
fetch_header_blocks_in_range,
request_and_validate_additions,
request_and_validate_removals,
request_header_blocks,
sort_coin_states,
subscribe_to_coin_updates,
subscribe_to_phs,
)
from chia.wallet.util.wallet_types import CoinType, WalletType
from chia.wallet.wallet_state_manager import WalletStateManager
from chia.wallet.wallet_weight_proof_handler import WalletWeightProofHandler, get_wp_fork_point
def get_wallet_db_path(root_path: Path, config: Dict[str, Any], key_fingerprint: str) -> Path:
"""
Construct a path to the wallet db. Uses config values and the wallet key's fingerprint to
determine the wallet db filename.
"""
db_path_replaced: str = (
config["database_path"].replace("CHALLENGE", config["selected_network"]).replace("KEY", key_fingerprint)
)
# "v2_r1" is the current wallet db version identifier
if "v2_r1" not in db_path_replaced:
db_path_replaced = db_path_replaced.replace("v2", "v2_r1").replace("v1", "v2_r1")
path: Path = path_from_root(root_path, db_path_replaced)
return path
@streamable
@dataclasses.dataclass(frozen=True)
class Balance(Streamable):
confirmed_wallet_balance: uint128 = uint128(0)
unconfirmed_wallet_balance: uint128 = uint128(0)
spendable_balance: uint128 = uint128(0)
pending_change: uint64 = uint64(0)
max_send_amount: uint128 = uint128(0)
unspent_coin_count: uint32 = uint32(0)
pending_coin_removal_count: uint32 = uint32(0)
@dataclasses.dataclass
class WalletNode:
if TYPE_CHECKING:
from chia.rpc.rpc_server import RpcServiceProtocol
_protocol_check: ClassVar[RpcServiceProtocol] = cast("WalletNode", None)
config: Dict[str, Any]
root_path: Path
constants: ConsensusConstants
local_keychain: Optional[Keychain] = None
log: logging.Logger = logging.getLogger(__name__)
# Sync data
state_changed_callback: Optional[StateChangedProtocol] = None
_wallet_state_manager: Optional[WalletStateManager] = None
_weight_proof_handler: Optional[WalletWeightProofHandler] = None
_server: Optional[ChiaServer] = None
sync_task: Optional[asyncio.Task[None]] = None
logged_in_fingerprint: Optional[int] = None
logged_in: bool = False
_keychain_proxy: Optional[KeychainProxy] = None
_balance_cache: Dict[int, Balance] = dataclasses.field(default_factory=dict)
# Peers that we have long synced to
synced_peers: Set[bytes32] = dataclasses.field(default_factory=set)
wallet_peers: Optional[WalletPeers] = None
peer_caches: Dict[bytes32, PeerRequestCache] = dataclasses.field(default_factory=dict)
validation_semaphore: Optional[asyncio.Semaphore] = None
local_node_synced: bool = False
LONG_SYNC_THRESHOLD: int = 300
last_wallet_tx_resend_time: int = 0
# Duration in seconds
coin_state_retry_seconds: int = 10
wallet_tx_resend_timeout_secs: int = 1800
_new_peak_queue: Optional[NewPeakQueue] = None
_shut_down: bool = False
_process_new_subscriptions_task: Optional[asyncio.Task[None]] = None
_retry_failed_states_task: Optional[asyncio.Task[None]] = None
_secondary_peer_sync_task: Optional[asyncio.Task[None]] = None
_tx_messages_in_progress: Dict[bytes32, List[bytes32]] = dataclasses.field(default_factory=dict)
@property
def keychain_proxy(self) -> KeychainProxy:
# This is a stop gap until the class usage is refactored such the values of
# integral attributes are known at creation of the instance.
if self._keychain_proxy is None:
raise RuntimeError("keychain proxy not assigned")
return self._keychain_proxy
@property
def wallet_state_manager(self) -> WalletStateManager:
# This is a stop gap until the class usage is refactored such the values of
# integral attributes are known at creation of the instance.
if self._wallet_state_manager is None:
raise RuntimeError("wallet state manager not assigned")
return self._wallet_state_manager
@property
def server(self) -> ChiaServer:
# This is a stop gap until the class usage is refactored such the values of
# integral attributes are known at creation of the instance.
if self._server is None:
raise RuntimeError("server not assigned")
return self._server
@property
def new_peak_queue(self) -> NewPeakQueue:
# This is a stop gap until the class usage is refactored such the values of
# integral attributes are known at creation of the instance.
if self._new_peak_queue is None:
raise RuntimeError("new peak queue not assigned")
return self._new_peak_queue
def get_connections(self, request_node_type: Optional[NodeType]) -> List[Dict[str, Any]]:
return default_get_connections(server=self.server, request_node_type=request_node_type)
async def ensure_keychain_proxy(self) -> KeychainProxy:
if self._keychain_proxy is None:
if self.local_keychain:
self._keychain_proxy = wrap_local_keychain(self.local_keychain, log=self.log)
else:
self._keychain_proxy = await connect_to_keychain_and_validate(self.root_path, self.log)
if not self._keychain_proxy:
raise KeychainProxyConnectionFailure()
return self._keychain_proxy
def get_cache_for_peer(self, peer: WSChiaConnection) -> PeerRequestCache:
if peer.peer_node_id not in self.peer_caches:
self.peer_caches[peer.peer_node_id] = PeerRequestCache()
return self.peer_caches[peer.peer_node_id]
def rollback_request_caches(self, reorg_height: int) -> None:
# Everything after reorg_height should be removed from the cache
for cache in self.peer_caches.values():
cache.clear_after_height(reorg_height)
async def get_key_for_fingerprint(self, fingerprint: Optional[int]) -> Optional[PrivateKey]:
try:
keychain_proxy = await self.ensure_keychain_proxy()
# Returns first private key if fingerprint is None
key = await keychain_proxy.get_key_for_fingerprint(fingerprint)
except KeychainIsEmpty:
self.log.warning("No keys present. Create keys with the UI, or with the 'chia keys' program.")
return None
except KeychainKeyNotFound:
self.log.warning(f"Key not found for fingerprint {fingerprint}")
return None
except KeychainIsLocked:
self.log.warning("Keyring is locked")
return None
except KeychainProxyConnectionFailure as e:
tb = traceback.format_exc()
self.log.error(f"Missing keychain_proxy: {e} {tb}")
raise # Re-raise so that the caller can decide whether to continue or abort
return key
async def get_private_key(self, fingerprint: Optional[int]) -> Optional[PrivateKey]:
"""
Attempt to get the private key for the given fingerprint. If the fingerprint is None,
get_key_for_fingerprint() will return the first private key. Similarly, if a key isn't
returned for the provided fingerprint, the first key will be returned.
"""
key: Optional[PrivateKey] = await self.get_key_for_fingerprint(fingerprint)
if key is None and fingerprint is not None:
key = await self.get_key_for_fingerprint(None)
if key is not None:
self.log.info(f"Using first key found (fingerprint: {key.get_g1().get_fingerprint()})")
return key
def set_resync_on_startup(self, fingerprint: int, enabled: bool = True) -> None:
with lock_and_load_config(self.root_path, "config.yaml") as config:
if enabled is True:
config["wallet"]["reset_sync_for_fingerprint"] = fingerprint
self.log.info("Enabled resync for wallet fingerprint: %s", fingerprint)
else:
self.log.debug(
"Trying to disable resync: %s [%s]", fingerprint, config["wallet"].get("reset_sync_for_fingerprint")
)
if config["wallet"].get("reset_sync_for_fingerprint") == fingerprint:
del config["wallet"]["reset_sync_for_fingerprint"]
self.log.info("Disabled resync for wallet fingerprint: %s", fingerprint)
save_config(self.root_path, "config.yaml", config)
def set_auto_claim(self, auto_claim_config: AutoClaimSettings) -> Dict[str, Any]:
if auto_claim_config.batch_size < 1:
auto_claim_config = dataclasses.replace(auto_claim_config, batch_size=uint16(50))
auto_claim_config_json = auto_claim_config.to_json_dict()
if "auto_claim" not in self.config or self.config["auto_claim"] != auto_claim_config_json:
# Update in memory config
self.config["auto_claim"] = auto_claim_config_json
# Update config file
with lock_and_load_config(self.root_path, "config.yaml") as config:
config["wallet"]["auto_claim"] = self.config["auto_claim"]
save_config(self.root_path, "config.yaml", config)
return auto_claim_config.to_json_dict()
async def reset_sync_db(self, db_path: Union[Path, str], fingerprint: int) -> bool:
conn: aiosqlite.Connection
# are not part of core wallet tables, but might appear later
ignore_tables = {"lineage_proofs_", "sqlite_", "MIGRATED_VALID_TIMES_TXS", "MIGRATED_VALID_TIMES_TRADES"}
required_tables = [
"coin_record",
"transaction_record",
"derivation_paths",
"users_wallets",
"users_nfts",
"action_queue",
"all_notification_ids",
"key_val_store",
"trade_records",
"trade_record_times",
"tx_times",
"pool_state_transitions",
"singletons",
"singleton_records",
"mirrors",
"launchers",
"interested_coins",
"interested_puzzle_hashes",
"unacknowledged_asset_tokens",
"coin_of_interest_to_trade_record",
"notifications",
"retry_store",
"unacknowledged_asset_token_states",
"vc_records",
"vc_proofs",
]
async with manage_connection(db_path) as conn:
self.log.info("Resetting wallet sync data...")
rows = list(await conn.execute_fetchall("SELECT name FROM sqlite_master WHERE type='table'"))
names = {x[0] for x in rows}
names = names - set(required_tables)
for name in names:
for ignore_name in ignore_tables:
if name.startswith(ignore_name):
break
else:
self.log.error(
f"Mismatch in expected schema to reset, found unexpected table: {name}. "
"Please check if you've run all migration scripts."
)
return False
await conn.execute("BEGIN")
commit = True
tables = [row[0] for row in rows]
try:
if "coin_record" in tables:
await conn.execute("DELETE FROM coin_record")
if "interested_coins" in tables:
await conn.execute("DELETE FROM interested_coins")
if "interested_puzzle_hashes" in tables:
await conn.execute("DELETE FROM interested_puzzle_hashes")
if "key_val_store" in tables:
await conn.execute("DELETE FROM key_val_store")
if "users_nfts" in tables:
await conn.execute("DELETE FROM users_nfts")
except aiosqlite.Error:
self.log.exception("Error resetting sync tables")
commit = False
finally:
try:
if commit:
self.log.info("Reset wallet sync data completed.")
await conn.execute("COMMIT")
else:
self.log.info("Reverting reset resync changes")
await conn.execute("ROLLBACK")
except aiosqlite.Error:
self.log.exception("Error finishing reset resync db")
# disable the resync in any case
self.set_resync_on_startup(fingerprint, False)
return commit
async def _start(self) -> None:
await self._start_with_fingerprint()
async def _start_with_fingerprint(
self,
fingerprint: Optional[int] = None,
) -> bool:
# Makes sure the coin_state_updates get higher priority than new_peak messages.
# Delayed instantiation until here to avoid errors.
# got Future <Future pending> attached to a different loop
self._new_peak_queue = NewPeakQueue(inner_queue=asyncio.PriorityQueue())
if not fingerprint:
fingerprint = self.get_last_used_fingerprint()
multiprocessing_start_method = process_config_start_method(config=self.config, log=self.log)
multiprocessing_context = multiprocessing.get_context(method=multiprocessing_start_method)
self._weight_proof_handler = WalletWeightProofHandler(self.constants, multiprocessing_context)
self.synced_peers = set()
private_key = await self.get_private_key(fingerprint)
if private_key is None:
self.log_out()
return False
# override with private key fetched in case it's different from what was passed
if fingerprint is None:
fingerprint = private_key.get_g1().get_fingerprint()
if self.config.get("enable_profiler", False):
if sys.getprofile() is not None:
self.log.warning("not enabling profiler, getprofile() is already set")
else:
asyncio.create_task(profile_task(self.root_path, "wallet", self.log))
if self.config.get("enable_memory_profiler", False):
asyncio.create_task(mem_profile_task(self.root_path, "wallet", self.log))
path: Path = get_wallet_db_path(self.root_path, self.config, str(fingerprint))
path.parent.mkdir(parents=True, exist_ok=True)
if self.config.get("reset_sync_for_fingerprint") == fingerprint:
await self.reset_sync_db(path, fingerprint)
self._wallet_state_manager = await WalletStateManager.create(
private_key,
self.config,
path,
self.constants,
self.server,
self.root_path,
self,
)
if self.state_changed_callback is not None:
self.wallet_state_manager.set_callback(self.state_changed_callback)
self.last_wallet_tx_resend_time = int(time.time())
self.wallet_tx_resend_timeout_secs = self.config.get("tx_resend_timeout_secs", 60 * 60)
self.wallet_state_manager.set_pending_callback(self._pending_tx_handler)
self._shut_down = False
self._process_new_subscriptions_task = asyncio.create_task(self._process_new_subscriptions())
self._retry_failed_states_task = asyncio.create_task(self._retry_failed_states())
self.sync_event = asyncio.Event()
self.log_in(private_key)
self.wallet_state_manager.state_changed("sync_changed")
# Populate the balance caches for all wallets
async with self.wallet_state_manager.lock:
for wallet_id in self.wallet_state_manager.wallets:
await self._update_balance_cache(wallet_id)
async with self.wallet_state_manager.puzzle_store.lock:
index = await self.wallet_state_manager.puzzle_store.get_last_derivation_path()
if index is None or index < self.wallet_state_manager.initial_num_public_keys - 1:
await self.wallet_state_manager.create_more_puzzle_hashes(from_zero=True)
if self.wallet_peers is None:
self.initialize_wallet_peers()
return True
def _close(self) -> None:
self.log.info("self._close")
self.log_out()
self._shut_down = True
if self._weight_proof_handler is not None:
self._weight_proof_handler.cancel_weight_proof_tasks()
if self._process_new_subscriptions_task is not None:
self._process_new_subscriptions_task.cancel()
if self._retry_failed_states_task is not None:
self._retry_failed_states_task.cancel()
if self._secondary_peer_sync_task is not None:
self._secondary_peer_sync_task.cancel()
async def _await_closed(self, shutting_down: bool = True) -> None:
self.log.info("self._await_closed")
if self._server is not None:
await self.server.close_all_connections()
if self.wallet_peers is not None:
await self.wallet_peers.ensure_is_closed()
if self._wallet_state_manager is not None:
await self.wallet_state_manager._await_closed()
self._wallet_state_manager = None
if shutting_down and self._keychain_proxy is not None:
proxy = self._keychain_proxy
self._keychain_proxy = None
await proxy.close()
await asyncio.sleep(0.5) # https://docs.aiohttp.org/en/stable/client_advanced.html#graceful-shutdown
self.wallet_peers = None
self._balance_cache = {}
def _set_state_changed_callback(self, callback: StateChangedProtocol) -> None:
self.state_changed_callback = callback
if self._wallet_state_manager is not None:
self.wallet_state_manager.set_callback(self.state_changed_callback)
self.wallet_state_manager.set_pending_callback(self._pending_tx_handler)
def _pending_tx_handler(self) -> None:
if self._wallet_state_manager is None:
return None
asyncio.create_task(self._resend_queue())
async def _resend_queue(self) -> None:
if self._shut_down or self._server is None or self._wallet_state_manager is None:
return None
for msg, sent_peers in await self._messages_to_resend():
if self._shut_down or self._server is None or self._wallet_state_manager is None:
return None
full_nodes = self.server.get_connections(NodeType.FULL_NODE)
for peer in full_nodes:
if peer.peer_node_id in sent_peers:
continue
msg_name: bytes32 = std_hash(msg.data)
if (
peer.peer_node_id in self._tx_messages_in_progress
and msg_name in self._tx_messages_in_progress[peer.peer_node_id]
):
continue
self.log.debug(f"sending: {msg}")
await peer.send_message(msg)
self._tx_messages_in_progress.setdefault(peer.peer_node_id, [])
self._tx_messages_in_progress[peer.peer_node_id].append(msg_name)
async def _messages_to_resend(self) -> List[Tuple[Message, Set[bytes32]]]:
if self._wallet_state_manager is None or self._shut_down:
return []
messages: List[Tuple[Message, Set[bytes32]]] = []
current_time = int(time.time())
retry_accepted_txs = False
if self.last_wallet_tx_resend_time < current_time - self.wallet_tx_resend_timeout_secs:
self.last_wallet_tx_resend_time = current_time
retry_accepted_txs = True
records: List[TransactionRecord] = await self.wallet_state_manager.tx_store.get_not_sent(
include_accepted_txs=retry_accepted_txs
)
for record in records:
if record.spend_bundle is None:
continue
msg = make_msg(ProtocolMessageTypes.send_transaction, SendTransaction(record.spend_bundle))
already_sent = set()
for peer, status, _ in record.sent_to:
if status == MempoolInclusionStatus.SUCCESS.value:
already_sent.add(bytes32.from_hexstr(peer))
messages.append((msg, already_sent))
return messages
async def _retry_failed_states(self) -> None:
while not self._shut_down:
try:
await asyncio.sleep(self.coin_state_retry_seconds)
if self.wallet_state_manager is None:
continue
states_to_retry = await self.wallet_state_manager.retry_store.get_all_states_to_retry()
for state, peer_id, fork_height in states_to_retry:
matching_peer = tuple(
p for p in self.server.get_connections(NodeType.FULL_NODE) if p.peer_node_id == peer_id
)
if len(matching_peer) == 0:
try:
peer = self.get_full_node_peer()
self.log.info(
f"disconnected from peer {peer_id}, state will retry with {peer.peer_node_id}"
)
except ValueError:
self.log.info(f"disconnected from all peers, cannot retry state: {state}")
continue
else:
peer = matching_peer[0]
async with self.wallet_state_manager.db_wrapper.writer():
self.log.info(f"retrying coin_state: {state}")
await self.wallet_state_manager.add_coin_states(
[state], peer, None if fork_height == 0 else fork_height
)
except asyncio.CancelledError:
self.log.info("Retry task cancelled, exiting.")
raise
async def _process_new_subscriptions(self) -> None:
while not self._shut_down:
# Here we process four types of messages in the queue, where the first one has higher priority (lower
# number in the queue), and priority decreases for each type.
peer: Optional[WSChiaConnection] = None
item: Optional[NewPeakItem] = None
try:
peer, item = None, None
item = await self.new_peak_queue.get()
assert item is not None
if item.item_type == NewPeakQueueTypes.COIN_ID_SUBSCRIPTION:
self.log.debug("Pulled from queue: %s %s", item.item_type.name, item.data)
# Subscriptions are the highest priority, because we don't want to process any more peaks or
# state updates until we are sure that we subscribed to everything that we need to. Otherwise,
# we might not be able to process some state.
coin_ids: List[bytes32] = item.data
for peer in self.server.get_connections(NodeType.FULL_NODE):
coin_states: List[CoinState] = await subscribe_to_coin_updates(coin_ids, peer, uint32(0))
if len(coin_states) > 0:
async with self.wallet_state_manager.lock:
await self.add_states_from_peer(coin_states, peer)
elif item.item_type == NewPeakQueueTypes.PUZZLE_HASH_SUBSCRIPTION:
self.log.debug("Pulled from queue: %s %s", item.item_type.name, item.data)
puzzle_hashes: List[bytes32] = item.data
for peer in self.server.get_connections(NodeType.FULL_NODE):
# Puzzle hash subscription
coin_states = await subscribe_to_phs(puzzle_hashes, peer, uint32(0))
if len(coin_states) > 0:
async with self.wallet_state_manager.lock:
await self.add_states_from_peer(coin_states, peer)
elif item.item_type == NewPeakQueueTypes.FULL_NODE_STATE_UPDATED:
# Note: this can take a while when we have a lot of transactions. We want to process these
# before new_peaks, since new_peak_wallet requires that we first obtain the state for that peak.
self.log.debug("Pulled from queue: %s %s", item.item_type.name, item.data[0])
coin_state_update = item.data[0]
peer = item.data[1]
assert peer is not None
await self.state_update_received(coin_state_update, peer)
elif item.item_type == NewPeakQueueTypes.NEW_PEAK_WALLET:
self.log.debug("Pulled from queue: %s %s", item.item_type.name, item.data[0])
# This can take a VERY long time, because it might trigger a long sync. It is OK if we miss some
# subscriptions or state updates, since all subscriptions and state updates will be handled by
# long_sync (up to the target height).
new_peak = item.data[0]
peer = item.data[1]
assert peer is not None
await self.new_peak_wallet(new_peak, peer)
# Check if any coin needs auto spending
if self.config.get("auto_claim", {}).get("enabled", False):
await self.wallet_state_manager.auto_claim_coins()
else:
self.log.debug("Pulled from queue: UNKNOWN %s", item.item_type)
assert False
except asyncio.CancelledError:
self.log.info("Queue task cancelled, exiting.")
raise
except Exception as e:
self.log.error(f"Exception handling {item}, {e} {traceback.format_exc()}")
if peer is not None:
await peer.close(9999)
def log_in(self, sk: PrivateKey) -> None:
self.logged_in_fingerprint = sk.get_g1().get_fingerprint()
self.logged_in = True
self.log.info(f"Wallet is logged in using key with fingerprint: {self.logged_in_fingerprint}")
try:
self.update_last_used_fingerprint()
except Exception:
self.log.exception("Non-fatal: Unable to update last used fingerprint.")
def log_out(self) -> None:
self.logged_in_fingerprint = None
self.logged_in = False
def update_last_used_fingerprint(self) -> None:
fingerprint = self.logged_in_fingerprint
assert fingerprint is not None
path = self.get_last_used_fingerprint_path()
path.parent.mkdir(parents=True, exist_ok=True)
path.write_text(str(fingerprint))
self.log.info(f"Updated last used fingerprint: {fingerprint}")
def get_last_used_fingerprint(self) -> Optional[int]:
fingerprint: Optional[int] = None
try:
path = self.get_last_used_fingerprint_path()
if path.exists():
fingerprint = int(path.read_text().strip())
except Exception:
self.log.exception("Non-fatal: Unable to read last used fingerprint.")
return fingerprint
def get_last_used_fingerprint_path(self) -> Path:
db_path: Path = path_from_root(self.root_path, self.config["database_path"])
fingerprint_path = db_path.parent / "last_used_fingerprint"
return fingerprint_path
def set_server(self, server: ChiaServer) -> None:
self._server = server
self.initialize_wallet_peers()
def initialize_wallet_peers(self) -> None:
self.server.on_connect = self.on_connect
network_name = self.config["selected_network"]
try:
default_port = self.config["network_overrides"]["config"][network_name]["default_full_node_port"]
except KeyError:
self.log.info("Default port field not found in config.")
default_port = None
connect_to_unknown_peers = self.config.get("connect_to_unknown_peers", True)
testing = self.config.get("testing", False)
if self.wallet_peers is None and connect_to_unknown_peers and not testing:
self.wallet_peers = WalletPeers(
self.server,
self.config["target_peer_count"],
PeerStoreResolver(
self.root_path,
self.config,
selected_network=network_name,
peers_file_path_key="wallet_peers_file_path",
legacy_peer_db_path_key=WALLET_PEERS_PATH_KEY_DEPRECATED,
default_peers_file_path="wallet/db/wallet_peers.dat",
),
self.config["introducer_peer"],
self.config.get("dns_servers", ["dns-introducer.chia.net"]),
self.config["peer_connect_interval"],
network_name,
default_port,
self.log,
)
asyncio.create_task(self.wallet_peers.start())
def on_disconnect(self, peer: WSChiaConnection) -> None:
if self.is_trusted(peer):
self.local_node_synced = False
self.initialize_wallet_peers()
if peer.peer_node_id in self.peer_caches:
self.peer_caches.pop(peer.peer_node_id)
if peer.peer_node_id in self.synced_peers:
self.synced_peers.remove(peer.peer_node_id)
if peer.peer_node_id in self._tx_messages_in_progress:
del self._tx_messages_in_progress[peer.peer_node_id]
self.wallet_state_manager.state_changed("close_connection")
async def on_connect(self, peer: WSChiaConnection) -> None:
if self._wallet_state_manager is None:
return None
if peer.protocol_version < Version("0.0.33"):
self.log.info("Disconnecting, full node running old software")
await peer.close()
trusted = self.is_trusted(peer)
if not trusted and self.local_node_synced:
await peer.close()
if peer.peer_node_id in self.synced_peers:
self.synced_peers.remove(peer.peer_node_id)
self.log.info(f"Connected peer {peer.get_peer_info()} is trusted: {trusted}")
messages_peer_ids = await self._messages_to_resend()
self.wallet_state_manager.state_changed("add_connection")
for msg, peer_ids in messages_peer_ids:
if peer.peer_node_id in peer_ids:
continue
await peer.send_message(msg)
if self.wallet_peers is not None:
await self.wallet_peers.on_connect(peer)
async def perform_atomic_rollback(self, fork_height: int, cache: Optional[PeerRequestCache] = None) -> None:
self.log.info(f"perform_atomic_rollback to {fork_height}")
# this is to start a write transaction
async with self.wallet_state_manager.db_wrapper.writer():
try:
removed_wallet_ids = await self.wallet_state_manager.reorg_rollback(fork_height)
await self.wallet_state_manager.blockchain.set_finished_sync_up_to(fork_height, in_rollback=True)
if cache is None:
self.rollback_request_caches(fork_height)
else:
cache.clear_after_height(fork_height)
except Exception as e:
tb = traceback.format_exc()
self.log.error(f"Exception while perform_atomic_rollback: {e} {tb}")
raise
else:
await self.wallet_state_manager.blockchain.clean_block_records()
for wallet_id in removed_wallet_ids:
self.wallet_state_manager.wallets.pop(wallet_id)
# this has to be called *after* the transaction commits, otherwise it
# won't see the changes (since we spawn a new task to handle potential
# resends)
self._pending_tx_handler()
async def long_sync(
self,
target_height: uint32,
full_node: WSChiaConnection,
fork_height: int,
*,
rollback: bool,
) -> None:
"""
Sync algorithm:
- Download and verify weight proof (if not trusted)
- Roll back anything after the fork point (if rollback=True)
- Subscribe to all puzzle_hashes over and over until there are no more updates
- Subscribe to all coin_ids over and over until there are no more updates
- rollback=False means that we are just double-checking with this peer to make sure we don't have any
missing transactions, so we don't need to rollback
"""
def is_new_state_update(cs: CoinState) -> bool:
if cs.spent_height is None and cs.created_height is None:
return True
if cs.spent_height is not None and cs.spent_height >= fork_height:
return True
if cs.created_height is not None and cs.created_height >= fork_height:
return True
return False
trusted: bool = self.is_trusted(full_node)
self.log.info(f"Starting sync trusted: {trusted} to peer {full_node.peer_info.host}")
start_time = time.time()
if rollback:
# we should clear all peers since this is a full rollback
await self.perform_atomic_rollback(fork_height)
await self.update_ui()
# We only process new state updates to avoid slow reprocessing. We set the sync height after adding
# Things, so we don't have to reprocess these later. There can be many things in ph_update_res.
already_checked_ph: Set[bytes32] = set()
while not self._shut_down:
await self.wallet_state_manager.create_more_puzzle_hashes()
all_puzzle_hashes = await self.get_puzzle_hashes_to_subscribe()
not_checked_puzzle_hashes = set(all_puzzle_hashes) - already_checked_ph
if not_checked_puzzle_hashes == set():
break
for batch in to_batches(not_checked_puzzle_hashes, 1000):
ph_update_res: List[CoinState] = await subscribe_to_phs(batch.entries, full_node, 0)
ph_update_res = list(filter(is_new_state_update, ph_update_res))
if not await self.add_states_from_peer(ph_update_res, full_node):
# If something goes wrong, abort sync
return
already_checked_ph.update(not_checked_puzzle_hashes)
self.log.info(f"Successfully subscribed and updated {len(already_checked_ph)} puzzle hashes")
# The number of coin id updates are usually going to be significantly less than ph updates, so we can
# sync from 0 every time.
already_checked_coin_ids: Set[bytes32] = set()
while not self._shut_down:
all_coin_ids = await self.get_coin_ids_to_subscribe()
not_checked_coin_ids = set(all_coin_ids) - already_checked_coin_ids
if not_checked_coin_ids == set():
break
for batch in to_batches(not_checked_coin_ids, 1000):
c_update_res: List[CoinState] = await subscribe_to_coin_updates(batch.entries, full_node, 0)
if not await self.add_states_from_peer(c_update_res, full_node):
# If something goes wrong, abort sync
return
already_checked_coin_ids.update(not_checked_coin_ids)
self.log.info(f"Successfully subscribed and updated {len(already_checked_coin_ids)} coin ids")
# Only update this fully when the entire sync has completed
await self.wallet_state_manager.blockchain.set_finished_sync_up_to(target_height)
if trusted:
self.local_node_synced = True
self.wallet_state_manager.state_changed("new_block")
self.synced_peers.add(full_node.peer_node_id)
await self.update_ui()
self.log.info(f"Sync (trusted: {trusted}) duration was: {time.time() - start_time}")
async def add_states_from_peer(
self,
items_input: List[CoinState],
peer: WSChiaConnection,
fork_height: Optional[uint32] = None,
height: Optional[uint32] = None,
) -> bool:
# Adds the state to the wallet state manager. If the peer is trusted, we do not validate. If the peer is
# untrusted we do, but we might not add the state, since we need to receive the new_peak message as well.
assert self._wallet_state_manager is not None
trusted = self.is_trusted(peer)
# Validate states in parallel, apply serial
# TODO: optimize fetching
if self.validation_semaphore is None:
self.validation_semaphore = asyncio.Semaphore(10)
# Rollback is handled in wallet_short_sync_backtrack for untrusted peers, so we don't need to do it here.
# Also it's not safe to rollback, an untrusted peer can give us old fork point and make our TX disappear.
# wallet_short_sync_backtrack can safely rollback because we validated the weight for the new peak so we
# know the peer is telling the truth about the reorg.
# If there is a fork, we need to ensure that we roll back in trusted mode to properly handle reorgs
cache: PeerRequestCache = self.get_cache_for_peer(peer)
if (
trusted
and fork_height is not None
and height is not None
and fork_height != height - 1
and peer.peer_node_id in self.synced_peers
):
# only one peer told us to rollback so only clear for that peer
await self.perform_atomic_rollback(fork_height, cache=cache)
else:
if fork_height is not None:
# only one peer told us to rollback so only clear for that peer
cache.clear_after_height(fork_height)
self.log.info(f"clear_after_height {fork_height} for peer {peer}")
if not trusted:
# Rollback race_cache not in clear_after_height to avoid applying rollbacks from new peak processing
cache.rollback_race_cache(fork_height=fork_height)
all_tasks: List[asyncio.Task[None]] = []
target_concurrent_tasks: int = 30
# Ensure the list is sorted
unique_items = set(items_input)
before = len(unique_items)
items = await self.wallet_state_manager.filter_spam(sort_coin_states(unique_items))
num_filtered = before - len(items)
if num_filtered > 0:
self.log.info(f"Filtered {num_filtered} spam transactions")
async def validate_and_add(inner_states: List[CoinState], inner_idx_start: int) -> None:
try:
assert self.validation_semaphore is not None
async with self.validation_semaphore:
valid_states = [
inner_state
for inner_state in inner_states
if await self.validate_received_state_from_peer(inner_state, peer, cache, fork_height)
]
if len(valid_states) > 0:
async with self.wallet_state_manager.db_wrapper.writer():
self.log.info(
f"new coin state received ({inner_idx_start}-"
f"{inner_idx_start + len(inner_states) - 1}/ {len(updated_coin_states)})"
)
await self.wallet_state_manager.add_coin_states(valid_states, peer, fork_height)
except Exception as e:
tb = traceback.format_exc()
log_level = logging.DEBUG if peer.closed or self._shut_down else logging.ERROR
self.log.log(log_level, f"validate_and_add failed - exception: {e}, traceback: {tb}")
# Keep chunk size below 1000 just in case, windows has sqlite limits of 999 per query
# Untrusted has a smaller batch size since validation has to happen which takes a while
chunk_size: int = 900 if trusted else 10
reorged_coin_states = []
updated_coin_states = []
for coin_state in items:
if coin_state.created_height is None:
reorged_coin_states.append(coin_state)
else:
updated_coin_states.append(coin_state)
# Reorged coin states don't require any validation in untrusted mode, so we can just always apply them upfront
# instead of adding them to the race cache in untrusted mode.
for batch in to_batches(reorged_coin_states, chunk_size):
self.log.info(f"Process reorged states: ({len(batch.entries)} / {len(reorged_coin_states)})")
if not await self.wallet_state_manager.add_coin_states(batch.entries, peer, fork_height):
self.log.debug("Processing reorged states failed")
return False
idx = 1
for batch in to_batches(updated_coin_states, chunk_size):
if self._server is None:
self.log.error("No server")
await asyncio.gather(*all_tasks)
return False
if peer.peer_node_id not in self.server.all_connections:
self.log.error(f"Disconnected from peer {peer.peer_node_id} host {peer.peer_info.host}")
await asyncio.gather(*all_tasks)
return False
if trusted:
async with self.wallet_state_manager.db_wrapper.writer():
self.log.info(
f"new coin state received ({idx}-{idx + len(batch.entries) - 1}/ {len(updated_coin_states)})"
)
if not await self.wallet_state_manager.add_coin_states(batch.entries, peer, fork_height):
return False
else:
if fork_height is not None:
cache.add_states_to_race_cache(batch.entries)
else:
while len(all_tasks) >= target_concurrent_tasks:
all_tasks = [task for task in all_tasks if not task.done()]
await asyncio.sleep(0.1)
if self._shut_down:
self.log.info("Terminating receipt and validation due to shut down request")
await asyncio.gather(*all_tasks)
return False
all_tasks.append(asyncio.create_task(validate_and_add(batch.entries, idx)))
idx += len(batch.entries)
still_connected = self._server is not None and peer.peer_node_id in self.server.all_connections
await asyncio.gather(*all_tasks)
await self.update_ui()
return still_connected and self._server is not None and peer.peer_node_id in self.server.all_connections
def is_timestamp_in_sync(self, timestamp: uint64) -> bool:
return self.config.get("testing", False) or uint64(time.time()) - timestamp < 600
def is_trusted(self, peer: WSChiaConnection) -> bool:
return self.server.is_trusted_peer(peer, self.config.get("trusted_peers", {}))
async def state_update_received(self, request: CoinStateUpdate, peer: WSChiaConnection) -> None:
# This gets called every time there is a new coin or puzzle hash change in the DB
# that is of interest to this wallet. It is not guaranteed to come for every height. This message is guaranteed
# to come before the corresponding new_peak for each height. We handle this differently for trusted and
# untrusted peers. For trusted, we always process the state, and we process reorgs as well.
for coin in request.items:
self.log.info(f"request coin: {coin.coin.name().hex()}{coin}")
async with self.wallet_state_manager.lock:
await self.add_states_from_peer(
request.items,
peer,
request.fork_height,
request.height,
)
def get_full_node_peer(self) -> WSChiaConnection:
"""
Get a full node, preferring synced & trusted > synced & untrusted > unsynced & trusted > unsynced & untrusted
"""
full_nodes: List[WSChiaConnection] = self.get_full_node_peers_in_order()
if len(full_nodes) == 0:
raise ValueError("No peer connected")
return full_nodes[0]
def get_full_node_peers_in_order(self) -> List[WSChiaConnection]:
"""
Get all full nodes sorted:
preferring synced & trusted > synced & untrusted > unsynced & trusted > unsynced & untrusted
"""
if self._server is None:
return []
synced_and_trusted: List[WSChiaConnection] = []
synced: List[WSChiaConnection] = []
trusted: List[WSChiaConnection] = []
neither: List[WSChiaConnection] = []
all_nodes: List[WSChiaConnection] = self.server.get_connections(NodeType.FULL_NODE)
random.shuffle(all_nodes)
for node in all_nodes:
we_synced_to_it = node.peer_node_id in self.synced_peers
is_trusted = self.is_trusted(node)
if we_synced_to_it and is_trusted:
synced_and_trusted.append(node)
elif we_synced_to_it:
synced.append(node)
elif is_trusted:
trusted.append(node)
else:
neither.append(node)
return synced_and_trusted + synced + trusted + neither
async def get_timestamp_for_height_from_peer(self, height: uint32, peer: WSChiaConnection) -> Optional[uint64]:
"""
Returns the timestamp for transaction block at h=height, if not transaction block, backtracks until it finds
a transaction block
"""
cache = self.get_cache_for_peer(peer)
request_height: int = height
while request_height >= 0:
cached_timestamp = cache.get_height_timestamp(uint32(request_height))
if cached_timestamp is not None:
return cached_timestamp
block = cache.get_block(uint32(request_height))
if block is None:
self.log.debug(f"get_timestamp_for_height_from_peer cache miss for height {request_height}")
response: Optional[List[HeaderBlock]] = await request_header_blocks(
peer, uint32(request_height), uint32(request_height)
)
if response is not None and len(response) > 0:
self.log.debug(f"get_timestamp_for_height_from_peer add to cache for height {request_height}")
cache.add_to_blocks(response[0])
block = response[0]
elif request_height < height:
# The peer might be slightly behind but still synced, so we should allow fetching one more block
break
else:
self.log.debug(f"get_timestamp_for_height_from_peer use cached block for height {request_height}")
if block is not None and block.foliage_transaction_block is not None:
return block.foliage_transaction_block.timestamp
request_height -= 1
return None
async def get_timestamp_for_height(self, height: uint32) -> uint64:
for peer in self.get_full_node_peers_in_order():
timestamp = await self.get_timestamp_for_height_from_peer(height, peer)
if timestamp is not None:
return timestamp
raise PeerRequestException("Error fetching timestamp from all peers")
async def new_peak_wallet(self, new_peak: NewPeakWallet, peer: WSChiaConnection) -> None:
if self._wallet_state_manager is None:
# When logging out of wallet
self.log.debug("state manager is None (shutdown)")
return
trusted: bool = self.is_trusted(peer)
peak_hb: Optional[HeaderBlock] = await self.wallet_state_manager.blockchain.get_peak_block()
if peak_hb is not None and new_peak.weight < peak_hb.weight:
# Discards old blocks, but accepts blocks that are equal in weight to peak
self.log.debug("skip block with lower weight.")
return
request = RequestBlockHeader(new_peak.height)
response: Optional[RespondBlockHeader] = await peer.call_api(FullNodeAPI.request_block_header, request)
if response is None:
self.log.warning(f"Peer {peer.get_peer_info()} did not respond in time.")
await peer.close(120)
return
new_peak_hb: HeaderBlock = response.header_block
# check response is what we asked for
if (
new_peak_hb.header_hash != new_peak.header_hash
or new_peak_hb.weight != new_peak.weight
or new_peak_hb.height != new_peak.height
):
self.log.warning(f"bad header block response from Peer {peer.get_peer_info()}.")
# todo maybe accept the block if
# new_peak_hb.height == new_peak.height and new_peak_hb.weight >= new_peak.weight
# dont disconnect from peer, this might be a reorg
return
latest_timestamp = await self.get_timestamp_for_height_from_peer(new_peak_hb.height, peer)
if latest_timestamp is None or not self.is_timestamp_in_sync(latest_timestamp):
if trusted:
self.log.debug(f"Trusted peer {peer.get_peer_info()} is not synced.")
else:
self.log.warning(f"Non-trusted peer {peer.get_peer_info()} is not synced, disconnecting")
await peer.close(120)
return
if self.is_trusted(peer):
await self.new_peak_from_trusted(new_peak_hb, latest_timestamp, peer)
else:
if not await self.new_peak_from_untrusted(new_peak_hb, peer):
return
if peer.peer_node_id in self.synced_peers:
await self.wallet_state_manager.blockchain.set_finished_sync_up_to(new_peak.height)
# todo why do we call this if there was an exception / the sync is not finished
async with self.wallet_state_manager.lock:
await self.wallet_state_manager.new_peak(new_peak.height)
async def new_peak_from_trusted(
self, new_peak_hb: HeaderBlock, latest_timestamp: uint64, peer: WSChiaConnection
) -> None:
async with self.wallet_state_manager.set_sync_mode(new_peak_hb.height) as current_height:
await self.wallet_state_manager.blockchain.set_peak_block(new_peak_hb, latest_timestamp)
# Sync to trusted node if we haven't done so yet. As long as we have synced once (and not
# disconnected), we assume that the full node will continue to give us state updates, so we do
# not need to resync.
if peer.peer_node_id not in self.synced_peers:
await self.long_sync(new_peak_hb.height, peer, uint32(max(0, current_height - 256)), rollback=True)
async def new_peak_from_untrusted(self, new_peak_hb: HeaderBlock, peer: WSChiaConnection) -> bool:
far_behind: bool = (
new_peak_hb.height - await self.wallet_state_manager.blockchain.get_finished_sync_up_to()
> self.LONG_SYNC_THRESHOLD
)
if new_peak_hb.height < self.constants.WEIGHT_PROOF_RECENT_BLOCKS:
# this is the case happens chain is shorter then WEIGHT_PROOF_RECENT_BLOCKS
return await self.sync_from_untrusted_close_to_peak(new_peak_hb, peer)
if not far_behind and peer.peer_node_id in self.synced_peers:
# This is the (untrusted) case where we already synced and are not too far behind. Here we just
# fetch one by one.
return await self.sync_from_untrusted_close_to_peak(new_peak_hb, peer)
# we haven't synced fully to this peer yet
syncing = False
if far_behind or len(self.synced_peers) == 0:
syncing = True
secondary_sync_running = (
self._secondary_peer_sync_task is not None and self._secondary_peer_sync_task.done() is False
)
if not syncing and secondary_sync_running:
self.log.info("Will not do secondary sync, there is already another sync task running.")
return False
try:
await self.long_sync_from_untrusted(syncing, new_peak_hb, peer)
except Exception:
self.log.exception(f"Error syncing to {peer.get_peer_info()}")
await peer.close()
return False
return True
async def long_sync_from_untrusted(self, syncing: bool, new_peak_hb: HeaderBlock, peer: WSChiaConnection) -> None:
current_height: uint32 = await self.wallet_state_manager.blockchain.get_finished_sync_up_to()
fork_point_weight_proof = await self.fetch_and_update_weight_proof(peer, new_peak_hb)
# This usually happens the first time we start up the wallet. We roll back slightly to be
# safe, but we don't want to rollback too much (hence 16)
fork_point_rollback: int = max(0, current_height - 16)
# If the weight proof fork point is in the past, rollback more to ensure we don't have duplicate
fork_point_syncing = min(fork_point_rollback, fork_point_weight_proof)
if syncing:
async with self.wallet_state_manager.set_sync_mode(new_peak_hb.height):
await self.long_sync(new_peak_hb.height, peer, fork_point_syncing, rollback=True)
return
# we exit earlier in the case where syncing is False and a Secondary sync is running
assert self._secondary_peer_sync_task is None or self._secondary_peer_sync_task.done()
self.log.info("Secondary peer syncing")
# In this case we will not rollback so it's OK to check some older updates as well, to ensure
# that no recent transactions are being hidden.
self._secondary_peer_sync_task = asyncio.create_task(
self.long_sync(new_peak_hb.height, peer, 0, rollback=False)
)
async def sync_from_untrusted_close_to_peak(self, new_peak_hb: HeaderBlock, peer: WSChiaConnection) -> bool:
async with self.wallet_state_manager.lock:
peak_hb = await self.wallet_state_manager.blockchain.get_peak_block()
if peak_hb is None or new_peak_hb.weight > peak_hb.weight:
backtrack_fork_height: int = await self.wallet_short_sync_backtrack(new_peak_hb, peer)
else:
backtrack_fork_height = new_peak_hb.height - 1
cache = self.get_cache_for_peer(peer)
if peer.peer_node_id not in self.synced_peers:
# Edge case, this happens when the peak < WEIGHT_PROOF_RECENT_BLOCKS
# we still want to subscribe for all phs and coins.
# (Hints are not in filter)
all_coin_ids: List[bytes32] = await self.get_coin_ids_to_subscribe()
phs: List[bytes32] = await self.get_puzzle_hashes_to_subscribe()
ph_updates: List[CoinState] = await subscribe_to_phs(phs, peer, uint32(0))
coin_updates: List[CoinState] = await subscribe_to_coin_updates(all_coin_ids, peer, uint32(0))
success = await self.add_states_from_peer(
ph_updates + coin_updates,
peer,
fork_height=uint32(max(backtrack_fork_height, 0)),
)
if success:
self.synced_peers.add(peer.peer_node_id)
else:
if peak_hb is not None and new_peak_hb.weight <= peak_hb.weight:
# Don't process blocks at the same weight
return False
# For every block, we need to apply the cache from race_cache
for potential_height in range(backtrack_fork_height + 1, new_peak_hb.height + 1):
try:
race_cache = cache.get_race_cache(potential_height)
except KeyError:
continue
self.log.info(f"Apply race cache - height: {potential_height}, coin_states: {race_cache}")
await self.add_states_from_peer(list(race_cache), peer)
# Clear old entries that are no longer relevant
cache.cleanup_race_cache(min_height=backtrack_fork_height)
self.wallet_state_manager.state_changed("new_block")
self.log.info(f"Finished processing new peak of {new_peak_hb.height}")
return True
async def wallet_short_sync_backtrack(self, header_block: HeaderBlock, peer: WSChiaConnection) -> int:
peak: Optional[HeaderBlock] = await self.wallet_state_manager.blockchain.get_peak_block()
top = header_block
blocks = [top]
# Fetch blocks backwards until we hit the one that we have,
# then complete them with additions / removals going forward
fork_height = 0
if self.wallet_state_manager.blockchain.contains_block(header_block.prev_header_hash):
fork_height = header_block.height - 1
while not self.wallet_state_manager.blockchain.contains_block(top.prev_header_hash) and top.height > 0:
request_prev = RequestBlockHeader(uint32(top.height - 1))
response_prev: Optional[RespondBlockHeader] = await peer.call_api(
FullNodeAPI.request_block_header, request_prev
)
if response_prev is None or not isinstance(response_prev, RespondBlockHeader):
raise RuntimeError("bad block header response from peer while syncing")
prev_head = response_prev.header_block
blocks.append(prev_head)
top = prev_head
fork_height = top.height - 1
blocks.reverse()
# Roll back coins and transactions
peak_height = await self.wallet_state_manager.blockchain.get_finished_sync_up_to()
if fork_height < peak_height:
self.log.info(f"Rolling back to {fork_height}")
# we should clear all peers since this is a full rollback
await self.perform_atomic_rollback(fork_height)
await self.update_ui()
if peak is not None:
assert header_block.weight >= peak.weight
for block in blocks:
# Set blockchain to the latest peak
res, err = await self.wallet_state_manager.blockchain.add_block(block)
if res == AddBlockResult.INVALID_BLOCK:
raise ValueError(err)
return fork_height
async def update_ui(self) -> None:
for wallet_id, wallet in self.wallet_state_manager.wallets.items():
self.wallet_state_manager.state_changed("coin_removed", wallet_id)
self.wallet_state_manager.state_changed("coin_added", wallet_id)
async def fetch_and_update_weight_proof(self, peer: WSChiaConnection, peak: HeaderBlock) -> int:
assert self._weight_proof_handler is not None
weight_request = RequestProofOfWeight(peak.height, peak.header_hash)
wp_timeout = self.config.get("weight_proof_timeout", 360)
self.log.debug(f"weight proof timeout is {wp_timeout} sec")
weight_proof_response: RespondProofOfWeight = await peer.call_api(
FullNodeAPI.request_proof_of_weight, weight_request, timeout=wp_timeout
)
if weight_proof_response is None:
raise Exception("weight proof response was none")
weight_proof = weight_proof_response.wp
if weight_proof.recent_chain_data[-1].height != peak.height:
raise Exception("weight proof height does not match peak")
if weight_proof.recent_chain_data[-1].weight != peak.weight:
raise Exception("weight proof weight does not match peak")
if weight_proof.recent_chain_data[-1].header_hash != peak.header_hash:
raise Exception("weight proof peak hash does not match peak")
old_proof = self.wallet_state_manager.blockchain.synced_weight_proof
block_records = await self._weight_proof_handler.validate_weight_proof(weight_proof, False, old_proof)
await self.wallet_state_manager.blockchain.new_valid_weight_proof(weight_proof, block_records)
return get_wp_fork_point(self.constants, old_proof, weight_proof)
async def get_puzzle_hashes_to_subscribe(self) -> List[bytes32]:
all_puzzle_hashes = await self.wallet_state_manager.puzzle_store.get_all_puzzle_hashes(1)
# Get all phs from interested store
interested_puzzle_hashes = [
t[0] for t in await self.wallet_state_manager.interested_store.get_interested_puzzle_hashes()
]
all_puzzle_hashes.update(interested_puzzle_hashes)
return list(all_puzzle_hashes)
async def get_coin_ids_to_subscribe(self) -> List[bytes32]:
coin_ids = await self.wallet_state_manager.trade_manager.get_coins_of_interest()
coin_ids.update(await self.wallet_state_manager.interested_store.get_interested_coin_ids())
return list(coin_ids)
async def validate_received_state_from_peer(
self,
coin_state: CoinState,
peer: WSChiaConnection,
peer_request_cache: PeerRequestCache,
fork_height: Optional[uint32],
) -> bool:
"""
Returns True if the coin_state is valid and included in the blockchain proved by the weight proof.
"""
if peer.closed:
return False
# Only use the cache if we are talking about states before the fork point. If we are evaluating something
# in a reorg, we cannot use the cache, since we don't know if it's actually in the new chain after the reorg.
if can_use_peer_request_cache(coin_state, peer_request_cache, fork_height):
return True
spent_height: Optional[uint32] = None if coin_state.spent_height is None else uint32(coin_state.spent_height)
confirmed_height: Optional[uint32] = (
None if coin_state.created_height is None else uint32(coin_state.created_height)
)
current = await self.wallet_state_manager.coin_store.get_coin_record(coin_state.coin.name())
# if remote state is same as current local state we skip validation
# CoinRecord unspent = height 0, coin state = None. We adjust for comparison below
current_spent_height = None
if current is not None and current.spent_block_height != 0:
current_spent_height = current.spent_block_height
# Same as current state, nothing to do
if (
current is not None
and current_spent_height == spent_height
and current.confirmed_block_height == confirmed_height
):
peer_request_cache.add_to_states_validated(coin_state)
return True
reorg_mode = False
# If coin was removed from the blockchain
if confirmed_height is None:
if current is None:
# Coin does not exist in local DB, so no need to do anything
return False
# This coin got reorged
reorg_mode = True
confirmed_height = current.confirmed_block_height
# request header block for created height
state_block: Optional[HeaderBlock] = peer_request_cache.get_block(confirmed_height)
if state_block is None or reorg_mode:
state_blocks = await request_header_blocks(peer, confirmed_height, confirmed_height)
if state_blocks is None:
return False
state_block = state_blocks[0]
assert state_block is not None
peer_request_cache.add_to_blocks(state_block)
# get proof of inclusion
assert state_block.foliage_transaction_block is not None
validate_additions_result = await request_and_validate_additions(
peer,
peer_request_cache,
state_block.height,
state_block.header_hash,
coin_state.coin.puzzle_hash,
state_block.foliage_transaction_block.additions_root,
)
if validate_additions_result is False:
self.log.warning("Validate false 1")
await peer.close(9999)
return False
# If spent_height is None, we need to validate that the creation block is actually in the longest blockchain.
# Otherwise, we don't have to, since we will validate the spent block later.
if coin_state.spent_height is None:
validated = await self.validate_block_inclusion(state_block, peer, peer_request_cache)
if not validated:
return False
# TODO: make sure all cases are covered
if current is not None:
if spent_height is None and current.spent_block_height != 0:
# Peer is telling us that coin that was previously known to be spent is not spent anymore
# Check old state
spent_state_blocks: Optional[List[HeaderBlock]] = await request_header_blocks(
peer, current.spent_block_height, current.spent_block_height
)
if spent_state_blocks is None:
return False
spent_state_block = spent_state_blocks[0]
assert spent_state_block.height == current.spent_block_height
assert spent_state_block.foliage_transaction_block is not None
peer_request_cache.add_to_blocks(spent_state_block)
validate_removals_result: bool = await request_and_validate_removals(
peer,
current.spent_block_height,
spent_state_block.header_hash,
coin_state.coin.name(),
spent_state_block.foliage_transaction_block.removals_root,
)
if validate_removals_result is False:
self.log.warning("Validate false 2")
await peer.close(9999)
return False
validated = await self.validate_block_inclusion(spent_state_block, peer, peer_request_cache)
if not validated:
return False
if spent_height is not None:
# request header block for created height
cached_spent_state_block = peer_request_cache.get_block(spent_height)
if cached_spent_state_block is None:
spent_state_blocks = await request_header_blocks(peer, spent_height, spent_height)
if spent_state_blocks is None:
return False
spent_state_block = spent_state_blocks[0]
assert spent_state_block.height == spent_height
assert spent_state_block.foliage_transaction_block is not None
peer_request_cache.add_to_blocks(spent_state_block)
else:
spent_state_block = cached_spent_state_block
assert spent_state_block is not None
assert spent_state_block.foliage_transaction_block is not None
validate_removals_result = await request_and_validate_removals(
peer,
spent_state_block.height,
spent_state_block.header_hash,
coin_state.coin.name(),
spent_state_block.foliage_transaction_block.removals_root,
)
if validate_removals_result is False:
self.log.warning("Validate false 3")
await peer.close(9999)
return False
validated = await self.validate_block_inclusion(spent_state_block, peer, peer_request_cache)
if not validated:
return False
peer_request_cache.add_to_states_validated(coin_state)
return True
async def validate_block_inclusion(
self, block: HeaderBlock, peer: WSChiaConnection, peer_request_cache: PeerRequestCache
) -> bool:
if self.wallet_state_manager.blockchain.contains_height(block.height):
stored_hash = self.wallet_state_manager.blockchain.height_to_hash(block.height)
stored_record = self.wallet_state_manager.blockchain.try_block_record(stored_hash)
if stored_record is not None:
if stored_record.header_hash == block.header_hash:
return True
weight_proof: Optional[WeightProof] = self.wallet_state_manager.blockchain.synced_weight_proof
if weight_proof is None:
return False
if block.height >= weight_proof.recent_chain_data[0].height:
# this was already validated as part of the wp validation
index = block.height - weight_proof.recent_chain_data[0].height
if index >= len(weight_proof.recent_chain_data):
return False
if weight_proof.recent_chain_data[index].header_hash != block.header_hash:
self.log.error("Failed validation 1")
return False
return True
# block is not included in wp recent chain
start = uint32(block.height + 1)
compare_to_recent = False
inserted: int = 0
first_height_recent = weight_proof.recent_chain_data[0].height
if start > first_height_recent - 1000:
# compare up to weight_proof.recent_chain_data[0].height
compare_to_recent = True
end = first_height_recent
else:
# get ses from wp
start_height = block.height
end_height = block.height + 32
ses_start_height = 0
end = uint32(0)
for idx, ses in enumerate(weight_proof.sub_epochs):
if idx == len(weight_proof.sub_epochs) - 1:
break
next_ses_height = uint32(
(idx + 1) * self.constants.SUB_EPOCH_BLOCKS + weight_proof.sub_epochs[idx + 1].num_blocks_overflow
)
# start_ses_hash
if ses_start_height <= start_height < next_ses_height:
inserted = idx + 1
if ses_start_height < end_height < next_ses_height:
end = next_ses_height
break
else:
if idx > len(weight_proof.sub_epochs) - 3:
break
# else add extra ses as request start <-> end spans two ses
end = uint32(
(idx + 2) * self.constants.SUB_EPOCH_BLOCKS
+ weight_proof.sub_epochs[idx + 2].num_blocks_overflow
)
inserted += 1
break
ses_start_height = next_ses_height
if end == 0:
self.log.error("Error finding sub epoch")
return False
all_peers_c = self.server.get_connections(NodeType.FULL_NODE)
all_peers = [(con, self.is_trusted(con)) for con in all_peers_c]
blocks: Optional[List[HeaderBlock]] = await fetch_header_blocks_in_range(
start, end, peer_request_cache, all_peers
)
if blocks is None:
log_level = logging.DEBUG if self._shut_down or peer.closed else logging.ERROR
self.log.log(log_level, f"Error fetching blocks {start} {end}")
return False
if compare_to_recent and weight_proof.recent_chain_data[0].header_hash != blocks[-1].header_hash:
self.log.error("Failed validation 3")
return False
if not compare_to_recent:
last = blocks[-1].finished_sub_slots[-1].reward_chain.get_hash()
if last != weight_proof.sub_epochs[inserted].reward_chain_hash:
self.log.error("Failed validation 4")
return False
pk_m_sig: List[Tuple[G1Element, bytes32, G2Element]] = []
sigs_to_cache: List[HeaderBlock] = []
blocks_to_cache: List[Tuple[bytes32, uint32]] = []
signatures_to_validate: int = 30
for idx in range(len(blocks)):
en_block = blocks[idx]
if idx < signatures_to_validate and not peer_request_cache.in_block_signatures_validated(en_block):
# Validate that the block is buried in the foliage by checking the signatures
pk_m_sig.append(
(
en_block.reward_chain_block.proof_of_space.plot_public_key,
en_block.foliage.foliage_block_data.get_hash(),
en_block.foliage.foliage_block_data_signature,
)
)
sigs_to_cache.append(en_block)
# This is the reward chain challenge. If this is in the cache, it means the prev block
# has been validated. We must at least check the first block to ensure they are connected
reward_chain_hash: bytes32 = en_block.reward_chain_block.reward_chain_ip_vdf.challenge
if idx != 0 and peer_request_cache.in_blocks_validated(reward_chain_hash):
# As soon as we see a block we have already concluded is in the chain, we can quit.
if idx > signatures_to_validate:
break
else:
# Validate that the block is committed to by the weight proof
if idx == 0:
prev_block_rc_hash: bytes32 = block.reward_chain_block.get_hash()
prev_hash = block.header_hash
else:
prev_block_rc_hash = blocks[idx - 1].reward_chain_block.get_hash()
prev_hash = blocks[idx - 1].header_hash
if not en_block.prev_header_hash == prev_hash:
self.log.error("Failed validation 5")
return False
if len(en_block.finished_sub_slots) > 0:
reversed_slots = en_block.finished_sub_slots.copy()
reversed_slots.reverse()
for slot_idx, slot in enumerate(reversed_slots[:-1]):
hash_val = reversed_slots[slot_idx + 1].reward_chain.get_hash()
if not hash_val == slot.reward_chain.end_of_slot_vdf.challenge:
self.log.error("Failed validation 6")
return False
if not prev_block_rc_hash == reversed_slots[-1].reward_chain.end_of_slot_vdf.challenge:
self.log.error("Failed validation 7")
return False
else:
if not prev_block_rc_hash == reward_chain_hash:
self.log.error("Failed validation 8")
return False
blocks_to_cache.append((reward_chain_hash, en_block.height))
agg_sig: G2Element = AugSchemeMPL.aggregate([sig for (_, _, sig) in pk_m_sig])
if not AugSchemeMPL.aggregate_verify([pk for (pk, _, _) in pk_m_sig], [m for (_, m, _) in pk_m_sig], agg_sig):
self.log.error("Failed signature validation")
return False
for header_block in sigs_to_cache:
peer_request_cache.add_to_block_signatures_validated(header_block)
for reward_chain_hash, height in blocks_to_cache:
peer_request_cache.add_to_blocks_validated(reward_chain_hash, height)
return True
async def get_coin_state(
self, coin_names: List[bytes32], peer: WSChiaConnection, fork_height: Optional[uint32] = None
) -> List[CoinState]:
msg = RegisterForCoinUpdates(coin_names, uint32(0))
coin_state: Optional[RespondToCoinUpdates] = await peer.call_api(FullNodeAPI.register_interest_in_coin, msg)
if coin_state is None or not isinstance(coin_state, RespondToCoinUpdates):
raise PeerRequestException(f"Was not able to get states for {coin_names}")
if not self.is_trusted(peer):
valid_list = []
for coin in coin_state.coin_states:
valid = await self.validate_received_state_from_peer(
coin, peer, self.get_cache_for_peer(peer), fork_height
)
if valid:
valid_list.append(coin)
return valid_list
return coin_state.coin_states
async def fetch_children(
self, coin_name: bytes32, peer: WSChiaConnection, fork_height: Optional[uint32] = None
) -> List[CoinState]:
response: Optional[RespondChildren] = await peer.call_api(
FullNodeAPI.request_children, RequestChildren(coin_name)
)
if response is None or not isinstance(response, RespondChildren):
raise PeerRequestException(f"Was not able to obtain children {response}")
if not self.is_trusted(peer):
request_cache = self.get_cache_for_peer(peer)
validated = []
for state in response.coin_states:
valid = await self.validate_received_state_from_peer(state, peer, request_cache, fork_height)
if valid:
validated.append(state)
return validated
return response.coin_states
# For RPC only. You should use wallet_state_manager.add_pending_transaction for normal wallet business.
async def push_tx(self, spend_bundle: SpendBundle) -> None:
msg = make_msg(ProtocolMessageTypes.send_transaction, SendTransaction(spend_bundle))
full_nodes = self.server.get_connections(NodeType.FULL_NODE)
for peer in full_nodes:
await peer.send_message(msg)
async def _update_balance_cache(self, wallet_id: uint32) -> None:
assert self.wallet_state_manager.lock.locked(), "WalletStateManager.lock required"
wallet = self.wallet_state_manager.wallets[wallet_id]
if wallet.type() == WalletType.CRCAT:
coin_type = CoinType.CRCAT
else:
coin_type = CoinType.NORMAL
unspent_records = await self.wallet_state_manager.coin_store.get_unspent_coins_for_wallet(wallet_id, coin_type)
balance = await wallet.get_confirmed_balance(unspent_records)
pending_balance = await wallet.get_unconfirmed_balance(unspent_records)
spendable_balance = await wallet.get_spendable_balance(unspent_records)
pending_change = await wallet.get_pending_change_balance()
max_send_amount = await wallet.get_max_send_amount(unspent_records)
unconfirmed_removals: Dict[bytes32, Coin] = await wallet.wallet_state_manager.unconfirmed_removals_for_wallet(
wallet_id
)
self._balance_cache[wallet_id] = Balance(
confirmed_wallet_balance=balance,
unconfirmed_wallet_balance=pending_balance,
spendable_balance=spendable_balance,
pending_change=pending_change,
max_send_amount=max_send_amount,
unspent_coin_count=uint32(len(unspent_records)),
pending_coin_removal_count=uint32(len(unconfirmed_removals)),
)
async def get_balance(self, wallet_id: uint32) -> Balance:
self.log.debug(f"get_balance - wallet_id: {wallet_id}")
if not self.wallet_state_manager.sync_mode:
self.log.debug(f"get_balance - Updating cache for {wallet_id}")
async with self.wallet_state_manager.lock:
await self._update_balance_cache(wallet_id)
return self._balance_cache.get(wallet_id, Balance())