more explicit and complete handling of api decorator data (#13542)

* more explicit and complete handling of api decorator data

* fix

* .message_class

* actually, those are different types...

* tweak

* simplify

* learn that functools.wraps copies random attributes
This commit is contained in:
Kyle Altendorf 2022-09-30 18:47:56 -04:00 committed by GitHub
parent 2aa117b58a
commit 2fc65e1178
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 87 additions and 50 deletions

View File

@ -30,6 +30,7 @@ from chia.server.ssl_context import private_ssl_paths, public_ssl_paths
from chia.server.ws_connection import WSChiaConnection
from chia.types.blockchain_format.sized_bytes import bytes32
from chia.types.peer_info import PeerInfo
from chia.util.api_decorators import get_metadata
from chia.util.errors import Err, ProtocolError
from chia.util.ints import uint16
from chia.util.network import is_in_network, is_localhost, select_port
@ -581,7 +582,8 @@ class ChiaServer:
self.log.error(f"Non existing function: {message_type}")
raise ProtocolError(Err.INVALID_PROTOCOL_MESSAGE, [message_type])
if not hasattr(f, "api_function"):
metadata = get_metadata(function=f)
if not metadata.api_function:
self.log.error(f"Peer trying to call non api function {message_type}")
raise ProtocolError(Err.INVALID_PROTOCOL_MESSAGE, [message_type])
@ -591,12 +593,12 @@ class ChiaServer:
return None
timeout: Optional[int] = 600
if hasattr(f, "execute_task"):
if metadata.execute_task:
# Don't timeout on methods with execute_task decorator, these need to run fully
self.execute_tasks.add(task_id)
timeout = None
if hasattr(f, "peer_required"):
if metadata.peer_required:
coroutine = f(full_message.data, connection)
else:
coroutine = f(full_message.data)

View File

@ -5,7 +5,7 @@ import contextlib
import logging
import time
import traceback
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, get_type_hints
from typing import Any, Callable, Dict, List, Optional, Tuple
from aiohttp import WSCloseCode, WSMessage, WSMsgType
@ -17,12 +17,12 @@ from chia.protocols.shared_protocol import Capability, Handshake
from chia.server.outbound_message import Message, NodeType, make_msg
from chia.server.rate_limits import RateLimiter
from chia.types.peer_info import PeerInfo
from chia.util.api_decorators import get_metadata
from chia.util.errors import Err, ProtocolError
from chia.util.ints import uint8, uint16
# Each message is prepended with LENGTH_BYTES bytes specifying the length
from chia.util.network import class_for_type, is_localhost
from chia.util.streamable import Streamable
# Max size 2^(8*4) which is around 4GiB
LENGTH_BYTES: int = 4
@ -56,19 +56,6 @@ class WSChiaConnection:
# Local properties
self.ws: Any = ws
self.local_type = local_type
self.request_types: Dict[str, Type[Streamable]] = {}
for name, method in vars(class_for_type(self.local_type)).items():
is_api_function = getattr(method, "api_function", False)
if not is_api_function:
continue
# It would probably be good to move this into the decorator so it is only
# run a single time per decoration instead of on every new connection here.
# It would also be good to better identify the single parameter of interest.
self.request_types[name] = [
hint for name, hint in get_type_hints(method).items() if name not in {"return", "peer"}
][-1]
self.local_port = server_port
self.local_capabilities_for_handshake = local_capabilities_for_handshake
self.local_capabilities: List[Capability] = [
@ -347,7 +334,8 @@ class WSChiaConnection:
await self.ban_peer_bad_protocol(self.error_message)
raise ProtocolError(Err.INVALID_PROTOCOL_MESSAGE, [error_message])
result = self.request_types[ProtocolMessageTypes(result.type).name].from_bytes(result.data)
recv_method = getattr(class_for_type(self.local_type), recv_message_type.name)
result = get_metadata(recv_method).message_class.from_bytes(result.data)
return result
return invoke

View File

@ -2,31 +2,63 @@ from __future__ import annotations
import functools
import logging
from dataclasses import dataclass, field
from inspect import signature
from typing import Any, Callable, Coroutine, List, Optional, Union, get_type_hints
from typing import TYPE_CHECKING, Any, Callable, Coroutine, List, Optional, Union, get_type_hints
from chia.protocols.protocol_message_types import ProtocolMessageTypes
from chia.server.outbound_message import Message
from chia.server.ws_connection import WSChiaConnection
from chia.util.streamable import Streamable, _T_Streamable
log = logging.getLogger(__name__)
converted_api_f_type = Union[
Callable[[Union[bytes, _T_Streamable]], Coroutine[Any, Any, Optional[Message]]],
Callable[[Union[bytes, _T_Streamable], WSChiaConnection], Coroutine[Any, Any, Optional[Message]]],
]
if TYPE_CHECKING:
from chia.server.ws_connection import WSChiaConnection
initial_api_f_type = Union[
Callable[[Any, _T_Streamable], Coroutine[Any, Any, Optional[Message]]],
Callable[[Any, _T_Streamable, WSChiaConnection], Coroutine[Any, Any, Optional[Message]]],
]
converted_api_f_type = Union[
Callable[[Union[bytes, _T_Streamable]], Coroutine[Any, Any, Optional[Message]]],
Callable[[Union[bytes, _T_Streamable], WSChiaConnection], Coroutine[Any, Any, Optional[Message]]],
]
initial_api_f_type = Union[
Callable[[Any, _T_Streamable], Coroutine[Any, Any, Optional[Message]]],
Callable[[Any, _T_Streamable, WSChiaConnection], Coroutine[Any, Any, Optional[Message]]],
]
metadata_attribute_name = "_chia_api_metadata"
@dataclass
class ApiMetadata:
api_function: bool = False
peer_required: bool = False
bytes_required: bool = False
execute_task: bool = False
reply_type: List[ProtocolMessageTypes] = field(default_factory=list)
message_class: Optional[Any] = None
def get_metadata(function: Callable[..., Any]) -> ApiMetadata:
maybe_metadata: Optional[ApiMetadata] = getattr(function, metadata_attribute_name, None)
if maybe_metadata is None:
return ApiMetadata()
return maybe_metadata
def set_default_and_get_metadata(function: Callable[..., Any]) -> ApiMetadata:
maybe_metadata: Optional[ApiMetadata] = getattr(function, metadata_attribute_name, None)
if maybe_metadata is None:
metadata = ApiMetadata()
setattr(function, metadata_attribute_name, metadata)
else:
metadata = maybe_metadata
return metadata
def api_request(f: initial_api_f_type) -> converted_api_f_type: # type: ignore
annotations = get_type_hints(f)
sig = signature(f)
@functools.wraps(f)
def f_substitute(*args, **kwargs) -> Any: # type: ignore
binding = sig.bind(*args, **kwargs)
@ -36,28 +68,40 @@ def api_request(f: initial_api_f_type) -> converted_api_f_type: # type: ignore
# Converts each parameter from a Python dictionary, into an instance of the object
# specified by the type annotation (signature) of the function that is being called (f)
# The method can also be called with the target type instead of a dictionary.
for param_name, param_class in annotations.items():
if param_name != "return" and isinstance(inter[param_name], Streamable):
if param_class.__name__ == "bytes":
continue
if hasattr(f, "bytes_required"):
inter[f"{param_name}_bytes"] = bytes(inter[param_name])
continue
if param_name != "return" and isinstance(inter[param_name], bytes):
if param_class.__name__ == "bytes":
continue
if hasattr(f, "bytes_required"):
inter[f"{param_name}_bytes"] = inter[param_name]
inter[param_name] = param_class.from_bytes(inter[param_name])
for param_name, param_class in non_bytes_parameter_annotations.items():
original = inter[param_name]
if isinstance(original, Streamable):
if metadata.bytes_required:
inter[f"{param_name}_bytes"] = bytes(original)
elif isinstance(original, bytes):
if metadata.bytes_required:
inter[f"{param_name}_bytes"] = original
inter[param_name] = param_class.from_bytes(original)
return f(**inter) # type: ignore
setattr(f_substitute, "api_function", True)
non_bytes_parameter_annotations = {
name: hint for name, hint in get_type_hints(f).items() if name not in {"self", "return"} if hint is not bytes
}
sig = signature(f)
# Note that `functools.wraps()` is copying over the metadata attribute from `f()`
# onto `f_substitute()`.
metadata = set_default_and_get_metadata(function=f_substitute)
metadata.api_function = True
# It would be good to better identify the single parameter of interest.
metadata.message_class = [
hint for name, hint in get_type_hints(f).items() if name not in {"self", "peer", "return"}
][-1]
return f_substitute
def peer_required(func: Callable[..., Any]) -> Callable[..., Any]:
def inner() -> Callable[..., Any]:
setattr(func, "peer_required", True)
metadata = set_default_and_get_metadata(function=func)
metadata.peer_required = True
return func
return inner()
@ -65,7 +109,8 @@ def peer_required(func: Callable[..., Any]) -> Callable[..., Any]:
def bytes_required(func: Callable[..., Any]) -> Callable[..., Any]:
def inner() -> Callable[..., Any]:
setattr(func, "bytes_required", True)
metadata = set_default_and_get_metadata(function=func)
metadata.bytes_required = True
return func
return inner()
@ -73,7 +118,8 @@ def bytes_required(func: Callable[..., Any]) -> Callable[..., Any]:
def execute_task(func: Callable[..., Any]) -> Callable[..., Any]:
def inner() -> Callable[..., Any]:
setattr(func, "execute_task", True)
metadata = set_default_and_get_metadata(function=func)
metadata.execute_task = True
return func
return inner()
@ -82,7 +128,8 @@ def execute_task(func: Callable[..., Any]) -> Callable[..., Any]:
def reply_type(prot_type: List[ProtocolMessageTypes]) -> Callable[..., Any]:
def wrap(func: Callable[..., Any]) -> Callable[..., Any]:
def inner() -> Callable[..., Any]:
setattr(func, "reply_type", prot_type)
metadata = set_default_and_get_metadata(function=func)
metadata.reply_type.extend(prot_type)
return func
return inner()