mirror of
https://github.com/Chia-Network/chia-blockchain.git
synced 2024-09-21 00:24:37 +03:00
Start refactoring using aiter
This commit is contained in:
parent
ff2915642e
commit
919fbca7fc
3
.gitmodules
vendored
3
.gitmodules
vendored
@ -4,3 +4,6 @@
|
||||
[submodule "lib/bip158/lib/pybind11"]
|
||||
path = lib/bip158/lib/pybind11
|
||||
url = https://github.com/pybind/pybind11.git
|
||||
[submodule "aiter"]
|
||||
path = lib/aiter
|
||||
url = https://github.com/richardkiss/aiter.git
|
||||
|
1
lib/aiter
Submodule
1
lib/aiter
Submodule
@ -0,0 +1 @@
|
||||
Subproject commit b69ce7166f28e73a193b6f694ecf441c99240145
|
@ -1 +1 @@
|
||||
Subproject commit b044611edd6271c233589a79a18dcefe377fa983
|
||||
Subproject commit af9744d29a9c347a215599985d93f231aaabc0de
|
11
src/server/connection.py
Normal file
11
src/server/connection.py
Normal file
@ -0,0 +1,11 @@
|
||||
from asyncio import StreamReader, StreamWriter
|
||||
|
||||
|
||||
class Connection:
|
||||
def __init__(self, connection_type: str, sr: StreamReader, sw: StreamWriter):
|
||||
self.connection_type = connection_type
|
||||
self.reader = sr
|
||||
self.writer = sw
|
||||
|
||||
def get_peername(self):
|
||||
return self.writer.get_extra_info("peername")
|
97
src/server/new_server.py
Normal file
97
src/server/new_server.py
Normal file
@ -0,0 +1,97 @@
|
||||
import logging
|
||||
import asyncio
|
||||
import functools
|
||||
from src.util import cbor
|
||||
from lib.aiter.aiter.server import start_server_aiter
|
||||
from lib.aiter import parallel_map_aiter, map_aiter, join_aiters
|
||||
from src.server.connection import Connection
|
||||
from src.server.peer_connections import PeerConnections
|
||||
|
||||
# Each message is prepended with LENGTH_BYTES bytes specifying the length
|
||||
LENGTH_BYTES: int = 5
|
||||
log = logging.getLogger(__name__)
|
||||
global_connections: PeerConnections = PeerConnections([])
|
||||
|
||||
|
||||
async def stream_reader_writer_to_connection(connection_type, pair):
|
||||
sr, sw = pair
|
||||
return Connection(connection_type, sr, sw)
|
||||
|
||||
|
||||
async def connection_to_message(connection):
|
||||
"""
|
||||
Async generator which yields complete binary messages from connections,
|
||||
along with a streamwriter to send back responses.
|
||||
"""
|
||||
|
||||
try:
|
||||
while not connection.reader.at_eof():
|
||||
size = await connection.reader.readexactly(LENGTH_BYTES)
|
||||
full_message_length = int.from_bytes(size, "big")
|
||||
full_message = await connection.reader.readexactly(full_message_length)
|
||||
yield (connection, full_message)
|
||||
except asyncio.IncompleteReadError:
|
||||
log.warn(f"Received EOF from {connection.get_peername()}, closing connection.")
|
||||
finally:
|
||||
# Removes the connection from the global list, so we don't try to send things to it
|
||||
with await global_connections.get_lock():
|
||||
await global_connections.remove(connection)
|
||||
connection.writer.close()
|
||||
return
|
||||
|
||||
|
||||
async def handle_message(api, pair):
|
||||
"""
|
||||
Async generator which takes messages, parses, them, executes the right
|
||||
api function, and yields responses (to same connection, propagated, etc).
|
||||
"""
|
||||
connection, message = pair
|
||||
decoded = cbor.loads(message)
|
||||
function: str = decoded["function"]
|
||||
function_data: bytes = decoded["data"]
|
||||
f = getattr(api, function)
|
||||
if f is not None:
|
||||
async for outbound_message in f(function_data):
|
||||
yield connection, outbound_message
|
||||
else:
|
||||
log.error(f'Invalid message: {function} from {connection.get_peername()}')
|
||||
|
||||
|
||||
async def expand_outbound_messages(pair):
|
||||
"""
|
||||
Expands each of the outbound messages into it's own message.
|
||||
"""
|
||||
connection, outbound_message = pair
|
||||
if not outbound_message.broadcast:
|
||||
if outbound_message.respond:
|
||||
yield connection, outbound_message
|
||||
else:
|
||||
with await global_connections.get_lock():
|
||||
for peer in await global_connections.get_connections():
|
||||
if peer.connection_type == outbound_message.peer_type:
|
||||
if peer == connection:
|
||||
if outbound_message.respond:
|
||||
yield connection, outbound_message
|
||||
else:
|
||||
yield connection, outbound_message
|
||||
|
||||
|
||||
async def start_chia_server(host, port, api, connection_type):
|
||||
server, aiter = await start_server_aiter(port, host=host)
|
||||
connections_aiter = map_aiter(
|
||||
functools.partial(stream_reader_writer_to_connection, connection_type),
|
||||
aiter)
|
||||
messages_aiter = join_aiters(parallel_map_aiter(connection_to_message, 100, connections_aiter))
|
||||
responses_aiter = join_aiters(parallel_map_aiter(
|
||||
functools.partial(handle_message, api),
|
||||
100, messages_aiter))
|
||||
|
||||
outbound_messages_aiter = join_aiters(parallel_map_aiter(
|
||||
expand_outbound_messages, 100, responses_aiter))
|
||||
|
||||
async for connection, outbound_message in outbound_messages_aiter:
|
||||
log.info(f"Sending {outbound_message.function} to peer {connection.get_peername()}")
|
||||
encoded: bytes = cbor.dumps({"function": outbound_message.function, "data": outbound_message.data})
|
||||
assert(len(encoded) < (2**(LENGTH_BYTES*8)))
|
||||
connection.writer.write(len(encoded).to_bytes(LENGTH_BYTES, "big") + encoded)
|
||||
await connection.writer.drain()
|
11
src/server/outbound_message.py
Normal file
11
src/server/outbound_message.py
Normal file
@ -0,0 +1,11 @@
|
||||
from typing import Any
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class OutboundMessage:
|
||||
peer_type: str
|
||||
function: str
|
||||
data: Any
|
||||
respond: bool
|
||||
broadcast: bool
|
@ -1,18 +1,18 @@
|
||||
from src.server.chia_connection import ChiaConnection
|
||||
from src.server.connection import Connection
|
||||
from asyncio import Lock
|
||||
from typing import List
|
||||
|
||||
|
||||
class PeerConnections():
|
||||
def __init__(self, all_connections: List[ChiaConnection] = []):
|
||||
class PeerConnections:
|
||||
def __init__(self, all_connections: List[Connection] = []):
|
||||
self._connections_lock = Lock()
|
||||
self._all_connections = all_connections
|
||||
|
||||
async def add(self, connection: ChiaConnection):
|
||||
async def add(self, connection: Connection):
|
||||
async with self._connections_lock:
|
||||
self._all_connections.append(connection)
|
||||
|
||||
async def remove(self, connection: ChiaConnection):
|
||||
async def remove(self, connection: Connection):
|
||||
async with self._connections_lock:
|
||||
self._all_connections.remove(connection)
|
||||
|
||||
|
@ -3,6 +3,7 @@ import logging
|
||||
from src import full_node
|
||||
from src.server.server import start_server, retry_connection
|
||||
from src.server.peer_connections import PeerConnections
|
||||
from src.server.new_server import start_chia_server
|
||||
|
||||
|
||||
logging.basicConfig(format='FullNode %(name)-23s: %(levelname)-8s %(message)s', level=logging.INFO)
|
||||
@ -16,9 +17,7 @@ async def main():
|
||||
timelord_con_fut = retry_connection(full_node, full_node.timelord_ip, full_node.timelord_port,
|
||||
"timelord", global_connections)
|
||||
# Starts the full node server (which full nodes can connect to)
|
||||
server = asyncio.create_task(start_server(full_node, '127.0.0.1',
|
||||
full_node.full_node_port, global_connections,
|
||||
"full_node"))
|
||||
await start_chia_server("127.0.0.1", full_node.full_node_port, full_node, "full_node")
|
||||
|
||||
# Both connections to farmer and timelord have been started
|
||||
await asyncio.gather(farmer_con_fut, timelord_con_fut)
|
||||
|
Loading…
Reference in New Issue
Block a user