Added compression level and harvesting mode to harvester protocol/mes… (#15776)

* Added compression level and harvesting mode to harvester protocol/messages

* Added test

* Fixed lint error
This commit is contained in:
Izumi Hoshino 2023-07-19 06:24:56 +09:00 committed by GitHub
parent 1f9081a225
commit 85d14f561a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
20 changed files with 318 additions and 22 deletions

View File

@ -15,18 +15,22 @@ from chia.consensus.constants import ConsensusConstants
from chia.plot_sync.sender import Sender
from chia.plotting.manager import PlotManager
from chia.plotting.util import (
HarvestingMode,
PlotRefreshEvents,
PlotRefreshResult,
PlotsRefreshParameter,
add_plot_directory,
get_harvester_config,
get_plot_directories,
remove_plot,
remove_plot_directory,
update_harvester_config,
)
from chia.rpc.rpc_server import StateChangedProtocol, default_get_connections
from chia.server.outbound_message import NodeType
from chia.server.server import ChiaServer
from chia.server.ws_connection import WSChiaConnection
from chia.util.ints import uint32
log = logging.getLogger(__name__)
@ -42,6 +46,7 @@ class Harvester:
_refresh_lock: asyncio.Lock
event_loop: asyncio.events.AbstractEventLoop
_server: Optional[ChiaServer]
_mode: HarvestingMode
@property
def server(self) -> ChiaServer:
@ -73,7 +78,6 @@ class Harvester:
self.plot_manager = PlotManager(
root_path, refresh_parameter=refresh_parameter, refresh_callback=self._plot_refresh_callback
)
self.plot_sync_sender = Sender(self.plot_manager)
self._shut_down = False
self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=config["num_threads"])
self._server = None
@ -92,7 +96,7 @@ class Harvester:
enforce_gpu_index = config.get("enforce_gpu_index", False)
try:
self.plot_manager.configure_decompressor(
self._mode = self.plot_manager.configure_decompressor(
context_count,
thread_count,
disable_cpu_affinity,
@ -105,6 +109,8 @@ class Harvester:
self.log.error(f"{type(e)} {e} while configuring decompressor.")
raise
self.plot_sync_sender = Sender(self.plot_manager, self._mode)
async def _start(self) -> None:
self._refresh_lock = asyncio.Lock()
self.event_loop = asyncio.get_running_loop()
@ -171,6 +177,7 @@ class Harvester:
"plot_public_key": plot_info.plot_public_key,
"file_size": plot_info.file_size,
"time_modified": int(plot_info.time_modified),
"compression_level": prover.get_compression_level(),
}
)
self.log.debug(
@ -203,5 +210,42 @@ class Harvester:
self.plot_manager.trigger_refresh()
return True
async def get_harvester_config(self) -> Dict[str, Any]:
return get_harvester_config(self.root_path)
async def update_harvester_config(
self,
*,
use_gpu_harvesting: Optional[bool] = None,
gpu_index: Optional[int] = None,
enforce_gpu_index: Optional[bool] = None,
disable_cpu_affinity: Optional[bool] = None,
parallel_decompressor_count: Optional[int] = None,
decompressor_thread_count: Optional[int] = None,
recursive_plot_scan: Optional[bool] = None,
refresh_parameter_interval_seconds: Optional[uint32] = None,
) -> bool:
refresh_parameter: Optional[PlotsRefreshParameter] = None
if refresh_parameter_interval_seconds is not None:
refresh_parameter = PlotsRefreshParameter(
interval_seconds=refresh_parameter_interval_seconds,
retry_invalid_seconds=self.plot_manager.refresh_parameter.retry_invalid_seconds,
batch_size=self.plot_manager.refresh_parameter.batch_size,
batch_sleep_milliseconds=self.plot_manager.refresh_parameter.batch_sleep_milliseconds,
)
update_harvester_config(
self.root_path,
use_gpu_harvesting=use_gpu_harvesting,
gpu_index=gpu_index,
enforce_gpu_index=enforce_gpu_index,
disable_cpu_affinity=disable_cpu_affinity,
parallel_decompressor_count=parallel_decompressor_count,
decompressor_thread_count=decompressor_thread_count,
recursive_plot_scan=recursive_plot_scan,
refresh_parameter=refresh_parameter,
)
return True
def set_server(self, server: ChiaServer) -> None:
self._server = server

View File

@ -333,6 +333,7 @@ class HarvesterAPI:
plot["plot_public_key"],
plot["file_size"],
plot["time_modified"],
plot["compression_level"],
)
)

View File

@ -18,6 +18,7 @@ from chia.plot_sync.exceptions import (
SyncIdsMatchError,
)
from chia.plot_sync.util import ErrorCodes, State, T_PlotSyncMessage
from chia.plotting.util import HarvestingMode
from chia.protocols.harvester_protocol import (
Plot,
PlotSyncDone,
@ -85,6 +86,7 @@ class Receiver:
_total_plot_size: int
_total_effective_plot_size: int
_update_callback: ReceiverUpdateCallback
_harvesting_mode: Optional[HarvestingMode]
def __init__(
self,
@ -101,6 +103,7 @@ class Receiver:
self._total_plot_size = 0
self._total_effective_plot_size = 0
self._update_callback = update_callback
self._harvesting_mode = None
async def trigger_callback(self, update: Optional[Delta] = None) -> None:
try:
@ -118,6 +121,7 @@ class Receiver:
self._duplicates.clear()
self._total_plot_size = 0
self._total_effective_plot_size = 0
self._harvesting_mode = None
def connection(self) -> WSChiaConnection:
return self._connection
@ -149,6 +153,9 @@ class Receiver:
def total_effective_plot_size(self) -> int:
return self._total_effective_plot_size
def harvesting_mode(self) -> Optional[HarvestingMode]:
return self._harvesting_mode
async def _process(
self, method: Callable[[T_PlotSyncMessage], Any], message_type: ProtocolMessageTypes, message: T_PlotSyncMessage
) -> None:
@ -203,6 +210,7 @@ class Receiver:
self._current_sync.delta.clear()
self._current_sync.state = State.loaded
self._current_sync.plots_total = data.plot_file_count
self._harvesting_mode = HarvestingMode(data.harvesting_mode)
self._current_sync.bump_next_message_id()
async def sync_started(self, data: PlotSyncStart) -> None:
@ -370,4 +378,5 @@ class Receiver:
"total_effective_plot_size": self._total_effective_plot_size,
"syncing": syncing,
"last_sync_time": self._last_sync.time_done,
"harvesting_mode": self._harvesting_mode,
}

View File

@ -13,7 +13,7 @@ from typing_extensions import Protocol
from chia.plot_sync.exceptions import AlreadyStartedError, InvalidConnectionTypeError
from chia.plot_sync.util import Constants
from chia.plotting.manager import PlotManager
from chia.plotting.util import PlotInfo
from chia.plotting.util import HarvestingMode, PlotInfo
from chia.protocols.harvester_protocol import (
Plot,
PlotSyncDone,
@ -45,6 +45,7 @@ def _convert_plot_info_list(plot_infos: List[PlotInfo]) -> List[Plot]:
plot_public_key=plot_info.plot_public_key,
file_size=uint64(plot_info.file_size),
time_modified=uint64(int(plot_info.time_modified)),
compression_level=plot_info.prover.get_compression_level(),
)
)
return converted
@ -98,8 +99,9 @@ class Sender:
_stop_requested = False
_task: Optional[asyncio.Task[None]]
_response: Optional[ExpectedResponse]
_harvesting_mode: HarvestingMode
def __init__(self, plot_manager: PlotManager) -> None:
def __init__(self, plot_manager: PlotManager, harvesting_mode: HarvestingMode) -> None:
self._plot_manager = plot_manager
self._connection = None
self._sync_id = uint64(0)
@ -109,6 +111,7 @@ class Sender:
self._stop_requested = False
self._task = None
self._response = None
self._harvesting_mode = harvesting_mode
def __str__(self) -> str:
return f"sync_id {self._sync_id}, next_message_id {self._next_message_id}, messages {len(self._messages)}"
@ -268,7 +271,12 @@ class Sender:
log.debug(f"sync_start {sync_id}")
self._sync_id = uint64(sync_id)
self._add_message(
ProtocolMessageTypes.plot_sync_start, PlotSyncStart, initial, self._last_sync_id, uint32(int(count))
ProtocolMessageTypes.plot_sync_start,
PlotSyncStart,
initial,
self._last_sync_id,
uint32(int(count)),
self._harvesting_mode,
)
def process_batch(self, loaded: List[PlotInfo], remaining: int) -> None:

View File

@ -21,7 +21,7 @@ from chia.wallet.derive_keys import master_sk_to_local_sk
log = logging.getLogger(__name__)
CURRENT_VERSION: int = 1
CURRENT_VERSION: int = 2
@streamable

View File

@ -13,7 +13,14 @@ from chiapos import DiskProver, decompressor_context_queue
from chia.consensus.pos_quality import UI_ACTUAL_SPACE_CONSTANT_FACTOR, _expected_plot_size
from chia.plotting.cache import Cache, CacheEntry
from chia.plotting.util import PlotInfo, PlotRefreshEvents, PlotRefreshResult, PlotsRefreshParameter, get_plot_filenames
from chia.plotting.util import (
HarvestingMode,
PlotInfo,
PlotRefreshEvents,
PlotRefreshResult,
PlotsRefreshParameter,
get_plot_filenames,
)
from chia.util.misc import to_batches
log = logging.getLogger(__name__)
@ -56,7 +63,11 @@ class PlotManager:
self.no_key_filenames = set()
self.farmer_public_keys = []
self.pool_public_keys = []
self.cache = Cache(self.root_path.resolve() / "cache" / "plot_manager.dat")
# Since `compression_level` property was added to Cache structure,
# previous cache file formats needs to be reset
# When user downgrades harvester, it looks 'plot_manager.dat` while
# latest harvester reads/writes 'plot_manager_v2.dat`
self.cache = Cache(self.root_path.resolve() / "cache" / "plot_manager_v2.dat")
self.match_str = match_str
self.open_no_key_filenames = open_no_key_filenames
self.last_refresh_time = 0
@ -84,7 +95,7 @@ class PlotManager:
use_gpu_harvesting: bool,
gpu_index: int,
enforce_gpu_index: bool,
) -> None:
) -> HarvestingMode:
if max_compression_level_allowed > 7:
log.error(
"Currently only compression levels up to 7 are allowed, "
@ -107,6 +118,7 @@ class PlotManager:
f"Falling back to CPU harvesting: {context_count} decompressors count, {thread_count} threads."
)
self.max_compression_level_allowed = max_compression_level_allowed
return HarvestingMode.GPU if is_using_gpu else HarvestingMode.CPU
def reset(self) -> None:
with self:

View File

@ -2,9 +2,9 @@ from __future__ import annotations
import logging
from dataclasses import dataclass, field
from enum import Enum
from enum import Enum, IntEnum
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union
from blspy import G1Element, PrivateKey
from chiapos import DiskProver
@ -84,6 +84,11 @@ class Params:
stripe_size: int = 65536
class HarvestingMode(IntEnum):
CPU = 1
GPU = 2
def get_plot_directories(root_path: Path, config: Dict = None) -> List[str]:
if config is None:
config = load_config(root_path, "config.yaml")
@ -146,6 +151,60 @@ def remove_plot(path: Path):
path.unlink()
def get_harvester_config(root_path: Path) -> Dict[str, Any]:
config = load_config(root_path, "config.yaml")
plots_refresh_parameter = (
config["harvester"].get("plots_refresh_parameter")
if config["harvester"].get("plots_refresh_parameter") is not None
else PlotsRefreshParameter().to_json_dict()
)
return {
"use_gpu_harvesting": config["harvester"].get("use_gpu_harvesting"),
"gpu_index": config["harvester"].get("gpu_index"),
"enforce_gpu_index": config["harvester"].get("enforce_gpu_index"),
"disable_cpu_affinity": config["harvester"].get("disable_cpu_affinity"),
"parallel_decompressor_count": config["harvester"].get("parallel_decompressor_count"),
"decompressor_thread_count": config["harvester"].get("decompressor_thread_count"),
"recursive_plot_scan": config["harvester"].get("recursive_plot_scan"),
"plots_refresh_parameter": plots_refresh_parameter,
}
def update_harvester_config(
root_path: Path,
*,
use_gpu_harvesting: Optional[bool] = None,
gpu_index: Optional[int] = None,
enforce_gpu_index: Optional[bool] = None,
disable_cpu_affinity: Optional[bool] = None,
parallel_decompressor_count: Optional[int] = None,
decompressor_thread_count: Optional[int] = None,
recursive_plot_scan: Optional[bool] = None,
refresh_parameter: Optional[PlotsRefreshParameter] = None,
):
with lock_and_load_config(root_path, "config.yaml") as config:
if use_gpu_harvesting is not None:
config["harvester"]["use_gpu_harvesting"] = use_gpu_harvesting
if gpu_index is not None:
config["harvester"]["gpu_index"] = gpu_index
if enforce_gpu_index is not None:
config["harvester"]["enforce_gpu_index"] = enforce_gpu_index
if disable_cpu_affinity is not None:
config["harvester"]["disable_cpu_affinity"] = disable_cpu_affinity
if parallel_decompressor_count is not None:
config["harvester"]["parallel_decompressor_count"] = parallel_decompressor_count
if decompressor_thread_count is not None:
config["harvester"]["decompressor_thread_count"] = decompressor_thread_count
if recursive_plot_scan is not None:
config["harvester"]["recursive_plot_scan"] = recursive_plot_scan
if refresh_parameter is not None:
config["harvester"]["plots_refresh_parameter"] = refresh_parameter.to_json_dict()
save_config(root_path, "config.yaml", config)
def get_filenames(directory: Path, recursive: bool) -> List[Path]:
try:
if not directory.exists():

View File

@ -84,6 +84,7 @@ class Plot(Streamable):
plot_public_key: G1Element
file_size: uint64
time_modified: uint64
compression_level: Optional[uint8]
@streamable
@ -115,11 +116,13 @@ class PlotSyncStart(Streamable):
initial: bool
last_sync_id: uint64
plot_file_count: uint32
harvesting_mode: uint8
def __str__(self) -> str:
return (
f"PlotSyncStart: identifier {self.identifier}, initial {self.initial}, "
f"last_sync_id {self.last_sync_id}, plot_file_count {self.plot_file_count}"
f"last_sync_id {self.last_sync_id}, plot_file_count {self.plot_file_count}, "
f"harvesting_mode {self.harvesting_mode}"
)

View File

@ -4,6 +4,7 @@ from typing import Any, Dict, List, Optional
from chia.harvester.harvester import Harvester
from chia.rpc.rpc_server import Endpoint, EndpointResult
from chia.util.ints import uint32
from chia.util.ws_message import WsRpcMessage, create_payload_dict
@ -20,6 +21,8 @@ class HarvesterRpcApi:
"/add_plot_directory": self.add_plot_directory,
"/get_plot_directories": self.get_plot_directories,
"/remove_plot_directory": self.remove_plot_directory,
"/get_harvester_config": self.get_harvester_config,
"/update_harvester_config": self.update_harvester_config,
}
async def _state_changed(self, change: str, change_data: Optional[Dict[str, Any]] = None) -> List[WsRpcMessage]:
@ -78,3 +81,56 @@ class HarvesterRpcApi:
if await self.service.remove_plot_directory(directory_name):
return {}
raise ValueError(f"Did not remove plot directory {directory_name}")
async def get_harvester_config(self, _: Dict[str, Any]) -> EndpointResult:
harvester_config = await self.service.get_harvester_config()
return {
"use_gpu_harvesting": harvester_config["use_gpu_harvesting"],
"gpu_index": harvester_config["gpu_index"],
"enforce_gpu_index": harvester_config["enforce_gpu_index"],
"disable_cpu_affinity": harvester_config["disable_cpu_affinity"],
"parallel_decompressor_count": harvester_config["parallel_decompressor_count"],
"decompressor_thread_count": harvester_config["decompressor_thread_count"],
"recursive_plot_scan": harvester_config["recursive_plot_scan"],
"refresh_parameter_interval_seconds": harvester_config["plots_refresh_parameter"].get("interval_seconds"),
}
async def update_harvester_config(self, request: Dict[str, Any]) -> EndpointResult:
use_gpu_harvesting: Optional[bool] = None
gpu_index: Optional[int] = None
enforce_gpu_index: Optional[bool] = None
disable_cpu_affinity: Optional[bool] = None
parallel_decompressor_count: Optional[int] = None
decompressor_thread_count: Optional[int] = None
recursive_plot_scan: Optional[bool] = None
refresh_parameter_interval_seconds: Optional[uint32] = None
if "use_gpu_harvesting" in request:
use_gpu_harvesting = bool(request["use_gpu_harvesting"])
if "gpu_index" in request:
gpu_index = int(request["gpu_index"])
if "enforce_gpu_index" in request:
enforce_gpu_index = bool(request["enforce_gpu_index"])
if "disable_cpu_affinity" in request:
disable_cpu_affinity = bool(request["disable_cpu_affinity"])
if "parallel_decompressor_count" in request:
parallel_decompressor_count = int(request["parallel_decompressor_count"])
if "decompressor_thread_count" in request:
decompressor_thread_count = int(request["decompressor_thread_count"])
if "recursive_plot_scan" in request:
recursive_plot_scan = bool(request["recursive_plot_scan"])
if "refresh_parameter_interval_seconds" in request:
refresh_parameter_interval_seconds = uint32(request["refresh_parameter_interval_seconds"])
if refresh_parameter_interval_seconds < 3:
raise ValueError(f"Plot refresh interval seconds({refresh_parameter_interval_seconds}) is too short")
await self.service.update_harvester_config(
use_gpu_harvesting=use_gpu_harvesting,
gpu_index=gpu_index,
enforce_gpu_index=enforce_gpu_index,
disable_cpu_affinity=disable_cpu_affinity,
parallel_decompressor_count=parallel_decompressor_count,
decompressor_thread_count=decompressor_thread_count,
recursive_plot_scan=recursive_plot_scan,
refresh_parameter_interval_seconds=refresh_parameter_interval_seconds,
)
return {}

View File

@ -43,3 +43,12 @@ class HarvesterRpcClient(RpcClient):
# TODO: casting due to lack of type checked deserialization
result = cast(bool, response["success"])
return result
async def get_harvester_config(self) -> Dict[str, Any]:
return await self.fetch("get_harvester_config", {})
async def update_harvester_config(self, config: Dict[str, Any]) -> bool:
response = await self.fetch("update_harvester_config", config)
# TODO: casting due to lack of type checked deserialization
result = cast(bool, response["success"])
return result

View File

@ -1,23 +1,28 @@
from __future__ import annotations
import asyncio
from typing import List, Tuple
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
import pytest
from blspy import G1Element
from chia.cmds.cmds_util import get_any_service_client
from chia.farmer.farmer import Farmer
from chia.farmer.farmer_api import FarmerAPI
from chia.harvester.harvester import Harvester
from chia.harvester.harvester_api import HarvesterAPI
from chia.plotting.util import PlotsRefreshParameter
from chia.protocols import harvester_protocol
from chia.protocols.protocol_message_types import ProtocolMessageTypes
from chia.rpc.harvester_rpc_client import HarvesterRpcClient
from chia.server.outbound_message import NodeType, make_msg
from chia.server.start_service import Service
from chia.simulator.block_tools import BlockTools
from chia.simulator.time_out_assert import time_out_assert
from chia.types.blockchain_format.sized_bytes import bytes32
from chia.types.peer_info import UnresolvedPeerInfo
from chia.util.config import load_config
from chia.util.keychain import generate_mnemonic
from tests.conftest import HarvesterFarmerEnvironment
@ -26,6 +31,16 @@ def farmer_is_started(farmer: Farmer) -> bool:
return farmer.started
async def get_harvester_config(harvester_rpc_port: Optional[int], root_path: Path) -> Dict[str, Any]:
async with get_any_service_client(HarvesterRpcClient, harvester_rpc_port, root_path) as (harvester_client, _):
return await harvester_client.get_harvester_config()
async def update_harvester_config(harvester_rpc_port: Optional[int], root_path: Path, config: Dict[str, Any]) -> bool:
async with get_any_service_client(HarvesterRpcClient, harvester_rpc_port, root_path) as (harvester_client, _):
return await harvester_client.update_harvester_config(config)
@pytest.mark.asyncio
async def test_start_with_empty_keychain(
farmer_one_harvester_not_started: Tuple[
@ -150,3 +165,47 @@ async def test_farmer_respond_signatures(
# We fail the sps record check
expected_error = f"Do not have challenge hash {challenge_hash}"
assert expected_error in caplog.text
@pytest.mark.asyncio
async def test_harvester_config(
farmer_one_harvester: Tuple[List[Service[Harvester, HarvesterAPI]], Service[Farmer, FarmerAPI], BlockTools]
) -> None:
harvester_services, farmer_service, bt = farmer_one_harvester
harvester_service = harvester_services[0]
assert harvester_service.rpc_server and harvester_service.rpc_server.webserver
harvester_rpc_port = harvester_service.rpc_server.webserver.listen_port
harvester_config = await get_harvester_config(harvester_rpc_port, bt.root_path)
assert harvester_config["success"] is True
def check_config_match(config1: Dict[str, Any], config2: Dict[str, Any]) -> None:
assert config1["harvester"]["use_gpu_harvesting"] == config2["use_gpu_harvesting"]
assert config1["harvester"]["gpu_index"] == config2["gpu_index"]
assert config1["harvester"]["enforce_gpu_index"] == config2["enforce_gpu_index"]
assert config1["harvester"]["disable_cpu_affinity"] == config2["disable_cpu_affinity"]
assert config1["harvester"]["parallel_decompressor_count"] == config2["parallel_decompressor_count"]
assert config1["harvester"]["decompressor_thread_count"] == config2["decompressor_thread_count"]
assert config1["harvester"]["recursive_plot_scan"] == config2["recursive_plot_scan"]
assert (
config2["refresh_parameter_interval_seconds"] == config1["harvester"]["refresh_parameter_interval_seconds"]
if "refresh_parameter_interval_seconds" in config1["harvester"]
else PlotsRefreshParameter().interval_seconds
)
check_config_match(bt.config, harvester_config)
harvester_config["use_gpu_harvesting"] = not harvester_config["use_gpu_harvesting"]
harvester_config["gpu_index"] += 1
harvester_config["enforce_gpu_index"] = not harvester_config["enforce_gpu_index"]
harvester_config["disable_cpu_affinity"] = not harvester_config["disable_cpu_affinity"]
harvester_config["parallel_decompressor_count"] += 1
harvester_config["decompressor_thread_count"] += 1
harvester_config["recursive_plot_scan"] = not harvester_config["recursive_plot_scan"]
harvester_config["refresh_parameter_interval_seconds"] = harvester_config["refresh_parameter_interval_seconds"] + 1
res = await update_harvester_config(harvester_rpc_port, bt.root_path, harvester_config)
assert res is True
new_config = load_config(harvester_service.root_path, "config.yaml")
check_config_match(new_config, harvester_config)

View File

@ -15,7 +15,17 @@ log = logging.getLogger(__name__)
def dummy_plot(path: str) -> Plot:
return Plot(path, uint8(32), bytes32(b"\00" * 32), G1Element(), None, G1Element(), uint64(0), uint64(0))
return Plot(
filename=path,
size=uint8(32),
plot_id=bytes32(b"\00" * 32),
pool_public_key=G1Element(),
pool_contract_puzzle_hash=None,
plot_public_key=G1Element(),
file_size=uint64(0),
time_modified=uint64(0),
compression_level=uint8(0),
)
@pytest.mark.parametrize(

View File

@ -73,6 +73,7 @@ class ExpectedResult:
G1Element(),
uint64(0),
uint64(0),
uint8(0),
)
self.valid_count += len(list_plots)
@ -193,6 +194,7 @@ class Environment:
assert plot.prover.get_filename() == delta.valid.additions[path].filename
assert plot.prover.get_size() == delta.valid.additions[path].size
assert plot.prover.get_id() == delta.valid.additions[path].plot_id
assert plot.prover.get_compression_level() == delta.valid.additions[path].compression_level
assert plot.pool_public_key == delta.valid.additions[path].pool_public_key
assert plot.pool_contract_puzzle_hash == delta.valid.additions[path].pool_contract_puzzle_hash
assert plot.plot_public_key == delta.valid.additions[path].plot_public_key
@ -253,6 +255,7 @@ class Environment:
assert plot_info.prover.get_filename() == receiver.plots()[str(path)].filename
assert plot_info.prover.get_size() == receiver.plots()[str(path)].size
assert plot_info.prover.get_id() == receiver.plots()[str(path)].plot_id
assert plot_info.prover.get_compression_level() == receiver.plots()[str(path)].compression_level
assert plot_info.pool_public_key == receiver.plots()[str(path)].pool_public_key
assert plot_info.pool_contract_puzzle_hash == receiver.plots()[str(path)].pool_contract_puzzle_hash
assert plot_info.plot_public_key == receiver.plots()[str(path)].plot_public_key

View File

@ -14,6 +14,7 @@ from chia.consensus.pos_quality import UI_ACTUAL_SPACE_CONSTANT_FACTOR, _expecte
from chia.plot_sync.delta import Delta
from chia.plot_sync.receiver import Receiver, Sync
from chia.plot_sync.util import ErrorCodes, State
from chia.plotting.util import HarvestingMode
from chia.protocols.harvester_protocol import (
Plot,
PlotSyncDone,
@ -44,6 +45,7 @@ def assert_default_values(receiver: Receiver) -> None:
assert receiver.duplicates() == []
assert receiver.total_plot_size() == 0
assert receiver.total_effective_plot_size() == 0
assert receiver.harvesting_mode() is None
async def dummy_callback(_: bytes32, __: Delta) -> None:
@ -177,6 +179,7 @@ def plot_sync_setup() -> Tuple[Receiver, List[SyncStepData]]:
plot_public_key=G1Element(),
file_size=uint64(random.randint(0, 100)),
time_modified=uint64(0),
compression_level=uint8(0),
)
for x in path_list
]
@ -188,7 +191,15 @@ def plot_sync_setup() -> Tuple[Receiver, List[SyncStepData]]:
sum(UI_ACTUAL_SPACE_CONSTANT_FACTOR * int(_expected_plot_size(plot.size)) for plot in receiver.plots().values())
)
sync_steps: List[SyncStepData] = [
SyncStepData(State.idle, receiver.sync_started, PlotSyncStart, False, uint64(0), uint32(len(plot_info_list))),
SyncStepData(
State.idle,
receiver.sync_started,
PlotSyncStart,
False,
uint64(0),
uint32(len(plot_info_list)),
uint8(HarvestingMode.CPU),
),
SyncStepData(State.loaded, receiver.process_loaded, PlotSyncPlotList, plot_info_list[10:20], True),
SyncStepData(State.removed, receiver.process_removed, PlotSyncPathList, path_list[0:10], True),
SyncStepData(State.invalid, receiver.process_invalid, PlotSyncPathList, path_list[20:30], True),
@ -258,6 +269,7 @@ async def test_to_dict(counts_only: bool) -> None:
"host": receiver.connection().peer_info.host,
"port": receiver.connection().peer_info.port,
}
assert plot_sync_dict_1["harvesting_mode"] is None
# We should get equal dicts
assert plot_sync_dict_1 == receiver.to_dict(counts_only)
@ -299,6 +311,7 @@ async def test_to_dict(counts_only: bool) -> None:
)
assert plot_sync_dict_3["last_sync_time"] > 0
assert plot_sync_dict_3["syncing"] is None
assert sync_steps[State.idle].args[3] == plot_sync_dict_3["harvesting_mode"]
# Trigger a repeated plot sync
await receiver.sync_started(
@ -307,6 +320,7 @@ async def test_to_dict(counts_only: bool) -> None:
False,
receiver.last_sync().sync_id,
uint32(1),
uint8(HarvestingMode.CPU),
)
)
assert receiver.to_dict()["syncing"] == {
@ -369,13 +383,13 @@ async def test_invalid_ids() -> None:
receiver._last_sync.sync_id = uint64(1)
# Test "sync_started last doesn't match"
invalid_last_sync_id_param = PlotSyncStart(
plot_sync_identifier(uint64(0), uint64(0)), False, uint64(2), uint32(0)
plot_sync_identifier(uint64(0), uint64(0)), False, uint64(2), uint32(0), uint8(HarvestingMode.CPU)
)
await current_step.function(invalid_last_sync_id_param)
assert_error_response(receiver, ErrorCodes.invalid_last_sync_id)
# Test "last_sync_id == new_sync_id"
invalid_sync_id_match_param = PlotSyncStart(
plot_sync_identifier(uint64(1), uint64(0)), False, uint64(1), uint32(0)
plot_sync_identifier(uint64(1), uint64(0)), False, uint64(1), uint32(0), uint8(HarvestingMode.CPU)
)
await current_step.function(invalid_sync_id_match_param)
assert_error_response(receiver, ErrorCodes.sync_ids_match)

View File

@ -5,6 +5,7 @@ import pytest
from chia.plot_sync.exceptions import AlreadyStartedError, InvalidConnectionTypeError
from chia.plot_sync.sender import ExpectedResponse, Sender
from chia.plot_sync.util import Constants
from chia.plotting.util import HarvestingMode
from chia.protocols.harvester_protocol import PlotSyncIdentifier, PlotSyncResponse
from chia.protocols.protocol_message_types import ProtocolMessageTypes
from chia.server.outbound_message import NodeType
@ -14,7 +15,7 @@ from tests.plot_sync.util import get_dummy_connection, plot_sync_identifier
def test_default_values(bt: BlockTools) -> None:
sender = Sender(bt.plot_manager)
sender = Sender(bt.plot_manager, HarvestingMode.CPU)
assert sender._plot_manager == bt.plot_manager
assert sender._connection is None
assert sender._sync_id == uint64(0)
@ -24,11 +25,12 @@ def test_default_values(bt: BlockTools) -> None:
assert not sender._stop_requested
assert sender._task is None
assert sender._response is None
assert sender._harvesting_mode == HarvestingMode.CPU
def test_set_connection_values(bt: BlockTools) -> None:
farmer_connection = get_dummy_connection(NodeType.FARMER)
sender = Sender(bt.plot_manager)
sender = Sender(bt.plot_manager, HarvestingMode.CPU)
# Test invalid NodeType values
for connection_type in NodeType:
if connection_type != NodeType.FARMER:
@ -45,7 +47,7 @@ def test_set_connection_values(bt: BlockTools) -> None:
@pytest.mark.asyncio
async def test_start_stop_send_task(bt: BlockTools) -> None:
sender = Sender(bt.plot_manager)
sender = Sender(bt.plot_manager, HarvestingMode.CPU)
# Make sure starting/restarting works
for _ in range(2):
assert sender._task is None
@ -62,7 +64,7 @@ async def test_start_stop_send_task(bt: BlockTools) -> None:
def test_set_response(bt: BlockTools) -> None:
sender = Sender(bt.plot_manager)
sender = Sender(bt.plot_manager, HarvestingMode.CPU)
def new_expected_response(sync_id: int, message_id: int, message_type: ProtocolMessageTypes) -> ExpectedResponse:
return ExpectedResponse(message_type, plot_sync_identifier(uint64(sync_id), uint64(message_id)))

View File

@ -30,7 +30,7 @@ from chia.server.ws_connection import WSChiaConnection
from chia.simulator.block_tools import BlockTools
from chia.simulator.time_out_assert import time_out_assert
from chia.types.blockchain_format.sized_bytes import bytes32
from chia.util.ints import int16, uint64
from chia.util.ints import int16, uint8, uint64
from chia.util.misc import to_batches
from tests.plot_sync.util import start_harvester_service
@ -275,6 +275,9 @@ def create_example_plots(count: int) -> List[PlotInfo]:
def get_size(self) -> int:
return self.size
def get_compression_level(self) -> uint8:
return uint8(0)
return [
PlotInfo(
prover=DiskProver(f"{x}", bytes32(token_bytes(32)), 25 + x % 26),

View File

@ -502,6 +502,7 @@ async def test_plot_info_caching(environment, bt):
assert plot_manager.plots[path].prover.get_id() == plot_info.prover.get_id()
assert plot_manager.plots[path].prover.get_memo() == plot_info.prover.get_memo()
assert plot_manager.plots[path].prover.get_size() == plot_info.prover.get_size()
assert plot_manager.plots[path].prover.get_compression_level() == plot_info.prover.get_compression_level()
assert plot_manager.plots[path].pool_public_key == plot_info.pool_public_key
assert plot_manager.plots[path].pool_contract_puzzle_hash == plot_info.pool_contract_puzzle_hash
assert plot_manager.plots[path].plot_public_key == plot_info.plot_public_key

View File

@ -759,6 +759,7 @@ plot = harvester_protocol.Plot(
),
uint64(3368414292564311420),
uint64(2573238947935295522),
uint8(0),
)
request_plots = harvester_protocol.RequestPlots()

View File

@ -2073,6 +2073,7 @@ plot_json: Dict[str, Any] = {
"plot_public_key": "0xa04c6b5ac7dfb935f6feecfdd72348ccf1d4be4fe7e26acf271ea3b7d308da61e0a308f7a62495328a81f5147b66634c",
"file_size": 3368414292564311420,
"time_modified": 2573238947935295522,
"compression_level": 0,
}
request_plots_json: Dict[str, Any] = {}
@ -2088,6 +2089,7 @@ respond_plots_json: Dict[str, Any] = {
"plot_public_key": "0xa04c6b5ac7dfb935f6feecfdd72348ccf1d4be4fe7e26acf271ea3b7d308da61e0a308f7a62495328a81f5147b66634c",
"file_size": 3368414292564311420,
"time_modified": 2573238947935295522,
"compression_level": 0,
}
],
"failed_to_open_filenames": ["str"],