refactor @api_request decorator and family to use a single decorator (#13565)

* more explicit and complete handling of api decorator data

* fix

* .message_class

* actually, those are different types...

* tweak

* simplify

* learn that functools.wraps copies random attributes

* hack the ~planet~ `@api_request` decorator

* R not T

* more future

* implementation detail renames
This commit is contained in:
Kyle Altendorf 2022-10-20 16:10:23 -04:00 committed by GitHub
parent 7abc062d0d
commit bae4e0c5ce
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 179 additions and 284 deletions

View File

@ -31,7 +31,7 @@ from chia.server.server import ssl_context_for_root
from chia.ssl.create_ssl import get_mozilla_ca_crt
from chia.types.blockchain_format.pool_target import PoolTarget
from chia.types.blockchain_format.proof_of_space import ProofOfSpace
from chia.util.api_decorators import api_request, peer_required
from chia.util.api_decorators import api_request
from chia.util.ints import uint32, uint64
@ -51,8 +51,7 @@ class FarmerAPI:
def __init__(self, farmer) -> None:
self.farmer = farmer
@api_request
@peer_required
@api_request(peer_required=True)
async def new_proof_of_space(
self, new_proof_of_space: harvester_protocol.NewProofOfSpace, peer: ws.WSChiaConnection
):
@ -283,7 +282,7 @@ class FarmerAPI:
return
@api_request
@api_request()
async def respond_signatures(self, response: harvester_protocol.RespondSignatures):
"""
There are two cases: receiving signatures for sps, or receiving signatures for the block.
@ -438,7 +437,7 @@ class FarmerAPI:
FARMER PROTOCOL (FARMER <-> FULL NODE)
"""
@api_request
@api_request()
async def new_signage_point(self, new_signage_point: farmer_protocol.NewSignagePoint):
try:
pool_difficulties: List[PoolDifficulty] = []
@ -494,7 +493,7 @@ class FarmerAPI:
self.farmer.cache_add_time[new_signage_point.challenge_chain_sp] = uint64(int(time.time()))
self.farmer.state_changed("new_signage_point", {"sp_hash": new_signage_point.challenge_chain_sp})
@api_request
@api_request()
async def request_signed_values(self, full_node_request: farmer_protocol.RequestSignedValues):
if full_node_request.quality_string not in self.farmer.quality_str_to_identifiers:
self.farmer.log.error(f"Do not have quality string {full_node_request.quality_string}")
@ -513,7 +512,7 @@ class FarmerAPI:
msg = make_msg(ProtocolMessageTypes.request_signatures, request)
await self.farmer.server.send_to_specific([msg], node_id)
@api_request
@api_request()
async def farming_info(self, request: farmer_protocol.FarmingInfo):
self.farmer.state_changed(
"new_farming_info",
@ -529,42 +528,34 @@ class FarmerAPI:
},
)
@api_request
@peer_required
@api_request(peer_required=True)
async def respond_plots(self, _: harvester_protocol.RespondPlots, peer: ws.WSChiaConnection):
self.farmer.log.warning(f"Respond plots came too late from: {peer.get_peer_logging()}")
@api_request
@peer_required
@api_request(peer_required=True)
async def plot_sync_start(self, message: PlotSyncStart, peer: ws.WSChiaConnection):
await self.farmer.plot_sync_receivers[peer.peer_node_id].sync_started(message)
@api_request
@peer_required
@api_request(peer_required=True)
async def plot_sync_loaded(self, message: PlotSyncPlotList, peer: ws.WSChiaConnection):
await self.farmer.plot_sync_receivers[peer.peer_node_id].process_loaded(message)
@api_request
@peer_required
@api_request(peer_required=True)
async def plot_sync_removed(self, message: PlotSyncPathList, peer: ws.WSChiaConnection):
await self.farmer.plot_sync_receivers[peer.peer_node_id].process_removed(message)
@api_request
@peer_required
@api_request(peer_required=True)
async def plot_sync_invalid(self, message: PlotSyncPathList, peer: ws.WSChiaConnection):
await self.farmer.plot_sync_receivers[peer.peer_node_id].process_invalid(message)
@api_request
@peer_required
@api_request(peer_required=True)
async def plot_sync_keys_missing(self, message: PlotSyncPathList, peer: ws.WSChiaConnection):
await self.farmer.plot_sync_receivers[peer.peer_node_id].process_keys_missing(message)
@api_request
@peer_required
@api_request(peer_required=True)
async def plot_sync_duplicates(self, message: PlotSyncPathList, peer: ws.WSChiaConnection):
await self.farmer.plot_sync_receivers[peer.peer_node_id].process_duplicates(message)
@api_request
@peer_required
@api_request(peer_required=True)
async def plot_sync_done(self, message: PlotSyncDone, peer: ws.WSChiaConnection):
await self.farmer.plot_sync_receivers[peer.peer_node_id].sync_done(message)

View File

@ -50,7 +50,7 @@ from chia.types.mempool_item import MempoolItem
from chia.types.peer_info import PeerInfo
from chia.types.transaction_queue_entry import TransactionQueueEntry
from chia.types.unfinished_block import UnfinishedBlock
from chia.util.api_decorators import api_request, peer_required, bytes_required, execute_task, reply_type
from chia.util.api_decorators import api_request
from chia.util.full_block_utils import header_block_from_block
from chia.util.generator_tools import get_block_header, tx_removals_and_additions
from chia.util.hash import std_hash
@ -80,9 +80,7 @@ class FullNodeAPI:
def api_ready(self) -> bool:
return self.full_node.initialized
@peer_required
@api_request
@reply_type([ProtocolMessageTypes.respond_peers])
@api_request(peer_required=True, reply_types=[ProtocolMessageTypes.respond_peers])
async def request_peers(
self, _request: full_node_protocol.RequestPeers, peer: ws.WSChiaConnection
) -> Optional[Message]:
@ -94,8 +92,7 @@ class FullNodeAPI:
return msg
return None
@peer_required
@api_request
@api_request(peer_required=True)
async def respond_peers(
self, request: full_node_protocol.RespondPeers, peer: ws.WSChiaConnection
) -> Optional[Message]:
@ -104,8 +101,7 @@ class FullNodeAPI:
await self.full_node.full_node_peers.respond_peers(request, peer.get_peer_info(), True)
return None
@peer_required
@api_request
@api_request(peer_required=True)
async def respond_peers_introducer(
self, request: introducer_protocol.RespondPeersIntroducer, peer: ws.WSChiaConnection
) -> Optional[Message]:
@ -116,9 +112,7 @@ class FullNodeAPI:
await peer.close()
return None
@execute_task
@peer_required
@api_request
@api_request(peer_required=True, execute_task=True)
async def new_peak(self, request: full_node_protocol.NewPeak, peer: ws.WSChiaConnection) -> None:
"""
A peer notifies us that they have added a new peak to their blockchain. If we don't have it,
@ -139,8 +133,7 @@ class FullNodeAPI:
await self.full_node.new_peak(request, peer)
return None
@peer_required
@api_request
@api_request(peer_required=True)
async def new_transaction(
self, transaction: full_node_protocol.NewTransaction, peer: ws.WSChiaConnection
) -> Optional[Message]:
@ -223,8 +216,7 @@ class FullNodeAPI:
return None
return None
@api_request
@reply_type([ProtocolMessageTypes.respond_transaction])
@api_request(reply_types=[ProtocolMessageTypes.respond_transaction])
async def request_transaction(self, request: full_node_protocol.RequestTransaction) -> Optional[Message]:
"""Peer has requested a full transaction from us."""
# Ignore if syncing
@ -239,9 +231,7 @@ class FullNodeAPI:
msg = make_msg(ProtocolMessageTypes.respond_transaction, transaction)
return msg
@peer_required
@api_request
@bytes_required
@api_request(peer_required=True, bytes_required=True)
async def respond_transaction(
self,
tx: full_node_protocol.RespondTransaction,
@ -271,8 +261,7 @@ class FullNodeAPI:
)
return None
@api_request
@reply_type([ProtocolMessageTypes.respond_proof_of_weight])
@api_request(reply_types=[ProtocolMessageTypes.respond_proof_of_weight])
async def request_proof_of_weight(self, request: full_node_protocol.RequestProofOfWeight) -> Optional[Message]:
if self.full_node.weight_proof_handler is None:
return None
@ -312,13 +301,12 @@ class FullNodeAPI:
self.full_node.full_node_store.serialized_wp_message = message
return message
@api_request
@api_request()
async def respond_proof_of_weight(self, request: full_node_protocol.RespondProofOfWeight) -> Optional[Message]:
self.log.warning("Received proof of weight too late.")
return None
@api_request
@reply_type([ProtocolMessageTypes.respond_block, ProtocolMessageTypes.reject_block])
@api_request(reply_types=[ProtocolMessageTypes.respond_block, ProtocolMessageTypes.reject_block])
async def request_block(self, request: full_node_protocol.RequestBlock) -> Optional[Message]:
if not self.full_node.blockchain.contains_height(request.height):
reject = RejectBlock(request.height)
@ -335,8 +323,7 @@ class FullNodeAPI:
return make_msg(ProtocolMessageTypes.respond_block, full_node_protocol.RespondBlock(block))
return make_msg(ProtocolMessageTypes.reject_block, RejectBlock(request.height))
@api_request
@reply_type([ProtocolMessageTypes.respond_blocks, ProtocolMessageTypes.reject_blocks])
@api_request(reply_types=[ProtocolMessageTypes.respond_blocks, ProtocolMessageTypes.reject_blocks])
async def request_blocks(self, request: full_node_protocol.RequestBlocks) -> Optional[Message]:
if request.end_height < request.start_height or request.end_height - request.start_height > 32:
reject = RejectBlocks(request.start_height, request.end_height)
@ -392,21 +379,20 @@ class FullNodeAPI:
return msg
@api_request
@api_request()
async def reject_block(self, request: full_node_protocol.RejectBlock) -> None:
self.log.debug(f"reject_block {request.height}")
@api_request
@api_request()
async def reject_blocks(self, request: full_node_protocol.RejectBlocks) -> None:
self.log.debug(f"reject_blocks {request.start_height} {request.end_height}")
@api_request
@api_request()
async def respond_blocks(self, request: full_node_protocol.RespondBlocks) -> None:
self.log.warning("Received unsolicited/late blocks")
return None
@api_request
@peer_required
@api_request(peer_required=True)
async def respond_block(
self,
respond_block: full_node_protocol.RespondBlock,
@ -419,7 +405,7 @@ class FullNodeAPI:
self.log.warning(f"Received unsolicited/late block from peer {peer.get_peer_logging()}")
return None
@api_request
@api_request()
async def new_unfinished_block(
self, new_unfinished_block: full_node_protocol.NewUnfinishedBlock
) -> Optional[Message]:
@ -451,8 +437,7 @@ class FullNodeAPI:
return msg
@api_request
@reply_type([ProtocolMessageTypes.respond_unfinished_block])
@api_request(reply_types=[ProtocolMessageTypes.respond_unfinished_block])
async def request_unfinished_block(
self, request_unfinished_block: full_node_protocol.RequestUnfinishedBlock
) -> Optional[Message]:
@ -467,9 +452,7 @@ class FullNodeAPI:
return msg
return None
@peer_required
@api_request
@bytes_required
@api_request(peer_required=True, bytes_required=True)
async def respond_unfinished_block(
self,
respond_unfinished_block: full_node_protocol.RespondUnfinishedBlock,
@ -483,8 +466,7 @@ class FullNodeAPI:
)
return None
@api_request
@peer_required
@api_request(peer_required=True)
async def new_signage_point_or_end_of_sub_slot(
self, new_sp: full_node_protocol.NewSignagePointOrEndOfSubSlot, peer: ws.WSChiaConnection
) -> Optional[Message]:
@ -566,8 +548,7 @@ class FullNodeAPI:
return make_msg(ProtocolMessageTypes.request_signage_point_or_end_of_sub_slot, full_node_request)
@api_request
@reply_type([ProtocolMessageTypes.respond_signage_point, ProtocolMessageTypes.respond_end_of_sub_slot])
@api_request(reply_types=[ProtocolMessageTypes.respond_signage_point, ProtocolMessageTypes.respond_end_of_sub_slot])
async def request_signage_point_or_end_of_sub_slot(
self, request: full_node_protocol.RequestSignagePointOrEndOfSubSlot
) -> Optional[Message]:
@ -610,8 +591,7 @@ class FullNodeAPI:
self.log.info(f"Don't have signage point {request}")
return None
@peer_required
@api_request
@api_request(peer_required=True)
async def respond_signage_point(
self, request: full_node_protocol.RespondSignagePoint, peer: ws.WSChiaConnection
) -> Optional[Message]:
@ -666,8 +646,7 @@ class FullNodeAPI:
return None
@peer_required
@api_request
@api_request(peer_required=True)
async def respond_end_of_sub_slot(
self, request: full_node_protocol.RespondEndOfSubSlot, peer: ws.WSChiaConnection
) -> Optional[Message]:
@ -676,8 +655,7 @@ class FullNodeAPI:
msg, _ = await self.full_node.respond_end_of_sub_slot(request, peer)
return msg
@peer_required
@api_request
@api_request(peer_required=True)
async def request_mempool_transactions(
self,
request: full_node_protocol.RequestMempoolTransactions,
@ -694,8 +672,7 @@ class FullNodeAPI:
return None
# FARMER PROTOCOL
@api_request
@peer_required
@api_request(peer_required=True)
async def declare_proof_of_space(
self, request: farmer_protocol.DeclareProofOfSpace, peer: ws.WSChiaConnection
) -> Optional[Message]:
@ -984,8 +961,7 @@ class FullNodeAPI:
)
return None
@api_request
@peer_required
@api_request(peer_required=True)
async def signed_values(
self, farmer_request: farmer_protocol.SignedValues, peer: ws.WSChiaConnection
) -> Optional[Message]:
@ -1051,8 +1027,7 @@ class FullNodeAPI:
return None
# TIMELORD PROTOCOL
@peer_required
@api_request
@api_request(peer_required=True)
async def new_infusion_point_vdf(
self, request: timelord_protocol.NewInfusionPointVDF, peer: ws.WSChiaConnection
) -> Optional[Message]:
@ -1062,8 +1037,7 @@ class FullNodeAPI:
async with self.full_node.timelord_lock:
return await self.full_node.new_infusion_point_vdf(request, peer)
@peer_required
@api_request
@api_request(peer_required=True)
async def new_signage_point_vdf(
self, request: timelord_protocol.NewSignagePointVDF, peer: ws.WSChiaConnection
) -> None:
@ -1079,8 +1053,7 @@ class FullNodeAPI:
)
await self.respond_signage_point(full_node_message, peer)
@peer_required
@api_request
@api_request(peer_required=True)
async def new_end_of_sub_slot_vdf(
self, request: timelord_protocol.NewEndOfSubSlotVDF, peer: ws.WSChiaConnection
) -> Optional[Message]:
@ -1105,7 +1078,7 @@ class FullNodeAPI:
else:
return msg
@api_request
@api_request()
async def request_block_header(self, request: wallet_protocol.RequestBlockHeader) -> Optional[Message]:
header_hash = self.full_node.blockchain.height_to_hash(request.height)
if header_hash is None:
@ -1146,7 +1119,7 @@ class FullNodeAPI:
)
return msg
@api_request
@api_request()
async def request_additions(self, request: wallet_protocol.RequestAdditions) -> Optional[Message]:
if request.header_hash is None:
header_hash: Optional[bytes32] = self.full_node.blockchain.height_to_hash(request.height)
@ -1201,7 +1174,7 @@ class FullNodeAPI:
response = wallet_protocol.RespondAdditions(request.height, header_hash, coins_map, proofs_map)
return make_msg(ProtocolMessageTypes.respond_additions, response)
@api_request
@api_request()
async def request_removals(self, request: wallet_protocol.RequestRemovals) -> Optional[Message]:
block: Optional[FullBlock] = await self.full_node.block_store.get_full_block(request.header_hash)
@ -1266,7 +1239,7 @@ class FullNodeAPI:
msg = make_msg(ProtocolMessageTypes.respond_removals, response)
return msg
@api_request
@api_request()
async def send_transaction(
self, request: wallet_protocol.SendTransaction, *, test: bool = False
) -> Optional[Message]:
@ -1307,7 +1280,7 @@ class FullNodeAPI:
response = wallet_protocol.TransactionAck(spend_name, uint8(status.value), error_name)
return make_msg(ProtocolMessageTypes.transaction_ack, response)
@api_request
@api_request()
async def request_puzzle_solution(self, request: wallet_protocol.RequestPuzzleSolution) -> Optional[Message]:
coin_name = request.coin_name
height = request.height
@ -1343,7 +1316,7 @@ class FullNodeAPI:
response_msg = make_msg(ProtocolMessageTypes.respond_puzzle_solution, response)
return response_msg
@api_request
@api_request()
async def request_block_headers(self, request: wallet_protocol.RequestBlockHeaders) -> Optional[Message]:
"""Returns header blocks by directly streaming bytes into Message
@ -1390,7 +1363,7 @@ class FullNodeAPI:
respond_header_blocks_manually_streamed += b"".join(header_blocks_bytes)
return make_msg(ProtocolMessageTypes.respond_block_headers, respond_header_blocks_manually_streamed)
@api_request
@api_request()
async def request_header_blocks(self, request: wallet_protocol.RequestHeaderBlocks) -> Optional[Message]:
"""DEPRECATED: please use RequestBlockHeaders"""
if request.end_height < request.start_height or request.end_height - request.start_height > 32:
@ -1424,17 +1397,14 @@ class FullNodeAPI:
)
return msg
@api_request
@api_request()
async def respond_compact_proof_of_time(self, request: timelord_protocol.RespondCompactProofOfTime) -> None:
if self.full_node.sync_store.get_sync_mode():
return None
await self.full_node.respond_compact_proof_of_time(request)
return None
@execute_task
@peer_required
@api_request
@bytes_required
@api_request(peer_required=True, bytes_required=True, execute_task=True)
async def new_compact_vdf(
self, request: full_node_protocol.NewCompactVDF, peer: ws.WSChiaConnection, request_bytes: bytes = b""
) -> None:
@ -1462,9 +1432,7 @@ class FullNodeAPI:
self.full_node.compact_vdf_requests.remove(name)
return None
@peer_required
@api_request
@reply_type([ProtocolMessageTypes.respond_compact_vdf])
@api_request(peer_required=True, reply_types=[ProtocolMessageTypes.respond_compact_vdf])
async def request_compact_vdf(
self, request: full_node_protocol.RequestCompactVDF, peer: ws.WSChiaConnection
) -> None:
@ -1473,8 +1441,7 @@ class FullNodeAPI:
await self.full_node.request_compact_vdf(request, peer)
return None
@peer_required
@api_request
@api_request(peer_required=True)
async def respond_compact_vdf(
self, request: full_node_protocol.RespondCompactVDF, peer: ws.WSChiaConnection
) -> None:
@ -1483,8 +1450,7 @@ class FullNodeAPI:
await self.full_node.respond_compact_vdf(request, peer)
return None
@peer_required
@api_request
@api_request(peer_required=True)
async def register_interest_in_puzzle_hash(
self, request: wallet_protocol.RegisterForPhUpdates, peer: ws.WSChiaConnection
) -> Message:
@ -1525,8 +1491,7 @@ class FullNodeAPI:
msg = make_msg(ProtocolMessageTypes.respond_to_ph_update, response)
return msg
@peer_required
@api_request
@api_request(peer_required=True)
async def register_interest_in_coin(
self, request: wallet_protocol.RegisterForCoinUpdates, peer: ws.WSChiaConnection
) -> Message:
@ -1555,7 +1520,7 @@ class FullNodeAPI:
msg = make_msg(ProtocolMessageTypes.respond_to_coin_update, response)
return msg
@api_request
@api_request()
async def request_children(self, request: wallet_protocol.RequestChildren) -> Optional[Message]:
coin_records: List[CoinRecord] = await self.full_node.coin_store.get_coin_records_by_parent_ids(
True, [request.coin_name]
@ -1565,7 +1530,7 @@ class FullNodeAPI:
msg = make_msg(ProtocolMessageTypes.respond_children, response)
return msg
@api_request
@api_request()
async def request_ses_hashes(self, request: wallet_protocol.RequestSESInfo) -> Message:
"""Returns the start and end height of a sub-epoch for the height specified in request"""

View File

@ -18,7 +18,7 @@ from chia.server.outbound_message import make_msg
from chia.server.ws_connection import WSChiaConnection
from chia.types.blockchain_format.proof_of_space import ProofOfSpace
from chia.types.blockchain_format.sized_bytes import bytes32
from chia.util.api_decorators import api_request, peer_required
from chia.util.api_decorators import api_request
from chia.util.ints import uint8, uint32, uint64
from chia.wallet.derive_keys import master_sk_to_local_sk
@ -29,8 +29,7 @@ class HarvesterAPI:
def __init__(self, harvester: Harvester):
self.harvester = harvester
@peer_required
@api_request
@api_request(peer_required=True)
async def harvester_handshake(
self, harvester_handshake: harvester_protocol.HarvesterHandshake, peer: WSChiaConnection
):
@ -46,8 +45,7 @@ class HarvesterAPI:
await self.harvester.plot_sync_sender.start()
self.harvester.plot_manager.start_refreshing()
@peer_required
@api_request
@api_request(peer_required=True)
async def new_signage_point_harvester(
self, new_challenge: harvester_protocol.NewSignagePointHarvester, peer: WSChiaConnection
):
@ -242,7 +240,7 @@ class HarvesterAPI:
},
)
@api_request
@api_request()
async def request_signatures(self, request: harvester_protocol.RequestSignatures):
"""
The farmer requests a signature on the header hash, for one of the proofs that we found.
@ -291,7 +289,7 @@ class HarvesterAPI:
return make_msg(ProtocolMessageTypes.respond_signatures, response)
@api_request
@api_request()
async def request_plots(self, _: harvester_protocol.RequestPlots):
plots_response = []
plots, failed_to_open_filenames, no_key_filenames = self.harvester.get_plots()
@ -312,6 +310,6 @@ class HarvesterAPI:
response = harvester_protocol.RespondPlots(plots_response, failed_to_open_filenames, no_key_filenames)
return make_msg(ProtocolMessageTypes.respond_plots, response)
@api_request
@api_request()
async def plot_sync_response(self, response: PlotSyncResponse):
self.harvester.plot_sync_sender.set_response(response)

View File

@ -8,7 +8,7 @@ from chia.protocols.protocol_message_types import ProtocolMessageTypes
from chia.server.outbound_message import Message, make_msg
from chia.server.ws_connection import WSChiaConnection
from chia.types.peer_info import TimestampedPeerInfo
from chia.util.api_decorators import api_request, peer_required
from chia.util.api_decorators import api_request
from chia.util.ints import uint64
@ -21,8 +21,7 @@ class IntroducerAPI:
def _set_state_changed_callback(self, callback: Callable):
pass
@peer_required
@api_request
@api_request(peer_required=True)
async def request_peers_introducer(
self,
request: RequestPeersIntroducer,

View File

@ -7,7 +7,7 @@ from chia.protocols import full_node_protocol, wallet_protocol
from chia.seeder.crawler import Crawler
from chia.server.outbound_message import Message
from chia.server.server import ChiaServer
from chia.util.api_decorators import api_request, peer_required
from chia.util.api_decorators import api_request
class CrawlerAPI:
@ -31,76 +31,70 @@ class CrawlerAPI:
def log(self):
return self.crawler.log
@peer_required
@api_request
@api_request(peer_required=True)
async def request_peers(self, _request: full_node_protocol.RequestPeers, peer: ws.WSChiaConnection):
pass
@peer_required
@api_request
@api_request(peer_required=True)
async def respond_peers(
self, request: full_node_protocol.RespondPeers, peer: ws.WSChiaConnection
) -> Optional[Message]:
pass
@peer_required
@api_request
@api_request(peer_required=True)
async def new_peak(self, request: full_node_protocol.NewPeak, peer: ws.WSChiaConnection) -> Optional[Message]:
await self.crawler.new_peak(request, peer)
return None
@api_request
@api_request()
async def new_transaction(self, transaction: full_node_protocol.NewTransaction) -> Optional[Message]:
pass
@api_request
@peer_required
@api_request(peer_required=True)
async def new_signage_point_or_end_of_sub_slot(
self, new_sp: full_node_protocol.NewSignagePointOrEndOfSubSlot, peer: ws.WSChiaConnection
) -> Optional[Message]:
pass
@api_request
@api_request()
async def new_unfinished_block(
self, new_unfinished_block: full_node_protocol.NewUnfinishedBlock
) -> Optional[Message]:
pass
@peer_required
@api_request
@api_request(peer_required=True)
async def new_compact_vdf(self, request: full_node_protocol.NewCompactVDF, peer: ws.WSChiaConnection):
pass
@api_request
@api_request()
async def request_transaction(self, request: full_node_protocol.RequestTransaction) -> Optional[Message]:
pass
@api_request
@api_request()
async def request_proof_of_weight(self, request: full_node_protocol.RequestProofOfWeight) -> Optional[Message]:
pass
@api_request
@api_request()
async def request_block(self, request: full_node_protocol.RequestBlock) -> Optional[Message]:
pass
@api_request
@api_request()
async def request_blocks(self, request: full_node_protocol.RequestBlocks) -> Optional[Message]:
pass
@api_request
@api_request()
async def request_unfinished_block(
self, request_unfinished_block: full_node_protocol.RequestUnfinishedBlock
) -> Optional[Message]:
pass
@api_request
@api_request()
async def request_signage_point_or_end_of_sub_slot(
self, request: full_node_protocol.RequestSignagePointOrEndOfSubSlot
) -> Optional[Message]:
pass
@peer_required
@api_request
@api_request(peer_required=True)
async def request_mempool_transactions(
self,
request: full_node_protocol.RequestMempoolTransactions,
@ -108,22 +102,22 @@ class CrawlerAPI:
) -> Optional[Message]:
pass
@api_request
@api_request()
async def request_block_header(self, request: wallet_protocol.RequestBlockHeader) -> Optional[Message]:
pass
@api_request
@api_request()
async def request_additions(self, request: wallet_protocol.RequestAdditions) -> Optional[Message]:
pass
@api_request
@api_request()
async def request_removals(self, request: wallet_protocol.RequestRemovals) -> Optional[Message]:
pass
@api_request
@api_request()
async def request_puzzle_solution(self, request: wallet_protocol.RequestPuzzleSolution) -> Optional[Message]:
pass
@api_request
@api_request()
async def request_header_blocks(self, request: wallet_protocol.RequestHeaderBlocks) -> Optional[Message]:
pass

View File

@ -21,8 +21,8 @@ class TimelordAPI:
def _set_state_changed_callback(self, callback: Callable):
self.timelord.state_changed_callback = callback
@api_request
async def new_peak_timelord(self, new_peak: timelord_protocol.NewPeakTimelord):
@api_request()
async def new_peak_timelord(self, new_peak: timelord_protocol.NewPeakTimelord) -> None:
if self.timelord.last_state is None:
return None
async with self.timelord.lock:
@ -49,7 +49,7 @@ class TimelordAPI:
self.timelord.state_changed("new_peak", {"height": new_peak.reward_chain_block.height})
self.timelord.new_subslot_end = None
@api_request
@api_request()
async def new_unfinished_block_timelord(self, new_unfinished_block: timelord_protocol.NewUnfinishedBlockTimelord):
if self.timelord.last_state is None:
return None
@ -81,7 +81,7 @@ class TimelordAPI:
self.timelord.total_unfinished += 1
log.debug(f"Non-overflow unfinished block, total {self.timelord.total_unfinished}")
@api_request
@api_request()
async def request_compact_proof_of_time(self, vdf_info: timelord_protocol.RequestCompactProofOfTime):
async with self.timelord.lock:
if not self.timelord.bluebox_mode:

View File

@ -1,29 +1,21 @@
from __future__ import annotations
import functools
import logging
from dataclasses import dataclass, field
from inspect import signature
from typing import TYPE_CHECKING, Any, Callable, Coroutine, List, Optional, Union, get_type_hints
from typing import Any, Callable, List, Optional, TypeVar, Union, get_type_hints
from typing_extensions import Concatenate, ParamSpec
from chia.protocols.protocol_message_types import ProtocolMessageTypes
from chia.server.outbound_message import Message
from chia.util.streamable import Streamable, _T_Streamable
from chia.util.streamable import Streamable
log = logging.getLogger(__name__)
if TYPE_CHECKING:
from chia.server.ws_connection import WSChiaConnection
P = ParamSpec("P")
R = TypeVar("R")
S = TypeVar("S", bound=Streamable)
Self = TypeVar("Self")
converted_api_f_type = Union[
Callable[[Union[bytes, _T_Streamable]], Coroutine[Any, Any, Optional[Message]]],
Callable[[Union[bytes, _T_Streamable], WSChiaConnection], Coroutine[Any, Any, Optional[Message]]],
]
initial_api_f_type = Union[
Callable[[Any, _T_Streamable], Coroutine[Any, Any, Optional[Message]]],
Callable[[Any, _T_Streamable, WSChiaConnection], Coroutine[Any, Any, Optional[Message]]],
]
metadata_attribute_name = "_chia_api_metadata"
@ -34,11 +26,11 @@ class ApiMetadata:
peer_required: bool = False
bytes_required: bool = False
execute_task: bool = False
reply_type: List[ProtocolMessageTypes] = field(default_factory=list)
reply_types: List[ProtocolMessageTypes] = field(default_factory=list)
message_class: Optional[Any] = None
def get_metadata(function: Callable[..., Any]) -> ApiMetadata:
def get_metadata(function: Callable[..., object]) -> ApiMetadata:
maybe_metadata: Optional[ApiMetadata] = getattr(function, metadata_attribute_name, None)
if maybe_metadata is None:
return ApiMetadata()
@ -46,92 +38,54 @@ def get_metadata(function: Callable[..., Any]) -> ApiMetadata:
return maybe_metadata
def set_default_and_get_metadata(function: Callable[..., Any]) -> ApiMetadata:
maybe_metadata: Optional[ApiMetadata] = getattr(function, metadata_attribute_name, None)
def _set_metadata(function: Callable[..., object], metadata: ApiMetadata) -> None:
setattr(function, metadata_attribute_name, metadata)
if maybe_metadata is None:
metadata = ApiMetadata()
setattr(function, metadata_attribute_name, metadata)
# TODO: This hinting does not express that the returned callable *_bytes parameter
# corresponding to the first parameter name will be filled in by the wrapper.
def api_request(
peer_required: bool = False,
bytes_required: bool = False,
execute_task: bool = False,
reply_types: Optional[List[ProtocolMessageTypes]] = None,
) -> Callable[[Callable[Concatenate[Self, S, P], R]], Callable[Concatenate[Self, Union[bytes, S], P], R]]:
non_optional_reply_types: List[ProtocolMessageTypes]
if reply_types is None:
non_optional_reply_types = []
else:
metadata = maybe_metadata
non_optional_reply_types = reply_types
return metadata
def api_request(f: initial_api_f_type) -> converted_api_f_type: # type: ignore
@functools.wraps(f)
def f_substitute(*args, **kwargs) -> Any: # type: ignore
binding = sig.bind(*args, **kwargs)
binding.apply_defaults()
inter = dict(binding.arguments)
# Converts each parameter from a Python dictionary, into an instance of the object
# specified by the type annotation (signature) of the function that is being called (f)
# The method can also be called with the target type instead of a dictionary.
for param_name, param_class in non_bytes_parameter_annotations.items():
original = inter[param_name]
if isinstance(original, Streamable):
def inner(f: Callable[Concatenate[Self, S, P], R]) -> Callable[Concatenate[Self, Union[bytes, S], P], R]:
def wrapper(self: Self, original: Union[bytes, S], *args: P.args, **kwargs: P.kwargs) -> R:
arg: S
if isinstance(original, bytes):
if metadata.bytes_required:
inter[f"{param_name}_bytes"] = bytes(original)
elif isinstance(original, bytes):
kwargs[message_name_bytes] = original
arg = message_class.from_bytes(original)
else:
arg = original
if metadata.bytes_required:
inter[f"{param_name}_bytes"] = original
inter[param_name] = param_class.from_bytes(original)
return f(**inter) # type: ignore
kwargs[message_name_bytes] = bytes(original)
non_bytes_parameter_annotations = {
name: hint for name, hint in get_type_hints(f).items() if name not in {"self", "return"} if hint is not bytes
}
sig = signature(f)
return f(self, arg, *args, **kwargs)
# Note that `functools.wraps()` is copying over the metadata attribute from `f()`
# onto `f_substitute()`.
metadata = set_default_and_get_metadata(function=f_substitute)
metadata.api_function = True
message_name, message_class = next(
(name, hint) for name, hint in get_type_hints(f).items() if name not in {"self", "peer", "return"}
)
message_name_bytes = f"{message_name}_bytes"
# It would be good to better identify the single parameter of interest.
metadata.message_class = [
hint for name, hint in get_type_hints(f).items() if name not in {"self", "peer", "return"}
][-1]
metadata = ApiMetadata(
api_function=True,
peer_required=peer_required,
bytes_required=bytes_required,
execute_task=execute_task,
reply_types=non_optional_reply_types,
message_class=message_class,
)
return f_substitute
_set_metadata(function=wrapper, metadata=metadata)
return wrapper
def peer_required(func: Callable[..., Any]) -> Callable[..., Any]:
def inner() -> Callable[..., Any]:
metadata = set_default_and_get_metadata(function=func)
metadata.peer_required = True
return func
return inner()
def bytes_required(func: Callable[..., Any]) -> Callable[..., Any]:
def inner() -> Callable[..., Any]:
metadata = set_default_and_get_metadata(function=func)
metadata.bytes_required = True
return func
return inner()
def execute_task(func: Callable[..., Any]) -> Callable[..., Any]:
def inner() -> Callable[..., Any]:
metadata = set_default_and_get_metadata(function=func)
metadata.execute_task = True
return func
return inner()
def reply_type(prot_type: List[ProtocolMessageTypes]) -> Callable[..., Any]:
def wrap(func: Callable[..., Any]) -> Callable[..., Any]:
def inner() -> Callable[..., Any]:
metadata = set_default_and_get_metadata(function=func)
metadata.reply_type.extend(prot_type)
return func
return inner()
return wrap
return inner

View File

@ -2,7 +2,7 @@ from chia.protocols import full_node_protocol, introducer_protocol, wallet_proto
from chia.server.outbound_message import NodeType
from chia.server.ws_connection import WSChiaConnection
from chia.types.mempool_inclusion_status import MempoolInclusionStatus
from chia.util.api_decorators import api_request, peer_required, execute_task
from chia.util.api_decorators import api_request
from chia.util.errors import Err
from chia.wallet.wallet_node import WalletNode
@ -21,8 +21,7 @@ class WalletNodeAPI:
def api_ready(self):
return self.wallet_node.logged_in
@peer_required
@api_request
@api_request(peer_required=True)
async def respond_removals(self, response: wallet_protocol.RespondRemovals, peer: WSChiaConnection):
pass
@ -32,16 +31,14 @@ class WalletNodeAPI:
"""
pass
@api_request
@api_request()
async def reject_additions_request(self, response: wallet_protocol.RejectAdditionsRequest):
"""
The full node has rejected our request for additions.
"""
pass
@execute_task
@peer_required
@api_request
@api_request(peer_required=True, execute_task=True)
async def new_peak_wallet(self, peak: wallet_protocol.NewPeakWallet, peer: WSChiaConnection):
"""
The full node sent as a new peak
@ -49,28 +46,26 @@ class WalletNodeAPI:
self.wallet_node.node_peaks[peer.peer_node_id] = (peak.height, peak.header_hash)
await self.wallet_node.new_peak_queue.new_peak_wallet(peak, peer)
@api_request
@api_request()
async def reject_header_request(self, response: wallet_protocol.RejectHeaderRequest):
"""
The full node has rejected our request for a header.
"""
pass
@api_request
@api_request()
async def respond_block_header(self, response: wallet_protocol.RespondBlockHeader):
pass
@peer_required
@api_request
@api_request(peer_required=True)
async def respond_additions(self, response: wallet_protocol.RespondAdditions, peer: WSChiaConnection):
pass
@api_request
@api_request()
async def respond_proof_of_weight(self, response: full_node_protocol.RespondProofOfWeight):
pass
@peer_required
@api_request
@api_request(peer_required=True)
async def transaction_ack(self, ack: wallet_protocol.TransactionAck, peer: WSChiaConnection):
"""
This is an ack for our previous SendTransaction call. This removes the transaction from
@ -104,8 +99,7 @@ class WalletNodeAPI:
else:
await wallet_state_manager.remove_from_queue(ack.txid, name, status, None)
@peer_required
@api_request
@api_request(peer_required=True)
async def respond_peers_introducer(
self, request: introducer_protocol.RespondPeersIntroducer, peer: WSChiaConnection
):
@ -115,8 +109,7 @@ class WalletNodeAPI:
if peer is not None and peer.connection_type is NodeType.INTRODUCER:
await peer.close()
@peer_required
@api_request
@api_request(peer_required=True)
async def respond_peers(self, request: full_node_protocol.RespondPeers, peer: WSChiaConnection):
if self.wallet_node.wallet_peers is None:
return None
@ -126,53 +119,54 @@ class WalletNodeAPI:
return None
@api_request
@api_request()
async def respond_puzzle_solution(self, request: wallet_protocol.RespondPuzzleSolution):
self.log.error("Unexpected message `respond_puzzle_solution`. Peer might be slow to respond")
return None
@api_request
@api_request()
async def reject_puzzle_solution(self, request: wallet_protocol.RejectPuzzleSolution):
self.log.warning(f"Reject puzzle solution: {request}")
@api_request
@api_request()
async def respond_header_blocks(self, request: wallet_protocol.RespondHeaderBlocks):
pass
@api_request
@api_request()
async def respond_block_headers(self, request: wallet_protocol.RespondBlockHeaders):
pass
@api_request
@api_request()
async def reject_header_blocks(self, request: wallet_protocol.RejectHeaderBlocks):
self.log.warning(f"Reject header blocks: {request}")
@api_request
@api_request()
async def reject_block_headers(self, request: wallet_protocol.RejectBlockHeaders):
pass
@execute_task
@peer_required
@api_request
@api_request(peer_required=True, execute_task=True)
async def coin_state_update(self, request: wallet_protocol.CoinStateUpdate, peer: WSChiaConnection):
await self.wallet_node.new_peak_queue.full_node_state_updated(request, peer)
@api_request
# TODO: Review this hinting issue around this rust type not being a Streamable
# subclass, as you might expect it wouldn't be. Maybe we can get the
# protocol working right back at the api_request definition.
@api_request() # type: ignore[type-var]
async def respond_to_ph_update(self, request: wallet_protocol.RespondToPhUpdates):
pass
@api_request
@api_request()
async def respond_to_coin_update(self, request: wallet_protocol.RespondToCoinUpdates):
pass
@api_request
@api_request()
async def respond_children(self, request: wallet_protocol.RespondChildren):
pass
@api_request
@api_request()
async def respond_ses_hashes(self, request: wallet_protocol.RespondSESInfo):
pass
@api_request
@api_request()
async def respond_blocks(self, request: full_node_protocol.RespondBlocks) -> None:
pass

View File

@ -27,7 +27,7 @@ from chia.util.errors import Err
from chia.util.ints import uint64, uint32
from chia.util.hash import std_hash
from chia.types.mempool_inclusion_status import MempoolInclusionStatus
from chia.util.api_decorators import api_request, peer_required, bytes_required
from chia.util.api_decorators import api_request
from chia.full_node.mempool_check_conditions import get_name_puzzle_conditions
from chia.full_node.pending_tx_cache import PendingTxCache
from blspy import G2Element
@ -163,11 +163,9 @@ class TestMempool:
assert spend_bundle is not None
@peer_required
@api_request
@bytes_required
@api_request(peer_required=True, bytes_required=True)
async def respond_transaction(
node: FullNodeAPI,
self: FullNodeAPI,
tx: full_node_protocol.RespondTransaction,
peer: ws.WSChiaConnection,
tx_bytes: bytes = b"",
@ -179,11 +177,11 @@ async def respond_transaction(
"""
assert tx_bytes != b""
spend_name = std_hash(tx_bytes)
if spend_name in node.full_node.full_node_store.pending_tx_request:
node.full_node.full_node_store.pending_tx_request.pop(spend_name)
if spend_name in node.full_node.full_node_store.peers_with_tx:
node.full_node.full_node_store.peers_with_tx.pop(spend_name)
return await node.full_node.respond_transaction(tx.transaction, spend_name, peer, test)
if spend_name in self.full_node.full_node_store.pending_tx_request:
self.full_node.full_node_store.pending_tx_request.pop(spend_name)
if spend_name in self.full_node.full_node_store.peers_with_tx:
self.full_node.full_node_store.peers_with_tx.pop(spend_name)
return await self.full_node.respond_transaction(tx.transaction, spend_name, peer, test)
async def next_block(full_node_1, wallet_a, bt) -> Coin:
@ -383,7 +381,7 @@ class TestMempoolManager:
async def send_sb(self, node: FullNodeAPI, sb: SpendBundle) -> Optional[Message]:
tx = wallet_protocol.SendTransaction(sb)
return await node.send_transaction(tx, test=True) # type: ignore
return await node.send_transaction(tx, test=True)
async def gen_and_send_sb(self, node, peer, *args, **kwargs):
sb = generate_test_spend_bundle(*args, **kwargs)
@ -527,11 +525,12 @@ class TestMempoolManager:
pool_reward_puzzle_hash=reward_ph,
)
_, dummy_node_id = await add_dummy_connection(server_1, bt.config["self_hostname"], 100)
dummy_peer = None
for node_id, wsc in server_1.all_connections.items():
if node_id == dummy_node_id:
dummy_peer = wsc
break
else:
raise Exception("dummy peer not found")
for block in blocks:
await full_node_1.full_node.respond_block(full_node_protocol.RespondBlock(block))
@ -564,11 +563,12 @@ class TestMempoolManager:
time_per_block=10,
)
_, dummy_node_id = await add_dummy_connection(server_1, bt.config["self_hostname"], 100)
dummy_peer = None
for node_id, wsc in server_1.all_connections.items():
if node_id == dummy_node_id:
dummy_peer = wsc
break
else:
raise Exception("dummy peer not found")
for block in blocks:
await full_node_1.full_node.respond_block(full_node_protocol.RespondBlock(block))