server: More use of ChiaServer.get_connections (#13573)

* Drop `get_full_node_connections` -> use `get_connections`

* Drop `connection_by_type`

I think the housekeeping required to have it isn't worth it since we
don't handle huge numbers of connections and `get_connections` should be
less enough overhead. Im open for complains though.
This commit is contained in:
dustinface 2022-10-05 18:21:38 +02:00 committed by GitHub
parent aed8c61969
commit 7a5d9579ef
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 26 additions and 40 deletions

View File

@ -10,6 +10,7 @@ from chia.full_node.full_node import FullNode
from chia.full_node.generator import setup_generator_args from chia.full_node.generator import setup_generator_args
from chia.full_node.mempool_check_conditions import get_puzzle_and_solution_for_coin from chia.full_node.mempool_check_conditions import get_puzzle_and_solution_for_coin
from chia.rpc.rpc_server import Endpoint, EndpointResult from chia.rpc.rpc_server import Endpoint, EndpointResult
from chia.server.outbound_message import NodeType
from chia.types.blockchain_format.coin import Coin from chia.types.blockchain_format.coin import Coin
from chia.types.blockchain_format.program import Program, SerializedProgram from chia.types.blockchain_format.program import Program, SerializedProgram
from chia.types.blockchain_format.sized_bytes import bytes32 from chia.types.blockchain_format.sized_bytes import bytes32
@ -188,7 +189,7 @@ class FullNodeRpcApi:
mempool_min_fee_5m = 0 mempool_min_fee_5m = 0
mempool_max_total_cost = 0 mempool_max_total_cost = 0
if self.service.server is not None: if self.service.server is not None:
is_connected = len(self.service.server.get_full_node_connections()) > 0 or "simulator" in str( is_connected = len(self.service.server.get_connections(NodeType.FULL_NODE)) > 0 or "simulator" in str(
self.service.config.get("selected_network") self.service.config.get("selected_network")
) )
else: else:

View File

@ -447,7 +447,7 @@ class WalletRpcApi:
return {"network_name": network_name, "network_prefix": address_prefix} return {"network_name": network_name, "network_prefix": address_prefix}
async def push_tx(self, request: Dict) -> EndpointResult: async def push_tx(self, request: Dict) -> EndpointResult:
nodes = self.service.server.get_full_node_connections() nodes = self.service.server.get_connections(NodeType.FULL_NODE)
if len(nodes) == 0: if len(nodes) == 0:
raise ValueError("Wallet is not currently connected to any full node peers") raise ValueError("Wallet is not currently connected to any full node peers")
await self.service.push_tx(SpendBundle.from_bytes(hexstr_to_bytes(request["spend_bundle"]))) await self.service.push_tx(SpendBundle.from_bytes(hexstr_to_bytes(request["spend_bundle"])))

View File

@ -456,7 +456,7 @@ class FullNodeDiscovery:
await asyncio.sleep(cleanup_interval) await asyncio.sleep(cleanup_interval)
# Perform the cleanup only if we have at least 3 connections. # Perform the cleanup only if we have at least 3 connections.
full_node_connected = self.server.get_full_node_connections() full_node_connected = self.server.get_connections(NodeType.FULL_NODE)
connected = [c.get_peer_info() for c in full_node_connected] connected = [c.get_peer_info() for c in full_node_connected]
connected = [c for c in connected if c is not None] connected = [c for c in connected if c is not None]
if self.address_manager is not None and len(connected) >= 3: if self.address_manager is not None and len(connected) >= 3:
@ -642,7 +642,7 @@ class FullNodePeers(FullNodeDiscovery):
if not relay_peer_info.is_valid(): if not relay_peer_info.is_valid():
continue continue
# https://en.bitcoin.it/wiki/Satoshi_Client_Node_Discovery#Address_Relay # https://en.bitcoin.it/wiki/Satoshi_Client_Node_Discovery#Address_Relay
connections = self.server.get_full_node_connections() connections = self.server.get_connections(NodeType.FULL_NODE)
hashes = [] hashes = []
cur_day = int(time.time()) // (24 * 60 * 60) cur_day = int(time.time()) // (24 * 60 * 60)
for connection in connections: for connection in connections:

View File

@ -136,16 +136,6 @@ class ChiaServer:
# Keeps track of all connections to and from this node. # Keeps track of all connections to and from this node.
self.all_connections: Dict[bytes32, WSChiaConnection] = {} self.all_connections: Dict[bytes32, WSChiaConnection] = {}
self.connection_by_type: Dict[NodeType, Dict[bytes32, WSChiaConnection]] = {
NodeType.FULL_NODE: {},
NodeType.DATA_LAYER: {},
NodeType.WALLET: {},
NodeType.HARVESTER: {},
NodeType.FARMER: {},
NodeType.TIMELORD: {},
NodeType.INTRODUCER: {},
}
self._port = port # TCP port to identify our node self._port = port # TCP port to identify our node
self._local_type: NodeType = local_type self._local_type: NodeType = local_type
self._local_capabilities_for_handshake = capabilities self._local_capabilities_for_handshake = capabilities
@ -376,7 +366,6 @@ class ChiaServer:
await con.close() await con.close()
self.all_connections[connection.peer_node_id] = connection self.all_connections[connection.peer_node_id] = connection
if connection.connection_type is not None: if connection.connection_type is not None:
self.connection_by_type[connection.connection_type][connection.peer_node_id] = connection
if on_connect is not None: if on_connect is not None:
await on_connect(connection) await on_connect(connection)
else: else:
@ -525,10 +514,7 @@ class ChiaServer:
if connection.peer_node_id in self.all_connections: if connection.peer_node_id in self.all_connections:
self.all_connections.pop(connection.peer_node_id) self.all_connections.pop(connection.peer_node_id)
if connection.connection_type is not None: if connection.connection_type is None:
if connection.peer_node_id in self.connection_by_type[connection.connection_type]:
self.connection_by_type[connection.connection_type].pop(connection.peer_node_id)
else:
# This means the handshake was never finished with this peer # This means the handshake was never finished with this peer
self.log.debug( self.log.debug(
f"Invalid connection type for connection {connection.peer_host}," f"Invalid connection type for connection {connection.peer_host},"
@ -706,15 +692,12 @@ class ChiaServer:
def get_full_node_outgoing_connections(self) -> List[WSChiaConnection]: def get_full_node_outgoing_connections(self) -> List[WSChiaConnection]:
result = [] result = []
connections = self.get_full_node_connections() connections = self.get_connections(NodeType.FULL_NODE)
for connection in connections: for connection in connections:
if connection.is_outbound: if connection.is_outbound:
result.append(connection) result.append(connection)
return result return result
def get_full_node_connections(self) -> List[WSChiaConnection]:
return list(self.connection_by_type[NodeType.FULL_NODE].values())
def get_connections(self, node_type: Optional[NodeType] = None) -> List[WSChiaConnection]: def get_connections(self, node_type: Optional[NodeType] = None) -> List[WSChiaConnection]:
result = [] result = []
for _, connection in self.all_connections.items(): for _, connection in self.all_connections.items():
@ -795,7 +778,7 @@ class ChiaServer:
def accept_inbound_connections(self, node_type: NodeType) -> bool: def accept_inbound_connections(self, node_type: NodeType) -> bool:
if not self._local_type == NodeType.FULL_NODE: if not self._local_type == NodeType.FULL_NODE:
return True return True
inbound_count = len([conn for _, conn in self.connection_by_type[node_type].items() if not conn.is_outbound]) inbound_count = len([conn for conn in self.get_connections(node_type) if not conn.is_outbound])
if node_type == NodeType.FULL_NODE: if node_type == NodeType.FULL_NODE:
return inbound_count < self.config["target_peer_count"] - self.config["target_outbound_peer_count"] return inbound_count < self.config["target_peer_count"] - self.config["target_outbound_peer_count"]
if node_type == NodeType.WALLET: if node_type == NodeType.WALLET:

View File

@ -352,7 +352,7 @@ class WalletNode:
for msg, sent_peers in await self._messages_to_resend(): for msg, sent_peers in await self._messages_to_resend():
if self._shut_down or self._server is None or self._wallet_state_manager is None: if self._shut_down or self._server is None or self._wallet_state_manager is None:
return None return None
full_nodes = self.server.get_full_node_connections() full_nodes = self.server.get_connections(NodeType.FULL_NODE)
for peer in full_nodes: for peer in full_nodes:
if peer.peer_node_id in sent_peers: if peer.peer_node_id in sent_peers:
continue continue
@ -404,14 +404,14 @@ class WalletNode:
# state updates until we are sure that we subscribed to everything that we need to. Otherwise, # state updates until we are sure that we subscribed to everything that we need to. Otherwise,
# we might not be able to process some state. # we might not be able to process some state.
coin_ids: List[bytes32] = item.data coin_ids: List[bytes32] = item.data
for peer in self.server.get_full_node_connections(): for peer in self.server.get_connections(NodeType.FULL_NODE):
coin_states: List[CoinState] = await subscribe_to_coin_updates(coin_ids, peer, uint32(0)) coin_states: List[CoinState] = await subscribe_to_coin_updates(coin_ids, peer, uint32(0))
if len(coin_states) > 0: if len(coin_states) > 0:
async with self.wallet_state_manager.lock: async with self.wallet_state_manager.lock:
await self.receive_state_from_peer(coin_states, peer) await self.receive_state_from_peer(coin_states, peer)
elif item.item_type == NewPeakQueueTypes.PUZZLE_HASH_SUBSCRIPTION: elif item.item_type == NewPeakQueueTypes.PUZZLE_HASH_SUBSCRIPTION:
puzzle_hashes: List[bytes32] = item.data puzzle_hashes: List[bytes32] = item.data
for peer in self.server.get_full_node_connections(): for peer in self.server.get_connections(NodeType.FULL_NODE):
# Puzzle hash subscription # Puzzle hash subscription
coin_states: List[CoinState] = await subscribe_to_phs(puzzle_hashes, peer, uint32(0)) coin_states: List[CoinState] = await subscribe_to_phs(puzzle_hashes, peer, uint32(0))
if len(coin_states) > 0: if len(coin_states) > 0:
@ -821,10 +821,10 @@ class WalletNode:
async def get_coins_with_puzzle_hash(self, puzzle_hash) -> List[CoinState]: async def get_coins_with_puzzle_hash(self, puzzle_hash) -> List[CoinState]:
# TODO Use trusted peer, otherwise try untrusted # TODO Use trusted peer, otherwise try untrusted
all_nodes = self.server.connection_by_type[NodeType.FULL_NODE] all_nodes = self.server.get_connections(NodeType.FULL_NODE)
if len(all_nodes.keys()) == 0: if len(all_nodes) == 0:
raise ValueError("Not connected to the full node") raise ValueError("Not connected to the full node")
first_node = list(all_nodes.values())[0] first_node = all_nodes[0]
msg = wallet_protocol.RegisterForPhUpdates(puzzle_hash, uint32(0)) msg = wallet_protocol.RegisterForPhUpdates(puzzle_hash, uint32(0))
coin_state: Optional[RespondToPhUpdates] = await first_node.register_interest_in_puzzle_hash(msg) coin_state: Optional[RespondToPhUpdates] = await first_node.register_interest_in_puzzle_hash(msg)
# TODO validate state if received from untrusted peer # TODO validate state if received from untrusted peer
@ -901,7 +901,7 @@ class WalletNode:
synced: List[WSChiaConnection] = [] synced: List[WSChiaConnection] = []
trusted: List[WSChiaConnection] = [] trusted: List[WSChiaConnection] = []
neither: List[WSChiaConnection] = [] neither: List[WSChiaConnection] = []
all_nodes: List[WSChiaConnection] = self.server.get_full_node_connections().copy() all_nodes: List[WSChiaConnection] = self.server.get_connections(NodeType.FULL_NODE)
random.shuffle(all_nodes) random.shuffle(all_nodes)
for node in all_nodes: for node in all_nodes:
we_synced_to_it = node.peer_node_id in self.synced_peers we_synced_to_it = node.peer_node_id in self.synced_peers
@ -921,8 +921,9 @@ class WalletNode:
return return
# Close connection of non-trusted peers # Close connection of non-trusted peers
if len(self.server.get_full_node_connections()) > 1: full_node_connections = self.server.get_connections(NodeType.FULL_NODE)
for peer in self.server.get_full_node_connections(): if len(full_node_connections) > 1:
for peer in full_node_connections:
if not self.is_trusted(peer): if not self.is_trusted(peer):
await peer.close() await peer.close()
@ -933,7 +934,7 @@ class WalletNode:
async def check_for_synced_trusted_peer(self, header_block: HeaderBlock, request_time: uint64) -> bool: async def check_for_synced_trusted_peer(self, header_block: HeaderBlock, request_time: uint64) -> bool:
if self._server is None: if self._server is None:
return False return False
for peer in self.server.get_full_node_connections(): for peer in self.server.get_connections(NodeType.FULL_NODE):
if self.is_trusted(peer) and await self.is_peer_synced(peer, header_block, request_time): if self.is_trusted(peer) and await self.is_peer_synced(peer, header_block, request_time):
return True return True
return False return False
@ -1469,7 +1470,7 @@ class WalletNode:
if end == 0: if end == 0:
self.log.error("Error finding sub epoch") self.log.error("Error finding sub epoch")
return False return False
all_peers_c = self.server.get_full_node_connections() all_peers_c = self.server.get_connections(NodeType.FULL_NODE)
all_peers = [(con, self.is_trusted(con)) for con in all_peers_c] all_peers = [(con, self.is_trusted(con)) for con in all_peers_c]
blocks: Optional[List[HeaderBlock]] = await fetch_header_blocks_in_range( blocks: Optional[List[HeaderBlock]] = await fetch_header_blocks_in_range(
start, end, peer_request_cache, all_peers start, end, peer_request_cache, all_peers
@ -1615,6 +1616,6 @@ class WalletNode:
ProtocolMessageTypes.send_transaction, ProtocolMessageTypes.send_transaction,
wallet_protocol.SendTransaction(spend_bundle), wallet_protocol.SendTransaction(spend_bundle),
) )
full_nodes = self.server.get_full_node_connections() full_nodes = self.server.get_connections(NodeType.FULL_NODE)
for peer in full_nodes: for peer in full_nodes:
await peer.send_message(msg) await peer.send_message(msg)

View File

@ -19,6 +19,7 @@ from chia.pools.pool_puzzles import SINGLETON_LAUNCHER_HASH, solution_to_pool_st
from chia.pools.pool_wallet import PoolWallet from chia.pools.pool_wallet import PoolWallet
from chia.protocols import wallet_protocol from chia.protocols import wallet_protocol
from chia.protocols.wallet_protocol import CoinState from chia.protocols.wallet_protocol import CoinState
from chia.server.outbound_message import NodeType
from chia.server.server import ChiaServer from chia.server.server import ChiaServer
from chia.server.ws_connection import WSChiaConnection from chia.server.ws_connection import WSChiaConnection
from chia.types.blockchain_format.coin import Coin from chia.types.blockchain_format.coin import Coin
@ -487,7 +488,7 @@ class WalletStateManager:
self.pending_tx_callback() self.pending_tx_callback()
async def synced(self): async def synced(self):
if len(self.server.get_full_node_connections()) == 0: if len(self.server.get_connections(NodeType.FULL_NODE)) == 0:
return False return False
latest = await self.blockchain.get_peak_block() latest = await self.blockchain.get_peak_block()

View File

@ -185,7 +185,7 @@ async def test_daemon_simulation(self_hostname, daemon_simulation):
await server1.start_client(PeerInfo(self_hostname, uint16(node2_port))) await server1.start_client(PeerInfo(self_hostname, uint16(node2_port)))
async def num_connections(): async def num_connections():
count = len(node2.server.connection_by_type[NodeType.FULL_NODE].items()) count = len(node2.server.get_connections(NodeType.FULL_NODE))
return count return count
await time_out_assert_custom_interval(60, 1, num_connections, 1) await time_out_assert_custom_interval(60, 1, num_connections, 1)

View File

@ -19,7 +19,7 @@ from chia.protocols.full_node_protocol import RespondTransaction
from chia.protocols.protocol_message_types import ProtocolMessageTypes from chia.protocols.protocol_message_types import ProtocolMessageTypes
from chia.protocols.wallet_protocol import SendTransaction, TransactionAck from chia.protocols.wallet_protocol import SendTransaction, TransactionAck
from chia.server.address_manager import AddressManager from chia.server.address_manager import AddressManager
from chia.server.outbound_message import Message from chia.server.outbound_message import Message, NodeType
from chia.simulator.simulator_protocol import FarmNewBlockProtocol from chia.simulator.simulator_protocol import FarmNewBlockProtocol
from chia.types.blockchain_format.classgroup import ClassgroupElement from chia.types.blockchain_format.classgroup import ClassgroupElement
from chia.types.blockchain_format.program import Program, SerializedProgram from chia.types.blockchain_format.program import Program, SerializedProgram
@ -400,7 +400,7 @@ class TestFullNodeProtocol:
full_node_i = nodes[i] full_node_i = nodes[i]
server_i = full_node_i.full_node.server server_i = full_node_i.full_node.server
await server_i.start_client(PeerInfo(self_hostname, uint16(server_1._port))) await server_i.start_client(PeerInfo(self_hostname, uint16(server_1._port)))
assert len(server_1.get_full_node_connections()) == 2 assert len(server_1.get_connections(NodeType.FULL_NODE)) == 2
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_request_peers(self, wallet_nodes, self_hostname): async def test_request_peers(self, wallet_nodes, self_hostname):