mirror of
https://github.com/Chia-Network/chia-blockchain.git
synced 2024-11-10 02:26:47 +03:00
more explicit and complete handling of api decorator data (#13542)
* more explicit and complete handling of api decorator data * fix * .message_class * actually, those are different types... * tweak * simplify * learn that functools.wraps copies random attributes
This commit is contained in:
parent
2aa117b58a
commit
2fc65e1178
@ -30,6 +30,7 @@ from chia.server.ssl_context import private_ssl_paths, public_ssl_paths
|
||||
from chia.server.ws_connection import WSChiaConnection
|
||||
from chia.types.blockchain_format.sized_bytes import bytes32
|
||||
from chia.types.peer_info import PeerInfo
|
||||
from chia.util.api_decorators import get_metadata
|
||||
from chia.util.errors import Err, ProtocolError
|
||||
from chia.util.ints import uint16
|
||||
from chia.util.network import is_in_network, is_localhost, select_port
|
||||
@ -581,7 +582,8 @@ class ChiaServer:
|
||||
self.log.error(f"Non existing function: {message_type}")
|
||||
raise ProtocolError(Err.INVALID_PROTOCOL_MESSAGE, [message_type])
|
||||
|
||||
if not hasattr(f, "api_function"):
|
||||
metadata = get_metadata(function=f)
|
||||
if not metadata.api_function:
|
||||
self.log.error(f"Peer trying to call non api function {message_type}")
|
||||
raise ProtocolError(Err.INVALID_PROTOCOL_MESSAGE, [message_type])
|
||||
|
||||
@ -591,12 +593,12 @@ class ChiaServer:
|
||||
return None
|
||||
|
||||
timeout: Optional[int] = 600
|
||||
if hasattr(f, "execute_task"):
|
||||
if metadata.execute_task:
|
||||
# Don't timeout on methods with execute_task decorator, these need to run fully
|
||||
self.execute_tasks.add(task_id)
|
||||
timeout = None
|
||||
|
||||
if hasattr(f, "peer_required"):
|
||||
if metadata.peer_required:
|
||||
coroutine = f(full_message.data, connection)
|
||||
else:
|
||||
coroutine = f(full_message.data)
|
||||
|
@ -5,7 +5,7 @@ import contextlib
|
||||
import logging
|
||||
import time
|
||||
import traceback
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, get_type_hints
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
from aiohttp import WSCloseCode, WSMessage, WSMsgType
|
||||
|
||||
@ -17,12 +17,12 @@ from chia.protocols.shared_protocol import Capability, Handshake
|
||||
from chia.server.outbound_message import Message, NodeType, make_msg
|
||||
from chia.server.rate_limits import RateLimiter
|
||||
from chia.types.peer_info import PeerInfo
|
||||
from chia.util.api_decorators import get_metadata
|
||||
from chia.util.errors import Err, ProtocolError
|
||||
from chia.util.ints import uint8, uint16
|
||||
|
||||
# Each message is prepended with LENGTH_BYTES bytes specifying the length
|
||||
from chia.util.network import class_for_type, is_localhost
|
||||
from chia.util.streamable import Streamable
|
||||
|
||||
# Max size 2^(8*4) which is around 4GiB
|
||||
LENGTH_BYTES: int = 4
|
||||
@ -56,19 +56,6 @@ class WSChiaConnection:
|
||||
# Local properties
|
||||
self.ws: Any = ws
|
||||
self.local_type = local_type
|
||||
self.request_types: Dict[str, Type[Streamable]] = {}
|
||||
for name, method in vars(class_for_type(self.local_type)).items():
|
||||
is_api_function = getattr(method, "api_function", False)
|
||||
if not is_api_function:
|
||||
continue
|
||||
|
||||
# It would probably be good to move this into the decorator so it is only
|
||||
# run a single time per decoration instead of on every new connection here.
|
||||
# It would also be good to better identify the single parameter of interest.
|
||||
self.request_types[name] = [
|
||||
hint for name, hint in get_type_hints(method).items() if name not in {"return", "peer"}
|
||||
][-1]
|
||||
|
||||
self.local_port = server_port
|
||||
self.local_capabilities_for_handshake = local_capabilities_for_handshake
|
||||
self.local_capabilities: List[Capability] = [
|
||||
@ -347,7 +334,8 @@ class WSChiaConnection:
|
||||
await self.ban_peer_bad_protocol(self.error_message)
|
||||
raise ProtocolError(Err.INVALID_PROTOCOL_MESSAGE, [error_message])
|
||||
|
||||
result = self.request_types[ProtocolMessageTypes(result.type).name].from_bytes(result.data)
|
||||
recv_method = getattr(class_for_type(self.local_type), recv_message_type.name)
|
||||
result = get_metadata(recv_method).message_class.from_bytes(result.data)
|
||||
return result
|
||||
|
||||
return invoke
|
||||
|
@ -2,31 +2,63 @@ from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from inspect import signature
|
||||
from typing import Any, Callable, Coroutine, List, Optional, Union, get_type_hints
|
||||
from typing import TYPE_CHECKING, Any, Callable, Coroutine, List, Optional, Union, get_type_hints
|
||||
|
||||
from chia.protocols.protocol_message_types import ProtocolMessageTypes
|
||||
from chia.server.outbound_message import Message
|
||||
from chia.server.ws_connection import WSChiaConnection
|
||||
from chia.util.streamable import Streamable, _T_Streamable
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
converted_api_f_type = Union[
|
||||
Callable[[Union[bytes, _T_Streamable]], Coroutine[Any, Any, Optional[Message]]],
|
||||
Callable[[Union[bytes, _T_Streamable], WSChiaConnection], Coroutine[Any, Any, Optional[Message]]],
|
||||
]
|
||||
if TYPE_CHECKING:
|
||||
from chia.server.ws_connection import WSChiaConnection
|
||||
|
||||
initial_api_f_type = Union[
|
||||
Callable[[Any, _T_Streamable], Coroutine[Any, Any, Optional[Message]]],
|
||||
Callable[[Any, _T_Streamable, WSChiaConnection], Coroutine[Any, Any, Optional[Message]]],
|
||||
]
|
||||
converted_api_f_type = Union[
|
||||
Callable[[Union[bytes, _T_Streamable]], Coroutine[Any, Any, Optional[Message]]],
|
||||
Callable[[Union[bytes, _T_Streamable], WSChiaConnection], Coroutine[Any, Any, Optional[Message]]],
|
||||
]
|
||||
|
||||
initial_api_f_type = Union[
|
||||
Callable[[Any, _T_Streamable], Coroutine[Any, Any, Optional[Message]]],
|
||||
Callable[[Any, _T_Streamable, WSChiaConnection], Coroutine[Any, Any, Optional[Message]]],
|
||||
]
|
||||
|
||||
metadata_attribute_name = "_chia_api_metadata"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ApiMetadata:
|
||||
api_function: bool = False
|
||||
peer_required: bool = False
|
||||
bytes_required: bool = False
|
||||
execute_task: bool = False
|
||||
reply_type: List[ProtocolMessageTypes] = field(default_factory=list)
|
||||
message_class: Optional[Any] = None
|
||||
|
||||
|
||||
def get_metadata(function: Callable[..., Any]) -> ApiMetadata:
|
||||
maybe_metadata: Optional[ApiMetadata] = getattr(function, metadata_attribute_name, None)
|
||||
if maybe_metadata is None:
|
||||
return ApiMetadata()
|
||||
|
||||
return maybe_metadata
|
||||
|
||||
|
||||
def set_default_and_get_metadata(function: Callable[..., Any]) -> ApiMetadata:
|
||||
maybe_metadata: Optional[ApiMetadata] = getattr(function, metadata_attribute_name, None)
|
||||
|
||||
if maybe_metadata is None:
|
||||
metadata = ApiMetadata()
|
||||
setattr(function, metadata_attribute_name, metadata)
|
||||
else:
|
||||
metadata = maybe_metadata
|
||||
|
||||
return metadata
|
||||
|
||||
|
||||
def api_request(f: initial_api_f_type) -> converted_api_f_type: # type: ignore
|
||||
annotations = get_type_hints(f)
|
||||
sig = signature(f)
|
||||
|
||||
@functools.wraps(f)
|
||||
def f_substitute(*args, **kwargs) -> Any: # type: ignore
|
||||
binding = sig.bind(*args, **kwargs)
|
||||
@ -36,28 +68,40 @@ def api_request(f: initial_api_f_type) -> converted_api_f_type: # type: ignore
|
||||
# Converts each parameter from a Python dictionary, into an instance of the object
|
||||
# specified by the type annotation (signature) of the function that is being called (f)
|
||||
# The method can also be called with the target type instead of a dictionary.
|
||||
for param_name, param_class in annotations.items():
|
||||
if param_name != "return" and isinstance(inter[param_name], Streamable):
|
||||
if param_class.__name__ == "bytes":
|
||||
continue
|
||||
if hasattr(f, "bytes_required"):
|
||||
inter[f"{param_name}_bytes"] = bytes(inter[param_name])
|
||||
continue
|
||||
if param_name != "return" and isinstance(inter[param_name], bytes):
|
||||
if param_class.__name__ == "bytes":
|
||||
continue
|
||||
if hasattr(f, "bytes_required"):
|
||||
inter[f"{param_name}_bytes"] = inter[param_name]
|
||||
inter[param_name] = param_class.from_bytes(inter[param_name])
|
||||
for param_name, param_class in non_bytes_parameter_annotations.items():
|
||||
original = inter[param_name]
|
||||
|
||||
if isinstance(original, Streamable):
|
||||
if metadata.bytes_required:
|
||||
inter[f"{param_name}_bytes"] = bytes(original)
|
||||
elif isinstance(original, bytes):
|
||||
if metadata.bytes_required:
|
||||
inter[f"{param_name}_bytes"] = original
|
||||
inter[param_name] = param_class.from_bytes(original)
|
||||
return f(**inter) # type: ignore
|
||||
|
||||
setattr(f_substitute, "api_function", True)
|
||||
non_bytes_parameter_annotations = {
|
||||
name: hint for name, hint in get_type_hints(f).items() if name not in {"self", "return"} if hint is not bytes
|
||||
}
|
||||
sig = signature(f)
|
||||
|
||||
# Note that `functools.wraps()` is copying over the metadata attribute from `f()`
|
||||
# onto `f_substitute()`.
|
||||
metadata = set_default_and_get_metadata(function=f_substitute)
|
||||
metadata.api_function = True
|
||||
|
||||
# It would be good to better identify the single parameter of interest.
|
||||
metadata.message_class = [
|
||||
hint for name, hint in get_type_hints(f).items() if name not in {"self", "peer", "return"}
|
||||
][-1]
|
||||
|
||||
return f_substitute
|
||||
|
||||
|
||||
def peer_required(func: Callable[..., Any]) -> Callable[..., Any]:
|
||||
def inner() -> Callable[..., Any]:
|
||||
setattr(func, "peer_required", True)
|
||||
metadata = set_default_and_get_metadata(function=func)
|
||||
metadata.peer_required = True
|
||||
return func
|
||||
|
||||
return inner()
|
||||
@ -65,7 +109,8 @@ def peer_required(func: Callable[..., Any]) -> Callable[..., Any]:
|
||||
|
||||
def bytes_required(func: Callable[..., Any]) -> Callable[..., Any]:
|
||||
def inner() -> Callable[..., Any]:
|
||||
setattr(func, "bytes_required", True)
|
||||
metadata = set_default_and_get_metadata(function=func)
|
||||
metadata.bytes_required = True
|
||||
return func
|
||||
|
||||
return inner()
|
||||
@ -73,7 +118,8 @@ def bytes_required(func: Callable[..., Any]) -> Callable[..., Any]:
|
||||
|
||||
def execute_task(func: Callable[..., Any]) -> Callable[..., Any]:
|
||||
def inner() -> Callable[..., Any]:
|
||||
setattr(func, "execute_task", True)
|
||||
metadata = set_default_and_get_metadata(function=func)
|
||||
metadata.execute_task = True
|
||||
return func
|
||||
|
||||
return inner()
|
||||
@ -82,7 +128,8 @@ def execute_task(func: Callable[..., Any]) -> Callable[..., Any]:
|
||||
def reply_type(prot_type: List[ProtocolMessageTypes]) -> Callable[..., Any]:
|
||||
def wrap(func: Callable[..., Any]) -> Callable[..., Any]:
|
||||
def inner() -> Callable[..., Any]:
|
||||
setattr(func, "reply_type", prot_type)
|
||||
metadata = set_default_and_get_metadata(function=func)
|
||||
metadata.reply_type.extend(prot_type)
|
||||
return func
|
||||
|
||||
return inner()
|
||||
|
Loading…
Reference in New Issue
Block a user