mirror of
https://github.com/Chia-Network/chia-blockchain.git
synced 2024-11-13 03:12:24 +03:00
421 lines
18 KiB
Python
421 lines
18 KiB
Python
from __future__ import annotations
|
|
|
|
import dataclasses
|
|
import logging
|
|
import random
|
|
import time
|
|
from secrets import token_bytes
|
|
from typing import Any, Callable, List, Tuple, Type, Union
|
|
|
|
import pytest
|
|
from blspy import G1Element
|
|
|
|
from chia.plot_sync.delta import Delta
|
|
from chia.plot_sync.receiver import Receiver, Sync
|
|
from chia.plot_sync.util import ErrorCodes, State
|
|
from chia.protocols.harvester_protocol import (
|
|
Plot,
|
|
PlotSyncDone,
|
|
PlotSyncIdentifier,
|
|
PlotSyncPathList,
|
|
PlotSyncPlotList,
|
|
PlotSyncResponse,
|
|
PlotSyncStart,
|
|
)
|
|
from chia.server.outbound_message import NodeType
|
|
from chia.types.blockchain_format.sized_bytes import bytes32
|
|
from chia.util.ints import uint8, uint32, uint64
|
|
from chia.util.misc import get_list_or_len
|
|
from chia.util.streamable import _T_Streamable
|
|
from tests.plot_sync.util import get_dummy_connection
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
next_message_id = uint64(0)
|
|
|
|
|
|
def assert_default_values(receiver: Receiver) -> None:
|
|
assert receiver.current_sync() == Sync()
|
|
assert receiver.last_sync() == Sync()
|
|
assert receiver.plots() == {}
|
|
assert receiver.invalid() == []
|
|
assert receiver.keys_missing() == []
|
|
assert receiver.duplicates() == []
|
|
|
|
|
|
async def dummy_callback(_: bytes32, __: Delta) -> None:
|
|
pass
|
|
|
|
|
|
class SyncStepData:
|
|
state: State
|
|
function: Any
|
|
payload_type: Any
|
|
args: Any
|
|
|
|
def __init__(
|
|
self, state: State, function: Callable[[_T_Streamable], Any], payload_type: Type[_T_Streamable], *args: Any
|
|
) -> None:
|
|
self.state = state
|
|
self.function = function
|
|
self.payload_type = payload_type
|
|
self.args = args
|
|
|
|
|
|
def plot_sync_identifier(current_sync_id: uint64, message_id: uint64) -> PlotSyncIdentifier:
|
|
return PlotSyncIdentifier(uint64(0), current_sync_id, message_id)
|
|
|
|
|
|
def create_payload(payload_type: Any, start: bool, *args: Any) -> Any:
|
|
global next_message_id
|
|
if start:
|
|
next_message_id = uint64(0)
|
|
next_identifier = plot_sync_identifier(uint64(1), next_message_id)
|
|
next_message_id = uint64(next_message_id + 1)
|
|
return payload_type(next_identifier, *args)
|
|
|
|
|
|
def assert_error_response(plot_sync: Receiver, error_code: ErrorCodes) -> None:
|
|
connection = plot_sync.connection()
|
|
assert connection is not None
|
|
# WSChiaConnection doesn't have last_sent_message its part of the WSChiaConnectionDummy class used for testing
|
|
message = connection.last_sent_message # type: ignore[attr-defined]
|
|
assert message is not None
|
|
response: PlotSyncResponse = PlotSyncResponse.from_bytes(message.data)
|
|
assert response.error is not None
|
|
assert response.error.code == error_code.value
|
|
|
|
|
|
def pre_function_validate(receiver: Receiver, data: Union[List[Plot], List[str]], expected_state: State) -> None:
|
|
if expected_state == State.loaded:
|
|
for plot_info in data:
|
|
assert type(plot_info) == Plot
|
|
assert plot_info.filename not in receiver.plots()
|
|
elif expected_state == State.removed:
|
|
for path in data:
|
|
assert path in receiver.plots()
|
|
elif expected_state == State.invalid:
|
|
for path in data:
|
|
assert path not in receiver.invalid()
|
|
elif expected_state == State.keys_missing:
|
|
for path in data:
|
|
assert path not in receiver.keys_missing()
|
|
elif expected_state == State.duplicates:
|
|
for path in data:
|
|
assert path not in receiver.duplicates()
|
|
|
|
|
|
def post_function_validate(receiver: Receiver, data: Union[List[Plot], List[str]], expected_state: State) -> None:
|
|
if expected_state == State.loaded:
|
|
for plot_info in data:
|
|
assert type(plot_info) == Plot
|
|
assert plot_info.filename in receiver._current_sync.delta.valid.additions
|
|
elif expected_state == State.removed:
|
|
for path in data:
|
|
assert path in receiver._current_sync.delta.valid.removals
|
|
elif expected_state == State.invalid:
|
|
for path in data:
|
|
assert path in receiver._current_sync.delta.invalid.additions
|
|
elif expected_state == State.keys_missing:
|
|
for path in data:
|
|
assert path in receiver._current_sync.delta.keys_missing.additions
|
|
elif expected_state == State.duplicates:
|
|
for path in data:
|
|
assert path in receiver._current_sync.delta.duplicates.additions
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def run_sync_step(receiver: Receiver, sync_step: SyncStepData) -> None:
|
|
assert receiver.current_sync().state == sync_step.state
|
|
last_sync_time_before = receiver._last_sync.time_done
|
|
# For the the list types invoke the trigger function in batches
|
|
if sync_step.payload_type == PlotSyncPlotList or sync_step.payload_type == PlotSyncPathList:
|
|
step_data, _ = sync_step.args
|
|
assert len(step_data) == 10
|
|
# Invoke batches of: 1, 2, 3, 4 items and validate the data against plot store before and after
|
|
indexes = [0, 1, 3, 6, 10]
|
|
for i in range(0, len(indexes) - 1):
|
|
plots_processed_before = receiver.current_sync().plots_processed
|
|
invoke_data = step_data[indexes[i] : indexes[i + 1]]
|
|
pre_function_validate(receiver, invoke_data, sync_step.state)
|
|
await sync_step.function(
|
|
create_payload(sync_step.payload_type, False, invoke_data, i == (len(indexes) - 2))
|
|
)
|
|
post_function_validate(receiver, invoke_data, sync_step.state)
|
|
if sync_step.state == State.removed:
|
|
assert receiver.current_sync().plots_processed == plots_processed_before
|
|
else:
|
|
assert receiver.current_sync().plots_processed == plots_processed_before + len(invoke_data)
|
|
else:
|
|
# For Start/Done just invoke it..
|
|
await sync_step.function(create_payload(sync_step.payload_type, sync_step.state == State.idle, *sync_step.args))
|
|
# Make sure we moved to the next state
|
|
assert receiver.current_sync().state != sync_step.state
|
|
if sync_step.payload_type == PlotSyncDone:
|
|
assert receiver._last_sync.time_done != last_sync_time_before
|
|
assert receiver.last_sync().plots_processed == receiver.last_sync().plots_total
|
|
else:
|
|
assert receiver._last_sync.time_done == last_sync_time_before
|
|
|
|
|
|
def plot_sync_setup() -> Tuple[Receiver, List[SyncStepData]]:
|
|
harvester_connection = get_dummy_connection(NodeType.HARVESTER)
|
|
receiver = Receiver(harvester_connection, dummy_callback) # type:ignore[arg-type]
|
|
|
|
# Create example plot data
|
|
path_list = [str(x) for x in range(0, 40)]
|
|
plot_info_list = [
|
|
Plot(
|
|
filename=str(x),
|
|
size=uint8(0),
|
|
plot_id=bytes32(token_bytes(32)),
|
|
pool_contract_puzzle_hash=None,
|
|
pool_public_key=None,
|
|
plot_public_key=G1Element(),
|
|
file_size=uint64(random.randint(0, 100)),
|
|
time_modified=uint64(0),
|
|
)
|
|
for x in path_list
|
|
]
|
|
|
|
# Manually add the plots we want to remove in tests
|
|
receiver._plots = {plot_info.filename: plot_info for plot_info in plot_info_list[0:10]}
|
|
receiver._total_plot_size = sum(plot.file_size for plot in receiver._plots.values())
|
|
|
|
sync_steps: List[SyncStepData] = [
|
|
SyncStepData(State.idle, receiver.sync_started, PlotSyncStart, False, uint64(0), uint32(len(plot_info_list))),
|
|
SyncStepData(State.loaded, receiver.process_loaded, PlotSyncPlotList, plot_info_list[10:20], True),
|
|
SyncStepData(State.removed, receiver.process_removed, PlotSyncPathList, path_list[0:10], True),
|
|
SyncStepData(State.invalid, receiver.process_invalid, PlotSyncPathList, path_list[20:30], True),
|
|
SyncStepData(State.keys_missing, receiver.process_keys_missing, PlotSyncPathList, path_list[30:40], True),
|
|
SyncStepData(State.duplicates, receiver.process_duplicates, PlotSyncPathList, path_list[10:20], True),
|
|
SyncStepData(State.done, receiver.sync_done, PlotSyncDone, uint64(0)),
|
|
]
|
|
|
|
return receiver, sync_steps
|
|
|
|
|
|
def test_default_values() -> None:
|
|
assert_default_values(Receiver(get_dummy_connection(NodeType.HARVESTER), dummy_callback)) # type:ignore[arg-type]
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_reset() -> None:
|
|
receiver, sync_steps = plot_sync_setup()
|
|
connection_before = receiver.connection()
|
|
# Assign some dummy values
|
|
receiver._current_sync.state = State.done
|
|
receiver._current_sync.sync_id = uint64(1)
|
|
receiver._current_sync.next_message_id = uint64(1)
|
|
receiver._current_sync.plots_processed = uint32(1)
|
|
receiver._current_sync.plots_total = uint32(1)
|
|
receiver._current_sync.delta.valid.additions = receiver.plots().copy()
|
|
receiver._current_sync.delta.valid.removals = ["1"]
|
|
receiver._current_sync.delta.invalid.additions = ["1"]
|
|
receiver._current_sync.delta.invalid.removals = ["1"]
|
|
receiver._current_sync.delta.keys_missing.additions = ["1"]
|
|
receiver._current_sync.delta.keys_missing.removals = ["1"]
|
|
receiver._current_sync.delta.duplicates.additions = ["1"]
|
|
receiver._current_sync.delta.duplicates.removals = ["1"]
|
|
receiver._current_sync.time_done = time.time()
|
|
receiver._last_sync = dataclasses.replace(receiver._current_sync)
|
|
receiver._invalid = ["1"]
|
|
receiver._keys_missing = ["1"]
|
|
receiver._duplicates = ["1"]
|
|
|
|
receiver._last_sync.sync_id = uint64(1)
|
|
# Call `reset` and make sure all expected values are set back to their defaults.
|
|
receiver.reset()
|
|
assert_default_values(receiver)
|
|
assert receiver._current_sync.delta == Delta()
|
|
# Connection should remain
|
|
assert receiver.connection() == connection_before
|
|
|
|
|
|
@pytest.mark.parametrize("counts_only", [True, False])
|
|
@pytest.mark.asyncio
|
|
async def test_to_dict(counts_only: bool) -> None:
|
|
receiver, sync_steps = plot_sync_setup()
|
|
plot_sync_dict_1 = receiver.to_dict(counts_only)
|
|
|
|
assert get_list_or_len(plot_sync_dict_1["plots"], not counts_only) == 10
|
|
assert get_list_or_len(plot_sync_dict_1["failed_to_open_filenames"], not counts_only) == 0
|
|
assert get_list_or_len(plot_sync_dict_1["no_key_filenames"], not counts_only) == 0
|
|
assert get_list_or_len(plot_sync_dict_1["duplicates"], not counts_only) == 0
|
|
assert plot_sync_dict_1["total_plot_size"] == sum(plot.file_size for plot in receiver.plots().values())
|
|
assert plot_sync_dict_1["syncing"] is None
|
|
assert plot_sync_dict_1["last_sync_time"] is None
|
|
assert plot_sync_dict_1["connection"] == {
|
|
"node_id": receiver.connection().peer_node_id,
|
|
"host": receiver.connection().peer_host,
|
|
"port": receiver.connection().peer_port,
|
|
}
|
|
|
|
# We should get equal dicts
|
|
assert plot_sync_dict_1 == receiver.to_dict(counts_only)
|
|
# But unequal dicts wit the opposite counts_only value
|
|
assert plot_sync_dict_1 != receiver.to_dict(not counts_only)
|
|
|
|
expected_plot_files_processed: int = 0
|
|
expected_plot_files_total: int = sync_steps[State.idle].args[2]
|
|
|
|
# Walk through all states from idle to done and run them with the test data and validate the sync progress
|
|
for state in State:
|
|
await run_sync_step(receiver, sync_steps[state])
|
|
|
|
if state != State.idle and state != State.removed and state != State.done:
|
|
expected_plot_files_processed += len(sync_steps[state].args[0])
|
|
|
|
sync_data = receiver.to_dict()["syncing"]
|
|
if state == State.done:
|
|
expected_sync_data = None
|
|
else:
|
|
expected_sync_data = {
|
|
"initial": True,
|
|
"plot_files_processed": expected_plot_files_processed,
|
|
"plot_files_total": expected_plot_files_total,
|
|
}
|
|
assert sync_data == expected_sync_data
|
|
|
|
plot_sync_dict_3 = receiver.to_dict(counts_only)
|
|
assert get_list_or_len(sync_steps[State.loaded].args[0], counts_only) == plot_sync_dict_3["plots"]
|
|
assert (
|
|
get_list_or_len(sync_steps[State.invalid].args[0], counts_only) == plot_sync_dict_3["failed_to_open_filenames"]
|
|
)
|
|
assert get_list_or_len(sync_steps[State.keys_missing].args[0], counts_only) == plot_sync_dict_3["no_key_filenames"]
|
|
assert get_list_or_len(sync_steps[State.duplicates].args[0], counts_only) == plot_sync_dict_3["duplicates"]
|
|
|
|
assert plot_sync_dict_3["total_plot_size"] == sum(plot.file_size for plot in receiver.plots().values())
|
|
assert plot_sync_dict_3["last_sync_time"] > 0
|
|
assert plot_sync_dict_3["syncing"] is None
|
|
|
|
# Trigger a repeated plot sync
|
|
await receiver.sync_started(
|
|
PlotSyncStart(
|
|
PlotSyncIdentifier(uint64(time.time()), uint64(receiver.last_sync().sync_id + 1), uint64(0)),
|
|
False,
|
|
receiver.last_sync().sync_id,
|
|
uint32(1),
|
|
)
|
|
)
|
|
assert receiver.to_dict()["syncing"] == {
|
|
"initial": False,
|
|
"plot_files_processed": 0,
|
|
"plot_files_total": 1,
|
|
}
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_sync_flow() -> None:
|
|
receiver, sync_steps = plot_sync_setup()
|
|
|
|
for plot_info in sync_steps[State.loaded].args[0]:
|
|
assert plot_info.filename not in receiver.plots()
|
|
|
|
for path in sync_steps[State.removed].args[0]:
|
|
assert path in receiver.plots()
|
|
|
|
for path in sync_steps[State.invalid].args[0]:
|
|
assert path not in receiver.invalid()
|
|
|
|
for path in sync_steps[State.keys_missing].args[0]:
|
|
assert path not in receiver.keys_missing()
|
|
|
|
for path in sync_steps[State.duplicates].args[0]:
|
|
assert path not in receiver.duplicates()
|
|
|
|
# Walk through all states from idle to done and run them with the test data
|
|
for state in State:
|
|
await run_sync_step(receiver, sync_steps[state])
|
|
|
|
for plot_info in sync_steps[State.loaded].args[0]:
|
|
assert plot_info.filename in receiver.plots()
|
|
|
|
for path in sync_steps[State.removed].args[0]:
|
|
assert path not in receiver.plots()
|
|
|
|
for path in sync_steps[State.invalid].args[0]:
|
|
assert path in receiver.invalid()
|
|
|
|
for path in sync_steps[State.keys_missing].args[0]:
|
|
assert path in receiver.keys_missing()
|
|
|
|
for path in sync_steps[State.duplicates].args[0]:
|
|
assert path in receiver.duplicates()
|
|
|
|
# We should be in idle state again
|
|
assert receiver.current_sync().state == State.idle
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_invalid_ids() -> None:
|
|
receiver, sync_steps = plot_sync_setup()
|
|
for state in State:
|
|
assert receiver.current_sync().state == state
|
|
current_step = sync_steps[state]
|
|
if receiver.current_sync().state == State.idle:
|
|
# Set last_sync_id for the tests below
|
|
receiver._last_sync.sync_id = uint64(1)
|
|
# Test "sync_started last doesn't match"
|
|
invalid_last_sync_id_param = PlotSyncStart(
|
|
plot_sync_identifier(uint64(0), uint64(0)), False, uint64(2), uint32(0)
|
|
)
|
|
await current_step.function(invalid_last_sync_id_param)
|
|
assert_error_response(receiver, ErrorCodes.invalid_last_sync_id)
|
|
# Test "last_sync_id == new_sync_id"
|
|
invalid_sync_id_match_param = PlotSyncStart(
|
|
plot_sync_identifier(uint64(1), uint64(0)), False, uint64(1), uint32(0)
|
|
)
|
|
await current_step.function(invalid_sync_id_match_param)
|
|
assert_error_response(receiver, ErrorCodes.sync_ids_match)
|
|
# Reset the last_sync_id to the default
|
|
receiver._last_sync.sync_id = uint64(0)
|
|
else:
|
|
# Test invalid sync_id
|
|
invalid_sync_id_param = current_step.payload_type(
|
|
plot_sync_identifier(uint64(10), uint64(receiver.current_sync().next_message_id)), *current_step.args
|
|
)
|
|
await current_step.function(invalid_sync_id_param)
|
|
assert_error_response(receiver, ErrorCodes.invalid_identifier)
|
|
# Test invalid message_id
|
|
invalid_message_id_param = current_step.payload_type(
|
|
plot_sync_identifier(
|
|
receiver.current_sync().sync_id, uint64(receiver.current_sync().next_message_id + 1)
|
|
),
|
|
*current_step.args,
|
|
)
|
|
await current_step.function(invalid_message_id_param)
|
|
assert_error_response(receiver, ErrorCodes.invalid_identifier)
|
|
payload = create_payload(current_step.payload_type, state == State.idle, *current_step.args)
|
|
await current_step.function(payload)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
["state_to_fail", "expected_error_code"],
|
|
[
|
|
pytest.param(State.loaded, ErrorCodes.plot_already_available, id="already available plots"),
|
|
pytest.param(State.invalid, ErrorCodes.plot_already_available, id="already available paths"),
|
|
pytest.param(State.removed, ErrorCodes.plot_not_available, id="not available"),
|
|
],
|
|
)
|
|
@pytest.mark.asyncio
|
|
async def test_plot_errors(state_to_fail: State, expected_error_code: ErrorCodes) -> None:
|
|
receiver, sync_steps = plot_sync_setup()
|
|
for state in State:
|
|
assert receiver.current_sync().state == state
|
|
current_step = sync_steps[state]
|
|
if state == state_to_fail:
|
|
plot_infos, _ = current_step.args
|
|
await current_step.function(create_payload(current_step.payload_type, False, plot_infos, False))
|
|
identifier = plot_sync_identifier(receiver.current_sync().sync_id, receiver.current_sync().next_message_id)
|
|
invalid_payload = current_step.payload_type(identifier, plot_infos, True)
|
|
await current_step.function(invalid_payload)
|
|
if state == state_to_fail:
|
|
assert_error_response(receiver, expected_error_code)
|
|
return
|
|
else:
|
|
await current_step.function(
|
|
create_payload(current_step.payload_type, state == State.idle, *current_step.args)
|
|
)
|
|
assert False, "Didn't fail in the expected state"
|