From afb83e7923cb74bbca39a85209cb79e4e0733bb7 Mon Sep 17 00:00:00 2001 From: Yostra Date: Wed, 21 Oct 2020 01:19:40 -0700 Subject: [PATCH] rpcs --- src/farmer_api.py | 5 ++++- src/full_node/full_node_api.py | 9 ++++++++- src/harvester_api.py | 5 ++++- src/introducer_api.py | 5 ++++- src/rpc/full_node_rpc_api.py | 8 +++++--- src/rpc/rpc_server.py | 14 +++++++------- src/rpc/wallet_rpc_api.py | 8 ++------ src/server/server.py | 10 +++++++++- src/server/ws_connection.py | 11 +++++++++-- src/timelord_api.py | 5 +++++ tests/rpc/test_full_node_rpc.py | 10 ++++++---- tests/rpc/test_wallet_rpc.py | 9 ++++++--- tests/setup_nodes.py | 6 +++--- 13 files changed, 72 insertions(+), 33 deletions(-) diff --git a/src/farmer_api.py b/src/farmer_api.py index d396c385e60cb..c13a77c138f3c 100644 --- a/src/farmer_api.py +++ b/src/farmer_api.py @@ -1,5 +1,5 @@ import asyncio -from typing import Optional +from typing import Optional, Callable from blspy import AugSchemeMPL, G2Element @@ -21,6 +21,9 @@ class FarmerAPI: def __init__(self, farmer): self.farmer = farmer + def _set_state_changed_callback(self, callback: Callable): + self.farmer.state_changed_callback = callback + @api_request async def challenge_response( self, diff --git a/src/full_node/full_node_api.py b/src/full_node/full_node_api.py index f20dfef43ff57..d868a066592a2 100644 --- a/src/full_node/full_node_api.py +++ b/src/full_node/full_node_api.py @@ -1,7 +1,7 @@ import asyncio import time from pathlib import Path -from typing import AsyncGenerator, Dict, List, Optional, Tuple +from typing import AsyncGenerator, Dict, List, Optional, Tuple, Callable from chiabip158 import PyBIP158 from chiapos import Verifier from blspy import G2Element, AugSchemeMPL @@ -53,6 +53,13 @@ class FullNodeAPI: def __init__(self, full_node): self.full_node = full_node + def _set_state_changed_callback(self, callback: Callable): + self.full_node.state_changed_callback = callback + + @property + def server(self): + return self.full_node.server + @api_request async def request_peers( self, request: full_node_protocol.RequestPeers, peer: WSChiaConnection diff --git a/src/harvester_api.py b/src/harvester_api.py index 8718d5830d5d4..006979a942fd8 100644 --- a/src/harvester_api.py +++ b/src/harvester_api.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import Optional +from typing import Optional, Callable from blspy import AugSchemeMPL, G2Element from chiapos import DiskProver @@ -20,6 +20,9 @@ class HarvesterAPI: def __init__(self, harvester): self.harvester = harvester + def _set_state_changed_callback(self, callback: Callable): + self.harvester.state_changed_callback = callback + @api_request async def harvester_handshake( self, diff --git a/src/introducer_api.py b/src/introducer_api.py index 06c23a3cfab19..c4e761be4e548 100644 --- a/src/introducer_api.py +++ b/src/introducer_api.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Optional, Callable from src.introducer import Introducer from src.server.outbound_message import Message @@ -15,6 +15,9 @@ class IntroducerAPI: def __init__(self, introducer): self.introducer = introducer + def _set_state_changed_callback(self, callback: Callable): + self.introducer.state_changed_callback = callback + @api_request async def request_peers( self, diff --git a/src/rpc/full_node_rpc_api.py b/src/rpc/full_node_rpc_api.py index affdb065a1d2a..63c068d37973a 100644 --- a/src/rpc/full_node_rpc_api.py +++ b/src/rpc/full_node_rpc_api.py @@ -1,7 +1,8 @@ from src.full_node.full_node import FullNode from typing import Callable, List, Optional, Dict -# from src.types.header import Header +from src.full_node.full_node_api import FullNodeAPI +#from src.types.header import Header from src.types.full_block import FullBlock from src.util.ints import uint32, uint64, uint128 from src.types.sized_bytes import bytes32 @@ -12,8 +13,9 @@ from src.util.ws_message import create_payload class FullNodeRpcApi: - def __init__(self, full_node: FullNode): - self.service = full_node + def __init__(self, full_node_api: FullNodeAPI): + self.service = full_node_api + self.full_node = full_node_api.full_node self.service_name = "chia_full_node" self.cached_blockchain_state: Optional[Dict] = None diff --git a/src/rpc/rpc_server.py b/src/rpc/rpc_server.py index e226187af12cf..94a40a1ec2d23 100644 --- a/src/rpc/rpc_server.py +++ b/src/rpc/rpc_server.py @@ -87,9 +87,9 @@ class RpcServer: return inner async def get_connections(self, request: Dict) -> Dict: - if self.rpc_api.service.global_connections is None: + if self.rpc_api.service.server is None: raise ValueError("Global connections is not set") - connections = self.rpc_api.service.global_connections.get_connections() + connections = self.rpc_api.service.server.get_connections() con_info = [ { "type": con.connection_type, @@ -98,7 +98,7 @@ class RpcServer: "peer_host": con.peer_host, "peer_port": con.peer_port, "peer_server_port": con.peer_server_port, - "node_id": con.node_id, + "node_id": con.peer_node_id, "creation_time": con.creation_time, "bytes_read": con.bytes_read, "bytes_written": con.bytes_written, @@ -123,17 +123,17 @@ class RpcServer: async def close_connection(self, request: Dict): node_id = hexstr_to_bytes(request["node_id"]) - if self.rpc_api.service.global_connections is None: + if self.rpc_api.service.server is None: raise aiohttp.web.HTTPInternalServerError() connections_to_close = [ c - for c in self.rpc_api.service.global_connections.get_connections() - if c.node_id == node_id + for c in self.rpc_api.service.server.get_connections() + if c.peer_node_id == node_id ] if len(connections_to_close) == 0: raise ValueError(f"Connection with node_id {node_id.hex()} does not exist") for connection in connections_to_close: - self.rpc_api.service.global_connections.close(connection) + await connection.close() return {} async def stop_node(self, request): diff --git a/src/rpc/wallet_rpc_api.py b/src/rpc/wallet_rpc_api.py index 2c9f7127021a6..c4441944e6460 100644 --- a/src/rpc/wallet_rpc_api.py +++ b/src/rpc/wallet_rpc_api.py @@ -287,13 +287,9 @@ class WalletRpcApi: async def farm_block(self, request): raw_puzzle_hash = decode_puzzle_hash(request["address"]) request = FarmNewBlockProtocol(raw_puzzle_hash) - msg = OutboundMessage( - NodeType.FULL_NODE, - Message("farm_new_block", request), - Delivery.BROADCAST, - ) + msg = Message("farm_new_block", request) - self.service.server.push_message(msg) + await self.service.server.send_to_all([msg], NodeType.FULL_NODE) return {} ########################################################################################## diff --git a/src/server/server.py b/src/server/server.py index c244447134798..138175a55407d 100644 --- a/src/server/server.py +++ b/src/server/server.py @@ -99,6 +99,7 @@ class ChiaServer: False, request.remote, self.incoming_messages, + self.connection_closed, close_event, ) handshake = await connection.perform_handshake( @@ -159,6 +160,7 @@ class ChiaServer: False, target_node.host, self.incoming_messages, + self.connection_closed, session=session, ) handshake = await connection.perform_handshake( @@ -183,7 +185,7 @@ class ChiaServer: return False - def connection_disconnected(self, connection: WSChiaConnection): + def connection_closed(self, connection: WSChiaConnection): self.log.info(f"connection with disconnected: {connection.peer_host}") if connection.peer_node_id in self.global_connections: self.global_connections.pop(connection.peer_node_id) @@ -267,6 +269,12 @@ class ChiaServer: return result + def get_connections(self) -> List[WSChiaConnection]: + result = [] + for id, connection in self.global_connections.items(): + result.append(connection) + return result + async def close_all_connections(self): for id, connection in self.global_connections.items(): try: diff --git a/src/server/ws_connection.py b/src/server/ws_connection.py index 240bbb92fa11e..05f4fd38266b7 100644 --- a/src/server/ws_connection.py +++ b/src/server/ws_connection.py @@ -40,6 +40,7 @@ class WSChiaConnection: is_feeler: bool, # Special type of connection, that disconnects after the handshake. peer_host, incoming_queue, + close_callback: Callable, close_event=None, session=None, ): @@ -50,6 +51,12 @@ class WSChiaConnection: self.local_port = server_port # Remote properties self.peer_host = peer_host + + connection_host, connection_port = self.ws._writer.transport.get_extra_info( + "peername" + ) + + self.peer_port = connection_port self.peer_server_port: Optional[int] = None self.peer_node_id = None @@ -76,6 +83,7 @@ class WSChiaConnection: self.active = False # once handshake is successful this will be changed to True self.close_event = close_event self.session = session + self.close_callback = close_callback async def perform_handshake( self, network_id, protocol_version, node_id, server_port, local_type @@ -146,11 +154,10 @@ class WSChiaConnection: self.outbound_task.cancel() if self.session is not None: await self.session.close() - await self.closed() - async def closed(self): if self.close_event is not None: self.close_event.set() + self.close_callback(self) async def outbound_handler(self): try: diff --git a/src/timelord_api.py b/src/timelord_api.py index a22470236c372..cf5e8a74071b6 100644 --- a/src/timelord_api.py +++ b/src/timelord_api.py @@ -1,3 +1,5 @@ +from typing import Callable + from src.protocols import timelord_protocol from src.server.ws_connection import WSChiaConnection from src.timelord import Timelord @@ -10,6 +12,9 @@ class TimelordAPI: def __init__(self, timelord): self.timelord = timelord + def _set_state_changed_callback(self, callback: Callable): + self.timelord.state_changed_callback = callback + @api_request async def challenge_start( self, challenge_start: timelord_protocol.ChallengeStart, peer: WSChiaConnection diff --git a/tests/rpc/test_full_node_rpc.py b/tests/rpc/test_full_node_rpc.py index cc2c19a6afdd2..425c0eedca494 100644 --- a/tests/rpc/test_full_node_rpc.py +++ b/tests/rpc/test_full_node_rpc.py @@ -23,10 +23,12 @@ class TestRpc: blocks = bt.get_consecutive_blocks(test_constants, num_blocks, [], 10) for i in range(1, num_blocks): - async for _ in full_node_1.respond_unfinished_block(full_node_protocol.RespondUnfinishedBlock(blocks[i])): - pass - async for _ in full_node_1.respond_block(full_node_protocol.RespondBlock(blocks[i])): - pass + await full_node_1.respond_unfinished_block( + full_node_protocol.RespondUnfinishedBlock(blocks[i]), None + ) + await full_node_1.full_node._respond_block( + full_node_protocol.RespondBlock(blocks[i]) + ) def stop_node_cb(): full_node_1._close() diff --git a/tests/rpc/test_wallet_rpc.py b/tests/rpc/test_wallet_rpc.py index 729e8f52da3ef..84298e39a2e5f 100644 --- a/tests/rpc/test_wallet_rpc.py +++ b/tests/rpc/test_wallet_rpc.py @@ -27,7 +27,8 @@ class TestWalletRpc: test_rpc_port = uint16(21529) num_blocks = 5 full_nodes, wallets = two_wallet_nodes - full_node_1, server_1 = full_nodes[0] + full_node_api = full_nodes[0] + full_node_server = full_node_api.full_node.server wallet_node, server_2 = wallets[0] wallet_node_2, server_3 = wallets[1] wallet = wallet_node.wallet_state_manager.main_wallet @@ -35,10 +36,12 @@ class TestWalletRpc: ph = await wallet.get_new_puzzlehash() ph_2 = await wallet_2.get_new_puzzlehash() - await server_2.start_client(PeerInfo("localhost", uint16(server_1._port)), None) + await server_2.start_client( + PeerInfo("localhost", uint16(full_node_server._port)), None + ) for i in range(0, num_blocks): - await full_node_1.farm_new_block(FarmNewBlockProtocol(ph)) + await full_node_api.farm_new_block(FarmNewBlockProtocol(ph), None) initial_funds = sum( [ diff --git a/tests/setup_nodes.py b/tests/setup_nodes.py index 5ac34da64e47a..0f74004f49d4c 100644 --- a/tests/setup_nodes.py +++ b/tests/setup_nodes.py @@ -263,10 +263,10 @@ async def setup_two_nodes(consensus_constants: ConsensusConstants): setup_full_node(consensus_constants, "blockchain_test_2.db", 21235, simulator=False), ] - fn1, s1 = await node_iters[0].__anext__() - fn2, s2 = await node_iters[1].__anext__() + fn1 = await node_iters[0].__anext__() + fn2 = await node_iters[1].__anext__() - yield fn1, fn2, s1, s2 + yield (fn1, fn2, fn1.full_node.server, fn2.full_node.server) await _teardown_nodes(node_iters)