From 373fd13ed6bdfe3f23ec733821a3eeffafde8d8d Mon Sep 17 00:00:00 2001 From: Rostislav Date: Wed, 27 Nov 2019 14:14:05 +0100 Subject: [PATCH] Fix bug with on_connect for target hosts specified as a domain name Make 'on_connect' a property of Connection object when it's created, so that it can be unambiguously associated with a specific connection immediately after it's created. Create a type definition for on_connect functions named OnConnectFunc. --- src/server/connection.py | 8 +++++-- src/server/server.py | 52 +++++++++++++++++----------------------- 2 files changed, 28 insertions(+), 32 deletions(-) diff --git a/src/server/connection.py b/src/server/connection.py index d65ee4a24057..b136d4c8f1db 100644 --- a/src/server/connection.py +++ b/src/server/connection.py @@ -2,9 +2,9 @@ import logging import random import time from asyncio import StreamReader, StreamWriter -from typing import Any, List, Optional +from typing import Any, AsyncGenerator, Callable, List, Optional -from src.server.outbound_message import Message, NodeType +from src.server.outbound_message import Message, NodeType, OutboundMessage from src.types.peer_info import PeerInfo from src.util import cbor from src.util.ints import uint16 @@ -13,6 +13,8 @@ from src.util.ints import uint16 LENGTH_BYTES: int = 4 log = logging.getLogger(__name__) +OnConnectFunc = Optional[Callable[[], AsyncGenerator[OutboundMessage, None]]] + class Connection: """ @@ -28,6 +30,7 @@ class Connection: sr: StreamReader, sw: StreamWriter, server_port: int, + on_connect: OnConnectFunc, ): self.local_type = local_type self.connection_type = connection_type @@ -41,6 +44,7 @@ class Connection: self.peer_port = self.writer.get_extra_info("peername")[1] self.peer_server_port: Optional[int] = None self.node_id = None + self.on_connect = on_connect # Connection metrics self.creation_type = time.time() diff --git a/src/server/server.py b/src/server/server.py index c0ddfa87bdb4..bba25e79e578 100644 --- a/src/server/server.py +++ b/src/server/server.py @@ -4,12 +4,12 @@ import random import time import os from yaml import safe_load -from typing import Any, AsyncGenerator, Callable, Dict, List, Optional, Tuple +from typing import Any, AsyncGenerator, List, Optional, Tuple from aiter import aiter_forker, iter_to_aiter, join_aiters, map_aiter, push_aiter from aiter.server import start_server_aiter from definitions import ROOT_DIR from src.protocols.shared_protocol import Handshake, HandshakeAck, protocol_version -from src.server.connection import Connection, PeerConnections +from src.server.connection import Connection, OnConnectFunc, PeerConnections from src.server.outbound_message import Delivery, Message, NodeType, OutboundMessage from src.types.peer_info import PeerInfo from src.util import partial_func @@ -51,9 +51,8 @@ class ChiaServer: # Aiter used to broadcase messages _outbound_aiter: push_aiter - # These will get called after a handshake is performed - _on_connect_callbacks: Dict[PeerInfo, Callable] = {} - _on_connect_generic_callback: Optional[Callable] = None + # Called for inbound connections after successful handshake + _on_inbound_connect: OnConnectFunc = None def __init__(self, port: int, api: Any, local_type: NodeType): self._port = port # TCP port to identify our node @@ -69,9 +68,7 @@ class ChiaServer: async def start_server( self, host: str, - on_connect: Optional[ - Callable[[], AsyncGenerator[OutboundMessage, None]] - ] = None, + on_connect: OnConnectFunc = None, ) -> bool: """ Launches a listening server on host and port specified, to connect to NodeType nodes. On each @@ -86,12 +83,12 @@ class ChiaServer: self._port, host=None, reuse_address=True ) if on_connect is not None: - self._on_connect_generic_callback = on_connect + self._on_inbound_connect = on_connect def add_connection_type( srw: Tuple[asyncio.StreamReader, asyncio.StreamWriter] - ) -> Tuple[asyncio.StreamReader, asyncio.StreamWriter]: - return (srw[0], srw[1]) + ) -> Tuple[asyncio.StreamReader, asyncio.StreamWriter, None]: + return (srw[0], srw[1], None) srwt_aiter = map_aiter(add_connection_type, aiter) @@ -104,9 +101,7 @@ class ChiaServer: async def start_client( self, target_node: PeerInfo, - on_connect: Optional[ - Callable[[], AsyncGenerator[OutboundMessage, None]] - ] = None, + on_connect: OnConnectFunc = None, ) -> bool: """ Tries to connect to the target node, adding one connection into the pipeline, if successful. @@ -146,14 +141,13 @@ class ChiaServer: ) self.global_connections.peers.remove(target_node) return False - if on_connect is not None: - self._on_connect_callbacks[target_node] = on_connect - asyncio.create_task(self._add_to_srwt_aiter(iter_to_aiter([(reader, writer)]))) + asyncio.create_task(self._add_to_srwt_aiter(iter_to_aiter([(reader, writer, on_connect)]))) return True async def _add_to_srwt_aiter( self, - aiter: AsyncGenerator[Tuple[asyncio.StreamReader, asyncio.StreamWriter], None], + aiter: AsyncGenerator[Tuple[asyncio.StreamReader, asyncio.StreamWriter, + OnConnectFunc], None], ): """ Adds all swrt from aiter into the instance variable srwt_aiter, adding them to the pipeline. @@ -258,14 +252,16 @@ class ChiaServer: return asyncio.get_running_loop().create_task(serve_forever()) async def stream_reader_writer_to_connection( - self, swrt: Tuple[asyncio.StreamReader, asyncio.StreamWriter], server_port: int + self, + swrt: Tuple[asyncio.StreamReader, asyncio.StreamWriter, OnConnectFunc], + server_port: int ) -> Connection: """ - Maps a pair of (StreamReader, StreamWriter) to a Connection object, + Maps a tuple of (StreamReader, StreamWriter, on_connect) to a Connection object, which also stores the type of connection (str). It is also added to the global list. """ - sr, sw = swrt - con = Connection(self._local_type, None, sr, sw, server_port) + sr, sw, on_connect = swrt + con = Connection(self._local_type, None, sr, sw, server_port, on_connect) log.info(f"Connection with {con.get_peername()} established") return con @@ -276,14 +272,10 @@ class ChiaServer: """ Async generator which calls the on_connect async generator method, and yields any outbound messages. """ - peer = PeerInfo(connection.peer_host, connection.peer_port) - if peer in self._on_connect_callbacks: - on_connect = self._on_connect_callbacks[peer] - async for outbound_message in on_connect(): - yield connection, outbound_message - if self._on_connect_generic_callback: - async for outbound_message in self._on_connect_generic_callback(): - yield connection, outbound_message + for func in connection.on_connect, self._on_inbound_connect: + if func: + async for outbound_message in func(): + yield connection, outbound_message async def perform_handshake( self, connection: Connection