farmer|rpc|tests: Implement paginated harvester plot endpoints (#11365)

* farmer|rpc|tests: Implement paginated harvester plot endpoints

* Simplify filtering

Co-authored-by: Kyle Altendorf <sda@fstab.net>

* Let the API handle the exceptions

* Simplify the other filtering too

Co-authored-by: Kyle Altendorf <sda@fstab.net>

* Simplify count assertions

Co-authored-by: Kyle Altendorf <sda@fstab.net>

* Refactor `is_filter_match` to `plot_matches_filter`

And just convert to `Plot` in tests.

* Move `chia.util.misc.KeyValue` to `chia.rpc.farmer_rpc_api.FilterItem`

* Rename `peer_id` to `node_id` to be match `get_harvesters_{summary}`

Co-authored-by: Kyle Altendorf <sda@fstab.net>
This commit is contained in:
dustinface 2022-05-03 18:17:05 +02:00 committed by GitHub
parent 340e26e154
commit d00d045d9c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 309 additions and 2 deletions

View File

@ -650,6 +650,12 @@ class Farmer:
return {"harvesters": harvesters}
def get_receiver(self, node_id: bytes32) -> Receiver:
receiver: Optional[Receiver] = self.plot_sync_receivers.get(node_id)
if receiver is None:
raise KeyError(f"Receiver missing for {node_id}")
return receiver
async def _periodically_update_pool_state_task(self):
time_slept: uint64 = uint64(0)
config_path: Path = config_path_for_filename(self._root_path, "config.yaml")

View File

@ -1,11 +1,69 @@
import dataclasses
import operator
from typing import Any, Callable, Dict, List, Optional
from typing_extensions import Protocol
from chia.farmer.farmer import Farmer
from chia.plot_sync.receiver import Receiver
from chia.protocols.harvester_protocol import Plot
from chia.types.blockchain_format.sized_bytes import bytes32
from chia.util.byte_types import hexstr_to_bytes
from chia.util.paginator import Paginator
from chia.util.streamable import dataclass_from_dict
from chia.util.ws_message import WsRpcMessage, create_payload_dict
class PaginatedRequestData(Protocol):
node_id: bytes32
page: int
page_size: int
@dataclasses.dataclass
class FilterItem:
key: str
value: Optional[str]
@dataclasses.dataclass
class PlotInfoRequestData:
node_id: bytes32
page: int
page_size: int
filter: List[FilterItem] = dataclasses.field(default_factory=list)
sort_key: str = "filename"
reverse: bool = False
@dataclasses.dataclass
class PlotPathRequestData:
node_id: bytes32
page: int
page_size: int
filter: List[str] = dataclasses.field(default_factory=list)
reverse: bool = False
def paginated_plot_request(source: List[Any], request: PaginatedRequestData) -> Dict[str, object]:
paginator: Paginator = Paginator(source, request.page_size)
return {
"node_id": request.node_id.hex(),
"page": request.page,
"page_count": paginator.page_count(),
"total_count": len(source),
"plots": paginator.get_page(request.page),
}
def plot_matches_filter(plot: Plot, filter_item: FilterItem) -> bool:
plot_attribute = getattr(plot, filter_item.key)
if filter_item.value is None:
return plot_attribute is None
else:
return filter_item.value in str(plot_attribute)
class FarmerRpcApi:
def __init__(self, farmer: Farmer):
self.service = farmer
@ -21,6 +79,10 @@ class FarmerRpcApi:
"/set_payout_instructions": self.set_payout_instructions,
"/get_harvesters": self.get_harvesters,
"/get_harvesters_summary": self.get_harvesters_summary,
"/get_harvester_plots_valid": self.get_harvester_plots_valid,
"/get_harvester_plots_invalid": self.get_harvester_plots_invalid,
"/get_harvester_plots_keys_missing": self.get_harvester_plots_keys_missing,
"/get_harvester_plots_duplicates": self.get_harvester_plots_duplicates,
"/get_pool_login_link": self.get_pool_login_link,
}
@ -159,6 +221,44 @@ class FarmerRpcApi:
async def get_harvesters_summary(self, _: Dict[str, object]) -> Dict[str, object]:
return await self.service.get_harvesters(True)
async def get_harvester_plots_valid(self, request_dict: Dict[str, object]) -> Dict[str, object]:
# TODO: Consider having a extra List[PlotInfo] in Receiver to avoid rebuilding the list for each call
request = dataclass_from_dict(PlotInfoRequestData, request_dict)
plot_list = list(self.service.get_receiver(request.node_id).plots().values())
# Apply filter
plot_list = [
plot for plot in plot_list if all(plot_matches_filter(plot, filter_item) for filter_item in request.filter)
]
restricted_sort_keys: List[str] = ["pool_contract_puzzle_hash", "pool_public_key", "plot_public_key"]
# Apply sort_key and reverse if sort_key is not restricted
if request.sort_key in restricted_sort_keys:
raise KeyError(f"Can't sort by optional attributes: {restricted_sort_keys}")
# Sort by plot_id also by default since its unique
plot_list = sorted(plot_list, key=operator.attrgetter(request.sort_key, "plot_id"), reverse=request.reverse)
return paginated_plot_request(plot_list, request)
def paginated_plot_path_request(
self, source_func: Callable[[Receiver], List[str]], request_dict: Dict[str, object]
) -> Dict[str, object]:
request: PlotPathRequestData = dataclass_from_dict(PlotPathRequestData, request_dict)
receiver = self.service.get_receiver(request.node_id)
source = source_func(receiver)
request = dataclass_from_dict(PlotPathRequestData, request_dict)
# Apply filter
source = [plot for plot in source if all(filter_item in plot for filter_item in request.filter)]
# Apply reverse
source = sorted(source, reverse=request.reverse)
return paginated_plot_request(source, request)
async def get_harvester_plots_invalid(self, request_dict: Dict[str, object]) -> Dict[str, object]:
return self.paginated_plot_path_request(Receiver.invalid, request_dict)
async def get_harvester_plots_keys_missing(self, request_dict: Dict[str, object]) -> Dict[str, object]:
return self.paginated_plot_path_request(Receiver.keys_missing, request_dict)
async def get_harvester_plots_duplicates(self, request_dict: Dict[str, object]) -> Dict[str, object]:
return self.paginated_plot_path_request(Receiver.duplicates, request_dict)
async def get_pool_login_link(self, request: Dict) -> Dict:
launcher_id: bytes32 = bytes32(hexstr_to_bytes(request["launcher_id"]))
login_link: Optional[str] = await self.service.generate_login_link(launcher_id)

View File

@ -1,7 +1,9 @@
from typing import Any, Dict, List, Optional
from chia.rpc.farmer_rpc_api import PlotInfoRequestData, PlotPathRequestData
from chia.rpc.rpc_client import RpcClient
from chia.types.blockchain_format.sized_bytes import bytes32
from chia.util.misc import dataclass_to_json_dict
class FarmerRpcClient(RpcClient):
@ -58,6 +60,18 @@ class FarmerRpcClient(RpcClient):
async def get_harvesters_summary(self) -> Dict[str, object]:
return await self.fetch("get_harvesters_summary", {})
async def get_harvester_plots_valid(self, request: PlotInfoRequestData) -> Dict[str, Any]:
return await self.fetch("get_harvester_plots_valid", dataclass_to_json_dict(request))
async def get_harvester_plots_invalid(self, request: PlotPathRequestData) -> Dict[str, Any]:
return await self.fetch("get_harvester_plots_invalid", dataclass_to_json_dict(request))
async def get_harvester_plots_keys_missing(self, request: PlotPathRequestData) -> Dict[str, Any]:
return await self.fetch("get_harvester_plots_keys_missing", dataclass_to_json_dict(request))
async def get_harvester_plots_duplicates(self, request: PlotPathRequestData) -> Dict[str, Any]:
return await self.fetch("get_harvester_plots_duplicates", dataclass_to_json_dict(request))
async def get_pool_login_link(self, launcher_id: bytes32) -> Optional[str]:
try:
return (await self.fetch("get_pool_login_link", {"launcher_id": launcher_id.hex()}))["login_link"]

View File

@ -1,4 +1,7 @@
from typing import Sequence, Union
import dataclasses
from typing import Any, Dict, Sequence, Union
from chia.util.streamable import recurse_jsonify
def format_bytes(bytes: int) -> str:
@ -75,3 +78,7 @@ def prompt_yes_no(prompt: str = "(y/n) ") -> bool:
def get_list_or_len(list_in: Sequence[object], length: bool) -> Union[int, Sequence[object]]:
return len(list_in) if length else list_in
def dataclass_to_json_dict(instance: Any) -> Dict[str, Any]:
return recurse_jsonify(dataclasses.asdict(instance))

View File

@ -1,12 +1,28 @@
import logging
import operator
import time
from math import ceil
from os import mkdir
from pathlib import Path
from shutil import copy
from typing import Any, Awaitable, Callable, Dict, List, Union, cast
import pytest
import pytest_asyncio
from chia.consensus.coinbase import create_puzzlehash_for_pk
from chia.plot_sync.receiver import Receiver
from chia.plotting.util import add_plot_directory
from chia.protocols import farmer_protocol
from chia.rpc.farmer_rpc_api import FarmerRpcApi
from chia.protocols.harvester_protocol import Plot
from chia.rpc.farmer_rpc_api import (
FarmerRpcApi,
FilterItem,
PaginatedRequestData,
PlotInfoRequestData,
PlotPathRequestData,
plot_matches_filter,
)
from chia.rpc.farmer_rpc_client import FarmerRpcClient
from chia.rpc.harvester_rpc_api import HarvesterRpcApi
from chia.rpc.harvester_rpc_client import HarvesterRpcClient
@ -18,7 +34,10 @@ from chia.util.config import load_config, lock_and_load_config, save_config
from chia.util.hash import std_hash
from chia.util.ints import uint8, uint16, uint32, uint64
from chia.util.misc import get_list_or_len
from chia.util.streamable import dataclass_from_dict
from chia.wallet.derive_keys import master_sk_to_wallet_sk, master_sk_to_wallet_sk_unhardened
from tests.block_tools import get_plot_dir
from tests.plot_sync.test_delta import dummy_plot
from tests.setup_nodes import setup_harvester_farmer, test_constants
from tests.time_out_assert import time_out_assert, time_out_assert_custom_interval
from tests.util.rpc import validate_get_routes
@ -27,6 +46,14 @@ from tests.util.socket import find_available_listen_port
log = logging.getLogger(__name__)
async def wait_for_plot_sync(receiver: Receiver, previous_last_sync_id: uint64) -> None:
def wait():
current_last_sync_id = receiver.last_sync().sync_id
return current_last_sync_id != 0 and current_last_sync_id != previous_last_sync_id
await time_out_assert(30, wait)
@pytest_asyncio.fixture(scope="function")
async def harvester_farmer_simulation(bt, tmp_path):
async for _ in setup_harvester_farmer(bt, tmp_path, test_constants, start_services=True):
@ -367,3 +394,156 @@ async def test_farmer_get_pool_state_plot_count(harvester_farmer_environment, se
await time_out_assert(15, remove_all_and_validate, False)
assert (await farmer_rpc_client.get_pool_state())["pool_state"][0]["plot_count"] == 0
@pytest.mark.parametrize(
"filter_item, match",
[
(FilterItem("filename", "1"), True),
(FilterItem("filename", "12"), True),
(FilterItem("filename", "123"), True),
(FilterItem("filename", "1234"), False),
(FilterItem("filename", "23"), True),
(FilterItem("filename", "3"), True),
(FilterItem("filename", "0123"), False),
(FilterItem("pool_contract_puzzle_hash", None), True),
(FilterItem("pool_contract_puzzle_hash", "1"), False),
],
)
def test_plot_matches_filter(filter_item: FilterItem, match: bool):
assert plot_matches_filter(dummy_plot("123"), filter_item) == match
@pytest.mark.parametrize(
"endpoint, filtering, sort_key, reverse, expected_plot_count",
[
(FarmerRpcClient.get_harvester_plots_valid, [], "filename", False, 20),
(FarmerRpcClient.get_harvester_plots_valid, [], "size", True, 20),
(
FarmerRpcClient.get_harvester_plots_valid,
[FilterItem("pool_contract_puzzle_hash", None)],
"file_size",
True,
15,
),
(
FarmerRpcClient.get_harvester_plots_valid,
[FilterItem("size", "20"), FilterItem("filename", "81")],
"plot_id",
False,
4,
),
(FarmerRpcClient.get_harvester_plots_invalid, [], None, True, 13),
(FarmerRpcClient.get_harvester_plots_invalid, ["invalid_0"], None, False, 6),
(FarmerRpcClient.get_harvester_plots_invalid, ["inval", "lid_1/"], None, False, 2),
(FarmerRpcClient.get_harvester_plots_keys_missing, [], None, True, 3),
(FarmerRpcClient.get_harvester_plots_keys_missing, ["keys_missing_1"], None, False, 2),
(FarmerRpcClient.get_harvester_plots_duplicates, [], None, True, 7),
(FarmerRpcClient.get_harvester_plots_duplicates, ["duplicates_0"], None, False, 3),
],
)
@pytest.mark.asyncio
async def test_farmer_get_harvester_plots_endpoints(
harvester_farmer_environment: Any,
endpoint: Callable[[FarmerRpcClient, PaginatedRequestData], Awaitable[Dict[str, Any]]],
filtering: Union[List[FilterItem], List[str]],
sort_key: str,
reverse: bool,
expected_plot_count: int,
) -> None:
(
farmer_service,
farmer_rpc_api,
farmer_rpc_client,
harvester_service,
harvester_rpc_api,
harvester_rpc_client,
) = harvester_farmer_environment
harvester = harvester_service._node
harvester_id = harvester_service._server.node_id
receiver = farmer_service._api.farmer.plot_sync_receivers[harvester_id]
if receiver.initial_sync():
await wait_for_plot_sync(receiver, receiver.last_sync().sync_id)
harvester_plots = (await harvester_rpc_client.get_plots())["plots"]
plots = []
request: PaginatedRequestData
if endpoint == FarmerRpcClient.get_harvester_plots_valid:
request = PlotInfoRequestData(harvester_id, 0, -1, cast(List[FilterItem], filtering), sort_key, reverse)
else:
request = PlotPathRequestData(harvester_id, 0, -1, cast(List[str], filtering), reverse)
def add_plot_directories(prefix: str, count: int) -> List[Path]:
new_paths = []
for i in range(count):
new_paths.append(harvester.root_path / f"{prefix}_{i}")
mkdir(new_paths[-1])
add_plot_directory(harvester.root_path, str(new_paths[-1]))
return new_paths
# Generate the plot data and
if endpoint == FarmerRpcClient.get_harvester_plots_valid:
plots = harvester_plots
elif endpoint == FarmerRpcClient.get_harvester_plots_invalid:
invalid_paths = add_plot_directories("invalid", 3)
for dir_index, r in [(0, range(0, 6)), (1, range(6, 8)), (2, range(8, 13))]:
plots += [str(invalid_paths[dir_index] / f"{i}.plot") for i in r]
for plot in plots:
with open(plot, "w"):
pass
elif endpoint == FarmerRpcClient.get_harvester_plots_keys_missing:
keys_missing_plots = [path for path in (Path(get_plot_dir()) / "not_in_keychain").iterdir() if path.is_file()]
keys_missing_paths = add_plot_directories("keys_missing", 2)
for dir_index, copy_plots in [(0, keys_missing_plots[:1]), (1, keys_missing_plots[1:3])]:
for plot in copy_plots:
copy(plot, keys_missing_paths[dir_index])
plots.append(str(keys_missing_paths[dir_index] / plot.name))
elif endpoint == FarmerRpcClient.get_harvester_plots_duplicates:
duplicate_paths = add_plot_directories("duplicates", 2)
for dir_index, r in [(0, range(0, 3)), (1, range(3, 7))]:
for i in r:
plot_path = Path(harvester_plots[i]["filename"])
plots.append(str(duplicate_paths[dir_index] / plot_path.name))
copy(plot_path, plots[-1])
# Sort and filter the data
if endpoint == FarmerRpcClient.get_harvester_plots_valid:
for filter_item in filtering:
assert isinstance(filter_item, FilterItem)
plots = [plot for plot in plots if plot_matches_filter(dataclass_from_dict(Plot, plot), filter_item)]
plots.sort(key=operator.itemgetter(sort_key, "plot_id"), reverse=reverse)
else:
for filter_item in filtering:
plots = [plot for plot in plots if filter_item in plot]
plots.sort(reverse=reverse)
total_count = len(plots)
assert total_count == expected_plot_count
last_sync_id = receiver.last_sync().sync_id
harvester.plot_manager.trigger_refresh()
harvester.plot_manager.start_refreshing()
await wait_for_plot_sync(receiver, last_sync_id)
for page_size in [1, int(total_count / 2), total_count - 1, total_count, total_count + 1, 100]:
request.page_size = page_size
expected_page_count = ceil(total_count / page_size)
for page in range(expected_page_count):
request.page = page
page_result = await endpoint(farmer_rpc_client, request)
offset = page * page_size
expected_plots = plots[offset : offset + page_size]
assert page_result == {
"success": True,
"node_id": harvester_id.hex(),
"page": page,
"page_count": expected_page_count,
"total_count": total_count,
"plots": expected_plots,
}