mirror of
https://github.com/Chia-Network/chia-blockchain.git
synced 2024-11-09 17:36:14 +03:00
3a5c1a5bd5
* hint chia.harvester.harvester * hint chia.harvester.harvester_api
331 lines
14 KiB
Python
331 lines
14 KiB
Python
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import logging
|
|
import time
|
|
import traceback
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
from typing import Any, Generic, Iterable, List, Optional, Tuple, Type, TypeVar
|
|
|
|
from typing_extensions import Protocol
|
|
|
|
from chia.plot_sync.exceptions import AlreadyStartedError, InvalidConnectionTypeError
|
|
from chia.plot_sync.util import Constants
|
|
from chia.plotting.manager import PlotManager
|
|
from chia.plotting.util import PlotInfo
|
|
from chia.protocols.harvester_protocol import (
|
|
Plot,
|
|
PlotSyncDone,
|
|
PlotSyncIdentifier,
|
|
PlotSyncPathList,
|
|
PlotSyncPlotList,
|
|
PlotSyncResponse,
|
|
PlotSyncStart,
|
|
)
|
|
from chia.protocols.protocol_message_types import ProtocolMessageTypes
|
|
from chia.server.outbound_message import NodeType, make_msg
|
|
from chia.server.ws_connection import WSChiaConnection
|
|
from chia.util.generator_tools import list_to_batches
|
|
from chia.util.ints import int16, uint32, uint64
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
def _convert_plot_info_list(plot_infos: List[PlotInfo]) -> List[Plot]:
|
|
converted: List[Plot] = []
|
|
for plot_info in plot_infos:
|
|
converted.append(
|
|
Plot(
|
|
filename=plot_info.prover.get_filename(),
|
|
size=plot_info.prover.get_size(),
|
|
plot_id=plot_info.prover.get_id(),
|
|
pool_public_key=plot_info.pool_public_key,
|
|
pool_contract_puzzle_hash=plot_info.pool_contract_puzzle_hash,
|
|
plot_public_key=plot_info.plot_public_key,
|
|
file_size=uint64(plot_info.file_size),
|
|
time_modified=uint64(int(plot_info.time_modified)),
|
|
)
|
|
)
|
|
return converted
|
|
|
|
|
|
class PayloadType(Protocol):
|
|
def __init__(self, identifier: PlotSyncIdentifier, *args: object) -> None:
|
|
...
|
|
|
|
def __bytes__(self) -> bytes:
|
|
pass
|
|
|
|
|
|
T = TypeVar("T", bound=PayloadType)
|
|
|
|
|
|
@dataclass
|
|
class MessageGenerator(Generic[T]):
|
|
sync_id: uint64
|
|
message_type: ProtocolMessageTypes
|
|
message_id: uint64
|
|
payload_type: Type[T]
|
|
args: Iterable[object]
|
|
|
|
def generate(self) -> Tuple[PlotSyncIdentifier, T]:
|
|
identifier = PlotSyncIdentifier(uint64(int(time.time())), self.sync_id, self.message_id)
|
|
payload = self.payload_type(identifier, *self.args)
|
|
return identifier, payload
|
|
|
|
|
|
@dataclass
|
|
class ExpectedResponse:
|
|
message_type: ProtocolMessageTypes
|
|
identifier: PlotSyncIdentifier
|
|
message: Optional[PlotSyncResponse] = None
|
|
|
|
def __str__(self) -> str:
|
|
return (
|
|
f"expected_message_type: {self.message_type.name}, "
|
|
f"expected_identifier: {self.identifier}, message {self.message}"
|
|
)
|
|
|
|
|
|
class Sender:
|
|
_plot_manager: PlotManager
|
|
_connection: Optional[WSChiaConnection]
|
|
_sync_id: uint64
|
|
_next_message_id: uint64
|
|
_messages: List[MessageGenerator[PayloadType]]
|
|
_last_sync_id: uint64
|
|
_stop_requested = False
|
|
_task: Optional[asyncio.Task[None]]
|
|
_response: Optional[ExpectedResponse]
|
|
|
|
def __init__(self, plot_manager: PlotManager) -> None:
|
|
self._plot_manager = plot_manager
|
|
self._connection = None
|
|
self._sync_id = uint64(0)
|
|
self._next_message_id = uint64(0)
|
|
self._messages = []
|
|
self._last_sync_id = uint64(0)
|
|
self._stop_requested = False
|
|
self._task = None
|
|
self._response = None
|
|
|
|
def __str__(self) -> str:
|
|
return f"sync_id {self._sync_id}, next_message_id {self._next_message_id}, messages {len(self._messages)}"
|
|
|
|
async def start(self) -> None:
|
|
if self._task is not None and self._stop_requested:
|
|
await self.await_closed()
|
|
if self._task is None:
|
|
self._task = asyncio.create_task(self._run())
|
|
if not self._plot_manager.initial_refresh() or self._sync_id != 0:
|
|
self._reset()
|
|
else:
|
|
raise AlreadyStartedError()
|
|
|
|
def stop(self) -> None:
|
|
self._stop_requested = True
|
|
|
|
async def await_closed(self) -> None:
|
|
if self._task is not None:
|
|
await self._task
|
|
self._task = None
|
|
self._reset()
|
|
self._stop_requested = False
|
|
|
|
def set_connection(self, connection: WSChiaConnection) -> None:
|
|
assert connection.connection_type is not None
|
|
if connection.connection_type != NodeType.FARMER:
|
|
raise InvalidConnectionTypeError(connection.connection_type, NodeType.HARVESTER)
|
|
self._connection = connection
|
|
|
|
def bump_next_message_id(self) -> None:
|
|
self._next_message_id = uint64(self._next_message_id + 1)
|
|
|
|
def _reset(self) -> None:
|
|
log.debug(f"_reset {self}")
|
|
self._last_sync_id = uint64(0)
|
|
self._sync_id = uint64(0)
|
|
self._next_message_id = uint64(0)
|
|
self._messages.clear()
|
|
if self._task is not None:
|
|
self.sync_start(self._plot_manager.plot_count(), True)
|
|
for remaining, batch in list_to_batches(
|
|
list(self._plot_manager.plots.values()), self._plot_manager.refresh_parameter.batch_size
|
|
):
|
|
self.process_batch(batch, remaining)
|
|
self.sync_done([], 0)
|
|
|
|
async def _wait_for_response(self) -> bool:
|
|
start = time.time()
|
|
assert self._response is not None
|
|
while time.time() - start < Constants.message_timeout and self._response.message is None:
|
|
await asyncio.sleep(0.1)
|
|
return self._response.message is not None
|
|
|
|
def set_response(self, response: PlotSyncResponse) -> bool:
|
|
if self._response is None or self._response.message is not None:
|
|
log.warning(f"set_response skip unexpected response: {response}")
|
|
return False
|
|
if time.time() - float(response.identifier.timestamp) > Constants.message_timeout:
|
|
log.warning(f"set_response skip expired response: {response}")
|
|
return False
|
|
if response.identifier.sync_id != self._response.identifier.sync_id:
|
|
log.warning(
|
|
"set_response unexpected sync-id: {response.identifier.sync_id}/{self._response.identifier.sync_id}"
|
|
)
|
|
return False
|
|
if response.identifier.message_id != self._response.identifier.message_id:
|
|
log.warning(
|
|
"set_response unexpected message-id: "
|
|
f"{response.identifier.message_id}/{self._response.identifier.message_id}"
|
|
)
|
|
return False
|
|
if response.message_type != int16(self._response.message_type.value):
|
|
log.warning(
|
|
"set_response unexpected message-type: {response.message_type}/{self._response.message_type.value}"
|
|
)
|
|
return False
|
|
log.debug(f"set_response valid {response}")
|
|
self._response.message = response
|
|
return True
|
|
|
|
def _add_message(self, message_type: ProtocolMessageTypes, payload_type: Any, *args: Any) -> None:
|
|
assert self._sync_id != 0
|
|
message_id = uint64(len(self._messages))
|
|
self._messages.append(MessageGenerator(self._sync_id, message_type, message_id, payload_type, args))
|
|
|
|
async def _send_next_message(self) -> bool:
|
|
def failed(message: str) -> bool:
|
|
# By forcing a reset we try to get back into a normal state if some not recoverable failure came up.
|
|
log.warning(message)
|
|
self._reset()
|
|
return False
|
|
|
|
assert len(self._messages) >= self._next_message_id
|
|
message_generator = self._messages[self._next_message_id]
|
|
identifier, payload = message_generator.generate()
|
|
if self._sync_id == 0 or identifier.sync_id != self._sync_id or identifier.message_id != self._next_message_id:
|
|
return failed(f"Invalid message generator {message_generator} for {self}")
|
|
|
|
self._response = ExpectedResponse(message_generator.message_type, identifier)
|
|
log.debug(f"_send_next_message send {message_generator.message_type.name}: {payload}")
|
|
if self._connection is None or not await self._connection.send_message(
|
|
make_msg(message_generator.message_type, payload)
|
|
):
|
|
return failed(f"Send failed {self._connection}")
|
|
if not await self._wait_for_response():
|
|
log.info(f"_send_next_message didn't receive response {self._response}")
|
|
return False
|
|
|
|
assert self._response.message is not None
|
|
if self._response.message.error is not None:
|
|
recovered = False
|
|
expected = self._response.message.error.expected_identifier
|
|
# If we have a recoverable error there is a `expected_identifier` included
|
|
if expected is not None:
|
|
# If the receiver has a zero sync/message id and we already sent all messages from the current event
|
|
# we most likely missed the response to the done message. We can finalize the sync and move on here.
|
|
all_sent = (
|
|
self._messages[-1].message_type == ProtocolMessageTypes.plot_sync_done
|
|
and self._next_message_id == len(self._messages) - 1
|
|
)
|
|
if expected.sync_id == expected.message_id == 0 and all_sent:
|
|
self._finalize_sync()
|
|
recovered = True
|
|
elif self._sync_id == expected.sync_id and expected.message_id < len(self._messages):
|
|
self._next_message_id = expected.message_id
|
|
recovered = True
|
|
if not recovered:
|
|
return failed(f"Not recoverable error {self._response.message}")
|
|
return True
|
|
|
|
if self._response.message_type == ProtocolMessageTypes.plot_sync_done:
|
|
self._finalize_sync()
|
|
else:
|
|
self.bump_next_message_id()
|
|
|
|
return True
|
|
|
|
def _add_list_batched(self, message_type: ProtocolMessageTypes, payload_type: Any, data: List[Any]) -> None:
|
|
if len(data) == 0:
|
|
self._add_message(message_type, payload_type, [], True)
|
|
return
|
|
for remaining, batch in list_to_batches(data, self._plot_manager.refresh_parameter.batch_size):
|
|
self._add_message(message_type, payload_type, batch, remaining == 0)
|
|
|
|
def sync_start(self, count: float, initial: bool) -> None:
|
|
log.debug(f"sync_start {self}: count {count}, initial {initial}")
|
|
while self.sync_active():
|
|
if self._stop_requested:
|
|
log.debug("sync_start aborted")
|
|
return
|
|
time.sleep(0.1)
|
|
sync_id = int(time.time())
|
|
# Make sure we have unique sync-id's even if we restart refreshing within a second (i.e. in tests)
|
|
if sync_id == self._last_sync_id:
|
|
sync_id = sync_id + 1
|
|
log.debug(f"sync_start {sync_id}")
|
|
self._sync_id = uint64(sync_id)
|
|
self._add_message(
|
|
ProtocolMessageTypes.plot_sync_start, PlotSyncStart, initial, self._last_sync_id, uint32(int(count))
|
|
)
|
|
|
|
def process_batch(self, loaded: List[PlotInfo], remaining: int) -> None:
|
|
log.debug(f"process_batch {self}: loaded {len(loaded)}, remaining {remaining}")
|
|
if len(loaded) > 0 or remaining == 0:
|
|
converted = _convert_plot_info_list(loaded)
|
|
self._add_message(ProtocolMessageTypes.plot_sync_loaded, PlotSyncPlotList, converted, remaining == 0)
|
|
|
|
def sync_done(self, removed: List[Path], duration: float) -> None:
|
|
log.debug(f"sync_done {self}: removed {len(removed)}, duration {duration}")
|
|
removed_list = [str(x) for x in removed]
|
|
self._add_list_batched(
|
|
ProtocolMessageTypes.plot_sync_removed,
|
|
PlotSyncPathList,
|
|
removed_list,
|
|
)
|
|
failed_to_open_list = [str(x) for x in list(self._plot_manager.failed_to_open_filenames)]
|
|
self._add_list_batched(ProtocolMessageTypes.plot_sync_invalid, PlotSyncPathList, failed_to_open_list)
|
|
no_key_list = [str(x) for x in self._plot_manager.no_key_filenames]
|
|
self._add_list_batched(ProtocolMessageTypes.plot_sync_keys_missing, PlotSyncPathList, no_key_list)
|
|
duplicates_list = self._plot_manager.get_duplicates().copy()
|
|
self._add_list_batched(ProtocolMessageTypes.plot_sync_duplicates, PlotSyncPathList, duplicates_list)
|
|
self._add_message(ProtocolMessageTypes.plot_sync_done, PlotSyncDone, uint64(int(duration)))
|
|
|
|
def _finalize_sync(self) -> None:
|
|
log.debug(f"_finalize_sync {self}")
|
|
assert self._sync_id != 0
|
|
self._last_sync_id = self._sync_id
|
|
self._next_message_id = uint64(0)
|
|
self._messages.clear()
|
|
# Do this at the end since `_sync_id` is used as sync active indicator.
|
|
self._sync_id = uint64(0)
|
|
|
|
def sync_active(self) -> bool:
|
|
return self._sync_id != 0
|
|
|
|
def connected(self) -> bool:
|
|
return self._connection is not None
|
|
|
|
async def _run(self) -> None:
|
|
"""
|
|
This is the sender task responsible to send new messages during sync as they come into Sender._messages
|
|
triggered by the plot manager callback.
|
|
"""
|
|
while not self._stop_requested:
|
|
try:
|
|
while not self.connected() or not self.sync_active():
|
|
if self._stop_requested:
|
|
return
|
|
await asyncio.sleep(0.1)
|
|
while not self._stop_requested and self.sync_active():
|
|
if self._next_message_id >= len(self._messages):
|
|
await asyncio.sleep(0.1)
|
|
continue
|
|
if not await self._send_next_message():
|
|
await asyncio.sleep(Constants.message_timeout)
|
|
except Exception as e:
|
|
log.error(f"Exception: {e} {traceback.format_exc()}")
|
|
self._reset()
|