Basic full node networking tests

This commit is contained in:
Mariano Sorgente 2020-11-30 17:54:18 +09:00 committed by Yostra
parent e1c1f6c341
commit 273fe0355b
5 changed files with 127 additions and 132 deletions

View File

@ -492,7 +492,7 @@ class FullNode:
added_eos.reward_chain.end_of_slot_vdf.challenge,
)
msg = Message("new_signage_point_or_end_of_sub_slot", broadcast)
self.server.send_to_all([msg], NodeType.FullNode)
await self.server.send_to_all([msg], NodeType.FullNode)
if new_peak.height % 1000 == 0:
# Occasionally clear the seen list to keep it small
@ -511,7 +511,7 @@ class FullNode:
sub_block.reward_chain_sub_block.get_unfinished().get_hash(),
),
)
self.server.send_to_all([msg], NodeType.FULL_NODE)
await self.server.send_to_all([msg], NodeType.FULL_NODE)
# Tell wallets about the new peak
msg = Message(
@ -523,7 +523,7 @@ class FullNode:
fork_height,
),
)
self.server.send_to_all([msg], NodeType.WALLET)
await self.server.send_to_all([msg], NodeType.WALLET)
elif added == ReceiveBlockResult.ADDED_AS_ORPHAN:
self.log.info(f"Received orphan block of height {sub_block.height}")
@ -607,12 +607,12 @@ class FullNode:
)
msg = Message("new_unfinished_sub_block", timelord_request)
self.server.send_to_all([msg], NodeType.TIMELORD)
await self.server.send_to_all([msg], NodeType.TIMELORD)
full_node_request = full_node_protocol.NewUnfinishedSubBlock(block.reward_chain_sub_block.get_hash())
msg = Message("new_unfinished_sub_block", full_node_request)
if peer is not None:
self.server.send_to_all_except([msg], NodeType.FULL_NODE, peer.peer_node_id)
await self.server.send_to_all_except([msg], NodeType.FULL_NODE, peer.peer_node_id)
else:
self.server.send_to_all([msg], NodeType.FULL_NODE)
await self.server.send_to_all([msg], NodeType.FULL_NODE)
self._state_changed("sub_block")

View File

@ -307,7 +307,7 @@ class FullNodeAPI:
request.reward_chain_vdf.challenge_hash,
)
msg = Message("new_signage_point_or_end_of_sub_slot", broadcast)
self.server.send_to_all_except([msg], NodeType.FULL_NODE, peer.peer_node_id)
await self.server.send_to_all_except([msg], NodeType.FULL_NODE, peer.peer_node_id)
return
@ -328,7 +328,7 @@ class FullNodeAPI:
request.end_of_slot_bundle.reward_chain.end_of_slot_vdf.challenge_hash,
)
msg = Message("new_signage_point_or_end_of_sub_slot", broadcast)
self.server.send_to_all_except([msg], NodeType.FULL_NODE, peer.peer_node_id)
await self.server.send_to_all_except([msg], NodeType.FULL_NODE, peer.peer_node_id)
return

View File

@ -3,7 +3,7 @@ import logging
import ssl
from asyncio import Queue
from pathlib import Path
from typing import Any, List, Dict, Tuple, Callable, Optional
from typing import Any, List, Dict, Tuple, Callable, Optional, Set
import aiohttp
from aiohttp import web
@ -29,21 +29,15 @@ def ssl_context_for_server(
private_cert_path: Path, private_key_path: Path, require_cert: bool = False
) -> Optional[ssl.SSLContext]:
ssl_context = ssl._create_unverified_context(purpose=ssl.Purpose.CLIENT_AUTH)
ssl_context.load_cert_chain(
certfile=str(private_cert_path), keyfile=str(private_key_path)
)
ssl_context.load_cert_chain(certfile=str(private_cert_path), keyfile=str(private_key_path))
ssl_context.load_verify_locations(str(private_cert_path))
ssl_context.verify_mode = ssl.CERT_REQUIRED if require_cert else ssl.CERT_NONE
return ssl_context
def ssl_context_for_client(
private_cert_path: Path, private_key_path: Path, auth: bool
) -> Optional[ssl.SSLContext]:
def ssl_context_for_client(private_cert_path: Path, private_key_path: Path, auth: bool) -> Optional[ssl.SSLContext]:
ssl_context = ssl._create_unverified_context(purpose=ssl.Purpose.SERVER_AUTH)
ssl_context.load_cert_chain(
certfile=str(private_cert_path), keyfile=str(private_key_path)
)
ssl_context.load_cert_chain(certfile=str(private_cert_path), keyfile=str(private_key_path))
if auth:
ssl_context.verify_mode = ssl.CERT_REQUIRED
ssl_context.load_verify_locations(str(private_cert_path))
@ -67,6 +61,7 @@ class ChiaServer:
# Keeps track of all connections to and from this node.
self.all_connections: Dict[bytes32, WSChiaConnection] = {}
self.full_nodes: Dict[bytes32, WSChiaConnection] = {}
self.tasks: Set[asyncio.Task] = set()
self.connection_by_type: Dict[NodeType, Dict[str, WSChiaConnection]] = {}
self._port = port # TCP port to identify our node
@ -91,9 +86,7 @@ class ChiaServer:
self.root_path = root_path
self.config = config
self.on_connect: Optional[Callable] = None
self.incoming_messages: Queue[
Tuple[Payload, WSChiaConnection]
] = asyncio.Queue()
self.incoming_messages: Queue[Tuple[Payload, WSChiaConnection]] = asyncio.Queue()
self.shut_down_event = asyncio.Event()
if self._local_type is NodeType.INTRODUCER:
@ -108,6 +101,10 @@ class ChiaServer:
self.app: Optional[Application] = None
self.site: Optional[TCPSite] = None
self.connection_close_task: Optional[asyncio.Task] = None
self.site_shutdown_task: Optional[asyncio.Task] = None
self.app_shut_down_task: Optional[asyncio.Task] = None
async def start_server(self, on_connect: Callable = None):
self.app = web.Application()
self.on_connect = on_connect
@ -118,9 +115,7 @@ class ChiaServer:
self.runner = web.AppRunner(self.app, access_log=None)
await self.runner.setup()
require_cert = self._local_type not in (NodeType.FULL_NODE, NodeType.INTRODUCER)
ssl_context = ssl_context_for_server(
self._private_cert_path, self._private_key_path, require_cert
)
ssl_context = ssl_context_for_server(self._private_cert_path, self._private_key_path, require_cert)
if self._local_type not in [NodeType.WALLET, NodeType.HARVESTER]:
self.site = web.TCPSite(
self.runner,
@ -159,10 +154,7 @@ class ChiaServer:
assert handshake is True
await self.connection_added(connection, self.on_connect)
if (
self._local_type is NodeType.INTRODUCER
and connection.connection_type is NodeType.FULL_NODE
):
if self._local_type is NodeType.INTRODUCER and connection.connection_type is NodeType.FULL_NODE:
self.introducer_peers.add(connection.get_peer_info())
except Exception as e:
error_stack = traceback.format_exc()
@ -191,18 +183,14 @@ class ChiaServer:
Tries to connect to the target node, adding one connection into the pipeline, if successful.
An on connect method can also be specified, and this will be saved into the instance variables.
"""
ssl_context = ssl_context_for_client(
self._private_cert_path, self._private_key_path, auth
)
ssl_context = ssl_context_for_client(self._private_cert_path, self._private_key_path, auth)
session = None
try:
timeout = aiohttp.ClientTimeout(total=10)
session = aiohttp.ClientSession(timeout=timeout)
url = f"wss://{target_node.host}:{target_node.port}/ws"
self.log.info(f"Connecting: {url}")
ws = await session.ws_connect(
url, autoclose=False, autoping=True, ssl=ssl_context
)
ws = await session.ws_connect(url, autoclose=False, autoping=True, ssl=ssl_context)
if ws is not None:
connection = WSChiaConnection(
self._local_type,
@ -246,82 +234,59 @@ class ChiaServer:
async def incoming_api_task(self):
self.tasks = set()
while True:
payload, connection = await self.incoming_messages.get()
if payload is None or connection is None:
payload_inc, connection_inc = await self.incoming_messages.get()
if payload_inc is None or connection_inc is None:
continue
async def api_call(payload: Payload, connection: WSChiaConnection):
try:
full_message = payload.msg
connection.log.info(
f"<- {full_message.function} from peer {connection.peer_node_id}"
)
if len(
full_message.function
) == 0 or full_message.function.startswith("_"):
connection.log.info(f"<- {full_message.function} from peer {connection.peer_node_id}")
if len(full_message.function) == 0 or full_message.function.startswith("_"):
# This prevents remote calling of private methods that start with "_"
self.log.error(
f"Non existing function: {full_message.function}"
)
raise ProtocolError(
Err.INVALID_PROTOCOL_MESSAGE, [full_message.function]
)
self.log.error(f"Non existing function: {full_message.function}")
raise ProtocolError(Err.INVALID_PROTOCOL_MESSAGE, [full_message.function])
f = getattr(self.api, full_message.function, None)
if f is None:
self.log.error(
f"Non existing function: {full_message.function}"
)
raise ProtocolError(
Err.INVALID_PROTOCOL_MESSAGE, [full_message.function]
)
self.log.error(f"Non existing function: {full_message.function}")
raise ProtocolError(Err.INVALID_PROTOCOL_MESSAGE, [full_message.function])
if hasattr(f, "peer_required"):
response: Optional[Message] = await f(
full_message.data, connection
)
response: Optional[Message] = await f(full_message.data, connection)
else:
response = await f(full_message.data)
if response is not None:
id = payload.id
response_payload = Payload(response, id)
payload_id = payload.id
response_payload = Payload(response, payload_id)
await connection.reply_to_request(response_payload)
except Exception as e:
tb = traceback.format_exc()
connection.log.error(
f"Exception: {e}, closing connection {connection}. {tb}"
)
connection.log.error(f"Exception: {e}, closing connection {connection}. {tb}")
await connection.close()
asyncio.create_task(api_call(payload, connection))
asyncio.create_task(api_call(payload_inc, connection_inc))
async def send_to_others(
self, messages: List[Message], type: NodeType, origin_peer: WSChiaConnection
):
for id, connection in self.all_connections.items():
if id == origin_peer.peer_node_id:
async def send_to_others(self, messages: List[Message], node_type: NodeType, origin_peer: WSChiaConnection):
for node_id, connection in self.all_connections.items():
if node_id == origin_peer.peer_node_id:
continue
if connection.connection_type is type:
if connection.connection_type is node_type:
for message in messages:
await connection.send_message(message)
async def send_to_all(self, messages: List[Message], type: NodeType):
for id, connection in self.all_connections.items():
if connection.connection_type is type:
async def send_to_all(self, messages: List[Message], node_type: NodeType):
for _, connection in self.all_connections.items():
if connection.connection_type is node_type:
for message in messages:
await connection.send_message(message)
async def send_to_all_except(
self, messages: List[Message], type: NodeType, exclude: bytes32
):
for id, connection in self.all_connections.items():
if (
connection.connection_type is type
and connection.peer_node_id != exclude
):
async def send_to_all_except(self, messages: List[Message], node_type: NodeType, exclude: bytes32):
for _, connection in self.all_connections.items():
if connection.connection_type is node_type and connection.peer_node_id != exclude:
for message in messages:
await connection.send_message(message)
@ -333,7 +298,7 @@ class ChiaServer:
def get_outgoing_connections(self) -> List[WSChiaConnection]:
result = []
for id, connection in self.all_connections.items():
for _, connection in self.all_connections.items():
if connection.is_outbound:
result.append(connection)
@ -341,7 +306,7 @@ class ChiaServer:
def get_full_node_connections(self) -> List[WSChiaConnection]:
result = []
for id, connection in self.all_connections.items():
for _, connection in self.all_connections.items():
if connection.connection_type is NodeType.FULL_NODE:
result.append(connection)
@ -349,22 +314,22 @@ class ChiaServer:
def get_connections(self) -> List[WSChiaConnection]:
result = []
for id, connection in self.all_connections.items():
for _, connection in self.all_connections.items():
result.append(connection)
return result
async def close_all_connections(self):
keys = [a for a, b in self.all_connections.items()]
for id in keys:
for node_id in keys:
try:
if id in self.all_connections:
connection = self.all_connections[id]
if node_id in self.all_connections:
connection = self.all_connections[node_id]
await connection.close()
except Exception as e:
self.log.error(f"exeption while closing connection {e}")
def close_all(self):
self.connection_close_taks = asyncio.create_task(self.close_all_connections())
self.connection_close_task = asyncio.create_task(self.close_all_connections())
self.site_shutdown_task = asyncio.create_task(self.runner.cleanup())
self.app_shut_down_task = asyncio.create_task(self.app.shutdown())
@ -375,9 +340,12 @@ class ChiaServer:
async def await_closed(self):
self.log.info("Await Closed")
await self.shut_down_event.wait()
await self.connection_close_taks
await self.app_shut_down_task
await self.site_shutdown_task
if self.connection_close_task is not None:
await self.connection_close_task
if self.app_shut_down_task is not None:
await self.app_shut_down_task
if self.site_shutdown_task is not None:
await self.site_shutdown_task
async def get_peer_info(self) -> Optional[PeerInfo]:
ip = None

View File

@ -20,7 +20,6 @@ from src.util.errors import Err, ProtocolError
from src.util.network import class_for_type
LENGTH_BYTES: int = 4
log = logging.getLogger(__name__)
OnConnectFunc = Optional[Callable[[], AsyncGenerator[OutboundMessage, None]]]
@ -54,9 +53,7 @@ class WSChiaConnection:
# Remote properties
self.peer_host = peer_host
connection_host, connection_port = self.ws._writer.transport.get_extra_info(
"peername"
)
connection_host, connection_port = self.ws._writer.transport.get_extra_info("peername")
self.peer_port = connection_port
self.peer_server_port: Optional[uint16] = None
@ -75,9 +72,7 @@ class WSChiaConnection:
self.last_message_time: float = 0
# Messaging
self.incoming_queue: asyncio.Queue[
Tuple[Payload, WSChiaConnection]
] = incoming_queue
self.incoming_queue: asyncio.Queue[Tuple[Payload, WSChiaConnection]] = incoming_queue
self.outgoing_queue: asyncio.Queue[Payload] = asyncio.Queue()
self.inbound_task = None
@ -90,10 +85,9 @@ class WSChiaConnection:
self.pending_requests: Dict[bytes32, asyncio.Event] = {}
self.request_results: Dict[bytes32, Payload] = {}
self.closed = False
self.connection_type = None
async def perform_handshake(
self, network_id, protocol_version, node_id, server_port, local_type
):
async def perform_handshake(self, network_id, protocol_version, node_id, server_port, local_type):
self.log.info("Doing handshake")
if self.is_outbound:
self.log.info("Outbound handshake")
@ -111,11 +105,7 @@ class WSChiaConnection:
await self._send_message(payload)
payload = await self._read_one_message()
inbound_handshake = Handshake(**payload.msg.data)
if (
payload.msg.function != "handshake"
or not inbound_handshake
or not inbound_handshake.node_type
):
if payload.msg.function != "handshake" or not inbound_handshake or not inbound_handshake.node_type:
raise ProtocolError(Err.INVALID_HANDSHAKE)
self.peer_node_id = inbound_handshake.node_id
self.peer_server_port = int(inbound_handshake.server_port)
@ -124,11 +114,7 @@ class WSChiaConnection:
self.log.info("Inbound handshake")
payload = await self._read_one_message()
inbound_handshake = Handshake(**payload.msg.data)
if (
payload.msg.function != "handshake"
or not inbound_handshake
or not inbound_handshake.node_type
):
if payload.msg.function != "handshake" or not inbound_handshake or not inbound_handshake.node_type:
raise ProtocolError(Err.INVALID_HANDSHAKE)
outbound_handshake = Message(
"handshake",
@ -213,9 +199,7 @@ class WSChiaConnection:
msg = Message(attr_name, args[0])
result = await self.create_request(msg)
if result is not None:
ret_attr = getattr(
class_for_type(self.local_type), result.function, None
)
ret_attr = getattr(class_for_type(self.local_type), result.function, None)
req_annotations = ret_attr.__annotations__
req = None
@ -236,27 +220,25 @@ class WSChiaConnection:
return None
event = asyncio.Event()
id = token_bytes(8)
payload = Payload(message, id)
payload = Payload(message, token_bytes(8))
self.pending_requests[payload.id] = event
await self.outgoing_queue.put(payload)
async def time_out(req_id, req_timeout):
await asyncio.sleep(req_timeout)
if id in self.pending_requests:
event = self.pending_requests[req_id]
event.set()
if req_id in self.pending_requests:
self.pending_requests[req_id].set()
asyncio.create_task(time_out(id, timeout))
asyncio.create_task(time_out(payload.id, timeout))
await event.wait()
self.pending_requests.pop(id)
self.pending_requests.pop(payload.id)
result: Optional[Message] = None
if id in self.request_results:
result_payload: Payload = self.request_results[id]
if payload.id in self.request_results:
result_payload: Payload = self.request_results[payload.id]
result = result_payload.msg
self.request_results.pop(id)
self.request_results.pop(payload.id)
return result
@ -274,9 +256,7 @@ class WSChiaConnection:
async def _send_message(self, payload: Payload):
self.log.info(f"-> {payload.msg.function}")
encoded: bytes = cbor.dumps(
{"f": payload.msg.function, "d": payload.msg.data, "i": payload.id}
)
encoded: bytes = cbor.dumps({"f": payload.msg.function, "d": payload.msg.data, "i": payload.id})
size = len(encoded)
assert len(encoded) < (2 ** (LENGTH_BYTES * 8))
await self.ws.send_bytes(encoded)
@ -303,7 +283,5 @@ class WSChiaConnection:
return None
def get_peer_info(self):
connection_host, connection_port = self.ws._writer.transport.get_extra_info(
"peername"
)
connection_host, connection_port = self.ws._writer.transport.get_extra_info("peername")
return PeerInfo(connection_host, self.peer_server_port)

View File

@ -1,12 +1,18 @@
import asyncio
import aiohttp
import pytest
import random
import time
import logging
from typing import Dict
from secrets import token_bytes
from src.full_node.full_node_api import FullNodeAPI
from src.protocols import full_node_protocol as fnp, wallet_protocol
from src.server.outbound_message import NodeType
from src.server.server import ssl_context_for_client, ChiaServer
from src.server.ws_connection import WSChiaConnection
from src.types.coin import hash_coin_list
from src.types.mempool_inclusion_status import MempoolInclusionStatus
from src.types.peer_info import TimestampedPeerInfo, PeerInfo
@ -29,6 +35,9 @@ from tests.setup_nodes import setup_two_nodes, test_constants, bt
from src.util.wallet_tools import WalletTool
from src.util.clvm import int_to_bytes
from tests.time_out_assert import time_out_assert, time_out_assert_custom_interval
from src.protocols.shared_protocol import protocol_version
log = logging.getLogger(__name__)
async def get_block_path(full_node: FullNodeAPI):
@ -40,6 +49,23 @@ async def get_block_path(full_node: FullNodeAPI):
return blocks_list
async def add_dummy_connection(server: ChiaServer, dummy_port: int) -> asyncio.Queue:
timeout = aiohttp.ClientTimeout(total=10)
session = aiohttp.ClientSession(timeout=timeout)
incoming_queue = asyncio.Queue()
ssl_context = ssl_context_for_client(server._private_cert_path, server._private_key_path, False)
url = f"wss://127.0.0.1:{server._port}/ws"
ws = await session.ws_connect(url, autoclose=False, autoping=True, ssl=ssl_context)
wsc = WSChiaConnection(
NodeType.FULL_NODE, ws, server._port, log, True, False, "127.0.0.1", incoming_queue, lambda x: x
)
handshake = await wsc.perform_handshake(
server._network_id, protocol_version, std_hash(b"123"), dummy_port, NodeType.FULL_NODE
)
assert handshake is True
return incoming_queue
@pytest.fixture(scope="module")
def event_loop():
loop = asyncio.get_event_loop()
@ -105,10 +131,33 @@ class TestFullNodeProtocol:
async def test_basic_chain(self, two_nodes):
full_node_1, full_node_2, server_1, server_2 = two_nodes
blocks = bt.get_consecutive_blocks(100)
for block in blocks:
incoming_queue = await add_dummy_connection(server_1, 12312)
async def has_mempool_tx():
if incoming_queue.qsize() == 0:
return False
res = set()
while incoming_queue.qsize() > 0:
res.add((await incoming_queue.get())[0].msg.function)
return res == {"request_mempool_transactions"}
await time_out_assert(10, has_mempool_tx, True)
blocks = bt.get_consecutive_blocks(1)
for block in blocks[:1]:
await full_node_1.respond_sub_block(fnp.RespondSubBlock(block))
assert full_node_1.full_node.blockchain.get_peak().height == 99
async def has_new_peak():
if incoming_queue.qsize() == 0:
return False
res = set()
while incoming_queue.qsize() > 0:
res.add((await incoming_queue.get())[0].msg.function)
return res == {"new_peak"}
await time_out_assert(10, has_new_peak, True)
assert full_node_1.full_node.blockchain.get_peak().height == 0
# @pytest.mark.asyncio