mirror of
https://github.com/Chia-Network/chia-blockchain.git
synced 2024-11-13 12:09:25 +03:00
Basic full node networking tests
This commit is contained in:
parent
e1c1f6c341
commit
273fe0355b
@ -492,7 +492,7 @@ class FullNode:
|
||||
added_eos.reward_chain.end_of_slot_vdf.challenge,
|
||||
)
|
||||
msg = Message("new_signage_point_or_end_of_sub_slot", broadcast)
|
||||
self.server.send_to_all([msg], NodeType.FullNode)
|
||||
await self.server.send_to_all([msg], NodeType.FullNode)
|
||||
|
||||
if new_peak.height % 1000 == 0:
|
||||
# Occasionally clear the seen list to keep it small
|
||||
@ -511,7 +511,7 @@ class FullNode:
|
||||
sub_block.reward_chain_sub_block.get_unfinished().get_hash(),
|
||||
),
|
||||
)
|
||||
self.server.send_to_all([msg], NodeType.FULL_NODE)
|
||||
await self.server.send_to_all([msg], NodeType.FULL_NODE)
|
||||
|
||||
# Tell wallets about the new peak
|
||||
msg = Message(
|
||||
@ -523,7 +523,7 @@ class FullNode:
|
||||
fork_height,
|
||||
),
|
||||
)
|
||||
self.server.send_to_all([msg], NodeType.WALLET)
|
||||
await self.server.send_to_all([msg], NodeType.WALLET)
|
||||
|
||||
elif added == ReceiveBlockResult.ADDED_AS_ORPHAN:
|
||||
self.log.info(f"Received orphan block of height {sub_block.height}")
|
||||
@ -607,12 +607,12 @@ class FullNode:
|
||||
)
|
||||
|
||||
msg = Message("new_unfinished_sub_block", timelord_request)
|
||||
self.server.send_to_all([msg], NodeType.TIMELORD)
|
||||
await self.server.send_to_all([msg], NodeType.TIMELORD)
|
||||
|
||||
full_node_request = full_node_protocol.NewUnfinishedSubBlock(block.reward_chain_sub_block.get_hash())
|
||||
msg = Message("new_unfinished_sub_block", full_node_request)
|
||||
if peer is not None:
|
||||
self.server.send_to_all_except([msg], NodeType.FULL_NODE, peer.peer_node_id)
|
||||
await self.server.send_to_all_except([msg], NodeType.FULL_NODE, peer.peer_node_id)
|
||||
else:
|
||||
self.server.send_to_all([msg], NodeType.FULL_NODE)
|
||||
await self.server.send_to_all([msg], NodeType.FULL_NODE)
|
||||
self._state_changed("sub_block")
|
||||
|
@ -307,7 +307,7 @@ class FullNodeAPI:
|
||||
request.reward_chain_vdf.challenge_hash,
|
||||
)
|
||||
msg = Message("new_signage_point_or_end_of_sub_slot", broadcast)
|
||||
self.server.send_to_all_except([msg], NodeType.FULL_NODE, peer.peer_node_id)
|
||||
await self.server.send_to_all_except([msg], NodeType.FULL_NODE, peer.peer_node_id)
|
||||
|
||||
return
|
||||
|
||||
@ -328,7 +328,7 @@ class FullNodeAPI:
|
||||
request.end_of_slot_bundle.reward_chain.end_of_slot_vdf.challenge_hash,
|
||||
)
|
||||
msg = Message("new_signage_point_or_end_of_sub_slot", broadcast)
|
||||
self.server.send_to_all_except([msg], NodeType.FULL_NODE, peer.peer_node_id)
|
||||
await self.server.send_to_all_except([msg], NodeType.FULL_NODE, peer.peer_node_id)
|
||||
|
||||
return
|
||||
|
||||
|
@ -3,7 +3,7 @@ import logging
|
||||
import ssl
|
||||
from asyncio import Queue
|
||||
from pathlib import Path
|
||||
from typing import Any, List, Dict, Tuple, Callable, Optional
|
||||
from typing import Any, List, Dict, Tuple, Callable, Optional, Set
|
||||
|
||||
import aiohttp
|
||||
from aiohttp import web
|
||||
@ -29,21 +29,15 @@ def ssl_context_for_server(
|
||||
private_cert_path: Path, private_key_path: Path, require_cert: bool = False
|
||||
) -> Optional[ssl.SSLContext]:
|
||||
ssl_context = ssl._create_unverified_context(purpose=ssl.Purpose.CLIENT_AUTH)
|
||||
ssl_context.load_cert_chain(
|
||||
certfile=str(private_cert_path), keyfile=str(private_key_path)
|
||||
)
|
||||
ssl_context.load_cert_chain(certfile=str(private_cert_path), keyfile=str(private_key_path))
|
||||
ssl_context.load_verify_locations(str(private_cert_path))
|
||||
ssl_context.verify_mode = ssl.CERT_REQUIRED if require_cert else ssl.CERT_NONE
|
||||
return ssl_context
|
||||
|
||||
|
||||
def ssl_context_for_client(
|
||||
private_cert_path: Path, private_key_path: Path, auth: bool
|
||||
) -> Optional[ssl.SSLContext]:
|
||||
def ssl_context_for_client(private_cert_path: Path, private_key_path: Path, auth: bool) -> Optional[ssl.SSLContext]:
|
||||
ssl_context = ssl._create_unverified_context(purpose=ssl.Purpose.SERVER_AUTH)
|
||||
ssl_context.load_cert_chain(
|
||||
certfile=str(private_cert_path), keyfile=str(private_key_path)
|
||||
)
|
||||
ssl_context.load_cert_chain(certfile=str(private_cert_path), keyfile=str(private_key_path))
|
||||
if auth:
|
||||
ssl_context.verify_mode = ssl.CERT_REQUIRED
|
||||
ssl_context.load_verify_locations(str(private_cert_path))
|
||||
@ -67,6 +61,7 @@ class ChiaServer:
|
||||
# Keeps track of all connections to and from this node.
|
||||
self.all_connections: Dict[bytes32, WSChiaConnection] = {}
|
||||
self.full_nodes: Dict[bytes32, WSChiaConnection] = {}
|
||||
self.tasks: Set[asyncio.Task] = set()
|
||||
|
||||
self.connection_by_type: Dict[NodeType, Dict[str, WSChiaConnection]] = {}
|
||||
self._port = port # TCP port to identify our node
|
||||
@ -91,9 +86,7 @@ class ChiaServer:
|
||||
self.root_path = root_path
|
||||
self.config = config
|
||||
self.on_connect: Optional[Callable] = None
|
||||
self.incoming_messages: Queue[
|
||||
Tuple[Payload, WSChiaConnection]
|
||||
] = asyncio.Queue()
|
||||
self.incoming_messages: Queue[Tuple[Payload, WSChiaConnection]] = asyncio.Queue()
|
||||
self.shut_down_event = asyncio.Event()
|
||||
|
||||
if self._local_type is NodeType.INTRODUCER:
|
||||
@ -108,6 +101,10 @@ class ChiaServer:
|
||||
self.app: Optional[Application] = None
|
||||
self.site: Optional[TCPSite] = None
|
||||
|
||||
self.connection_close_task: Optional[asyncio.Task] = None
|
||||
self.site_shutdown_task: Optional[asyncio.Task] = None
|
||||
self.app_shut_down_task: Optional[asyncio.Task] = None
|
||||
|
||||
async def start_server(self, on_connect: Callable = None):
|
||||
self.app = web.Application()
|
||||
self.on_connect = on_connect
|
||||
@ -118,9 +115,7 @@ class ChiaServer:
|
||||
self.runner = web.AppRunner(self.app, access_log=None)
|
||||
await self.runner.setup()
|
||||
require_cert = self._local_type not in (NodeType.FULL_NODE, NodeType.INTRODUCER)
|
||||
ssl_context = ssl_context_for_server(
|
||||
self._private_cert_path, self._private_key_path, require_cert
|
||||
)
|
||||
ssl_context = ssl_context_for_server(self._private_cert_path, self._private_key_path, require_cert)
|
||||
if self._local_type not in [NodeType.WALLET, NodeType.HARVESTER]:
|
||||
self.site = web.TCPSite(
|
||||
self.runner,
|
||||
@ -159,10 +154,7 @@ class ChiaServer:
|
||||
|
||||
assert handshake is True
|
||||
await self.connection_added(connection, self.on_connect)
|
||||
if (
|
||||
self._local_type is NodeType.INTRODUCER
|
||||
and connection.connection_type is NodeType.FULL_NODE
|
||||
):
|
||||
if self._local_type is NodeType.INTRODUCER and connection.connection_type is NodeType.FULL_NODE:
|
||||
self.introducer_peers.add(connection.get_peer_info())
|
||||
except Exception as e:
|
||||
error_stack = traceback.format_exc()
|
||||
@ -191,18 +183,14 @@ class ChiaServer:
|
||||
Tries to connect to the target node, adding one connection into the pipeline, if successful.
|
||||
An on connect method can also be specified, and this will be saved into the instance variables.
|
||||
"""
|
||||
ssl_context = ssl_context_for_client(
|
||||
self._private_cert_path, self._private_key_path, auth
|
||||
)
|
||||
ssl_context = ssl_context_for_client(self._private_cert_path, self._private_key_path, auth)
|
||||
session = None
|
||||
try:
|
||||
timeout = aiohttp.ClientTimeout(total=10)
|
||||
session = aiohttp.ClientSession(timeout=timeout)
|
||||
url = f"wss://{target_node.host}:{target_node.port}/ws"
|
||||
self.log.info(f"Connecting: {url}")
|
||||
ws = await session.ws_connect(
|
||||
url, autoclose=False, autoping=True, ssl=ssl_context
|
||||
)
|
||||
ws = await session.ws_connect(url, autoclose=False, autoping=True, ssl=ssl_context)
|
||||
if ws is not None:
|
||||
connection = WSChiaConnection(
|
||||
self._local_type,
|
||||
@ -246,82 +234,59 @@ class ChiaServer:
|
||||
async def incoming_api_task(self):
|
||||
self.tasks = set()
|
||||
while True:
|
||||
payload, connection = await self.incoming_messages.get()
|
||||
if payload is None or connection is None:
|
||||
payload_inc, connection_inc = await self.incoming_messages.get()
|
||||
if payload_inc is None or connection_inc is None:
|
||||
continue
|
||||
|
||||
async def api_call(payload: Payload, connection: WSChiaConnection):
|
||||
try:
|
||||
full_message = payload.msg
|
||||
connection.log.info(
|
||||
f"<- {full_message.function} from peer {connection.peer_node_id}"
|
||||
)
|
||||
if len(
|
||||
full_message.function
|
||||
) == 0 or full_message.function.startswith("_"):
|
||||
connection.log.info(f"<- {full_message.function} from peer {connection.peer_node_id}")
|
||||
if len(full_message.function) == 0 or full_message.function.startswith("_"):
|
||||
# This prevents remote calling of private methods that start with "_"
|
||||
self.log.error(
|
||||
f"Non existing function: {full_message.function}"
|
||||
)
|
||||
raise ProtocolError(
|
||||
Err.INVALID_PROTOCOL_MESSAGE, [full_message.function]
|
||||
)
|
||||
self.log.error(f"Non existing function: {full_message.function}")
|
||||
raise ProtocolError(Err.INVALID_PROTOCOL_MESSAGE, [full_message.function])
|
||||
|
||||
f = getattr(self.api, full_message.function, None)
|
||||
|
||||
if f is None:
|
||||
self.log.error(
|
||||
f"Non existing function: {full_message.function}"
|
||||
)
|
||||
raise ProtocolError(
|
||||
Err.INVALID_PROTOCOL_MESSAGE, [full_message.function]
|
||||
)
|
||||
self.log.error(f"Non existing function: {full_message.function}")
|
||||
raise ProtocolError(Err.INVALID_PROTOCOL_MESSAGE, [full_message.function])
|
||||
|
||||
if hasattr(f, "peer_required"):
|
||||
response: Optional[Message] = await f(
|
||||
full_message.data, connection
|
||||
)
|
||||
response: Optional[Message] = await f(full_message.data, connection)
|
||||
else:
|
||||
response = await f(full_message.data)
|
||||
|
||||
if response is not None:
|
||||
id = payload.id
|
||||
response_payload = Payload(response, id)
|
||||
payload_id = payload.id
|
||||
response_payload = Payload(response, payload_id)
|
||||
await connection.reply_to_request(response_payload)
|
||||
|
||||
except Exception as e:
|
||||
tb = traceback.format_exc()
|
||||
connection.log.error(
|
||||
f"Exception: {e}, closing connection {connection}. {tb}"
|
||||
)
|
||||
connection.log.error(f"Exception: {e}, closing connection {connection}. {tb}")
|
||||
await connection.close()
|
||||
|
||||
asyncio.create_task(api_call(payload, connection))
|
||||
asyncio.create_task(api_call(payload_inc, connection_inc))
|
||||
|
||||
async def send_to_others(
|
||||
self, messages: List[Message], type: NodeType, origin_peer: WSChiaConnection
|
||||
):
|
||||
for id, connection in self.all_connections.items():
|
||||
if id == origin_peer.peer_node_id:
|
||||
async def send_to_others(self, messages: List[Message], node_type: NodeType, origin_peer: WSChiaConnection):
|
||||
for node_id, connection in self.all_connections.items():
|
||||
if node_id == origin_peer.peer_node_id:
|
||||
continue
|
||||
if connection.connection_type is type:
|
||||
if connection.connection_type is node_type:
|
||||
for message in messages:
|
||||
await connection.send_message(message)
|
||||
|
||||
async def send_to_all(self, messages: List[Message], type: NodeType):
|
||||
for id, connection in self.all_connections.items():
|
||||
if connection.connection_type is type:
|
||||
async def send_to_all(self, messages: List[Message], node_type: NodeType):
|
||||
for _, connection in self.all_connections.items():
|
||||
if connection.connection_type is node_type:
|
||||
for message in messages:
|
||||
await connection.send_message(message)
|
||||
|
||||
async def send_to_all_except(
|
||||
self, messages: List[Message], type: NodeType, exclude: bytes32
|
||||
):
|
||||
for id, connection in self.all_connections.items():
|
||||
if (
|
||||
connection.connection_type is type
|
||||
and connection.peer_node_id != exclude
|
||||
):
|
||||
async def send_to_all_except(self, messages: List[Message], node_type: NodeType, exclude: bytes32):
|
||||
for _, connection in self.all_connections.items():
|
||||
if connection.connection_type is node_type and connection.peer_node_id != exclude:
|
||||
for message in messages:
|
||||
await connection.send_message(message)
|
||||
|
||||
@ -333,7 +298,7 @@ class ChiaServer:
|
||||
|
||||
def get_outgoing_connections(self) -> List[WSChiaConnection]:
|
||||
result = []
|
||||
for id, connection in self.all_connections.items():
|
||||
for _, connection in self.all_connections.items():
|
||||
if connection.is_outbound:
|
||||
result.append(connection)
|
||||
|
||||
@ -341,7 +306,7 @@ class ChiaServer:
|
||||
|
||||
def get_full_node_connections(self) -> List[WSChiaConnection]:
|
||||
result = []
|
||||
for id, connection in self.all_connections.items():
|
||||
for _, connection in self.all_connections.items():
|
||||
if connection.connection_type is NodeType.FULL_NODE:
|
||||
result.append(connection)
|
||||
|
||||
@ -349,22 +314,22 @@ class ChiaServer:
|
||||
|
||||
def get_connections(self) -> List[WSChiaConnection]:
|
||||
result = []
|
||||
for id, connection in self.all_connections.items():
|
||||
for _, connection in self.all_connections.items():
|
||||
result.append(connection)
|
||||
return result
|
||||
|
||||
async def close_all_connections(self):
|
||||
keys = [a for a, b in self.all_connections.items()]
|
||||
for id in keys:
|
||||
for node_id in keys:
|
||||
try:
|
||||
if id in self.all_connections:
|
||||
connection = self.all_connections[id]
|
||||
if node_id in self.all_connections:
|
||||
connection = self.all_connections[node_id]
|
||||
await connection.close()
|
||||
except Exception as e:
|
||||
self.log.error(f"exeption while closing connection {e}")
|
||||
|
||||
def close_all(self):
|
||||
self.connection_close_taks = asyncio.create_task(self.close_all_connections())
|
||||
self.connection_close_task = asyncio.create_task(self.close_all_connections())
|
||||
self.site_shutdown_task = asyncio.create_task(self.runner.cleanup())
|
||||
self.app_shut_down_task = asyncio.create_task(self.app.shutdown())
|
||||
|
||||
@ -375,9 +340,12 @@ class ChiaServer:
|
||||
async def await_closed(self):
|
||||
self.log.info("Await Closed")
|
||||
await self.shut_down_event.wait()
|
||||
await self.connection_close_taks
|
||||
await self.app_shut_down_task
|
||||
await self.site_shutdown_task
|
||||
if self.connection_close_task is not None:
|
||||
await self.connection_close_task
|
||||
if self.app_shut_down_task is not None:
|
||||
await self.app_shut_down_task
|
||||
if self.site_shutdown_task is not None:
|
||||
await self.site_shutdown_task
|
||||
|
||||
async def get_peer_info(self) -> Optional[PeerInfo]:
|
||||
ip = None
|
||||
|
@ -20,7 +20,6 @@ from src.util.errors import Err, ProtocolError
|
||||
from src.util.network import class_for_type
|
||||
|
||||
LENGTH_BYTES: int = 4
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
OnConnectFunc = Optional[Callable[[], AsyncGenerator[OutboundMessage, None]]]
|
||||
|
||||
@ -54,9 +53,7 @@ class WSChiaConnection:
|
||||
# Remote properties
|
||||
self.peer_host = peer_host
|
||||
|
||||
connection_host, connection_port = self.ws._writer.transport.get_extra_info(
|
||||
"peername"
|
||||
)
|
||||
connection_host, connection_port = self.ws._writer.transport.get_extra_info("peername")
|
||||
|
||||
self.peer_port = connection_port
|
||||
self.peer_server_port: Optional[uint16] = None
|
||||
@ -75,9 +72,7 @@ class WSChiaConnection:
|
||||
self.last_message_time: float = 0
|
||||
|
||||
# Messaging
|
||||
self.incoming_queue: asyncio.Queue[
|
||||
Tuple[Payload, WSChiaConnection]
|
||||
] = incoming_queue
|
||||
self.incoming_queue: asyncio.Queue[Tuple[Payload, WSChiaConnection]] = incoming_queue
|
||||
self.outgoing_queue: asyncio.Queue[Payload] = asyncio.Queue()
|
||||
|
||||
self.inbound_task = None
|
||||
@ -90,10 +85,9 @@ class WSChiaConnection:
|
||||
self.pending_requests: Dict[bytes32, asyncio.Event] = {}
|
||||
self.request_results: Dict[bytes32, Payload] = {}
|
||||
self.closed = False
|
||||
self.connection_type = None
|
||||
|
||||
async def perform_handshake(
|
||||
self, network_id, protocol_version, node_id, server_port, local_type
|
||||
):
|
||||
async def perform_handshake(self, network_id, protocol_version, node_id, server_port, local_type):
|
||||
self.log.info("Doing handshake")
|
||||
if self.is_outbound:
|
||||
self.log.info("Outbound handshake")
|
||||
@ -111,11 +105,7 @@ class WSChiaConnection:
|
||||
await self._send_message(payload)
|
||||
payload = await self._read_one_message()
|
||||
inbound_handshake = Handshake(**payload.msg.data)
|
||||
if (
|
||||
payload.msg.function != "handshake"
|
||||
or not inbound_handshake
|
||||
or not inbound_handshake.node_type
|
||||
):
|
||||
if payload.msg.function != "handshake" or not inbound_handshake or not inbound_handshake.node_type:
|
||||
raise ProtocolError(Err.INVALID_HANDSHAKE)
|
||||
self.peer_node_id = inbound_handshake.node_id
|
||||
self.peer_server_port = int(inbound_handshake.server_port)
|
||||
@ -124,11 +114,7 @@ class WSChiaConnection:
|
||||
self.log.info("Inbound handshake")
|
||||
payload = await self._read_one_message()
|
||||
inbound_handshake = Handshake(**payload.msg.data)
|
||||
if (
|
||||
payload.msg.function != "handshake"
|
||||
or not inbound_handshake
|
||||
or not inbound_handshake.node_type
|
||||
):
|
||||
if payload.msg.function != "handshake" or not inbound_handshake or not inbound_handshake.node_type:
|
||||
raise ProtocolError(Err.INVALID_HANDSHAKE)
|
||||
outbound_handshake = Message(
|
||||
"handshake",
|
||||
@ -213,9 +199,7 @@ class WSChiaConnection:
|
||||
msg = Message(attr_name, args[0])
|
||||
result = await self.create_request(msg)
|
||||
if result is not None:
|
||||
ret_attr = getattr(
|
||||
class_for_type(self.local_type), result.function, None
|
||||
)
|
||||
ret_attr = getattr(class_for_type(self.local_type), result.function, None)
|
||||
|
||||
req_annotations = ret_attr.__annotations__
|
||||
req = None
|
||||
@ -236,27 +220,25 @@ class WSChiaConnection:
|
||||
return None
|
||||
|
||||
event = asyncio.Event()
|
||||
id = token_bytes(8)
|
||||
payload = Payload(message, id)
|
||||
payload = Payload(message, token_bytes(8))
|
||||
|
||||
self.pending_requests[payload.id] = event
|
||||
await self.outgoing_queue.put(payload)
|
||||
|
||||
async def time_out(req_id, req_timeout):
|
||||
await asyncio.sleep(req_timeout)
|
||||
if id in self.pending_requests:
|
||||
event = self.pending_requests[req_id]
|
||||
event.set()
|
||||
if req_id in self.pending_requests:
|
||||
self.pending_requests[req_id].set()
|
||||
|
||||
asyncio.create_task(time_out(id, timeout))
|
||||
asyncio.create_task(time_out(payload.id, timeout))
|
||||
await event.wait()
|
||||
|
||||
self.pending_requests.pop(id)
|
||||
self.pending_requests.pop(payload.id)
|
||||
result: Optional[Message] = None
|
||||
if id in self.request_results:
|
||||
result_payload: Payload = self.request_results[id]
|
||||
if payload.id in self.request_results:
|
||||
result_payload: Payload = self.request_results[payload.id]
|
||||
result = result_payload.msg
|
||||
self.request_results.pop(id)
|
||||
self.request_results.pop(payload.id)
|
||||
|
||||
return result
|
||||
|
||||
@ -274,9 +256,7 @@ class WSChiaConnection:
|
||||
|
||||
async def _send_message(self, payload: Payload):
|
||||
self.log.info(f"-> {payload.msg.function}")
|
||||
encoded: bytes = cbor.dumps(
|
||||
{"f": payload.msg.function, "d": payload.msg.data, "i": payload.id}
|
||||
)
|
||||
encoded: bytes = cbor.dumps({"f": payload.msg.function, "d": payload.msg.data, "i": payload.id})
|
||||
size = len(encoded)
|
||||
assert len(encoded) < (2 ** (LENGTH_BYTES * 8))
|
||||
await self.ws.send_bytes(encoded)
|
||||
@ -303,7 +283,5 @@ class WSChiaConnection:
|
||||
return None
|
||||
|
||||
def get_peer_info(self):
|
||||
connection_host, connection_port = self.ws._writer.transport.get_extra_info(
|
||||
"peername"
|
||||
)
|
||||
connection_host, connection_port = self.ws._writer.transport.get_extra_info("peername")
|
||||
return PeerInfo(connection_host, self.peer_server_port)
|
||||
|
@ -1,12 +1,18 @@
|
||||
import asyncio
|
||||
|
||||
import aiohttp
|
||||
import pytest
|
||||
import random
|
||||
import time
|
||||
import logging
|
||||
from typing import Dict
|
||||
from secrets import token_bytes
|
||||
|
||||
from src.full_node.full_node_api import FullNodeAPI
|
||||
from src.protocols import full_node_protocol as fnp, wallet_protocol
|
||||
from src.server.outbound_message import NodeType
|
||||
from src.server.server import ssl_context_for_client, ChiaServer
|
||||
from src.server.ws_connection import WSChiaConnection
|
||||
from src.types.coin import hash_coin_list
|
||||
from src.types.mempool_inclusion_status import MempoolInclusionStatus
|
||||
from src.types.peer_info import TimestampedPeerInfo, PeerInfo
|
||||
@ -29,6 +35,9 @@ from tests.setup_nodes import setup_two_nodes, test_constants, bt
|
||||
from src.util.wallet_tools import WalletTool
|
||||
from src.util.clvm import int_to_bytes
|
||||
from tests.time_out_assert import time_out_assert, time_out_assert_custom_interval
|
||||
from src.protocols.shared_protocol import protocol_version
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def get_block_path(full_node: FullNodeAPI):
|
||||
@ -40,6 +49,23 @@ async def get_block_path(full_node: FullNodeAPI):
|
||||
return blocks_list
|
||||
|
||||
|
||||
async def add_dummy_connection(server: ChiaServer, dummy_port: int) -> asyncio.Queue:
|
||||
timeout = aiohttp.ClientTimeout(total=10)
|
||||
session = aiohttp.ClientSession(timeout=timeout)
|
||||
incoming_queue = asyncio.Queue()
|
||||
ssl_context = ssl_context_for_client(server._private_cert_path, server._private_key_path, False)
|
||||
url = f"wss://127.0.0.1:{server._port}/ws"
|
||||
ws = await session.ws_connect(url, autoclose=False, autoping=True, ssl=ssl_context)
|
||||
wsc = WSChiaConnection(
|
||||
NodeType.FULL_NODE, ws, server._port, log, True, False, "127.0.0.1", incoming_queue, lambda x: x
|
||||
)
|
||||
handshake = await wsc.perform_handshake(
|
||||
server._network_id, protocol_version, std_hash(b"123"), dummy_port, NodeType.FULL_NODE
|
||||
)
|
||||
assert handshake is True
|
||||
return incoming_queue
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def event_loop():
|
||||
loop = asyncio.get_event_loop()
|
||||
@ -105,10 +131,33 @@ class TestFullNodeProtocol:
|
||||
async def test_basic_chain(self, two_nodes):
|
||||
full_node_1, full_node_2, server_1, server_2 = two_nodes
|
||||
|
||||
blocks = bt.get_consecutive_blocks(100)
|
||||
for block in blocks:
|
||||
incoming_queue = await add_dummy_connection(server_1, 12312)
|
||||
|
||||
async def has_mempool_tx():
|
||||
if incoming_queue.qsize() == 0:
|
||||
return False
|
||||
res = set()
|
||||
while incoming_queue.qsize() > 0:
|
||||
res.add((await incoming_queue.get())[0].msg.function)
|
||||
return res == {"request_mempool_transactions"}
|
||||
|
||||
await time_out_assert(10, has_mempool_tx, True)
|
||||
|
||||
blocks = bt.get_consecutive_blocks(1)
|
||||
for block in blocks[:1]:
|
||||
await full_node_1.respond_sub_block(fnp.RespondSubBlock(block))
|
||||
assert full_node_1.full_node.blockchain.get_peak().height == 99
|
||||
|
||||
async def has_new_peak():
|
||||
if incoming_queue.qsize() == 0:
|
||||
return False
|
||||
res = set()
|
||||
while incoming_queue.qsize() > 0:
|
||||
res.add((await incoming_queue.get())[0].msg.function)
|
||||
return res == {"new_peak"}
|
||||
|
||||
await time_out_assert(10, has_new_peak, True)
|
||||
|
||||
assert full_node_1.full_node.blockchain.get_peak().height == 0
|
||||
|
||||
|
||||
# @pytest.mark.asyncio
|
||||
|
Loading…
Reference in New Issue
Block a user