chia|tests|github: Implement, integrate and test plot sync protocol (#9695)

* protocols|server: Define new harvester plot refreshing protocol messages

* protocols: Bump `protocol_version` to `0.0.34`

* tests: Introduce `setup_farmer_multi_harvester`

Allows to run a test setup with 1 farmer and mutiple harvesters.

* plotting: Add an initial plot loading indication to `PlotManager`

* plotting|tests: Don't add removed duplicates to `total_result.removed`

`PlotRefreshResult.removed` should only contain plots that were loaded
properly before they were removed. It shouldn't contain e.g. removed
duplicates or invalid plots since those are synced in an extra sync step
and not as diff but as whole list every time.

* harvester: Reset `PlotManager` on shutdown

* plot_sync: Implement plot sync protocol

* farmer|harvester: Integrate and enable plot sync

* tests: Implement tests for the plot sync protocol

* farmer|tests: Drop obsolete harvester caching code

* setup: Add `chia.plot_sync` to packages

* plot_sync: Type hints in `DeltaType`

* plot_sync: Drop parameters in `super()` calls

* plot_sync: Introduce `send_response` helper in `Receiver._process`

* plot_sync: Add some parentheses

Co-authored-by: Kyle Altendorf <sda@fstab.net>

* plot_sync: Additional hint for a `Receiver.process_path_list` parameter

* plot_sync: Force named parameters in `Receiver.process_path_list`

* test: Fix fixtures after rebase

* tests: Fix sorting after rebase

* tests: Return type hint for `plot_sync_setup`

* tests: Rename `WSChiaConnection` and move it in the outer scope

* tests|plot_sync: More type hints

* tests: Rework some delta tests

* tests: Drop a `range` and iterate over the list directly

* tests: Use the proper flags to overwrite

* test: More missing duplicates tests

* tests: Drop `ExpectedResult.reset`

* tests: Reduce some asserts

* tests: Add messages to some `assert False` statements

* tests: Introduce `ErrorSimulation` enum in `test_sync_simulated.py`

* tests: Use `secrects` instead of `Crypto.Random`

* Fixes after rebase

* Import from `typing_extensions` to support python 3.7

* Drop task name to support python 3.7

* Introduce `Sender.syncing`, `Sender.connected` and a log about the task

* Add `tests/plot_sync/config.py`

* Align the multi harvester fixture with what we do in other places

* Update the workflows

Co-authored-by: Kyle Altendorf <sda@fstab.net>
This commit is contained in:
dustinface 2022-04-08 02:10:44 +02:00 committed by GitHub
parent 8c0cdda880
commit ded9f68583
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
31 changed files with 2877 additions and 132 deletions

View File

@ -0,0 +1,107 @@
#
# THIS FILE IS GENERATED. SEE https://github.com/Chia-Network/chia-blockchain/tree/main/tests#readme
#
name: MacOS plot_sync Tests
on:
push:
branches:
- main
tags:
- '**'
pull_request:
branches:
- '**'
concurrency:
# SHA is added to the end if on `main` to let all main workflows run
group: ${{ github.ref }}-${{ github.workflow }}-${{ github.event_name }}-${{ github.ref == 'refs/heads/main' && github.sha || '' }}
cancel-in-progress: true
jobs:
build:
name: MacOS plot_sync Tests
runs-on: ${{ matrix.os }}
timeout-minutes: 30
strategy:
fail-fast: false
max-parallel: 4
matrix:
python-version: [3.8, 3.9]
os: [macOS-latest]
env:
CHIA_ROOT: ${{ github.workspace }}/.chia/mainnet
JOB_FILE_NAME: tests_${{ matrix.os }}_python-${{ matrix.python-version }}_plot_sync
steps:
- name: Checkout Code
uses: actions/checkout@v3
with:
fetch-depth: 0
- name: Setup Python environment
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Create keychain for CI use
run: |
security create-keychain -p foo chiachain
security default-keychain -s chiachain
security unlock-keychain -p foo chiachain
security set-keychain-settings -t 7200 -u chiachain
- name: Get pip cache dir
id: pip-cache
run: |
echo "::set-output name=dir::$(pip cache dir)"
- name: Cache pip
uses: actions/cache@v3
with:
# Note that new runners may break this https://github.com/actions/cache/issues/292
path: ${{ steps.pip-cache.outputs.dir }}
key: ${{ runner.os }}-pip-${{ hashFiles('**/setup.py') }}
restore-keys: |
${{ runner.os }}-pip-
- name: Checkout test blocks and plots
uses: actions/checkout@v3
with:
repository: 'Chia-Network/test-cache'
path: '.chia'
ref: '0.28.0'
fetch-depth: 1
- name: Run install script
env:
INSTALL_PYTHON_VERSION: ${{ matrix.python-version }}
run: |
brew install boost
sh install.sh -d
# Omitted installing Timelord
- name: Test plot_sync code with pytest
run: |
. ./activate
venv/bin/coverage run --rcfile=.coveragerc ./venv/bin/py.test tests/plot_sync/test_*.py --durations=10 -n 4 -m "not benchmark"
- name: Process coverage data
run: |
venv/bin/coverage combine --rcfile=.coveragerc .coverage.*
venv/bin/coverage xml --rcfile=.coveragerc -o coverage.xml
mkdir coverage_reports
cp .coverage "coverage_reports/.coverage.${{ env.JOB_FILE_NAME }}"
cp coverage.xml "coverage_reports/coverage.${{ env.JOB_FILE_NAME }}.xml"
venv/bin/coverage report --rcfile=.coveragerc --show-missing
- name: Publish coverage
uses: actions/upload-artifact@v2
with:
name: coverage
path: coverage_reports/*
if-no-files-found: error
#
# THIS FILE IS GENERATED. SEE https://github.com/Chia-Network/chia-blockchain/tree/main/tests#readme
#

View File

@ -0,0 +1,109 @@
#
# THIS FILE IS GENERATED. SEE https://github.com/Chia-Network/chia-blockchain/tree/main/tests#readme
#
name: Ubuntu plot_sync Test
on:
push:
branches:
- main
tags:
- '**'
pull_request:
branches:
- '**'
concurrency:
# SHA is added to the end if on `main` to let all main workflows run
group: ${{ github.ref }}-${{ github.workflow }}-${{ github.event_name }}-${{ github.ref == 'refs/heads/main' && github.sha || '' }}
cancel-in-progress: true
jobs:
build:
name: Ubuntu plot_sync Test
runs-on: ${{ matrix.os }}
timeout-minutes: 30
strategy:
fail-fast: false
max-parallel: 4
matrix:
python-version: [3.7, 3.8, 3.9]
os: [ubuntu-latest]
env:
CHIA_ROOT: ${{ github.workspace }}/.chia/mainnet
JOB_FILE_NAME: tests_${{ matrix.os }}_python-${{ matrix.python-version }}_plot_sync
steps:
- name: Checkout Code
uses: actions/checkout@v3
with:
fetch-depth: 0
- name: Setup Python environment
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Cache npm
uses: actions/cache@v3
with:
path: ~/.npm
key: ${{ runner.os }}-node-${{ hashFiles('**/package-lock.json') }}
restore-keys: |
${{ runner.os }}-node-
- name: Get pip cache dir
id: pip-cache
run: |
echo "::set-output name=dir::$(pip cache dir)"
- name: Cache pip
uses: actions/cache@v3
with:
path: ${{ steps.pip-cache.outputs.dir }}
key: ${{ runner.os }}-pip-${{ hashFiles('**/setup.py') }}
restore-keys: |
${{ runner.os }}-pip-
- name: Checkout test blocks and plots
uses: actions/checkout@v3
with:
repository: 'Chia-Network/test-cache'
path: '.chia'
ref: '0.28.0'
fetch-depth: 1
- name: Run install script
env:
INSTALL_PYTHON_VERSION: ${{ matrix.python-version }}
run: |
sh install.sh -d
# Omitted installing Timelord
- name: Test plot_sync code with pytest
run: |
. ./activate
venv/bin/coverage run --rcfile=.coveragerc ./venv/bin/py.test tests/plot_sync/test_*.py --durations=10 -n 4 -m "not benchmark" -p no:monitor
- name: Process coverage data
run: |
venv/bin/coverage combine --rcfile=.coveragerc .coverage.*
venv/bin/coverage xml --rcfile=.coveragerc -o coverage.xml
mkdir coverage_reports
cp .coverage "coverage_reports/.coverage.${{ env.JOB_FILE_NAME }}"
cp coverage.xml "coverage_reports/coverage.${{ env.JOB_FILE_NAME }}.xml"
venv/bin/coverage report --rcfile=.coveragerc --show-missing
- name: Publish coverage
uses: actions/upload-artifact@v2
with:
name: coverage
path: coverage_reports/*
if-no-files-found: error
# Omitted resource usage check
#
# THIS FILE IS GENERATED. SEE https://github.com/Chia-Network/chia-blockchain/tree/main/tests#readme
#

View File

@ -18,6 +18,8 @@ from chia.daemon.keychain_proxy import (
connect_to_keychain_and_validate,
wrap_local_keychain,
)
from chia.plot_sync.receiver import Receiver
from chia.plot_sync.delta import Delta
from chia.pools.pool_config import PoolWalletConfig, load_pool_config, add_auth_key
from chia.protocols import farmer_protocol, harvester_protocol
from chia.protocols.pool_protocol import (
@ -59,32 +61,12 @@ log = logging.getLogger(__name__)
UPDATE_POOL_INFO_INTERVAL: int = 3600
UPDATE_POOL_FARMER_INFO_INTERVAL: int = 300
UPDATE_HARVESTER_CACHE_INTERVAL: int = 90
"""
HARVESTER PROTOCOL (FARMER <-> HARVESTER)
"""
class HarvesterCacheEntry:
def __init__(self):
self.data: Optional[dict] = None
self.last_update: float = 0
def bump_last_update(self):
self.last_update = time.time()
def set_data(self, data):
self.data = data
self.bump_last_update()
def needs_update(self, update_interval: int):
return time.time() - self.last_update > update_interval
def expired(self, update_interval: int):
return time.time() - self.last_update > update_interval * 10
class Farmer:
def __init__(
self,
@ -115,8 +97,7 @@ class Farmer:
# to periodically clear the memory
self.cache_add_time: Dict[bytes32, uint64] = {}
# Interval to request plots from connected harvesters
self.update_harvester_cache_interval = UPDATE_HARVESTER_CACHE_INTERVAL
self.plot_sync_receivers: Dict[bytes32, Receiver] = {}
self.cache_clear_task: Optional[asyncio.Task] = None
self.update_pool_state_task: Optional[asyncio.Task] = None
@ -137,8 +118,6 @@ class Farmer:
# Last time we updated pool_state based on the config file
self.last_config_access_time: uint64 = uint64(0)
self.harvester_cache: Dict[str, Dict[str, HarvesterCacheEntry]] = {}
async def ensure_keychain_proxy(self) -> KeychainProxy:
if self.keychain_proxy is None:
if self.local_keychain:
@ -256,6 +235,7 @@ class Farmer:
self.harvester_handshake_task = None
if peer.connection_type is NodeType.HARVESTER:
self.plot_sync_receivers[peer.peer_node_id] = Receiver(peer, self.plot_sync_callback)
self.harvester_handshake_task = asyncio.create_task(handshake_task())
def set_server(self, server):
@ -274,6 +254,13 @@ class Farmer:
def on_disconnect(self, connection: ws.WSChiaConnection):
self.log.info(f"peer disconnected {connection.get_peer_logging()}")
self.state_changed("close_connection", {})
if connection.connection_type is NodeType.HARVESTER:
del self.plot_sync_receivers[connection.peer_node_id]
async def plot_sync_callback(self, peer_id: bytes32, delta: Delta) -> None:
log.info(f"plot_sync_callback: peer_id {peer_id}, delta {delta}")
if not delta.empty():
self.state_changed("new_plots", await self.get_harvesters())
async def _pool_get_pool_info(self, pool_config: PoolWalletConfig) -> Optional[Dict]:
try:
@ -642,80 +629,17 @@ class Farmer:
return None
async def update_cached_harvesters(self) -> bool:
# First remove outdated cache entries
self.log.debug(f"update_cached_harvesters cache entries: {len(self.harvester_cache)}")
remove_hosts = []
for host, host_cache in self.harvester_cache.items():
remove_peers = []
for peer_id, peer_cache in host_cache.items():
# If the peer cache is expired it means the harvester didn't respond for too long
if peer_cache.expired(self.update_harvester_cache_interval):
remove_peers.append(peer_id)
for key in remove_peers:
del host_cache[key]
if len(host_cache) == 0:
self.log.debug(f"update_cached_harvesters remove host: {host}")
remove_hosts.append(host)
for key in remove_hosts:
del self.harvester_cache[key]
# Now query each harvester and update caches
updated = False
for connection in self.server.get_connections(NodeType.HARVESTER):
cache_entry = await self.get_cached_harvesters(connection)
if cache_entry.needs_update(self.update_harvester_cache_interval):
self.log.debug(f"update_cached_harvesters update harvester: {connection.peer_node_id}")
cache_entry.bump_last_update()
response = await connection.request_plots(
harvester_protocol.RequestPlots(), timeout=self.update_harvester_cache_interval
)
if response is not None:
if isinstance(response, harvester_protocol.RespondPlots):
new_data: Dict = response.to_json_dict()
if cache_entry.data != new_data:
updated = True
self.log.debug(f"update_cached_harvesters cache updated: {connection.peer_node_id}")
else:
self.log.debug(f"update_cached_harvesters no changes for: {connection.peer_node_id}")
cache_entry.set_data(new_data)
else:
self.log.error(
f"Invalid response from harvester:"
f"peer_host {connection.peer_host}, peer_node_id {connection.peer_node_id}"
)
else:
self.log.error(
f"Harvester '{connection.peer_host}/{connection.peer_node_id}' did not respond: "
f"(version mismatch or time out {UPDATE_HARVESTER_CACHE_INTERVAL}s)"
)
return updated
async def get_cached_harvesters(self, connection: WSChiaConnection) -> HarvesterCacheEntry:
host_cache = self.harvester_cache.get(connection.peer_host)
if host_cache is None:
host_cache = {}
self.harvester_cache[connection.peer_host] = host_cache
node_cache = host_cache.get(connection.peer_node_id.hex())
if node_cache is None:
node_cache = HarvesterCacheEntry()
host_cache[connection.peer_node_id.hex()] = node_cache
return node_cache
async def get_harvesters(self) -> Dict:
harvesters: List = []
for connection in self.server.get_connections(NodeType.HARVESTER):
self.log.debug(f"get_harvesters host: {connection.peer_host}, node_id: {connection.peer_node_id}")
cache_entry = await self.get_cached_harvesters(connection)
if cache_entry.data is not None:
harvester_object: dict = dict(cache_entry.data)
harvester_object["connection"] = {
"node_id": connection.peer_node_id.hex(),
"host": connection.peer_host,
"port": connection.peer_port,
}
harvesters.append(harvester_object)
receiver = self.plot_sync_receivers.get(connection.peer_node_id)
if receiver is not None:
harvesters.append(receiver.to_dict())
else:
self.log.debug(f"get_harvesters no cache: {connection.peer_host}, node_id: {connection.peer_node_id}")
self.log.debug(
f"get_harvesters invalid peer: {connection.peer_host}, node_id: {connection.peer_node_id}"
)
return {"harvesters": harvesters}
@ -766,9 +690,6 @@ class Farmer:
self.state_changed("add_connection", {})
refresh_slept = 0
# Handles harvester plots cache cleanup and updates
if await self.update_cached_harvesters():
self.state_changed("new_plots", await self.get_harvesters())
except Exception:
log.error(f"_periodically_clear_cache_and_refresh_task failed: {traceback.format_exc()}")

View File

@ -11,7 +11,13 @@ from chia.consensus.network_type import NetworkType
from chia.consensus.pot_iterations import calculate_iterations_quality, calculate_sp_interval_iters
from chia.farmer.farmer import Farmer
from chia.protocols import farmer_protocol, harvester_protocol
from chia.protocols.harvester_protocol import PoolDifficulty
from chia.protocols.harvester_protocol import (
PoolDifficulty,
PlotSyncStart,
PlotSyncPlotList,
PlotSyncPathList,
PlotSyncDone,
)
from chia.protocols.pool_protocol import (
get_current_authentication_token,
PoolErrorCode,
@ -518,3 +524,38 @@ class FarmerAPI:
@peer_required
async def respond_plots(self, _: harvester_protocol.RespondPlots, peer: ws.WSChiaConnection):
self.farmer.log.warning(f"Respond plots came too late from: {peer.get_peer_logging()}")
@api_request
@peer_required
async def plot_sync_start(self, message: PlotSyncStart, peer: ws.WSChiaConnection):
await self.farmer.plot_sync_receivers[peer.peer_node_id].sync_started(message)
@api_request
@peer_required
async def plot_sync_loaded(self, message: PlotSyncPlotList, peer: ws.WSChiaConnection):
await self.farmer.plot_sync_receivers[peer.peer_node_id].process_loaded(message)
@api_request
@peer_required
async def plot_sync_removed(self, message: PlotSyncPathList, peer: ws.WSChiaConnection):
await self.farmer.plot_sync_receivers[peer.peer_node_id].process_removed(message)
@api_request
@peer_required
async def plot_sync_invalid(self, message: PlotSyncPathList, peer: ws.WSChiaConnection):
await self.farmer.plot_sync_receivers[peer.peer_node_id].process_invalid(message)
@api_request
@peer_required
async def plot_sync_keys_missing(self, message: PlotSyncPathList, peer: ws.WSChiaConnection):
await self.farmer.plot_sync_receivers[peer.peer_node_id].process_keys_missing(message)
@api_request
@peer_required
async def plot_sync_duplicates(self, message: PlotSyncPathList, peer: ws.WSChiaConnection):
await self.farmer.plot_sync_receivers[peer.peer_node_id].process_duplicates(message)
@api_request
@peer_required
async def plot_sync_done(self, message: PlotSyncDone, peer: ws.WSChiaConnection):
await self.farmer.plot_sync_receivers[peer.peer_node_id].sync_done(message)

View File

@ -8,6 +8,7 @@ from typing import Callable, Dict, List, Optional, Tuple
import chia.server.ws_connection as ws # lgtm [py/import-and-import-from]
from chia.consensus.constants import ConsensusConstants
from chia.plot_sync.sender import Sender
from chia.plotting.manager import PlotManager
from chia.plotting.util import (
add_plot_directory,
@ -25,6 +26,7 @@ log = logging.getLogger(__name__)
class Harvester:
plot_manager: PlotManager
plot_sync_sender: Sender
root_path: Path
_is_shutdown: bool
executor: ThreadPoolExecutor
@ -53,6 +55,7 @@ class Harvester:
self.plot_manager = PlotManager(
root_path, refresh_parameter=refresh_parameter, refresh_callback=self._plot_refresh_callback
)
self.plot_sync_sender = Sender(self.plot_manager)
self._is_shutdown = False
self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=config["num_threads"])
self.state_changed_callback = None
@ -70,9 +73,11 @@ class Harvester:
self._is_shutdown = True
self.executor.shutdown(wait=True)
self.plot_manager.stop_refreshing()
self.plot_manager.reset()
self.plot_sync_sender.stop()
async def _await_closed(self):
pass
await self.plot_sync_sender.await_closed()
def _set_state_changed_callback(self, callback: Callable):
self.state_changed_callback = callback
@ -90,12 +95,18 @@ class Harvester:
f"duration: {update_result.duration:.2f} seconds, "
f"total plots: {len(self.plot_manager.plots)}"
)
if len(update_result.loaded) > 0:
self.event_loop.call_soon_threadsafe(self._state_changed, "plots")
if event == PlotRefreshEvents.started:
self.plot_sync_sender.sync_start(update_result.remaining, self.plot_manager.initial_refresh())
if event == PlotRefreshEvents.batch_processed:
self.plot_sync_sender.process_batch(update_result.loaded, update_result.remaining)
if event == PlotRefreshEvents.done:
self.plot_sync_sender.sync_done(update_result.removed, update_result.duration)
def on_disconnect(self, connection: ws.WSChiaConnection):
self.log.info(f"peer disconnected {connection.get_peer_logging()}")
self._state_changed("close_connection")
self.plot_manager.stop_refreshing()
self.plot_sync_sender.stop()
def get_plots(self) -> Tuple[List[Dict], List[str], List[str]]:
self.log.debug(f"get_plots prover items: {self.plot_manager.plot_count()}")

View File

@ -10,7 +10,7 @@ from chia.harvester.harvester import Harvester
from chia.plotting.util import PlotInfo, parse_plot_info
from chia.protocols import harvester_protocol
from chia.protocols.farmer_protocol import FarmingInfo
from chia.protocols.harvester_protocol import Plot
from chia.protocols.harvester_protocol import Plot, PlotSyncResponse
from chia.protocols.protocol_message_types import ProtocolMessageTypes
from chia.server.outbound_message import make_msg
from chia.server.ws_connection import WSChiaConnection
@ -30,8 +30,11 @@ class HarvesterAPI:
def _set_state_changed_callback(self, callback: Callable):
self.harvester.state_changed_callback = callback
@peer_required
@api_request
async def harvester_handshake(self, harvester_handshake: harvester_protocol.HarvesterHandshake):
async def harvester_handshake(
self, harvester_handshake: harvester_protocol.HarvesterHandshake, peer: WSChiaConnection
):
"""
Handshake between the harvester and farmer. The harvester receives the pool public keys,
as well as the farmer pks, which must be put into the plots, before the plotting process begins.
@ -40,7 +43,8 @@ class HarvesterAPI:
self.harvester.plot_manager.set_public_keys(
harvester_handshake.farmer_public_keys, harvester_handshake.pool_public_keys
)
self.harvester.plot_sync_sender.set_connection(peer)
await self.harvester.plot_sync_sender.start()
self.harvester.plot_manager.start_refreshing()
@peer_required
@ -289,3 +293,7 @@ class HarvesterAPI:
response = harvester_protocol.RespondPlots(plots_response, failed_to_open_filenames, no_key_filenames)
return make_msg(ProtocolMessageTypes.respond_plots, response)
@api_request
async def plot_sync_response(self, response: PlotSyncResponse):
self.harvester.plot_sync_sender.set_response(response)

View File

59
chia/plot_sync/delta.py Normal file
View File

@ -0,0 +1,59 @@
from dataclasses import dataclass, field
from typing import Dict, List, Union
from chia.protocols.harvester_protocol import Plot
@dataclass
class DeltaType:
additions: Union[Dict[str, Plot], List[str]]
removals: List[str]
def __str__(self) -> str:
return f"+{len(self.additions)}/-{len(self.removals)}"
def clear(self) -> None:
self.additions.clear()
self.removals.clear()
def empty(self) -> bool:
return len(self.additions) == 0 and len(self.removals) == 0
@dataclass
class PlotListDelta(DeltaType):
additions: Dict[str, Plot] = field(default_factory=dict)
removals: List[str] = field(default_factory=list)
@dataclass
class PathListDelta(DeltaType):
additions: List[str] = field(default_factory=list)
removals: List[str] = field(default_factory=list)
@staticmethod
def from_lists(old: List[str], new: List[str]) -> "PathListDelta":
return PathListDelta([x for x in new if x not in old], [x for x in old if x not in new])
@dataclass
class Delta:
valid: PlotListDelta = field(default_factory=PlotListDelta)
invalid: PathListDelta = field(default_factory=PathListDelta)
keys_missing: PathListDelta = field(default_factory=PathListDelta)
duplicates: PathListDelta = field(default_factory=PathListDelta)
def empty(self) -> bool:
return self.valid.empty() and self.invalid.empty() and self.keys_missing.empty() and self.duplicates.empty()
def __str__(self) -> str:
return (
f"valid {self.valid}, invalid {self.invalid}, keys missing: {self.keys_missing}, "
f"duplicates: {self.duplicates}"
)
def clear(self) -> None:
self.valid.clear()
self.invalid.clear()
self.keys_missing.clear()
self.duplicates.clear()

View File

@ -0,0 +1,54 @@
from typing import Any
from chia.plot_sync.util import ErrorCodes, State
from chia.protocols.harvester_protocol import PlotSyncIdentifier
from chia.server.ws_connection import NodeType
from chia.util.ints import uint64
class PlotSyncException(Exception):
def __init__(self, message: str, error_code: ErrorCodes) -> None:
super().__init__(message)
self.error_code = error_code
class AlreadyStartedError(Exception):
def __init__(self) -> None:
super().__init__("Already started!")
class InvalidValueError(PlotSyncException):
def __init__(self, message: str, actual: Any, expected: Any, error_code: ErrorCodes) -> None:
super().__init__(f"{message}: Actual {actual}, Expected {expected}", error_code)
class InvalidIdentifierError(InvalidValueError):
def __init__(self, actual_identifier: PlotSyncIdentifier, expected_identifier: PlotSyncIdentifier) -> None:
super().__init__("Invalid identifier", actual_identifier, expected_identifier, ErrorCodes.invalid_identifier)
self.actual_identifier: PlotSyncIdentifier = actual_identifier
self.expected_identifier: PlotSyncIdentifier = expected_identifier
class InvalidLastSyncIdError(InvalidValueError):
def __init__(self, actual: uint64, expected: uint64) -> None:
super().__init__("Invalid last-sync-id", actual, expected, ErrorCodes.invalid_last_sync_id)
class InvalidConnectionTypeError(InvalidValueError):
def __init__(self, actual: NodeType, expected: NodeType) -> None:
super().__init__("Unexpected connection type", actual, expected, ErrorCodes.invalid_connection_type)
class PlotAlreadyAvailableError(PlotSyncException):
def __init__(self, state: State, path: str) -> None:
super().__init__(f"{state.name}: Plot already available - {path}", ErrorCodes.plot_already_available)
class PlotNotAvailableError(PlotSyncException):
def __init__(self, state: State, path: str) -> None:
super().__init__(f"{state.name}: Plot not available - {path}", ErrorCodes.plot_not_available)
class SyncIdsMatchError(PlotSyncException):
def __init__(self, state: State, sync_id: uint64) -> None:
super().__init__(f"{state.name}: Sync ids are equal - {sync_id}", ErrorCodes.sync_ids_match)

304
chia/plot_sync/receiver.py Normal file
View File

@ -0,0 +1,304 @@
import logging
import time
from typing import Any, Callable, Collection, Coroutine, Dict, List, Optional
from chia.plot_sync.delta import Delta, PathListDelta, PlotListDelta
from chia.plot_sync.exceptions import (
InvalidIdentifierError,
InvalidLastSyncIdError,
PlotAlreadyAvailableError,
PlotNotAvailableError,
PlotSyncException,
SyncIdsMatchError,
)
from chia.plot_sync.util import ErrorCodes, State
from chia.protocols.harvester_protocol import (
Plot,
PlotSyncDone,
PlotSyncError,
PlotSyncIdentifier,
PlotSyncPathList,
PlotSyncPlotList,
PlotSyncResponse,
PlotSyncStart,
)
from chia.server.ws_connection import ProtocolMessageTypes, WSChiaConnection, make_msg
from chia.types.blockchain_format.sized_bytes import bytes32
from chia.util.ints import int16, uint64
from chia.util.streamable import _T_Streamable
log = logging.getLogger(__name__)
class Receiver:
_connection: WSChiaConnection
_sync_state: State
_delta: Delta
_expected_sync_id: uint64
_expected_message_id: uint64
_last_sync_id: uint64
_last_sync_time: float
_plots: Dict[str, Plot]
_invalid: List[str]
_keys_missing: List[str]
_duplicates: List[str]
_update_callback: Callable[[bytes32, Delta], Coroutine[Any, Any, None]]
def __init__(
self, connection: WSChiaConnection, update_callback: Callable[[bytes32, Delta], Coroutine[Any, Any, None]]
) -> None:
self._connection = connection
self._sync_state = State.idle
self._delta = Delta()
self._expected_sync_id = uint64(0)
self._expected_message_id = uint64(0)
self._last_sync_id = uint64(0)
self._last_sync_time = 0
self._plots = {}
self._invalid = []
self._keys_missing = []
self._duplicates = []
self._update_callback = update_callback # type: ignore[assignment, misc]
def reset(self) -> None:
self._sync_state = State.idle
self._expected_sync_id = uint64(0)
self._expected_message_id = uint64(0)
self._last_sync_id = uint64(0)
self._last_sync_time = 0
self._plots.clear()
self._invalid.clear()
self._keys_missing.clear()
self._duplicates.clear()
self._delta.clear()
def bump_expected_message_id(self) -> None:
self._expected_message_id = uint64(self._expected_message_id + 1)
def connection(self) -> WSChiaConnection:
return self._connection
def state(self) -> State:
return self._sync_state
def expected_sync_id(self) -> uint64:
return self._expected_sync_id
def expected_message_id(self) -> uint64:
return self._expected_message_id
def last_sync_id(self) -> uint64:
return self._last_sync_id
def last_sync_time(self) -> float:
return self._last_sync_time
def plots(self) -> Dict[str, Plot]:
return self._plots
def invalid(self) -> List[str]:
return self._invalid
def keys_missing(self) -> List[str]:
return self._keys_missing
def duplicates(self) -> List[str]:
return self._duplicates
async def _process(
self, method: Callable[[_T_Streamable], Any], message_type: ProtocolMessageTypes, message: Any
) -> None:
async def send_response(plot_sync_error: Optional[PlotSyncError] = None) -> None:
if self._connection is not None:
await self._connection.send_message(
make_msg(
ProtocolMessageTypes.plot_sync_response,
PlotSyncResponse(message.identifier, int16(message_type.value), plot_sync_error),
)
)
try:
await method(message)
await send_response()
except InvalidIdentifierError as e:
log.warning(f"_process: InvalidIdentifierError {e}")
await send_response(PlotSyncError(int16(e.error_code), f"{e}", e.expected_identifier))
except PlotSyncException as e:
log.warning(f"_process: Error {e}")
await send_response(PlotSyncError(int16(e.error_code), f"{e}", None))
except Exception as e:
log.warning(f"_process: Exception {e}")
await send_response(PlotSyncError(int16(ErrorCodes.unknown), f"{e}", None))
def _validate_identifier(self, identifier: PlotSyncIdentifier, start: bool = False) -> None:
sync_id_match = identifier.sync_id == self._expected_sync_id
message_id_match = identifier.message_id == self._expected_message_id
identifier_match = sync_id_match and message_id_match
if (start and not message_id_match) or (not start and not identifier_match):
expected: PlotSyncIdentifier = PlotSyncIdentifier(
identifier.timestamp, self._expected_sync_id, self._expected_message_id
)
raise InvalidIdentifierError(
identifier,
expected,
)
async def _sync_started(self, data: PlotSyncStart) -> None:
if data.initial:
self.reset()
self._validate_identifier(data.identifier, True)
if data.last_sync_id != self.last_sync_id():
raise InvalidLastSyncIdError(data.last_sync_id, self.last_sync_id())
if data.last_sync_id == data.identifier.sync_id:
raise SyncIdsMatchError(State.idle, data.last_sync_id)
self._expected_sync_id = data.identifier.sync_id
self._delta.clear()
self._sync_state = State.loaded
self.bump_expected_message_id()
async def sync_started(self, data: PlotSyncStart) -> None:
await self._process(self._sync_started, ProtocolMessageTypes.plot_sync_start, data)
async def _process_loaded(self, plot_infos: PlotSyncPlotList) -> None:
self._validate_identifier(plot_infos.identifier)
for plot_info in plot_infos.data:
if plot_info.filename in self._plots or plot_info.filename in self._delta.valid.additions:
raise PlotAlreadyAvailableError(State.loaded, plot_info.filename)
self._delta.valid.additions[plot_info.filename] = plot_info
if plot_infos.final:
self._sync_state = State.removed
self.bump_expected_message_id()
async def process_loaded(self, plot_infos: PlotSyncPlotList) -> None:
await self._process(self._process_loaded, ProtocolMessageTypes.plot_sync_loaded, plot_infos)
async def process_path_list(
self,
*,
state: State,
next_state: State,
target: Collection[str],
delta: List[str],
paths: PlotSyncPathList,
is_removal: bool = False,
) -> None:
self._validate_identifier(paths.identifier)
for path in paths.data:
if is_removal and (path not in target or path in delta):
raise PlotNotAvailableError(state, path)
if not is_removal and path in delta:
raise PlotAlreadyAvailableError(state, path)
delta.append(path)
if paths.final:
self._sync_state = next_state
self.bump_expected_message_id()
async def _process_removed(self, paths: PlotSyncPathList) -> None:
await self.process_path_list(
state=State.removed,
next_state=State.invalid,
target=self._plots,
delta=self._delta.valid.removals,
paths=paths,
is_removal=True,
)
async def process_removed(self, paths: PlotSyncPathList) -> None:
await self._process(self._process_removed, ProtocolMessageTypes.plot_sync_removed, paths)
async def _process_invalid(self, paths: PlotSyncPathList) -> None:
await self.process_path_list(
state=State.invalid,
next_state=State.keys_missing,
target=self._invalid,
delta=self._delta.invalid.additions,
paths=paths,
)
async def process_invalid(self, paths: PlotSyncPathList) -> None:
await self._process(self._process_invalid, ProtocolMessageTypes.plot_sync_invalid, paths)
async def _process_keys_missing(self, paths: PlotSyncPathList) -> None:
await self.process_path_list(
state=State.keys_missing,
next_state=State.duplicates,
target=self._keys_missing,
delta=self._delta.keys_missing.additions,
paths=paths,
)
async def process_keys_missing(self, paths: PlotSyncPathList) -> None:
await self._process(self._process_keys_missing, ProtocolMessageTypes.plot_sync_keys_missing, paths)
async def _process_duplicates(self, paths: PlotSyncPathList) -> None:
await self.process_path_list(
state=State.duplicates,
next_state=State.done,
target=self._duplicates,
delta=self._delta.duplicates.additions,
paths=paths,
)
async def process_duplicates(self, paths: PlotSyncPathList) -> None:
await self._process(self._process_duplicates, ProtocolMessageTypes.plot_sync_duplicates, paths)
async def _sync_done(self, data: PlotSyncDone) -> None:
self._validate_identifier(data.identifier)
# Update ids
self._last_sync_id = self._expected_sync_id
self._expected_sync_id = uint64(0)
self._expected_message_id = uint64(0)
# First create the update delta (i.e. transform invalid/keys_missing into additions/removals) which we will
# send to the callback receiver below
delta_invalid: PathListDelta = PathListDelta.from_lists(self._invalid, self._delta.invalid.additions)
delta_keys_missing: PathListDelta = PathListDelta.from_lists(
self._keys_missing, self._delta.keys_missing.additions
)
delta_duplicates: PathListDelta = PathListDelta.from_lists(self._duplicates, self._delta.duplicates.additions)
update = Delta(
PlotListDelta(self._delta.valid.additions.copy(), self._delta.valid.removals.copy()),
delta_invalid,
delta_keys_missing,
delta_duplicates,
)
# Apply delta
self._plots.update(self._delta.valid.additions)
for removal in self._delta.valid.removals:
del self._plots[removal]
self._invalid = self._delta.invalid.additions.copy()
self._keys_missing = self._delta.keys_missing.additions.copy()
self._duplicates = self._delta.duplicates.additions.copy()
# Update state and bump last sync time
self._sync_state = State.idle
self._last_sync_time = time.time()
# Let the callback receiver know if this sync cycle caused any update
try:
await self._update_callback(self._connection.peer_node_id, update) # type: ignore[misc,call-arg]
except Exception as e:
log.error(f"_update_callback raised: {e}")
self._delta.clear()
async def sync_done(self, data: PlotSyncDone) -> None:
await self._process(self._sync_done, ProtocolMessageTypes.plot_sync_done, data)
def to_dict(self) -> Dict[str, Any]:
result: Dict[str, Any] = {
"connection": {
"node_id": self._connection.peer_node_id,
"host": self._connection.peer_host,
"port": self._connection.peer_port,
},
"plots": list(self._plots.values()),
"failed_to_open_filenames": self._invalid,
"no_key_filenames": self._keys_missing,
"duplicates": self._duplicates,
}
if self._last_sync_time != 0:
result["last_sync_time"] = self._last_sync_time
return result

327
chia/plot_sync/sender.py Normal file
View File

@ -0,0 +1,327 @@
import asyncio
import logging
import threading
import time
import traceback
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Generic, Iterable, List, Optional, Tuple, Type, TypeVar
from typing_extensions import Protocol
from chia.plot_sync.exceptions import AlreadyStartedError, InvalidConnectionTypeError
from chia.plot_sync.util import Constants
from chia.plotting.manager import PlotManager
from chia.plotting.util import PlotInfo
from chia.protocols.harvester_protocol import (
Plot,
PlotSyncDone,
PlotSyncIdentifier,
PlotSyncPathList,
PlotSyncPlotList,
PlotSyncResponse,
PlotSyncStart,
)
from chia.server.ws_connection import NodeType, ProtocolMessageTypes, WSChiaConnection, make_msg
from chia.util.generator_tools import list_to_batches
from chia.util.ints import int16, uint32, uint64
log = logging.getLogger(__name__)
def _convert_plot_info_list(plot_infos: List[PlotInfo]) -> List[Plot]:
converted: List[Plot] = []
for plot_info in plot_infos:
converted.append(
Plot(
filename=plot_info.prover.get_filename(),
size=plot_info.prover.get_size(),
plot_id=plot_info.prover.get_id(),
pool_public_key=plot_info.pool_public_key,
pool_contract_puzzle_hash=plot_info.pool_contract_puzzle_hash,
plot_public_key=plot_info.plot_public_key,
file_size=uint64(plot_info.file_size),
time_modified=uint64(int(plot_info.time_modified)),
)
)
return converted
class PayloadType(Protocol):
def __init__(self, identifier: PlotSyncIdentifier, *args: object) -> None:
...
T = TypeVar("T", bound=PayloadType)
@dataclass
class MessageGenerator(Generic[T]):
sync_id: uint64
message_type: ProtocolMessageTypes
message_id: uint64
payload_type: Type[T]
args: Iterable[object]
def generate(self) -> Tuple[PlotSyncIdentifier, T]:
identifier = PlotSyncIdentifier(uint64(int(time.time())), self.sync_id, self.message_id)
payload = self.payload_type(identifier, *self.args)
return identifier, payload
@dataclass
class ExpectedResponse:
message_type: ProtocolMessageTypes
identifier: PlotSyncIdentifier
message: Optional[PlotSyncResponse] = None
def __str__(self) -> str:
return (
f"expected_message_type: {self.message_type.name}, "
f"expected_identifier: {self.identifier}, message {self.message}"
)
class Sender:
_plot_manager: PlotManager
_connection: Optional[WSChiaConnection]
_sync_id: uint64
_next_message_id: uint64
_messages: List[MessageGenerator[PayloadType]]
_last_sync_id: uint64
_stop_requested = False
_task: Optional[asyncio.Task] # type: ignore[type-arg] # Asks for Task parameter which doesn't work
_lock: threading.Lock
_response: Optional[ExpectedResponse]
def __init__(self, plot_manager: PlotManager) -> None:
self._plot_manager = plot_manager
self._connection = None
self._sync_id = uint64(0)
self._next_message_id = uint64(0)
self._messages = []
self._last_sync_id = uint64(0)
self._stop_requested = False
self._task = None
self._lock = threading.Lock()
self._response = None
def __str__(self) -> str:
return f"sync_id {self._sync_id}, next_message_id {self._next_message_id}, messages {len(self._messages)}"
async def start(self) -> None:
if self._task is not None and self._stop_requested:
await self.await_closed()
if self._task is None:
self._task = asyncio.create_task(self._run())
# TODO, Add typing in PlotManager
if not self._plot_manager.initial_refresh() or self._sync_id != 0: # type:ignore[no-untyped-call]
self._reset()
else:
raise AlreadyStartedError()
def stop(self) -> None:
self._stop_requested = True
async def await_closed(self) -> None:
if self._task is not None:
await self._task
self._task = None
self._reset()
self._stop_requested = False
def set_connection(self, connection: WSChiaConnection) -> None:
assert connection.connection_type is not None
if connection.connection_type != NodeType.FARMER:
raise InvalidConnectionTypeError(connection.connection_type, NodeType.HARVESTER)
self._connection = connection
def bump_next_message_id(self) -> None:
self._next_message_id = uint64(self._next_message_id + 1)
def _reset(self) -> None:
log.debug(f"_reset {self}")
self._last_sync_id = uint64(0)
self._sync_id = uint64(0)
self._next_message_id = uint64(0)
self._messages.clear()
if self._lock.locked():
self._lock.release()
if self._task is not None:
# TODO, Add typing in PlotManager
self.sync_start(self._plot_manager.plot_count(), True) # type:ignore[no-untyped-call]
for remaining, batch in list_to_batches(
list(self._plot_manager.plots.values()), self._plot_manager.refresh_parameter.batch_size
):
self.process_batch(batch, remaining)
self.sync_done([], 0)
async def _wait_for_response(self) -> bool:
start = time.time()
assert self._response is not None
while time.time() - start < Constants.message_timeout and self._response.message is None:
await asyncio.sleep(0.1)
return self._response.message is not None
def set_response(self, response: PlotSyncResponse) -> bool:
if self._response is None or self._response.message is not None:
log.warning(f"set_response skip unexpected response: {response}")
return False
if time.time() - float(response.identifier.timestamp) > Constants.message_timeout:
log.warning(f"set_response skip expired response: {response}")
return False
if response.identifier.sync_id != self._response.identifier.sync_id:
log.warning(
"set_response unexpected sync-id: " f"{response.identifier.sync_id}/{self._response.identifier.sync_id}"
)
return False
if response.identifier.message_id != self._response.identifier.message_id:
log.warning(
"set_response unexpected message-id: "
f"{response.identifier.message_id}/{self._response.identifier.message_id}"
)
return False
if response.message_type != int16(self._response.message_type.value):
log.warning(
"set_response unexpected message-type: " f"{response.message_type}/{self._response.message_type.value}"
)
return False
log.debug(f"set_response valid {response}")
self._response.message = response
return True
def _add_message(self, message_type: ProtocolMessageTypes, payload_type: Any, *args: Any) -> None:
assert self._sync_id != 0
message_id = uint64(len(self._messages))
self._messages.append(MessageGenerator(self._sync_id, message_type, message_id, payload_type, args))
async def _send_next_message(self) -> bool:
def failed(message: str) -> bool:
# By forcing a reset we try to get back into a normal state if some not recoverable failure came up.
log.warning(message)
self._reset()
return False
assert len(self._messages) >= self._next_message_id
message_generator = self._messages[self._next_message_id]
identifier, payload = message_generator.generate()
if self._sync_id == 0 or identifier.sync_id != self._sync_id or identifier.message_id != self._next_message_id:
return failed(f"Invalid message generator {message_generator} for {self}")
self._response = ExpectedResponse(message_generator.message_type, identifier)
log.debug(f"_send_next_message send {message_generator.message_type.name}: {payload}")
if self._connection is None or not await self._connection.send_message(
make_msg(message_generator.message_type, payload)
):
return failed(f"Send failed {self._connection}")
if not await self._wait_for_response():
log.info(f"_send_next_message didn't receive response {self._response}")
return False
assert self._response.message is not None
if self._response.message.error is not None:
recovered = False
expected = self._response.message.error.expected_identifier
# If we have a recoverable error there is a `expected_identifier` included
if expected is not None:
# If the receiver has a zero sync/message id and we already sent all messages from the current event
# we most likely missed the response to the done message. We can finalize the sync and move on here.
all_sent = (
self._messages[-1].message_type == ProtocolMessageTypes.plot_sync_done
and self._next_message_id == len(self._messages) - 1
)
if expected.sync_id == expected.message_id == 0 and all_sent:
self._finalize_sync()
recovered = True
elif self._sync_id == expected.sync_id and expected.message_id < len(self._messages):
self._next_message_id = expected.message_id
recovered = True
if not recovered:
return failed(f"Not recoverable error {self._response.message}")
return True
if self._response.message_type == ProtocolMessageTypes.plot_sync_done:
self._finalize_sync()
else:
self.bump_next_message_id()
return True
def _add_list_batched(self, message_type: ProtocolMessageTypes, payload_type: Any, data: List[Any]) -> None:
if len(data) == 0:
self._add_message(message_type, payload_type, [], True)
return
for remaining, batch in list_to_batches(data, self._plot_manager.refresh_parameter.batch_size):
self._add_message(message_type, payload_type, batch, remaining == 0)
def sync_start(self, count: float, initial: bool) -> None:
log.debug(f"sync_start {self}: count {count}, initial {initial}")
self._lock.acquire()
sync_id = int(time.time())
# Make sure we have unique sync-id's even if we restart refreshing within a second (i.e. in tests)
if sync_id == self._last_sync_id:
sync_id = sync_id + 1
log.debug(f"sync_start {sync_id}")
self._sync_id = uint64(sync_id)
self._add_message(
ProtocolMessageTypes.plot_sync_start, PlotSyncStart, initial, self._last_sync_id, uint32(int(count))
)
def process_batch(self, loaded: List[PlotInfo], remaining: int) -> None:
log.debug(f"process_batch {self}: loaded {len(loaded)}, remaining {remaining}")
if len(loaded) > 0 or remaining == 0:
converted = _convert_plot_info_list(loaded)
self._add_message(ProtocolMessageTypes.plot_sync_loaded, PlotSyncPlotList, converted, remaining == 0)
def sync_done(self, removed: List[Path], duration: float) -> None:
log.debug(f"sync_done {self}: removed {len(removed)}, duration {duration}")
removed_list = [str(x) for x in removed]
self._add_list_batched(
ProtocolMessageTypes.plot_sync_removed,
PlotSyncPathList,
removed_list,
)
failed_to_open_list = [str(x) for x in list(self._plot_manager.failed_to_open_filenames)]
self._add_list_batched(ProtocolMessageTypes.plot_sync_invalid, PlotSyncPathList, failed_to_open_list)
no_key_list = [str(x) for x in self._plot_manager.no_key_filenames]
self._add_list_batched(ProtocolMessageTypes.plot_sync_keys_missing, PlotSyncPathList, no_key_list)
# TODO, Add typing in PlotManager
duplicates_list: List[str] = self._plot_manager.get_duplicates().copy() # type:ignore[no-untyped-call]
self._add_list_batched(ProtocolMessageTypes.plot_sync_duplicates, PlotSyncPathList, duplicates_list)
self._add_message(ProtocolMessageTypes.plot_sync_done, PlotSyncDone, uint64(int(duration)))
def _finalize_sync(self) -> None:
log.debug(f"_finalize_sync {self}")
assert self._sync_id != 0
self._last_sync_id = self._sync_id
self._sync_id = uint64(0)
self._next_message_id = uint64(0)
self._messages.clear()
self._lock.release()
def sync_active(self) -> bool:
return self._lock.locked() and self._sync_id != 0
def connected(self) -> bool:
return self._connection is not None
async def _run(self) -> None:
"""
This is the sender task responsible to send new messages during sync as they come into Sender._messages
triggered by the plot manager callback.
"""
while not self._stop_requested:
try:
while not self.connected() or not self.sync_active():
if self._stop_requested:
return
await asyncio.sleep(0.1)
while not self._stop_requested and self.sync_active():
if self._next_message_id >= len(self._messages):
await asyncio.sleep(0.1)
continue
if not await self._send_next_message():
await asyncio.sleep(Constants.message_timeout)
except Exception as e:
log.error(f"Exception: {e} {traceback.format_exc()}")
self._reset()

27
chia/plot_sync/util.py Normal file
View File

@ -0,0 +1,27 @@
from enum import IntEnum
class Constants:
message_timeout: int = 10
class State(IntEnum):
idle = 0
loaded = 1
removed = 2
invalid = 3
keys_missing = 4
duplicates = 5
done = 6
class ErrorCodes(IntEnum):
unknown = -1
invalid_state = 0
invalid_peer_id = 1
invalid_identifier = 2
invalid_last_sync_id = 3
invalid_connection_type = 4
plot_already_available = 5
plot_not_available = 6
sync_ids_match = 7

View File

@ -131,6 +131,7 @@ class PlotManager:
_refresh_thread: Optional[threading.Thread]
_refreshing_enabled: bool
_refresh_callback: Callable
_initial: bool
def __init__(
self,
@ -158,6 +159,7 @@ class PlotManager:
self._refresh_thread = None
self._refreshing_enabled = False
self._refresh_callback = refresh_callback # type: ignore
self._initial = True
def __enter__(self):
self._lock.acquire()
@ -172,6 +174,7 @@ class PlotManager:
self.plot_filename_paths.clear()
self.failed_to_open_filenames.clear()
self.no_key_filenames.clear()
self._initial = True
def set_refresh_callback(self, callback: Callable):
self._refresh_callback = callback # type: ignore
@ -180,6 +183,9 @@ class PlotManager:
self.farmer_public_keys = farmer_public_keys
self.pool_public_keys = pool_public_keys
def initial_refresh(self):
return self._initial
def public_keys_available(self):
return len(self.farmer_public_keys) and len(self.pool_public_keys)
@ -262,7 +268,6 @@ class PlotManager:
loaded_plot = Path(path) / Path(plot_filename)
if loaded_plot not in plot_paths:
paths_to_remove.append(path)
total_result.removed.append(loaded_plot)
for path in paths_to_remove:
duplicated_paths.remove(path)
@ -290,6 +295,9 @@ class PlotManager:
if self._refreshing_enabled:
self._refresh_callback(PlotRefreshEvents.done, total_result)
# Reset the initial refresh indication
self._initial = False
# Cleanup unused cache
available_ids = set([plot_info.prover.get_id() for plot_info in self.plots.values()])
invalid_cache_keys = [plot_id for plot_id in self.cache.keys() if plot_id not in available_ids]

View File

@ -5,7 +5,7 @@ from blspy import G1Element, G2Element
from chia.types.blockchain_format.proof_of_space import ProofOfSpace
from chia.types.blockchain_format.sized_bytes import bytes32
from chia.util.ints import uint8, uint64
from chia.util.ints import int16, uint8, uint32, uint64
from chia.util.streamable import Streamable, streamable
"""
@ -95,3 +95,80 @@ class RespondPlots(Streamable):
plots: List[Plot]
failed_to_open_filenames: List[str]
no_key_filenames: List[str]
@dataclass(frozen=True)
@streamable
class PlotSyncIdentifier(Streamable):
timestamp: uint64
sync_id: uint64
message_id: uint64
@dataclass(frozen=True)
@streamable
class PlotSyncStart(Streamable):
identifier: PlotSyncIdentifier
initial: bool
last_sync_id: uint64
plot_file_count: uint32
def __str__(self) -> str:
return (
f"PlotSyncStart: identifier {self.identifier}, initial {self.initial}, "
f"last_sync_id {self.last_sync_id}, plot_file_count {self.plot_file_count}"
)
@dataclass(frozen=True)
@streamable
class PlotSyncPathList(Streamable):
identifier: PlotSyncIdentifier
data: List[str]
final: bool
def __str__(self) -> str:
return f"PlotSyncPathList: identifier {self.identifier}, count {len(self.data)}, final {self.final}"
@dataclass(frozen=True)
@streamable
class PlotSyncPlotList(Streamable):
identifier: PlotSyncIdentifier
data: List[Plot]
final: bool
def __str__(self) -> str:
return f"PlotSyncPlotList: identifier {self.identifier}, count {len(self.data)}, final {self.final}"
@dataclass(frozen=True)
@streamable
class PlotSyncDone(Streamable):
identifier: PlotSyncIdentifier
duration: uint64
def __str__(self) -> str:
return f"PlotSyncDone: identifier {self.identifier}, duration {self.duration}"
@dataclass(frozen=True)
@streamable
class PlotSyncError(Streamable):
code: int16
message: str
expected_identifier: Optional[PlotSyncIdentifier]
def __str__(self) -> str:
return f"PlotSyncError: code {self.code}, count {self.message}, expected_identifier {self.expected_identifier}"
@dataclass(frozen=True)
@streamable
class PlotSyncResponse(Streamable):
identifier: PlotSyncIdentifier
message_type: int16
error: Optional[PlotSyncError]
def __str__(self) -> str:
return f"PlotSyncResponse: identifier {self.identifier}, message_type {self.message_type}, error {self.error}"

View File

@ -86,6 +86,14 @@ class ProtocolMessageTypes(Enum):
new_signage_point_harvester = 66
request_plots = 67
respond_plots = 68
plot_sync_start = 78
plot_sync_loaded = 79
plot_sync_removed = 80
plot_sync_invalid = 81
plot_sync_keys_missing = 82
plot_sync_duplicates = 83
plot_sync_done = 84
plot_sync_response = 85
# More wallet protocol
coin_state_update = 69

View File

@ -5,7 +5,7 @@ from typing import List, Tuple
from chia.util.ints import uint8, uint16
from chia.util.streamable import Streamable, streamable
protocol_version = "0.0.33"
protocol_version = "0.0.34"
"""
Handshake when establishing a connection between two servers.

View File

@ -97,6 +97,14 @@ rate_limits_other = {
ProtocolMessageTypes.farm_new_block: RLSettings(200, 200),
ProtocolMessageTypes.request_plots: RLSettings(10, 10 * 1024 * 1024),
ProtocolMessageTypes.respond_plots: RLSettings(10, 100 * 1024 * 1024),
ProtocolMessageTypes.plot_sync_start: RLSettings(1000, 100 * 1024 * 1024),
ProtocolMessageTypes.plot_sync_loaded: RLSettings(1000, 100 * 1024 * 1024),
ProtocolMessageTypes.plot_sync_removed: RLSettings(1000, 100 * 1024 * 1024),
ProtocolMessageTypes.plot_sync_invalid: RLSettings(1000, 100 * 1024 * 1024),
ProtocolMessageTypes.plot_sync_keys_missing: RLSettings(1000, 100 * 1024 * 1024),
ProtocolMessageTypes.plot_sync_duplicates: RLSettings(1000, 100 * 1024 * 1024),
ProtocolMessageTypes.plot_sync_done: RLSettings(1000, 100 * 1024 * 1024),
ProtocolMessageTypes.plot_sync_response: RLSettings(3000, 100 * 1024 * 1024),
ProtocolMessageTypes.coin_state_update: RLSettings(1000, 100 * 1024 * 1024),
ProtocolMessageTypes.register_interest_in_puzzle_hash: RLSettings(1000, 100 * 1024 * 1024),
ProtocolMessageTypes.respond_to_ph_update: RLSettings(1000, 100 * 1024 * 1024),

View File

@ -92,6 +92,7 @@ kwargs = dict(
"chia.farmer",
"chia.harvester",
"chia.introducer",
"chia.plot_sync",
"chia.plotters",
"chia.plotting",
"chia.pools",

View File

@ -8,6 +8,9 @@ import pytest_asyncio
import tempfile
from tests.setup_nodes import setup_node_and_wallet, setup_n_nodes, setup_two_nodes
from pathlib import Path
from typing import Any, AsyncIterator, Dict, List, Tuple
from chia.server.start_service import Service
# Set spawn after stdlib imports, but before other imports
from chia.clvm.spend_sim import SimClient, SpendSim
@ -39,6 +42,7 @@ from pathlib import Path
from chia.util.keyring_wrapper import KeyringWrapper
from tests.block_tools import BlockTools, test_constants, create_block_tools, create_block_tools_async
from tests.util.keyring import TempKeyring
from tests.setup_nodes import setup_farmer_multi_harvester
@pytest.fixture(scope="session")
@ -403,6 +407,24 @@ async def two_nodes_one_block(bt, wallet_a):
yield _
@pytest_asyncio.fixture(scope="function")
async def farmer_one_harvester(tmp_path: Path, bt: BlockTools) -> AsyncIterator[Tuple[List[Service], Service]]:
async for _ in setup_farmer_multi_harvester(bt, 1, tmp_path, test_constants):
yield _
@pytest_asyncio.fixture(scope="function")
async def farmer_two_harvester(tmp_path: Path, bt: BlockTools) -> AsyncIterator[Tuple[List[Service], Service]]:
async for _ in setup_farmer_multi_harvester(bt, 2, tmp_path, test_constants):
yield _
@pytest_asyncio.fixture(scope="function")
async def farmer_three_harvester(tmp_path: Path, bt: BlockTools) -> AsyncIterator[Tuple[List[Service], Service]]:
async for _ in setup_farmer_multi_harvester(bt, 3, tmp_path, test_constants):
yield _
# TODO: Ideally, the db_version should be the (parameterized) db_version
# fixture, to test all versions of the database schema. This doesn't work
# because of a hack in shutting down the full node, which means you cannot run

View File

@ -112,7 +112,6 @@ async def test_farmer_get_harvesters(harvester_farmer_environment):
harvester_rpc_api,
harvester_rpc_client,
) = harvester_farmer_environment
farmer_api = farmer_service._api
harvester = harvester_service._node
num_plots = 0
@ -125,11 +124,6 @@ async def test_farmer_get_harvesters(harvester_farmer_environment):
await time_out_assert(10, non_zero_plots)
# Reset cache and force updates cache every second to make sure the farmer gets the most recent data
update_interval_before = farmer_api.farmer.update_harvester_cache_interval
farmer_api.farmer.update_harvester_cache_interval = 1
farmer_api.farmer.harvester_cache = {}
async def test_get_harvesters():
harvester.plot_manager.trigger_refresh()
await time_out_assert(5, harvester.plot_manager.needs_refresh, value=False)
@ -144,10 +138,6 @@ async def test_farmer_get_harvesters(harvester_farmer_environment):
await time_out_assert_custom_interval(30, 1, test_get_harvesters)
# Reset cache and reset update interval to avoid hitting the rate limit
farmer_api.farmer.update_harvester_cache_interval = update_interval_before
farmer_api.farmer.harvester_cache = {}
@pytest.mark.asyncio
async def test_farmer_signage_point_endpoints(harvester_farmer_environment):

View File

View File

@ -0,0 +1,2 @@
parallel = True
checkout_blocks_and_plots = True

View File

@ -0,0 +1,90 @@
import logging
from typing import List
import pytest
from blspy import G1Element
from chia.plot_sync.delta import Delta, DeltaType, PathListDelta, PlotListDelta
from chia.protocols.harvester_protocol import Plot
from chia.types.blockchain_format.sized_bytes import bytes32
from chia.util.ints import uint8, uint64
log = logging.getLogger(__name__)
def dummy_plot(path: str) -> Plot:
return Plot(path, uint8(32), bytes32(b"\00" * 32), G1Element(), None, G1Element(), uint64(0), uint64(0))
@pytest.mark.parametrize(
["delta"],
[
pytest.param(PathListDelta(), id="path list"),
pytest.param(PlotListDelta(), id="plot list"),
],
)
def test_list_delta(delta: DeltaType) -> None:
assert delta.empty()
if type(delta) == PathListDelta:
assert delta.additions == []
elif type(delta) == PlotListDelta:
assert delta.additions == {}
else:
assert False
assert delta.removals == []
assert delta.empty()
if type(delta) == PathListDelta:
delta.additions.append("0")
elif type(delta) == PlotListDelta:
delta.additions["0"] = dummy_plot("0")
else:
assert False, "Invalid delta type"
assert not delta.empty()
delta.removals.append("0")
assert not delta.empty()
delta.additions.clear()
assert not delta.empty()
delta.clear()
assert delta.empty()
@pytest.mark.parametrize(
["old", "new", "result"],
[
[[], [], PathListDelta()],
[["1"], ["0"], PathListDelta(["0"], ["1"])],
[["1", "2", "3"], ["1", "2", "3"], PathListDelta([], [])],
[["2", "1", "3"], ["2", "3", "1"], PathListDelta([], [])],
[["2"], ["2", "3", "1"], PathListDelta(["3", "1"], [])],
[["2"], ["1", "3"], PathListDelta(["1", "3"], ["2"])],
[["1"], ["1", "2", "3"], PathListDelta(["2", "3"], [])],
[[], ["1", "2", "3"], PathListDelta(["1", "2", "3"], [])],
[["-1"], ["1", "2", "3"], PathListDelta(["1", "2", "3"], ["-1"])],
[["-1", "1"], ["2", "3"], PathListDelta(["2", "3"], ["-1", "1"])],
[["-1", "1", "2"], ["2", "3"], PathListDelta(["3"], ["-1", "1"])],
[["-1", "2", "3"], ["2", "3"], PathListDelta([], ["-1"])],
[["-1", "2", "3", "-2"], ["2", "3"], PathListDelta([], ["-1", "-2"])],
[["-2", "2", "3", "-1"], ["2", "3"], PathListDelta([], ["-2", "-1"])],
],
)
def test_path_list_delta_from_lists(old: List[str], new: List[str], result: PathListDelta) -> None:
assert PathListDelta.from_lists(old, new) == result
def test_delta_empty() -> None:
delta: Delta = Delta()
all_deltas: List[DeltaType] = [delta.valid, delta.invalid, delta.keys_missing, delta.duplicates]
assert delta.empty()
for d1 in all_deltas:
delta.valid.additions["0"] = dummy_plot("0")
delta.invalid.additions.append("0")
delta.keys_missing.additions.append("0")
delta.duplicates.additions.append("0")
assert not delta.empty()
for d2 in all_deltas:
if d2 is not d1:
d2.clear()
assert not delta.empty()
assert not delta.empty()
d1.clear()
assert delta.empty()

View File

@ -0,0 +1,537 @@
from dataclasses import dataclass, field
from pathlib import Path
from shutil import copy
from typing import List, Optional, Tuple
import pytest
import pytest_asyncio
from blspy import G1Element
from chia.farmer.farmer_api import Farmer
from chia.harvester.harvester_api import Harvester
from chia.plot_sync.delta import Delta, PathListDelta, PlotListDelta
from chia.plot_sync.receiver import Receiver
from chia.plot_sync.sender import Sender
from chia.plot_sync.util import State
from chia.plotting.manager import PlotManager
from chia.plotting.util import add_plot_directory, remove_plot_directory
from chia.protocols.harvester_protocol import Plot
from chia.server.start_service import Service
from chia.types.blockchain_format.sized_bytes import bytes32
from chia.util.config import create_default_chia_config
from chia.util.ints import uint8, uint64
from tests.block_tools import BlockTools
from tests.plot_sync.util import start_harvester_service
from tests.plotting.test_plot_manager import MockPlotInfo, TestDirectory
from tests.plotting.util import get_test_plots
from tests.time_out_assert import time_out_assert
def synced(sender: Sender, receiver: Receiver, previous_last_sync_id: int) -> bool:
return (
sender._last_sync_id != previous_last_sync_id
and sender._last_sync_id == receiver._last_sync_id != 0
and receiver.state() == State.idle
and not sender._lock.locked()
)
def assert_path_list_matches(expected_list: List[str], actual_list: List[str]) -> None:
assert len(expected_list) == len(actual_list)
for item in expected_list:
assert str(item) in actual_list
@dataclass
class ExpectedResult:
valid_count: int = 0
valid_delta: PlotListDelta = field(default_factory=PlotListDelta)
invalid_count: int = 0
invalid_delta: PathListDelta = field(default_factory=PathListDelta)
keys_missing_count: int = 0
keys_missing_delta: PathListDelta = field(default_factory=PathListDelta)
duplicates_count: int = 0
duplicates_delta: PathListDelta = field(default_factory=PathListDelta)
callback_passed: bool = False
def add_valid(self, list_plots: List[MockPlotInfo]) -> None:
def create_mock_plot(info: MockPlotInfo) -> Plot:
return Plot(
info.prover.get_filename(),
uint8(0),
bytes32(b"\x00" * 32),
None,
None,
G1Element(),
uint64(0),
uint64(0),
)
self.valid_count += len(list_plots)
self.valid_delta.additions.update({x.prover.get_filename(): create_mock_plot(x) for x in list_plots})
def remove_valid(self, list_paths: List[Path]) -> None:
self.valid_count -= len(list_paths)
self.valid_delta.removals += [str(x) for x in list_paths]
def add_invalid(self, list_paths: List[Path]) -> None:
self.invalid_count += len(list_paths)
self.invalid_delta.additions += [str(x) for x in list_paths]
def remove_invalid(self, list_paths: List[Path]) -> None:
self.invalid_count -= len(list_paths)
self.invalid_delta.removals += [str(x) for x in list_paths]
def add_keys_missing(self, list_paths: List[Path]) -> None:
self.keys_missing_count += len(list_paths)
self.keys_missing_delta.additions += [str(x) for x in list_paths]
def remove_keys_missing(self, list_paths: List[Path]) -> None:
self.keys_missing_count -= len(list_paths)
self.keys_missing_delta.removals += [str(x) for x in list_paths]
def add_duplicates(self, list_paths: List[Path]) -> None:
self.duplicates_count += len(list_paths)
self.duplicates_delta.additions += [str(x) for x in list_paths]
def remove_duplicates(self, list_paths: List[Path]) -> None:
self.duplicates_count -= len(list_paths)
self.duplicates_delta.removals += [str(x) for x in list_paths]
@dataclass
class Environment:
root_path: Path
harvester_services: List[Service]
farmer_service: Service
harvesters: List[Harvester]
farmer: Farmer
dir_1: TestDirectory
dir_2: TestDirectory
dir_3: TestDirectory
dir_4: TestDirectory
dir_invalid: TestDirectory
dir_keys_missing: TestDirectory
dir_duplicates: TestDirectory
expected: List[ExpectedResult]
def get_harvester(self, peer_id: bytes32) -> Optional[Harvester]:
for harvester in self.harvesters:
assert harvester.server is not None
if harvester.server.node_id == peer_id:
return harvester
return None
def add_directory(self, harvester_index: int, directory: TestDirectory, state: State = State.loaded) -> None:
add_plot_directory(self.harvesters[harvester_index].root_path, str(directory.path))
if state == State.loaded:
self.expected[harvester_index].add_valid(directory.plot_info_list())
elif state == State.invalid:
self.expected[harvester_index].add_invalid(directory.path_list())
elif state == State.keys_missing:
self.expected[harvester_index].add_keys_missing(directory.path_list())
elif state == State.duplicates:
self.expected[harvester_index].add_duplicates(directory.path_list())
else:
assert False, "Invalid state"
def remove_directory(self, harvester_index: int, directory: TestDirectory, state: State = State.removed) -> None:
remove_plot_directory(self.harvesters[harvester_index].root_path, str(directory.path))
if state == State.removed:
self.expected[harvester_index].remove_valid(directory.path_list())
elif state == State.invalid:
self.expected[harvester_index].remove_invalid(directory.path_list())
elif state == State.keys_missing:
self.expected[harvester_index].remove_keys_missing(directory.path_list())
elif state == State.duplicates:
self.expected[harvester_index].remove_duplicates(directory.path_list())
else:
assert False, "Invalid state"
def add_all_directories(self, harvester_index: int) -> None:
self.add_directory(harvester_index, self.dir_1)
self.add_directory(harvester_index, self.dir_2)
self.add_directory(harvester_index, self.dir_3)
self.add_directory(harvester_index, self.dir_4)
self.add_directory(harvester_index, self.dir_keys_missing, State.keys_missing)
self.add_directory(harvester_index, self.dir_invalid, State.invalid)
# Note: This does not add dir_duplicates since its important that the duplicated plots are loaded after the
# the original ones.
# self.add_directory(harvester_index, self.dir_duplicates, State.duplicates)
def remove_all_directories(self, harvester_index: int) -> None:
self.remove_directory(harvester_index, self.dir_1)
self.remove_directory(harvester_index, self.dir_2)
self.remove_directory(harvester_index, self.dir_3)
self.remove_directory(harvester_index, self.dir_4)
self.remove_directory(harvester_index, self.dir_keys_missing, State.keys_missing)
self.remove_directory(harvester_index, self.dir_invalid, State.invalid)
self.remove_directory(harvester_index, self.dir_duplicates, State.duplicates)
async def plot_sync_callback(self, peer_id: bytes32, delta: Delta) -> None:
harvester: Optional[Harvester] = self.get_harvester(peer_id)
assert harvester is not None
expected = self.expected[self.harvesters.index(harvester)]
assert len(expected.valid_delta.additions) == len(delta.valid.additions)
for path, plot_info in expected.valid_delta.additions.items():
assert path in delta.valid.additions
plot = harvester.plot_manager.plots.get(Path(path), None)
assert plot is not None
assert plot.prover.get_filename() == delta.valid.additions[path].filename
assert plot.prover.get_size() == delta.valid.additions[path].size
assert plot.prover.get_id() == delta.valid.additions[path].plot_id
assert plot.pool_public_key == delta.valid.additions[path].pool_public_key
assert plot.pool_contract_puzzle_hash == delta.valid.additions[path].pool_contract_puzzle_hash
assert plot.plot_public_key == delta.valid.additions[path].plot_public_key
assert plot.file_size == delta.valid.additions[path].file_size
assert int(plot.time_modified) == delta.valid.additions[path].time_modified
assert_path_list_matches(expected.valid_delta.removals, delta.valid.removals)
assert_path_list_matches(expected.invalid_delta.additions, delta.invalid.additions)
assert_path_list_matches(expected.invalid_delta.removals, delta.invalid.removals)
assert_path_list_matches(expected.keys_missing_delta.additions, delta.keys_missing.additions)
assert_path_list_matches(expected.keys_missing_delta.removals, delta.keys_missing.removals)
assert_path_list_matches(expected.duplicates_delta.additions, delta.duplicates.additions)
assert_path_list_matches(expected.duplicates_delta.removals, delta.duplicates.removals)
expected.valid_delta.clear()
expected.invalid_delta.clear()
expected.keys_missing_delta.clear()
expected.duplicates_delta.clear()
expected.callback_passed = True
async def run_sync_test(self) -> None:
plot_manager: PlotManager
assert len(self.harvesters) == len(self.expected)
last_sync_ids: List[uint64] = []
# Run the test in two steps, first trigger the refresh on both harvesters
for harvester in self.harvesters:
plot_manager = harvester.plot_manager
assert harvester.server is not None
receiver = self.farmer.plot_sync_receivers[harvester.server.node_id]
# Make sure to reset the passed flag always before a new run
self.expected[self.harvesters.index(harvester)].callback_passed = False
receiver._update_callback = self.plot_sync_callback
assert harvester.plot_sync_sender._last_sync_id == receiver._last_sync_id
last_sync_ids.append(harvester.plot_sync_sender._last_sync_id)
plot_manager.start_refreshing()
plot_manager.trigger_refresh()
# Then wait for them to be synced with the farmer and validate them
for harvester in self.harvesters:
plot_manager = harvester.plot_manager
assert harvester.server is not None
receiver = self.farmer.plot_sync_receivers[harvester.server.node_id]
await time_out_assert(10, plot_manager.needs_refresh, value=False)
harvester_index = self.harvesters.index(harvester)
await time_out_assert(
10, synced, True, harvester.plot_sync_sender, receiver, last_sync_ids[harvester_index]
)
expected = self.expected[harvester_index]
assert plot_manager.plot_count() == len(receiver.plots()) == expected.valid_count
assert len(plot_manager.failed_to_open_filenames) == len(receiver.invalid()) == expected.invalid_count
assert len(plot_manager.no_key_filenames) == len(receiver.keys_missing()) == expected.keys_missing_count
assert len(plot_manager.get_duplicates()) == len(receiver.duplicates()) == expected.duplicates_count
assert expected.callback_passed
assert expected.valid_delta.empty()
assert expected.invalid_delta.empty()
assert expected.keys_missing_delta.empty()
assert expected.duplicates_delta.empty()
for path, plot_info in plot_manager.plots.items():
assert str(path) in receiver.plots()
assert plot_info.prover.get_filename() == receiver.plots()[str(path)].filename
assert plot_info.prover.get_size() == receiver.plots()[str(path)].size
assert plot_info.prover.get_id() == receiver.plots()[str(path)].plot_id
assert plot_info.pool_public_key == receiver.plots()[str(path)].pool_public_key
assert plot_info.pool_contract_puzzle_hash == receiver.plots()[str(path)].pool_contract_puzzle_hash
assert plot_info.plot_public_key == receiver.plots()[str(path)].plot_public_key
assert plot_info.file_size == receiver.plots()[str(path)].file_size
assert int(plot_info.time_modified) == receiver.plots()[str(path)].time_modified
for path in plot_manager.failed_to_open_filenames:
assert str(path) in receiver.invalid()
for path in plot_manager.no_key_filenames:
assert str(path) in receiver.keys_missing()
for path in plot_manager.get_duplicates():
assert str(path) in receiver.duplicates()
async def handshake_done(self, index: int) -> bool:
return (
self.harvesters[index].plot_manager._refresh_thread is not None
and len(self.harvesters[index].plot_manager.farmer_public_keys) > 0
)
@pytest_asyncio.fixture(scope="function")
async def environment(
bt: BlockTools, tmp_path: Path, farmer_two_harvester: Tuple[List[Service], Service]
) -> Environment:
def new_test_dir(name: str, plot_list: List[Path]) -> TestDirectory:
return TestDirectory(tmp_path / "plots" / name, plot_list)
plots: List[Path] = get_test_plots()
plots_invalid: List[Path] = get_test_plots()[0:3]
plots_keys_missing: List[Path] = get_test_plots("not_in_keychain")
# Create 4 directories where: dir_n contains n plots
directories: List[TestDirectory] = []
offset: int = 0
while len(directories) < 4:
dir_number = len(directories) + 1
directories.append(new_test_dir(f"{dir_number}", plots[offset : offset + dir_number]))
offset += dir_number
dir_invalid: TestDirectory = new_test_dir("invalid", plots_invalid)
dir_keys_missing: TestDirectory = new_test_dir("keys_missing", plots_keys_missing)
dir_duplicates: TestDirectory = new_test_dir("duplicates", directories[3].plots)
create_default_chia_config(tmp_path)
# Invalidate the plots in `dir_invalid`
for path in dir_invalid.path_list():
with open(path, "wb") as file:
file.write(bytes(100))
harvester_services: List[Service]
farmer_service: Service
harvester_services, farmer_service = farmer_two_harvester
farmer: Farmer = farmer_service._node
harvesters: List[Harvester] = [await start_harvester_service(service) for service in harvester_services]
for harvester in harvesters:
harvester.plot_manager.set_public_keys(
bt.plot_manager.farmer_public_keys.copy(), bt.plot_manager.pool_public_keys.copy()
)
assert len(farmer.plot_sync_receivers) == 2
return Environment(
tmp_path,
harvester_services,
farmer_service,
harvesters,
farmer,
directories[0],
directories[1],
directories[2],
directories[3],
dir_invalid,
dir_keys_missing,
dir_duplicates,
[ExpectedResult() for _ in harvesters],
)
@pytest.mark.asyncio
async def test_sync_valid(environment: Environment) -> None:
env: Environment = environment
env.add_directory(0, env.dir_1)
env.add_directory(1, env.dir_2)
await env.run_sync_test()
# Run again two times to make sure we still get the same results in repeated refresh intervals
env.expected[0].valid_delta.clear()
env.expected[1].valid_delta.clear()
await env.run_sync_test()
await env.run_sync_test()
env.add_directory(0, env.dir_3)
env.add_directory(1, env.dir_4)
await env.run_sync_test()
while len(env.dir_3.path_list()):
drop_plot = env.dir_3.path_list()[0]
drop_plot.unlink()
env.dir_3.drop(drop_plot)
env.expected[0].remove_valid([drop_plot])
await env.run_sync_test()
env.remove_directory(0, env.dir_3)
await env.run_sync_test()
env.remove_directory(1, env.dir_4)
await env.run_sync_test()
env.remove_directory(0, env.dir_1)
env.remove_directory(1, env.dir_2)
await env.run_sync_test()
@pytest.mark.asyncio
async def test_sync_invalid(environment: Environment) -> None:
env: Environment = environment
assert len(env.farmer.plot_sync_receivers) == 2
# Use dir_3 and dir_4 in this test because the invalid plots are copies from dir_1 + dir_2
env.add_directory(0, env.dir_3)
env.add_directory(0, env.dir_invalid, State.invalid)
env.add_directory(1, env.dir_4)
await env.run_sync_test()
# Run again two times to make sure we still get the same results in repeated refresh intervals
await env.run_sync_test()
await env.run_sync_test()
# Drop all but two of the invalid plots
assert len(env.dir_invalid) > 2
for _ in range(len(env.dir_invalid) - 2):
drop_plot = env.dir_invalid.path_list()[0]
drop_plot.unlink()
env.dir_invalid.drop(drop_plot)
env.expected[0].remove_invalid([drop_plot])
await env.run_sync_test()
assert len(env.dir_invalid) == 2
# Add the directory to the first harvester too
env.add_directory(1, env.dir_invalid, State.invalid)
await env.run_sync_test()
# Recover one the remaining invalid plot
for path in get_test_plots():
if path.name == env.dir_invalid.path_list()[0].name:
copy(path, env.dir_invalid.path)
for i in range(len(env.harvesters)):
env.expected[i].add_valid([env.dir_invalid.plot_info_list()[0]])
env.expected[i].remove_invalid([env.dir_invalid.path_list()[0]])
env.harvesters[i].plot_manager.refresh_parameter.retry_invalid_seconds = 0
await env.run_sync_test()
for i in [0, 1]:
remove_plot_directory(env.harvesters[i].root_path, str(env.dir_invalid.path))
env.expected[i].remove_valid([env.dir_invalid.path_list()[0]])
env.expected[i].remove_invalid([env.dir_invalid.path_list()[1]])
await env.run_sync_test()
@pytest.mark.asyncio
async def test_sync_keys_missing(environment: Environment) -> None:
env: Environment = environment
env.add_directory(0, env.dir_1)
env.add_directory(0, env.dir_keys_missing, State.keys_missing)
env.add_directory(1, env.dir_2)
await env.run_sync_test()
# Run again two times to make sure we still get the same results in repeated refresh intervals
await env.run_sync_test()
await env.run_sync_test()
# Drop all but 2 plots with missing keys and test sync inbetween
assert len(env.dir_keys_missing) > 2
for _ in range(len(env.dir_keys_missing) - 2):
drop_plot = env.dir_keys_missing.path_list()[0]
drop_plot.unlink()
env.dir_keys_missing.drop(drop_plot)
env.expected[0].remove_keys_missing([drop_plot])
await env.run_sync_test()
assert len(env.dir_keys_missing) == 2
# Add the plots with missing keys to the other harvester
env.add_directory(0, env.dir_3)
env.add_directory(1, env.dir_keys_missing, State.keys_missing)
await env.run_sync_test()
# Add the missing keys to the first harvester's plot manager
env.harvesters[0].plot_manager.farmer_public_keys.append(G1Element())
env.harvesters[0].plot_manager.pool_public_keys.append(G1Element())
# And validate they become valid now
env.expected[0].add_valid(env.dir_keys_missing.plot_info_list())
env.expected[0].remove_keys_missing(env.dir_keys_missing.path_list())
await env.run_sync_test()
# Drop the valid plots from one harvester and the keys missing plots from the other harvester
env.remove_directory(0, env.dir_keys_missing)
env.remove_directory(1, env.dir_keys_missing, State.keys_missing)
await env.run_sync_test()
@pytest.mark.asyncio
async def test_sync_duplicates(environment: Environment) -> None:
env: Environment = environment
# dir_4 and then dir_duplicates contain the same plots. Load dir_4 first to make sure the plots seen as duplicates
# are from dir_duplicates.
env.add_directory(0, env.dir_4)
await env.run_sync_test()
env.add_directory(0, env.dir_duplicates, State.duplicates)
env.add_directory(1, env.dir_2)
await env.run_sync_test()
# Run again two times to make sure we still get the same results in repeated refresh intervals
await env.run_sync_test()
await env.run_sync_test()
# Drop all but 1 duplicates and test sync in-between
assert len(env.dir_duplicates) > 2
for _ in range(len(env.dir_duplicates) - 2):
drop_plot = env.dir_duplicates.path_list()[0]
drop_plot.unlink()
env.dir_duplicates.drop(drop_plot)
env.expected[0].remove_duplicates([drop_plot])
await env.run_sync_test()
assert len(env.dir_duplicates) == 2
# Removing dir_4 now leads to the plots in dir_duplicates to become loaded instead
env.remove_directory(0, env.dir_4)
env.expected[0].remove_duplicates(env.dir_duplicates.path_list())
env.expected[0].add_valid(env.dir_duplicates.plot_info_list())
await env.run_sync_test()
async def add_and_validate_all_directories(env: Environment) -> None:
# Add all available directories to both harvesters and make sure they load and get synced
env.add_all_directories(0)
env.add_all_directories(1)
await env.run_sync_test()
env.add_directory(0, env.dir_duplicates, State.duplicates)
env.add_directory(1, env.dir_duplicates, State.duplicates)
await env.run_sync_test()
async def remove_and_validate_all_directories(env: Environment) -> None:
# Remove all available directories to both harvesters and make sure they are removed and get synced
env.remove_all_directories(0)
env.remove_all_directories(1)
await env.run_sync_test()
@pytest.mark.asyncio
async def test_add_and_remove_all_directories(environment: Environment) -> None:
await add_and_validate_all_directories(environment)
await remove_and_validate_all_directories(environment)
@pytest.mark.asyncio
async def test_harvester_restart(environment: Environment) -> None:
env: Environment = environment
# Load all directories for both harvesters
await add_and_validate_all_directories(env)
# Stop the harvester and make sure the receiver gets dropped on the farmer and refreshing gets stopped
env.harvester_services[0].stop()
await env.harvester_services[0].wait_closed()
assert len(env.farmer.plot_sync_receivers) == 1
assert not env.harvesters[0].plot_manager._refreshing_enabled
assert not env.harvesters[0].plot_manager.needs_refresh()
# Start the harvester, wait for the handshake and make sure the receiver comes back
await env.harvester_services[0].start()
await time_out_assert(5, env.handshake_done, True, 0)
assert len(env.farmer.plot_sync_receivers) == 2
# Remove the duplicates dir to avoid conflicts with the original plots
env.remove_directory(0, env.dir_duplicates)
# Reset the expected data for harvester 0 and re-add all directories because of the restart
env.expected[0] = ExpectedResult()
env.add_all_directories(0)
# Run the refresh two times and make sure everything recovers and stays recovered after harvester restart
await env.run_sync_test()
env.add_directory(0, env.dir_duplicates, State.duplicates)
await env.run_sync_test()
@pytest.mark.asyncio
async def test_farmer_restart(environment: Environment) -> None:
env: Environment = environment
# Load all directories for both harvesters
await add_and_validate_all_directories(env)
last_sync_ids: List[uint64] = []
for i in range(0, len(env.harvesters)):
last_sync_ids.append(env.harvesters[i].plot_sync_sender._last_sync_id)
# Stop the farmer and make sure both receivers get dropped and refreshing gets stopped on the harvesters
env.farmer_service.stop()
await env.farmer_service.wait_closed()
assert len(env.farmer.plot_sync_receivers) == 0
assert not env.harvesters[0].plot_manager._refreshing_enabled
assert not env.harvesters[1].plot_manager._refreshing_enabled
# Start the farmer, wait for the handshake and make sure the receivers come back
await env.farmer_service.start()
await time_out_assert(5, env.handshake_done, True, 0)
await time_out_assert(5, env.handshake_done, True, 1)
assert len(env.farmer.plot_sync_receivers) == 2
# Do not use run_sync_test here, to have a more realistic test scenario just wait for the harvesters to be synced.
# The handshake should trigger re-sync.
for i in range(0, len(env.harvesters)):
harvester: Harvester = env.harvesters[i]
assert harvester.server is not None
receiver = env.farmer.plot_sync_receivers[harvester.server.node_id]
await time_out_assert(10, synced, True, harvester.plot_sync_sender, receiver, last_sync_ids[i])
# Validate the sync
for harvester in env.harvesters:
plot_manager: PlotManager = harvester.plot_manager
assert harvester.server is not None
receiver = env.farmer.plot_sync_receivers[harvester.server.node_id]
expected = env.expected[env.harvesters.index(harvester)]
assert plot_manager.plot_count() == len(receiver.plots()) == expected.valid_count
assert len(plot_manager.failed_to_open_filenames) == len(receiver.invalid()) == expected.invalid_count
assert len(plot_manager.no_key_filenames) == len(receiver.keys_missing()) == expected.keys_missing_count
assert len(plot_manager.get_duplicates()) == len(receiver.duplicates()) == expected.duplicates_count

View File

@ -0,0 +1,376 @@
import logging
import time
from secrets import token_bytes
from typing import Any, Callable, List, Tuple, Type, Union
import pytest
from blspy import G1Element
from chia.plot_sync.delta import Delta
from chia.plot_sync.receiver import Receiver
from chia.plot_sync.util import ErrorCodes, State
from chia.protocols.harvester_protocol import (
Plot,
PlotSyncDone,
PlotSyncIdentifier,
PlotSyncPathList,
PlotSyncPlotList,
PlotSyncResponse,
PlotSyncStart,
)
from chia.server.ws_connection import NodeType
from chia.types.blockchain_format.sized_bytes import bytes32
from chia.util.ints import uint8, uint32, uint64
from chia.util.streamable import _T_Streamable
from tests.plot_sync.util import get_dummy_connection
log = logging.getLogger(__name__)
next_message_id = uint64(0)
def assert_default_values(receiver: Receiver) -> None:
assert receiver.state() == State.idle
assert receiver.expected_sync_id() == 0
assert receiver.expected_message_id() == 0
assert receiver.last_sync_id() == 0
assert receiver.last_sync_time() == 0
assert receiver.plots() == {}
assert receiver.invalid() == []
assert receiver.keys_missing() == []
assert receiver.duplicates() == []
async def dummy_callback(_: bytes32, __: Delta) -> None:
pass
class SyncStepData:
state: State
function: Any
payload_type: Any
args: Any
def __init__(
self, state: State, function: Callable[[_T_Streamable], Any], payload_type: Type[_T_Streamable], *args: Any
) -> None:
self.state = state
self.function = function
self.payload_type = payload_type
self.args = args
def plot_sync_identifier(current_sync_id: uint64, message_id: uint64) -> PlotSyncIdentifier:
return PlotSyncIdentifier(uint64(0), current_sync_id, message_id)
def create_payload(payload_type: Any, start: bool, *args: Any) -> Any:
global next_message_id
if start:
next_message_id = uint64(0)
next_identifier = plot_sync_identifier(uint64(1), next_message_id)
next_message_id = uint64(next_message_id + 1)
return payload_type(next_identifier, *args)
def assert_error_response(plot_sync: Receiver, error_code: ErrorCodes) -> None:
connection = plot_sync.connection()
assert connection is not None
message = connection.last_sent_message
assert message is not None
response: PlotSyncResponse = PlotSyncResponse.from_bytes(message.data)
assert response.error is not None
assert response.error.code == error_code.value
def pre_function_validate(receiver: Receiver, data: Union[List[Plot], List[str]], expected_state: State) -> None:
if expected_state == State.loaded:
for plot_info in data:
assert type(plot_info) == Plot
assert plot_info.filename not in receiver.plots()
elif expected_state == State.removed:
for path in data:
assert path in receiver.plots()
elif expected_state == State.invalid:
for path in data:
assert path not in receiver.invalid()
elif expected_state == State.keys_missing:
for path in data:
assert path not in receiver.keys_missing()
elif expected_state == State.duplicates:
for path in data:
assert path not in receiver.duplicates()
def post_function_validate(receiver: Receiver, data: Union[List[Plot], List[str]], expected_state: State) -> None:
if expected_state == State.loaded:
for plot_info in data:
assert type(plot_info) == Plot
assert plot_info.filename in receiver._delta.valid.additions
elif expected_state == State.removed:
for path in data:
assert path in receiver._delta.valid.removals
elif expected_state == State.invalid:
for path in data:
assert path in receiver._delta.invalid.additions
elif expected_state == State.keys_missing:
for path in data:
assert path in receiver._delta.keys_missing.additions
elif expected_state == State.duplicates:
for path in data:
assert path in receiver._delta.duplicates.additions
@pytest.mark.asyncio
async def run_sync_step(receiver: Receiver, sync_step: SyncStepData, expected_state: State) -> None:
assert receiver.state() == expected_state
last_sync_time_before = receiver._last_sync_time
# For the the list types invoke the trigger function in batches
if sync_step.payload_type == PlotSyncPlotList or sync_step.payload_type == PlotSyncPathList:
step_data, _ = sync_step.args
assert len(step_data) == 10
# Invoke batches of: 1, 2, 3, 4 items and validate the data against plot store before and after
indexes = [0, 1, 3, 6, 10]
for i in range(0, len(indexes) - 1):
invoke_data = step_data[indexes[i] : indexes[i + 1]]
pre_function_validate(receiver, invoke_data, expected_state)
await sync_step.function(
create_payload(sync_step.payload_type, False, invoke_data, i == (len(indexes) - 2))
)
post_function_validate(receiver, invoke_data, expected_state)
else:
# For Start/Done just invoke it..
await sync_step.function(create_payload(sync_step.payload_type, sync_step.state == State.idle, *sync_step.args))
# Make sure we moved to the next state
assert receiver.state() != expected_state
if sync_step.payload_type == PlotSyncDone:
assert receiver._last_sync_time != last_sync_time_before
else:
assert receiver._last_sync_time == last_sync_time_before
def plot_sync_setup() -> Tuple[Receiver, List[SyncStepData]]:
harvester_connection = get_dummy_connection(NodeType.HARVESTER)
receiver = Receiver(harvester_connection, dummy_callback) # type:ignore[arg-type]
# Create example plot data
path_list = [str(x) for x in range(0, 40)]
plot_info_list = [
Plot(
filename=str(x),
size=uint8(0),
plot_id=bytes32(token_bytes(32)),
pool_contract_puzzle_hash=None,
pool_public_key=None,
plot_public_key=G1Element(),
file_size=uint64(0),
time_modified=uint64(0),
)
for x in path_list
]
# Manually add the plots we want to remove in tests
receiver._plots = {plot_info.filename: plot_info for plot_info in plot_info_list[0:10]}
sync_steps: List[SyncStepData] = [
SyncStepData(State.idle, receiver.sync_started, PlotSyncStart, False, uint64(0), uint32(len(plot_info_list))),
SyncStepData(State.loaded, receiver.process_loaded, PlotSyncPlotList, plot_info_list[10:20], True),
SyncStepData(State.removed, receiver.process_removed, PlotSyncPathList, path_list[0:10], True),
SyncStepData(State.invalid, receiver.process_invalid, PlotSyncPathList, path_list[20:30], True),
SyncStepData(State.keys_missing, receiver.process_keys_missing, PlotSyncPathList, path_list[30:40], True),
SyncStepData(State.duplicates, receiver.process_duplicates, PlotSyncPathList, path_list[10:20], True),
SyncStepData(State.done, receiver.sync_done, PlotSyncDone, uint64(0)),
]
return receiver, sync_steps
def test_default_values() -> None:
assert_default_values(Receiver(get_dummy_connection(NodeType.HARVESTER), dummy_callback)) # type:ignore[arg-type]
@pytest.mark.asyncio
async def test_reset() -> None:
receiver, sync_steps = plot_sync_setup()
connection_before = receiver.connection()
# Assign some dummy values
receiver._sync_state = State.done
receiver._expected_sync_id = uint64(1)
receiver._expected_message_id = uint64(1)
receiver._last_sync_id = uint64(1)
receiver._last_sync_time = time.time()
receiver._invalid = ["1"]
receiver._keys_missing = ["1"]
receiver._delta.valid.additions = receiver.plots().copy()
receiver._delta.valid.removals = ["1"]
receiver._delta.invalid.additions = ["1"]
receiver._delta.invalid.removals = ["1"]
receiver._delta.keys_missing.additions = ["1"]
receiver._delta.keys_missing.removals = ["1"]
receiver._delta.duplicates.additions = ["1"]
receiver._delta.duplicates.removals = ["1"]
# Call `reset` and make sure all expected values are set back to their defaults.
receiver.reset()
assert_default_values(receiver)
assert receiver._delta == Delta()
# Connection should remain
assert receiver.connection() == connection_before
@pytest.mark.asyncio
async def test_to_dict() -> None:
receiver, sync_steps = plot_sync_setup()
plot_sync_dict_1 = receiver.to_dict()
assert "plots" in plot_sync_dict_1 and len(plot_sync_dict_1["plots"]) == 10
assert "failed_to_open_filenames" in plot_sync_dict_1 and len(plot_sync_dict_1["failed_to_open_filenames"]) == 0
assert "no_key_filenames" in plot_sync_dict_1 and len(plot_sync_dict_1["no_key_filenames"]) == 0
assert "last_sync_time" not in plot_sync_dict_1
assert plot_sync_dict_1["connection"] == {
"node_id": receiver.connection().peer_node_id,
"host": receiver.connection().peer_host,
"port": receiver.connection().peer_port,
}
# We should get equal dicts
plot_sync_dict_2 = receiver.to_dict()
assert plot_sync_dict_1 == plot_sync_dict_2
dict_2_paths = [x.filename for x in plot_sync_dict_2["plots"]]
for plot_info in sync_steps[State.loaded].args[0]:
assert plot_info.filename not in dict_2_paths
# Walk through all states from idle to done and run them with the test data
for state in State:
await run_sync_step(receiver, sync_steps[state], state)
plot_sync_dict_3 = receiver.to_dict()
dict_3_paths = [x.filename for x in plot_sync_dict_3["plots"]]
for plot_info in sync_steps[State.loaded].args[0]:
assert plot_info.filename in dict_3_paths
for path in sync_steps[State.removed].args[0]:
assert path not in plot_sync_dict_3["plots"]
for path in sync_steps[State.invalid].args[0]:
assert path in plot_sync_dict_3["failed_to_open_filenames"]
for path in sync_steps[State.keys_missing].args[0]:
assert path in plot_sync_dict_3["no_key_filenames"]
for path in sync_steps[State.duplicates].args[0]:
assert path in plot_sync_dict_3["duplicates"]
assert plot_sync_dict_3["last_sync_time"] > 0
@pytest.mark.asyncio
async def test_sync_flow() -> None:
receiver, sync_steps = plot_sync_setup()
for plot_info in sync_steps[State.loaded].args[0]:
assert plot_info.filename not in receiver.plots()
for path in sync_steps[State.removed].args[0]:
assert path in receiver.plots()
for path in sync_steps[State.invalid].args[0]:
assert path not in receiver.invalid()
for path in sync_steps[State.keys_missing].args[0]:
assert path not in receiver.keys_missing()
for path in sync_steps[State.duplicates].args[0]:
assert path not in receiver.duplicates()
# Walk through all states from idle to done and run them with the test data
for state in State:
await run_sync_step(receiver, sync_steps[state], state)
for plot_info in sync_steps[State.loaded].args[0]:
assert plot_info.filename in receiver.plots()
for path in sync_steps[State.removed].args[0]:
assert path not in receiver.plots()
for path in sync_steps[State.invalid].args[0]:
assert path in receiver.invalid()
for path in sync_steps[State.keys_missing].args[0]:
assert path in receiver.keys_missing()
for path in sync_steps[State.duplicates].args[0]:
assert path in receiver.duplicates()
# We should be in idle state again
assert receiver.state() == State.idle
@pytest.mark.asyncio
async def test_invalid_ids() -> None:
receiver, sync_steps = plot_sync_setup()
for state in State:
assert receiver.state() == state
current_step = sync_steps[state]
if receiver.state() == State.idle:
# Set last_sync_id for the tests below
receiver._last_sync_id = uint64(1)
# Test "sync_started last doesn't match"
invalid_last_sync_id_param = PlotSyncStart(
plot_sync_identifier(uint64(0), uint64(0)), False, uint64(2), uint32(0)
)
await current_step.function(invalid_last_sync_id_param)
assert_error_response(receiver, ErrorCodes.invalid_last_sync_id)
# Test "last_sync_id == new_sync_id"
invalid_sync_id_match_param = PlotSyncStart(
plot_sync_identifier(uint64(1), uint64(0)), False, uint64(1), uint32(0)
)
await current_step.function(invalid_sync_id_match_param)
assert_error_response(receiver, ErrorCodes.sync_ids_match)
# Reset the last_sync_id to the default
receiver._last_sync_id = uint64(0)
else:
# Test invalid sync_id
invalid_sync_id_param = current_step.payload_type(
plot_sync_identifier(uint64(10), uint64(receiver.expected_message_id())), *current_step.args
)
await current_step.function(invalid_sync_id_param)
assert_error_response(receiver, ErrorCodes.invalid_identifier)
# Test invalid message_id
invalid_message_id_param = current_step.payload_type(
plot_sync_identifier(receiver.expected_sync_id(), uint64(receiver.expected_message_id() + 1)),
*current_step.args,
)
await current_step.function(invalid_message_id_param)
assert_error_response(receiver, ErrorCodes.invalid_identifier)
payload = create_payload(current_step.payload_type, state == State.idle, *current_step.args)
await current_step.function(payload)
@pytest.mark.parametrize(
["state_to_fail", "expected_error_code"],
[
pytest.param(State.loaded, ErrorCodes.plot_already_available, id="already available plots"),
pytest.param(State.invalid, ErrorCodes.plot_already_available, id="already available paths"),
pytest.param(State.removed, ErrorCodes.plot_not_available, id="not available"),
],
)
@pytest.mark.asyncio
async def test_plot_errors(state_to_fail: State, expected_error_code: ErrorCodes) -> None:
receiver, sync_steps = plot_sync_setup()
for state in State:
assert receiver.state() == state
current_step = sync_steps[state]
if state == state_to_fail:
plot_infos, _ = current_step.args
await current_step.function(create_payload(current_step.payload_type, False, plot_infos, False))
identifier = plot_sync_identifier(receiver.expected_sync_id(), receiver.expected_message_id())
invalid_payload = current_step.payload_type(identifier, plot_infos, True)
await current_step.function(invalid_payload)
if state == state_to_fail:
assert_error_response(receiver, expected_error_code)
return
else:
await current_step.function(
create_payload(current_step.payload_type, state == State.idle, *current_step.args)
)
assert False, "Didn't fail in the expected state"

View File

@ -0,0 +1,102 @@
import pytest
from chia.plot_sync.exceptions import AlreadyStartedError, InvalidConnectionTypeError
from chia.plot_sync.sender import ExpectedResponse, Sender
from chia.plot_sync.util import Constants
from chia.protocols.harvester_protocol import PlotSyncIdentifier, PlotSyncResponse
from chia.server.ws_connection import NodeType, ProtocolMessageTypes
from chia.util.ints import int16, uint64
from tests.block_tools import BlockTools
from tests.plot_sync.util import get_dummy_connection, plot_sync_identifier
def test_default_values(bt: BlockTools) -> None:
sender = Sender(bt.plot_manager)
assert sender._plot_manager == bt.plot_manager
assert sender._connection is None
assert sender._sync_id == uint64(0)
assert sender._next_message_id == uint64(0)
assert sender._messages == []
assert sender._last_sync_id == uint64(0)
assert not sender._stop_requested
assert sender._task is None
assert not sender._lock.locked()
assert sender._response is None
def test_set_connection_values(bt: BlockTools) -> None:
farmer_connection = get_dummy_connection(NodeType.FARMER)
sender = Sender(bt.plot_manager)
# Test invalid NodeType values
for connection_type in NodeType:
if connection_type != NodeType.FARMER:
pytest.raises(
InvalidConnectionTypeError,
sender.set_connection,
get_dummy_connection(connection_type, farmer_connection.peer_node_id),
)
# Test setting a valid connection works
sender.set_connection(farmer_connection) # type:ignore[arg-type]
assert sender._connection is not None
assert sender._connection == farmer_connection # type: ignore[comparison-overlap]
@pytest.mark.asyncio
async def test_start_stop_send_task(bt: BlockTools) -> None:
sender = Sender(bt.plot_manager)
# Make sure starting/restarting works
for _ in range(2):
assert sender._task is None
await sender.start()
assert sender._task is not None
with pytest.raises(AlreadyStartedError):
await sender.start()
assert not sender._stop_requested
sender.stop()
assert sender._stop_requested
await sender.await_closed()
assert not sender._stop_requested
assert sender._task is None
def test_set_response(bt: BlockTools) -> None:
sender = Sender(bt.plot_manager)
def new_expected_response(sync_id: int, message_id: int, message_type: ProtocolMessageTypes) -> ExpectedResponse:
return ExpectedResponse(message_type, plot_sync_identifier(uint64(sync_id), uint64(message_id)))
def new_response_message(sync_id: int, message_id: int, message_type: ProtocolMessageTypes) -> PlotSyncResponse:
return PlotSyncResponse(
plot_sync_identifier(uint64(sync_id), uint64(message_id)), int16(int(message_type.value)), None
)
response_message = new_response_message(0, 1, ProtocolMessageTypes.plot_sync_start)
assert sender._response is None
# Should trigger unexpected response because `Farmer._response` is `None`
assert not sender.set_response(response_message)
# Set `Farmer._response` and make sure the response gets assigned properly
sender._response = new_expected_response(0, 1, ProtocolMessageTypes.plot_sync_start)
assert sender._response.message is None
assert sender.set_response(response_message)
assert sender._response.message is not None
# Should trigger unexpected response because we already received the message for the currently expected response
assert not sender.set_response(response_message)
# Test expired message
expected_response = new_expected_response(1, 0, ProtocolMessageTypes.plot_sync_start)
sender._response = expected_response
expired_identifier = PlotSyncIdentifier(
uint64(expected_response.identifier.timestamp - Constants.message_timeout - 1),
expected_response.identifier.sync_id,
expected_response.identifier.message_id,
)
expired_message = PlotSyncResponse(expired_identifier, int16(int(ProtocolMessageTypes.plot_sync_start.value)), None)
assert not sender.set_response(expired_message)
# Test invalid sync-id
sender._response = new_expected_response(2, 0, ProtocolMessageTypes.plot_sync_start)
assert not sender.set_response(new_response_message(3, 0, ProtocolMessageTypes.plot_sync_start))
# Test invalid message-id
sender._response = new_expected_response(2, 1, ProtocolMessageTypes.plot_sync_start)
assert not sender.set_response(new_response_message(2, 2, ProtocolMessageTypes.plot_sync_start))
# Test invalid message-type
sender._response = new_expected_response(3, 0, ProtocolMessageTypes.plot_sync_start)
assert not sender.set_response(new_response_message(3, 0, ProtocolMessageTypes.plot_sync_loaded))

View File

@ -0,0 +1,433 @@
import asyncio
import functools
import logging
import time
from dataclasses import dataclass, field
from enum import Enum
from pathlib import Path
from secrets import token_bytes
from typing import Any, Dict, List, Optional, Set, Tuple
import pytest
from blspy import G1Element
from chia.farmer.farmer_api import Farmer
from chia.harvester.harvester_api import Harvester
from chia.plot_sync.receiver import Receiver
from chia.plot_sync.sender import Sender
from chia.plot_sync.util import Constants
from chia.plotting.manager import PlotManager
from chia.plotting.util import PlotInfo
from chia.protocols.harvester_protocol import PlotSyncError, PlotSyncResponse
from chia.server.start_service import Service
from chia.server.ws_connection import ProtocolMessageTypes, WSChiaConnection, make_msg
from chia.types.blockchain_format.sized_bytes import bytes32
from chia.util.generator_tools import list_to_batches
from chia.util.ints import int16, uint64
from tests.plot_sync.util import start_harvester_service
from tests.time_out_assert import time_out_assert
log = logging.getLogger(__name__)
class ErrorSimulation(Enum):
DropEveryFourthMessage = 1
DropThreeMessages = 2
RespondTooLateEveryFourthMessage = 3
RespondTwice = 4
NonRecoverableError = 5
NotConnected = 6
@dataclass
class TestData:
harvester: Harvester
plot_sync_sender: Sender
plot_sync_receiver: Receiver
event_loop: asyncio.AbstractEventLoop
plots: Dict[Path, PlotInfo] = field(default_factory=dict)
invalid: List[PlotInfo] = field(default_factory=list)
keys_missing: List[PlotInfo] = field(default_factory=list)
duplicates: List[PlotInfo] = field(default_factory=list)
async def run(
self,
*,
loaded: List[PlotInfo],
removed: List[PlotInfo],
invalid: List[PlotInfo],
keys_missing: List[PlotInfo],
duplicates: List[PlotInfo],
initial: bool,
) -> None:
for plot_info in loaded:
assert plot_info.prover.get_filename() not in self.plots
for plot_info in removed:
assert plot_info.prover.get_filename() in self.plots
self.invalid = invalid
self.keys_missing = keys_missing
self.duplicates = duplicates
removed_paths: List[Path] = [p.prover.get_filename() for p in removed] if removed is not None else []
invalid_dict: Dict[Path, int] = {p.prover.get_filename(): 0 for p in self.invalid}
keys_missing_set: Set[Path] = set([p.prover.get_filename() for p in self.keys_missing])
duplicates_set: Set[str] = set([p.prover.get_filename() for p in self.duplicates])
# Inject invalid plots into `PlotManager` of the harvester so that the callback calls below can use them
# to sync them to the farmer.
self.harvester.plot_manager.failed_to_open_filenames = invalid_dict
# Inject key missing plots into `PlotManager` of the harvester so that the callback calls below can use them
# to sync them to the farmer.
self.harvester.plot_manager.no_key_filenames = keys_missing_set
# Inject duplicated plots into `PlotManager` of the harvester so that the callback calls below can use them
# to sync them to the farmer.
for plot_info in loaded:
plot_path = Path(plot_info.prover.get_filename())
self.harvester.plot_manager.plot_filename_paths[plot_path.name] = (str(plot_path.parent), set())
for duplicate in duplicates_set:
plot_path = Path(duplicate)
assert plot_path.name in self.harvester.plot_manager.plot_filename_paths
self.harvester.plot_manager.plot_filename_paths[plot_path.name][1].add(str(plot_path.parent))
batch_size = self.harvester.plot_manager.refresh_parameter.batch_size
# Used to capture the sync id in `run_internal`
sync_id: Optional[uint64] = None
def run_internal() -> None:
nonlocal sync_id
# Simulate one plot manager refresh cycle by calling the methods directly.
self.harvester.plot_sync_sender.sync_start(len(loaded), initial)
sync_id = self.plot_sync_sender._sync_id
if len(loaded) == 0:
self.harvester.plot_sync_sender.process_batch([], 0)
for remaining, batch in list_to_batches(loaded, batch_size):
self.harvester.plot_sync_sender.process_batch(batch, remaining)
self.harvester.plot_sync_sender.sync_done(removed_paths, 0)
await self.event_loop.run_in_executor(None, run_internal)
async def sync_done() -> bool:
assert sync_id is not None
return self.plot_sync_receiver.last_sync_id() == self.plot_sync_sender._last_sync_id == sync_id
await time_out_assert(60, sync_done)
for plot_info in loaded:
self.plots[plot_info.prover.get_filename()] = plot_info
for plot_info in removed:
del self.plots[plot_info.prover.get_filename()]
def validate_plot_sync(self) -> None:
assert len(self.plots) == len(self.plot_sync_receiver.plots())
assert len(self.invalid) == len(self.plot_sync_receiver.invalid())
assert len(self.keys_missing) == len(self.plot_sync_receiver.keys_missing())
for _, plot_info in self.plots.items():
assert plot_info.prover.get_filename() not in self.plot_sync_receiver.invalid()
assert plot_info.prover.get_filename() not in self.plot_sync_receiver.keys_missing()
assert plot_info.prover.get_filename() in self.plot_sync_receiver.plots()
synced_plot = self.plot_sync_receiver.plots()[plot_info.prover.get_filename()]
assert plot_info.prover.get_filename() == synced_plot.filename
assert plot_info.pool_public_key == synced_plot.pool_public_key
assert plot_info.pool_contract_puzzle_hash == synced_plot.pool_contract_puzzle_hash
assert plot_info.plot_public_key == synced_plot.plot_public_key
assert plot_info.file_size == synced_plot.file_size
assert uint64(int(plot_info.time_modified)) == synced_plot.time_modified
for plot_info in self.invalid:
assert plot_info.prover.get_filename() not in self.plot_sync_receiver.plots()
assert plot_info.prover.get_filename() in self.plot_sync_receiver.invalid()
assert plot_info.prover.get_filename() not in self.plot_sync_receiver.keys_missing()
assert plot_info.prover.get_filename() not in self.plot_sync_receiver.duplicates()
for plot_info in self.keys_missing:
assert plot_info.prover.get_filename() not in self.plot_sync_receiver.plots()
assert plot_info.prover.get_filename() not in self.plot_sync_receiver.invalid()
assert plot_info.prover.get_filename() in self.plot_sync_receiver.keys_missing()
assert plot_info.prover.get_filename() not in self.plot_sync_receiver.duplicates()
for plot_info in self.duplicates:
assert plot_info.prover.get_filename() not in self.plot_sync_receiver.invalid()
assert plot_info.prover.get_filename() not in self.plot_sync_receiver.keys_missing()
assert plot_info.prover.get_filename() in self.plot_sync_receiver.duplicates()
@dataclass
class TestRunner:
test_data: List[TestData]
def __init__(
self, harvesters: List[Harvester], farmer: Farmer, event_loop: asyncio.events.AbstractEventLoop
) -> None:
self.test_data = []
for harvester in harvesters:
assert harvester.server is not None
self.test_data.append(
TestData(
harvester,
harvester.plot_sync_sender,
farmer.plot_sync_receivers[harvester.server.node_id],
event_loop,
)
)
async def run(
self,
index: int,
*,
loaded: List[PlotInfo],
removed: List[PlotInfo],
invalid: List[PlotInfo],
keys_missing: List[PlotInfo],
duplicates: List[PlotInfo],
initial: bool,
) -> None:
await self.test_data[index].run(
loaded=loaded,
removed=removed,
invalid=invalid,
keys_missing=keys_missing,
duplicates=duplicates,
initial=initial,
)
for data in self.test_data:
data.validate_plot_sync()
async def skip_processing(self: Any, _: WSChiaConnection, message_type: ProtocolMessageTypes, message: Any) -> bool:
self.message_counter += 1
if self.simulate_error == ErrorSimulation.DropEveryFourthMessage:
if self.message_counter % 4 == 0:
return True
if self.simulate_error == ErrorSimulation.DropThreeMessages:
if 2 < self.message_counter < 6:
return True
if self.simulate_error == ErrorSimulation.RespondTooLateEveryFourthMessage:
if self.message_counter % 4 == 0:
await asyncio.sleep(Constants.message_timeout + 1)
return False
if self.simulate_error == ErrorSimulation.RespondTwice:
await self.connection().send_message(
make_msg(
ProtocolMessageTypes.plot_sync_response,
PlotSyncResponse(message.identifier, int16(message_type.value), None),
)
)
if self.simulate_error == ErrorSimulation.NonRecoverableError and self.message_counter > 1:
await self.connection().send_message(
make_msg(
ProtocolMessageTypes.plot_sync_response,
PlotSyncResponse(
message.identifier, int16(message_type.value), PlotSyncError(int16(0), "non recoverable", None)
),
)
)
self.simulate_error = 0
return True
return False
async def _testable_process(
self: Any, peer: WSChiaConnection, message_type: ProtocolMessageTypes, message: Any
) -> None:
if await skip_processing(self, peer, message_type, message):
return
await self.original_process(peer, message_type, message)
async def create_test_runner(
harvester_services: List[Service], farmer: Farmer, event_loop: asyncio.events.AbstractEventLoop
) -> TestRunner:
assert len(farmer.plot_sync_receivers) == 0
harvesters: List[Harvester] = [await start_harvester_service(service) for service in harvester_services]
for receiver in farmer.plot_sync_receivers.values():
receiver.simulate_error = 0 # type: ignore[attr-defined]
receiver.message_counter = 0 # type: ignore[attr-defined]
receiver.original_process = receiver._process # type: ignore[attr-defined]
receiver._process = functools.partial(_testable_process, receiver) # type: ignore[assignment]
return TestRunner(harvesters, farmer, event_loop)
def create_example_plots(count: int) -> List[PlotInfo]:
@dataclass
class DiskProver:
file_name: str
plot_id: bytes32
size: int
def get_filename(self) -> str:
return self.file_name
def get_id(self) -> bytes32:
return self.plot_id
def get_size(self) -> int:
return self.size
return [
PlotInfo(
prover=DiskProver(f"{x}", bytes32(token_bytes(32)), x % 255),
pool_public_key=None,
pool_contract_puzzle_hash=None,
plot_public_key=G1Element(),
file_size=uint64(0),
time_modified=time.time(),
)
for x in range(0, count)
]
@pytest.mark.asyncio
async def test_sync_simulated(
farmer_three_harvester: Tuple[List[Service], Service], event_loop: asyncio.events.AbstractEventLoop
) -> None:
harvester_services: List[Service]
farmer_service: Service
harvester_services, farmer_service = farmer_three_harvester
farmer: Farmer = farmer_service._node
test_runner: TestRunner = await create_test_runner(harvester_services, farmer, event_loop)
plots = create_example_plots(31000)
await test_runner.run(
0, loaded=plots[0:10000], removed=[], invalid=[], keys_missing=[], duplicates=plots[0:1000], initial=True
)
await test_runner.run(
1,
loaded=plots[10000:20000],
removed=[],
invalid=plots[30000:30100],
keys_missing=[],
duplicates=[],
initial=True,
)
await test_runner.run(
2,
loaded=plots[20000:30000],
removed=[],
invalid=[],
keys_missing=plots[30100:30200],
duplicates=[],
initial=True,
)
await test_runner.run(
0,
loaded=[],
removed=[],
invalid=plots[30300:30400],
keys_missing=plots[30400:30453],
duplicates=[],
initial=False,
)
await test_runner.run(0, loaded=[], removed=[], invalid=[], keys_missing=[], duplicates=[], initial=False)
await test_runner.run(
0, loaded=[], removed=plots[5000:10000], invalid=[], keys_missing=[], duplicates=[], initial=False
)
await test_runner.run(
1, loaded=[], removed=plots[10000:20000], invalid=[], keys_missing=[], duplicates=[], initial=False
)
await test_runner.run(
2, loaded=[], removed=plots[20000:29000], invalid=[], keys_missing=[], duplicates=[], initial=False
)
await test_runner.run(
0, loaded=[], removed=plots[0:5000], invalid=[], keys_missing=[], duplicates=[], initial=False
)
await test_runner.run(
2,
loaded=plots[5000:10000],
removed=plots[29000:30000],
invalid=plots[30000:30500],
keys_missing=plots[30500:31000],
duplicates=plots[5000:6000],
initial=False,
)
await test_runner.run(
2, loaded=[], removed=plots[5000:10000], invalid=[], keys_missing=[], duplicates=[], initial=False
)
assert len(farmer.plot_sync_receivers) == 3
for plot_sync in farmer.plot_sync_receivers.values():
assert len(plot_sync.plots()) == 0
@pytest.mark.parametrize(
"simulate_error",
[
ErrorSimulation.DropEveryFourthMessage,
ErrorSimulation.DropThreeMessages,
ErrorSimulation.RespondTooLateEveryFourthMessage,
ErrorSimulation.RespondTwice,
],
)
@pytest.mark.asyncio
async def test_farmer_error_simulation(
farmer_one_harvester: Tuple[List[Service], Service],
event_loop: asyncio.events.AbstractEventLoop,
simulate_error: ErrorSimulation,
) -> None:
Constants.message_timeout = 5
harvester_services: List[Service]
farmer_service: Service
harvester_services, farmer_service = farmer_one_harvester
test_runner: TestRunner = await create_test_runner(harvester_services, farmer_service._node, event_loop)
batch_size = test_runner.test_data[0].harvester.plot_manager.refresh_parameter.batch_size
plots = create_example_plots(batch_size + 3)
receiver = test_runner.test_data[0].plot_sync_receiver
receiver.simulate_error = simulate_error # type: ignore[attr-defined]
await test_runner.run(
0,
loaded=plots[0 : batch_size + 1],
removed=[],
invalid=[plots[batch_size + 1]],
keys_missing=[plots[batch_size + 2]],
duplicates=[],
initial=True,
)
@pytest.mark.parametrize("simulate_error", [ErrorSimulation.NonRecoverableError, ErrorSimulation.NotConnected])
@pytest.mark.asyncio
async def test_sync_reset_cases(
farmer_one_harvester: Tuple[List[Service], Service],
event_loop: asyncio.events.AbstractEventLoop,
simulate_error: ErrorSimulation,
) -> None:
harvester_services: List[Service]
farmer_service: Service
harvester_services, farmer_service = farmer_one_harvester
test_runner: TestRunner = await create_test_runner(harvester_services, farmer_service._node, event_loop)
test_data: TestData = test_runner.test_data[0]
plot_manager: PlotManager = test_data.harvester.plot_manager
plots = create_example_plots(30)
# Inject some data into `PlotManager` of the harvester so that we can validate the reset worked and triggered a
# fresh sync of all available data of the plot manager
for plot_info in plots[0:10]:
test_data.plots[plot_info.prover.get_filename()] = plot_info
plot_manager.plots = test_data.plots
test_data.invalid = plots[10:20]
test_data.keys_missing = plots[20:30]
test_data.plot_sync_receiver.simulate_error = simulate_error # type: ignore[attr-defined]
sender: Sender = test_runner.test_data[0].plot_sync_sender
started_sync_id: uint64 = uint64(0)
plot_manager.failed_to_open_filenames = {p.prover.get_filename(): 0 for p in test_data.invalid}
plot_manager.no_key_filenames = set([p.prover.get_filename() for p in test_data.keys_missing])
async def wait_for_reset() -> bool:
assert started_sync_id != 0
return sender._sync_id != started_sync_id != 0
async def sync_done() -> bool:
assert started_sync_id != 0
return test_data.plot_sync_receiver.last_sync_id() == sender._last_sync_id == started_sync_id
# Send start and capture the sync_id
sender.sync_start(len(plots), True)
started_sync_id = sender._sync_id
# Sleep 2 seconds to make sure we have a different sync_id after the reset which gets triggered
await asyncio.sleep(2)
saved_connection = sender._connection
if simulate_error == ErrorSimulation.NotConnected:
sender._connection = None
sender.process_batch(plots, 0)
await time_out_assert(60, wait_for_reset)
started_sync_id = sender._sync_id
sender._connection = saved_connection
await time_out_assert(60, sync_done)
test_runner.test_data[0].validate_plot_sync()

53
tests/plot_sync/util.py Normal file
View File

@ -0,0 +1,53 @@
import time
from dataclasses import dataclass
from secrets import token_bytes
from typing import Optional
from chia.harvester.harvester_api import Harvester
from chia.plot_sync.sender import Sender
from chia.protocols.harvester_protocol import PlotSyncIdentifier
from chia.server.start_service import Service
from chia.server.ws_connection import Message, NodeType
from chia.types.blockchain_format.sized_bytes import bytes32
from chia.util.ints import uint64
from tests.time_out_assert import time_out_assert
@dataclass
class WSChiaConnectionDummy:
connection_type: NodeType
peer_node_id: bytes32
peer_host: str = "localhost"
peer_port: int = 0
last_sent_message: Optional[Message] = None
async def send_message(self, message: Message) -> None:
self.last_sent_message = message
def get_dummy_connection(node_type: NodeType, peer_id: Optional[bytes32] = None) -> WSChiaConnectionDummy:
return WSChiaConnectionDummy(node_type, bytes32(token_bytes(32)) if peer_id is None else peer_id)
def plot_sync_identifier(current_sync_id: uint64, message_id: uint64) -> PlotSyncIdentifier:
return PlotSyncIdentifier(uint64(int(time.time())), current_sync_id, message_id)
async def start_harvester_service(harvester_service: Service) -> Harvester:
# Set the `last_refresh_time` of the plot manager to avoid initial plot loading
harvester: Harvester = harvester_service._node
harvester.plot_manager.last_refresh_time = time.time()
await harvester_service.start()
harvester.plot_manager.stop_refreshing() # type: ignore[no-untyped-call] # TODO, Add typing in PlotManager
assert harvester.plot_sync_sender._sync_id == 0
assert harvester.plot_sync_sender._next_message_id == 0
assert harvester.plot_sync_sender._last_sync_id == 0
assert harvester.plot_sync_sender._messages == []
def wait_for_farmer_connection(plot_sync_sender: Sender) -> bool:
return plot_sync_sender._connection is not None
await time_out_assert(10, wait_for_farmer_connection, True, harvester.plot_sync_sender)
return harvester

View File

@ -236,7 +236,7 @@ async def test_plot_refreshing(test_plot_environment):
trigger=trigger_remove_plot,
test_path=drop_path,
expect_loaded=[],
expect_removed=[drop_path],
expect_removed=[],
expect_processed=len(env.dir_1) + len(env.dir_2) + len(dir_duplicates),
expect_duplicates=len(dir_duplicates),
expected_directories=3,
@ -262,7 +262,7 @@ async def test_plot_refreshing(test_plot_environment):
trigger=remove_plot_directory,
test_path=dir_duplicates.path,
expect_loaded=[],
expect_removed=dir_duplicates.path_list(),
expect_removed=[],
expect_processed=len(env.dir_1) + len(env.dir_2),
expect_duplicates=0,
expected_directories=2,
@ -316,7 +316,7 @@ async def test_plot_refreshing(test_plot_environment):
trigger=trigger_remove_plot,
test_path=drop_path,
expect_loaded=[],
expect_removed=[drop_path],
expect_removed=[],
expect_processed=len(env.dir_1) + len(env.dir_2) + len(dir_duplicates),
expect_duplicates=len(env.dir_1),
expected_directories=3,
@ -357,6 +357,17 @@ async def test_plot_refreshing(test_plot_environment):
)
@pytest.mark.asyncio
async def test_initial_refresh_flag(test_plot_environment: TestEnvironment) -> None:
env: TestEnvironment = test_plot_environment
assert env.refresh_tester.plot_manager.initial_refresh()
for _ in range(2):
await env.refresh_tester.run(PlotRefreshResult())
assert not env.refresh_tester.plot_manager.initial_refresh()
env.refresh_tester.plot_manager.reset()
assert env.refresh_tester.plot_manager.initial_refresh()
@pytest.mark.asyncio
async def test_invalid_plots(test_plot_environment):
env: TestEnvironment = test_plot_environment

View File

@ -1,12 +1,15 @@
import asyncio
import logging
from secrets import token_bytes
from typing import Dict, List
from typing import AsyncIterator, Dict, List, Tuple
from pathlib import Path
from chia.consensus.constants import ConsensusConstants
from chia.cmds.init_funcs import init
from chia.full_node.full_node_api import FullNodeAPI
from chia.server.start_service import Service
from chia.server.start_wallet import service_kwargs_for_wallet
from chia.util.config import load_config, save_config
from chia.util.hash import std_hash
from chia.util.ints import uint16, uint32
from chia.util.keychain import bytes_to_mnemonic
@ -294,7 +297,7 @@ async def setup_harvester_farmer(bt: BlockTools, consensus_constants: ConsensusC
harvester_rpc_port = find_available_listen_port("harvester rpc")
node_iters = [
setup_harvester(
bt,
bt.root_path,
bt.config["self_hostname"],
harvester_port,
harvester_rpc_port,
@ -320,6 +323,62 @@ async def setup_harvester_farmer(bt: BlockTools, consensus_constants: ConsensusC
await _teardown_nodes(node_iters)
async def setup_farmer_multi_harvester(
block_tools: BlockTools,
harvester_count: int,
temp_dir: Path,
consensus_constants: ConsensusConstants,
) -> AsyncIterator[Tuple[List[Service], Service]]:
farmer_port = find_available_listen_port("farmer")
farmer_rpc_port = find_available_listen_port("farmer rpc")
node_iterators = [
setup_farmer(
block_tools, block_tools.config["self_hostname"], farmer_port, farmer_rpc_port, consensus_constants
)
]
for i in range(0, harvester_count):
root_path: Path = temp_dir / str(i)
init(None, root_path)
init(block_tools.root_path / "config" / "ssl" / "ca", root_path)
config = load_config(root_path, "config.yaml")
config["logging"]["log_stdout"] = True
config["selected_network"] = "testnet0"
config["harvester"]["selected_network"] = "testnet0"
harvester_port = find_available_listen_port("harvester")
harvester_rpc_port = find_available_listen_port("harvester rpc")
save_config(root_path, "config.yaml", config)
node_iterators.append(
setup_harvester(
root_path,
block_tools.config["self_hostname"],
harvester_port,
harvester_rpc_port,
farmer_port,
consensus_constants,
False,
)
)
farmer_service = await node_iterators[0].__anext__()
harvester_services = []
for node in node_iterators[1:]:
harvester_service = await node.__anext__()
harvester_services.append(harvester_service)
yield harvester_services, farmer_service
for harvester_service in harvester_services:
harvester_service.stop()
await harvester_service.wait_closed()
farmer_service.stop()
await farmer_service.wait_closed()
await _teardown_nodes(node_iterators)
async def setup_full_system(
consensus_constants: ConsensusConstants,
shared_b_tools: BlockTools,
@ -353,7 +412,7 @@ async def setup_full_system(
node_iters = [
setup_introducer(shared_b_tools, introducer_port),
setup_harvester(
shared_b_tools,
shared_b_tools.root_path,
shared_b_tools.config["self_hostname"],
harvester_port,
harvester_rpc_port,

View File

@ -2,6 +2,7 @@ import asyncio
import logging
import signal
import sqlite3
from pathlib import Path
from secrets import token_bytes
from typing import AsyncGenerator, Optional
@ -16,8 +17,8 @@ from chia.server.start_timelord import service_kwargs_for_timelord
from chia.server.start_wallet import service_kwargs_for_wallet
from chia.simulator.start_simulator import service_kwargs_for_full_node_simulator
from chia.timelord.timelord_launcher import kill_processes, spawn_process
from chia.types.peer_info import PeerInfo
from chia.util.bech32m import encode_puzzle_hash
from chia.util.config import load_config, save_config
from chia.util.ints import uint16
from chia.util.keychain import bytes_to_mnemonic
from tests.block_tools import BlockTools
@ -184,7 +185,7 @@ async def setup_wallet_node(
async def setup_harvester(
b_tools: BlockTools,
root_path: Path,
self_hostname: str,
port,
rpc_port,
@ -192,15 +193,14 @@ async def setup_harvester(
consensus_constants: ConsensusConstants,
start_service: bool = True,
):
config = b_tools.config["harvester"]
config["port"] = port
config["rpc_port"] = rpc_port
kwargs = service_kwargs_for_harvester(b_tools.root_path, config, consensus_constants)
config = load_config(root_path, "config.yaml")
config["harvester"]["port"] = port
config["harvester"]["rpc_port"] = rpc_port
config["harvester"]["farmer_peer"]["host"] = self_hostname
config["harvester"]["farmer_peer"]["port"] = farmer_port
save_config(root_path, "config.yaml", config)
kwargs = service_kwargs_for_harvester(root_path, config["harvester"], consensus_constants)
kwargs.update(
server_listen_ports=[port],
advertised_port=port,
connect_peers=[PeerInfo(self_hostname, farmer_port)],
parse_cli_args=False,
connect_to_daemon=False,
service_name_prefix="test_",