This commit is contained in:
Yostra 2020-10-21 01:19:40 -07:00
parent c874bf3ad2
commit afb83e7923
13 changed files with 72 additions and 33 deletions

View File

@ -1,5 +1,5 @@
import asyncio
from typing import Optional
from typing import Optional, Callable
from blspy import AugSchemeMPL, G2Element
@ -21,6 +21,9 @@ class FarmerAPI:
def __init__(self, farmer):
self.farmer = farmer
def _set_state_changed_callback(self, callback: Callable):
self.farmer.state_changed_callback = callback
@api_request
async def challenge_response(
self,

View File

@ -1,7 +1,7 @@
import asyncio
import time
from pathlib import Path
from typing import AsyncGenerator, Dict, List, Optional, Tuple
from typing import AsyncGenerator, Dict, List, Optional, Tuple, Callable
from chiabip158 import PyBIP158
from chiapos import Verifier
from blspy import G2Element, AugSchemeMPL
@ -53,6 +53,13 @@ class FullNodeAPI:
def __init__(self, full_node):
self.full_node = full_node
def _set_state_changed_callback(self, callback: Callable):
self.full_node.state_changed_callback = callback
@property
def server(self):
return self.full_node.server
@api_request
async def request_peers(
self, request: full_node_protocol.RequestPeers, peer: WSChiaConnection

View File

@ -1,5 +1,5 @@
from pathlib import Path
from typing import Optional
from typing import Optional, Callable
from blspy import AugSchemeMPL, G2Element
from chiapos import DiskProver
@ -20,6 +20,9 @@ class HarvesterAPI:
def __init__(self, harvester):
self.harvester = harvester
def _set_state_changed_callback(self, callback: Callable):
self.harvester.state_changed_callback = callback
@api_request
async def harvester_handshake(
self,

View File

@ -1,4 +1,4 @@
from typing import Optional
from typing import Optional, Callable
from src.introducer import Introducer
from src.server.outbound_message import Message
@ -15,6 +15,9 @@ class IntroducerAPI:
def __init__(self, introducer):
self.introducer = introducer
def _set_state_changed_callback(self, callback: Callable):
self.introducer.state_changed_callback = callback
@api_request
async def request_peers(
self,

View File

@ -1,7 +1,8 @@
from src.full_node.full_node import FullNode
from typing import Callable, List, Optional, Dict
# from src.types.header import Header
from src.full_node.full_node_api import FullNodeAPI
#from src.types.header import Header
from src.types.full_block import FullBlock
from src.util.ints import uint32, uint64, uint128
from src.types.sized_bytes import bytes32
@ -12,8 +13,9 @@ from src.util.ws_message import create_payload
class FullNodeRpcApi:
def __init__(self, full_node: FullNode):
self.service = full_node
def __init__(self, full_node_api: FullNodeAPI):
self.service = full_node_api
self.full_node = full_node_api.full_node
self.service_name = "chia_full_node"
self.cached_blockchain_state: Optional[Dict] = None

View File

@ -87,9 +87,9 @@ class RpcServer:
return inner
async def get_connections(self, request: Dict) -> Dict:
if self.rpc_api.service.global_connections is None:
if self.rpc_api.service.server is None:
raise ValueError("Global connections is not set")
connections = self.rpc_api.service.global_connections.get_connections()
connections = self.rpc_api.service.server.get_connections()
con_info = [
{
"type": con.connection_type,
@ -98,7 +98,7 @@ class RpcServer:
"peer_host": con.peer_host,
"peer_port": con.peer_port,
"peer_server_port": con.peer_server_port,
"node_id": con.node_id,
"node_id": con.peer_node_id,
"creation_time": con.creation_time,
"bytes_read": con.bytes_read,
"bytes_written": con.bytes_written,
@ -123,17 +123,17 @@ class RpcServer:
async def close_connection(self, request: Dict):
node_id = hexstr_to_bytes(request["node_id"])
if self.rpc_api.service.global_connections is None:
if self.rpc_api.service.server is None:
raise aiohttp.web.HTTPInternalServerError()
connections_to_close = [
c
for c in self.rpc_api.service.global_connections.get_connections()
if c.node_id == node_id
for c in self.rpc_api.service.server.get_connections()
if c.peer_node_id == node_id
]
if len(connections_to_close) == 0:
raise ValueError(f"Connection with node_id {node_id.hex()} does not exist")
for connection in connections_to_close:
self.rpc_api.service.global_connections.close(connection)
await connection.close()
return {}
async def stop_node(self, request):

View File

@ -287,13 +287,9 @@ class WalletRpcApi:
async def farm_block(self, request):
raw_puzzle_hash = decode_puzzle_hash(request["address"])
request = FarmNewBlockProtocol(raw_puzzle_hash)
msg = OutboundMessage(
NodeType.FULL_NODE,
Message("farm_new_block", request),
Delivery.BROADCAST,
)
msg = Message("farm_new_block", request)
self.service.server.push_message(msg)
await self.service.server.send_to_all([msg], NodeType.FULL_NODE)
return {}
##########################################################################################

View File

@ -99,6 +99,7 @@ class ChiaServer:
False,
request.remote,
self.incoming_messages,
self.connection_closed,
close_event,
)
handshake = await connection.perform_handshake(
@ -159,6 +160,7 @@ class ChiaServer:
False,
target_node.host,
self.incoming_messages,
self.connection_closed,
session=session,
)
handshake = await connection.perform_handshake(
@ -183,7 +185,7 @@ class ChiaServer:
return False
def connection_disconnected(self, connection: WSChiaConnection):
def connection_closed(self, connection: WSChiaConnection):
self.log.info(f"connection with disconnected: {connection.peer_host}")
if connection.peer_node_id in self.global_connections:
self.global_connections.pop(connection.peer_node_id)
@ -267,6 +269,12 @@ class ChiaServer:
return result
def get_connections(self) -> List[WSChiaConnection]:
result = []
for id, connection in self.global_connections.items():
result.append(connection)
return result
async def close_all_connections(self):
for id, connection in self.global_connections.items():
try:

View File

@ -40,6 +40,7 @@ class WSChiaConnection:
is_feeler: bool, # Special type of connection, that disconnects after the handshake.
peer_host,
incoming_queue,
close_callback: Callable,
close_event=None,
session=None,
):
@ -50,6 +51,12 @@ class WSChiaConnection:
self.local_port = server_port
# Remote properties
self.peer_host = peer_host
connection_host, connection_port = self.ws._writer.transport.get_extra_info(
"peername"
)
self.peer_port = connection_port
self.peer_server_port: Optional[int] = None
self.peer_node_id = None
@ -76,6 +83,7 @@ class WSChiaConnection:
self.active = False # once handshake is successful this will be changed to True
self.close_event = close_event
self.session = session
self.close_callback = close_callback
async def perform_handshake(
self, network_id, protocol_version, node_id, server_port, local_type
@ -146,11 +154,10 @@ class WSChiaConnection:
self.outbound_task.cancel()
if self.session is not None:
await self.session.close()
await self.closed()
async def closed(self):
if self.close_event is not None:
self.close_event.set()
self.close_callback(self)
async def outbound_handler(self):
try:

View File

@ -1,3 +1,5 @@
from typing import Callable
from src.protocols import timelord_protocol
from src.server.ws_connection import WSChiaConnection
from src.timelord import Timelord
@ -10,6 +12,9 @@ class TimelordAPI:
def __init__(self, timelord):
self.timelord = timelord
def _set_state_changed_callback(self, callback: Callable):
self.timelord.state_changed_callback = callback
@api_request
async def challenge_start(
self, challenge_start: timelord_protocol.ChallengeStart, peer: WSChiaConnection

View File

@ -23,10 +23,12 @@ class TestRpc:
blocks = bt.get_consecutive_blocks(test_constants, num_blocks, [], 10)
for i in range(1, num_blocks):
async for _ in full_node_1.respond_unfinished_block(full_node_protocol.RespondUnfinishedBlock(blocks[i])):
pass
async for _ in full_node_1.respond_block(full_node_protocol.RespondBlock(blocks[i])):
pass
await full_node_1.respond_unfinished_block(
full_node_protocol.RespondUnfinishedBlock(blocks[i]), None
)
await full_node_1.full_node._respond_block(
full_node_protocol.RespondBlock(blocks[i])
)
def stop_node_cb():
full_node_1._close()

View File

@ -27,7 +27,8 @@ class TestWalletRpc:
test_rpc_port = uint16(21529)
num_blocks = 5
full_nodes, wallets = two_wallet_nodes
full_node_1, server_1 = full_nodes[0]
full_node_api = full_nodes[0]
full_node_server = full_node_api.full_node.server
wallet_node, server_2 = wallets[0]
wallet_node_2, server_3 = wallets[1]
wallet = wallet_node.wallet_state_manager.main_wallet
@ -35,10 +36,12 @@ class TestWalletRpc:
ph = await wallet.get_new_puzzlehash()
ph_2 = await wallet_2.get_new_puzzlehash()
await server_2.start_client(PeerInfo("localhost", uint16(server_1._port)), None)
await server_2.start_client(
PeerInfo("localhost", uint16(full_node_server._port)), None
)
for i in range(0, num_blocks):
await full_node_1.farm_new_block(FarmNewBlockProtocol(ph))
await full_node_api.farm_new_block(FarmNewBlockProtocol(ph), None)
initial_funds = sum(
[

View File

@ -263,10 +263,10 @@ async def setup_two_nodes(consensus_constants: ConsensusConstants):
setup_full_node(consensus_constants, "blockchain_test_2.db", 21235, simulator=False),
]
fn1, s1 = await node_iters[0].__anext__()
fn2, s2 = await node_iters[1].__anext__()
fn1 = await node_iters[0].__anext__()
fn2 = await node_iters[1].__anext__()
yield fn1, fn2, s1, s2
yield (fn1, fn2, fn1.full_node.server, fn2.full_node.server)
await _teardown_nodes(node_iters)