Refactor Seeder & Crawler code + add tests (#15781)

* Cleanup seeder & mypy all files

holy crap mother of all tech debt this was horrid

* Make UDP Protocol class a dataclass & separate out functions.

* add TCP protocol class to DNS server

* Fix mypy types & other cleanup

also fix a couple bugs

* Add seeder and crawler tests

* change log levels

* re add db lock timeout

oops

* add edns & edns tests

* fix repeated shutdown on close signal

* fix binding to use both ipv6 and ipv4

whyyyyyyy

* add ipv6 and ipv4 tests + add ipv4 if windows
This commit is contained in:
Jack Nelson 2023-08-11 06:25:19 -04:00 committed by GitHub
parent c9e6c9d5cc
commit 0d36874caa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 1082 additions and 468 deletions

View File

@ -1,37 +1,34 @@
from __future__ import annotations
import asyncio
import dataclasses
import ipaddress
import logging
import random
import time
from dataclasses import dataclass, field, replace
from typing import Dict, List
import aiosqlite
from chia.seeder.peer_record import PeerRecord, PeerReliability
from chia.util.ints import uint64
log = logging.getLogger(__name__)
@dataclass
class CrawlStore:
crawl_db: aiosqlite.Connection
last_timestamp: int
lock: asyncio.Lock
host_to_records: Dict
host_to_selected_time: Dict
host_to_reliability: Dict
banned_peers: int
ignored_peers: int
reliable_peers: int
host_to_records: Dict[str, PeerRecord] = field(default_factory=dict) # peer_id: PeerRecord
host_to_selected_time: Dict[str, float] = field(default_factory=dict) # peer_id: timestamp (as a float)
host_to_reliability: Dict[str, PeerReliability] = field(default_factory=dict) # peer_id: PeerReliability
banned_peers: int = 0
ignored_peers: int = 0
reliable_peers: int = 0
@classmethod
async def create(cls, connection: aiosqlite.Connection):
self = cls()
async def create(cls, connection: aiosqlite.Connection) -> CrawlStore:
self = cls(connection)
self.crawl_db = connection
await self.crawl_db.execute(
(
"CREATE TABLE IF NOT EXISTS peer_records("
@ -82,21 +79,16 @@ class CrawlStore:
await self.crawl_db.execute("CREATE INDEX IF NOT EXISTS ignore_till on peer_reliability(ignore_till)")
await self.crawl_db.commit()
self.last_timestamp = 0
self.ignored_peers = 0
self.banned_peers = 0
self.reliable_peers = 0
self.host_to_selected_time = {}
await self.unload_from_db()
return self
def maybe_add_peer(self, peer_record: PeerRecord, peer_reliability: PeerReliability):
def maybe_add_peer(self, peer_record: PeerRecord, peer_reliability: PeerReliability) -> None:
if peer_record.peer_id not in self.host_to_records:
self.host_to_records[peer_record.peer_id] = peer_record
if peer_reliability.peer_id not in self.host_to_reliability:
self.host_to_reliability[peer_reliability.peer_id] = peer_reliability
async def add_peer(self, peer_record: PeerRecord, peer_reliability: PeerReliability, save_db: bool = False):
async def add_peer(self, peer_record: PeerRecord, peer_reliability: PeerReliability, save_db: bool = False) -> None:
if not save_db:
self.host_to_records[peer_record.peer_id] = peer_record
self.host_to_reliability[peer_reliability.peer_id] = peer_reliability
@ -152,41 +144,41 @@ class CrawlStore:
async def get_peer_reliability(self, peer_id: str) -> PeerReliability:
return self.host_to_reliability[peer_id]
async def peer_failed_to_connect(self, peer: PeerRecord):
now = int(time.time())
async def peer_failed_to_connect(self, peer: PeerRecord) -> None:
now = uint64(time.time())
age_timestamp = int(max(peer.last_try_timestamp, peer.connected_timestamp))
if age_timestamp == 0:
age_timestamp = now - 1000
replaced = dataclasses.replace(peer, try_count=peer.try_count + 1, last_try_timestamp=now)
replaced = replace(peer, try_count=peer.try_count + 1, last_try_timestamp=now)
reliability = await self.get_peer_reliability(peer.peer_id)
if reliability is None:
reliability = PeerReliability(peer.peer_id)
reliability.update(False, now - age_timestamp)
await self.add_peer(replaced, reliability)
async def peer_connected(self, peer: PeerRecord, tls_version: str):
now = int(time.time())
async def peer_connected(self, peer: PeerRecord, tls_version: str) -> None:
now = uint64(time.time())
age_timestamp = int(max(peer.last_try_timestamp, peer.connected_timestamp))
if age_timestamp == 0:
age_timestamp = now - 1000
replaced = dataclasses.replace(peer, connected=True, connected_timestamp=now, tls_version=tls_version)
replaced = replace(peer, connected=True, connected_timestamp=now, tls_version=tls_version)
reliability = await self.get_peer_reliability(peer.peer_id)
if reliability is None:
reliability = PeerReliability(peer.peer_id)
reliability.update(True, now - age_timestamp)
await self.add_peer(replaced, reliability)
async def update_best_timestamp(self, host: str, timestamp):
async def update_best_timestamp(self, host: str, timestamp: uint64) -> None:
if host not in self.host_to_records:
return
record = self.host_to_records[host]
replaced = dataclasses.replace(record, best_timestamp=timestamp)
replaced = replace(record, best_timestamp=timestamp)
if host not in self.host_to_reliability:
return
reliability = self.host_to_reliability[host]
await self.add_peer(replaced, reliability)
async def peer_connected_hostname(self, host: str, connected: bool = True, tls_version: str = "unknown"):
async def peer_connected_hostname(self, host: str, connected: bool = True, tls_version: str = "unknown") -> None:
if host not in self.host_to_records:
return
record = self.host_to_records[host]
@ -195,7 +187,7 @@ class CrawlStore:
else:
await self.peer_failed_to_connect(record)
async def get_peers_to_crawl(self, min_batch_size, max_batch_size) -> List[PeerRecord]:
async def get_peers_to_crawl(self, min_batch_size: int, max_batch_size: int) -> List[PeerRecord]:
now = int(time.time())
records = []
records_v6 = []
@ -270,20 +262,20 @@ class CrawlStore:
def get_reliable_peers(self) -> int:
return self.reliable_peers
async def load_to_db(self):
log.error("Saving peers to DB...")
async def load_to_db(self) -> None:
log.info("Saving peers to DB...")
for peer_id in list(self.host_to_reliability.keys()):
if peer_id in self.host_to_reliability and peer_id in self.host_to_records:
reliability = self.host_to_reliability[peer_id]
record = self.host_to_records[peer_id]
await self.add_peer(record, reliability, True)
await self.crawl_db.commit()
log.error(" - Done saving peers to DB")
log.info(" - Done saving peers to DB")
async def unload_from_db(self):
async def unload_from_db(self) -> None:
self.host_to_records = {}
self.host_to_reliability = {}
log.error("Loading peer reliability records...")
log.info("Loading peer reliability records...")
cursor = await self.crawl_db.execute(
"SELECT * from peer_reliability",
)
@ -312,9 +304,9 @@ class CrawlStore:
row[18],
row[19],
)
self.host_to_reliability[row[0]] = reliability
log.error(" - Done loading peer reliability records...")
log.error("Loading peer records...")
self.host_to_reliability[reliability.peer_id] = reliability
log.info(" - Done loading peer reliability records...")
log.info("Loading peer records...")
cursor = await self.crawl_db.execute(
"SELECT * from peer_records",
)
@ -324,24 +316,23 @@ class CrawlStore:
peer = PeerRecord(
row[0], row[1], row[2], row[3], row[4], row[5], row[6], row[7], row[8], row[9], row[10], row[11]
)
self.host_to_records[row[0]] = peer
log.error(" - Done loading peer records...")
self.host_to_records[peer.peer_id] = peer
log.info(" - Done loading peer records...")
# Crawler -> DNS.
async def load_reliable_peers_to_db(self):
async def load_reliable_peers_to_db(self) -> None:
peers = []
for peer_id in self.host_to_reliability:
reliability = self.host_to_reliability[peer_id]
for peer_id, reliability in self.host_to_reliability.items():
if reliability.is_reliable():
peers.append(peer_id)
self.reliable_peers = len(peers)
log.error("Deleting old good_peers from DB...")
log.info("Deleting old good_peers from DB...")
cursor = await self.crawl_db.execute(
"DELETE from good_peers",
)
await cursor.close()
log.error(" - Done deleting old good_peers...")
log.error("Saving new good_peers to DB...")
log.info(" - Done deleting old good_peers...")
log.info("Saving new good_peers to DB...")
for peer in peers:
cursor = await self.crawl_db.execute(
"INSERT OR REPLACE INTO good_peers VALUES(?)",
@ -349,9 +340,9 @@ class CrawlStore:
)
await cursor.close()
await self.crawl_db.commit()
log.error(" - Done saving new good_peers to DB...")
log.info(" - Done saving new good_peers to DB...")
def load_host_to_version(self):
def load_host_to_version(self) -> tuple[dict[str, str], dict[str, uint64]]:
versions = {}
handshake = {}
@ -366,19 +357,30 @@ class CrawlStore:
versions[host] = record.version
handshake[host] = record.handshake_time
return (versions, handshake)
return versions, handshake
def load_best_peer_reliability(self):
def load_best_peer_reliability(self) -> dict[str, uint64]:
best_timestamp = {}
for host, record in self.host_to_records.items():
if record.best_timestamp > time.time() - 5 * 24 * 3600:
best_timestamp[host] = record.best_timestamp
return best_timestamp
async def update_version(self, host, version, now):
async def update_version(self, host: str, version: str, timestamp_now: uint64) -> None:
record = self.host_to_records.get(host, None)
reliability = self.host_to_reliability.get(host, None)
if record is None or reliability is None:
return
record.update_version(version, now)
record.update_version(version, timestamp_now)
await self.add_peer(record, reliability)
async def get_good_peers(self) -> list[str]: # This is for the DNS server
cursor = await self.crawl_db.execute(
"SELECT * from good_peers",
)
rows = await cursor.fetchall()
await cursor.close()
result = [row[0] for row in rows]
if len(result) > 0:
random.shuffle(result) # mix up the peers
return result

View File

@ -6,15 +6,16 @@ import logging
import time
import traceback
from collections import defaultdict
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Awaitable, Callable, Dict, List, Optional, Set, Tuple
import aiosqlite
from chia.consensus.constants import ConsensusConstants
from chia.full_node.coin_store import CoinStore
from chia.full_node.full_node_api import FullNodeAPI
from chia.protocols import full_node_protocol
from chia.protocols.full_node_protocol import RespondPeers
from chia.rpc.rpc_server import StateChangedProtocol, default_get_connections
from chia.seeder.crawl_store import CrawlStore
from chia.seeder.peer_record import PeerRecord, PeerReliability
@ -29,20 +30,28 @@ from chia.util.path import path_from_root
log = logging.getLogger(__name__)
@dataclass
class Crawler:
sync_store: Any
coin_store: CoinStore
connection: Optional[aiosqlite.Connection]
config: Dict
_server: Optional[ChiaServer]
crawl_store: Optional[CrawlStore]
log: logging.Logger
constants: ConsensusConstants
_shut_down: bool
config: Dict[str, Any]
root_path: Path
peer_count: int
with_peak: set
minimum_version_count: int
constants: ConsensusConstants
print_status: bool = True
state_changed_callback: Optional[StateChangedProtocol] = None
_server: Optional[ChiaServer] = None
crawl_task: Optional[asyncio.Task[None]] = None
crawl_store: Optional[CrawlStore] = None
log: logging.Logger = log
_shut_down: bool = False
peer_count: int = 0
with_peak: Set[PeerInfo] = field(default_factory=set)
seen_nodes: Set[str] = field(default_factory=set)
minimum_version_count: int = 0
peers_retrieved: List[RespondPeers] = field(default_factory=list)
host_to_version: Dict[str, str] = field(default_factory=dict)
versions: Dict[str, int] = field(default_factory=lambda: defaultdict(lambda: 0))
version_cache: List[Tuple[str, str]] = field(default_factory=list)
handshake_time: Dict[str, uint64] = field(default_factory=dict)
best_timestamp_per_peer: Dict[str, uint64] = field(default_factory=dict)
@property
def server(self) -> ChiaServer:
@ -53,38 +62,16 @@ class Crawler:
return self._server
def __init__(
self,
config: Dict,
root_path: Path,
consensus_constants: ConsensusConstants,
name: str = None,
):
self.initialized = False
self.root_path = root_path
self.config = config
self.connection = None
self._server = None
self._shut_down = False # Set to true to close all infinite loops
self.constants = consensus_constants
self.state_changed_callback: Optional[StateChangedProtocol] = None
self.crawl_store = None
self.log = log
self.peer_count = 0
self.with_peak = set()
self.peers_retrieved: List[Any] = []
self.host_to_version: Dict[str, str] = {}
self.version_cache: List[Tuple[str, str]] = []
self.handshake_time: Dict[str, int] = {}
self.best_timestamp_per_peer: Dict[str, int] = {}
crawler_db_path: str = config.get("crawler_db_path", "crawler.db")
self.db_path = path_from_root(root_path, crawler_db_path)
def __post_init__(self) -> None:
# get db path
crawler_db_path: str = self.config.get("crawler_db_path", "crawler.db")
self.db_path = path_from_root(self.root_path, crawler_db_path)
self.db_path.parent.mkdir(parents=True, exist_ok=True)
self.bootstrap_peers = config["bootstrap_peers"]
self.minimum_height = config["minimum_height"]
self.other_peers_port = config["other_peers_port"]
self.versions: Dict[str, int] = defaultdict(lambda: 0)
self.minimum_version_count = self.config.get("minimum_version_count", 100)
# load data from config
self.bootstrap_peers = self.config["bootstrap_peers"]
self.minimum_height = self.config["minimum_height"]
self.other_peers_port = self.config["other_peers_port"]
self.minimum_version_count: int = self.config.get("minimum_version_count", 100)
if self.minimum_version_count < 1:
self.log.warning(
f"Crawler configuration minimum_version_count expected to be greater than zero: "
@ -97,11 +84,16 @@ class Crawler:
def get_connections(self, request_node_type: Optional[NodeType]) -> List[Dict[str, Any]]:
return default_get_connections(server=self.server, request_node_type=request_node_type)
async def create_client(self, peer_info, on_connect):
async def create_client(
self, peer_info: PeerInfo, on_connect: Callable[[WSChiaConnection], Awaitable[None]]
) -> bool:
return await self.server.start_client(peer_info, on_connect)
async def connect_task(self, peer):
async def peer_action(peer: WSChiaConnection):
async def connect_task(self, peer: PeerRecord) -> None:
if self.crawl_store is None:
raise ValueError("Not Connected to DB")
async def peer_action(peer: WSChiaConnection) -> None:
peer_info = peer.get_peer_info()
version = peer.get_version()
if peer_info is not None and version is not None:
@ -134,41 +126,34 @@ class Crawler:
if not connected:
await self.crawl_store.peer_failed_to_connect(peer)
except Exception as e:
self.log.info(f"Exception: {e}. Traceback: {traceback.format_exc()}.")
self.log.warning(f"Exception: {e}. Traceback: {traceback.format_exc()}.")
await self.crawl_store.peer_failed_to_connect(peer)
async def _start(self):
async def _start(self) -> None:
# We override the default peer_connect_timeout when running from the crawler
crawler_peer_timeout = self.config.get("peer_connect_timeout", 2)
self.server.config["peer_connect_timeout"] = crawler_peer_timeout
self.task = asyncio.create_task(self.crawl())
async def crawl(self):
# Ensure the state_changed callback is set up before moving on
# Sometimes, the daemon connection + state changed callback isn't up and ready
# by the time we get to the first _state_changed call, so this just ensures it's there before moving on
while self.state_changed_callback is None:
self.log.info("Waiting for state changed callback...")
await asyncio.sleep(0.1)
# Connect to the DB
self.crawl_store: CrawlStore = await CrawlStore.create(await aiosqlite.connect(self.db_path))
# Bootstrap the initial peers
await self.load_bootstrap_peers()
self.crawl_task = asyncio.create_task(self.crawl())
async def load_bootstrap_peers(self) -> None:
assert self.crawl_store is not None
try:
self.connection = await aiosqlite.connect(self.db_path)
self.crawl_store = await CrawlStore.create(self.connection)
self.log.info("Started")
self.log.warning("Bootstrapping initial peers...")
t_start = time.time()
total_nodes = 0
self.seen_nodes = set()
tried_nodes = set()
for peer in self.bootstrap_peers:
new_peer = PeerRecord(
peer,
peer,
self.other_peers_port,
False,
0,
0,
0,
uint64(0),
uint32(0),
uint64(0),
uint64(int(time.time())),
uint64(0),
"undefined",
@ -180,14 +165,26 @@ class Crawler:
self.host_to_version, self.handshake_time = self.crawl_store.load_host_to_version()
self.best_timestamp_per_peer = self.crawl_store.load_best_peer_reliability()
self.versions = defaultdict(lambda: 0)
for host, version in self.host_to_version.items():
self.versions[version] += 1
self._state_changed("loaded_initial_peers")
self.log.warning(f"Bootstrapped initial peers in {time.time() - t_start} seconds")
except Exception as e:
self.log.error(f"Error bootstrapping initial peers: {e}")
while True:
self.with_peak = set()
async def crawl(self) -> None:
# Ensure the state_changed callback is set up before moving on
# Sometimes, the daemon connection + state changed callback isn't up and ready
# by the time we get to the first _state_changed call, so this just ensures it's there before moving on
while self.state_changed_callback is None:
self.log.info("Waiting for state changed callback...")
await asyncio.sleep(0.1)
assert self.crawl_store is not None
t_start = time.time()
total_nodes = 0
tried_nodes = set()
try:
while not self._shut_down:
peers_to_crawl = await self.crawl_store.get_peers_to_crawl(25000, 250000)
tasks = set()
for peer in peers_to_crawl:
@ -200,7 +197,6 @@ class Crawler:
if len(tasks) >= 250:
await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
tasks = set(filter(lambda t: not t.done(), tasks))
if len(tasks) > 0:
await asyncio.wait(tasks, timeout=30)
@ -231,16 +227,15 @@ class Crawler:
tls_version="unknown",
)
new_peer_reliability = PeerReliability(response_peer.host)
if self.crawl_store is not None:
self.crawl_store.maybe_add_peer(new_peer, new_peer_reliability)
self.crawl_store.maybe_add_peer(new_peer, new_peer_reliability)
await self.crawl_store.update_best_timestamp(
response_peer.host,
self.best_timestamp_per_peer[response_peer.host],
)
for host, version in self.version_cache:
self.handshake_time[host] = int(time.time())
self.handshake_time[host] = uint64(time.time())
self.host_to_version[host] = version
await self.crawl_store.update_version(host, version, int(time.time()))
await self.crawl_store.update_version(host, version, uint64(time.time()))
to_remove = set()
now = int(time.time())
@ -264,94 +259,46 @@ class Crawler:
self.versions = defaultdict(lambda: 0)
for host, version in self.host_to_version.items():
self.versions[version] += 1
self.version_cache = []
self.peers_retrieved = []
# clear caches
self.version_cache: List[Tuple[str, str]] = []
self.peers_retrieved = []
self.server.banned_peers = {}
self.with_peak = set()
if len(peers_to_crawl) == 0:
continue
# Try up to 5 times to write to the DB in case there is a lock that causes a timeout
for i in range(1, 5):
try:
await self.crawl_store.load_to_db()
await self.crawl_store.load_reliable_peers_to_db()
except Exception as e:
self.log.error(f"Exception while saving to DB: {e}.")
self.log.error("Waiting 5 seconds before retry...")
await asyncio.sleep(5)
continue
break
total_records = self.crawl_store.get_total_records()
ipv6_count = self.crawl_store.get_ipv6_peers()
self.log.error("***")
self.log.error("Finished batch:")
self.log.error(f"Total IPs stored in DB: {total_records}.")
self.log.error(f"Total IPV6 addresses stored in DB: {ipv6_count}")
self.log.error(f"Total connections attempted since crawler started: {total_nodes}.")
self.log.error(f"Total unique nodes attempted since crawler started: {len(tried_nodes)}.")
t_now = time.time()
t_delta = int(t_now - t_start)
if t_delta > 0:
self.log.error(f"Avg connections per second: {total_nodes // t_delta}.")
# Periodically print detailed stats.
reliable_peers = self.crawl_store.get_reliable_peers()
self.log.error(f"High quality reachable nodes, used by DNS introducer in replies: {reliable_peers}")
banned_peers = self.crawl_store.get_banned_peers()
ignored_peers = self.crawl_store.get_ignored_peers()
available_peers = len(self.host_to_version)
addresses_count = len(self.best_timestamp_per_peer)
total_records = self.crawl_store.get_total_records()
ipv6_addresses_count = 0
for host in self.best_timestamp_per_peer.keys():
try:
ipaddress.IPv6Address(host)
ipv6_addresses_count += 1
except ipaddress.AddressValueError:
continue
self.log.error(
"IPv4 addresses gossiped with timestamp in the last 5 days with respond_peers messages: "
f"{addresses_count - ipv6_addresses_count}."
)
self.log.error(
"IPv6 addresses gossiped with timestamp in the last 5 days with respond_peers messages: "
f"{ipv6_addresses_count}."
)
ipv6_available_peers = 0
for host in self.host_to_version.keys():
try:
ipaddress.IPv6Address(host)
ipv6_available_peers += 1
except ipaddress.AddressValueError:
continue
self.log.error(
f"Total IPv4 nodes reachable in the last 5 days: {available_peers - ipv6_available_peers}."
)
self.log.error(f"Total IPv6 nodes reachable in the last 5 days: {ipv6_available_peers}.")
self.log.error("Version distribution among reachable in the last 5 days (at least 100 nodes):")
for version, count in sorted(self.versions.items(), key=lambda kv: kv[1], reverse=True):
if count >= self.minimum_version_count:
self.log.error(f"Version: {version} - Count: {count}")
self.log.error(f"Banned addresses in the DB: {banned_peers}")
self.log.error(f"Temporary ignored addresses in the DB: {ignored_peers}")
self.log.error(
"Peers to crawl from in the next batch (total IPs - ignored - banned): "
f"{total_records - banned_peers - ignored_peers}"
)
self.log.error("***")
await self.save_to_db()
await self.print_summary(t_start, total_nodes, tried_nodes)
await asyncio.sleep(15) # 15 seconds between db updates
self._state_changed("crawl_batch_completed")
except Exception as e:
self.log.error(f"Exception: {e}. Traceback: {traceback.format_exc()}.")
def set_server(self, server: ChiaServer):
async def save_to_db(self) -> None:
# Try up to 5 times to write to the DB in case there is a lock that causes a timeout
if self.crawl_store is None:
raise ValueError("Not Connected to DB")
for i in range(1, 5):
try:
await self.crawl_store.load_to_db()
await self.crawl_store.load_reliable_peers_to_db()
return
except Exception as e:
self.log.error(f"Exception while saving to DB: {e}.")
self.log.error("Waiting 5 seconds before retry...")
await asyncio.sleep(5)
continue
def set_server(self, server: ChiaServer) -> None:
self._server = server
def _state_changed(self, change: str, change_data: Optional[Dict[str, Any]] = None):
def _state_changed(self, change: str, change_data: Optional[Dict[str, Any]] = None) -> None:
if self.state_changed_callback is not None:
self.state_changed_callback(change, change_data)
async def new_peak(self, request: full_node_protocol.NewPeak, peer: WSChiaConnection):
async def new_peak(self, request: full_node_protocol.NewPeak, peer: WSChiaConnection) -> None:
try:
peer_info = peer.get_peer_info()
tls_version = peer.get_tls_version()
@ -359,6 +306,11 @@ class Crawler:
tls_version = "unknown"
if peer_info is None:
return
# validate peer ip address:
try:
ipaddress.ip_address(peer_info.host)
except ValueError:
raise ValueError(f"Invalid peer ip address: {peer_info.host}")
if request.height >= self.minimum_height:
if self.crawl_store is not None:
await self.crawl_store.peer_connected_hostname(peer_info.host, True, tls_version)
@ -366,12 +318,79 @@ class Crawler:
except Exception as e:
self.log.error(f"Exception: {e}. Traceback: {traceback.format_exc()}.")
async def on_connect(self, connection: WSChiaConnection):
async def on_connect(self, connection: WSChiaConnection) -> None:
pass
def _close(self):
def _close(self) -> None:
self._shut_down = True
async def _await_closed(self):
if self.connection is not None:
await self.connection.close()
async def _await_closed(self) -> None:
if self.crawl_task is not None:
try:
await asyncio.wait_for(self.crawl_task, timeout=10) # wait 10 seconds before giving up
except asyncio.TimeoutError:
self.log.error("Crawl task did not exit in time, killing task.")
self.crawl_task.cancel()
if self.crawl_store is not None:
self.log.info("Closing connection to DB.")
await self.crawl_store.crawl_db.close()
async def print_summary(self, t_start: float, total_nodes: int, tried_nodes: Set[str]) -> None:
assert self.crawl_store is not None # this is only ever called from the crawl task
if not self.print_status:
return
total_records = self.crawl_store.get_total_records()
ipv6_count = self.crawl_store.get_ipv6_peers()
self.log.warning("***")
self.log.warning("Finished batch:")
self.log.warning(f"Total IPs stored in DB: {total_records}.")
self.log.warning(f"Total IPV6 addresses stored in DB: {ipv6_count}")
self.log.warning(f"Total connections attempted since crawler started: {total_nodes}.")
self.log.warning(f"Total unique nodes attempted since crawler started: {len(tried_nodes)}.")
t_now = time.time()
t_delta = int(t_now - t_start)
if t_delta > 0:
self.log.warning(f"Avg connections per second: {total_nodes // t_delta}.")
# Periodically print detailed stats.
reliable_peers = self.crawl_store.get_reliable_peers()
self.log.warning(f"High quality reachable nodes, used by DNS introducer in replies: {reliable_peers}")
banned_peers = self.crawl_store.get_banned_peers()
ignored_peers = self.crawl_store.get_ignored_peers()
available_peers = len(self.host_to_version)
addresses_count = len(self.best_timestamp_per_peer)
total_records = self.crawl_store.get_total_records()
ipv6_addresses_count = 0
for host in self.best_timestamp_per_peer.keys():
try:
ipaddress.IPv6Address(host)
ipv6_addresses_count += 1
except ipaddress.AddressValueError:
continue
self.log.warning(
"IPv4 addresses gossiped with timestamp in the last 5 days with respond_peers messages: "
f"{addresses_count - ipv6_addresses_count}."
)
self.log.warning(
"IPv6 addresses gossiped with timestamp in the last 5 days with respond_peers messages: "
f"{ipv6_addresses_count}."
)
ipv6_available_peers = 0
for host in self.host_to_version.keys():
try:
ipaddress.IPv6Address(host)
ipv6_available_peers += 1
except ipaddress.AddressValueError:
continue
self.log.warning(f"Total IPv4 nodes reachable in the last 5 days: {available_peers - ipv6_available_peers}.")
self.log.warning(f"Total IPv6 nodes reachable in the last 5 days: {ipv6_available_peers}.")
self.log.warning("Version distribution among reachable in the last 5 days (at least 100 nodes):")
for version, count in sorted(self.versions.items(), key=lambda kv: kv[1], reverse=True):
if count >= self.minimum_version_count:
self.log.warning(f"Version: {version} - Count: {count}")
self.log.warning(f"Banned addresses in the DB: {banned_peers}")
self.log.warning(f"Temporary ignored addresses in the DB: {ignored_peers}")
self.log.warning(
"Peers to crawl from in the next batch (total IPs - ignored - banned): "
f"{total_records - banned_peers - ignored_peers}"
)
self.log.warning("***")

View File

@ -1,294 +1,534 @@
from __future__ import annotations
import asyncio
import ipaddress
import logging
import random
import signal
import sys
import traceback
from asyncio import AbstractEventLoop
from contextlib import asynccontextmanager
from dataclasses import dataclass, field
from ipaddress import IPv4Address, IPv6Address, ip_address
from multiprocessing import freeze_support
from pathlib import Path
from typing import Any, Dict, List
from typing import Any, AsyncIterator, Awaitable, Callable, Dict, List, Optional
import aiosqlite
from dnslib import AAAA, CNAME, MX, NS, QTYPE, RR, SOA, A, DNSHeader, DNSRecord
from dnslib import AAAA, EDNS0, NS, QTYPE, RCODE, RD, RR, SOA, A, DNSError, DNSHeader, DNSQuestion, DNSRecord
from chia.util.chia_logging import initialize_logging
from chia.util.config import load_config
from chia.seeder.crawl_store import CrawlStore
from chia.util.chia_logging import initialize_service_logging
from chia.util.config import load_config, load_config_cli
from chia.util.default_root import DEFAULT_ROOT_PATH
from chia.util.path import path_from_root
SERVICE_NAME = "seeder"
log = logging.getLogger(__name__)
DnsCallback = Callable[[DNSRecord], Awaitable[DNSRecord]]
# DNS snippet taken from: https://gist.github.com/pklaus/b5a7876d4d2cf7271873
class DomainName(str):
def __getattr__(self, item):
return DomainName(item + "." + self)
def __getattr__(self, item: str) -> DomainName:
return DomainName(item + "." + self) # DomainName.NS becomes DomainName("NS.DomainName")
D = None
ns = None
IP = "127.0.0.1"
TTL = None
soa_record = None
ns_records: List[Any] = []
@dataclass(frozen=True)
class PeerList:
ipv4: List[IPv4Address]
ipv6: List[IPv6Address]
@property
def no_peers(self) -> bool:
return not self.ipv4 and not self.ipv6
class EchoServerProtocol(asyncio.DatagramProtocol):
def __init__(self, callback):
self.data_queue = asyncio.Queue()
self.callback = callback
asyncio.ensure_future(self.respond())
@dataclass
class UDPDNSServerProtocol(asyncio.DatagramProtocol):
"""
This is a really simple UDP Server, that converts all requests to DNSRecord objects and passes them to the callback.
"""
def connection_made(self, transport):
self.transport = transport
callback: DnsCallback
transport: Optional[asyncio.DatagramTransport] = field(init=False, default=None)
data_queue: asyncio.Queue[tuple[DNSRecord, tuple[str, int]]] = field(default_factory=asyncio.Queue)
queue_task: Optional[asyncio.Task[None]] = field(init=False, default=None)
def datagram_received(self, data, addr):
asyncio.ensure_future(self.handler(data, addr))
def start(self) -> None:
self.queue_task = asyncio.create_task(self.respond()) # This starts the dns respond loop.
async def respond(self):
while True:
async def stop(self) -> None:
if self.queue_task is not None:
self.queue_task.cancel()
try:
resp, caller = await self.data_queue.get()
self.transport.sendto(resp, caller)
await self.queue_task
except asyncio.CancelledError: # we dont care
pass
if self.transport is not None:
self.transport.close()
def connection_made(self, transport: asyncio.BaseTransport) -> None:
# we use the #ignore because transport is a subclass of BaseTransport, but we need the real type.
self.transport = transport # type: ignore[assignment]
def datagram_received(self, data: bytes, addr: tuple[str, int]) -> None:
log.debug(f"Received UDP DNS request from {addr}.")
dns_request: Optional[DNSRecord] = parse_dns_request(data)
if dns_request is None: # Invalid Request, we can just drop it and move on.
return
asyncio.create_task(self.handler(dns_request, addr))
async def respond(self) -> None:
log.info("UDP DNS responder started.")
while self.transport is None: # we wait for the transport to be set.
await asyncio.sleep(0.1)
while not self.transport.is_closing():
try:
edns_max_size = 0
reply, caller = await self.data_queue.get()
if len(reply.ar) > 0 and reply.ar[0].rtype == QTYPE.OPT:
edns_max_size = reply.ar[0].edns_len
reply_packed = reply.pack()
if len(reply_packed) > max(512, edns_max_size): # 512 is the default max size for DNS:
log.debug(f"DNS response to {caller} is too large, truncating.")
reply_packed = reply.truncate().pack()
self.transport.sendto(reply_packed, caller)
log.debug(f"Sent UDP DNS response to {caller}, of size {len(reply_packed)}.")
except Exception as e:
log.error(f"Exception: {e}. Traceback: {traceback.format_exc()}.")
log.error(f"Exception while responding to UDP DNS request: {e}. Traceback: {traceback.format_exc()}.")
log.info("UDP DNS responder stopped.")
async def handler(self, data, caller):
async def handler(self, data: DNSRecord, caller: tuple[str, int]) -> None:
r_data = await get_dns_reply(self.callback, data) # process the request, returning a DNSRecord response.
await self.data_queue.put((r_data, caller))
@dataclass
class TCPDNSServerProtocol(asyncio.BufferedProtocol):
"""
This TCP server is a little more complicated, because we need to handle the length field, however it still
converts all requests to DNSRecord objects and passes them to the callback, after receiving the full message.
"""
callback: DnsCallback
transport: Optional[asyncio.Transport] = field(init=False, default=None)
peer_info: str = field(init=False, default="")
expected_length: int = 0
buffer: bytearray = field(init=False, default_factory=lambda: bytearray(2))
futures: List[asyncio.Future[None]] = field(init=False, default_factory=list)
def connection_made(self, transport: asyncio.BaseTransport) -> None:
"""
This is called whenever we get a new connection.
"""
# we use the #ignore because transport is a subclass of BaseTransport, but we need the real type.
self.transport = transport # type: ignore[assignment]
peer_info = transport.get_extra_info("peername")
self.peer_info = f"{peer_info[0]}:{peer_info[1]}"
log.debug(f"TCP connection established with {self.peer_info}.")
def connection_lost(self, exc: Optional[Exception]) -> None:
"""
This is called whenever a connection is lost, or closed.
"""
if exc is not None:
log.debug(f"TCP DNS connection lost with {self.peer_info}. Exception: {exc}.")
else:
log.debug(f"TCP DNS connection closed with {self.peer_info}.")
# reset the state of the protocol.
for future in self.futures:
future.cancel()
self.futures = []
self.buffer = bytearray(2)
self.expected_length = 0
def get_buffer(self, sizehint: int) -> memoryview:
"""
This is the first function called after connection_made, it returns a buffer that the tcp server will write to.
Once a buffer is written to, buffer_updated is called.
"""
return memoryview(self.buffer)
def buffer_updated(self, nbytes: int) -> None:
"""
This is called whenever the buffer is written to, and it loops through the buffer, grouping them into messages
and then dns records.
"""
while not len(self.buffer) == 0 and self.transport is not None:
if not self.expected_length:
# Length field received (This is the first part of the message)
self.expected_length = int.from_bytes(self.buffer, byteorder="big")
self.buffer = self.buffer[2:] # Remove the length field from the buffer.
if len(self.buffer) >= self.expected_length:
# This is the rest of the message (after the length field)
message = self.buffer[: self.expected_length]
self.buffer = self.buffer[self.expected_length :] # Remove the message from the buffer
self.expected_length = 0 # Reset the expected length
dns_request: Optional[DNSRecord] = parse_dns_request(message)
if dns_request is None: # Invalid Request, so we disconnect and don't send anything back.
self.transport.close()
return
self.futures.append(asyncio.create_task(self.handle_and_respond(dns_request)))
self.buffer = bytearray(2 if self.expected_length == 0 else self.expected_length) # Reset the buffer if empty.
def eof_received(self) -> Optional[bool]:
"""
This is called when the client closes the connection, False or None means we close the connection.
True means we keep the connection open.
"""
if len(self.futures) > 0: # Successful requests
if self.expected_length != 0: # Incomplete requests
log.warning(
f"Received incomplete TCP DNS request of length {self.expected_length} from {self.peer_info}, "
f"closing connection after dns replies are sent."
)
asyncio.create_task(self.wait_for_futures())
return True # Keep connection open, until the futures are done.
log.info(f"Received early EOF from {self.peer_info}, closing connection.")
return False
async def wait_for_futures(self) -> None:
"""
Waits for all the futures to complete, and then closes the connection.
"""
try:
data = await self.callback(data)
if data is None:
return
await self.data_queue.put((data, caller))
await asyncio.wait_for(asyncio.gather(*self.futures), timeout=10)
except asyncio.TimeoutError:
log.warning(f"Timed out waiting for DNS replies to be sent to {self.peer_info}.")
if self.transport is not None:
self.transport.close()
async def handle_and_respond(self, data: DNSRecord) -> None:
r_data = await get_dns_reply(self.callback, data) # process the request, returning a DNSRecord response.
try:
# If the client closed the connection, we don't want to send anything.
if self.transport is not None and not self.transport.is_closing():
self.transport.write(dns_response_to_tcp(r_data)) # send data back to the client
log.debug(f"Sent DNS response for {data.q.qname}, to {self.peer_info}.")
except Exception as e:
log.error(f"Exception: {e}. Traceback: {traceback.format_exc()}.")
log.error(f"Exception while responding to TCP DNS request: {e}. Traceback: {traceback.format_exc()}.")
def dns_response_to_tcp(data: DNSRecord) -> bytes:
"""
Converts a DNSRecord response to a TCP DNS response, by adding a 2 byte length field to the start.
"""
dns_response = data.pack()
dns_response_length = len(dns_response).to_bytes(2, byteorder="big")
return bytes(dns_response_length + dns_response)
def create_dns_reply(dns_request: DNSRecord) -> DNSRecord:
"""
Creates a DNS response with the correct header and section flags set.
"""
# QR means query response, AA means authoritative answer, RA means recursion available
return DNSRecord(DNSHeader(id=dns_request.header.id, qr=1, aa=1, ra=0), q=dns_request.q)
def parse_dns_request(data: bytes) -> Optional[DNSRecord]:
"""
Parses the DNS request, and returns a DNSRecord object, or None if the request is invalid.
"""
dns_request: Optional[DNSRecord] = None
try:
dns_request = DNSRecord.parse(data)
except DNSError as e:
log.warning(f"Received invalid DNS request: {e}. Traceback: {traceback.format_exc()}.")
return dns_request
async def get_dns_reply(callback: DnsCallback, dns_request: DNSRecord) -> DNSRecord:
"""
This function calls the callback, and returns SERVFAIL if the callback raises an exception.
"""
try:
dns_reply = await callback(dns_request)
except Exception as e:
log.error(f"Exception during DNS record processing: {e}. Traceback: {traceback.format_exc()}.")
# we return an empty response with an error code
dns_reply = create_dns_reply(dns_request) # This is an empty response, with only the header set.
dns_reply.header.rcode = RCODE.SERVFAIL
return dns_reply
@dataclass
class DNSServer:
reliable_peers_v4: List[str]
reliable_peers_v6: List[str]
lock: asyncio.Lock
pointer: int
crawl_db: aiosqlite.Connection
config: Dict[str, Any]
root_path: Path
lock: asyncio.Lock = field(default_factory=asyncio.Lock)
shutdown_event: asyncio.Event = field(default_factory=asyncio.Event)
crawl_store: Optional[CrawlStore] = field(init=False, default=None)
reliable_task: Optional[asyncio.Task[None]] = field(init=False, default=None)
shutdown_task: Optional[asyncio.Task[None]] = field(init=False, default=None)
udp_transport_ipv4: Optional[asyncio.DatagramTransport] = field(init=False, default=None)
udp_protocol_ipv4: Optional[UDPDNSServerProtocol] = field(init=False, default=None)
udp_transport_ipv6: Optional[asyncio.DatagramTransport] = field(init=False, default=None)
udp_protocol_ipv6: Optional[UDPDNSServerProtocol] = field(init=False, default=None)
# TODO: After 3.10 is dropped change to asyncio.Server
tcp_server: Optional[asyncio.base_events.Server] = field(init=False, default=None)
# these are all set in __post_init__
dns_port: int = field(init=False)
db_path: Path = field(init=False)
domain: DomainName = field(init=False)
ns1: DomainName = field(init=False)
ns_records: List[RR] = field(init=False)
ttl: int = field(init=False)
soa_record: RR = field(init=False)
reliable_peers_v4: List[IPv4Address] = field(default_factory=list)
reliable_peers_v6: List[IPv6Address] = field(default_factory=list)
pointer_v4: int = 0
pointer_v6: int = 0
def __init__(self, config: Dict, root_path: Path):
self.reliable_peers_v4 = []
self.reliable_peers_v6 = []
self.lock = asyncio.Lock()
self.pointer_v4 = 0
self.pointer_v6 = 0
crawler_db_path: str = config.get("crawler_db_path", "crawler.db")
self.db_path = path_from_root(root_path, crawler_db_path)
def __post_init__(self) -> None:
"""
We initialize all the variables set to field(init=False) here.
"""
# From Config
self.dns_port: int = self.config.get("dns_port", 53)
# DB Path
crawler_db_path: str = self.config.get("crawler_db_path", "crawler.db")
self.db_path: Path = path_from_root(self.root_path, crawler_db_path)
self.db_path.parent.mkdir(parents=True, exist_ok=True)
async def start(self):
# self.crawl_db = await aiosqlite.connect(self.db_path)
# Get a reference to the event loop as we plan to use
# low-level APIs.
loop = asyncio.get_running_loop()
# One protocol instance will be created to serve all
# client requests.
self.transport, self.protocol = await loop.create_datagram_endpoint(
lambda: EchoServerProtocol(self.dns_response), local_addr=("::0", 53)
# DNS info
self.domain: DomainName = DomainName(self.config["domain_name"])
if not self.domain.endswith("."):
self.domain = DomainName(self.domain + ".") # Make sure the domain ends with a period, as per RFC 1035.
self.ns1: DomainName = DomainName(self.config["nameserver"])
self.ns_records: List[NS] = [NS(self.ns1)]
self.ttl: int = self.config["ttl"]
self.soa_record: SOA = SOA(
mname=self.ns1, # primary name server
rname=self.config["soa"]["rname"], # email of the domain administrator
times=(
self.config["soa"]["serial_number"],
self.config["soa"]["refresh"],
self.config["soa"]["retry"],
self.config["soa"]["expire"],
self.config["soa"]["minimum"],
),
)
@asynccontextmanager
async def run(self) -> AsyncIterator[None]:
log.info("Starting DNS server.")
# Get a reference to the event loop as we plan to use low-level APIs.
loop = asyncio.get_running_loop()
await self.setup_signal_handlers(loop)
# Set up the crawl store and the peer update task.
self.crawl_store = await CrawlStore.create(await aiosqlite.connect(self.db_path, timeout=120))
self.reliable_task = asyncio.create_task(self.periodically_get_reliable_peers())
async def periodically_get_reliable_peers(self):
# One protocol instance will be created for each udp transport, so that we can accept ipv4 and ipv6
self.udp_transport_ipv6, self.udp_protocol_ipv6 = await loop.create_datagram_endpoint(
lambda: UDPDNSServerProtocol(self.dns_response), local_addr=("::0", self.dns_port)
)
self.udp_protocol_ipv6.start() # start ipv6 udp transmit task
# in case the port is 0 we need all protocols on the same port.
self.dns_port = self.udp_transport_ipv6.get_extra_info("sockname")[1] # get the actual port we are listening to
if sys.platform.startswith("win32") or sys.platform.startswith("cygwin"):
# Windows does not support dual stack sockets, so we need to create a new socket for ipv4.
self.udp_transport_ipv4, self.udp_protocol_ipv4 = await loop.create_datagram_endpoint(
lambda: UDPDNSServerProtocol(self.dns_response), local_addr=("0.0.0.0", self.dns_port)
)
self.udp_protocol_ipv4.start() # start ipv4 udp transmit task
# One tcp server will handle both ipv4 and ipv6 on both linux and windows.
self.tcp_server = await loop.create_server(
lambda: TCPDNSServerProtocol(self.dns_response), ["::0", "0.0.0.0"], self.dns_port
)
log.info("DNS server started.")
try:
yield
finally: # catches any errors and properly shuts down the server
await self.stop()
log.info("DNS server stopped.")
async def setup_signal_handlers(self, loop: AbstractEventLoop) -> None:
try:
loop.add_signal_handler(signal.SIGINT, self._accept_signal)
loop.add_signal_handler(signal.SIGTERM, self._accept_signal)
except NotImplementedError:
log.warning("signal handlers unsupported on this platform")
def _accept_signal(self) -> None: # pragma: no cover
if self.shutdown_task is None: # otherwise we are already shutting down, so we ignore the signal
self.shutdown_task = asyncio.create_task(self.stop())
async def stop(self) -> None:
log.info("Stopping DNS server...")
if self.reliable_task is not None:
self.reliable_task.cancel() # cancel the peer update task
if self.crawl_store is not None:
await self.crawl_store.crawl_db.close()
if self.udp_protocol_ipv6 is not None:
await self.udp_protocol_ipv6.stop() # stop responding to and accepting udp requests (ipv6) & ipv4 if linux.
if self.udp_protocol_ipv4 is not None:
await self.udp_protocol_ipv4.stop() # stop responding to and accepting udp requests (ipv4) if windows.
if self.tcp_server is not None:
self.tcp_server.close() # stop accepting new tcp requests (ipv4 and ipv6)
await self.tcp_server.wait_closed() # wait for existing TCP requests to finish (ipv4 and ipv6)
self.shutdown_event.set()
async def periodically_get_reliable_peers(self) -> None:
sleep_interval = 0
while True:
while not self.shutdown_event.is_set() and self.crawl_store is not None:
try:
# TODO: double check this. It shouldn't take this long to connect.
crawl_db = await aiosqlite.connect(self.db_path, timeout=600)
cursor = await crawl_db.execute(
"SELECT * from good_peers",
)
new_reliable_peers = []
rows = await cursor.fetchall()
await cursor.close()
await crawl_db.close()
for row in rows:
new_reliable_peers.append(row[0])
if len(new_reliable_peers) > 0:
random.shuffle(new_reliable_peers)
async with self.lock:
self.reliable_peers_v4 = []
self.reliable_peers_v6 = []
for peer in new_reliable_peers:
ipv4 = True
try:
_ = ipaddress.IPv4Address(peer)
except ValueError:
ipv4 = False
if ipv4:
self.reliable_peers_v4.append(peer)
else:
try:
_ = ipaddress.IPv6Address(peer)
except ValueError:
continue
self.reliable_peers_v6.append(peer)
self.pointer_v4 = 0
self.pointer_v6 = 0
log.error(
new_reliable_peers = await self.crawl_store.get_good_peers()
except Exception as e:
log.error(f"Error loading reliable peers from database: {e}. Traceback: {traceback.format_exc()}.")
continue
if len(new_reliable_peers) == 0:
log.info("No reliable peers found in database, waiting for db to be populated.")
await asyncio.sleep(2) # sleep for 2 seconds, because the db has not been populated yet.
continue
async with self.lock:
self.reliable_peers_v4 = []
self.reliable_peers_v6 = []
self.pointer_v4 = 0
self.pointer_v6 = 0
for peer in new_reliable_peers:
try:
validated_peer = ip_address(peer)
if validated_peer.version == 4:
self.reliable_peers_v4.append(validated_peer)
elif validated_peer.version == 6:
self.reliable_peers_v6.append(validated_peer)
except ValueError:
log.error(f"Invalid peer: {peer}")
continue
log.info(
f"Number of reliable peers discovered in dns server:"
f" IPv4 count - {len(self.reliable_peers_v4)}"
f" IPv6 count - {len(self.reliable_peers_v6)}"
)
except Exception as e:
log.error(f"Exception: {e}. Traceback: {traceback.format_exc()}.")
sleep_interval = min(15, sleep_interval + 1)
await asyncio.sleep(sleep_interval * 60)
async def get_peers_to_respond(self, ipv4_count, ipv6_count):
peers = []
async def get_peers_to_respond(self, ipv4_count: int, ipv6_count: int) -> PeerList:
async with self.lock:
# Append IPv4.
ipv4_peers: List[IPv4Address] = []
size = len(self.reliable_peers_v4)
if ipv4_count > 0 and size <= ipv4_count:
peers = self.reliable_peers_v4
ipv4_peers = self.reliable_peers_v4
elif ipv4_count > 0:
peers = [self.reliable_peers_v4[i % size] for i in range(self.pointer_v4, self.pointer_v4 + ipv4_count)]
self.pointer_v4 = (self.pointer_v4 + ipv4_count) % size
ipv4_peers = [
self.reliable_peers_v4[i % size] for i in range(self.pointer_v4, self.pointer_v4 + ipv4_count)
]
self.pointer_v4 = (self.pointer_v4 + ipv4_count) % size # mark where we left off
# Append IPv6.
ipv6_peers: List[IPv6Address] = []
size = len(self.reliable_peers_v6)
if ipv6_count > 0 and size <= ipv6_count:
peers = peers + self.reliable_peers_v6
ipv6_peers = self.reliable_peers_v6
elif ipv6_count > 0:
peers = peers + [
ipv6_peers = [
self.reliable_peers_v6[i % size] for i in range(self.pointer_v6, self.pointer_v6 + ipv6_count)
]
self.pointer_v6 = (self.pointer_v6 + ipv6_count) % size
return peers
self.pointer_v6 = (self.pointer_v6 + ipv6_count) % size # mark where we left off
return PeerList(ipv4_peers, ipv6_peers)
async def dns_response(self, data):
try:
request = DNSRecord.parse(data)
IPs = [MX(D.mail), soa_record] + ns_records
ipv4_count = 0
ipv6_count = 0
if request.q.qtype == 1:
ipv4_count = 32
elif request.q.qtype == 28:
ipv6_count = 32
elif request.q.qtype == 255:
ipv4_count = 16
ipv6_count = 16
else:
ipv4_count = 32
peers = await self.get_peers_to_respond(ipv4_count, ipv6_count)
if len(peers) == 0:
return None
for peer in peers:
ipv4 = True
try:
_ = ipaddress.IPv4Address(peer)
except ValueError:
ipv4 = False
if ipv4:
IPs.append(A(peer))
else:
try:
_ = ipaddress.IPv6Address(peer)
except ValueError:
continue
IPs.append(AAAA(peer))
reply = DNSRecord(DNSHeader(id=request.header.id, qr=1, aa=len(IPs), ra=1), q=request.q)
async def dns_response(self, request: DNSRecord) -> DNSRecord:
"""
This function is called when a DNS request is received, and it returns a DNS response.
It does not catch any errors as it is called from within a try-except block.
"""
reply = create_dns_reply(request)
dns_question: DNSQuestion = request.q # this is the question / request
question_type: int = dns_question.qtype # the type of the record being requested
qname = dns_question.qname # the name being queried / requested
# ADD EDNS0 to response if supported
if len(request.ar) > 0 and request.ar[0].rtype == QTYPE.OPT: # OPT Means EDNS
udp_len = min(4096, request.ar[0].edns_len)
edns_reply = EDNS0(udp_len=udp_len)
reply.add_ar(edns_reply)
# DNS labels are mixed case with DNS resolvers that implement the use of bit 0x20 to improve
# transaction identity. See https://datatracker.ietf.org/doc/html/draft-vixie-dnsext-dns0x20-00
qname_str = str(qname).lower()
if qname_str != self.domain and not qname_str.endswith("." + self.domain):
# we don't answer for other domains (we have the not recursive bit set)
log.warning(f"Invalid request for {qname_str}, returning REFUSED.")
reply.header.rcode = RCODE.REFUSED
return reply
records = {
D: IPs,
D.ns1: [A(IP)], # MX and NS records must never point to a CNAME alias (RFC 2181 section 10.3)
D.ns2: [A(IP)],
D.mail: [A(IP)],
D.andrei: [CNAME(D)],
}
ttl: int = self.ttl
# we add these to the list as it will allow us to respond to ns and soa requests
ips: List[RD] = [self.soa_record] + self.ns_records
ipv4_count = 0
ipv6_count = 0
if question_type is QTYPE.A:
ipv4_count = 32
elif question_type is QTYPE.AAAA:
ipv6_count = 32
elif question_type is QTYPE.ANY:
ipv4_count = 16
ipv6_count = 16
else:
ipv4_count = 32
peers: PeerList = await self.get_peers_to_respond(ipv4_count, ipv6_count)
if peers.no_peers:
log.error("No peers found, returning SOA and NS records only.")
ttl = 60 # 1 minute as we should have some peers very soon
# we always return the SOA and NS records, so we continue even if there are no peers
ips.extend([A(str(peer)) for peer in peers.ipv4])
ips.extend([AAAA(str(peer)) for peer in peers.ipv6])
qname = request.q.qname
# DNS labels are mixed case with DNS resolvers that implement the use of bit 0x20 to improve
# transaction identity. See https://datatracker.ietf.org/doc/html/draft-vixie-dnsext-dns0x20-00
qn = str(qname).lower()
qtype = request.q.qtype
qt = QTYPE[qtype]
if qn == D or qn.endswith("." + D):
for name, rrs in records.items():
if name == qn:
for rdata in rrs:
rqt = rdata.__class__.__name__
if qt in ["*", rqt] or (qt == "ANY" and (rqt == "A" or rqt == "AAAA")):
reply.add_answer(
RR(rname=qname, rtype=getattr(QTYPE, rqt), rclass=1, ttl=TTL, rdata=rdata)
)
records: Dict[DomainName, List[RD]] = { # this is where we can add other records we want to serve
self.domain: ips,
}
for rdata in ns_records:
reply.add_ar(RR(rname=D, rtype=QTYPE.NS, rclass=1, ttl=TTL, rdata=rdata))
reply.add_auth(RR(rname=D, rtype=QTYPE.SOA, rclass=1, ttl=TTL, rdata=soa_record))
return reply.pack()
except Exception as e:
log.error(f"Exception: {e}. Traceback: {traceback.format_exc()}.")
valid_domain = False
for domain_name, domain_responses in records.items():
if domain_name == qname_str: # if the dns name is the same as the requested name
valid_domain = True
for response in domain_responses:
rqt: int = getattr(QTYPE, response.__class__.__name__)
if question_type == rqt or question_type == QTYPE.ANY:
reply.add_answer(RR(rname=qname, rtype=rqt, rclass=1, ttl=ttl, rdata=response))
if not valid_domain and len(reply.rr) == 0: # if we didn't find any records to return
reply.header.rcode = RCODE.NXDOMAIN
# always put nameservers and the SOA records
for nameserver in self.ns_records:
reply.add_auth(RR(rname=self.domain, rtype=QTYPE.NS, rclass=1, ttl=ttl, rdata=nameserver))
reply.add_auth(RR(rname=self.domain, rtype=QTYPE.SOA, rclass=1, ttl=ttl, rdata=self.soa_record))
return reply
async def serve_dns(config: Dict, root_path: Path):
dns_server = DNSServer(config, root_path)
await dns_server.start()
# TODO: Make this cleaner?
while True:
await asyncio.sleep(3600)
async def run_dns_server(dns_server: DNSServer) -> None: # pragma: no cover
async with dns_server.run():
await dns_server.shutdown_event.wait() # this is released on SIGINT or SIGTERM or any unhandled exception
async def kill_processes():
# TODO: implement.
pass
def create_dns_server_service(config: Dict[str, Any], root_path: Path) -> DNSServer:
service_config = config[SERVICE_NAME]
return DNSServer(service_config, root_path)
def signal_received():
asyncio.create_task(kill_processes())
async def async_main(config, root_path):
loop = asyncio.get_running_loop()
try:
loop.add_signal_handler(signal.SIGINT, signal_received)
loop.add_signal_handler(signal.SIGTERM, signal_received)
except NotImplementedError:
log.info("signal handlers unsupported")
await serve_dns(config, root_path)
def main():
def main() -> None: # pragma: no cover
freeze_support()
root_path = DEFAULT_ROOT_PATH
config = load_config(root_path, "config.yaml", SERVICE_NAME)
initialize_logging(SERVICE_NAME, config["logging"], root_path)
global D
global ns
global TTL
global soa_record
global ns_records
D = DomainName(config["domain_name"])
ns = DomainName(config["nameserver"])
TTL = config["ttl"]
soa_record = SOA(
mname=ns, # primary name server
rname=config["soa"]["rname"], # email of the domain administrator
times=(
config["soa"]["serial_number"],
config["soa"]["refresh"],
config["soa"]["retry"],
config["soa"]["expire"],
config["soa"]["minimum"],
),
)
ns_records = [NS(ns)]
# TODO: refactor to avoid the double load
config = load_config(DEFAULT_ROOT_PATH, "config.yaml")
service_config = load_config_cli(DEFAULT_ROOT_PATH, "config.yaml", SERVICE_NAME)
config[SERVICE_NAME] = service_config
initialize_service_logging(service_name=SERVICE_NAME, config=config)
asyncio.run(async_main(config=config, root_path=root_path))
dns_server = create_dns_server_service(config, root_path)
asyncio.run(run_dns_server(dns_server))
if __name__ == "__main__":

View File

@ -24,23 +24,19 @@ class PeerRecord(Streamable):
handshake_time: uint64
tls_version: str
def update_version(self, version, now):
def update_version(self, version: str, now: uint64) -> None:
if version != "undefined":
object.__setattr__(self, "version", version)
object.__setattr__(self, "handshake_time", uint64(now))
object.__setattr__(self, "handshake_time", now)
@dataclass
class PeerStat:
weight: float
count: float
reliability: float
def __init__(self, weight, count, reliability):
self.weight = weight
self.count = count
self.reliability = reliability
def update(self, is_reachable: bool, age: int, tau: int):
def update(self, is_reachable: bool, age: int, tau: int) -> None:
f = math.exp(-age / tau)
self.reliability = self.reliability * f + (1.0 - f if is_reachable else 0.0)
self.count = self.count * f + 1.0
@ -81,7 +77,7 @@ class PeerReliability:
stat_1m_reliability: float = 0.0,
tries: int = 0,
successes: int = 0,
):
) -> None:
self.peer_id = peer_id
self.ignore_till = ignore_till
self.ban_till = ban_till
@ -132,7 +128,7 @@ class PeerReliability:
return 2 * 3600
return 0
def update(self, is_reachable: bool, age: int):
def update(self, is_reachable: bool, age: int) -> None:
self.stat_2h.update(is_reachable, age, 2 * 3600)
self.stat_8h.update(is_reachable, age, 8 * 3600)
self.stat_1d.update(is_reachable, age, 24 * 3600)

View File

@ -4,7 +4,7 @@ import logging
import pathlib
import sys
from multiprocessing import freeze_support
from typing import Dict, Optional
from typing import Any, Dict, Optional
from chia.consensus.constants import ConsensusConstants
from chia.consensus.default_constants import DEFAULT_CONSTANTS
@ -26,24 +26,25 @@ log = logging.getLogger(__name__)
def create_full_node_crawler_service(
root_path: pathlib.Path,
config: Dict,
config: Dict[str, Any],
consensus_constants: ConsensusConstants,
connect_to_daemon: bool = True,
) -> Service[Crawler, CrawlerAPI]:
service_config = config[SERVICE_NAME]
crawler_config = service_config["crawler"]
crawler = Crawler(
service_config,
root_path=root_path,
consensus_constants=consensus_constants,
constants=consensus_constants,
)
api = CrawlerAPI(crawler)
network_id = service_config["selected_network"]
rpc_info: Optional[RpcInfo] = None
if service_config.get("start_rpc_server", True):
rpc_info = (CrawlerRpcApi, service_config.get("rpc_port", 8561))
if crawler_config.get("start_rpc_server", True):
rpc_info = (CrawlerRpcApi, crawler_config.get("rpc_port", 8561))
return Service(
root_path=root_path,

View File

@ -24,6 +24,7 @@ from chia.introducer.introducer_api import IntroducerAPI
from chia.protocols.shared_protocol import Capability, capabilities
from chia.seeder.crawler import Crawler
from chia.seeder.crawler_api import CrawlerAPI
from chia.seeder.dns_server import DNSServer, create_dns_server_service
from chia.seeder.start_crawler import create_full_node_crawler_service
from chia.server.start_farmer import create_farmer_service
from chia.server.start_full_node import create_full_node_service
@ -32,15 +33,17 @@ from chia.server.start_introducer import create_introducer_service
from chia.server.start_service import Service
from chia.server.start_timelord import create_timelord_service
from chia.server.start_wallet import create_wallet_service
from chia.simulator.block_tools import BlockTools
from chia.simulator.block_tools import BlockTools, test_constants
from chia.simulator.keyring import TempKeyring
from chia.simulator.ssl_certs import get_next_nodes_certs_and_keys, get_next_private_ca_cert_and_key
from chia.simulator.start_simulator import create_full_node_simulator_service
from chia.ssl.create_ssl import create_all_ssl
from chia.timelord.timelord import Timelord
from chia.timelord.timelord_api import TimelordAPI
from chia.timelord.timelord_launcher import kill_processes, spawn_process
from chia.types.peer_info import UnresolvedPeerInfo
from chia.util.bech32m import encode_puzzle_hash
from chia.util.config import config_path_for_filename, lock_and_load_config, save_config
from chia.util.config import config_path_for_filename, load_config, lock_and_load_config, save_config
from chia.util.ints import uint16
from chia.util.keychain import bytes_to_mnemonic
from chia.util.lock import Lockfile
@ -175,24 +178,36 @@ async def setup_full_node(
async def setup_crawler(
bt: BlockTools,
root_path_populated_with_config: Path, database_uri: str
) -> AsyncGenerator[Service[Crawler, CrawlerAPI], None]:
config = bt.config
create_all_ssl(
root_path=root_path_populated_with_config,
private_ca_crt_and_key=get_next_private_ca_cert_and_key().collateral.cert_and_key,
node_certs_and_keys=get_next_nodes_certs_and_keys().collateral.certs_and_keys,
)
config = load_config(root_path_populated_with_config, "config.yaml")
service_config = config["seeder"]
service_config["selected_network"] = "testnet0"
service_config["port"] = 0
service_config["start_rpc_server"] = False
service_config["crawler"]["start_rpc_server"] = False
service_config["other_peers_port"] = 58444
service_config["crawler_db_path"] = database_uri
overrides = service_config["network_overrides"]["constants"][service_config["selected_network"]]
updated_constants = bt.constants.replace_str_to_bytes(**overrides)
bt.change_config(config)
updated_constants = test_constants.replace_str_to_bytes(**overrides)
service = create_full_node_crawler_service(
bt.root_path,
root_path_populated_with_config,
config,
updated_constants,
connect_to_daemon=False,
)
await service.start()
if not service_config["crawler"]["start_rpc_server"]: # otherwise the loops don't work.
service._node.state_changed_callback = lambda x, y: None
try:
yield service
finally:
@ -200,6 +215,24 @@ async def setup_crawler(
await service.wait_closed()
async def setup_seeder(root_path_populated_with_config: Path, database_uri: str) -> AsyncGenerator[DNSServer, None]:
config = load_config(root_path_populated_with_config, "config.yaml")
service_config = config["seeder"]
service_config["selected_network"] = "testnet0"
if service_config["domain_name"].endswith("."): # remove the trailing . so that we can test that logic.
service_config["domain_name"] = service_config["domain_name"][:-1]
service_config["dns_port"] = 0
service_config["crawler_db_path"] = database_uri
service = create_dns_server_service(
config,
root_path_populated_with_config,
)
async with service.run():
yield service
# Note: convert these setup functions to fixtures, or push it one layer up,
# keeping these usable independently?
async def setup_wallet_node(

View File

@ -139,6 +139,8 @@ seeder:
port: 8444
# Most full nodes on the network run on this port. (i.e. 8444 for mainnet, 58444 for testnet).
other_peers_port: 8444
# What port to run the DNS server on, (this is useful if you are already using port 53 for DNS).
dns_port: 53
# This will override the default full_node.peer_connect_timeout for the crawler full node
peer_connect_timeout: 2
# Path to crawler DB. Defaults to $CHIA_ROOT/crawler.db
@ -154,7 +156,7 @@ seeder:
nameserver: "example.com."
ttl: 300
soa:
rname: "hostmaster.example.com."
rname: "hostmaster.example.com" # all @ symbols need to be replaced with . in dns records.
serial_number: 1619105223
refresh: 10800
retry: 10800

View File

@ -16,11 +16,6 @@ chia.rpc.rpc_client
chia.rpc.util
chia.rpc.wallet_rpc_api
chia.rpc.wallet_rpc_client
chia.seeder.crawl_store
chia.seeder.crawler
chia.seeder.dns_server
chia.seeder.peer_record
chia.seeder.start_crawler
chia.simulator.full_node_simulator
chia.simulator.keyring
chia.simulator.setup_services

View File

@ -4,6 +4,7 @@ from __future__ import annotations
import datetime
import multiprocessing
import os
import random
import sysconfig
import tempfile
from enum import Enum
@ -32,6 +33,7 @@ from chia.rpc.harvester_rpc_client import HarvesterRpcClient
from chia.rpc.wallet_rpc_client import WalletRpcClient
from chia.seeder.crawler import Crawler
from chia.seeder.crawler_api import CrawlerAPI
from chia.seeder.dns_server import DNSServer
from chia.server.server import ChiaServer
from chia.server.start_service import Service
from chia.simulator.full_node_simulator import FullNodeSimulator
@ -44,7 +46,7 @@ from chia.simulator.setup_nodes import (
setup_simulators_and_wallets_service,
setup_two_nodes,
)
from chia.simulator.setup_services import setup_crawler, setup_daemon, setup_introducer, setup_timelord
from chia.simulator.setup_services import setup_crawler, setup_daemon, setup_introducer, setup_seeder, setup_timelord
from chia.simulator.time_out_assert import time_out_assert
from chia.simulator.wallet_tools import WalletTool
from chia.types.peer_info import PeerInfo
@ -872,11 +874,19 @@ async def timelord_service(bt):
@pytest_asyncio.fixture(scope="function")
async def crawler_service(bt: BlockTools) -> AsyncIterator[Service[Crawler, CrawlerAPI]]:
async for service in setup_crawler(bt):
async def crawler_service(
root_path_populated_with_config: Path, database_uri: str
) -> AsyncIterator[Service[Crawler, CrawlerAPI]]:
async for service in setup_crawler(root_path_populated_with_config, database_uri):
yield service
@pytest_asyncio.fixture(scope="function")
async def seeder_service(root_path_populated_with_config: Path, database_uri: str) -> AsyncIterator[DNSServer]:
async for seeder in setup_seeder(root_path_populated_with_config, database_uri):
yield seeder
@pytest.fixture(scope="function")
def tmp_chia_root(tmp_path):
"""
@ -980,3 +990,8 @@ async def harvester_farmer_environment(
harvester_rpc_cl.close()
await farmer_rpc_cl.await_closed()
await harvester_rpc_cl.await_closed()
@pytest.fixture(name="database_uri")
def database_uri_fixture() -> str:
return f"file:db_{random.randint(0, 99999999)}?mode=memory&cache=shared"

View File

@ -2,7 +2,6 @@ from __future__ import annotations
import os
import pathlib
import random
import sys
import time
from typing import Any, AsyncIterable, Awaitable, Callable, Dict, Iterator
@ -54,11 +53,6 @@ def create_example_fixture(request: SubRequest) -> Callable[[DataStore, bytes32]
return request.param # type: ignore[no-any-return]
@pytest.fixture(name="database_uri")
def database_uri_fixture() -> str:
return f"file:db_{random.randint(0, 99999999)}?mode=memory&cache=shared"
@pytest.fixture(name="tree_id", scope="function")
def tree_id_fixture() -> bytes32:
base = b"a tree id"

View File

@ -1,6 +1,7 @@
from __future__ import annotations
import logging
import time
from typing import cast
import pytest
@ -11,13 +12,14 @@ from chia.protocols.protocol_message_types import ProtocolMessageTypes
from chia.protocols.wallet_protocol import RequestChildren
from chia.seeder.crawler import Crawler
from chia.seeder.crawler_api import CrawlerAPI
from chia.seeder.peer_record import PeerRecord, PeerReliability
from chia.server.outbound_message import make_msg
from chia.server.start_service import Service
from chia.simulator.setup_nodes import SimulatorsAndWalletsServices
from chia.simulator.time_out_assert import time_out_assert
from chia.types.blockchain_format.sized_bytes import bytes32
from chia.types.peer_info import PeerInfo
from chia.util.ints import uint16, uint32, uint128
from chia.util.ints import uint16, uint32, uint64, uint128
@pytest.mark.asyncio
@ -68,3 +70,43 @@ async def test_valid_message(
)
assert await connection.send_message(msg)
await time_out_assert(10, peer_added)
@pytest.mark.asyncio
async def test_crawler_to_db(
crawler_service: Service[Crawler, CrawlerAPI], one_node: SimulatorsAndWalletsServices
) -> None:
"""
This is a lot more of an integration test, but it tests the whole process. We add a node to the crawler, then we
save it to the db and validate.
"""
[full_node_service], _, _ = one_node
full_node = full_node_service._node
crawler = crawler_service._node
crawl_store = crawler.crawl_store
assert crawl_store is not None
peer_address = "127.0.0.1"
# create peer records
peer_record = PeerRecord(
peer_address,
peer_address,
uint32(full_node.server.get_port()),
False,
uint64(0),
uint32(0),
uint64(0),
uint64(int(time.time())),
uint64(0),
"undefined",
uint64(0),
tls_version="unknown",
)
peer_reliability = PeerReliability(peer_address, tries=1, successes=1)
# add peer to the db & mark it as connected
await crawl_store.add_peer(peer_record, peer_reliability)
assert peer_record == crawl_store.host_to_records[peer_address]
# validate the db data
await time_out_assert(20, crawl_store.get_good_peers, [peer_address])

View File

@ -9,7 +9,7 @@ from chia.seeder.crawler import Crawler
class TestCrawlerRpc:
@pytest.mark.asyncio
async def test_get_ips_after_timestamp(self, bt):
crawler = Crawler(bt.config.get("seeder", {}), bt.root_path, consensus_constants=bt.constants)
crawler = Crawler(bt.config.get("seeder", {}), bt.root_path, constants=bt.constants)
crawler_rpc_api = CrawlerRpcApi(crawler)
# Should raise ValueError when `after` is not supplied

275
tests/core/test_seeder.py Normal file
View File

@ -0,0 +1,275 @@
from __future__ import annotations
import time
from dataclasses import dataclass
from ipaddress import IPv4Address, IPv6Address
from socket import AF_INET6, SOCK_STREAM
from typing import List, Tuple, cast
import dns
import pytest
from chia.seeder.dns_server import DNSServer
from chia.seeder.peer_record import PeerRecord, PeerReliability
from chia.simulator.time_out_assert import time_out_assert
from chia.util.ints import uint32, uint64
timeout = 0.5
def generate_test_combs() -> List[Tuple[bool, str, dns.rdatatype.RdataType]]:
"""
Generates all the combinations of tests we want to run.
"""
output = []
use_tcp = [True, False]
target_address = ["::1", "127.0.0.1"]
request_types = [dns.rdatatype.A, dns.rdatatype.AAAA, dns.rdatatype.ANY, dns.rdatatype.NS, dns.rdatatype.SOA]
# (use_tcp, target-addr, request_type), so udp + tcp, ipv6 +4 for every request type we support.
for addr in target_address:
for tcp in use_tcp:
for req in request_types:
output.append((tcp, addr, req))
return output
all_test_combinations = generate_test_combs()
@dataclass(frozen=True)
class FakeDnsPacket:
real_packet: dns.message.Message
def to_wire(self) -> bytes:
return self.real_packet.to_wire()[23:]
def __getattr__(self, item: object) -> None:
# This is definitely cheating, but it works
return None
async def make_dns_query(
use_tcp: bool, target_address: str, port: int, dns_message: dns.message.Message, d_timeout: float = timeout
) -> dns.message.Message:
"""
Makes a DNS query for the given domain name using the given protocol type.
"""
if use_tcp:
return await dns.asyncquery.tcp(q=dns_message, where=target_address, timeout=d_timeout, port=port)
return await dns.asyncquery.udp(q=dns_message, where=target_address, timeout=d_timeout, port=port)
def get_addresses(num_subnets: int = 10) -> Tuple[List[IPv4Address], List[IPv6Address]]:
ipv4 = []
ipv6 = []
# generate 2500 ipv4 and 2500 ipv6 peers, it's just a string so who cares
for s in range(num_subnets):
for i in range(1, 251): # im being lazy as we can only have 255 per subnet
ipv4.append(IPv4Address(f"192.168.{s}.{i}"))
ipv6.append(IPv6Address(f"2001:db8::{s}:{i}"))
return ipv4, ipv6
def assert_standard_results(
std_query_answer: List[dns.rrset.RRset], request_type: dns.rdatatype.RdataType, num_ns: int
) -> None:
if request_type == dns.rdatatype.A:
assert len(std_query_answer) == 1 # only 1 kind of answer
a_answer = std_query_answer[0]
assert a_answer.rdtype == dns.rdatatype.A
assert len(a_answer) == 32 # 32 ipv4 addresses
elif request_type == dns.rdatatype.AAAA:
assert len(std_query_answer) == 1 # only 1 kind of answer
aaaa_answer = std_query_answer[0]
assert aaaa_answer.rdtype == dns.rdatatype.AAAA
assert len(aaaa_answer) == 32 # 32 ipv6 addresses
elif request_type == dns.rdatatype.ANY:
assert len(std_query_answer) == 4 # 4 kinds of answers
for answer in std_query_answer:
if answer.rdtype == dns.rdatatype.A:
assert len(answer) == 16
elif answer.rdtype == dns.rdatatype.AAAA:
assert len(answer) == 16
elif answer.rdtype == dns.rdatatype.NS:
assert len(answer) == num_ns
else:
assert len(answer) == 1
elif request_type == dns.rdatatype.NS:
assert len(std_query_answer) == 1 # only 1 kind of answer
ns_answer = std_query_answer[0]
assert ns_answer.rdtype == dns.rdatatype.NS
assert len(ns_answer) == num_ns # ns records
else:
assert len(std_query_answer) == 1 # soa
soa_answer = std_query_answer[0]
assert soa_answer.rdtype == dns.rdatatype.SOA
assert len(soa_answer) == 1
@pytest.mark.asyncio
@pytest.mark.parametrize("use_tcp, target_address, request_type", all_test_combinations)
async def test_error_conditions(
seeder_service: DNSServer, use_tcp: bool, target_address: str, request_type: dns.rdatatype.RdataType
) -> None:
"""
We check having no peers, an invalid packet, an early EOF, and a packet then an EOF halfway through (tcp only).
We also check for a dns record that does not exist, and a dns record outside the domain.
"""
port = seeder_service.dns_port
domain = seeder_service.domain # default is: seeder.example.com
num_ns = len(seeder_service.ns_records)
# No peers
no_peers = dns.message.make_query(domain, request_type)
no_peers_response = await make_dns_query(use_tcp, target_address, port, no_peers)
assert no_peers_response.rcode() == dns.rcode.NOERROR
if request_type == dns.rdatatype.A or request_type == dns.rdatatype.AAAA:
assert len(no_peers_response.answer) == 0 # no response, as expected
elif request_type == dns.rdatatype.ANY: # ns + soa
assert len(no_peers_response.answer) == 2
for answer in no_peers_response.answer:
if answer.rdtype == dns.rdatatype.NS:
assert len(answer.items) == num_ns
else:
assert len(answer.items) == 1
elif request_type == dns.rdatatype.NS:
assert len(no_peers_response.answer) == 1 # ns
assert no_peers_response.answer[0].rdtype == dns.rdatatype.NS
assert len(no_peers_response.answer[0].items) == num_ns
else:
assert len(no_peers_response.answer) == 1 # soa
assert no_peers_response.answer[0].rdtype == dns.rdatatype.SOA
assert len(no_peers_response.answer[0].items) == 1
# Authority Records
assert len(no_peers_response.authority) == num_ns + 1 # ns + soa
for record_list in no_peers_response.authority:
if record_list.rdtype == dns.rdatatype.NS:
assert len(record_list.items) == num_ns
else:
assert len(record_list.items) == 1 # soa
# Invalid packet (this is kinda a pain)
invalid_packet = cast(dns.message.Message, FakeDnsPacket(dns.message.make_query(domain, request_type)))
with pytest.raises(EOFError if use_tcp else dns.exception.Timeout): # UDP will time out, TCP will EOF
await make_dns_query(use_tcp, target_address, port, invalid_packet)
# early EOF packet
if use_tcp:
backend = dns.asyncbackend.get_default_backend()
tcp_socket = await backend.make_socket( # type: ignore[no-untyped-call]
AF_INET6, SOCK_STREAM, 0, ("::", 0), ("::1", port), timeout=timeout
)
async with tcp_socket as socket:
await socket.close()
with pytest.raises(EOFError):
await dns.asyncquery.receive_tcp(tcp_socket, timeout)
# Packet then length header then EOF
if use_tcp:
backend = dns.asyncbackend.get_default_backend()
tcp_socket = await backend.make_socket( # type: ignore[no-untyped-call]
AF_INET6, SOCK_STREAM, 0, ("::", 0), ("::1", port), timeout=timeout
)
async with tcp_socket as socket: # send packet then length then eof
r = await dns.asyncquery.tcp(q=no_peers, where="::1", timeout=timeout, sock=socket)
assert r.answer == no_peers_response.answer
# send 120, as the first 2 bytes / the length of the packet, so that the server expects more.
await socket.sendall(int(120).to_bytes(2, byteorder="big"), int(time.time() + timeout))
await socket.close()
with pytest.raises(EOFError):
await dns.asyncquery.receive_tcp(tcp_socket, timeout)
# Record does not exist
record_does_not_exist = dns.message.make_query("doesnotexist." + domain, request_type)
record_does_not_exist_response = await make_dns_query(use_tcp, target_address, port, record_does_not_exist)
assert record_does_not_exist_response.rcode() == dns.rcode.NXDOMAIN
assert len(record_does_not_exist_response.answer) == 0
assert len(record_does_not_exist_response.authority) == num_ns + 1 # ns + soa
# Record outside domain
record_outside_domain = dns.message.make_query("chia.net", request_type)
record_outside_domain_response = await make_dns_query(use_tcp, target_address, port, record_outside_domain)
assert record_outside_domain_response.rcode() == dns.rcode.REFUSED
assert len(record_outside_domain_response.answer) == 0
assert len(record_outside_domain_response.authority) == 0
@pytest.mark.asyncio
@pytest.mark.parametrize("use_tcp, target_address, request_type", all_test_combinations)
async def test_dns_queries(
seeder_service: DNSServer, use_tcp: bool, target_address: str, request_type: dns.rdatatype.RdataType
) -> None:
"""
We add 5000 peers directly, then try every kind of query many times over both the TCP and UDP protocols.
"""
port = seeder_service.dns_port
domain = seeder_service.domain # default is: seeder.example.com
num_ns = len(seeder_service.ns_records)
# add 5000 peers (2500 ipv4, 2500 ipv6)
seeder_service.reliable_peers_v4, seeder_service.reliable_peers_v6 = get_addresses()
# now we query for each type of record a lot of times and make sure we get the right number of responses
for i in range(150):
query = dns.message.make_query(domain, request_type, use_edns=True) # we need to generate a new request id.
std_query_response = await make_dns_query(use_tcp, target_address, port, query)
assert std_query_response.rcode() == dns.rcode.NOERROR
assert_standard_results(std_query_response.answer, request_type, num_ns)
# Assert Authority Records
assert len(std_query_response.authority) == num_ns + 1 # ns + soa
for record_list in std_query_response.authority:
if record_list.rdtype == dns.rdatatype.NS:
assert len(record_list.items) == num_ns
else:
assert len(record_list.items) == 1 # soa
# Validate EDNS
e_query = dns.message.make_query(domain, dns.rdatatype.ANY, use_edns=False)
with pytest.raises(dns.query.BadResponse): # response is truncated without EDNS
await make_dns_query(False, target_address, port, e_query)
@pytest.mark.asyncio
async def test_db_processing(seeder_service: DNSServer) -> None:
"""
We add 1000 peers through the db, then try every kind of query over both the TCP and UDP protocols.
"""
port = seeder_service.dns_port
domain = seeder_service.domain # default is: seeder.example.com
num_ns = len(seeder_service.ns_records)
crawl_store = seeder_service.crawl_store
assert crawl_store is not None
ipv4, ipv6 = get_addresses(2) # get 1000 addresses
# add peers to the in memory part of the db.
for peer in [str(peer) for pair in zip(ipv4, ipv6) for peer in pair]:
new_peer = PeerRecord(
peer,
peer,
uint32(58444),
False,
uint64(0),
uint32(0),
uint64(0),
uint64(int(time.time())),
uint64(0),
"undefined",
uint64(0),
tls_version="unknown",
)
new_peer_reliability = PeerReliability(peer, tries=3, successes=3) # make sure the peer starts as reliable.
crawl_store.maybe_add_peer(new_peer, new_peer_reliability) # we don't fully add it because we don't need to.
# Write these new peers to db.
await crawl_store.load_reliable_peers_to_db()
# wait for the new db to be read.
await time_out_assert(30, lambda: seeder_service.reliable_peers_v4 != [])
# now we check all the combinations once (not a stupid amount of times)
for use_tcp, target_address, request_type in all_test_combinations:
query = dns.message.make_query(domain, request_type, use_edns=True) # we need to generate a new request id.
std_query_response = await make_dns_query(use_tcp, target_address, port, query)
assert std_query_response.rcode() == dns.rcode.NOERROR
assert_standard_results(std_query_response.answer, request_type, num_ns)