cherry-pick: into main RPC marshaller (#17201)

Recreates #16865 targeting main
This commit is contained in:
StartToaster 2024-01-04 10:15:20 -08:00 committed by GitHub
commit a9e0d11438
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 79 additions and 1 deletions

View File

@ -2,17 +2,42 @@ from __future__ import annotations
import logging
import traceback
from typing import Any, Callable, Coroutine, Dict, List, Optional, Tuple
from typing import Any, Awaitable, Callable, Coroutine, Dict, List, Optional, Tuple, get_type_hints
import aiohttp
from chia.types.blockchain_format.coin import Coin
from chia.util.json_util import obj_to_response
from chia.util.streamable import Streamable
from chia.wallet.conditions import Condition, ConditionValidTimes, conditions_from_json_dicts, parse_timelock_info
from chia.wallet.util.tx_config import TXConfig, TXConfigLoader
log = logging.getLogger(__name__)
# TODO: consolidate this with chia.rpc.rpc_server.Endpoint
# Not all endpoints only take a dictionary so that definition is imperfect
# This definition is weaker than that one however because the arguments can be anything
RpcEndpoint = Callable[..., Awaitable[Dict[str, Any]]]
MarshallableRpcEndpoint = Callable[..., Awaitable[Streamable]]
def marshal(func: MarshallableRpcEndpoint) -> RpcEndpoint:
hints = get_type_hints(func)
request_hint = hints["request"]
assert issubclass(request_hint, Streamable)
request_class = request_hint
async def rpc_endpoint(self, request: Dict[str, Any], *args: object, **kwargs: object) -> Dict[str, Any]:
response_obj: Streamable = await func(
self,
request_class.from_json_dict(request),
*args,
**kwargs,
)
return response_obj.to_json_dict()
return rpc_endpoint
def wrap_http_handler(f) -> Callable:
async def inner(request) -> aiohttp.web.Response:

View File

@ -0,0 +1,53 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import List
import pytest
from chia.rpc.util import marshal
from chia.util.ints import uint32
from chia.util.streamable import Streamable, streamable
@streamable
@dataclass(frozen=True)
class SubObject(Streamable):
qux: str
@streamable
@dataclass(frozen=True)
class TestRequestType(Streamable):
foofoo: str
barbar: uint32
bat: bytes
bam: SubObject
@streamable
@dataclass(frozen=True)
class TestResponseObject(Streamable):
qat: List[str]
sub: SubObject
@pytest.mark.anyio
async def test_rpc_marshalling() -> None:
@marshal
async def test_rpc_endpoint(self: None, request: TestRequestType) -> TestResponseObject:
return TestResponseObject(
[request.foofoo, str(request.barbar), request.bat.hex(), request.bam.qux], request.bam
)
assert await test_rpc_endpoint(
None,
{
"foofoo": "foofoo",
"barbar": 1,
"bat": b"\xff",
"bam": {
"qux": "qux",
},
},
) == {"qat": ["foofoo", "1", "ff", "qux"], "sub": {"qux": "qux"}}