diff --git a/.github/workflows/build-test-macos-plot_sync.yml b/.github/workflows/build-test-macos-plot_sync.yml new file mode 100644 index 000000000000..d18e76cf7d20 --- /dev/null +++ b/.github/workflows/build-test-macos-plot_sync.yml @@ -0,0 +1,107 @@ +# +# THIS FILE IS GENERATED. SEE https://github.com/Chia-Network/chia-blockchain/tree/main/tests#readme +# +name: MacOS plot_sync Tests + +on: + push: + branches: + - main + tags: + - '**' + pull_request: + branches: + - '**' + +concurrency: + # SHA is added to the end if on `main` to let all main workflows run + group: ${{ github.ref }}-${{ github.workflow }}-${{ github.event_name }}-${{ github.ref == 'refs/heads/main' && github.sha || '' }} + cancel-in-progress: true + +jobs: + build: + name: MacOS plot_sync Tests + runs-on: ${{ matrix.os }} + timeout-minutes: 30 + strategy: + fail-fast: false + max-parallel: 4 + matrix: + python-version: [3.8, 3.9] + os: [macOS-latest] + env: + CHIA_ROOT: ${{ github.workspace }}/.chia/mainnet + JOB_FILE_NAME: tests_${{ matrix.os }}_python-${{ matrix.python-version }}_plot_sync + + steps: + - name: Checkout Code + uses: actions/checkout@v3 + with: + fetch-depth: 0 + + - name: Setup Python environment + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + + - name: Create keychain for CI use + run: | + security create-keychain -p foo chiachain + security default-keychain -s chiachain + security unlock-keychain -p foo chiachain + security set-keychain-settings -t 7200 -u chiachain + + - name: Get pip cache dir + id: pip-cache + run: | + echo "::set-output name=dir::$(pip cache dir)" + + - name: Cache pip + uses: actions/cache@v3 + with: + # Note that new runners may break this https://github.com/actions/cache/issues/292 + path: ${{ steps.pip-cache.outputs.dir }} + key: ${{ runner.os }}-pip-${{ hashFiles('**/setup.py') }} + restore-keys: | + ${{ runner.os }}-pip- + + - name: Checkout test blocks and plots + uses: actions/checkout@v3 + with: + repository: 'Chia-Network/test-cache' + path: '.chia' + ref: '0.28.0' + fetch-depth: 1 + + - name: Run install script + env: + INSTALL_PYTHON_VERSION: ${{ matrix.python-version }} + run: | + brew install boost + sh install.sh -d + +# Omitted installing Timelord + + - name: Test plot_sync code with pytest + run: | + . ./activate + venv/bin/coverage run --rcfile=.coveragerc ./venv/bin/py.test tests/plot_sync/test_*.py --durations=10 -n 4 -m "not benchmark" + + - name: Process coverage data + run: | + venv/bin/coverage combine --rcfile=.coveragerc .coverage.* + venv/bin/coverage xml --rcfile=.coveragerc -o coverage.xml + mkdir coverage_reports + cp .coverage "coverage_reports/.coverage.${{ env.JOB_FILE_NAME }}" + cp coverage.xml "coverage_reports/coverage.${{ env.JOB_FILE_NAME }}.xml" + venv/bin/coverage report --rcfile=.coveragerc --show-missing + + - name: Publish coverage + uses: actions/upload-artifact@v2 + with: + name: coverage + path: coverage_reports/* + if-no-files-found: error +# +# THIS FILE IS GENERATED. SEE https://github.com/Chia-Network/chia-blockchain/tree/main/tests#readme +# diff --git a/.github/workflows/build-test-ubuntu-plot_sync.yml b/.github/workflows/build-test-ubuntu-plot_sync.yml new file mode 100644 index 000000000000..886573b5743d --- /dev/null +++ b/.github/workflows/build-test-ubuntu-plot_sync.yml @@ -0,0 +1,109 @@ +# +# THIS FILE IS GENERATED. SEE https://github.com/Chia-Network/chia-blockchain/tree/main/tests#readme +# +name: Ubuntu plot_sync Test + +on: + push: + branches: + - main + tags: + - '**' + pull_request: + branches: + - '**' + +concurrency: + # SHA is added to the end if on `main` to let all main workflows run + group: ${{ github.ref }}-${{ github.workflow }}-${{ github.event_name }}-${{ github.ref == 'refs/heads/main' && github.sha || '' }} + cancel-in-progress: true + +jobs: + build: + name: Ubuntu plot_sync Test + runs-on: ${{ matrix.os }} + timeout-minutes: 30 + strategy: + fail-fast: false + max-parallel: 4 + matrix: + python-version: [3.7, 3.8, 3.9] + os: [ubuntu-latest] + env: + CHIA_ROOT: ${{ github.workspace }}/.chia/mainnet + JOB_FILE_NAME: tests_${{ matrix.os }}_python-${{ matrix.python-version }}_plot_sync + + steps: + - name: Checkout Code + uses: actions/checkout@v3 + with: + fetch-depth: 0 + + - name: Setup Python environment + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + + - name: Cache npm + uses: actions/cache@v3 + with: + path: ~/.npm + key: ${{ runner.os }}-node-${{ hashFiles('**/package-lock.json') }} + restore-keys: | + ${{ runner.os }}-node- + + - name: Get pip cache dir + id: pip-cache + run: | + echo "::set-output name=dir::$(pip cache dir)" + + - name: Cache pip + uses: actions/cache@v3 + with: + path: ${{ steps.pip-cache.outputs.dir }} + key: ${{ runner.os }}-pip-${{ hashFiles('**/setup.py') }} + restore-keys: | + ${{ runner.os }}-pip- + + - name: Checkout test blocks and plots + uses: actions/checkout@v3 + with: + repository: 'Chia-Network/test-cache' + path: '.chia' + ref: '0.28.0' + fetch-depth: 1 + + - name: Run install script + env: + INSTALL_PYTHON_VERSION: ${{ matrix.python-version }} + run: | + sh install.sh -d + +# Omitted installing Timelord + + - name: Test plot_sync code with pytest + run: | + . ./activate + venv/bin/coverage run --rcfile=.coveragerc ./venv/bin/py.test tests/plot_sync/test_*.py --durations=10 -n 4 -m "not benchmark" -p no:monitor + + - name: Process coverage data + run: | + venv/bin/coverage combine --rcfile=.coveragerc .coverage.* + venv/bin/coverage xml --rcfile=.coveragerc -o coverage.xml + mkdir coverage_reports + cp .coverage "coverage_reports/.coverage.${{ env.JOB_FILE_NAME }}" + cp coverage.xml "coverage_reports/coverage.${{ env.JOB_FILE_NAME }}.xml" + venv/bin/coverage report --rcfile=.coveragerc --show-missing + + - name: Publish coverage + uses: actions/upload-artifact@v2 + with: + name: coverage + path: coverage_reports/* + if-no-files-found: error + +# Omitted resource usage check + +# +# THIS FILE IS GENERATED. SEE https://github.com/Chia-Network/chia-blockchain/tree/main/tests#readme +# diff --git a/chia/farmer/farmer.py b/chia/farmer/farmer.py index 33f3b973a727..f7e50e9ed89e 100644 --- a/chia/farmer/farmer.py +++ b/chia/farmer/farmer.py @@ -18,6 +18,8 @@ from chia.daemon.keychain_proxy import ( connect_to_keychain_and_validate, wrap_local_keychain, ) +from chia.plot_sync.receiver import Receiver +from chia.plot_sync.delta import Delta from chia.pools.pool_config import PoolWalletConfig, load_pool_config, add_auth_key from chia.protocols import farmer_protocol, harvester_protocol from chia.protocols.pool_protocol import ( @@ -59,32 +61,12 @@ log = logging.getLogger(__name__) UPDATE_POOL_INFO_INTERVAL: int = 3600 UPDATE_POOL_FARMER_INFO_INTERVAL: int = 300 -UPDATE_HARVESTER_CACHE_INTERVAL: int = 90 """ HARVESTER PROTOCOL (FARMER <-> HARVESTER) """ -class HarvesterCacheEntry: - def __init__(self): - self.data: Optional[dict] = None - self.last_update: float = 0 - - def bump_last_update(self): - self.last_update = time.time() - - def set_data(self, data): - self.data = data - self.bump_last_update() - - def needs_update(self, update_interval: int): - return time.time() - self.last_update > update_interval - - def expired(self, update_interval: int): - return time.time() - self.last_update > update_interval * 10 - - class Farmer: def __init__( self, @@ -115,8 +97,7 @@ class Farmer: # to periodically clear the memory self.cache_add_time: Dict[bytes32, uint64] = {} - # Interval to request plots from connected harvesters - self.update_harvester_cache_interval = UPDATE_HARVESTER_CACHE_INTERVAL + self.plot_sync_receivers: Dict[bytes32, Receiver] = {} self.cache_clear_task: Optional[asyncio.Task] = None self.update_pool_state_task: Optional[asyncio.Task] = None @@ -137,8 +118,6 @@ class Farmer: # Last time we updated pool_state based on the config file self.last_config_access_time: uint64 = uint64(0) - self.harvester_cache: Dict[str, Dict[str, HarvesterCacheEntry]] = {} - async def ensure_keychain_proxy(self) -> KeychainProxy: if self.keychain_proxy is None: if self.local_keychain: @@ -256,6 +235,7 @@ class Farmer: self.harvester_handshake_task = None if peer.connection_type is NodeType.HARVESTER: + self.plot_sync_receivers[peer.peer_node_id] = Receiver(peer, self.plot_sync_callback) self.harvester_handshake_task = asyncio.create_task(handshake_task()) def set_server(self, server): @@ -274,6 +254,13 @@ class Farmer: def on_disconnect(self, connection: ws.WSChiaConnection): self.log.info(f"peer disconnected {connection.get_peer_logging()}") self.state_changed("close_connection", {}) + if connection.connection_type is NodeType.HARVESTER: + del self.plot_sync_receivers[connection.peer_node_id] + + async def plot_sync_callback(self, peer_id: bytes32, delta: Delta) -> None: + log.info(f"plot_sync_callback: peer_id {peer_id}, delta {delta}") + if not delta.empty(): + self.state_changed("new_plots", await self.get_harvesters()) async def _pool_get_pool_info(self, pool_config: PoolWalletConfig) -> Optional[Dict]: try: @@ -642,80 +629,17 @@ class Farmer: return None - async def update_cached_harvesters(self) -> bool: - # First remove outdated cache entries - self.log.debug(f"update_cached_harvesters cache entries: {len(self.harvester_cache)}") - remove_hosts = [] - for host, host_cache in self.harvester_cache.items(): - remove_peers = [] - for peer_id, peer_cache in host_cache.items(): - # If the peer cache is expired it means the harvester didn't respond for too long - if peer_cache.expired(self.update_harvester_cache_interval): - remove_peers.append(peer_id) - for key in remove_peers: - del host_cache[key] - if len(host_cache) == 0: - self.log.debug(f"update_cached_harvesters remove host: {host}") - remove_hosts.append(host) - for key in remove_hosts: - del self.harvester_cache[key] - # Now query each harvester and update caches - updated = False - for connection in self.server.get_connections(NodeType.HARVESTER): - cache_entry = await self.get_cached_harvesters(connection) - if cache_entry.needs_update(self.update_harvester_cache_interval): - self.log.debug(f"update_cached_harvesters update harvester: {connection.peer_node_id}") - cache_entry.bump_last_update() - response = await connection.request_plots( - harvester_protocol.RequestPlots(), timeout=self.update_harvester_cache_interval - ) - if response is not None: - if isinstance(response, harvester_protocol.RespondPlots): - new_data: Dict = response.to_json_dict() - if cache_entry.data != new_data: - updated = True - self.log.debug(f"update_cached_harvesters cache updated: {connection.peer_node_id}") - else: - self.log.debug(f"update_cached_harvesters no changes for: {connection.peer_node_id}") - cache_entry.set_data(new_data) - else: - self.log.error( - f"Invalid response from harvester:" - f"peer_host {connection.peer_host}, peer_node_id {connection.peer_node_id}" - ) - else: - self.log.error( - f"Harvester '{connection.peer_host}/{connection.peer_node_id}' did not respond: " - f"(version mismatch or time out {UPDATE_HARVESTER_CACHE_INTERVAL}s)" - ) - return updated - - async def get_cached_harvesters(self, connection: WSChiaConnection) -> HarvesterCacheEntry: - host_cache = self.harvester_cache.get(connection.peer_host) - if host_cache is None: - host_cache = {} - self.harvester_cache[connection.peer_host] = host_cache - node_cache = host_cache.get(connection.peer_node_id.hex()) - if node_cache is None: - node_cache = HarvesterCacheEntry() - host_cache[connection.peer_node_id.hex()] = node_cache - return node_cache - async def get_harvesters(self) -> Dict: harvesters: List = [] for connection in self.server.get_connections(NodeType.HARVESTER): self.log.debug(f"get_harvesters host: {connection.peer_host}, node_id: {connection.peer_node_id}") - cache_entry = await self.get_cached_harvesters(connection) - if cache_entry.data is not None: - harvester_object: dict = dict(cache_entry.data) - harvester_object["connection"] = { - "node_id": connection.peer_node_id.hex(), - "host": connection.peer_host, - "port": connection.peer_port, - } - harvesters.append(harvester_object) + receiver = self.plot_sync_receivers.get(connection.peer_node_id) + if receiver is not None: + harvesters.append(receiver.to_dict()) else: - self.log.debug(f"get_harvesters no cache: {connection.peer_host}, node_id: {connection.peer_node_id}") + self.log.debug( + f"get_harvesters invalid peer: {connection.peer_host}, node_id: {connection.peer_node_id}" + ) return {"harvesters": harvesters} @@ -766,9 +690,6 @@ class Farmer: self.state_changed("add_connection", {}) refresh_slept = 0 - # Handles harvester plots cache cleanup and updates - if await self.update_cached_harvesters(): - self.state_changed("new_plots", await self.get_harvesters()) except Exception: log.error(f"_periodically_clear_cache_and_refresh_task failed: {traceback.format_exc()}") diff --git a/chia/farmer/farmer_api.py b/chia/farmer/farmer_api.py index d0cb11577e15..9e94a5052334 100644 --- a/chia/farmer/farmer_api.py +++ b/chia/farmer/farmer_api.py @@ -11,7 +11,13 @@ from chia.consensus.network_type import NetworkType from chia.consensus.pot_iterations import calculate_iterations_quality, calculate_sp_interval_iters from chia.farmer.farmer import Farmer from chia.protocols import farmer_protocol, harvester_protocol -from chia.protocols.harvester_protocol import PoolDifficulty +from chia.protocols.harvester_protocol import ( + PoolDifficulty, + PlotSyncStart, + PlotSyncPlotList, + PlotSyncPathList, + PlotSyncDone, +) from chia.protocols.pool_protocol import ( get_current_authentication_token, PoolErrorCode, @@ -518,3 +524,38 @@ class FarmerAPI: @peer_required async def respond_plots(self, _: harvester_protocol.RespondPlots, peer: ws.WSChiaConnection): self.farmer.log.warning(f"Respond plots came too late from: {peer.get_peer_logging()}") + + @api_request + @peer_required + async def plot_sync_start(self, message: PlotSyncStart, peer: ws.WSChiaConnection): + await self.farmer.plot_sync_receivers[peer.peer_node_id].sync_started(message) + + @api_request + @peer_required + async def plot_sync_loaded(self, message: PlotSyncPlotList, peer: ws.WSChiaConnection): + await self.farmer.plot_sync_receivers[peer.peer_node_id].process_loaded(message) + + @api_request + @peer_required + async def plot_sync_removed(self, message: PlotSyncPathList, peer: ws.WSChiaConnection): + await self.farmer.plot_sync_receivers[peer.peer_node_id].process_removed(message) + + @api_request + @peer_required + async def plot_sync_invalid(self, message: PlotSyncPathList, peer: ws.WSChiaConnection): + await self.farmer.plot_sync_receivers[peer.peer_node_id].process_invalid(message) + + @api_request + @peer_required + async def plot_sync_keys_missing(self, message: PlotSyncPathList, peer: ws.WSChiaConnection): + await self.farmer.plot_sync_receivers[peer.peer_node_id].process_keys_missing(message) + + @api_request + @peer_required + async def plot_sync_duplicates(self, message: PlotSyncPathList, peer: ws.WSChiaConnection): + await self.farmer.plot_sync_receivers[peer.peer_node_id].process_duplicates(message) + + @api_request + @peer_required + async def plot_sync_done(self, message: PlotSyncDone, peer: ws.WSChiaConnection): + await self.farmer.plot_sync_receivers[peer.peer_node_id].sync_done(message) diff --git a/chia/harvester/harvester.py b/chia/harvester/harvester.py index b63b0c7a16d7..fab6a77da9ca 100644 --- a/chia/harvester/harvester.py +++ b/chia/harvester/harvester.py @@ -8,6 +8,7 @@ from typing import Callable, Dict, List, Optional, Tuple import chia.server.ws_connection as ws # lgtm [py/import-and-import-from] from chia.consensus.constants import ConsensusConstants +from chia.plot_sync.sender import Sender from chia.plotting.manager import PlotManager from chia.plotting.util import ( add_plot_directory, @@ -25,6 +26,7 @@ log = logging.getLogger(__name__) class Harvester: plot_manager: PlotManager + plot_sync_sender: Sender root_path: Path _is_shutdown: bool executor: ThreadPoolExecutor @@ -53,6 +55,7 @@ class Harvester: self.plot_manager = PlotManager( root_path, refresh_parameter=refresh_parameter, refresh_callback=self._plot_refresh_callback ) + self.plot_sync_sender = Sender(self.plot_manager) self._is_shutdown = False self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=config["num_threads"]) self.state_changed_callback = None @@ -70,9 +73,11 @@ class Harvester: self._is_shutdown = True self.executor.shutdown(wait=True) self.plot_manager.stop_refreshing() + self.plot_manager.reset() + self.plot_sync_sender.stop() async def _await_closed(self): - pass + await self.plot_sync_sender.await_closed() def _set_state_changed_callback(self, callback: Callable): self.state_changed_callback = callback @@ -90,12 +95,18 @@ class Harvester: f"duration: {update_result.duration:.2f} seconds, " f"total plots: {len(self.plot_manager.plots)}" ) - if len(update_result.loaded) > 0: - self.event_loop.call_soon_threadsafe(self._state_changed, "plots") + if event == PlotRefreshEvents.started: + self.plot_sync_sender.sync_start(update_result.remaining, self.plot_manager.initial_refresh()) + if event == PlotRefreshEvents.batch_processed: + self.plot_sync_sender.process_batch(update_result.loaded, update_result.remaining) + if event == PlotRefreshEvents.done: + self.plot_sync_sender.sync_done(update_result.removed, update_result.duration) def on_disconnect(self, connection: ws.WSChiaConnection): self.log.info(f"peer disconnected {connection.get_peer_logging()}") self._state_changed("close_connection") + self.plot_manager.stop_refreshing() + self.plot_sync_sender.stop() def get_plots(self) -> Tuple[List[Dict], List[str], List[str]]: self.log.debug(f"get_plots prover items: {self.plot_manager.plot_count()}") diff --git a/chia/harvester/harvester_api.py b/chia/harvester/harvester_api.py index 760a57cb8c45..88528678594d 100644 --- a/chia/harvester/harvester_api.py +++ b/chia/harvester/harvester_api.py @@ -10,7 +10,7 @@ from chia.harvester.harvester import Harvester from chia.plotting.util import PlotInfo, parse_plot_info from chia.protocols import harvester_protocol from chia.protocols.farmer_protocol import FarmingInfo -from chia.protocols.harvester_protocol import Plot +from chia.protocols.harvester_protocol import Plot, PlotSyncResponse from chia.protocols.protocol_message_types import ProtocolMessageTypes from chia.server.outbound_message import make_msg from chia.server.ws_connection import WSChiaConnection @@ -30,8 +30,11 @@ class HarvesterAPI: def _set_state_changed_callback(self, callback: Callable): self.harvester.state_changed_callback = callback + @peer_required @api_request - async def harvester_handshake(self, harvester_handshake: harvester_protocol.HarvesterHandshake): + async def harvester_handshake( + self, harvester_handshake: harvester_protocol.HarvesterHandshake, peer: WSChiaConnection + ): """ Handshake between the harvester and farmer. The harvester receives the pool public keys, as well as the farmer pks, which must be put into the plots, before the plotting process begins. @@ -40,7 +43,8 @@ class HarvesterAPI: self.harvester.plot_manager.set_public_keys( harvester_handshake.farmer_public_keys, harvester_handshake.pool_public_keys ) - + self.harvester.plot_sync_sender.set_connection(peer) + await self.harvester.plot_sync_sender.start() self.harvester.plot_manager.start_refreshing() @peer_required @@ -289,3 +293,7 @@ class HarvesterAPI: response = harvester_protocol.RespondPlots(plots_response, failed_to_open_filenames, no_key_filenames) return make_msg(ProtocolMessageTypes.respond_plots, response) + + @api_request + async def plot_sync_response(self, response: PlotSyncResponse): + self.harvester.plot_sync_sender.set_response(response) diff --git a/chia/plot_sync/__init__.py b/chia/plot_sync/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/chia/plot_sync/delta.py b/chia/plot_sync/delta.py new file mode 100644 index 000000000000..6a797ffc33f1 --- /dev/null +++ b/chia/plot_sync/delta.py @@ -0,0 +1,59 @@ +from dataclasses import dataclass, field +from typing import Dict, List, Union + +from chia.protocols.harvester_protocol import Plot + + +@dataclass +class DeltaType: + additions: Union[Dict[str, Plot], List[str]] + removals: List[str] + + def __str__(self) -> str: + return f"+{len(self.additions)}/-{len(self.removals)}" + + def clear(self) -> None: + self.additions.clear() + self.removals.clear() + + def empty(self) -> bool: + return len(self.additions) == 0 and len(self.removals) == 0 + + +@dataclass +class PlotListDelta(DeltaType): + additions: Dict[str, Plot] = field(default_factory=dict) + removals: List[str] = field(default_factory=list) + + +@dataclass +class PathListDelta(DeltaType): + additions: List[str] = field(default_factory=list) + removals: List[str] = field(default_factory=list) + + @staticmethod + def from_lists(old: List[str], new: List[str]) -> "PathListDelta": + return PathListDelta([x for x in new if x not in old], [x for x in old if x not in new]) + + +@dataclass +class Delta: + valid: PlotListDelta = field(default_factory=PlotListDelta) + invalid: PathListDelta = field(default_factory=PathListDelta) + keys_missing: PathListDelta = field(default_factory=PathListDelta) + duplicates: PathListDelta = field(default_factory=PathListDelta) + + def empty(self) -> bool: + return self.valid.empty() and self.invalid.empty() and self.keys_missing.empty() and self.duplicates.empty() + + def __str__(self) -> str: + return ( + f"valid {self.valid}, invalid {self.invalid}, keys missing: {self.keys_missing}, " + f"duplicates: {self.duplicates}" + ) + + def clear(self) -> None: + self.valid.clear() + self.invalid.clear() + self.keys_missing.clear() + self.duplicates.clear() diff --git a/chia/plot_sync/exceptions.py b/chia/plot_sync/exceptions.py new file mode 100644 index 000000000000..e972a2a2935d --- /dev/null +++ b/chia/plot_sync/exceptions.py @@ -0,0 +1,54 @@ +from typing import Any + +from chia.plot_sync.util import ErrorCodes, State +from chia.protocols.harvester_protocol import PlotSyncIdentifier +from chia.server.ws_connection import NodeType +from chia.util.ints import uint64 + + +class PlotSyncException(Exception): + def __init__(self, message: str, error_code: ErrorCodes) -> None: + super().__init__(message) + self.error_code = error_code + + +class AlreadyStartedError(Exception): + def __init__(self) -> None: + super().__init__("Already started!") + + +class InvalidValueError(PlotSyncException): + def __init__(self, message: str, actual: Any, expected: Any, error_code: ErrorCodes) -> None: + super().__init__(f"{message}: Actual {actual}, Expected {expected}", error_code) + + +class InvalidIdentifierError(InvalidValueError): + def __init__(self, actual_identifier: PlotSyncIdentifier, expected_identifier: PlotSyncIdentifier) -> None: + super().__init__("Invalid identifier", actual_identifier, expected_identifier, ErrorCodes.invalid_identifier) + self.actual_identifier: PlotSyncIdentifier = actual_identifier + self.expected_identifier: PlotSyncIdentifier = expected_identifier + + +class InvalidLastSyncIdError(InvalidValueError): + def __init__(self, actual: uint64, expected: uint64) -> None: + super().__init__("Invalid last-sync-id", actual, expected, ErrorCodes.invalid_last_sync_id) + + +class InvalidConnectionTypeError(InvalidValueError): + def __init__(self, actual: NodeType, expected: NodeType) -> None: + super().__init__("Unexpected connection type", actual, expected, ErrorCodes.invalid_connection_type) + + +class PlotAlreadyAvailableError(PlotSyncException): + def __init__(self, state: State, path: str) -> None: + super().__init__(f"{state.name}: Plot already available - {path}", ErrorCodes.plot_already_available) + + +class PlotNotAvailableError(PlotSyncException): + def __init__(self, state: State, path: str) -> None: + super().__init__(f"{state.name}: Plot not available - {path}", ErrorCodes.plot_not_available) + + +class SyncIdsMatchError(PlotSyncException): + def __init__(self, state: State, sync_id: uint64) -> None: + super().__init__(f"{state.name}: Sync ids are equal - {sync_id}", ErrorCodes.sync_ids_match) diff --git a/chia/plot_sync/receiver.py b/chia/plot_sync/receiver.py new file mode 100644 index 000000000000..4df791b53112 --- /dev/null +++ b/chia/plot_sync/receiver.py @@ -0,0 +1,304 @@ +import logging +import time +from typing import Any, Callable, Collection, Coroutine, Dict, List, Optional + +from chia.plot_sync.delta import Delta, PathListDelta, PlotListDelta +from chia.plot_sync.exceptions import ( + InvalidIdentifierError, + InvalidLastSyncIdError, + PlotAlreadyAvailableError, + PlotNotAvailableError, + PlotSyncException, + SyncIdsMatchError, +) +from chia.plot_sync.util import ErrorCodes, State +from chia.protocols.harvester_protocol import ( + Plot, + PlotSyncDone, + PlotSyncError, + PlotSyncIdentifier, + PlotSyncPathList, + PlotSyncPlotList, + PlotSyncResponse, + PlotSyncStart, +) +from chia.server.ws_connection import ProtocolMessageTypes, WSChiaConnection, make_msg +from chia.types.blockchain_format.sized_bytes import bytes32 +from chia.util.ints import int16, uint64 +from chia.util.streamable import _T_Streamable + +log = logging.getLogger(__name__) + + +class Receiver: + _connection: WSChiaConnection + _sync_state: State + _delta: Delta + _expected_sync_id: uint64 + _expected_message_id: uint64 + _last_sync_id: uint64 + _last_sync_time: float + _plots: Dict[str, Plot] + _invalid: List[str] + _keys_missing: List[str] + _duplicates: List[str] + _update_callback: Callable[[bytes32, Delta], Coroutine[Any, Any, None]] + + def __init__( + self, connection: WSChiaConnection, update_callback: Callable[[bytes32, Delta], Coroutine[Any, Any, None]] + ) -> None: + self._connection = connection + self._sync_state = State.idle + self._delta = Delta() + self._expected_sync_id = uint64(0) + self._expected_message_id = uint64(0) + self._last_sync_id = uint64(0) + self._last_sync_time = 0 + self._plots = {} + self._invalid = [] + self._keys_missing = [] + self._duplicates = [] + self._update_callback = update_callback # type: ignore[assignment, misc] + + def reset(self) -> None: + self._sync_state = State.idle + self._expected_sync_id = uint64(0) + self._expected_message_id = uint64(0) + self._last_sync_id = uint64(0) + self._last_sync_time = 0 + self._plots.clear() + self._invalid.clear() + self._keys_missing.clear() + self._duplicates.clear() + self._delta.clear() + + def bump_expected_message_id(self) -> None: + self._expected_message_id = uint64(self._expected_message_id + 1) + + def connection(self) -> WSChiaConnection: + return self._connection + + def state(self) -> State: + return self._sync_state + + def expected_sync_id(self) -> uint64: + return self._expected_sync_id + + def expected_message_id(self) -> uint64: + return self._expected_message_id + + def last_sync_id(self) -> uint64: + return self._last_sync_id + + def last_sync_time(self) -> float: + return self._last_sync_time + + def plots(self) -> Dict[str, Plot]: + return self._plots + + def invalid(self) -> List[str]: + return self._invalid + + def keys_missing(self) -> List[str]: + return self._keys_missing + + def duplicates(self) -> List[str]: + return self._duplicates + + async def _process( + self, method: Callable[[_T_Streamable], Any], message_type: ProtocolMessageTypes, message: Any + ) -> None: + async def send_response(plot_sync_error: Optional[PlotSyncError] = None) -> None: + if self._connection is not None: + await self._connection.send_message( + make_msg( + ProtocolMessageTypes.plot_sync_response, + PlotSyncResponse(message.identifier, int16(message_type.value), plot_sync_error), + ) + ) + + try: + await method(message) + await send_response() + except InvalidIdentifierError as e: + log.warning(f"_process: InvalidIdentifierError {e}") + await send_response(PlotSyncError(int16(e.error_code), f"{e}", e.expected_identifier)) + except PlotSyncException as e: + log.warning(f"_process: Error {e}") + await send_response(PlotSyncError(int16(e.error_code), f"{e}", None)) + except Exception as e: + log.warning(f"_process: Exception {e}") + await send_response(PlotSyncError(int16(ErrorCodes.unknown), f"{e}", None)) + + def _validate_identifier(self, identifier: PlotSyncIdentifier, start: bool = False) -> None: + sync_id_match = identifier.sync_id == self._expected_sync_id + message_id_match = identifier.message_id == self._expected_message_id + identifier_match = sync_id_match and message_id_match + if (start and not message_id_match) or (not start and not identifier_match): + expected: PlotSyncIdentifier = PlotSyncIdentifier( + identifier.timestamp, self._expected_sync_id, self._expected_message_id + ) + raise InvalidIdentifierError( + identifier, + expected, + ) + + async def _sync_started(self, data: PlotSyncStart) -> None: + if data.initial: + self.reset() + self._validate_identifier(data.identifier, True) + if data.last_sync_id != self.last_sync_id(): + raise InvalidLastSyncIdError(data.last_sync_id, self.last_sync_id()) + if data.last_sync_id == data.identifier.sync_id: + raise SyncIdsMatchError(State.idle, data.last_sync_id) + self._expected_sync_id = data.identifier.sync_id + self._delta.clear() + self._sync_state = State.loaded + self.bump_expected_message_id() + + async def sync_started(self, data: PlotSyncStart) -> None: + await self._process(self._sync_started, ProtocolMessageTypes.plot_sync_start, data) + + async def _process_loaded(self, plot_infos: PlotSyncPlotList) -> None: + self._validate_identifier(plot_infos.identifier) + + for plot_info in plot_infos.data: + if plot_info.filename in self._plots or plot_info.filename in self._delta.valid.additions: + raise PlotAlreadyAvailableError(State.loaded, plot_info.filename) + self._delta.valid.additions[plot_info.filename] = plot_info + + if plot_infos.final: + self._sync_state = State.removed + + self.bump_expected_message_id() + + async def process_loaded(self, plot_infos: PlotSyncPlotList) -> None: + await self._process(self._process_loaded, ProtocolMessageTypes.plot_sync_loaded, plot_infos) + + async def process_path_list( + self, + *, + state: State, + next_state: State, + target: Collection[str], + delta: List[str], + paths: PlotSyncPathList, + is_removal: bool = False, + ) -> None: + self._validate_identifier(paths.identifier) + + for path in paths.data: + if is_removal and (path not in target or path in delta): + raise PlotNotAvailableError(state, path) + if not is_removal and path in delta: + raise PlotAlreadyAvailableError(state, path) + delta.append(path) + + if paths.final: + self._sync_state = next_state + + self.bump_expected_message_id() + + async def _process_removed(self, paths: PlotSyncPathList) -> None: + await self.process_path_list( + state=State.removed, + next_state=State.invalid, + target=self._plots, + delta=self._delta.valid.removals, + paths=paths, + is_removal=True, + ) + + async def process_removed(self, paths: PlotSyncPathList) -> None: + await self._process(self._process_removed, ProtocolMessageTypes.plot_sync_removed, paths) + + async def _process_invalid(self, paths: PlotSyncPathList) -> None: + await self.process_path_list( + state=State.invalid, + next_state=State.keys_missing, + target=self._invalid, + delta=self._delta.invalid.additions, + paths=paths, + ) + + async def process_invalid(self, paths: PlotSyncPathList) -> None: + await self._process(self._process_invalid, ProtocolMessageTypes.plot_sync_invalid, paths) + + async def _process_keys_missing(self, paths: PlotSyncPathList) -> None: + await self.process_path_list( + state=State.keys_missing, + next_state=State.duplicates, + target=self._keys_missing, + delta=self._delta.keys_missing.additions, + paths=paths, + ) + + async def process_keys_missing(self, paths: PlotSyncPathList) -> None: + await self._process(self._process_keys_missing, ProtocolMessageTypes.plot_sync_keys_missing, paths) + + async def _process_duplicates(self, paths: PlotSyncPathList) -> None: + await self.process_path_list( + state=State.duplicates, + next_state=State.done, + target=self._duplicates, + delta=self._delta.duplicates.additions, + paths=paths, + ) + + async def process_duplicates(self, paths: PlotSyncPathList) -> None: + await self._process(self._process_duplicates, ProtocolMessageTypes.plot_sync_duplicates, paths) + + async def _sync_done(self, data: PlotSyncDone) -> None: + self._validate_identifier(data.identifier) + # Update ids + self._last_sync_id = self._expected_sync_id + self._expected_sync_id = uint64(0) + self._expected_message_id = uint64(0) + # First create the update delta (i.e. transform invalid/keys_missing into additions/removals) which we will + # send to the callback receiver below + delta_invalid: PathListDelta = PathListDelta.from_lists(self._invalid, self._delta.invalid.additions) + delta_keys_missing: PathListDelta = PathListDelta.from_lists( + self._keys_missing, self._delta.keys_missing.additions + ) + delta_duplicates: PathListDelta = PathListDelta.from_lists(self._duplicates, self._delta.duplicates.additions) + update = Delta( + PlotListDelta(self._delta.valid.additions.copy(), self._delta.valid.removals.copy()), + delta_invalid, + delta_keys_missing, + delta_duplicates, + ) + # Apply delta + self._plots.update(self._delta.valid.additions) + for removal in self._delta.valid.removals: + del self._plots[removal] + self._invalid = self._delta.invalid.additions.copy() + self._keys_missing = self._delta.keys_missing.additions.copy() + self._duplicates = self._delta.duplicates.additions.copy() + # Update state and bump last sync time + self._sync_state = State.idle + self._last_sync_time = time.time() + # Let the callback receiver know if this sync cycle caused any update + try: + await self._update_callback(self._connection.peer_node_id, update) # type: ignore[misc,call-arg] + except Exception as e: + log.error(f"_update_callback raised: {e}") + self._delta.clear() + + async def sync_done(self, data: PlotSyncDone) -> None: + await self._process(self._sync_done, ProtocolMessageTypes.plot_sync_done, data) + + def to_dict(self) -> Dict[str, Any]: + result: Dict[str, Any] = { + "connection": { + "node_id": self._connection.peer_node_id, + "host": self._connection.peer_host, + "port": self._connection.peer_port, + }, + "plots": list(self._plots.values()), + "failed_to_open_filenames": self._invalid, + "no_key_filenames": self._keys_missing, + "duplicates": self._duplicates, + } + if self._last_sync_time != 0: + result["last_sync_time"] = self._last_sync_time + return result diff --git a/chia/plot_sync/sender.py b/chia/plot_sync/sender.py new file mode 100644 index 000000000000..257051a4c2ed --- /dev/null +++ b/chia/plot_sync/sender.py @@ -0,0 +1,327 @@ +import asyncio +import logging +import threading +import time +import traceback +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Generic, Iterable, List, Optional, Tuple, Type, TypeVar + +from typing_extensions import Protocol + +from chia.plot_sync.exceptions import AlreadyStartedError, InvalidConnectionTypeError +from chia.plot_sync.util import Constants +from chia.plotting.manager import PlotManager +from chia.plotting.util import PlotInfo +from chia.protocols.harvester_protocol import ( + Plot, + PlotSyncDone, + PlotSyncIdentifier, + PlotSyncPathList, + PlotSyncPlotList, + PlotSyncResponse, + PlotSyncStart, +) +from chia.server.ws_connection import NodeType, ProtocolMessageTypes, WSChiaConnection, make_msg +from chia.util.generator_tools import list_to_batches +from chia.util.ints import int16, uint32, uint64 + +log = logging.getLogger(__name__) + + +def _convert_plot_info_list(plot_infos: List[PlotInfo]) -> List[Plot]: + converted: List[Plot] = [] + for plot_info in plot_infos: + converted.append( + Plot( + filename=plot_info.prover.get_filename(), + size=plot_info.prover.get_size(), + plot_id=plot_info.prover.get_id(), + pool_public_key=plot_info.pool_public_key, + pool_contract_puzzle_hash=plot_info.pool_contract_puzzle_hash, + plot_public_key=plot_info.plot_public_key, + file_size=uint64(plot_info.file_size), + time_modified=uint64(int(plot_info.time_modified)), + ) + ) + return converted + + +class PayloadType(Protocol): + def __init__(self, identifier: PlotSyncIdentifier, *args: object) -> None: + ... + + +T = TypeVar("T", bound=PayloadType) + + +@dataclass +class MessageGenerator(Generic[T]): + sync_id: uint64 + message_type: ProtocolMessageTypes + message_id: uint64 + payload_type: Type[T] + args: Iterable[object] + + def generate(self) -> Tuple[PlotSyncIdentifier, T]: + identifier = PlotSyncIdentifier(uint64(int(time.time())), self.sync_id, self.message_id) + payload = self.payload_type(identifier, *self.args) + return identifier, payload + + +@dataclass +class ExpectedResponse: + message_type: ProtocolMessageTypes + identifier: PlotSyncIdentifier + message: Optional[PlotSyncResponse] = None + + def __str__(self) -> str: + return ( + f"expected_message_type: {self.message_type.name}, " + f"expected_identifier: {self.identifier}, message {self.message}" + ) + + +class Sender: + _plot_manager: PlotManager + _connection: Optional[WSChiaConnection] + _sync_id: uint64 + _next_message_id: uint64 + _messages: List[MessageGenerator[PayloadType]] + _last_sync_id: uint64 + _stop_requested = False + _task: Optional[asyncio.Task] # type: ignore[type-arg] # Asks for Task parameter which doesn't work + _lock: threading.Lock + _response: Optional[ExpectedResponse] + + def __init__(self, plot_manager: PlotManager) -> None: + self._plot_manager = plot_manager + self._connection = None + self._sync_id = uint64(0) + self._next_message_id = uint64(0) + self._messages = [] + self._last_sync_id = uint64(0) + self._stop_requested = False + self._task = None + self._lock = threading.Lock() + self._response = None + + def __str__(self) -> str: + return f"sync_id {self._sync_id}, next_message_id {self._next_message_id}, messages {len(self._messages)}" + + async def start(self) -> None: + if self._task is not None and self._stop_requested: + await self.await_closed() + if self._task is None: + self._task = asyncio.create_task(self._run()) + # TODO, Add typing in PlotManager + if not self._plot_manager.initial_refresh() or self._sync_id != 0: # type:ignore[no-untyped-call] + self._reset() + else: + raise AlreadyStartedError() + + def stop(self) -> None: + self._stop_requested = True + + async def await_closed(self) -> None: + if self._task is not None: + await self._task + self._task = None + self._reset() + self._stop_requested = False + + def set_connection(self, connection: WSChiaConnection) -> None: + assert connection.connection_type is not None + if connection.connection_type != NodeType.FARMER: + raise InvalidConnectionTypeError(connection.connection_type, NodeType.HARVESTER) + self._connection = connection + + def bump_next_message_id(self) -> None: + self._next_message_id = uint64(self._next_message_id + 1) + + def _reset(self) -> None: + log.debug(f"_reset {self}") + self._last_sync_id = uint64(0) + self._sync_id = uint64(0) + self._next_message_id = uint64(0) + self._messages.clear() + if self._lock.locked(): + self._lock.release() + if self._task is not None: + # TODO, Add typing in PlotManager + self.sync_start(self._plot_manager.plot_count(), True) # type:ignore[no-untyped-call] + for remaining, batch in list_to_batches( + list(self._plot_manager.plots.values()), self._plot_manager.refresh_parameter.batch_size + ): + self.process_batch(batch, remaining) + self.sync_done([], 0) + + async def _wait_for_response(self) -> bool: + start = time.time() + assert self._response is not None + while time.time() - start < Constants.message_timeout and self._response.message is None: + await asyncio.sleep(0.1) + return self._response.message is not None + + def set_response(self, response: PlotSyncResponse) -> bool: + if self._response is None or self._response.message is not None: + log.warning(f"set_response skip unexpected response: {response}") + return False + if time.time() - float(response.identifier.timestamp) > Constants.message_timeout: + log.warning(f"set_response skip expired response: {response}") + return False + if response.identifier.sync_id != self._response.identifier.sync_id: + log.warning( + "set_response unexpected sync-id: " f"{response.identifier.sync_id}/{self._response.identifier.sync_id}" + ) + return False + if response.identifier.message_id != self._response.identifier.message_id: + log.warning( + "set_response unexpected message-id: " + f"{response.identifier.message_id}/{self._response.identifier.message_id}" + ) + return False + if response.message_type != int16(self._response.message_type.value): + log.warning( + "set_response unexpected message-type: " f"{response.message_type}/{self._response.message_type.value}" + ) + return False + log.debug(f"set_response valid {response}") + self._response.message = response + return True + + def _add_message(self, message_type: ProtocolMessageTypes, payload_type: Any, *args: Any) -> None: + assert self._sync_id != 0 + message_id = uint64(len(self._messages)) + self._messages.append(MessageGenerator(self._sync_id, message_type, message_id, payload_type, args)) + + async def _send_next_message(self) -> bool: + def failed(message: str) -> bool: + # By forcing a reset we try to get back into a normal state if some not recoverable failure came up. + log.warning(message) + self._reset() + return False + + assert len(self._messages) >= self._next_message_id + message_generator = self._messages[self._next_message_id] + identifier, payload = message_generator.generate() + if self._sync_id == 0 or identifier.sync_id != self._sync_id or identifier.message_id != self._next_message_id: + return failed(f"Invalid message generator {message_generator} for {self}") + + self._response = ExpectedResponse(message_generator.message_type, identifier) + log.debug(f"_send_next_message send {message_generator.message_type.name}: {payload}") + if self._connection is None or not await self._connection.send_message( + make_msg(message_generator.message_type, payload) + ): + return failed(f"Send failed {self._connection}") + if not await self._wait_for_response(): + log.info(f"_send_next_message didn't receive response {self._response}") + return False + + assert self._response.message is not None + if self._response.message.error is not None: + recovered = False + expected = self._response.message.error.expected_identifier + # If we have a recoverable error there is a `expected_identifier` included + if expected is not None: + # If the receiver has a zero sync/message id and we already sent all messages from the current event + # we most likely missed the response to the done message. We can finalize the sync and move on here. + all_sent = ( + self._messages[-1].message_type == ProtocolMessageTypes.plot_sync_done + and self._next_message_id == len(self._messages) - 1 + ) + if expected.sync_id == expected.message_id == 0 and all_sent: + self._finalize_sync() + recovered = True + elif self._sync_id == expected.sync_id and expected.message_id < len(self._messages): + self._next_message_id = expected.message_id + recovered = True + if not recovered: + return failed(f"Not recoverable error {self._response.message}") + return True + + if self._response.message_type == ProtocolMessageTypes.plot_sync_done: + self._finalize_sync() + else: + self.bump_next_message_id() + + return True + + def _add_list_batched(self, message_type: ProtocolMessageTypes, payload_type: Any, data: List[Any]) -> None: + if len(data) == 0: + self._add_message(message_type, payload_type, [], True) + return + for remaining, batch in list_to_batches(data, self._plot_manager.refresh_parameter.batch_size): + self._add_message(message_type, payload_type, batch, remaining == 0) + + def sync_start(self, count: float, initial: bool) -> None: + log.debug(f"sync_start {self}: count {count}, initial {initial}") + self._lock.acquire() + sync_id = int(time.time()) + # Make sure we have unique sync-id's even if we restart refreshing within a second (i.e. in tests) + if sync_id == self._last_sync_id: + sync_id = sync_id + 1 + log.debug(f"sync_start {sync_id}") + self._sync_id = uint64(sync_id) + self._add_message( + ProtocolMessageTypes.plot_sync_start, PlotSyncStart, initial, self._last_sync_id, uint32(int(count)) + ) + + def process_batch(self, loaded: List[PlotInfo], remaining: int) -> None: + log.debug(f"process_batch {self}: loaded {len(loaded)}, remaining {remaining}") + if len(loaded) > 0 or remaining == 0: + converted = _convert_plot_info_list(loaded) + self._add_message(ProtocolMessageTypes.plot_sync_loaded, PlotSyncPlotList, converted, remaining == 0) + + def sync_done(self, removed: List[Path], duration: float) -> None: + log.debug(f"sync_done {self}: removed {len(removed)}, duration {duration}") + removed_list = [str(x) for x in removed] + self._add_list_batched( + ProtocolMessageTypes.plot_sync_removed, + PlotSyncPathList, + removed_list, + ) + failed_to_open_list = [str(x) for x in list(self._plot_manager.failed_to_open_filenames)] + self._add_list_batched(ProtocolMessageTypes.plot_sync_invalid, PlotSyncPathList, failed_to_open_list) + no_key_list = [str(x) for x in self._plot_manager.no_key_filenames] + self._add_list_batched(ProtocolMessageTypes.plot_sync_keys_missing, PlotSyncPathList, no_key_list) + # TODO, Add typing in PlotManager + duplicates_list: List[str] = self._plot_manager.get_duplicates().copy() # type:ignore[no-untyped-call] + self._add_list_batched(ProtocolMessageTypes.plot_sync_duplicates, PlotSyncPathList, duplicates_list) + self._add_message(ProtocolMessageTypes.plot_sync_done, PlotSyncDone, uint64(int(duration))) + + def _finalize_sync(self) -> None: + log.debug(f"_finalize_sync {self}") + assert self._sync_id != 0 + self._last_sync_id = self._sync_id + self._sync_id = uint64(0) + self._next_message_id = uint64(0) + self._messages.clear() + self._lock.release() + + def sync_active(self) -> bool: + return self._lock.locked() and self._sync_id != 0 + + def connected(self) -> bool: + return self._connection is not None + + async def _run(self) -> None: + """ + This is the sender task responsible to send new messages during sync as they come into Sender._messages + triggered by the plot manager callback. + """ + while not self._stop_requested: + try: + while not self.connected() or not self.sync_active(): + if self._stop_requested: + return + await asyncio.sleep(0.1) + while not self._stop_requested and self.sync_active(): + if self._next_message_id >= len(self._messages): + await asyncio.sleep(0.1) + continue + if not await self._send_next_message(): + await asyncio.sleep(Constants.message_timeout) + except Exception as e: + log.error(f"Exception: {e} {traceback.format_exc()}") + self._reset() diff --git a/chia/plot_sync/util.py b/chia/plot_sync/util.py new file mode 100644 index 000000000000..5776631a73d8 --- /dev/null +++ b/chia/plot_sync/util.py @@ -0,0 +1,27 @@ +from enum import IntEnum + + +class Constants: + message_timeout: int = 10 + + +class State(IntEnum): + idle = 0 + loaded = 1 + removed = 2 + invalid = 3 + keys_missing = 4 + duplicates = 5 + done = 6 + + +class ErrorCodes(IntEnum): + unknown = -1 + invalid_state = 0 + invalid_peer_id = 1 + invalid_identifier = 2 + invalid_last_sync_id = 3 + invalid_connection_type = 4 + plot_already_available = 5 + plot_not_available = 6 + sync_ids_match = 7 diff --git a/chia/plotting/manager.py b/chia/plotting/manager.py index 6fcd84982eef..7f48cbfc304e 100644 --- a/chia/plotting/manager.py +++ b/chia/plotting/manager.py @@ -131,6 +131,7 @@ class PlotManager: _refresh_thread: Optional[threading.Thread] _refreshing_enabled: bool _refresh_callback: Callable + _initial: bool def __init__( self, @@ -158,6 +159,7 @@ class PlotManager: self._refresh_thread = None self._refreshing_enabled = False self._refresh_callback = refresh_callback # type: ignore + self._initial = True def __enter__(self): self._lock.acquire() @@ -172,6 +174,7 @@ class PlotManager: self.plot_filename_paths.clear() self.failed_to_open_filenames.clear() self.no_key_filenames.clear() + self._initial = True def set_refresh_callback(self, callback: Callable): self._refresh_callback = callback # type: ignore @@ -180,6 +183,9 @@ class PlotManager: self.farmer_public_keys = farmer_public_keys self.pool_public_keys = pool_public_keys + def initial_refresh(self): + return self._initial + def public_keys_available(self): return len(self.farmer_public_keys) and len(self.pool_public_keys) @@ -262,7 +268,6 @@ class PlotManager: loaded_plot = Path(path) / Path(plot_filename) if loaded_plot not in plot_paths: paths_to_remove.append(path) - total_result.removed.append(loaded_plot) for path in paths_to_remove: duplicated_paths.remove(path) @@ -290,6 +295,9 @@ class PlotManager: if self._refreshing_enabled: self._refresh_callback(PlotRefreshEvents.done, total_result) + # Reset the initial refresh indication + self._initial = False + # Cleanup unused cache available_ids = set([plot_info.prover.get_id() for plot_info in self.plots.values()]) invalid_cache_keys = [plot_id for plot_id in self.cache.keys() if plot_id not in available_ids] diff --git a/chia/protocols/harvester_protocol.py b/chia/protocols/harvester_protocol.py index b48165773fc9..4c5cddc1459c 100644 --- a/chia/protocols/harvester_protocol.py +++ b/chia/protocols/harvester_protocol.py @@ -5,7 +5,7 @@ from blspy import G1Element, G2Element from chia.types.blockchain_format.proof_of_space import ProofOfSpace from chia.types.blockchain_format.sized_bytes import bytes32 -from chia.util.ints import uint8, uint64 +from chia.util.ints import int16, uint8, uint32, uint64 from chia.util.streamable import Streamable, streamable """ @@ -95,3 +95,80 @@ class RespondPlots(Streamable): plots: List[Plot] failed_to_open_filenames: List[str] no_key_filenames: List[str] + + +@dataclass(frozen=True) +@streamable +class PlotSyncIdentifier(Streamable): + timestamp: uint64 + sync_id: uint64 + message_id: uint64 + + +@dataclass(frozen=True) +@streamable +class PlotSyncStart(Streamable): + identifier: PlotSyncIdentifier + initial: bool + last_sync_id: uint64 + plot_file_count: uint32 + + def __str__(self) -> str: + return ( + f"PlotSyncStart: identifier {self.identifier}, initial {self.initial}, " + f"last_sync_id {self.last_sync_id}, plot_file_count {self.plot_file_count}" + ) + + +@dataclass(frozen=True) +@streamable +class PlotSyncPathList(Streamable): + identifier: PlotSyncIdentifier + data: List[str] + final: bool + + def __str__(self) -> str: + return f"PlotSyncPathList: identifier {self.identifier}, count {len(self.data)}, final {self.final}" + + +@dataclass(frozen=True) +@streamable +class PlotSyncPlotList(Streamable): + identifier: PlotSyncIdentifier + data: List[Plot] + final: bool + + def __str__(self) -> str: + return f"PlotSyncPlotList: identifier {self.identifier}, count {len(self.data)}, final {self.final}" + + +@dataclass(frozen=True) +@streamable +class PlotSyncDone(Streamable): + identifier: PlotSyncIdentifier + duration: uint64 + + def __str__(self) -> str: + return f"PlotSyncDone: identifier {self.identifier}, duration {self.duration}" + + +@dataclass(frozen=True) +@streamable +class PlotSyncError(Streamable): + code: int16 + message: str + expected_identifier: Optional[PlotSyncIdentifier] + + def __str__(self) -> str: + return f"PlotSyncError: code {self.code}, count {self.message}, expected_identifier {self.expected_identifier}" + + +@dataclass(frozen=True) +@streamable +class PlotSyncResponse(Streamable): + identifier: PlotSyncIdentifier + message_type: int16 + error: Optional[PlotSyncError] + + def __str__(self) -> str: + return f"PlotSyncResponse: identifier {self.identifier}, message_type {self.message_type}, error {self.error}" diff --git a/chia/protocols/protocol_message_types.py b/chia/protocols/protocol_message_types.py index 7596f4554745..b54e2717b41d 100644 --- a/chia/protocols/protocol_message_types.py +++ b/chia/protocols/protocol_message_types.py @@ -86,6 +86,14 @@ class ProtocolMessageTypes(Enum): new_signage_point_harvester = 66 request_plots = 67 respond_plots = 68 + plot_sync_start = 78 + plot_sync_loaded = 79 + plot_sync_removed = 80 + plot_sync_invalid = 81 + plot_sync_keys_missing = 82 + plot_sync_duplicates = 83 + plot_sync_done = 84 + plot_sync_response = 85 # More wallet protocol coin_state_update = 69 diff --git a/chia/protocols/shared_protocol.py b/chia/protocols/shared_protocol.py index d7c0e6cd94a8..ed1cc9e7d680 100644 --- a/chia/protocols/shared_protocol.py +++ b/chia/protocols/shared_protocol.py @@ -5,7 +5,7 @@ from typing import List, Tuple from chia.util.ints import uint8, uint16 from chia.util.streamable import Streamable, streamable -protocol_version = "0.0.33" +protocol_version = "0.0.34" """ Handshake when establishing a connection between two servers. diff --git a/chia/server/rate_limits.py b/chia/server/rate_limits.py index 78f4d69340ff..b70c04b0f8f0 100644 --- a/chia/server/rate_limits.py +++ b/chia/server/rate_limits.py @@ -97,6 +97,14 @@ rate_limits_other = { ProtocolMessageTypes.farm_new_block: RLSettings(200, 200), ProtocolMessageTypes.request_plots: RLSettings(10, 10 * 1024 * 1024), ProtocolMessageTypes.respond_plots: RLSettings(10, 100 * 1024 * 1024), + ProtocolMessageTypes.plot_sync_start: RLSettings(1000, 100 * 1024 * 1024), + ProtocolMessageTypes.plot_sync_loaded: RLSettings(1000, 100 * 1024 * 1024), + ProtocolMessageTypes.plot_sync_removed: RLSettings(1000, 100 * 1024 * 1024), + ProtocolMessageTypes.plot_sync_invalid: RLSettings(1000, 100 * 1024 * 1024), + ProtocolMessageTypes.plot_sync_keys_missing: RLSettings(1000, 100 * 1024 * 1024), + ProtocolMessageTypes.plot_sync_duplicates: RLSettings(1000, 100 * 1024 * 1024), + ProtocolMessageTypes.plot_sync_done: RLSettings(1000, 100 * 1024 * 1024), + ProtocolMessageTypes.plot_sync_response: RLSettings(3000, 100 * 1024 * 1024), ProtocolMessageTypes.coin_state_update: RLSettings(1000, 100 * 1024 * 1024), ProtocolMessageTypes.register_interest_in_puzzle_hash: RLSettings(1000, 100 * 1024 * 1024), ProtocolMessageTypes.respond_to_ph_update: RLSettings(1000, 100 * 1024 * 1024), diff --git a/setup.py b/setup.py index 8d4d20558bea..9c6e98238dc1 100644 --- a/setup.py +++ b/setup.py @@ -92,6 +92,7 @@ kwargs = dict( "chia.farmer", "chia.harvester", "chia.introducer", + "chia.plot_sync", "chia.plotters", "chia.plotting", "chia.pools", diff --git a/tests/conftest.py b/tests/conftest.py index 9dd58b313a0b..f5377b57618e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -8,6 +8,9 @@ import pytest_asyncio import tempfile from tests.setup_nodes import setup_node_and_wallet, setup_n_nodes, setup_two_nodes +from pathlib import Path +from typing import Any, AsyncIterator, Dict, List, Tuple +from chia.server.start_service import Service # Set spawn after stdlib imports, but before other imports from chia.clvm.spend_sim import SimClient, SpendSim @@ -39,6 +42,7 @@ from pathlib import Path from chia.util.keyring_wrapper import KeyringWrapper from tests.block_tools import BlockTools, test_constants, create_block_tools, create_block_tools_async from tests.util.keyring import TempKeyring +from tests.setup_nodes import setup_farmer_multi_harvester @pytest.fixture(scope="session") @@ -403,6 +407,24 @@ async def two_nodes_one_block(bt, wallet_a): yield _ +@pytest_asyncio.fixture(scope="function") +async def farmer_one_harvester(tmp_path: Path, bt: BlockTools) -> AsyncIterator[Tuple[List[Service], Service]]: + async for _ in setup_farmer_multi_harvester(bt, 1, tmp_path, test_constants): + yield _ + + +@pytest_asyncio.fixture(scope="function") +async def farmer_two_harvester(tmp_path: Path, bt: BlockTools) -> AsyncIterator[Tuple[List[Service], Service]]: + async for _ in setup_farmer_multi_harvester(bt, 2, tmp_path, test_constants): + yield _ + + +@pytest_asyncio.fixture(scope="function") +async def farmer_three_harvester(tmp_path: Path, bt: BlockTools) -> AsyncIterator[Tuple[List[Service], Service]]: + async for _ in setup_farmer_multi_harvester(bt, 3, tmp_path, test_constants): + yield _ + + # TODO: Ideally, the db_version should be the (parameterized) db_version # fixture, to test all versions of the database schema. This doesn't work # because of a hack in shutting down the full node, which means you cannot run diff --git a/tests/core/test_farmer_harvester_rpc.py b/tests/core/test_farmer_harvester_rpc.py index 72a16ba9e44d..e312698fe007 100644 --- a/tests/core/test_farmer_harvester_rpc.py +++ b/tests/core/test_farmer_harvester_rpc.py @@ -112,7 +112,6 @@ async def test_farmer_get_harvesters(harvester_farmer_environment): harvester_rpc_api, harvester_rpc_client, ) = harvester_farmer_environment - farmer_api = farmer_service._api harvester = harvester_service._node num_plots = 0 @@ -125,11 +124,6 @@ async def test_farmer_get_harvesters(harvester_farmer_environment): await time_out_assert(10, non_zero_plots) - # Reset cache and force updates cache every second to make sure the farmer gets the most recent data - update_interval_before = farmer_api.farmer.update_harvester_cache_interval - farmer_api.farmer.update_harvester_cache_interval = 1 - farmer_api.farmer.harvester_cache = {} - async def test_get_harvesters(): harvester.plot_manager.trigger_refresh() await time_out_assert(5, harvester.plot_manager.needs_refresh, value=False) @@ -144,10 +138,6 @@ async def test_farmer_get_harvesters(harvester_farmer_environment): await time_out_assert_custom_interval(30, 1, test_get_harvesters) - # Reset cache and reset update interval to avoid hitting the rate limit - farmer_api.farmer.update_harvester_cache_interval = update_interval_before - farmer_api.farmer.harvester_cache = {} - @pytest.mark.asyncio async def test_farmer_signage_point_endpoints(harvester_farmer_environment): diff --git a/tests/plot_sync/__init__.py b/tests/plot_sync/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/plot_sync/config.py b/tests/plot_sync/config.py new file mode 100644 index 000000000000..235efb181c1c --- /dev/null +++ b/tests/plot_sync/config.py @@ -0,0 +1,2 @@ +parallel = True +checkout_blocks_and_plots = True diff --git a/tests/plot_sync/test_delta.py b/tests/plot_sync/test_delta.py new file mode 100644 index 000000000000..057e449bda15 --- /dev/null +++ b/tests/plot_sync/test_delta.py @@ -0,0 +1,90 @@ +import logging +from typing import List + +import pytest +from blspy import G1Element + +from chia.plot_sync.delta import Delta, DeltaType, PathListDelta, PlotListDelta +from chia.protocols.harvester_protocol import Plot +from chia.types.blockchain_format.sized_bytes import bytes32 +from chia.util.ints import uint8, uint64 + +log = logging.getLogger(__name__) + + +def dummy_plot(path: str) -> Plot: + return Plot(path, uint8(32), bytes32(b"\00" * 32), G1Element(), None, G1Element(), uint64(0), uint64(0)) + + +@pytest.mark.parametrize( + ["delta"], + [ + pytest.param(PathListDelta(), id="path list"), + pytest.param(PlotListDelta(), id="plot list"), + ], +) +def test_list_delta(delta: DeltaType) -> None: + assert delta.empty() + if type(delta) == PathListDelta: + assert delta.additions == [] + elif type(delta) == PlotListDelta: + assert delta.additions == {} + else: + assert False + assert delta.removals == [] + assert delta.empty() + if type(delta) == PathListDelta: + delta.additions.append("0") + elif type(delta) == PlotListDelta: + delta.additions["0"] = dummy_plot("0") + else: + assert False, "Invalid delta type" + assert not delta.empty() + delta.removals.append("0") + assert not delta.empty() + delta.additions.clear() + assert not delta.empty() + delta.clear() + assert delta.empty() + + +@pytest.mark.parametrize( + ["old", "new", "result"], + [ + [[], [], PathListDelta()], + [["1"], ["0"], PathListDelta(["0"], ["1"])], + [["1", "2", "3"], ["1", "2", "3"], PathListDelta([], [])], + [["2", "1", "3"], ["2", "3", "1"], PathListDelta([], [])], + [["2"], ["2", "3", "1"], PathListDelta(["3", "1"], [])], + [["2"], ["1", "3"], PathListDelta(["1", "3"], ["2"])], + [["1"], ["1", "2", "3"], PathListDelta(["2", "3"], [])], + [[], ["1", "2", "3"], PathListDelta(["1", "2", "3"], [])], + [["-1"], ["1", "2", "3"], PathListDelta(["1", "2", "3"], ["-1"])], + [["-1", "1"], ["2", "3"], PathListDelta(["2", "3"], ["-1", "1"])], + [["-1", "1", "2"], ["2", "3"], PathListDelta(["3"], ["-1", "1"])], + [["-1", "2", "3"], ["2", "3"], PathListDelta([], ["-1"])], + [["-1", "2", "3", "-2"], ["2", "3"], PathListDelta([], ["-1", "-2"])], + [["-2", "2", "3", "-1"], ["2", "3"], PathListDelta([], ["-2", "-1"])], + ], +) +def test_path_list_delta_from_lists(old: List[str], new: List[str], result: PathListDelta) -> None: + assert PathListDelta.from_lists(old, new) == result + + +def test_delta_empty() -> None: + delta: Delta = Delta() + all_deltas: List[DeltaType] = [delta.valid, delta.invalid, delta.keys_missing, delta.duplicates] + assert delta.empty() + for d1 in all_deltas: + delta.valid.additions["0"] = dummy_plot("0") + delta.invalid.additions.append("0") + delta.keys_missing.additions.append("0") + delta.duplicates.additions.append("0") + assert not delta.empty() + for d2 in all_deltas: + if d2 is not d1: + d2.clear() + assert not delta.empty() + assert not delta.empty() + d1.clear() + assert delta.empty() diff --git a/tests/plot_sync/test_plot_sync.py b/tests/plot_sync/test_plot_sync.py new file mode 100644 index 000000000000..96e1fcadd832 --- /dev/null +++ b/tests/plot_sync/test_plot_sync.py @@ -0,0 +1,537 @@ +from dataclasses import dataclass, field +from pathlib import Path +from shutil import copy +from typing import List, Optional, Tuple + +import pytest +import pytest_asyncio +from blspy import G1Element + +from chia.farmer.farmer_api import Farmer +from chia.harvester.harvester_api import Harvester +from chia.plot_sync.delta import Delta, PathListDelta, PlotListDelta +from chia.plot_sync.receiver import Receiver +from chia.plot_sync.sender import Sender +from chia.plot_sync.util import State +from chia.plotting.manager import PlotManager +from chia.plotting.util import add_plot_directory, remove_plot_directory +from chia.protocols.harvester_protocol import Plot +from chia.server.start_service import Service +from chia.types.blockchain_format.sized_bytes import bytes32 +from chia.util.config import create_default_chia_config +from chia.util.ints import uint8, uint64 +from tests.block_tools import BlockTools +from tests.plot_sync.util import start_harvester_service +from tests.plotting.test_plot_manager import MockPlotInfo, TestDirectory +from tests.plotting.util import get_test_plots +from tests.time_out_assert import time_out_assert + + +def synced(sender: Sender, receiver: Receiver, previous_last_sync_id: int) -> bool: + return ( + sender._last_sync_id != previous_last_sync_id + and sender._last_sync_id == receiver._last_sync_id != 0 + and receiver.state() == State.idle + and not sender._lock.locked() + ) + + +def assert_path_list_matches(expected_list: List[str], actual_list: List[str]) -> None: + assert len(expected_list) == len(actual_list) + for item in expected_list: + assert str(item) in actual_list + + +@dataclass +class ExpectedResult: + valid_count: int = 0 + valid_delta: PlotListDelta = field(default_factory=PlotListDelta) + invalid_count: int = 0 + invalid_delta: PathListDelta = field(default_factory=PathListDelta) + keys_missing_count: int = 0 + keys_missing_delta: PathListDelta = field(default_factory=PathListDelta) + duplicates_count: int = 0 + duplicates_delta: PathListDelta = field(default_factory=PathListDelta) + callback_passed: bool = False + + def add_valid(self, list_plots: List[MockPlotInfo]) -> None: + def create_mock_plot(info: MockPlotInfo) -> Plot: + return Plot( + info.prover.get_filename(), + uint8(0), + bytes32(b"\x00" * 32), + None, + None, + G1Element(), + uint64(0), + uint64(0), + ) + + self.valid_count += len(list_plots) + self.valid_delta.additions.update({x.prover.get_filename(): create_mock_plot(x) for x in list_plots}) + + def remove_valid(self, list_paths: List[Path]) -> None: + self.valid_count -= len(list_paths) + self.valid_delta.removals += [str(x) for x in list_paths] + + def add_invalid(self, list_paths: List[Path]) -> None: + self.invalid_count += len(list_paths) + self.invalid_delta.additions += [str(x) for x in list_paths] + + def remove_invalid(self, list_paths: List[Path]) -> None: + self.invalid_count -= len(list_paths) + self.invalid_delta.removals += [str(x) for x in list_paths] + + def add_keys_missing(self, list_paths: List[Path]) -> None: + self.keys_missing_count += len(list_paths) + self.keys_missing_delta.additions += [str(x) for x in list_paths] + + def remove_keys_missing(self, list_paths: List[Path]) -> None: + self.keys_missing_count -= len(list_paths) + self.keys_missing_delta.removals += [str(x) for x in list_paths] + + def add_duplicates(self, list_paths: List[Path]) -> None: + self.duplicates_count += len(list_paths) + self.duplicates_delta.additions += [str(x) for x in list_paths] + + def remove_duplicates(self, list_paths: List[Path]) -> None: + self.duplicates_count -= len(list_paths) + self.duplicates_delta.removals += [str(x) for x in list_paths] + + +@dataclass +class Environment: + root_path: Path + harvester_services: List[Service] + farmer_service: Service + harvesters: List[Harvester] + farmer: Farmer + dir_1: TestDirectory + dir_2: TestDirectory + dir_3: TestDirectory + dir_4: TestDirectory + dir_invalid: TestDirectory + dir_keys_missing: TestDirectory + dir_duplicates: TestDirectory + expected: List[ExpectedResult] + + def get_harvester(self, peer_id: bytes32) -> Optional[Harvester]: + for harvester in self.harvesters: + assert harvester.server is not None + if harvester.server.node_id == peer_id: + return harvester + return None + + def add_directory(self, harvester_index: int, directory: TestDirectory, state: State = State.loaded) -> None: + add_plot_directory(self.harvesters[harvester_index].root_path, str(directory.path)) + if state == State.loaded: + self.expected[harvester_index].add_valid(directory.plot_info_list()) + elif state == State.invalid: + self.expected[harvester_index].add_invalid(directory.path_list()) + elif state == State.keys_missing: + self.expected[harvester_index].add_keys_missing(directory.path_list()) + elif state == State.duplicates: + self.expected[harvester_index].add_duplicates(directory.path_list()) + else: + assert False, "Invalid state" + + def remove_directory(self, harvester_index: int, directory: TestDirectory, state: State = State.removed) -> None: + remove_plot_directory(self.harvesters[harvester_index].root_path, str(directory.path)) + if state == State.removed: + self.expected[harvester_index].remove_valid(directory.path_list()) + elif state == State.invalid: + self.expected[harvester_index].remove_invalid(directory.path_list()) + elif state == State.keys_missing: + self.expected[harvester_index].remove_keys_missing(directory.path_list()) + elif state == State.duplicates: + self.expected[harvester_index].remove_duplicates(directory.path_list()) + else: + assert False, "Invalid state" + + def add_all_directories(self, harvester_index: int) -> None: + self.add_directory(harvester_index, self.dir_1) + self.add_directory(harvester_index, self.dir_2) + self.add_directory(harvester_index, self.dir_3) + self.add_directory(harvester_index, self.dir_4) + self.add_directory(harvester_index, self.dir_keys_missing, State.keys_missing) + self.add_directory(harvester_index, self.dir_invalid, State.invalid) + # Note: This does not add dir_duplicates since its important that the duplicated plots are loaded after the + # the original ones. + # self.add_directory(harvester_index, self.dir_duplicates, State.duplicates) + + def remove_all_directories(self, harvester_index: int) -> None: + self.remove_directory(harvester_index, self.dir_1) + self.remove_directory(harvester_index, self.dir_2) + self.remove_directory(harvester_index, self.dir_3) + self.remove_directory(harvester_index, self.dir_4) + self.remove_directory(harvester_index, self.dir_keys_missing, State.keys_missing) + self.remove_directory(harvester_index, self.dir_invalid, State.invalid) + self.remove_directory(harvester_index, self.dir_duplicates, State.duplicates) + + async def plot_sync_callback(self, peer_id: bytes32, delta: Delta) -> None: + harvester: Optional[Harvester] = self.get_harvester(peer_id) + assert harvester is not None + expected = self.expected[self.harvesters.index(harvester)] + assert len(expected.valid_delta.additions) == len(delta.valid.additions) + for path, plot_info in expected.valid_delta.additions.items(): + assert path in delta.valid.additions + plot = harvester.plot_manager.plots.get(Path(path), None) + assert plot is not None + assert plot.prover.get_filename() == delta.valid.additions[path].filename + assert plot.prover.get_size() == delta.valid.additions[path].size + assert plot.prover.get_id() == delta.valid.additions[path].plot_id + assert plot.pool_public_key == delta.valid.additions[path].pool_public_key + assert plot.pool_contract_puzzle_hash == delta.valid.additions[path].pool_contract_puzzle_hash + assert plot.plot_public_key == delta.valid.additions[path].plot_public_key + assert plot.file_size == delta.valid.additions[path].file_size + assert int(plot.time_modified) == delta.valid.additions[path].time_modified + + assert_path_list_matches(expected.valid_delta.removals, delta.valid.removals) + assert_path_list_matches(expected.invalid_delta.additions, delta.invalid.additions) + assert_path_list_matches(expected.invalid_delta.removals, delta.invalid.removals) + assert_path_list_matches(expected.keys_missing_delta.additions, delta.keys_missing.additions) + assert_path_list_matches(expected.keys_missing_delta.removals, delta.keys_missing.removals) + assert_path_list_matches(expected.duplicates_delta.additions, delta.duplicates.additions) + assert_path_list_matches(expected.duplicates_delta.removals, delta.duplicates.removals) + expected.valid_delta.clear() + expected.invalid_delta.clear() + expected.keys_missing_delta.clear() + expected.duplicates_delta.clear() + expected.callback_passed = True + + async def run_sync_test(self) -> None: + plot_manager: PlotManager + assert len(self.harvesters) == len(self.expected) + last_sync_ids: List[uint64] = [] + # Run the test in two steps, first trigger the refresh on both harvesters + for harvester in self.harvesters: + plot_manager = harvester.plot_manager + assert harvester.server is not None + receiver = self.farmer.plot_sync_receivers[harvester.server.node_id] + # Make sure to reset the passed flag always before a new run + self.expected[self.harvesters.index(harvester)].callback_passed = False + receiver._update_callback = self.plot_sync_callback + assert harvester.plot_sync_sender._last_sync_id == receiver._last_sync_id + last_sync_ids.append(harvester.plot_sync_sender._last_sync_id) + plot_manager.start_refreshing() + plot_manager.trigger_refresh() + # Then wait for them to be synced with the farmer and validate them + for harvester in self.harvesters: + plot_manager = harvester.plot_manager + assert harvester.server is not None + receiver = self.farmer.plot_sync_receivers[harvester.server.node_id] + await time_out_assert(10, plot_manager.needs_refresh, value=False) + harvester_index = self.harvesters.index(harvester) + await time_out_assert( + 10, synced, True, harvester.plot_sync_sender, receiver, last_sync_ids[harvester_index] + ) + expected = self.expected[harvester_index] + assert plot_manager.plot_count() == len(receiver.plots()) == expected.valid_count + assert len(plot_manager.failed_to_open_filenames) == len(receiver.invalid()) == expected.invalid_count + assert len(plot_manager.no_key_filenames) == len(receiver.keys_missing()) == expected.keys_missing_count + assert len(plot_manager.get_duplicates()) == len(receiver.duplicates()) == expected.duplicates_count + assert expected.callback_passed + assert expected.valid_delta.empty() + assert expected.invalid_delta.empty() + assert expected.keys_missing_delta.empty() + assert expected.duplicates_delta.empty() + for path, plot_info in plot_manager.plots.items(): + assert str(path) in receiver.plots() + assert plot_info.prover.get_filename() == receiver.plots()[str(path)].filename + assert plot_info.prover.get_size() == receiver.plots()[str(path)].size + assert plot_info.prover.get_id() == receiver.plots()[str(path)].plot_id + assert plot_info.pool_public_key == receiver.plots()[str(path)].pool_public_key + assert plot_info.pool_contract_puzzle_hash == receiver.plots()[str(path)].pool_contract_puzzle_hash + assert plot_info.plot_public_key == receiver.plots()[str(path)].plot_public_key + assert plot_info.file_size == receiver.plots()[str(path)].file_size + assert int(plot_info.time_modified) == receiver.plots()[str(path)].time_modified + for path in plot_manager.failed_to_open_filenames: + assert str(path) in receiver.invalid() + for path in plot_manager.no_key_filenames: + assert str(path) in receiver.keys_missing() + for path in plot_manager.get_duplicates(): + assert str(path) in receiver.duplicates() + + async def handshake_done(self, index: int) -> bool: + return ( + self.harvesters[index].plot_manager._refresh_thread is not None + and len(self.harvesters[index].plot_manager.farmer_public_keys) > 0 + ) + + +@pytest_asyncio.fixture(scope="function") +async def environment( + bt: BlockTools, tmp_path: Path, farmer_two_harvester: Tuple[List[Service], Service] +) -> Environment: + def new_test_dir(name: str, plot_list: List[Path]) -> TestDirectory: + return TestDirectory(tmp_path / "plots" / name, plot_list) + + plots: List[Path] = get_test_plots() + plots_invalid: List[Path] = get_test_plots()[0:3] + plots_keys_missing: List[Path] = get_test_plots("not_in_keychain") + # Create 4 directories where: dir_n contains n plots + directories: List[TestDirectory] = [] + offset: int = 0 + while len(directories) < 4: + dir_number = len(directories) + 1 + directories.append(new_test_dir(f"{dir_number}", plots[offset : offset + dir_number])) + offset += dir_number + + dir_invalid: TestDirectory = new_test_dir("invalid", plots_invalid) + dir_keys_missing: TestDirectory = new_test_dir("keys_missing", plots_keys_missing) + dir_duplicates: TestDirectory = new_test_dir("duplicates", directories[3].plots) + create_default_chia_config(tmp_path) + + # Invalidate the plots in `dir_invalid` + for path in dir_invalid.path_list(): + with open(path, "wb") as file: + file.write(bytes(100)) + + harvester_services: List[Service] + farmer_service: Service + harvester_services, farmer_service = farmer_two_harvester + farmer: Farmer = farmer_service._node + harvesters: List[Harvester] = [await start_harvester_service(service) for service in harvester_services] + for harvester in harvesters: + harvester.plot_manager.set_public_keys( + bt.plot_manager.farmer_public_keys.copy(), bt.plot_manager.pool_public_keys.copy() + ) + + assert len(farmer.plot_sync_receivers) == 2 + + return Environment( + tmp_path, + harvester_services, + farmer_service, + harvesters, + farmer, + directories[0], + directories[1], + directories[2], + directories[3], + dir_invalid, + dir_keys_missing, + dir_duplicates, + [ExpectedResult() for _ in harvesters], + ) + + +@pytest.mark.asyncio +async def test_sync_valid(environment: Environment) -> None: + env: Environment = environment + env.add_directory(0, env.dir_1) + env.add_directory(1, env.dir_2) + await env.run_sync_test() + # Run again two times to make sure we still get the same results in repeated refresh intervals + env.expected[0].valid_delta.clear() + env.expected[1].valid_delta.clear() + await env.run_sync_test() + await env.run_sync_test() + env.add_directory(0, env.dir_3) + env.add_directory(1, env.dir_4) + await env.run_sync_test() + while len(env.dir_3.path_list()): + drop_plot = env.dir_3.path_list()[0] + drop_plot.unlink() + env.dir_3.drop(drop_plot) + env.expected[0].remove_valid([drop_plot]) + await env.run_sync_test() + env.remove_directory(0, env.dir_3) + await env.run_sync_test() + env.remove_directory(1, env.dir_4) + await env.run_sync_test() + env.remove_directory(0, env.dir_1) + env.remove_directory(1, env.dir_2) + await env.run_sync_test() + + +@pytest.mark.asyncio +async def test_sync_invalid(environment: Environment) -> None: + env: Environment = environment + assert len(env.farmer.plot_sync_receivers) == 2 + # Use dir_3 and dir_4 in this test because the invalid plots are copies from dir_1 + dir_2 + env.add_directory(0, env.dir_3) + env.add_directory(0, env.dir_invalid, State.invalid) + env.add_directory(1, env.dir_4) + await env.run_sync_test() + # Run again two times to make sure we still get the same results in repeated refresh intervals + await env.run_sync_test() + await env.run_sync_test() + # Drop all but two of the invalid plots + assert len(env.dir_invalid) > 2 + for _ in range(len(env.dir_invalid) - 2): + drop_plot = env.dir_invalid.path_list()[0] + drop_plot.unlink() + env.dir_invalid.drop(drop_plot) + env.expected[0].remove_invalid([drop_plot]) + await env.run_sync_test() + assert len(env.dir_invalid) == 2 + # Add the directory to the first harvester too + env.add_directory(1, env.dir_invalid, State.invalid) + await env.run_sync_test() + # Recover one the remaining invalid plot + for path in get_test_plots(): + if path.name == env.dir_invalid.path_list()[0].name: + copy(path, env.dir_invalid.path) + for i in range(len(env.harvesters)): + env.expected[i].add_valid([env.dir_invalid.plot_info_list()[0]]) + env.expected[i].remove_invalid([env.dir_invalid.path_list()[0]]) + env.harvesters[i].plot_manager.refresh_parameter.retry_invalid_seconds = 0 + await env.run_sync_test() + for i in [0, 1]: + remove_plot_directory(env.harvesters[i].root_path, str(env.dir_invalid.path)) + env.expected[i].remove_valid([env.dir_invalid.path_list()[0]]) + env.expected[i].remove_invalid([env.dir_invalid.path_list()[1]]) + await env.run_sync_test() + + +@pytest.mark.asyncio +async def test_sync_keys_missing(environment: Environment) -> None: + env: Environment = environment + env.add_directory(0, env.dir_1) + env.add_directory(0, env.dir_keys_missing, State.keys_missing) + env.add_directory(1, env.dir_2) + await env.run_sync_test() + # Run again two times to make sure we still get the same results in repeated refresh intervals + await env.run_sync_test() + await env.run_sync_test() + # Drop all but 2 plots with missing keys and test sync inbetween + assert len(env.dir_keys_missing) > 2 + for _ in range(len(env.dir_keys_missing) - 2): + drop_plot = env.dir_keys_missing.path_list()[0] + drop_plot.unlink() + env.dir_keys_missing.drop(drop_plot) + env.expected[0].remove_keys_missing([drop_plot]) + await env.run_sync_test() + assert len(env.dir_keys_missing) == 2 + # Add the plots with missing keys to the other harvester + env.add_directory(0, env.dir_3) + env.add_directory(1, env.dir_keys_missing, State.keys_missing) + await env.run_sync_test() + # Add the missing keys to the first harvester's plot manager + env.harvesters[0].plot_manager.farmer_public_keys.append(G1Element()) + env.harvesters[0].plot_manager.pool_public_keys.append(G1Element()) + # And validate they become valid now + env.expected[0].add_valid(env.dir_keys_missing.plot_info_list()) + env.expected[0].remove_keys_missing(env.dir_keys_missing.path_list()) + await env.run_sync_test() + # Drop the valid plots from one harvester and the keys missing plots from the other harvester + env.remove_directory(0, env.dir_keys_missing) + env.remove_directory(1, env.dir_keys_missing, State.keys_missing) + await env.run_sync_test() + + +@pytest.mark.asyncio +async def test_sync_duplicates(environment: Environment) -> None: + env: Environment = environment + # dir_4 and then dir_duplicates contain the same plots. Load dir_4 first to make sure the plots seen as duplicates + # are from dir_duplicates. + env.add_directory(0, env.dir_4) + await env.run_sync_test() + env.add_directory(0, env.dir_duplicates, State.duplicates) + env.add_directory(1, env.dir_2) + await env.run_sync_test() + # Run again two times to make sure we still get the same results in repeated refresh intervals + await env.run_sync_test() + await env.run_sync_test() + # Drop all but 1 duplicates and test sync in-between + assert len(env.dir_duplicates) > 2 + for _ in range(len(env.dir_duplicates) - 2): + drop_plot = env.dir_duplicates.path_list()[0] + drop_plot.unlink() + env.dir_duplicates.drop(drop_plot) + env.expected[0].remove_duplicates([drop_plot]) + await env.run_sync_test() + assert len(env.dir_duplicates) == 2 + # Removing dir_4 now leads to the plots in dir_duplicates to become loaded instead + env.remove_directory(0, env.dir_4) + env.expected[0].remove_duplicates(env.dir_duplicates.path_list()) + env.expected[0].add_valid(env.dir_duplicates.plot_info_list()) + await env.run_sync_test() + + +async def add_and_validate_all_directories(env: Environment) -> None: + # Add all available directories to both harvesters and make sure they load and get synced + env.add_all_directories(0) + env.add_all_directories(1) + await env.run_sync_test() + env.add_directory(0, env.dir_duplicates, State.duplicates) + env.add_directory(1, env.dir_duplicates, State.duplicates) + await env.run_sync_test() + + +async def remove_and_validate_all_directories(env: Environment) -> None: + # Remove all available directories to both harvesters and make sure they are removed and get synced + env.remove_all_directories(0) + env.remove_all_directories(1) + await env.run_sync_test() + + +@pytest.mark.asyncio +async def test_add_and_remove_all_directories(environment: Environment) -> None: + await add_and_validate_all_directories(environment) + await remove_and_validate_all_directories(environment) + + +@pytest.mark.asyncio +async def test_harvester_restart(environment: Environment) -> None: + env: Environment = environment + # Load all directories for both harvesters + await add_and_validate_all_directories(env) + # Stop the harvester and make sure the receiver gets dropped on the farmer and refreshing gets stopped + env.harvester_services[0].stop() + await env.harvester_services[0].wait_closed() + assert len(env.farmer.plot_sync_receivers) == 1 + assert not env.harvesters[0].plot_manager._refreshing_enabled + assert not env.harvesters[0].plot_manager.needs_refresh() + # Start the harvester, wait for the handshake and make sure the receiver comes back + await env.harvester_services[0].start() + await time_out_assert(5, env.handshake_done, True, 0) + assert len(env.farmer.plot_sync_receivers) == 2 + # Remove the duplicates dir to avoid conflicts with the original plots + env.remove_directory(0, env.dir_duplicates) + # Reset the expected data for harvester 0 and re-add all directories because of the restart + env.expected[0] = ExpectedResult() + env.add_all_directories(0) + # Run the refresh two times and make sure everything recovers and stays recovered after harvester restart + await env.run_sync_test() + env.add_directory(0, env.dir_duplicates, State.duplicates) + await env.run_sync_test() + + +@pytest.mark.asyncio +async def test_farmer_restart(environment: Environment) -> None: + env: Environment = environment + # Load all directories for both harvesters + await add_and_validate_all_directories(env) + last_sync_ids: List[uint64] = [] + for i in range(0, len(env.harvesters)): + last_sync_ids.append(env.harvesters[i].plot_sync_sender._last_sync_id) + # Stop the farmer and make sure both receivers get dropped and refreshing gets stopped on the harvesters + env.farmer_service.stop() + await env.farmer_service.wait_closed() + assert len(env.farmer.plot_sync_receivers) == 0 + assert not env.harvesters[0].plot_manager._refreshing_enabled + assert not env.harvesters[1].plot_manager._refreshing_enabled + # Start the farmer, wait for the handshake and make sure the receivers come back + await env.farmer_service.start() + await time_out_assert(5, env.handshake_done, True, 0) + await time_out_assert(5, env.handshake_done, True, 1) + assert len(env.farmer.plot_sync_receivers) == 2 + # Do not use run_sync_test here, to have a more realistic test scenario just wait for the harvesters to be synced. + # The handshake should trigger re-sync. + for i in range(0, len(env.harvesters)): + harvester: Harvester = env.harvesters[i] + assert harvester.server is not None + receiver = env.farmer.plot_sync_receivers[harvester.server.node_id] + await time_out_assert(10, synced, True, harvester.plot_sync_sender, receiver, last_sync_ids[i]) + # Validate the sync + for harvester in env.harvesters: + plot_manager: PlotManager = harvester.plot_manager + assert harvester.server is not None + receiver = env.farmer.plot_sync_receivers[harvester.server.node_id] + expected = env.expected[env.harvesters.index(harvester)] + assert plot_manager.plot_count() == len(receiver.plots()) == expected.valid_count + assert len(plot_manager.failed_to_open_filenames) == len(receiver.invalid()) == expected.invalid_count + assert len(plot_manager.no_key_filenames) == len(receiver.keys_missing()) == expected.keys_missing_count + assert len(plot_manager.get_duplicates()) == len(receiver.duplicates()) == expected.duplicates_count diff --git a/tests/plot_sync/test_receiver.py b/tests/plot_sync/test_receiver.py new file mode 100644 index 000000000000..5c63ae412640 --- /dev/null +++ b/tests/plot_sync/test_receiver.py @@ -0,0 +1,376 @@ +import logging +import time +from secrets import token_bytes +from typing import Any, Callable, List, Tuple, Type, Union + +import pytest +from blspy import G1Element + +from chia.plot_sync.delta import Delta +from chia.plot_sync.receiver import Receiver +from chia.plot_sync.util import ErrorCodes, State +from chia.protocols.harvester_protocol import ( + Plot, + PlotSyncDone, + PlotSyncIdentifier, + PlotSyncPathList, + PlotSyncPlotList, + PlotSyncResponse, + PlotSyncStart, +) +from chia.server.ws_connection import NodeType +from chia.types.blockchain_format.sized_bytes import bytes32 +from chia.util.ints import uint8, uint32, uint64 +from chia.util.streamable import _T_Streamable +from tests.plot_sync.util import get_dummy_connection + +log = logging.getLogger(__name__) + +next_message_id = uint64(0) + + +def assert_default_values(receiver: Receiver) -> None: + assert receiver.state() == State.idle + assert receiver.expected_sync_id() == 0 + assert receiver.expected_message_id() == 0 + assert receiver.last_sync_id() == 0 + assert receiver.last_sync_time() == 0 + assert receiver.plots() == {} + assert receiver.invalid() == [] + assert receiver.keys_missing() == [] + assert receiver.duplicates() == [] + + +async def dummy_callback(_: bytes32, __: Delta) -> None: + pass + + +class SyncStepData: + state: State + function: Any + payload_type: Any + args: Any + + def __init__( + self, state: State, function: Callable[[_T_Streamable], Any], payload_type: Type[_T_Streamable], *args: Any + ) -> None: + self.state = state + self.function = function + self.payload_type = payload_type + self.args = args + + +def plot_sync_identifier(current_sync_id: uint64, message_id: uint64) -> PlotSyncIdentifier: + return PlotSyncIdentifier(uint64(0), current_sync_id, message_id) + + +def create_payload(payload_type: Any, start: bool, *args: Any) -> Any: + global next_message_id + if start: + next_message_id = uint64(0) + next_identifier = plot_sync_identifier(uint64(1), next_message_id) + next_message_id = uint64(next_message_id + 1) + return payload_type(next_identifier, *args) + + +def assert_error_response(plot_sync: Receiver, error_code: ErrorCodes) -> None: + connection = plot_sync.connection() + assert connection is not None + message = connection.last_sent_message + assert message is not None + response: PlotSyncResponse = PlotSyncResponse.from_bytes(message.data) + assert response.error is not None + assert response.error.code == error_code.value + + +def pre_function_validate(receiver: Receiver, data: Union[List[Plot], List[str]], expected_state: State) -> None: + if expected_state == State.loaded: + for plot_info in data: + assert type(plot_info) == Plot + assert plot_info.filename not in receiver.plots() + elif expected_state == State.removed: + for path in data: + assert path in receiver.plots() + elif expected_state == State.invalid: + for path in data: + assert path not in receiver.invalid() + elif expected_state == State.keys_missing: + for path in data: + assert path not in receiver.keys_missing() + elif expected_state == State.duplicates: + for path in data: + assert path not in receiver.duplicates() + + +def post_function_validate(receiver: Receiver, data: Union[List[Plot], List[str]], expected_state: State) -> None: + if expected_state == State.loaded: + for plot_info in data: + assert type(plot_info) == Plot + assert plot_info.filename in receiver._delta.valid.additions + elif expected_state == State.removed: + for path in data: + assert path in receiver._delta.valid.removals + elif expected_state == State.invalid: + for path in data: + assert path in receiver._delta.invalid.additions + elif expected_state == State.keys_missing: + for path in data: + assert path in receiver._delta.keys_missing.additions + elif expected_state == State.duplicates: + for path in data: + assert path in receiver._delta.duplicates.additions + + +@pytest.mark.asyncio +async def run_sync_step(receiver: Receiver, sync_step: SyncStepData, expected_state: State) -> None: + assert receiver.state() == expected_state + last_sync_time_before = receiver._last_sync_time + # For the the list types invoke the trigger function in batches + if sync_step.payload_type == PlotSyncPlotList or sync_step.payload_type == PlotSyncPathList: + step_data, _ = sync_step.args + assert len(step_data) == 10 + # Invoke batches of: 1, 2, 3, 4 items and validate the data against plot store before and after + indexes = [0, 1, 3, 6, 10] + for i in range(0, len(indexes) - 1): + invoke_data = step_data[indexes[i] : indexes[i + 1]] + pre_function_validate(receiver, invoke_data, expected_state) + await sync_step.function( + create_payload(sync_step.payload_type, False, invoke_data, i == (len(indexes) - 2)) + ) + post_function_validate(receiver, invoke_data, expected_state) + else: + # For Start/Done just invoke it.. + await sync_step.function(create_payload(sync_step.payload_type, sync_step.state == State.idle, *sync_step.args)) + # Make sure we moved to the next state + assert receiver.state() != expected_state + if sync_step.payload_type == PlotSyncDone: + assert receiver._last_sync_time != last_sync_time_before + else: + assert receiver._last_sync_time == last_sync_time_before + + +def plot_sync_setup() -> Tuple[Receiver, List[SyncStepData]]: + harvester_connection = get_dummy_connection(NodeType.HARVESTER) + receiver = Receiver(harvester_connection, dummy_callback) # type:ignore[arg-type] + + # Create example plot data + path_list = [str(x) for x in range(0, 40)] + plot_info_list = [ + Plot( + filename=str(x), + size=uint8(0), + plot_id=bytes32(token_bytes(32)), + pool_contract_puzzle_hash=None, + pool_public_key=None, + plot_public_key=G1Element(), + file_size=uint64(0), + time_modified=uint64(0), + ) + for x in path_list + ] + + # Manually add the plots we want to remove in tests + receiver._plots = {plot_info.filename: plot_info for plot_info in plot_info_list[0:10]} + + sync_steps: List[SyncStepData] = [ + SyncStepData(State.idle, receiver.sync_started, PlotSyncStart, False, uint64(0), uint32(len(plot_info_list))), + SyncStepData(State.loaded, receiver.process_loaded, PlotSyncPlotList, plot_info_list[10:20], True), + SyncStepData(State.removed, receiver.process_removed, PlotSyncPathList, path_list[0:10], True), + SyncStepData(State.invalid, receiver.process_invalid, PlotSyncPathList, path_list[20:30], True), + SyncStepData(State.keys_missing, receiver.process_keys_missing, PlotSyncPathList, path_list[30:40], True), + SyncStepData(State.duplicates, receiver.process_duplicates, PlotSyncPathList, path_list[10:20], True), + SyncStepData(State.done, receiver.sync_done, PlotSyncDone, uint64(0)), + ] + + return receiver, sync_steps + + +def test_default_values() -> None: + assert_default_values(Receiver(get_dummy_connection(NodeType.HARVESTER), dummy_callback)) # type:ignore[arg-type] + + +@pytest.mark.asyncio +async def test_reset() -> None: + receiver, sync_steps = plot_sync_setup() + connection_before = receiver.connection() + # Assign some dummy values + receiver._sync_state = State.done + receiver._expected_sync_id = uint64(1) + receiver._expected_message_id = uint64(1) + receiver._last_sync_id = uint64(1) + receiver._last_sync_time = time.time() + receiver._invalid = ["1"] + receiver._keys_missing = ["1"] + receiver._delta.valid.additions = receiver.plots().copy() + receiver._delta.valid.removals = ["1"] + receiver._delta.invalid.additions = ["1"] + receiver._delta.invalid.removals = ["1"] + receiver._delta.keys_missing.additions = ["1"] + receiver._delta.keys_missing.removals = ["1"] + receiver._delta.duplicates.additions = ["1"] + receiver._delta.duplicates.removals = ["1"] + # Call `reset` and make sure all expected values are set back to their defaults. + receiver.reset() + assert_default_values(receiver) + assert receiver._delta == Delta() + # Connection should remain + assert receiver.connection() == connection_before + + +@pytest.mark.asyncio +async def test_to_dict() -> None: + receiver, sync_steps = plot_sync_setup() + plot_sync_dict_1 = receiver.to_dict() + assert "plots" in plot_sync_dict_1 and len(plot_sync_dict_1["plots"]) == 10 + assert "failed_to_open_filenames" in plot_sync_dict_1 and len(plot_sync_dict_1["failed_to_open_filenames"]) == 0 + assert "no_key_filenames" in plot_sync_dict_1 and len(plot_sync_dict_1["no_key_filenames"]) == 0 + assert "last_sync_time" not in plot_sync_dict_1 + assert plot_sync_dict_1["connection"] == { + "node_id": receiver.connection().peer_node_id, + "host": receiver.connection().peer_host, + "port": receiver.connection().peer_port, + } + + # We should get equal dicts + plot_sync_dict_2 = receiver.to_dict() + assert plot_sync_dict_1 == plot_sync_dict_2 + + dict_2_paths = [x.filename for x in plot_sync_dict_2["plots"]] + for plot_info in sync_steps[State.loaded].args[0]: + assert plot_info.filename not in dict_2_paths + + # Walk through all states from idle to done and run them with the test data + for state in State: + await run_sync_step(receiver, sync_steps[state], state) + + plot_sync_dict_3 = receiver.to_dict() + dict_3_paths = [x.filename for x in plot_sync_dict_3["plots"]] + for plot_info in sync_steps[State.loaded].args[0]: + assert plot_info.filename in dict_3_paths + + for path in sync_steps[State.removed].args[0]: + assert path not in plot_sync_dict_3["plots"] + + for path in sync_steps[State.invalid].args[0]: + assert path in plot_sync_dict_3["failed_to_open_filenames"] + + for path in sync_steps[State.keys_missing].args[0]: + assert path in plot_sync_dict_3["no_key_filenames"] + + for path in sync_steps[State.duplicates].args[0]: + assert path in plot_sync_dict_3["duplicates"] + + assert plot_sync_dict_3["last_sync_time"] > 0 + + +@pytest.mark.asyncio +async def test_sync_flow() -> None: + receiver, sync_steps = plot_sync_setup() + + for plot_info in sync_steps[State.loaded].args[0]: + assert plot_info.filename not in receiver.plots() + + for path in sync_steps[State.removed].args[0]: + assert path in receiver.plots() + + for path in sync_steps[State.invalid].args[0]: + assert path not in receiver.invalid() + + for path in sync_steps[State.keys_missing].args[0]: + assert path not in receiver.keys_missing() + + for path in sync_steps[State.duplicates].args[0]: + assert path not in receiver.duplicates() + + # Walk through all states from idle to done and run them with the test data + for state in State: + await run_sync_step(receiver, sync_steps[state], state) + + for plot_info in sync_steps[State.loaded].args[0]: + assert plot_info.filename in receiver.plots() + + for path in sync_steps[State.removed].args[0]: + assert path not in receiver.plots() + + for path in sync_steps[State.invalid].args[0]: + assert path in receiver.invalid() + + for path in sync_steps[State.keys_missing].args[0]: + assert path in receiver.keys_missing() + + for path in sync_steps[State.duplicates].args[0]: + assert path in receiver.duplicates() + + # We should be in idle state again + assert receiver.state() == State.idle + + +@pytest.mark.asyncio +async def test_invalid_ids() -> None: + receiver, sync_steps = plot_sync_setup() + for state in State: + assert receiver.state() == state + current_step = sync_steps[state] + if receiver.state() == State.idle: + # Set last_sync_id for the tests below + receiver._last_sync_id = uint64(1) + # Test "sync_started last doesn't match" + invalid_last_sync_id_param = PlotSyncStart( + plot_sync_identifier(uint64(0), uint64(0)), False, uint64(2), uint32(0) + ) + await current_step.function(invalid_last_sync_id_param) + assert_error_response(receiver, ErrorCodes.invalid_last_sync_id) + # Test "last_sync_id == new_sync_id" + invalid_sync_id_match_param = PlotSyncStart( + plot_sync_identifier(uint64(1), uint64(0)), False, uint64(1), uint32(0) + ) + await current_step.function(invalid_sync_id_match_param) + assert_error_response(receiver, ErrorCodes.sync_ids_match) + # Reset the last_sync_id to the default + receiver._last_sync_id = uint64(0) + else: + # Test invalid sync_id + invalid_sync_id_param = current_step.payload_type( + plot_sync_identifier(uint64(10), uint64(receiver.expected_message_id())), *current_step.args + ) + await current_step.function(invalid_sync_id_param) + assert_error_response(receiver, ErrorCodes.invalid_identifier) + # Test invalid message_id + invalid_message_id_param = current_step.payload_type( + plot_sync_identifier(receiver.expected_sync_id(), uint64(receiver.expected_message_id() + 1)), + *current_step.args, + ) + await current_step.function(invalid_message_id_param) + assert_error_response(receiver, ErrorCodes.invalid_identifier) + payload = create_payload(current_step.payload_type, state == State.idle, *current_step.args) + await current_step.function(payload) + + +@pytest.mark.parametrize( + ["state_to_fail", "expected_error_code"], + [ + pytest.param(State.loaded, ErrorCodes.plot_already_available, id="already available plots"), + pytest.param(State.invalid, ErrorCodes.plot_already_available, id="already available paths"), + pytest.param(State.removed, ErrorCodes.plot_not_available, id="not available"), + ], +) +@pytest.mark.asyncio +async def test_plot_errors(state_to_fail: State, expected_error_code: ErrorCodes) -> None: + receiver, sync_steps = plot_sync_setup() + for state in State: + assert receiver.state() == state + current_step = sync_steps[state] + if state == state_to_fail: + plot_infos, _ = current_step.args + await current_step.function(create_payload(current_step.payload_type, False, plot_infos, False)) + identifier = plot_sync_identifier(receiver.expected_sync_id(), receiver.expected_message_id()) + invalid_payload = current_step.payload_type(identifier, plot_infos, True) + await current_step.function(invalid_payload) + if state == state_to_fail: + assert_error_response(receiver, expected_error_code) + return + else: + await current_step.function( + create_payload(current_step.payload_type, state == State.idle, *current_step.args) + ) + assert False, "Didn't fail in the expected state" diff --git a/tests/plot_sync/test_sender.py b/tests/plot_sync/test_sender.py new file mode 100644 index 000000000000..09747ec8bbaf --- /dev/null +++ b/tests/plot_sync/test_sender.py @@ -0,0 +1,102 @@ +import pytest + +from chia.plot_sync.exceptions import AlreadyStartedError, InvalidConnectionTypeError +from chia.plot_sync.sender import ExpectedResponse, Sender +from chia.plot_sync.util import Constants +from chia.protocols.harvester_protocol import PlotSyncIdentifier, PlotSyncResponse +from chia.server.ws_connection import NodeType, ProtocolMessageTypes +from chia.util.ints import int16, uint64 +from tests.block_tools import BlockTools +from tests.plot_sync.util import get_dummy_connection, plot_sync_identifier + + +def test_default_values(bt: BlockTools) -> None: + sender = Sender(bt.plot_manager) + assert sender._plot_manager == bt.plot_manager + assert sender._connection is None + assert sender._sync_id == uint64(0) + assert sender._next_message_id == uint64(0) + assert sender._messages == [] + assert sender._last_sync_id == uint64(0) + assert not sender._stop_requested + assert sender._task is None + assert not sender._lock.locked() + assert sender._response is None + + +def test_set_connection_values(bt: BlockTools) -> None: + farmer_connection = get_dummy_connection(NodeType.FARMER) + sender = Sender(bt.plot_manager) + # Test invalid NodeType values + for connection_type in NodeType: + if connection_type != NodeType.FARMER: + pytest.raises( + InvalidConnectionTypeError, + sender.set_connection, + get_dummy_connection(connection_type, farmer_connection.peer_node_id), + ) + # Test setting a valid connection works + sender.set_connection(farmer_connection) # type:ignore[arg-type] + assert sender._connection is not None + assert sender._connection == farmer_connection # type: ignore[comparison-overlap] + + +@pytest.mark.asyncio +async def test_start_stop_send_task(bt: BlockTools) -> None: + sender = Sender(bt.plot_manager) + # Make sure starting/restarting works + for _ in range(2): + assert sender._task is None + await sender.start() + assert sender._task is not None + with pytest.raises(AlreadyStartedError): + await sender.start() + assert not sender._stop_requested + sender.stop() + assert sender._stop_requested + await sender.await_closed() + assert not sender._stop_requested + assert sender._task is None + + +def test_set_response(bt: BlockTools) -> None: + sender = Sender(bt.plot_manager) + + def new_expected_response(sync_id: int, message_id: int, message_type: ProtocolMessageTypes) -> ExpectedResponse: + return ExpectedResponse(message_type, plot_sync_identifier(uint64(sync_id), uint64(message_id))) + + def new_response_message(sync_id: int, message_id: int, message_type: ProtocolMessageTypes) -> PlotSyncResponse: + return PlotSyncResponse( + plot_sync_identifier(uint64(sync_id), uint64(message_id)), int16(int(message_type.value)), None + ) + + response_message = new_response_message(0, 1, ProtocolMessageTypes.plot_sync_start) + assert sender._response is None + # Should trigger unexpected response because `Farmer._response` is `None` + assert not sender.set_response(response_message) + # Set `Farmer._response` and make sure the response gets assigned properly + sender._response = new_expected_response(0, 1, ProtocolMessageTypes.plot_sync_start) + assert sender._response.message is None + assert sender.set_response(response_message) + assert sender._response.message is not None + # Should trigger unexpected response because we already received the message for the currently expected response + assert not sender.set_response(response_message) + # Test expired message + expected_response = new_expected_response(1, 0, ProtocolMessageTypes.plot_sync_start) + sender._response = expected_response + expired_identifier = PlotSyncIdentifier( + uint64(expected_response.identifier.timestamp - Constants.message_timeout - 1), + expected_response.identifier.sync_id, + expected_response.identifier.message_id, + ) + expired_message = PlotSyncResponse(expired_identifier, int16(int(ProtocolMessageTypes.plot_sync_start.value)), None) + assert not sender.set_response(expired_message) + # Test invalid sync-id + sender._response = new_expected_response(2, 0, ProtocolMessageTypes.plot_sync_start) + assert not sender.set_response(new_response_message(3, 0, ProtocolMessageTypes.plot_sync_start)) + # Test invalid message-id + sender._response = new_expected_response(2, 1, ProtocolMessageTypes.plot_sync_start) + assert not sender.set_response(new_response_message(2, 2, ProtocolMessageTypes.plot_sync_start)) + # Test invalid message-type + sender._response = new_expected_response(3, 0, ProtocolMessageTypes.plot_sync_start) + assert not sender.set_response(new_response_message(3, 0, ProtocolMessageTypes.plot_sync_loaded)) diff --git a/tests/plot_sync/test_sync_simulated.py b/tests/plot_sync/test_sync_simulated.py new file mode 100644 index 000000000000..ae83dc7b64c4 --- /dev/null +++ b/tests/plot_sync/test_sync_simulated.py @@ -0,0 +1,433 @@ +import asyncio +import functools +import logging +import time +from dataclasses import dataclass, field +from enum import Enum +from pathlib import Path +from secrets import token_bytes +from typing import Any, Dict, List, Optional, Set, Tuple + +import pytest +from blspy import G1Element + +from chia.farmer.farmer_api import Farmer +from chia.harvester.harvester_api import Harvester +from chia.plot_sync.receiver import Receiver +from chia.plot_sync.sender import Sender +from chia.plot_sync.util import Constants +from chia.plotting.manager import PlotManager +from chia.plotting.util import PlotInfo +from chia.protocols.harvester_protocol import PlotSyncError, PlotSyncResponse +from chia.server.start_service import Service +from chia.server.ws_connection import ProtocolMessageTypes, WSChiaConnection, make_msg +from chia.types.blockchain_format.sized_bytes import bytes32 +from chia.util.generator_tools import list_to_batches +from chia.util.ints import int16, uint64 +from tests.plot_sync.util import start_harvester_service +from tests.time_out_assert import time_out_assert + +log = logging.getLogger(__name__) + + +class ErrorSimulation(Enum): + DropEveryFourthMessage = 1 + DropThreeMessages = 2 + RespondTooLateEveryFourthMessage = 3 + RespondTwice = 4 + NonRecoverableError = 5 + NotConnected = 6 + + +@dataclass +class TestData: + harvester: Harvester + plot_sync_sender: Sender + plot_sync_receiver: Receiver + event_loop: asyncio.AbstractEventLoop + plots: Dict[Path, PlotInfo] = field(default_factory=dict) + invalid: List[PlotInfo] = field(default_factory=list) + keys_missing: List[PlotInfo] = field(default_factory=list) + duplicates: List[PlotInfo] = field(default_factory=list) + + async def run( + self, + *, + loaded: List[PlotInfo], + removed: List[PlotInfo], + invalid: List[PlotInfo], + keys_missing: List[PlotInfo], + duplicates: List[PlotInfo], + initial: bool, + ) -> None: + for plot_info in loaded: + assert plot_info.prover.get_filename() not in self.plots + for plot_info in removed: + assert plot_info.prover.get_filename() in self.plots + + self.invalid = invalid + self.keys_missing = keys_missing + self.duplicates = duplicates + + removed_paths: List[Path] = [p.prover.get_filename() for p in removed] if removed is not None else [] + invalid_dict: Dict[Path, int] = {p.prover.get_filename(): 0 for p in self.invalid} + keys_missing_set: Set[Path] = set([p.prover.get_filename() for p in self.keys_missing]) + duplicates_set: Set[str] = set([p.prover.get_filename() for p in self.duplicates]) + + # Inject invalid plots into `PlotManager` of the harvester so that the callback calls below can use them + # to sync them to the farmer. + self.harvester.plot_manager.failed_to_open_filenames = invalid_dict + # Inject key missing plots into `PlotManager` of the harvester so that the callback calls below can use them + # to sync them to the farmer. + self.harvester.plot_manager.no_key_filenames = keys_missing_set + # Inject duplicated plots into `PlotManager` of the harvester so that the callback calls below can use them + # to sync them to the farmer. + for plot_info in loaded: + plot_path = Path(plot_info.prover.get_filename()) + self.harvester.plot_manager.plot_filename_paths[plot_path.name] = (str(plot_path.parent), set()) + for duplicate in duplicates_set: + plot_path = Path(duplicate) + assert plot_path.name in self.harvester.plot_manager.plot_filename_paths + self.harvester.plot_manager.plot_filename_paths[plot_path.name][1].add(str(plot_path.parent)) + + batch_size = self.harvester.plot_manager.refresh_parameter.batch_size + + # Used to capture the sync id in `run_internal` + sync_id: Optional[uint64] = None + + def run_internal() -> None: + nonlocal sync_id + # Simulate one plot manager refresh cycle by calling the methods directly. + self.harvester.plot_sync_sender.sync_start(len(loaded), initial) + sync_id = self.plot_sync_sender._sync_id + if len(loaded) == 0: + self.harvester.plot_sync_sender.process_batch([], 0) + for remaining, batch in list_to_batches(loaded, batch_size): + self.harvester.plot_sync_sender.process_batch(batch, remaining) + self.harvester.plot_sync_sender.sync_done(removed_paths, 0) + + await self.event_loop.run_in_executor(None, run_internal) + + async def sync_done() -> bool: + assert sync_id is not None + return self.plot_sync_receiver.last_sync_id() == self.plot_sync_sender._last_sync_id == sync_id + + await time_out_assert(60, sync_done) + + for plot_info in loaded: + self.plots[plot_info.prover.get_filename()] = plot_info + for plot_info in removed: + del self.plots[plot_info.prover.get_filename()] + + def validate_plot_sync(self) -> None: + assert len(self.plots) == len(self.plot_sync_receiver.plots()) + assert len(self.invalid) == len(self.plot_sync_receiver.invalid()) + assert len(self.keys_missing) == len(self.plot_sync_receiver.keys_missing()) + for _, plot_info in self.plots.items(): + assert plot_info.prover.get_filename() not in self.plot_sync_receiver.invalid() + assert plot_info.prover.get_filename() not in self.plot_sync_receiver.keys_missing() + assert plot_info.prover.get_filename() in self.plot_sync_receiver.plots() + synced_plot = self.plot_sync_receiver.plots()[plot_info.prover.get_filename()] + assert plot_info.prover.get_filename() == synced_plot.filename + assert plot_info.pool_public_key == synced_plot.pool_public_key + assert plot_info.pool_contract_puzzle_hash == synced_plot.pool_contract_puzzle_hash + assert plot_info.plot_public_key == synced_plot.plot_public_key + assert plot_info.file_size == synced_plot.file_size + assert uint64(int(plot_info.time_modified)) == synced_plot.time_modified + for plot_info in self.invalid: + assert plot_info.prover.get_filename() not in self.plot_sync_receiver.plots() + assert plot_info.prover.get_filename() in self.plot_sync_receiver.invalid() + assert plot_info.prover.get_filename() not in self.plot_sync_receiver.keys_missing() + assert plot_info.prover.get_filename() not in self.plot_sync_receiver.duplicates() + for plot_info in self.keys_missing: + assert plot_info.prover.get_filename() not in self.plot_sync_receiver.plots() + assert plot_info.prover.get_filename() not in self.plot_sync_receiver.invalid() + assert plot_info.prover.get_filename() in self.plot_sync_receiver.keys_missing() + assert plot_info.prover.get_filename() not in self.plot_sync_receiver.duplicates() + for plot_info in self.duplicates: + assert plot_info.prover.get_filename() not in self.plot_sync_receiver.invalid() + assert plot_info.prover.get_filename() not in self.plot_sync_receiver.keys_missing() + assert plot_info.prover.get_filename() in self.plot_sync_receiver.duplicates() + + +@dataclass +class TestRunner: + test_data: List[TestData] + + def __init__( + self, harvesters: List[Harvester], farmer: Farmer, event_loop: asyncio.events.AbstractEventLoop + ) -> None: + self.test_data = [] + for harvester in harvesters: + assert harvester.server is not None + self.test_data.append( + TestData( + harvester, + harvester.plot_sync_sender, + farmer.plot_sync_receivers[harvester.server.node_id], + event_loop, + ) + ) + + async def run( + self, + index: int, + *, + loaded: List[PlotInfo], + removed: List[PlotInfo], + invalid: List[PlotInfo], + keys_missing: List[PlotInfo], + duplicates: List[PlotInfo], + initial: bool, + ) -> None: + await self.test_data[index].run( + loaded=loaded, + removed=removed, + invalid=invalid, + keys_missing=keys_missing, + duplicates=duplicates, + initial=initial, + ) + for data in self.test_data: + data.validate_plot_sync() + + +async def skip_processing(self: Any, _: WSChiaConnection, message_type: ProtocolMessageTypes, message: Any) -> bool: + self.message_counter += 1 + if self.simulate_error == ErrorSimulation.DropEveryFourthMessage: + if self.message_counter % 4 == 0: + return True + if self.simulate_error == ErrorSimulation.DropThreeMessages: + if 2 < self.message_counter < 6: + return True + if self.simulate_error == ErrorSimulation.RespondTooLateEveryFourthMessage: + if self.message_counter % 4 == 0: + await asyncio.sleep(Constants.message_timeout + 1) + return False + if self.simulate_error == ErrorSimulation.RespondTwice: + await self.connection().send_message( + make_msg( + ProtocolMessageTypes.plot_sync_response, + PlotSyncResponse(message.identifier, int16(message_type.value), None), + ) + ) + if self.simulate_error == ErrorSimulation.NonRecoverableError and self.message_counter > 1: + await self.connection().send_message( + make_msg( + ProtocolMessageTypes.plot_sync_response, + PlotSyncResponse( + message.identifier, int16(message_type.value), PlotSyncError(int16(0), "non recoverable", None) + ), + ) + ) + self.simulate_error = 0 + return True + return False + + +async def _testable_process( + self: Any, peer: WSChiaConnection, message_type: ProtocolMessageTypes, message: Any +) -> None: + if await skip_processing(self, peer, message_type, message): + return + await self.original_process(peer, message_type, message) + + +async def create_test_runner( + harvester_services: List[Service], farmer: Farmer, event_loop: asyncio.events.AbstractEventLoop +) -> TestRunner: + assert len(farmer.plot_sync_receivers) == 0 + harvesters: List[Harvester] = [await start_harvester_service(service) for service in harvester_services] + for receiver in farmer.plot_sync_receivers.values(): + receiver.simulate_error = 0 # type: ignore[attr-defined] + receiver.message_counter = 0 # type: ignore[attr-defined] + receiver.original_process = receiver._process # type: ignore[attr-defined] + receiver._process = functools.partial(_testable_process, receiver) # type: ignore[assignment] + return TestRunner(harvesters, farmer, event_loop) + + +def create_example_plots(count: int) -> List[PlotInfo]: + @dataclass + class DiskProver: + file_name: str + plot_id: bytes32 + size: int + + def get_filename(self) -> str: + return self.file_name + + def get_id(self) -> bytes32: + return self.plot_id + + def get_size(self) -> int: + return self.size + + return [ + PlotInfo( + prover=DiskProver(f"{x}", bytes32(token_bytes(32)), x % 255), + pool_public_key=None, + pool_contract_puzzle_hash=None, + plot_public_key=G1Element(), + file_size=uint64(0), + time_modified=time.time(), + ) + for x in range(0, count) + ] + + +@pytest.mark.asyncio +async def test_sync_simulated( + farmer_three_harvester: Tuple[List[Service], Service], event_loop: asyncio.events.AbstractEventLoop +) -> None: + harvester_services: List[Service] + farmer_service: Service + harvester_services, farmer_service = farmer_three_harvester + farmer: Farmer = farmer_service._node + test_runner: TestRunner = await create_test_runner(harvester_services, farmer, event_loop) + plots = create_example_plots(31000) + + await test_runner.run( + 0, loaded=plots[0:10000], removed=[], invalid=[], keys_missing=[], duplicates=plots[0:1000], initial=True + ) + await test_runner.run( + 1, + loaded=plots[10000:20000], + removed=[], + invalid=plots[30000:30100], + keys_missing=[], + duplicates=[], + initial=True, + ) + await test_runner.run( + 2, + loaded=plots[20000:30000], + removed=[], + invalid=[], + keys_missing=plots[30100:30200], + duplicates=[], + initial=True, + ) + await test_runner.run( + 0, + loaded=[], + removed=[], + invalid=plots[30300:30400], + keys_missing=plots[30400:30453], + duplicates=[], + initial=False, + ) + await test_runner.run(0, loaded=[], removed=[], invalid=[], keys_missing=[], duplicates=[], initial=False) + await test_runner.run( + 0, loaded=[], removed=plots[5000:10000], invalid=[], keys_missing=[], duplicates=[], initial=False + ) + await test_runner.run( + 1, loaded=[], removed=plots[10000:20000], invalid=[], keys_missing=[], duplicates=[], initial=False + ) + await test_runner.run( + 2, loaded=[], removed=plots[20000:29000], invalid=[], keys_missing=[], duplicates=[], initial=False + ) + await test_runner.run( + 0, loaded=[], removed=plots[0:5000], invalid=[], keys_missing=[], duplicates=[], initial=False + ) + await test_runner.run( + 2, + loaded=plots[5000:10000], + removed=plots[29000:30000], + invalid=plots[30000:30500], + keys_missing=plots[30500:31000], + duplicates=plots[5000:6000], + initial=False, + ) + await test_runner.run( + 2, loaded=[], removed=plots[5000:10000], invalid=[], keys_missing=[], duplicates=[], initial=False + ) + assert len(farmer.plot_sync_receivers) == 3 + for plot_sync in farmer.plot_sync_receivers.values(): + assert len(plot_sync.plots()) == 0 + + +@pytest.mark.parametrize( + "simulate_error", + [ + ErrorSimulation.DropEveryFourthMessage, + ErrorSimulation.DropThreeMessages, + ErrorSimulation.RespondTooLateEveryFourthMessage, + ErrorSimulation.RespondTwice, + ], +) +@pytest.mark.asyncio +async def test_farmer_error_simulation( + farmer_one_harvester: Tuple[List[Service], Service], + event_loop: asyncio.events.AbstractEventLoop, + simulate_error: ErrorSimulation, +) -> None: + Constants.message_timeout = 5 + harvester_services: List[Service] + farmer_service: Service + harvester_services, farmer_service = farmer_one_harvester + test_runner: TestRunner = await create_test_runner(harvester_services, farmer_service._node, event_loop) + batch_size = test_runner.test_data[0].harvester.plot_manager.refresh_parameter.batch_size + plots = create_example_plots(batch_size + 3) + receiver = test_runner.test_data[0].plot_sync_receiver + receiver.simulate_error = simulate_error # type: ignore[attr-defined] + await test_runner.run( + 0, + loaded=plots[0 : batch_size + 1], + removed=[], + invalid=[plots[batch_size + 1]], + keys_missing=[plots[batch_size + 2]], + duplicates=[], + initial=True, + ) + + +@pytest.mark.parametrize("simulate_error", [ErrorSimulation.NonRecoverableError, ErrorSimulation.NotConnected]) +@pytest.mark.asyncio +async def test_sync_reset_cases( + farmer_one_harvester: Tuple[List[Service], Service], + event_loop: asyncio.events.AbstractEventLoop, + simulate_error: ErrorSimulation, +) -> None: + harvester_services: List[Service] + farmer_service: Service + harvester_services, farmer_service = farmer_one_harvester + test_runner: TestRunner = await create_test_runner(harvester_services, farmer_service._node, event_loop) + test_data: TestData = test_runner.test_data[0] + plot_manager: PlotManager = test_data.harvester.plot_manager + plots = create_example_plots(30) + # Inject some data into `PlotManager` of the harvester so that we can validate the reset worked and triggered a + # fresh sync of all available data of the plot manager + for plot_info in plots[0:10]: + test_data.plots[plot_info.prover.get_filename()] = plot_info + plot_manager.plots = test_data.plots + test_data.invalid = plots[10:20] + test_data.keys_missing = plots[20:30] + test_data.plot_sync_receiver.simulate_error = simulate_error # type: ignore[attr-defined] + sender: Sender = test_runner.test_data[0].plot_sync_sender + started_sync_id: uint64 = uint64(0) + + plot_manager.failed_to_open_filenames = {p.prover.get_filename(): 0 for p in test_data.invalid} + plot_manager.no_key_filenames = set([p.prover.get_filename() for p in test_data.keys_missing]) + + async def wait_for_reset() -> bool: + assert started_sync_id != 0 + return sender._sync_id != started_sync_id != 0 + + async def sync_done() -> bool: + assert started_sync_id != 0 + return test_data.plot_sync_receiver.last_sync_id() == sender._last_sync_id == started_sync_id + + # Send start and capture the sync_id + sender.sync_start(len(plots), True) + started_sync_id = sender._sync_id + # Sleep 2 seconds to make sure we have a different sync_id after the reset which gets triggered + await asyncio.sleep(2) + saved_connection = sender._connection + if simulate_error == ErrorSimulation.NotConnected: + sender._connection = None + sender.process_batch(plots, 0) + await time_out_assert(60, wait_for_reset) + started_sync_id = sender._sync_id + sender._connection = saved_connection + await time_out_assert(60, sync_done) + test_runner.test_data[0].validate_plot_sync() diff --git a/tests/plot_sync/util.py b/tests/plot_sync/util.py new file mode 100644 index 000000000000..823616f50b40 --- /dev/null +++ b/tests/plot_sync/util.py @@ -0,0 +1,53 @@ +import time +from dataclasses import dataclass +from secrets import token_bytes +from typing import Optional + +from chia.harvester.harvester_api import Harvester +from chia.plot_sync.sender import Sender +from chia.protocols.harvester_protocol import PlotSyncIdentifier +from chia.server.start_service import Service +from chia.server.ws_connection import Message, NodeType +from chia.types.blockchain_format.sized_bytes import bytes32 +from chia.util.ints import uint64 +from tests.time_out_assert import time_out_assert + + +@dataclass +class WSChiaConnectionDummy: + connection_type: NodeType + peer_node_id: bytes32 + peer_host: str = "localhost" + peer_port: int = 0 + last_sent_message: Optional[Message] = None + + async def send_message(self, message: Message) -> None: + self.last_sent_message = message + + +def get_dummy_connection(node_type: NodeType, peer_id: Optional[bytes32] = None) -> WSChiaConnectionDummy: + return WSChiaConnectionDummy(node_type, bytes32(token_bytes(32)) if peer_id is None else peer_id) + + +def plot_sync_identifier(current_sync_id: uint64, message_id: uint64) -> PlotSyncIdentifier: + return PlotSyncIdentifier(uint64(int(time.time())), current_sync_id, message_id) + + +async def start_harvester_service(harvester_service: Service) -> Harvester: + # Set the `last_refresh_time` of the plot manager to avoid initial plot loading + harvester: Harvester = harvester_service._node + harvester.plot_manager.last_refresh_time = time.time() + await harvester_service.start() + harvester.plot_manager.stop_refreshing() # type: ignore[no-untyped-call] # TODO, Add typing in PlotManager + + assert harvester.plot_sync_sender._sync_id == 0 + assert harvester.plot_sync_sender._next_message_id == 0 + assert harvester.plot_sync_sender._last_sync_id == 0 + assert harvester.plot_sync_sender._messages == [] + + def wait_for_farmer_connection(plot_sync_sender: Sender) -> bool: + return plot_sync_sender._connection is not None + + await time_out_assert(10, wait_for_farmer_connection, True, harvester.plot_sync_sender) + + return harvester diff --git a/tests/plotting/test_plot_manager.py b/tests/plotting/test_plot_manager.py index f2a0e843c255..fea50cc4e171 100644 --- a/tests/plotting/test_plot_manager.py +++ b/tests/plotting/test_plot_manager.py @@ -236,7 +236,7 @@ async def test_plot_refreshing(test_plot_environment): trigger=trigger_remove_plot, test_path=drop_path, expect_loaded=[], - expect_removed=[drop_path], + expect_removed=[], expect_processed=len(env.dir_1) + len(env.dir_2) + len(dir_duplicates), expect_duplicates=len(dir_duplicates), expected_directories=3, @@ -262,7 +262,7 @@ async def test_plot_refreshing(test_plot_environment): trigger=remove_plot_directory, test_path=dir_duplicates.path, expect_loaded=[], - expect_removed=dir_duplicates.path_list(), + expect_removed=[], expect_processed=len(env.dir_1) + len(env.dir_2), expect_duplicates=0, expected_directories=2, @@ -316,7 +316,7 @@ async def test_plot_refreshing(test_plot_environment): trigger=trigger_remove_plot, test_path=drop_path, expect_loaded=[], - expect_removed=[drop_path], + expect_removed=[], expect_processed=len(env.dir_1) + len(env.dir_2) + len(dir_duplicates), expect_duplicates=len(env.dir_1), expected_directories=3, @@ -357,6 +357,17 @@ async def test_plot_refreshing(test_plot_environment): ) +@pytest.mark.asyncio +async def test_initial_refresh_flag(test_plot_environment: TestEnvironment) -> None: + env: TestEnvironment = test_plot_environment + assert env.refresh_tester.plot_manager.initial_refresh() + for _ in range(2): + await env.refresh_tester.run(PlotRefreshResult()) + assert not env.refresh_tester.plot_manager.initial_refresh() + env.refresh_tester.plot_manager.reset() + assert env.refresh_tester.plot_manager.initial_refresh() + + @pytest.mark.asyncio async def test_invalid_plots(test_plot_environment): env: TestEnvironment = test_plot_environment diff --git a/tests/setup_nodes.py b/tests/setup_nodes.py index 4cfb701d9c14..03a97c386d67 100644 --- a/tests/setup_nodes.py +++ b/tests/setup_nodes.py @@ -1,12 +1,15 @@ import asyncio import logging from secrets import token_bytes -from typing import Dict, List +from typing import AsyncIterator, Dict, List, Tuple +from pathlib import Path from chia.consensus.constants import ConsensusConstants +from chia.cmds.init_funcs import init from chia.full_node.full_node_api import FullNodeAPI from chia.server.start_service import Service from chia.server.start_wallet import service_kwargs_for_wallet +from chia.util.config import load_config, save_config from chia.util.hash import std_hash from chia.util.ints import uint16, uint32 from chia.util.keychain import bytes_to_mnemonic @@ -294,7 +297,7 @@ async def setup_harvester_farmer(bt: BlockTools, consensus_constants: ConsensusC harvester_rpc_port = find_available_listen_port("harvester rpc") node_iters = [ setup_harvester( - bt, + bt.root_path, bt.config["self_hostname"], harvester_port, harvester_rpc_port, @@ -320,6 +323,62 @@ async def setup_harvester_farmer(bt: BlockTools, consensus_constants: ConsensusC await _teardown_nodes(node_iters) +async def setup_farmer_multi_harvester( + block_tools: BlockTools, + harvester_count: int, + temp_dir: Path, + consensus_constants: ConsensusConstants, +) -> AsyncIterator[Tuple[List[Service], Service]]: + farmer_port = find_available_listen_port("farmer") + farmer_rpc_port = find_available_listen_port("farmer rpc") + + node_iterators = [ + setup_farmer( + block_tools, block_tools.config["self_hostname"], farmer_port, farmer_rpc_port, consensus_constants + ) + ] + + for i in range(0, harvester_count): + root_path: Path = temp_dir / str(i) + init(None, root_path) + init(block_tools.root_path / "config" / "ssl" / "ca", root_path) + config = load_config(root_path, "config.yaml") + config["logging"]["log_stdout"] = True + config["selected_network"] = "testnet0" + config["harvester"]["selected_network"] = "testnet0" + harvester_port = find_available_listen_port("harvester") + harvester_rpc_port = find_available_listen_port("harvester rpc") + save_config(root_path, "config.yaml", config) + node_iterators.append( + setup_harvester( + root_path, + block_tools.config["self_hostname"], + harvester_port, + harvester_rpc_port, + farmer_port, + consensus_constants, + False, + ) + ) + + farmer_service = await node_iterators[0].__anext__() + harvester_services = [] + for node in node_iterators[1:]: + harvester_service = await node.__anext__() + harvester_services.append(harvester_service) + + yield harvester_services, farmer_service + + for harvester_service in harvester_services: + harvester_service.stop() + await harvester_service.wait_closed() + + farmer_service.stop() + await farmer_service.wait_closed() + + await _teardown_nodes(node_iterators) + + async def setup_full_system( consensus_constants: ConsensusConstants, shared_b_tools: BlockTools, @@ -353,7 +412,7 @@ async def setup_full_system( node_iters = [ setup_introducer(shared_b_tools, introducer_port), setup_harvester( - shared_b_tools, + shared_b_tools.root_path, shared_b_tools.config["self_hostname"], harvester_port, harvester_rpc_port, diff --git a/tests/setup_services.py b/tests/setup_services.py index 848506fe7fed..27b6b42410dd 100644 --- a/tests/setup_services.py +++ b/tests/setup_services.py @@ -2,6 +2,7 @@ import asyncio import logging import signal import sqlite3 +from pathlib import Path from secrets import token_bytes from typing import AsyncGenerator, Optional @@ -16,8 +17,8 @@ from chia.server.start_timelord import service_kwargs_for_timelord from chia.server.start_wallet import service_kwargs_for_wallet from chia.simulator.start_simulator import service_kwargs_for_full_node_simulator from chia.timelord.timelord_launcher import kill_processes, spawn_process -from chia.types.peer_info import PeerInfo from chia.util.bech32m import encode_puzzle_hash +from chia.util.config import load_config, save_config from chia.util.ints import uint16 from chia.util.keychain import bytes_to_mnemonic from tests.block_tools import BlockTools @@ -184,7 +185,7 @@ async def setup_wallet_node( async def setup_harvester( - b_tools: BlockTools, + root_path: Path, self_hostname: str, port, rpc_port, @@ -192,15 +193,14 @@ async def setup_harvester( consensus_constants: ConsensusConstants, start_service: bool = True, ): - - config = b_tools.config["harvester"] - config["port"] = port - config["rpc_port"] = rpc_port - kwargs = service_kwargs_for_harvester(b_tools.root_path, config, consensus_constants) + config = load_config(root_path, "config.yaml") + config["harvester"]["port"] = port + config["harvester"]["rpc_port"] = rpc_port + config["harvester"]["farmer_peer"]["host"] = self_hostname + config["harvester"]["farmer_peer"]["port"] = farmer_port + save_config(root_path, "config.yaml", config) + kwargs = service_kwargs_for_harvester(root_path, config["harvester"], consensus_constants) kwargs.update( - server_listen_ports=[port], - advertised_port=port, - connect_peers=[PeerInfo(self_hostname, farmer_port)], parse_cli_args=False, connect_to_daemon=False, service_name_prefix="test_",