Start refactoring using aiter

This commit is contained in:
Mariano Sorgente 2019-08-15 17:31:19 -03:00
parent ff2915642e
commit 919fbca7fc
8 changed files with 131 additions and 9 deletions

3
.gitmodules vendored
View File

@ -4,3 +4,6 @@
[submodule "lib/bip158/lib/pybind11"]
path = lib/bip158/lib/pybind11
url = https://github.com/pybind/pybind11.git
[submodule "aiter"]
path = lib/aiter
url = https://github.com/richardkiss/aiter.git

1
lib/aiter Submodule

@ -0,0 +1 @@
Subproject commit b69ce7166f28e73a193b6f694ecf441c99240145

@ -1 +1 @@
Subproject commit b044611edd6271c233589a79a18dcefe377fa983
Subproject commit af9744d29a9c347a215599985d93f231aaabc0de

11
src/server/connection.py Normal file
View File

@ -0,0 +1,11 @@
from asyncio import StreamReader, StreamWriter
class Connection:
def __init__(self, connection_type: str, sr: StreamReader, sw: StreamWriter):
self.connection_type = connection_type
self.reader = sr
self.writer = sw
def get_peername(self):
return self.writer.get_extra_info("peername")

97
src/server/new_server.py Normal file
View File

@ -0,0 +1,97 @@
import logging
import asyncio
import functools
from src.util import cbor
from lib.aiter.aiter.server import start_server_aiter
from lib.aiter import parallel_map_aiter, map_aiter, join_aiters
from src.server.connection import Connection
from src.server.peer_connections import PeerConnections
# Each message is prepended with LENGTH_BYTES bytes specifying the length
LENGTH_BYTES: int = 5
log = logging.getLogger(__name__)
global_connections: PeerConnections = PeerConnections([])
async def stream_reader_writer_to_connection(connection_type, pair):
sr, sw = pair
return Connection(connection_type, sr, sw)
async def connection_to_message(connection):
"""
Async generator which yields complete binary messages from connections,
along with a streamwriter to send back responses.
"""
try:
while not connection.reader.at_eof():
size = await connection.reader.readexactly(LENGTH_BYTES)
full_message_length = int.from_bytes(size, "big")
full_message = await connection.reader.readexactly(full_message_length)
yield (connection, full_message)
except asyncio.IncompleteReadError:
log.warn(f"Received EOF from {connection.get_peername()}, closing connection.")
finally:
# Removes the connection from the global list, so we don't try to send things to it
with await global_connections.get_lock():
await global_connections.remove(connection)
connection.writer.close()
return
async def handle_message(api, pair):
"""
Async generator which takes messages, parses, them, executes the right
api function, and yields responses (to same connection, propagated, etc).
"""
connection, message = pair
decoded = cbor.loads(message)
function: str = decoded["function"]
function_data: bytes = decoded["data"]
f = getattr(api, function)
if f is not None:
async for outbound_message in f(function_data):
yield connection, outbound_message
else:
log.error(f'Invalid message: {function} from {connection.get_peername()}')
async def expand_outbound_messages(pair):
"""
Expands each of the outbound messages into it's own message.
"""
connection, outbound_message = pair
if not outbound_message.broadcast:
if outbound_message.respond:
yield connection, outbound_message
else:
with await global_connections.get_lock():
for peer in await global_connections.get_connections():
if peer.connection_type == outbound_message.peer_type:
if peer == connection:
if outbound_message.respond:
yield connection, outbound_message
else:
yield connection, outbound_message
async def start_chia_server(host, port, api, connection_type):
server, aiter = await start_server_aiter(port, host=host)
connections_aiter = map_aiter(
functools.partial(stream_reader_writer_to_connection, connection_type),
aiter)
messages_aiter = join_aiters(parallel_map_aiter(connection_to_message, 100, connections_aiter))
responses_aiter = join_aiters(parallel_map_aiter(
functools.partial(handle_message, api),
100, messages_aiter))
outbound_messages_aiter = join_aiters(parallel_map_aiter(
expand_outbound_messages, 100, responses_aiter))
async for connection, outbound_message in outbound_messages_aiter:
log.info(f"Sending {outbound_message.function} to peer {connection.get_peername()}")
encoded: bytes = cbor.dumps({"function": outbound_message.function, "data": outbound_message.data})
assert(len(encoded) < (2**(LENGTH_BYTES*8)))
connection.writer.write(len(encoded).to_bytes(LENGTH_BYTES, "big") + encoded)
await connection.writer.drain()

View File

@ -0,0 +1,11 @@
from typing import Any
from dataclasses import dataclass
@dataclass
class OutboundMessage:
peer_type: str
function: str
data: Any
respond: bool
broadcast: bool

View File

@ -1,18 +1,18 @@
from src.server.chia_connection import ChiaConnection
from src.server.connection import Connection
from asyncio import Lock
from typing import List
class PeerConnections():
def __init__(self, all_connections: List[ChiaConnection] = []):
class PeerConnections:
def __init__(self, all_connections: List[Connection] = []):
self._connections_lock = Lock()
self._all_connections = all_connections
async def add(self, connection: ChiaConnection):
async def add(self, connection: Connection):
async with self._connections_lock:
self._all_connections.append(connection)
async def remove(self, connection: ChiaConnection):
async def remove(self, connection: Connection):
async with self._connections_lock:
self._all_connections.remove(connection)

View File

@ -3,6 +3,7 @@ import logging
from src import full_node
from src.server.server import start_server, retry_connection
from src.server.peer_connections import PeerConnections
from src.server.new_server import start_chia_server
logging.basicConfig(format='FullNode %(name)-23s: %(levelname)-8s %(message)s', level=logging.INFO)
@ -16,9 +17,7 @@ async def main():
timelord_con_fut = retry_connection(full_node, full_node.timelord_ip, full_node.timelord_port,
"timelord", global_connections)
# Starts the full node server (which full nodes can connect to)
server = asyncio.create_task(start_server(full_node, '127.0.0.1',
full_node.full_node_port, global_connections,
"full_node"))
await start_chia_server("127.0.0.1", full_node.full_node_port, full_node, "full_node")
# Both connections to farmer and timelord have been started
await asyncio.gather(farmer_con_fut, timelord_con_fut)