mirror of
https://github.com/Chia-Network/chia-blockchain.git
synced 2024-09-20 16:08:51 +03:00
server: More use of ChiaServer.get_connections
(#13573)
* Drop `get_full_node_connections` -> use `get_connections` * Drop `connection_by_type` I think the housekeeping required to have it isn't worth it since we don't handle huge numbers of connections and `get_connections` should be less enough overhead. Im open for complains though.
This commit is contained in:
parent
aed8c61969
commit
7a5d9579ef
@ -10,6 +10,7 @@ from chia.full_node.full_node import FullNode
|
|||||||
from chia.full_node.generator import setup_generator_args
|
from chia.full_node.generator import setup_generator_args
|
||||||
from chia.full_node.mempool_check_conditions import get_puzzle_and_solution_for_coin
|
from chia.full_node.mempool_check_conditions import get_puzzle_and_solution_for_coin
|
||||||
from chia.rpc.rpc_server import Endpoint, EndpointResult
|
from chia.rpc.rpc_server import Endpoint, EndpointResult
|
||||||
|
from chia.server.outbound_message import NodeType
|
||||||
from chia.types.blockchain_format.coin import Coin
|
from chia.types.blockchain_format.coin import Coin
|
||||||
from chia.types.blockchain_format.program import Program, SerializedProgram
|
from chia.types.blockchain_format.program import Program, SerializedProgram
|
||||||
from chia.types.blockchain_format.sized_bytes import bytes32
|
from chia.types.blockchain_format.sized_bytes import bytes32
|
||||||
@ -188,7 +189,7 @@ class FullNodeRpcApi:
|
|||||||
mempool_min_fee_5m = 0
|
mempool_min_fee_5m = 0
|
||||||
mempool_max_total_cost = 0
|
mempool_max_total_cost = 0
|
||||||
if self.service.server is not None:
|
if self.service.server is not None:
|
||||||
is_connected = len(self.service.server.get_full_node_connections()) > 0 or "simulator" in str(
|
is_connected = len(self.service.server.get_connections(NodeType.FULL_NODE)) > 0 or "simulator" in str(
|
||||||
self.service.config.get("selected_network")
|
self.service.config.get("selected_network")
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
@ -447,7 +447,7 @@ class WalletRpcApi:
|
|||||||
return {"network_name": network_name, "network_prefix": address_prefix}
|
return {"network_name": network_name, "network_prefix": address_prefix}
|
||||||
|
|
||||||
async def push_tx(self, request: Dict) -> EndpointResult:
|
async def push_tx(self, request: Dict) -> EndpointResult:
|
||||||
nodes = self.service.server.get_full_node_connections()
|
nodes = self.service.server.get_connections(NodeType.FULL_NODE)
|
||||||
if len(nodes) == 0:
|
if len(nodes) == 0:
|
||||||
raise ValueError("Wallet is not currently connected to any full node peers")
|
raise ValueError("Wallet is not currently connected to any full node peers")
|
||||||
await self.service.push_tx(SpendBundle.from_bytes(hexstr_to_bytes(request["spend_bundle"])))
|
await self.service.push_tx(SpendBundle.from_bytes(hexstr_to_bytes(request["spend_bundle"])))
|
||||||
|
@ -456,7 +456,7 @@ class FullNodeDiscovery:
|
|||||||
await asyncio.sleep(cleanup_interval)
|
await asyncio.sleep(cleanup_interval)
|
||||||
|
|
||||||
# Perform the cleanup only if we have at least 3 connections.
|
# Perform the cleanup only if we have at least 3 connections.
|
||||||
full_node_connected = self.server.get_full_node_connections()
|
full_node_connected = self.server.get_connections(NodeType.FULL_NODE)
|
||||||
connected = [c.get_peer_info() for c in full_node_connected]
|
connected = [c.get_peer_info() for c in full_node_connected]
|
||||||
connected = [c for c in connected if c is not None]
|
connected = [c for c in connected if c is not None]
|
||||||
if self.address_manager is not None and len(connected) >= 3:
|
if self.address_manager is not None and len(connected) >= 3:
|
||||||
@ -642,7 +642,7 @@ class FullNodePeers(FullNodeDiscovery):
|
|||||||
if not relay_peer_info.is_valid():
|
if not relay_peer_info.is_valid():
|
||||||
continue
|
continue
|
||||||
# https://en.bitcoin.it/wiki/Satoshi_Client_Node_Discovery#Address_Relay
|
# https://en.bitcoin.it/wiki/Satoshi_Client_Node_Discovery#Address_Relay
|
||||||
connections = self.server.get_full_node_connections()
|
connections = self.server.get_connections(NodeType.FULL_NODE)
|
||||||
hashes = []
|
hashes = []
|
||||||
cur_day = int(time.time()) // (24 * 60 * 60)
|
cur_day = int(time.time()) // (24 * 60 * 60)
|
||||||
for connection in connections:
|
for connection in connections:
|
||||||
|
@ -136,16 +136,6 @@ class ChiaServer:
|
|||||||
# Keeps track of all connections to and from this node.
|
# Keeps track of all connections to and from this node.
|
||||||
self.all_connections: Dict[bytes32, WSChiaConnection] = {}
|
self.all_connections: Dict[bytes32, WSChiaConnection] = {}
|
||||||
|
|
||||||
self.connection_by_type: Dict[NodeType, Dict[bytes32, WSChiaConnection]] = {
|
|
||||||
NodeType.FULL_NODE: {},
|
|
||||||
NodeType.DATA_LAYER: {},
|
|
||||||
NodeType.WALLET: {},
|
|
||||||
NodeType.HARVESTER: {},
|
|
||||||
NodeType.FARMER: {},
|
|
||||||
NodeType.TIMELORD: {},
|
|
||||||
NodeType.INTRODUCER: {},
|
|
||||||
}
|
|
||||||
|
|
||||||
self._port = port # TCP port to identify our node
|
self._port = port # TCP port to identify our node
|
||||||
self._local_type: NodeType = local_type
|
self._local_type: NodeType = local_type
|
||||||
self._local_capabilities_for_handshake = capabilities
|
self._local_capabilities_for_handshake = capabilities
|
||||||
@ -376,7 +366,6 @@ class ChiaServer:
|
|||||||
await con.close()
|
await con.close()
|
||||||
self.all_connections[connection.peer_node_id] = connection
|
self.all_connections[connection.peer_node_id] = connection
|
||||||
if connection.connection_type is not None:
|
if connection.connection_type is not None:
|
||||||
self.connection_by_type[connection.connection_type][connection.peer_node_id] = connection
|
|
||||||
if on_connect is not None:
|
if on_connect is not None:
|
||||||
await on_connect(connection)
|
await on_connect(connection)
|
||||||
else:
|
else:
|
||||||
@ -525,10 +514,7 @@ class ChiaServer:
|
|||||||
|
|
||||||
if connection.peer_node_id in self.all_connections:
|
if connection.peer_node_id in self.all_connections:
|
||||||
self.all_connections.pop(connection.peer_node_id)
|
self.all_connections.pop(connection.peer_node_id)
|
||||||
if connection.connection_type is not None:
|
if connection.connection_type is None:
|
||||||
if connection.peer_node_id in self.connection_by_type[connection.connection_type]:
|
|
||||||
self.connection_by_type[connection.connection_type].pop(connection.peer_node_id)
|
|
||||||
else:
|
|
||||||
# This means the handshake was never finished with this peer
|
# This means the handshake was never finished with this peer
|
||||||
self.log.debug(
|
self.log.debug(
|
||||||
f"Invalid connection type for connection {connection.peer_host},"
|
f"Invalid connection type for connection {connection.peer_host},"
|
||||||
@ -706,15 +692,12 @@ class ChiaServer:
|
|||||||
|
|
||||||
def get_full_node_outgoing_connections(self) -> List[WSChiaConnection]:
|
def get_full_node_outgoing_connections(self) -> List[WSChiaConnection]:
|
||||||
result = []
|
result = []
|
||||||
connections = self.get_full_node_connections()
|
connections = self.get_connections(NodeType.FULL_NODE)
|
||||||
for connection in connections:
|
for connection in connections:
|
||||||
if connection.is_outbound:
|
if connection.is_outbound:
|
||||||
result.append(connection)
|
result.append(connection)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def get_full_node_connections(self) -> List[WSChiaConnection]:
|
|
||||||
return list(self.connection_by_type[NodeType.FULL_NODE].values())
|
|
||||||
|
|
||||||
def get_connections(self, node_type: Optional[NodeType] = None) -> List[WSChiaConnection]:
|
def get_connections(self, node_type: Optional[NodeType] = None) -> List[WSChiaConnection]:
|
||||||
result = []
|
result = []
|
||||||
for _, connection in self.all_connections.items():
|
for _, connection in self.all_connections.items():
|
||||||
@ -795,7 +778,7 @@ class ChiaServer:
|
|||||||
def accept_inbound_connections(self, node_type: NodeType) -> bool:
|
def accept_inbound_connections(self, node_type: NodeType) -> bool:
|
||||||
if not self._local_type == NodeType.FULL_NODE:
|
if not self._local_type == NodeType.FULL_NODE:
|
||||||
return True
|
return True
|
||||||
inbound_count = len([conn for _, conn in self.connection_by_type[node_type].items() if not conn.is_outbound])
|
inbound_count = len([conn for conn in self.get_connections(node_type) if not conn.is_outbound])
|
||||||
if node_type == NodeType.FULL_NODE:
|
if node_type == NodeType.FULL_NODE:
|
||||||
return inbound_count < self.config["target_peer_count"] - self.config["target_outbound_peer_count"]
|
return inbound_count < self.config["target_peer_count"] - self.config["target_outbound_peer_count"]
|
||||||
if node_type == NodeType.WALLET:
|
if node_type == NodeType.WALLET:
|
||||||
|
@ -352,7 +352,7 @@ class WalletNode:
|
|||||||
for msg, sent_peers in await self._messages_to_resend():
|
for msg, sent_peers in await self._messages_to_resend():
|
||||||
if self._shut_down or self._server is None or self._wallet_state_manager is None:
|
if self._shut_down or self._server is None or self._wallet_state_manager is None:
|
||||||
return None
|
return None
|
||||||
full_nodes = self.server.get_full_node_connections()
|
full_nodes = self.server.get_connections(NodeType.FULL_NODE)
|
||||||
for peer in full_nodes:
|
for peer in full_nodes:
|
||||||
if peer.peer_node_id in sent_peers:
|
if peer.peer_node_id in sent_peers:
|
||||||
continue
|
continue
|
||||||
@ -404,14 +404,14 @@ class WalletNode:
|
|||||||
# state updates until we are sure that we subscribed to everything that we need to. Otherwise,
|
# state updates until we are sure that we subscribed to everything that we need to. Otherwise,
|
||||||
# we might not be able to process some state.
|
# we might not be able to process some state.
|
||||||
coin_ids: List[bytes32] = item.data
|
coin_ids: List[bytes32] = item.data
|
||||||
for peer in self.server.get_full_node_connections():
|
for peer in self.server.get_connections(NodeType.FULL_NODE):
|
||||||
coin_states: List[CoinState] = await subscribe_to_coin_updates(coin_ids, peer, uint32(0))
|
coin_states: List[CoinState] = await subscribe_to_coin_updates(coin_ids, peer, uint32(0))
|
||||||
if len(coin_states) > 0:
|
if len(coin_states) > 0:
|
||||||
async with self.wallet_state_manager.lock:
|
async with self.wallet_state_manager.lock:
|
||||||
await self.receive_state_from_peer(coin_states, peer)
|
await self.receive_state_from_peer(coin_states, peer)
|
||||||
elif item.item_type == NewPeakQueueTypes.PUZZLE_HASH_SUBSCRIPTION:
|
elif item.item_type == NewPeakQueueTypes.PUZZLE_HASH_SUBSCRIPTION:
|
||||||
puzzle_hashes: List[bytes32] = item.data
|
puzzle_hashes: List[bytes32] = item.data
|
||||||
for peer in self.server.get_full_node_connections():
|
for peer in self.server.get_connections(NodeType.FULL_NODE):
|
||||||
# Puzzle hash subscription
|
# Puzzle hash subscription
|
||||||
coin_states: List[CoinState] = await subscribe_to_phs(puzzle_hashes, peer, uint32(0))
|
coin_states: List[CoinState] = await subscribe_to_phs(puzzle_hashes, peer, uint32(0))
|
||||||
if len(coin_states) > 0:
|
if len(coin_states) > 0:
|
||||||
@ -821,10 +821,10 @@ class WalletNode:
|
|||||||
|
|
||||||
async def get_coins_with_puzzle_hash(self, puzzle_hash) -> List[CoinState]:
|
async def get_coins_with_puzzle_hash(self, puzzle_hash) -> List[CoinState]:
|
||||||
# TODO Use trusted peer, otherwise try untrusted
|
# TODO Use trusted peer, otherwise try untrusted
|
||||||
all_nodes = self.server.connection_by_type[NodeType.FULL_NODE]
|
all_nodes = self.server.get_connections(NodeType.FULL_NODE)
|
||||||
if len(all_nodes.keys()) == 0:
|
if len(all_nodes) == 0:
|
||||||
raise ValueError("Not connected to the full node")
|
raise ValueError("Not connected to the full node")
|
||||||
first_node = list(all_nodes.values())[0]
|
first_node = all_nodes[0]
|
||||||
msg = wallet_protocol.RegisterForPhUpdates(puzzle_hash, uint32(0))
|
msg = wallet_protocol.RegisterForPhUpdates(puzzle_hash, uint32(0))
|
||||||
coin_state: Optional[RespondToPhUpdates] = await first_node.register_interest_in_puzzle_hash(msg)
|
coin_state: Optional[RespondToPhUpdates] = await first_node.register_interest_in_puzzle_hash(msg)
|
||||||
# TODO validate state if received from untrusted peer
|
# TODO validate state if received from untrusted peer
|
||||||
@ -901,7 +901,7 @@ class WalletNode:
|
|||||||
synced: List[WSChiaConnection] = []
|
synced: List[WSChiaConnection] = []
|
||||||
trusted: List[WSChiaConnection] = []
|
trusted: List[WSChiaConnection] = []
|
||||||
neither: List[WSChiaConnection] = []
|
neither: List[WSChiaConnection] = []
|
||||||
all_nodes: List[WSChiaConnection] = self.server.get_full_node_connections().copy()
|
all_nodes: List[WSChiaConnection] = self.server.get_connections(NodeType.FULL_NODE)
|
||||||
random.shuffle(all_nodes)
|
random.shuffle(all_nodes)
|
||||||
for node in all_nodes:
|
for node in all_nodes:
|
||||||
we_synced_to_it = node.peer_node_id in self.synced_peers
|
we_synced_to_it = node.peer_node_id in self.synced_peers
|
||||||
@ -921,8 +921,9 @@ class WalletNode:
|
|||||||
return
|
return
|
||||||
|
|
||||||
# Close connection of non-trusted peers
|
# Close connection of non-trusted peers
|
||||||
if len(self.server.get_full_node_connections()) > 1:
|
full_node_connections = self.server.get_connections(NodeType.FULL_NODE)
|
||||||
for peer in self.server.get_full_node_connections():
|
if len(full_node_connections) > 1:
|
||||||
|
for peer in full_node_connections:
|
||||||
if not self.is_trusted(peer):
|
if not self.is_trusted(peer):
|
||||||
await peer.close()
|
await peer.close()
|
||||||
|
|
||||||
@ -933,7 +934,7 @@ class WalletNode:
|
|||||||
async def check_for_synced_trusted_peer(self, header_block: HeaderBlock, request_time: uint64) -> bool:
|
async def check_for_synced_trusted_peer(self, header_block: HeaderBlock, request_time: uint64) -> bool:
|
||||||
if self._server is None:
|
if self._server is None:
|
||||||
return False
|
return False
|
||||||
for peer in self.server.get_full_node_connections():
|
for peer in self.server.get_connections(NodeType.FULL_NODE):
|
||||||
if self.is_trusted(peer) and await self.is_peer_synced(peer, header_block, request_time):
|
if self.is_trusted(peer) and await self.is_peer_synced(peer, header_block, request_time):
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
@ -1469,7 +1470,7 @@ class WalletNode:
|
|||||||
if end == 0:
|
if end == 0:
|
||||||
self.log.error("Error finding sub epoch")
|
self.log.error("Error finding sub epoch")
|
||||||
return False
|
return False
|
||||||
all_peers_c = self.server.get_full_node_connections()
|
all_peers_c = self.server.get_connections(NodeType.FULL_NODE)
|
||||||
all_peers = [(con, self.is_trusted(con)) for con in all_peers_c]
|
all_peers = [(con, self.is_trusted(con)) for con in all_peers_c]
|
||||||
blocks: Optional[List[HeaderBlock]] = await fetch_header_blocks_in_range(
|
blocks: Optional[List[HeaderBlock]] = await fetch_header_blocks_in_range(
|
||||||
start, end, peer_request_cache, all_peers
|
start, end, peer_request_cache, all_peers
|
||||||
@ -1615,6 +1616,6 @@ class WalletNode:
|
|||||||
ProtocolMessageTypes.send_transaction,
|
ProtocolMessageTypes.send_transaction,
|
||||||
wallet_protocol.SendTransaction(spend_bundle),
|
wallet_protocol.SendTransaction(spend_bundle),
|
||||||
)
|
)
|
||||||
full_nodes = self.server.get_full_node_connections()
|
full_nodes = self.server.get_connections(NodeType.FULL_NODE)
|
||||||
for peer in full_nodes:
|
for peer in full_nodes:
|
||||||
await peer.send_message(msg)
|
await peer.send_message(msg)
|
||||||
|
@ -19,6 +19,7 @@ from chia.pools.pool_puzzles import SINGLETON_LAUNCHER_HASH, solution_to_pool_st
|
|||||||
from chia.pools.pool_wallet import PoolWallet
|
from chia.pools.pool_wallet import PoolWallet
|
||||||
from chia.protocols import wallet_protocol
|
from chia.protocols import wallet_protocol
|
||||||
from chia.protocols.wallet_protocol import CoinState
|
from chia.protocols.wallet_protocol import CoinState
|
||||||
|
from chia.server.outbound_message import NodeType
|
||||||
from chia.server.server import ChiaServer
|
from chia.server.server import ChiaServer
|
||||||
from chia.server.ws_connection import WSChiaConnection
|
from chia.server.ws_connection import WSChiaConnection
|
||||||
from chia.types.blockchain_format.coin import Coin
|
from chia.types.blockchain_format.coin import Coin
|
||||||
@ -487,7 +488,7 @@ class WalletStateManager:
|
|||||||
self.pending_tx_callback()
|
self.pending_tx_callback()
|
||||||
|
|
||||||
async def synced(self):
|
async def synced(self):
|
||||||
if len(self.server.get_full_node_connections()) == 0:
|
if len(self.server.get_connections(NodeType.FULL_NODE)) == 0:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
latest = await self.blockchain.get_peak_block()
|
latest = await self.blockchain.get_peak_block()
|
||||||
|
@ -185,7 +185,7 @@ async def test_daemon_simulation(self_hostname, daemon_simulation):
|
|||||||
await server1.start_client(PeerInfo(self_hostname, uint16(node2_port)))
|
await server1.start_client(PeerInfo(self_hostname, uint16(node2_port)))
|
||||||
|
|
||||||
async def num_connections():
|
async def num_connections():
|
||||||
count = len(node2.server.connection_by_type[NodeType.FULL_NODE].items())
|
count = len(node2.server.get_connections(NodeType.FULL_NODE))
|
||||||
return count
|
return count
|
||||||
|
|
||||||
await time_out_assert_custom_interval(60, 1, num_connections, 1)
|
await time_out_assert_custom_interval(60, 1, num_connections, 1)
|
||||||
|
@ -19,7 +19,7 @@ from chia.protocols.full_node_protocol import RespondTransaction
|
|||||||
from chia.protocols.protocol_message_types import ProtocolMessageTypes
|
from chia.protocols.protocol_message_types import ProtocolMessageTypes
|
||||||
from chia.protocols.wallet_protocol import SendTransaction, TransactionAck
|
from chia.protocols.wallet_protocol import SendTransaction, TransactionAck
|
||||||
from chia.server.address_manager import AddressManager
|
from chia.server.address_manager import AddressManager
|
||||||
from chia.server.outbound_message import Message
|
from chia.server.outbound_message import Message, NodeType
|
||||||
from chia.simulator.simulator_protocol import FarmNewBlockProtocol
|
from chia.simulator.simulator_protocol import FarmNewBlockProtocol
|
||||||
from chia.types.blockchain_format.classgroup import ClassgroupElement
|
from chia.types.blockchain_format.classgroup import ClassgroupElement
|
||||||
from chia.types.blockchain_format.program import Program, SerializedProgram
|
from chia.types.blockchain_format.program import Program, SerializedProgram
|
||||||
@ -400,7 +400,7 @@ class TestFullNodeProtocol:
|
|||||||
full_node_i = nodes[i]
|
full_node_i = nodes[i]
|
||||||
server_i = full_node_i.full_node.server
|
server_i = full_node_i.full_node.server
|
||||||
await server_i.start_client(PeerInfo(self_hostname, uint16(server_1._port)))
|
await server_i.start_client(PeerInfo(self_hostname, uint16(server_1._port)))
|
||||||
assert len(server_1.get_full_node_connections()) == 2
|
assert len(server_1.get_connections(NodeType.FULL_NODE)) == 2
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_request_peers(self, wallet_nodes, self_hostname):
|
async def test_request_peers(self, wallet_nodes, self_hostname):
|
||||||
|
Loading…
Reference in New Issue
Block a user