farmer: Make Plot{Info|Path}RequestData streamable + use from_json_dict (#11762)

* Make `PlotInfoRequestData` and `PlotPathRequestData` streamable

* Use `from_json_dict` instead of `dataclass_from_dict`

* Recover `PaginatedRequestData(Protocol)`
This commit is contained in:
dustinface 2022-06-03 19:37:03 +02:00 committed by GitHub
parent 1aeeffedbe
commit a375e3cb8f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 35 additions and 20 deletions

View File

@ -9,38 +9,50 @@ from chia.plot_sync.receiver import Receiver
from chia.protocols.harvester_protocol import Plot from chia.protocols.harvester_protocol import Plot
from chia.types.blockchain_format.sized_bytes import bytes32 from chia.types.blockchain_format.sized_bytes import bytes32
from chia.util.byte_types import hexstr_to_bytes from chia.util.byte_types import hexstr_to_bytes
from chia.util.ints import uint32
from chia.util.paginator import Paginator from chia.util.paginator import Paginator
from chia.util.streamable import dataclass_from_dict from chia.util.streamable import Streamable, streamable
from chia.util.ws_message import WsRpcMessage, create_payload_dict from chia.util.ws_message import WsRpcMessage, create_payload_dict
class PaginatedRequestData(Protocol): class PaginatedRequestData(Protocol):
node_id: bytes32 @property
page: int def node_id(self) -> bytes32:
page_size: int pass
@property
def page(self) -> uint32:
pass
@property
def page_size(self) -> uint32:
pass
@dataclasses.dataclass @streamable
class FilterItem: @dataclasses.dataclass(frozen=True)
class FilterItem(Streamable):
key: str key: str
value: Optional[str] value: Optional[str]
@dataclasses.dataclass @streamable
class PlotInfoRequestData: @dataclasses.dataclass(frozen=True)
class PlotInfoRequestData(Streamable):
node_id: bytes32 node_id: bytes32
page: int page: uint32
page_size: int page_size: uint32
filter: List[FilterItem] = dataclasses.field(default_factory=list) filter: List[FilterItem] = dataclasses.field(default_factory=list)
sort_key: str = "filename" sort_key: str = "filename"
reverse: bool = False reverse: bool = False
@dataclasses.dataclass @streamable
class PlotPathRequestData: @dataclasses.dataclass(frozen=True)
class PlotPathRequestData(Streamable):
node_id: bytes32 node_id: bytes32
page: int page: uint32
page_size: int page_size: uint32
filter: List[str] = dataclasses.field(default_factory=list) filter: List[str] = dataclasses.field(default_factory=list)
reverse: bool = False reverse: bool = False
@ -232,7 +244,7 @@ class FarmerRpcApi:
async def get_harvester_plots_valid(self, request_dict: Dict[str, object]) -> Dict[str, object]: async def get_harvester_plots_valid(self, request_dict: Dict[str, object]) -> Dict[str, object]:
# TODO: Consider having a extra List[PlotInfo] in Receiver to avoid rebuilding the list for each call # TODO: Consider having a extra List[PlotInfo] in Receiver to avoid rebuilding the list for each call
request = dataclass_from_dict(PlotInfoRequestData, request_dict) request = PlotInfoRequestData.from_json_dict(request_dict)
plot_list = list(self.service.get_receiver(request.node_id).plots().values()) plot_list = list(self.service.get_receiver(request.node_id).plots().values())
# Apply filter # Apply filter
plot_list = [ plot_list = [
@ -249,7 +261,7 @@ class FarmerRpcApi:
def paginated_plot_path_request( def paginated_plot_path_request(
self, source_func: Callable[[Receiver], List[str]], request_dict: Dict[str, object] self, source_func: Callable[[Receiver], List[str]], request_dict: Dict[str, object]
) -> Dict[str, object]: ) -> Dict[str, object]:
request: PlotPathRequestData = dataclass_from_dict(PlotPathRequestData, request_dict) request: PlotPathRequestData = PlotPathRequestData.from_json_dict(request_dict)
receiver = self.service.get_receiver(request.node_id) receiver = self.service.get_receiver(request.node_id)
source = source_func(receiver) source = source_func(receiver)
# Apply filter # Apply filter

View File

@ -1,3 +1,4 @@
import dataclasses
import logging import logging
import operator import operator
import time import time
@ -462,9 +463,11 @@ async def test_farmer_get_harvester_plots_endpoints(
request: PaginatedRequestData request: PaginatedRequestData
if endpoint == FarmerRpcClient.get_harvester_plots_valid: if endpoint == FarmerRpcClient.get_harvester_plots_valid:
request = PlotInfoRequestData(harvester_id, 0, -1, cast(List[FilterItem], filtering), sort_key, reverse) request = PlotInfoRequestData(
harvester_id, uint32(0), uint32(0), cast(List[FilterItem], filtering), sort_key, reverse
)
else: else:
request = PlotPathRequestData(harvester_id, 0, -1, cast(List[str], filtering), reverse) request = PlotPathRequestData(harvester_id, uint32(0), uint32(0), cast(List[str], filtering), reverse)
def add_plot_directories(prefix: str, count: int) -> List[Path]: def add_plot_directories(prefix: str, count: int) -> List[Path]:
new_paths = [] new_paths = []
@ -522,10 +525,10 @@ async def test_farmer_get_harvester_plots_endpoints(
await wait_for_plot_sync(receiver, last_sync_id) await wait_for_plot_sync(receiver, last_sync_id)
for page_size in [1, int(total_count / 2), total_count - 1, total_count, total_count + 1, 100]: for page_size in [1, int(total_count / 2), total_count - 1, total_count, total_count + 1, 100]:
request.page_size = page_size request = dataclasses.replace(request, page_size=uint32(page_size))
expected_page_count = ceil(total_count / page_size) expected_page_count = ceil(total_count / page_size)
for page in range(expected_page_count): for page in range(expected_page_count):
request.page = page request = dataclasses.replace(request, page=uint32(page))
page_result = await endpoint(farmer_rpc_client, request) page_result = await endpoint(farmer_rpc_client, request)
offset = page * page_size offset = page * page_size
expected_plots = plots[offset : offset + page_size] expected_plots = plots[offset : offset + page_size]