mirror of
https://github.com/Chia-Network/chia-blockchain.git
synced 2024-09-21 08:31:52 +03:00
Merge branch 'master' of github.com:Chia-Network/chia-blockchain into speedy
This commit is contained in:
commit
0c86698507
@ -2,9 +2,9 @@ import logging
|
||||
import random
|
||||
import time
|
||||
from asyncio import StreamReader, StreamWriter
|
||||
from typing import Any, List, Optional
|
||||
from typing import Any, AsyncGenerator, Callable, List, Optional
|
||||
|
||||
from src.server.outbound_message import Message, NodeType
|
||||
from src.server.outbound_message import Message, NodeType, OutboundMessage
|
||||
from src.types.peer_info import PeerInfo
|
||||
from src.util import cbor
|
||||
from src.util.ints import uint16
|
||||
@ -13,6 +13,8 @@ from src.util.ints import uint16
|
||||
LENGTH_BYTES: int = 4
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
OnConnectFunc = Optional[Callable[[], AsyncGenerator[OutboundMessage, None]]]
|
||||
|
||||
|
||||
class Connection:
|
||||
"""
|
||||
@ -28,6 +30,7 @@ class Connection:
|
||||
sr: StreamReader,
|
||||
sw: StreamWriter,
|
||||
server_port: int,
|
||||
on_connect: OnConnectFunc,
|
||||
):
|
||||
self.local_type = local_type
|
||||
self.connection_type = connection_type
|
||||
@ -40,6 +43,7 @@ class Connection:
|
||||
self.peer_port = self.writer.get_extra_info("peername")[1]
|
||||
self.peer_server_port: Optional[int] = None
|
||||
self.node_id = None
|
||||
self.on_connect = on_connect
|
||||
|
||||
# Connection metrics
|
||||
self.creation_type = time.time()
|
||||
|
@ -4,12 +4,12 @@ import random
|
||||
import time
|
||||
import os
|
||||
from yaml import safe_load
|
||||
from typing import Any, AsyncGenerator, Callable, Dict, List, Optional, Tuple
|
||||
from typing import Any, AsyncGenerator, List, Optional, Tuple
|
||||
from aiter import aiter_forker, iter_to_aiter, join_aiters, map_aiter, push_aiter
|
||||
from aiter.server import start_server_aiter
|
||||
from definitions import ROOT_DIR
|
||||
from src.protocols.shared_protocol import Handshake, HandshakeAck, protocol_version
|
||||
from src.server.connection import Connection, PeerConnections
|
||||
from src.server.connection import Connection, OnConnectFunc, PeerConnections
|
||||
from src.server.outbound_message import Delivery, Message, NodeType, OutboundMessage
|
||||
from src.types.peer_info import PeerInfo
|
||||
from src.util import partial_func
|
||||
@ -51,9 +51,8 @@ class ChiaServer:
|
||||
# Aiter used to broadcase messages
|
||||
_outbound_aiter: push_aiter
|
||||
|
||||
# These will get called after a handshake is performed
|
||||
_on_connect_callbacks: Dict[PeerInfo, Callable] = {}
|
||||
_on_connect_generic_callback: Optional[Callable] = None
|
||||
# Called for inbound connections after successful handshake
|
||||
_on_inbound_connect: OnConnectFunc = None
|
||||
|
||||
def __init__(self, port: int, api: Any, local_type: NodeType):
|
||||
self._port = port # TCP port to identify our node
|
||||
@ -69,9 +68,7 @@ class ChiaServer:
|
||||
async def start_server(
|
||||
self,
|
||||
host: str,
|
||||
on_connect: Optional[
|
||||
Callable[[], AsyncGenerator[OutboundMessage, None]]
|
||||
] = None,
|
||||
on_connect: OnConnectFunc = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Launches a listening server on host and port specified, to connect to NodeType nodes. On each
|
||||
@ -86,12 +83,12 @@ class ChiaServer:
|
||||
self._port, host=None, reuse_address=True
|
||||
)
|
||||
if on_connect is not None:
|
||||
self._on_connect_generic_callback = on_connect
|
||||
self._on_inbound_connect = on_connect
|
||||
|
||||
def add_connection_type(
|
||||
srw: Tuple[asyncio.StreamReader, asyncio.StreamWriter]
|
||||
) -> Tuple[asyncio.StreamReader, asyncio.StreamWriter]:
|
||||
return (srw[0], srw[1])
|
||||
) -> Tuple[asyncio.StreamReader, asyncio.StreamWriter, None]:
|
||||
return (srw[0], srw[1], None)
|
||||
|
||||
srwt_aiter = map_aiter(add_connection_type, aiter)
|
||||
|
||||
@ -104,9 +101,7 @@ class ChiaServer:
|
||||
async def start_client(
|
||||
self,
|
||||
target_node: PeerInfo,
|
||||
on_connect: Optional[
|
||||
Callable[[], AsyncGenerator[OutboundMessage, None]]
|
||||
] = None,
|
||||
on_connect: OnConnectFunc = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Tries to connect to the target node, adding one connection into the pipeline, if successful.
|
||||
@ -146,14 +141,13 @@ class ChiaServer:
|
||||
)
|
||||
self.global_connections.peers.remove(target_node)
|
||||
return False
|
||||
if on_connect is not None:
|
||||
self._on_connect_callbacks[target_node] = on_connect
|
||||
asyncio.create_task(self._add_to_srwt_aiter(iter_to_aiter([(reader, writer)])))
|
||||
asyncio.create_task(self._add_to_srwt_aiter(iter_to_aiter([(reader, writer, on_connect)])))
|
||||
return True
|
||||
|
||||
async def _add_to_srwt_aiter(
|
||||
self,
|
||||
aiter: AsyncGenerator[Tuple[asyncio.StreamReader, asyncio.StreamWriter], None],
|
||||
aiter: AsyncGenerator[Tuple[asyncio.StreamReader, asyncio.StreamWriter,
|
||||
OnConnectFunc], None],
|
||||
):
|
||||
"""
|
||||
Adds all swrt from aiter into the instance variable srwt_aiter, adding them to the pipeline.
|
||||
@ -258,14 +252,16 @@ class ChiaServer:
|
||||
return asyncio.get_running_loop().create_task(serve_forever())
|
||||
|
||||
async def stream_reader_writer_to_connection(
|
||||
self, swrt: Tuple[asyncio.StreamReader, asyncio.StreamWriter], server_port: int
|
||||
self,
|
||||
swrt: Tuple[asyncio.StreamReader, asyncio.StreamWriter, OnConnectFunc],
|
||||
server_port: int
|
||||
) -> Connection:
|
||||
"""
|
||||
Maps a pair of (StreamReader, StreamWriter) to a Connection object,
|
||||
Maps a tuple of (StreamReader, StreamWriter, on_connect) to a Connection object,
|
||||
which also stores the type of connection (str). It is also added to the global list.
|
||||
"""
|
||||
sr, sw = swrt
|
||||
con = Connection(self._local_type, None, sr, sw, server_port)
|
||||
sr, sw, on_connect = swrt
|
||||
con = Connection(self._local_type, None, sr, sw, server_port, on_connect)
|
||||
|
||||
log.info(f"Connection with {con.get_peername()} established")
|
||||
return con
|
||||
@ -276,14 +272,10 @@ class ChiaServer:
|
||||
"""
|
||||
Async generator which calls the on_connect async generator method, and yields any outbound messages.
|
||||
"""
|
||||
peer = PeerInfo(connection.peer_host, connection.peer_port)
|
||||
if peer in self._on_connect_callbacks:
|
||||
on_connect = self._on_connect_callbacks[peer]
|
||||
async for outbound_message in on_connect():
|
||||
yield connection, outbound_message
|
||||
if self._on_connect_generic_callback:
|
||||
async for outbound_message in self._on_connect_generic_callback():
|
||||
yield connection, outbound_message
|
||||
for func in connection.on_connect, self._on_inbound_connect:
|
||||
if func:
|
||||
async for outbound_message in func():
|
||||
yield connection, outbound_message
|
||||
|
||||
async def perform_handshake(
|
||||
self, connection: Connection
|
||||
|
Loading…
Reference in New Issue
Block a user