mirror of
https://github.com/Chia-Network/chia-blockchain.git
synced 2024-11-11 01:28:17 +03:00
702 lines
30 KiB
Python
702 lines
30 KiB
Python
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import contextlib
|
|
import logging
|
|
import math
|
|
import time
|
|
import traceback
|
|
from dataclasses import dataclass, field
|
|
from secrets import token_bytes
|
|
from typing import Any, Awaitable, Callable, Dict, List, Optional, Set, Tuple, Union
|
|
|
|
from aiohttp import ClientSession, WSCloseCode, WSMessage, WSMsgType
|
|
from aiohttp.client import ClientWebSocketResponse
|
|
from aiohttp.web import WebSocketResponse
|
|
from typing_extensions import Protocol, final
|
|
|
|
from chia.cmds.init_funcs import chia_full_version_str
|
|
from chia.protocols.protocol_message_types import ProtocolMessageTypes
|
|
from chia.protocols.protocol_state_machine import message_response_ok
|
|
from chia.protocols.protocol_timing import API_EXCEPTION_BAN_SECONDS, INTERNAL_PROTOCOL_ERROR_BAN_SECONDS
|
|
from chia.protocols.shared_protocol import Capability, Handshake
|
|
from chia.server.capabilities import known_active_capabilities
|
|
from chia.server.outbound_message import Message, NodeType, make_msg
|
|
from chia.server.rate_limits import RateLimiter
|
|
from chia.types.blockchain_format.sized_bytes import bytes32
|
|
from chia.types.peer_info import PeerInfo
|
|
from chia.util.api_decorators import get_metadata
|
|
from chia.util.errors import Err, ProtocolError
|
|
from chia.util.ints import uint8, uint16
|
|
from chia.util.log_exceptions import log_exceptions
|
|
|
|
# Each message is prepended with LENGTH_BYTES bytes specifying the length
|
|
from chia.util.network import class_for_type, is_localhost
|
|
from chia.util.streamable import Streamable
|
|
|
|
# Max size 2^(8*4) which is around 4GiB
|
|
LENGTH_BYTES: int = 4
|
|
|
|
WebSocket = Union[WebSocketResponse, ClientWebSocketResponse]
|
|
ConnectionCallback = Callable[["WSChiaConnection"], Awaitable[None]]
|
|
|
|
|
|
def create_default_last_message_time_dict() -> Dict[ProtocolMessageTypes, float]:
|
|
return {message_type: -math.inf for message_type in ProtocolMessageTypes}
|
|
|
|
|
|
class ConnectionClosedCallbackProtocol(Protocol):
|
|
def __call__(
|
|
self,
|
|
connection: WSChiaConnection,
|
|
ban_time: int,
|
|
closed_connection: bool = ...,
|
|
) -> None:
|
|
...
|
|
|
|
|
|
@final
|
|
@dataclass
|
|
class WSChiaConnection:
|
|
"""
|
|
Represents a connection to another node. Local host and port are ours, while peer host and
|
|
port are the host and port of the peer that we are connected to. Node_id and connection_type are
|
|
set after the handshake is performed in this connection.
|
|
"""
|
|
|
|
ws: WebSocket = field(repr=False)
|
|
api: Any = field(repr=False)
|
|
local_type: NodeType
|
|
local_port: int
|
|
local_capabilities_for_handshake: List[Tuple[uint16, str]] = field(repr=False)
|
|
local_capabilities: List[Capability]
|
|
peer_host: str
|
|
peer_port: uint16
|
|
peer_node_id: bytes32
|
|
log: logging.Logger = field(repr=False)
|
|
|
|
close_callback: Optional[ConnectionClosedCallbackProtocol] = field(repr=False)
|
|
outbound_rate_limiter: RateLimiter
|
|
inbound_rate_limiter: RateLimiter
|
|
|
|
# connection properties
|
|
is_outbound: bool
|
|
|
|
# Messaging
|
|
received_message_callback: Optional[ConnectionCallback] = field(repr=False)
|
|
incoming_queue: asyncio.Queue[Message] = field(default_factory=asyncio.Queue, repr=False)
|
|
outgoing_queue: asyncio.Queue[Message] = field(default_factory=asyncio.Queue, repr=False)
|
|
api_tasks: Dict[bytes32, asyncio.Task[None]] = field(default_factory=dict, repr=False)
|
|
# Contains task ids of api tasks which should not be canceled
|
|
execute_tasks: Set[bytes32] = field(default_factory=set, repr=False)
|
|
|
|
# ChiaConnection metrics
|
|
creation_time: float = field(default_factory=time.time)
|
|
bytes_read: int = 0
|
|
bytes_written: int = 0
|
|
last_message_time: float = 0
|
|
|
|
peer_server_port: Optional[uint16] = None
|
|
inbound_task: Optional[asyncio.Task[None]] = field(default=None, repr=False)
|
|
incoming_message_task: Optional[asyncio.Task[None]] = field(default=None, repr=False)
|
|
outbound_task: Optional[asyncio.Task[None]] = field(default=None, repr=False)
|
|
active: bool = False # once handshake is successful this will be changed to True
|
|
_close_event: asyncio.Event = field(default_factory=asyncio.Event, repr=False)
|
|
session: Optional[ClientSession] = field(default=None, repr=False)
|
|
|
|
pending_requests: Dict[uint16, asyncio.Event] = field(default_factory=dict, repr=False)
|
|
request_results: Dict[uint16, Message] = field(default_factory=dict, repr=False)
|
|
closed: bool = False
|
|
connection_type: Optional[NodeType] = None
|
|
request_nonce: uint16 = uint16(0)
|
|
peer_capabilities: List[Capability] = field(default_factory=list)
|
|
# Used by the Chia Seeder.
|
|
version: str = field(default_factory=str)
|
|
protocol_version: str = field(default_factory=str)
|
|
|
|
log_rate_limit_last_time: Dict[ProtocolMessageTypes, float] = field(
|
|
default_factory=create_default_last_message_time_dict,
|
|
repr=False,
|
|
)
|
|
|
|
@classmethod
|
|
def create(
|
|
cls,
|
|
local_type: NodeType,
|
|
ws: WebSocket,
|
|
api: Any,
|
|
server_port: int,
|
|
log: logging.Logger,
|
|
is_outbound: bool,
|
|
received_message_callback: Optional[ConnectionCallback],
|
|
peer_host: str,
|
|
close_callback: Optional[ConnectionClosedCallbackProtocol],
|
|
peer_id: bytes32,
|
|
inbound_rate_limit_percent: int,
|
|
outbound_rate_limit_percent: int,
|
|
local_capabilities_for_handshake: List[Tuple[uint16, str]],
|
|
session: Optional[ClientSession] = None,
|
|
) -> WSChiaConnection:
|
|
assert ws._writer is not None
|
|
peername = ws._writer.transport.get_extra_info("peername")
|
|
|
|
if peername is None:
|
|
raise ValueError(f"Was not able to get peername from {peer_host}")
|
|
|
|
if is_outbound:
|
|
request_nonce = uint16(0)
|
|
else:
|
|
# Different nonce to reduce chances of overlap. Each peer will increment the nonce by one for each
|
|
# request. The receiving peer (not is_outbound), will use 2^15 to 2^16 - 1
|
|
request_nonce = uint16(2**15)
|
|
|
|
return cls(
|
|
ws=ws,
|
|
api=api,
|
|
local_type=local_type,
|
|
local_port=server_port,
|
|
local_capabilities_for_handshake=local_capabilities_for_handshake,
|
|
local_capabilities=known_active_capabilities(local_capabilities_for_handshake),
|
|
peer_host=peer_host,
|
|
peer_port=peername[1],
|
|
peer_node_id=peer_id,
|
|
log=log,
|
|
close_callback=close_callback,
|
|
request_nonce=request_nonce,
|
|
outbound_rate_limiter=RateLimiter(incoming=False, percentage_of_limit=outbound_rate_limit_percent),
|
|
inbound_rate_limiter=RateLimiter(incoming=True, percentage_of_limit=inbound_rate_limit_percent),
|
|
is_outbound=is_outbound,
|
|
received_message_callback=received_message_callback,
|
|
session=session,
|
|
)
|
|
|
|
def _get_extra_info(self, name: str) -> Optional[Any]:
|
|
writer = self.ws._writer
|
|
assert writer is not None, "websocket's ._writer is None, was .prepare() called?"
|
|
transport = writer.transport
|
|
if transport is None:
|
|
return None
|
|
try:
|
|
return transport.get_extra_info(name)
|
|
except AttributeError:
|
|
# "/usr/lib/python3.11/asyncio/sslproto.py", line 91, in get_extra_info
|
|
# return self._ssl_protocol._get_extra_info(name, default)
|
|
# AttributeError: 'NoneType' object has no attribute '_get_extra_info'
|
|
return None
|
|
|
|
async def perform_handshake(
|
|
self,
|
|
network_id: str,
|
|
protocol_version: str,
|
|
server_port: int,
|
|
local_type: NodeType,
|
|
) -> None:
|
|
outbound_handshake = make_msg(
|
|
ProtocolMessageTypes.handshake,
|
|
Handshake(
|
|
network_id,
|
|
protocol_version,
|
|
chia_full_version_str(),
|
|
uint16(server_port),
|
|
uint8(local_type.value),
|
|
self.local_capabilities_for_handshake,
|
|
),
|
|
)
|
|
if self.is_outbound:
|
|
await self._send_message(outbound_handshake)
|
|
inbound_handshake_msg = await self._read_one_message()
|
|
if inbound_handshake_msg is None:
|
|
raise ProtocolError(Err.INVALID_HANDSHAKE)
|
|
inbound_handshake = Handshake.from_bytes(inbound_handshake_msg.data)
|
|
|
|
# Handle case of invalid ProtocolMessageType
|
|
try:
|
|
message_type: ProtocolMessageTypes = ProtocolMessageTypes(inbound_handshake_msg.type)
|
|
except Exception:
|
|
raise ProtocolError(Err.INVALID_HANDSHAKE)
|
|
|
|
if message_type != ProtocolMessageTypes.handshake:
|
|
raise ProtocolError(Err.INVALID_HANDSHAKE)
|
|
|
|
if inbound_handshake.network_id != network_id:
|
|
raise ProtocolError(Err.INCOMPATIBLE_NETWORK_ID)
|
|
|
|
self.version = inbound_handshake.software_version
|
|
self.protocol_version = inbound_handshake.protocol_version
|
|
self.peer_server_port = inbound_handshake.server_port
|
|
self.connection_type = NodeType(inbound_handshake.node_type)
|
|
# "1" means capability is enabled
|
|
self.peer_capabilities = known_active_capabilities(inbound_handshake.capabilities)
|
|
else:
|
|
try:
|
|
message = await self._read_one_message()
|
|
except Exception:
|
|
raise ProtocolError(Err.INVALID_HANDSHAKE)
|
|
|
|
if message is None:
|
|
raise ProtocolError(Err.INVALID_HANDSHAKE)
|
|
|
|
# Handle case of invalid ProtocolMessageType
|
|
try:
|
|
message_type = ProtocolMessageTypes(message.type)
|
|
except Exception:
|
|
raise ProtocolError(Err.INVALID_HANDSHAKE)
|
|
|
|
if message_type != ProtocolMessageTypes.handshake:
|
|
raise ProtocolError(Err.INVALID_HANDSHAKE)
|
|
|
|
inbound_handshake = Handshake.from_bytes(message.data)
|
|
if inbound_handshake.network_id != network_id:
|
|
raise ProtocolError(Err.INCOMPATIBLE_NETWORK_ID)
|
|
await self._send_message(outbound_handshake)
|
|
self.peer_server_port = inbound_handshake.server_port
|
|
self.connection_type = NodeType(inbound_handshake.node_type)
|
|
# "1" means capability is enabled
|
|
self.peer_capabilities = known_active_capabilities(inbound_handshake.capabilities)
|
|
|
|
self.outbound_task = asyncio.create_task(self.outbound_handler())
|
|
self.inbound_task = asyncio.create_task(self.inbound_handler())
|
|
self.incoming_message_task = asyncio.create_task(self.incoming_message_handler())
|
|
|
|
async def close(
|
|
self,
|
|
ban_time: int = 0,
|
|
ws_close_code: WSCloseCode = WSCloseCode.OK,
|
|
error: Optional[Err] = None,
|
|
) -> None:
|
|
"""
|
|
Closes the connection, and finally calls the close_callback on the server, so the connection gets removed
|
|
from the global list.
|
|
"""
|
|
if self.closed:
|
|
# always try to call the callback even for closed connections
|
|
with log_exceptions(self.log, consume=True):
|
|
self.log.debug(f"Closing already closed connection for {self.peer_host}")
|
|
if self.close_callback is not None:
|
|
self.close_callback(self, ban_time, closed_connection=True)
|
|
self._close_event.set()
|
|
return None
|
|
self.closed = True
|
|
|
|
if error is None:
|
|
message = b""
|
|
else:
|
|
message = str(int(error.value)).encode("utf-8")
|
|
|
|
try:
|
|
if self.inbound_task is not None:
|
|
self.inbound_task.cancel()
|
|
if self.incoming_message_task is not None:
|
|
self.incoming_message_task.cancel()
|
|
if self.outbound_task is not None:
|
|
self.outbound_task.cancel()
|
|
if self.ws is not None and self.ws.closed is False:
|
|
await self.ws.close(code=ws_close_code, message=message)
|
|
if self.session is not None:
|
|
await self.session.close()
|
|
self.cancel_pending_requests()
|
|
self.cancel_tasks()
|
|
except Exception:
|
|
error_stack = traceback.format_exc()
|
|
self.log.warning(f"Exception closing socket: {error_stack}")
|
|
raise
|
|
finally:
|
|
with log_exceptions(self.log, consume=True):
|
|
if self.close_callback is not None:
|
|
self.close_callback(self, ban_time, closed_connection=False)
|
|
self._close_event.set()
|
|
|
|
async def wait_until_closed(self) -> None:
|
|
await self._close_event.wait()
|
|
|
|
async def ban_peer_bad_protocol(self, log_err_msg: str) -> None:
|
|
"""Ban peer for protocol violation"""
|
|
ban_seconds = INTERNAL_PROTOCOL_ERROR_BAN_SECONDS
|
|
self.log.error(f"Banning peer for {ban_seconds} seconds: {self.peer_host} {log_err_msg}")
|
|
await self.close(ban_seconds, WSCloseCode.PROTOCOL_ERROR, Err.INVALID_PROTOCOL_MESSAGE)
|
|
|
|
def cancel_pending_requests(self) -> None:
|
|
for message_id, event in self.pending_requests.items():
|
|
try:
|
|
event.set()
|
|
except Exception as e:
|
|
self.log.error(f"Failed setting event for {message_id}: {e} {traceback.format_exc()}")
|
|
|
|
def cancel_tasks(self) -> None:
|
|
for task_id, task in self.api_tasks.copy().items():
|
|
if task_id in self.execute_tasks:
|
|
continue
|
|
task.cancel()
|
|
|
|
async def outbound_handler(self) -> None:
|
|
try:
|
|
while not self.closed:
|
|
msg = await self.outgoing_queue.get()
|
|
if msg is not None:
|
|
await self._send_message(msg)
|
|
except asyncio.CancelledError:
|
|
pass
|
|
except Exception as e:
|
|
expected = False
|
|
if isinstance(e, (BrokenPipeError, ConnectionResetError, TimeoutError)):
|
|
expected = True
|
|
elif isinstance(e, OSError):
|
|
if e.errno in {113}:
|
|
expected = True
|
|
|
|
if expected:
|
|
self.log.warning(f"{e} {self.peer_host}")
|
|
else:
|
|
error_stack = traceback.format_exc()
|
|
self.log.error(f"Exception: {e} with {self.peer_host}")
|
|
self.log.error(f"Exception Stack: {error_stack}")
|
|
|
|
async def _api_call(self, full_message: Message, task_id: bytes32) -> None:
|
|
start_time = time.time()
|
|
message_type = ""
|
|
try:
|
|
if self.received_message_callback is not None:
|
|
await self.received_message_callback(self)
|
|
self.log.debug(
|
|
f"<- {ProtocolMessageTypes(full_message.type).name} from peer {self.peer_node_id} {self.peer_host}"
|
|
)
|
|
message_type = ProtocolMessageTypes(full_message.type).name
|
|
|
|
f = getattr(self.api, message_type, None)
|
|
|
|
if f is None:
|
|
self.log.error(f"Non existing function: {message_type}")
|
|
raise ProtocolError(Err.INVALID_PROTOCOL_MESSAGE, [message_type])
|
|
|
|
metadata = get_metadata(function=f)
|
|
if metadata is None:
|
|
self.log.error(f"Peer trying to call non api function {message_type}")
|
|
raise ProtocolError(Err.INVALID_PROTOCOL_MESSAGE, [message_type])
|
|
|
|
# If api is not ready ignore the request
|
|
if hasattr(self.api, "api_ready"):
|
|
if self.api.api_ready is False:
|
|
return None
|
|
|
|
timeout: Optional[int] = 600
|
|
if metadata.execute_task:
|
|
# Don't timeout on methods with execute_task decorator, these need to run fully
|
|
self.execute_tasks.add(task_id)
|
|
timeout = None
|
|
|
|
if metadata.peer_required:
|
|
coroutine = f(full_message.data, self)
|
|
else:
|
|
coroutine = f(full_message.data)
|
|
|
|
async def wrapped_coroutine() -> Optional[Message]:
|
|
try:
|
|
# hinting Message here is compensating for difficulty around hinting of the callbacks
|
|
result: Message = await coroutine
|
|
return result
|
|
except asyncio.CancelledError:
|
|
pass
|
|
except Exception as e:
|
|
tb = traceback.format_exc()
|
|
self.log.error(f"Exception: {e}, {self.get_peer_logging()}. {tb}")
|
|
raise
|
|
return None
|
|
|
|
response: Optional[Message] = await asyncio.wait_for(wrapped_coroutine(), timeout=timeout)
|
|
self.log.debug(
|
|
f"Time taken to process {message_type} from {self.peer_node_id} is "
|
|
f"{time.time() - start_time} seconds"
|
|
)
|
|
|
|
if response is not None:
|
|
response_message = Message(response.type, full_message.id, response.data)
|
|
await self.send_message(response_message)
|
|
# todo uncomment when enabling none response capability
|
|
# check that this call needs a reply
|
|
# elif message_requires_reply(ProtocolMessageTypes(full_message.type)) and self.has_capability(
|
|
# Capability.NONE_RESPONSE
|
|
# ):
|
|
# # this peer can accept None reply's, send empty msg back, so it doesn't wait for timeout
|
|
# response_message = Message(uint8(ProtocolMessageTypes.none_response.value), full_message.id, b"")
|
|
# await self.send_message(response_message)
|
|
except TimeoutError:
|
|
self.log.error(f"Timeout error for: {message_type}")
|
|
except Exception as e:
|
|
if not self.closed:
|
|
tb = traceback.format_exc()
|
|
self.log.error(f"Exception: {e} {type(e)}, closing connection {self.get_peer_logging()}. {tb}")
|
|
else:
|
|
self.log.debug(f"Exception: {e} while closing connection")
|
|
# TODO: actually throw one of the errors from errors.py and pass this to close
|
|
await self.close(API_EXCEPTION_BAN_SECONDS, WSCloseCode.PROTOCOL_ERROR, Err.UNKNOWN)
|
|
finally:
|
|
if task_id in self.api_tasks:
|
|
self.api_tasks.pop(task_id)
|
|
if task_id in self.execute_tasks:
|
|
self.execute_tasks.remove(task_id)
|
|
|
|
async def incoming_message_handler(self) -> None:
|
|
while True:
|
|
message = await self.incoming_queue.get()
|
|
task_id: bytes32 = bytes32(token_bytes(32))
|
|
api_task = asyncio.create_task(self._api_call(message, task_id))
|
|
self.api_tasks[task_id] = api_task
|
|
|
|
async def inbound_handler(self) -> None:
|
|
try:
|
|
while not self.closed:
|
|
message = await self._read_one_message()
|
|
if message is not None:
|
|
if message.id in self.pending_requests:
|
|
self.request_results[message.id] = message
|
|
event = self.pending_requests[message.id]
|
|
event.set()
|
|
else:
|
|
await self.incoming_queue.put(message)
|
|
else:
|
|
continue
|
|
except asyncio.CancelledError:
|
|
self.log.debug("Inbound_handler task cancelled")
|
|
except Exception as e:
|
|
error_stack = traceback.format_exc()
|
|
self.log.error(f"Exception: {e}")
|
|
self.log.error(f"Exception Stack: {error_stack}")
|
|
|
|
async def send_message(self, message: Message) -> bool:
|
|
"""Send message sends a message with no tracking / callback."""
|
|
if self.closed:
|
|
return False
|
|
await self.outgoing_queue.put(message)
|
|
return True
|
|
|
|
async def call_api(
|
|
self,
|
|
request_method: Callable[..., Awaitable[Optional[Message]]],
|
|
message: Streamable,
|
|
timeout: int = 60,
|
|
) -> Any:
|
|
if self.connection_type is None:
|
|
raise ValueError("handshake not done yet")
|
|
request_metadata = get_metadata(request_method)
|
|
assert request_metadata is not None, f"ApiMetadata unavailable for {request_method}"
|
|
attribute = getattr(class_for_type(self.connection_type), request_metadata.request_type.name, None)
|
|
if attribute is None:
|
|
raise AttributeError(
|
|
f"Node type {self.connection_type} does not have method {request_metadata.request_type.name}"
|
|
)
|
|
|
|
request = Message(uint8(request_metadata.request_type.value), None, bytes(message))
|
|
request_start_t = time.time()
|
|
response = await self.send_request(request, timeout)
|
|
self.log.debug(
|
|
f"Time for request {request_metadata.request_type.name}: {self.get_peer_logging()} = "
|
|
f"{time.time() - request_start_t}, None? {response is None}"
|
|
)
|
|
# todo or response.type == ProtocolMessageTypes.none_response.value when enabling none response
|
|
if response is None or response.data == b"":
|
|
return None
|
|
sent_message_type = ProtocolMessageTypes(request.type)
|
|
recv_message_type = ProtocolMessageTypes(response.type)
|
|
if not message_response_ok(sent_message_type, recv_message_type):
|
|
# peer protocol violation
|
|
error_message = f"WSConnection.invoke sent message {sent_message_type.name} "
|
|
f"but received {recv_message_type.name}"
|
|
await self.ban_peer_bad_protocol(error_message)
|
|
raise ProtocolError(Err.INVALID_PROTOCOL_MESSAGE, [error_message])
|
|
|
|
recv_method = getattr(class_for_type(self.local_type), recv_message_type.name)
|
|
receive_metadata = get_metadata(recv_method)
|
|
assert receive_metadata is not None, f"ApiMetadata unavailable for {recv_method}"
|
|
return receive_metadata.message_class.from_bytes(response.data)
|
|
|
|
async def send_request(self, message_no_id: Message, timeout: int) -> Optional[Message]:
|
|
"""Sends a message and waits for a response."""
|
|
if self.closed:
|
|
return None
|
|
|
|
# We will wait for this event, it will be set either by the response, or the timeout
|
|
event = asyncio.Event()
|
|
|
|
# The request nonce is an integer between 0 and 2**16 - 1, which is used to match requests to responses
|
|
# If is_outbound, 0 <= nonce < 2^15, else 2^15 <= nonce < 2^16
|
|
request_id = self.request_nonce
|
|
if self.is_outbound:
|
|
self.request_nonce = uint16(self.request_nonce + 1) if self.request_nonce != (2**15 - 1) else uint16(0)
|
|
else:
|
|
self.request_nonce = (
|
|
uint16(self.request_nonce + 1) if self.request_nonce != (2**16 - 1) else uint16(2**15)
|
|
)
|
|
|
|
message = Message(message_no_id.type, request_id, message_no_id.data)
|
|
assert message.id is not None
|
|
self.pending_requests[message.id] = event
|
|
await self.outgoing_queue.put(message)
|
|
|
|
# Either the result is available below or not, no need to detect the timeout error
|
|
with contextlib.suppress(asyncio.TimeoutError):
|
|
await asyncio.wait_for(event.wait(), timeout=timeout)
|
|
|
|
self.pending_requests.pop(message.id)
|
|
result: Optional[Message] = None
|
|
if message.id in self.request_results:
|
|
result = self.request_results[message.id]
|
|
assert result is not None
|
|
self.log.debug(f"<- {ProtocolMessageTypes(result.type).name} from: {self.peer_host}:{self.peer_port}")
|
|
self.request_results.pop(message.id)
|
|
|
|
return result
|
|
|
|
async def send_messages(self, messages: List[Message]) -> None:
|
|
if self.closed:
|
|
return None
|
|
for message in messages:
|
|
await self.outgoing_queue.put(message)
|
|
|
|
async def _wait_and_retry(self, msg: Message) -> None:
|
|
try:
|
|
await asyncio.sleep(1)
|
|
await self.outgoing_queue.put(msg)
|
|
except Exception as e:
|
|
self.log.debug(f"Exception {e} while waiting to retry sending rate limited message")
|
|
return None
|
|
|
|
async def _send_message(self, message: Message) -> None:
|
|
encoded: bytes = bytes(message)
|
|
size = len(encoded)
|
|
assert len(encoded) < (2 ** (LENGTH_BYTES * 8))
|
|
if not self.outbound_rate_limiter.process_msg_and_check(
|
|
message, self.local_capabilities, self.peer_capabilities
|
|
):
|
|
if not is_localhost(self.peer_host):
|
|
message_type = ProtocolMessageTypes(message.type)
|
|
last_time = self.log_rate_limit_last_time[message_type]
|
|
now = time.monotonic()
|
|
self.log_rate_limit_last_time[message_type] = now
|
|
if now - last_time >= 60:
|
|
msg = f"Rate limiting ourselves. message type: {message_type.name}, peer: {self.peer_host}"
|
|
self.log.debug(msg)
|
|
|
|
# TODO: fix this special case. This function has rate limits which are too low.
|
|
if ProtocolMessageTypes(message.type) != ProtocolMessageTypes.respond_peers:
|
|
asyncio.create_task(self._wait_and_retry(message))
|
|
|
|
return None
|
|
else:
|
|
self.log.debug(
|
|
f"Not rate limiting ourselves. message type: {ProtocolMessageTypes(message.type).name}, "
|
|
f"peer: {self.peer_host}"
|
|
)
|
|
|
|
await self.ws.send_bytes(encoded)
|
|
self.log.debug(f"-> {ProtocolMessageTypes(message.type).name} to peer {self.peer_host} {self.peer_node_id}")
|
|
self.bytes_written += size
|
|
|
|
async def _read_one_message(self) -> Optional[Message]:
|
|
try:
|
|
message: WSMessage = await self.ws.receive(30)
|
|
except asyncio.TimeoutError:
|
|
# self.ws._closed if we didn't receive a ping / pong
|
|
if self.ws.closed:
|
|
asyncio.create_task(self.close())
|
|
await asyncio.sleep(3)
|
|
return None
|
|
return None
|
|
|
|
if self.connection_type is not None:
|
|
connection_type_str = NodeType(self.connection_type).name.lower()
|
|
else:
|
|
connection_type_str = ""
|
|
if message.type == WSMsgType.CLOSING:
|
|
self.log.debug(
|
|
f"Closing connection to {connection_type_str} {self.peer_host}:"
|
|
f"{self.peer_server_port}/"
|
|
f"{self.peer_port}"
|
|
)
|
|
asyncio.create_task(self.close())
|
|
await asyncio.sleep(3)
|
|
elif message.type == WSMsgType.CLOSE:
|
|
self.log.debug(
|
|
f"Peer closed connection {connection_type_str} {self.peer_host}:"
|
|
f"{self.peer_server_port}/"
|
|
f"{self.peer_port}"
|
|
)
|
|
asyncio.create_task(self.close())
|
|
await asyncio.sleep(3)
|
|
elif message.type == WSMsgType.CLOSED:
|
|
if not self.closed:
|
|
asyncio.create_task(self.close())
|
|
await asyncio.sleep(3)
|
|
return None
|
|
elif message.type == WSMsgType.BINARY:
|
|
data = message.data
|
|
full_message_loaded: Message = Message.from_bytes(data)
|
|
self.bytes_read += len(data)
|
|
self.last_message_time = time.time()
|
|
try:
|
|
message_type = ProtocolMessageTypes(full_message_loaded.type).name
|
|
except Exception:
|
|
message_type = "Unknown"
|
|
if not self.inbound_rate_limiter.process_msg_and_check(
|
|
full_message_loaded, self.local_capabilities, self.peer_capabilities
|
|
):
|
|
if self.local_type == NodeType.FULL_NODE and not is_localhost(self.peer_host):
|
|
self.log.error(
|
|
f"Peer has been rate limited and will be disconnected: {self.peer_host}, "
|
|
f"message: {message_type}"
|
|
)
|
|
# Only full node disconnects peers, to prevent abuse and crashing timelords, farmers, etc
|
|
asyncio.create_task(self.close(300))
|
|
await asyncio.sleep(3)
|
|
return None
|
|
else:
|
|
self.log.debug(
|
|
f"Peer surpassed rate limit {self.peer_host}, message: {message_type}, "
|
|
f"port {self.peer_port} but not disconnecting"
|
|
)
|
|
return full_message_loaded
|
|
return full_message_loaded
|
|
elif message.type == WSMsgType.ERROR:
|
|
self.log.error(f"WebSocket Error: {message}")
|
|
if message.data.code == WSCloseCode.MESSAGE_TOO_BIG:
|
|
asyncio.create_task(self.close(300))
|
|
else:
|
|
asyncio.create_task(self.close())
|
|
await asyncio.sleep(3)
|
|
|
|
else:
|
|
self.log.error(f"Unexpected WebSocket message type: {message}")
|
|
asyncio.create_task(self.close())
|
|
await asyncio.sleep(3)
|
|
return None
|
|
|
|
# Used by the Chia Seeder.
|
|
def get_version(self) -> str:
|
|
return self.version
|
|
|
|
def get_tls_version(self) -> str:
|
|
ssl_obj = self._get_extra_info("ssl_object")
|
|
if ssl_obj is not None:
|
|
return str(ssl_obj.version())
|
|
else:
|
|
return "unknown"
|
|
|
|
def get_peer_info(self) -> Optional[PeerInfo]:
|
|
result = self._get_extra_info("peername")
|
|
if result is None:
|
|
return None
|
|
connection_host = result[0]
|
|
port = self.peer_server_port if self.peer_server_port is not None else self.peer_port
|
|
return PeerInfo(connection_host, port)
|
|
|
|
def get_peer_logging(self) -> PeerInfo:
|
|
info: Optional[PeerInfo] = self.get_peer_info()
|
|
if info is None:
|
|
# in this case, we will use self.peer_host which is friendlier for logging
|
|
port = self.peer_server_port if self.peer_server_port is not None else self.peer_port
|
|
return PeerInfo(self.peer_host, port)
|
|
else:
|
|
return info
|
|
|
|
def has_capability(self, capability: Capability) -> bool:
|
|
return capability in self.peer_capabilities
|