From 85d14f561a97c499ea4bc20a1fc048b90c90cf42 Mon Sep 17 00:00:00 2001 From: Izumi Hoshino Date: Wed, 19 Jul 2023 06:24:56 +0900 Subject: [PATCH] =?UTF-8?q?Added=20compression=20level=20and=20harvesting?= =?UTF-8?q?=20mode=20to=20harvester=20protocol/mes=E2=80=A6=20(#15776)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Added compression level and harvesting mode to harvester protocol/messages * Added test * Fixed lint error --- chia/harvester/harvester.py | 48 ++++++++++++- chia/harvester/harvester_api.py | 1 + chia/plot_sync/receiver.py | 9 +++ chia/plot_sync/sender.py | 14 +++- chia/plotting/cache.py | 2 +- chia/plotting/manager.py | 18 ++++- chia/plotting/util.py | 63 +++++++++++++++++- chia/protocols/harvester_protocol.py | 5 +- chia/rpc/harvester_rpc_api.py | 56 ++++++++++++++++ chia/rpc/harvester_rpc_client.py | 9 +++ .../farmer_harvester/test_farmer_harvester.py | 61 ++++++++++++++++- tests/plot_sync/test_delta.py | 12 +++- tests/plot_sync/test_plot_sync.py | 3 + tests/plot_sync/test_receiver.py | 20 +++++- tests/plot_sync/test_sender.py | 10 +-- tests/plot_sync/test_sync_simulated.py | 5 +- tests/plotting/test_plot_manager.py | 1 + tests/util/network_protocol_data.py | 1 + tests/util/protocol_messages_bytes-v1.0 | Bin 46803 -> 46807 bytes tests/util/protocol_messages_json.py | 2 + 20 files changed, 318 insertions(+), 22 deletions(-) diff --git a/chia/harvester/harvester.py b/chia/harvester/harvester.py index ef419e72bcc5..3cdae4a1da4b 100644 --- a/chia/harvester/harvester.py +++ b/chia/harvester/harvester.py @@ -15,18 +15,22 @@ from chia.consensus.constants import ConsensusConstants from chia.plot_sync.sender import Sender from chia.plotting.manager import PlotManager from chia.plotting.util import ( + HarvestingMode, PlotRefreshEvents, PlotRefreshResult, PlotsRefreshParameter, add_plot_directory, + get_harvester_config, get_plot_directories, remove_plot, remove_plot_directory, + update_harvester_config, ) from chia.rpc.rpc_server import StateChangedProtocol, default_get_connections from chia.server.outbound_message import NodeType from chia.server.server import ChiaServer from chia.server.ws_connection import WSChiaConnection +from chia.util.ints import uint32 log = logging.getLogger(__name__) @@ -42,6 +46,7 @@ class Harvester: _refresh_lock: asyncio.Lock event_loop: asyncio.events.AbstractEventLoop _server: Optional[ChiaServer] + _mode: HarvestingMode @property def server(self) -> ChiaServer: @@ -73,7 +78,6 @@ class Harvester: self.plot_manager = PlotManager( root_path, refresh_parameter=refresh_parameter, refresh_callback=self._plot_refresh_callback ) - self.plot_sync_sender = Sender(self.plot_manager) self._shut_down = False self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=config["num_threads"]) self._server = None @@ -92,7 +96,7 @@ class Harvester: enforce_gpu_index = config.get("enforce_gpu_index", False) try: - self.plot_manager.configure_decompressor( + self._mode = self.plot_manager.configure_decompressor( context_count, thread_count, disable_cpu_affinity, @@ -105,6 +109,8 @@ class Harvester: self.log.error(f"{type(e)} {e} while configuring decompressor.") raise + self.plot_sync_sender = Sender(self.plot_manager, self._mode) + async def _start(self) -> None: self._refresh_lock = asyncio.Lock() self.event_loop = asyncio.get_running_loop() @@ -171,6 +177,7 @@ class Harvester: "plot_public_key": plot_info.plot_public_key, "file_size": plot_info.file_size, "time_modified": int(plot_info.time_modified), + "compression_level": prover.get_compression_level(), } ) self.log.debug( @@ -203,5 +210,42 @@ class Harvester: self.plot_manager.trigger_refresh() return True + async def get_harvester_config(self) -> Dict[str, Any]: + return get_harvester_config(self.root_path) + + async def update_harvester_config( + self, + *, + use_gpu_harvesting: Optional[bool] = None, + gpu_index: Optional[int] = None, + enforce_gpu_index: Optional[bool] = None, + disable_cpu_affinity: Optional[bool] = None, + parallel_decompressor_count: Optional[int] = None, + decompressor_thread_count: Optional[int] = None, + recursive_plot_scan: Optional[bool] = None, + refresh_parameter_interval_seconds: Optional[uint32] = None, + ) -> bool: + refresh_parameter: Optional[PlotsRefreshParameter] = None + if refresh_parameter_interval_seconds is not None: + refresh_parameter = PlotsRefreshParameter( + interval_seconds=refresh_parameter_interval_seconds, + retry_invalid_seconds=self.plot_manager.refresh_parameter.retry_invalid_seconds, + batch_size=self.plot_manager.refresh_parameter.batch_size, + batch_sleep_milliseconds=self.plot_manager.refresh_parameter.batch_sleep_milliseconds, + ) + + update_harvester_config( + self.root_path, + use_gpu_harvesting=use_gpu_harvesting, + gpu_index=gpu_index, + enforce_gpu_index=enforce_gpu_index, + disable_cpu_affinity=disable_cpu_affinity, + parallel_decompressor_count=parallel_decompressor_count, + decompressor_thread_count=decompressor_thread_count, + recursive_plot_scan=recursive_plot_scan, + refresh_parameter=refresh_parameter, + ) + return True + def set_server(self, server: ChiaServer) -> None: self._server = server diff --git a/chia/harvester/harvester_api.py b/chia/harvester/harvester_api.py index 3dac81e08a58..49b5240ad1aa 100644 --- a/chia/harvester/harvester_api.py +++ b/chia/harvester/harvester_api.py @@ -333,6 +333,7 @@ class HarvesterAPI: plot["plot_public_key"], plot["file_size"], plot["time_modified"], + plot["compression_level"], ) ) diff --git a/chia/plot_sync/receiver.py b/chia/plot_sync/receiver.py index 1de18c6f31e9..a11c4fc34b54 100644 --- a/chia/plot_sync/receiver.py +++ b/chia/plot_sync/receiver.py @@ -18,6 +18,7 @@ from chia.plot_sync.exceptions import ( SyncIdsMatchError, ) from chia.plot_sync.util import ErrorCodes, State, T_PlotSyncMessage +from chia.plotting.util import HarvestingMode from chia.protocols.harvester_protocol import ( Plot, PlotSyncDone, @@ -85,6 +86,7 @@ class Receiver: _total_plot_size: int _total_effective_plot_size: int _update_callback: ReceiverUpdateCallback + _harvesting_mode: Optional[HarvestingMode] def __init__( self, @@ -101,6 +103,7 @@ class Receiver: self._total_plot_size = 0 self._total_effective_plot_size = 0 self._update_callback = update_callback + self._harvesting_mode = None async def trigger_callback(self, update: Optional[Delta] = None) -> None: try: @@ -118,6 +121,7 @@ class Receiver: self._duplicates.clear() self._total_plot_size = 0 self._total_effective_plot_size = 0 + self._harvesting_mode = None def connection(self) -> WSChiaConnection: return self._connection @@ -149,6 +153,9 @@ class Receiver: def total_effective_plot_size(self) -> int: return self._total_effective_plot_size + def harvesting_mode(self) -> Optional[HarvestingMode]: + return self._harvesting_mode + async def _process( self, method: Callable[[T_PlotSyncMessage], Any], message_type: ProtocolMessageTypes, message: T_PlotSyncMessage ) -> None: @@ -203,6 +210,7 @@ class Receiver: self._current_sync.delta.clear() self._current_sync.state = State.loaded self._current_sync.plots_total = data.plot_file_count + self._harvesting_mode = HarvestingMode(data.harvesting_mode) self._current_sync.bump_next_message_id() async def sync_started(self, data: PlotSyncStart) -> None: @@ -370,4 +378,5 @@ class Receiver: "total_effective_plot_size": self._total_effective_plot_size, "syncing": syncing, "last_sync_time": self._last_sync.time_done, + "harvesting_mode": self._harvesting_mode, } diff --git a/chia/plot_sync/sender.py b/chia/plot_sync/sender.py index 2e05feed00c3..37417162e170 100644 --- a/chia/plot_sync/sender.py +++ b/chia/plot_sync/sender.py @@ -13,7 +13,7 @@ from typing_extensions import Protocol from chia.plot_sync.exceptions import AlreadyStartedError, InvalidConnectionTypeError from chia.plot_sync.util import Constants from chia.plotting.manager import PlotManager -from chia.plotting.util import PlotInfo +from chia.plotting.util import HarvestingMode, PlotInfo from chia.protocols.harvester_protocol import ( Plot, PlotSyncDone, @@ -45,6 +45,7 @@ def _convert_plot_info_list(plot_infos: List[PlotInfo]) -> List[Plot]: plot_public_key=plot_info.plot_public_key, file_size=uint64(plot_info.file_size), time_modified=uint64(int(plot_info.time_modified)), + compression_level=plot_info.prover.get_compression_level(), ) ) return converted @@ -98,8 +99,9 @@ class Sender: _stop_requested = False _task: Optional[asyncio.Task[None]] _response: Optional[ExpectedResponse] + _harvesting_mode: HarvestingMode - def __init__(self, plot_manager: PlotManager) -> None: + def __init__(self, plot_manager: PlotManager, harvesting_mode: HarvestingMode) -> None: self._plot_manager = plot_manager self._connection = None self._sync_id = uint64(0) @@ -109,6 +111,7 @@ class Sender: self._stop_requested = False self._task = None self._response = None + self._harvesting_mode = harvesting_mode def __str__(self) -> str: return f"sync_id {self._sync_id}, next_message_id {self._next_message_id}, messages {len(self._messages)}" @@ -268,7 +271,12 @@ class Sender: log.debug(f"sync_start {sync_id}") self._sync_id = uint64(sync_id) self._add_message( - ProtocolMessageTypes.plot_sync_start, PlotSyncStart, initial, self._last_sync_id, uint32(int(count)) + ProtocolMessageTypes.plot_sync_start, + PlotSyncStart, + initial, + self._last_sync_id, + uint32(int(count)), + self._harvesting_mode, ) def process_batch(self, loaded: List[PlotInfo], remaining: int) -> None: diff --git a/chia/plotting/cache.py b/chia/plotting/cache.py index 88704a11a7a2..5a8751397fcf 100644 --- a/chia/plotting/cache.py +++ b/chia/plotting/cache.py @@ -21,7 +21,7 @@ from chia.wallet.derive_keys import master_sk_to_local_sk log = logging.getLogger(__name__) -CURRENT_VERSION: int = 1 +CURRENT_VERSION: int = 2 @streamable diff --git a/chia/plotting/manager.py b/chia/plotting/manager.py index 3cd5e3203a05..ed19daa585fa 100644 --- a/chia/plotting/manager.py +++ b/chia/plotting/manager.py @@ -13,7 +13,14 @@ from chiapos import DiskProver, decompressor_context_queue from chia.consensus.pos_quality import UI_ACTUAL_SPACE_CONSTANT_FACTOR, _expected_plot_size from chia.plotting.cache import Cache, CacheEntry -from chia.plotting.util import PlotInfo, PlotRefreshEvents, PlotRefreshResult, PlotsRefreshParameter, get_plot_filenames +from chia.plotting.util import ( + HarvestingMode, + PlotInfo, + PlotRefreshEvents, + PlotRefreshResult, + PlotsRefreshParameter, + get_plot_filenames, +) from chia.util.misc import to_batches log = logging.getLogger(__name__) @@ -56,7 +63,11 @@ class PlotManager: self.no_key_filenames = set() self.farmer_public_keys = [] self.pool_public_keys = [] - self.cache = Cache(self.root_path.resolve() / "cache" / "plot_manager.dat") + # Since `compression_level` property was added to Cache structure, + # previous cache file formats needs to be reset + # When user downgrades harvester, it looks 'plot_manager.dat` while + # latest harvester reads/writes 'plot_manager_v2.dat` + self.cache = Cache(self.root_path.resolve() / "cache" / "plot_manager_v2.dat") self.match_str = match_str self.open_no_key_filenames = open_no_key_filenames self.last_refresh_time = 0 @@ -84,7 +95,7 @@ class PlotManager: use_gpu_harvesting: bool, gpu_index: int, enforce_gpu_index: bool, - ) -> None: + ) -> HarvestingMode: if max_compression_level_allowed > 7: log.error( "Currently only compression levels up to 7 are allowed, " @@ -107,6 +118,7 @@ class PlotManager: f"Falling back to CPU harvesting: {context_count} decompressors count, {thread_count} threads." ) self.max_compression_level_allowed = max_compression_level_allowed + return HarvestingMode.GPU if is_using_gpu else HarvestingMode.CPU def reset(self) -> None: with self: diff --git a/chia/plotting/util.py b/chia/plotting/util.py index 0663ddb446ad..7d3c955913a5 100644 --- a/chia/plotting/util.py +++ b/chia/plotting/util.py @@ -2,9 +2,9 @@ from __future__ import annotations import logging from dataclasses import dataclass, field -from enum import Enum +from enum import Enum, IntEnum from pathlib import Path -from typing import Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union from blspy import G1Element, PrivateKey from chiapos import DiskProver @@ -84,6 +84,11 @@ class Params: stripe_size: int = 65536 +class HarvestingMode(IntEnum): + CPU = 1 + GPU = 2 + + def get_plot_directories(root_path: Path, config: Dict = None) -> List[str]: if config is None: config = load_config(root_path, "config.yaml") @@ -146,6 +151,60 @@ def remove_plot(path: Path): path.unlink() +def get_harvester_config(root_path: Path) -> Dict[str, Any]: + config = load_config(root_path, "config.yaml") + + plots_refresh_parameter = ( + config["harvester"].get("plots_refresh_parameter") + if config["harvester"].get("plots_refresh_parameter") is not None + else PlotsRefreshParameter().to_json_dict() + ) + + return { + "use_gpu_harvesting": config["harvester"].get("use_gpu_harvesting"), + "gpu_index": config["harvester"].get("gpu_index"), + "enforce_gpu_index": config["harvester"].get("enforce_gpu_index"), + "disable_cpu_affinity": config["harvester"].get("disable_cpu_affinity"), + "parallel_decompressor_count": config["harvester"].get("parallel_decompressor_count"), + "decompressor_thread_count": config["harvester"].get("decompressor_thread_count"), + "recursive_plot_scan": config["harvester"].get("recursive_plot_scan"), + "plots_refresh_parameter": plots_refresh_parameter, + } + + +def update_harvester_config( + root_path: Path, + *, + use_gpu_harvesting: Optional[bool] = None, + gpu_index: Optional[int] = None, + enforce_gpu_index: Optional[bool] = None, + disable_cpu_affinity: Optional[bool] = None, + parallel_decompressor_count: Optional[int] = None, + decompressor_thread_count: Optional[int] = None, + recursive_plot_scan: Optional[bool] = None, + refresh_parameter: Optional[PlotsRefreshParameter] = None, +): + with lock_and_load_config(root_path, "config.yaml") as config: + if use_gpu_harvesting is not None: + config["harvester"]["use_gpu_harvesting"] = use_gpu_harvesting + if gpu_index is not None: + config["harvester"]["gpu_index"] = gpu_index + if enforce_gpu_index is not None: + config["harvester"]["enforce_gpu_index"] = enforce_gpu_index + if disable_cpu_affinity is not None: + config["harvester"]["disable_cpu_affinity"] = disable_cpu_affinity + if parallel_decompressor_count is not None: + config["harvester"]["parallel_decompressor_count"] = parallel_decompressor_count + if decompressor_thread_count is not None: + config["harvester"]["decompressor_thread_count"] = decompressor_thread_count + if recursive_plot_scan is not None: + config["harvester"]["recursive_plot_scan"] = recursive_plot_scan + if refresh_parameter is not None: + config["harvester"]["plots_refresh_parameter"] = refresh_parameter.to_json_dict() + + save_config(root_path, "config.yaml", config) + + def get_filenames(directory: Path, recursive: bool) -> List[Path]: try: if not directory.exists(): diff --git a/chia/protocols/harvester_protocol.py b/chia/protocols/harvester_protocol.py index 4283b89252c1..4a9871bbbdfd 100644 --- a/chia/protocols/harvester_protocol.py +++ b/chia/protocols/harvester_protocol.py @@ -84,6 +84,7 @@ class Plot(Streamable): plot_public_key: G1Element file_size: uint64 time_modified: uint64 + compression_level: Optional[uint8] @streamable @@ -115,11 +116,13 @@ class PlotSyncStart(Streamable): initial: bool last_sync_id: uint64 plot_file_count: uint32 + harvesting_mode: uint8 def __str__(self) -> str: return ( f"PlotSyncStart: identifier {self.identifier}, initial {self.initial}, " - f"last_sync_id {self.last_sync_id}, plot_file_count {self.plot_file_count}" + f"last_sync_id {self.last_sync_id}, plot_file_count {self.plot_file_count}, " + f"harvesting_mode {self.harvesting_mode}" ) diff --git a/chia/rpc/harvester_rpc_api.py b/chia/rpc/harvester_rpc_api.py index 3844503d61f3..0cacad4653b9 100644 --- a/chia/rpc/harvester_rpc_api.py +++ b/chia/rpc/harvester_rpc_api.py @@ -4,6 +4,7 @@ from typing import Any, Dict, List, Optional from chia.harvester.harvester import Harvester from chia.rpc.rpc_server import Endpoint, EndpointResult +from chia.util.ints import uint32 from chia.util.ws_message import WsRpcMessage, create_payload_dict @@ -20,6 +21,8 @@ class HarvesterRpcApi: "/add_plot_directory": self.add_plot_directory, "/get_plot_directories": self.get_plot_directories, "/remove_plot_directory": self.remove_plot_directory, + "/get_harvester_config": self.get_harvester_config, + "/update_harvester_config": self.update_harvester_config, } async def _state_changed(self, change: str, change_data: Optional[Dict[str, Any]] = None) -> List[WsRpcMessage]: @@ -78,3 +81,56 @@ class HarvesterRpcApi: if await self.service.remove_plot_directory(directory_name): return {} raise ValueError(f"Did not remove plot directory {directory_name}") + + async def get_harvester_config(self, _: Dict[str, Any]) -> EndpointResult: + harvester_config = await self.service.get_harvester_config() + return { + "use_gpu_harvesting": harvester_config["use_gpu_harvesting"], + "gpu_index": harvester_config["gpu_index"], + "enforce_gpu_index": harvester_config["enforce_gpu_index"], + "disable_cpu_affinity": harvester_config["disable_cpu_affinity"], + "parallel_decompressor_count": harvester_config["parallel_decompressor_count"], + "decompressor_thread_count": harvester_config["decompressor_thread_count"], + "recursive_plot_scan": harvester_config["recursive_plot_scan"], + "refresh_parameter_interval_seconds": harvester_config["plots_refresh_parameter"].get("interval_seconds"), + } + + async def update_harvester_config(self, request: Dict[str, Any]) -> EndpointResult: + use_gpu_harvesting: Optional[bool] = None + gpu_index: Optional[int] = None + enforce_gpu_index: Optional[bool] = None + disable_cpu_affinity: Optional[bool] = None + parallel_decompressor_count: Optional[int] = None + decompressor_thread_count: Optional[int] = None + recursive_plot_scan: Optional[bool] = None + refresh_parameter_interval_seconds: Optional[uint32] = None + if "use_gpu_harvesting" in request: + use_gpu_harvesting = bool(request["use_gpu_harvesting"]) + if "gpu_index" in request: + gpu_index = int(request["gpu_index"]) + if "enforce_gpu_index" in request: + enforce_gpu_index = bool(request["enforce_gpu_index"]) + if "disable_cpu_affinity" in request: + disable_cpu_affinity = bool(request["disable_cpu_affinity"]) + if "parallel_decompressor_count" in request: + parallel_decompressor_count = int(request["parallel_decompressor_count"]) + if "decompressor_thread_count" in request: + decompressor_thread_count = int(request["decompressor_thread_count"]) + if "recursive_plot_scan" in request: + recursive_plot_scan = bool(request["recursive_plot_scan"]) + if "refresh_parameter_interval_seconds" in request: + refresh_parameter_interval_seconds = uint32(request["refresh_parameter_interval_seconds"]) + if refresh_parameter_interval_seconds < 3: + raise ValueError(f"Plot refresh interval seconds({refresh_parameter_interval_seconds}) is too short") + + await self.service.update_harvester_config( + use_gpu_harvesting=use_gpu_harvesting, + gpu_index=gpu_index, + enforce_gpu_index=enforce_gpu_index, + disable_cpu_affinity=disable_cpu_affinity, + parallel_decompressor_count=parallel_decompressor_count, + decompressor_thread_count=decompressor_thread_count, + recursive_plot_scan=recursive_plot_scan, + refresh_parameter_interval_seconds=refresh_parameter_interval_seconds, + ) + return {} diff --git a/chia/rpc/harvester_rpc_client.py b/chia/rpc/harvester_rpc_client.py index 09893c0fe385..9741787cec18 100644 --- a/chia/rpc/harvester_rpc_client.py +++ b/chia/rpc/harvester_rpc_client.py @@ -43,3 +43,12 @@ class HarvesterRpcClient(RpcClient): # TODO: casting due to lack of type checked deserialization result = cast(bool, response["success"]) return result + + async def get_harvester_config(self) -> Dict[str, Any]: + return await self.fetch("get_harvester_config", {}) + + async def update_harvester_config(self, config: Dict[str, Any]) -> bool: + response = await self.fetch("update_harvester_config", config) + # TODO: casting due to lack of type checked deserialization + result = cast(bool, response["success"]) + return result diff --git a/tests/farmer_harvester/test_farmer_harvester.py b/tests/farmer_harvester/test_farmer_harvester.py index 372ea5dc2609..411aa7557f54 100644 --- a/tests/farmer_harvester/test_farmer_harvester.py +++ b/tests/farmer_harvester/test_farmer_harvester.py @@ -1,23 +1,28 @@ from __future__ import annotations import asyncio -from typing import List, Tuple +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple import pytest from blspy import G1Element +from chia.cmds.cmds_util import get_any_service_client from chia.farmer.farmer import Farmer from chia.farmer.farmer_api import FarmerAPI from chia.harvester.harvester import Harvester from chia.harvester.harvester_api import HarvesterAPI +from chia.plotting.util import PlotsRefreshParameter from chia.protocols import harvester_protocol from chia.protocols.protocol_message_types import ProtocolMessageTypes +from chia.rpc.harvester_rpc_client import HarvesterRpcClient from chia.server.outbound_message import NodeType, make_msg from chia.server.start_service import Service from chia.simulator.block_tools import BlockTools from chia.simulator.time_out_assert import time_out_assert from chia.types.blockchain_format.sized_bytes import bytes32 from chia.types.peer_info import UnresolvedPeerInfo +from chia.util.config import load_config from chia.util.keychain import generate_mnemonic from tests.conftest import HarvesterFarmerEnvironment @@ -26,6 +31,16 @@ def farmer_is_started(farmer: Farmer) -> bool: return farmer.started +async def get_harvester_config(harvester_rpc_port: Optional[int], root_path: Path) -> Dict[str, Any]: + async with get_any_service_client(HarvesterRpcClient, harvester_rpc_port, root_path) as (harvester_client, _): + return await harvester_client.get_harvester_config() + + +async def update_harvester_config(harvester_rpc_port: Optional[int], root_path: Path, config: Dict[str, Any]) -> bool: + async with get_any_service_client(HarvesterRpcClient, harvester_rpc_port, root_path) as (harvester_client, _): + return await harvester_client.update_harvester_config(config) + + @pytest.mark.asyncio async def test_start_with_empty_keychain( farmer_one_harvester_not_started: Tuple[ @@ -150,3 +165,47 @@ async def test_farmer_respond_signatures( # We fail the sps record check expected_error = f"Do not have challenge hash {challenge_hash}" assert expected_error in caplog.text + + +@pytest.mark.asyncio +async def test_harvester_config( + farmer_one_harvester: Tuple[List[Service[Harvester, HarvesterAPI]], Service[Farmer, FarmerAPI], BlockTools] +) -> None: + harvester_services, farmer_service, bt = farmer_one_harvester + harvester_service = harvester_services[0] + + assert harvester_service.rpc_server and harvester_service.rpc_server.webserver + + harvester_rpc_port = harvester_service.rpc_server.webserver.listen_port + harvester_config = await get_harvester_config(harvester_rpc_port, bt.root_path) + assert harvester_config["success"] is True + + def check_config_match(config1: Dict[str, Any], config2: Dict[str, Any]) -> None: + assert config1["harvester"]["use_gpu_harvesting"] == config2["use_gpu_harvesting"] + assert config1["harvester"]["gpu_index"] == config2["gpu_index"] + assert config1["harvester"]["enforce_gpu_index"] == config2["enforce_gpu_index"] + assert config1["harvester"]["disable_cpu_affinity"] == config2["disable_cpu_affinity"] + assert config1["harvester"]["parallel_decompressor_count"] == config2["parallel_decompressor_count"] + assert config1["harvester"]["decompressor_thread_count"] == config2["decompressor_thread_count"] + assert config1["harvester"]["recursive_plot_scan"] == config2["recursive_plot_scan"] + assert ( + config2["refresh_parameter_interval_seconds"] == config1["harvester"]["refresh_parameter_interval_seconds"] + if "refresh_parameter_interval_seconds" in config1["harvester"] + else PlotsRefreshParameter().interval_seconds + ) + + check_config_match(bt.config, harvester_config) + + harvester_config["use_gpu_harvesting"] = not harvester_config["use_gpu_harvesting"] + harvester_config["gpu_index"] += 1 + harvester_config["enforce_gpu_index"] = not harvester_config["enforce_gpu_index"] + harvester_config["disable_cpu_affinity"] = not harvester_config["disable_cpu_affinity"] + harvester_config["parallel_decompressor_count"] += 1 + harvester_config["decompressor_thread_count"] += 1 + harvester_config["recursive_plot_scan"] = not harvester_config["recursive_plot_scan"] + harvester_config["refresh_parameter_interval_seconds"] = harvester_config["refresh_parameter_interval_seconds"] + 1 + + res = await update_harvester_config(harvester_rpc_port, bt.root_path, harvester_config) + assert res is True + new_config = load_config(harvester_service.root_path, "config.yaml") + check_config_match(new_config, harvester_config) diff --git a/tests/plot_sync/test_delta.py b/tests/plot_sync/test_delta.py index 8f7779abb220..cc6365317760 100644 --- a/tests/plot_sync/test_delta.py +++ b/tests/plot_sync/test_delta.py @@ -15,7 +15,17 @@ log = logging.getLogger(__name__) def dummy_plot(path: str) -> Plot: - return Plot(path, uint8(32), bytes32(b"\00" * 32), G1Element(), None, G1Element(), uint64(0), uint64(0)) + return Plot( + filename=path, + size=uint8(32), + plot_id=bytes32(b"\00" * 32), + pool_public_key=G1Element(), + pool_contract_puzzle_hash=None, + plot_public_key=G1Element(), + file_size=uint64(0), + time_modified=uint64(0), + compression_level=uint8(0), + ) @pytest.mark.parametrize( diff --git a/tests/plot_sync/test_plot_sync.py b/tests/plot_sync/test_plot_sync.py index 42e0371ab57b..9bb09ac8bf9a 100644 --- a/tests/plot_sync/test_plot_sync.py +++ b/tests/plot_sync/test_plot_sync.py @@ -73,6 +73,7 @@ class ExpectedResult: G1Element(), uint64(0), uint64(0), + uint8(0), ) self.valid_count += len(list_plots) @@ -193,6 +194,7 @@ class Environment: assert plot.prover.get_filename() == delta.valid.additions[path].filename assert plot.prover.get_size() == delta.valid.additions[path].size assert plot.prover.get_id() == delta.valid.additions[path].plot_id + assert plot.prover.get_compression_level() == delta.valid.additions[path].compression_level assert plot.pool_public_key == delta.valid.additions[path].pool_public_key assert plot.pool_contract_puzzle_hash == delta.valid.additions[path].pool_contract_puzzle_hash assert plot.plot_public_key == delta.valid.additions[path].plot_public_key @@ -253,6 +255,7 @@ class Environment: assert plot_info.prover.get_filename() == receiver.plots()[str(path)].filename assert plot_info.prover.get_size() == receiver.plots()[str(path)].size assert plot_info.prover.get_id() == receiver.plots()[str(path)].plot_id + assert plot_info.prover.get_compression_level() == receiver.plots()[str(path)].compression_level assert plot_info.pool_public_key == receiver.plots()[str(path)].pool_public_key assert plot_info.pool_contract_puzzle_hash == receiver.plots()[str(path)].pool_contract_puzzle_hash assert plot_info.plot_public_key == receiver.plots()[str(path)].plot_public_key diff --git a/tests/plot_sync/test_receiver.py b/tests/plot_sync/test_receiver.py index f80293256e6f..140c2586211a 100644 --- a/tests/plot_sync/test_receiver.py +++ b/tests/plot_sync/test_receiver.py @@ -14,6 +14,7 @@ from chia.consensus.pos_quality import UI_ACTUAL_SPACE_CONSTANT_FACTOR, _expecte from chia.plot_sync.delta import Delta from chia.plot_sync.receiver import Receiver, Sync from chia.plot_sync.util import ErrorCodes, State +from chia.plotting.util import HarvestingMode from chia.protocols.harvester_protocol import ( Plot, PlotSyncDone, @@ -44,6 +45,7 @@ def assert_default_values(receiver: Receiver) -> None: assert receiver.duplicates() == [] assert receiver.total_plot_size() == 0 assert receiver.total_effective_plot_size() == 0 + assert receiver.harvesting_mode() is None async def dummy_callback(_: bytes32, __: Delta) -> None: @@ -177,6 +179,7 @@ def plot_sync_setup() -> Tuple[Receiver, List[SyncStepData]]: plot_public_key=G1Element(), file_size=uint64(random.randint(0, 100)), time_modified=uint64(0), + compression_level=uint8(0), ) for x in path_list ] @@ -188,7 +191,15 @@ def plot_sync_setup() -> Tuple[Receiver, List[SyncStepData]]: sum(UI_ACTUAL_SPACE_CONSTANT_FACTOR * int(_expected_plot_size(plot.size)) for plot in receiver.plots().values()) ) sync_steps: List[SyncStepData] = [ - SyncStepData(State.idle, receiver.sync_started, PlotSyncStart, False, uint64(0), uint32(len(plot_info_list))), + SyncStepData( + State.idle, + receiver.sync_started, + PlotSyncStart, + False, + uint64(0), + uint32(len(plot_info_list)), + uint8(HarvestingMode.CPU), + ), SyncStepData(State.loaded, receiver.process_loaded, PlotSyncPlotList, plot_info_list[10:20], True), SyncStepData(State.removed, receiver.process_removed, PlotSyncPathList, path_list[0:10], True), SyncStepData(State.invalid, receiver.process_invalid, PlotSyncPathList, path_list[20:30], True), @@ -258,6 +269,7 @@ async def test_to_dict(counts_only: bool) -> None: "host": receiver.connection().peer_info.host, "port": receiver.connection().peer_info.port, } + assert plot_sync_dict_1["harvesting_mode"] is None # We should get equal dicts assert plot_sync_dict_1 == receiver.to_dict(counts_only) @@ -299,6 +311,7 @@ async def test_to_dict(counts_only: bool) -> None: ) assert plot_sync_dict_3["last_sync_time"] > 0 assert plot_sync_dict_3["syncing"] is None + assert sync_steps[State.idle].args[3] == plot_sync_dict_3["harvesting_mode"] # Trigger a repeated plot sync await receiver.sync_started( @@ -307,6 +320,7 @@ async def test_to_dict(counts_only: bool) -> None: False, receiver.last_sync().sync_id, uint32(1), + uint8(HarvestingMode.CPU), ) ) assert receiver.to_dict()["syncing"] == { @@ -369,13 +383,13 @@ async def test_invalid_ids() -> None: receiver._last_sync.sync_id = uint64(1) # Test "sync_started last doesn't match" invalid_last_sync_id_param = PlotSyncStart( - plot_sync_identifier(uint64(0), uint64(0)), False, uint64(2), uint32(0) + plot_sync_identifier(uint64(0), uint64(0)), False, uint64(2), uint32(0), uint8(HarvestingMode.CPU) ) await current_step.function(invalid_last_sync_id_param) assert_error_response(receiver, ErrorCodes.invalid_last_sync_id) # Test "last_sync_id == new_sync_id" invalid_sync_id_match_param = PlotSyncStart( - plot_sync_identifier(uint64(1), uint64(0)), False, uint64(1), uint32(0) + plot_sync_identifier(uint64(1), uint64(0)), False, uint64(1), uint32(0), uint8(HarvestingMode.CPU) ) await current_step.function(invalid_sync_id_match_param) assert_error_response(receiver, ErrorCodes.sync_ids_match) diff --git a/tests/plot_sync/test_sender.py b/tests/plot_sync/test_sender.py index f71cb62ef7bb..ad5eff6626bd 100644 --- a/tests/plot_sync/test_sender.py +++ b/tests/plot_sync/test_sender.py @@ -5,6 +5,7 @@ import pytest from chia.plot_sync.exceptions import AlreadyStartedError, InvalidConnectionTypeError from chia.plot_sync.sender import ExpectedResponse, Sender from chia.plot_sync.util import Constants +from chia.plotting.util import HarvestingMode from chia.protocols.harvester_protocol import PlotSyncIdentifier, PlotSyncResponse from chia.protocols.protocol_message_types import ProtocolMessageTypes from chia.server.outbound_message import NodeType @@ -14,7 +15,7 @@ from tests.plot_sync.util import get_dummy_connection, plot_sync_identifier def test_default_values(bt: BlockTools) -> None: - sender = Sender(bt.plot_manager) + sender = Sender(bt.plot_manager, HarvestingMode.CPU) assert sender._plot_manager == bt.plot_manager assert sender._connection is None assert sender._sync_id == uint64(0) @@ -24,11 +25,12 @@ def test_default_values(bt: BlockTools) -> None: assert not sender._stop_requested assert sender._task is None assert sender._response is None + assert sender._harvesting_mode == HarvestingMode.CPU def test_set_connection_values(bt: BlockTools) -> None: farmer_connection = get_dummy_connection(NodeType.FARMER) - sender = Sender(bt.plot_manager) + sender = Sender(bt.plot_manager, HarvestingMode.CPU) # Test invalid NodeType values for connection_type in NodeType: if connection_type != NodeType.FARMER: @@ -45,7 +47,7 @@ def test_set_connection_values(bt: BlockTools) -> None: @pytest.mark.asyncio async def test_start_stop_send_task(bt: BlockTools) -> None: - sender = Sender(bt.plot_manager) + sender = Sender(bt.plot_manager, HarvestingMode.CPU) # Make sure starting/restarting works for _ in range(2): assert sender._task is None @@ -62,7 +64,7 @@ async def test_start_stop_send_task(bt: BlockTools) -> None: def test_set_response(bt: BlockTools) -> None: - sender = Sender(bt.plot_manager) + sender = Sender(bt.plot_manager, HarvestingMode.CPU) def new_expected_response(sync_id: int, message_id: int, message_type: ProtocolMessageTypes) -> ExpectedResponse: return ExpectedResponse(message_type, plot_sync_identifier(uint64(sync_id), uint64(message_id))) diff --git a/tests/plot_sync/test_sync_simulated.py b/tests/plot_sync/test_sync_simulated.py index a625345a68e6..70e558a725b2 100644 --- a/tests/plot_sync/test_sync_simulated.py +++ b/tests/plot_sync/test_sync_simulated.py @@ -30,7 +30,7 @@ from chia.server.ws_connection import WSChiaConnection from chia.simulator.block_tools import BlockTools from chia.simulator.time_out_assert import time_out_assert from chia.types.blockchain_format.sized_bytes import bytes32 -from chia.util.ints import int16, uint64 +from chia.util.ints import int16, uint8, uint64 from chia.util.misc import to_batches from tests.plot_sync.util import start_harvester_service @@ -275,6 +275,9 @@ def create_example_plots(count: int) -> List[PlotInfo]: def get_size(self) -> int: return self.size + def get_compression_level(self) -> uint8: + return uint8(0) + return [ PlotInfo( prover=DiskProver(f"{x}", bytes32(token_bytes(32)), 25 + x % 26), diff --git a/tests/plotting/test_plot_manager.py b/tests/plotting/test_plot_manager.py index 27adea53c7d2..3758ae4fd2ea 100644 --- a/tests/plotting/test_plot_manager.py +++ b/tests/plotting/test_plot_manager.py @@ -502,6 +502,7 @@ async def test_plot_info_caching(environment, bt): assert plot_manager.plots[path].prover.get_id() == plot_info.prover.get_id() assert plot_manager.plots[path].prover.get_memo() == plot_info.prover.get_memo() assert plot_manager.plots[path].prover.get_size() == plot_info.prover.get_size() + assert plot_manager.plots[path].prover.get_compression_level() == plot_info.prover.get_compression_level() assert plot_manager.plots[path].pool_public_key == plot_info.pool_public_key assert plot_manager.plots[path].pool_contract_puzzle_hash == plot_info.pool_contract_puzzle_hash assert plot_manager.plots[path].plot_public_key == plot_info.plot_public_key diff --git a/tests/util/network_protocol_data.py b/tests/util/network_protocol_data.py index 5297f14118eb..074b3767825e 100644 --- a/tests/util/network_protocol_data.py +++ b/tests/util/network_protocol_data.py @@ -759,6 +759,7 @@ plot = harvester_protocol.Plot( ), uint64(3368414292564311420), uint64(2573238947935295522), + uint8(0), ) request_plots = harvester_protocol.RequestPlots() diff --git a/tests/util/protocol_messages_bytes-v1.0 b/tests/util/protocol_messages_bytes-v1.0 index 2f7aa9a8ed13e64ff6d4a5a84e1afb310198d7a3..3b356b7d147346ae3e6e0425c80aad0a55c8eb73 100644 GIT binary patch delta 38 pcmccomg)LirVW#(Gwz=pGh;6oBLfuNoE$OZAQL0QX33emWC1a^4b1=m delta 33 lcmccqmg(|arVW#(Gwz)nH)Agc0|Z>395v(MX4#p$WC8u;4a)!k diff --git a/tests/util/protocol_messages_json.py b/tests/util/protocol_messages_json.py index 3fe5851cee47..d1c5d7c83633 100644 --- a/tests/util/protocol_messages_json.py +++ b/tests/util/protocol_messages_json.py @@ -2073,6 +2073,7 @@ plot_json: Dict[str, Any] = { "plot_public_key": "0xa04c6b5ac7dfb935f6feecfdd72348ccf1d4be4fe7e26acf271ea3b7d308da61e0a308f7a62495328a81f5147b66634c", "file_size": 3368414292564311420, "time_modified": 2573238947935295522, + "compression_level": 0, } request_plots_json: Dict[str, Any] = {} @@ -2088,6 +2089,7 @@ respond_plots_json: Dict[str, Any] = { "plot_public_key": "0xa04c6b5ac7dfb935f6feecfdd72348ccf1d4be4fe7e26acf271ea3b7d308da61e0a308f7a62495328a81f5147b66634c", "file_size": 3368414292564311420, "time_modified": 2573238947935295522, + "compression_level": 0, } ], "failed_to_open_filenames": ["str"],