mirror of
https://github.com/Chia-Network/chia-blockchain.git
synced 2024-08-16 14:20:47 +03:00
Refactor Seeder & Crawler code + add tests (#15781)
* Cleanup seeder & mypy all files holy crap mother of all tech debt this was horrid * Make UDP Protocol class a dataclass & separate out functions. * add TCP protocol class to DNS server * Fix mypy types & other cleanup also fix a couple bugs * Add seeder and crawler tests * change log levels * re add db lock timeout oops * add edns & edns tests * fix repeated shutdown on close signal * fix binding to use both ipv6 and ipv4 whyyyyyyy * add ipv6 and ipv4 tests + add ipv4 if windows
This commit is contained in:
parent
c9e6c9d5cc
commit
0d36874caa
@ -1,37 +1,34 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import dataclasses
|
||||
import ipaddress
|
||||
import logging
|
||||
import random
|
||||
import time
|
||||
from dataclasses import dataclass, field, replace
|
||||
from typing import Dict, List
|
||||
|
||||
import aiosqlite
|
||||
|
||||
from chia.seeder.peer_record import PeerRecord, PeerReliability
|
||||
from chia.util.ints import uint64
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CrawlStore:
|
||||
crawl_db: aiosqlite.Connection
|
||||
last_timestamp: int
|
||||
lock: asyncio.Lock
|
||||
|
||||
host_to_records: Dict
|
||||
host_to_selected_time: Dict
|
||||
host_to_reliability: Dict
|
||||
banned_peers: int
|
||||
ignored_peers: int
|
||||
reliable_peers: int
|
||||
host_to_records: Dict[str, PeerRecord] = field(default_factory=dict) # peer_id: PeerRecord
|
||||
host_to_selected_time: Dict[str, float] = field(default_factory=dict) # peer_id: timestamp (as a float)
|
||||
host_to_reliability: Dict[str, PeerReliability] = field(default_factory=dict) # peer_id: PeerReliability
|
||||
banned_peers: int = 0
|
||||
ignored_peers: int = 0
|
||||
reliable_peers: int = 0
|
||||
|
||||
@classmethod
|
||||
async def create(cls, connection: aiosqlite.Connection):
|
||||
self = cls()
|
||||
async def create(cls, connection: aiosqlite.Connection) -> CrawlStore:
|
||||
self = cls(connection)
|
||||
|
||||
self.crawl_db = connection
|
||||
await self.crawl_db.execute(
|
||||
(
|
||||
"CREATE TABLE IF NOT EXISTS peer_records("
|
||||
@ -82,21 +79,16 @@ class CrawlStore:
|
||||
await self.crawl_db.execute("CREATE INDEX IF NOT EXISTS ignore_till on peer_reliability(ignore_till)")
|
||||
|
||||
await self.crawl_db.commit()
|
||||
self.last_timestamp = 0
|
||||
self.ignored_peers = 0
|
||||
self.banned_peers = 0
|
||||
self.reliable_peers = 0
|
||||
self.host_to_selected_time = {}
|
||||
await self.unload_from_db()
|
||||
return self
|
||||
|
||||
def maybe_add_peer(self, peer_record: PeerRecord, peer_reliability: PeerReliability):
|
||||
def maybe_add_peer(self, peer_record: PeerRecord, peer_reliability: PeerReliability) -> None:
|
||||
if peer_record.peer_id not in self.host_to_records:
|
||||
self.host_to_records[peer_record.peer_id] = peer_record
|
||||
if peer_reliability.peer_id not in self.host_to_reliability:
|
||||
self.host_to_reliability[peer_reliability.peer_id] = peer_reliability
|
||||
|
||||
async def add_peer(self, peer_record: PeerRecord, peer_reliability: PeerReliability, save_db: bool = False):
|
||||
async def add_peer(self, peer_record: PeerRecord, peer_reliability: PeerReliability, save_db: bool = False) -> None:
|
||||
if not save_db:
|
||||
self.host_to_records[peer_record.peer_id] = peer_record
|
||||
self.host_to_reliability[peer_reliability.peer_id] = peer_reliability
|
||||
@ -152,41 +144,41 @@ class CrawlStore:
|
||||
async def get_peer_reliability(self, peer_id: str) -> PeerReliability:
|
||||
return self.host_to_reliability[peer_id]
|
||||
|
||||
async def peer_failed_to_connect(self, peer: PeerRecord):
|
||||
now = int(time.time())
|
||||
async def peer_failed_to_connect(self, peer: PeerRecord) -> None:
|
||||
now = uint64(time.time())
|
||||
age_timestamp = int(max(peer.last_try_timestamp, peer.connected_timestamp))
|
||||
if age_timestamp == 0:
|
||||
age_timestamp = now - 1000
|
||||
replaced = dataclasses.replace(peer, try_count=peer.try_count + 1, last_try_timestamp=now)
|
||||
replaced = replace(peer, try_count=peer.try_count + 1, last_try_timestamp=now)
|
||||
reliability = await self.get_peer_reliability(peer.peer_id)
|
||||
if reliability is None:
|
||||
reliability = PeerReliability(peer.peer_id)
|
||||
reliability.update(False, now - age_timestamp)
|
||||
await self.add_peer(replaced, reliability)
|
||||
|
||||
async def peer_connected(self, peer: PeerRecord, tls_version: str):
|
||||
now = int(time.time())
|
||||
async def peer_connected(self, peer: PeerRecord, tls_version: str) -> None:
|
||||
now = uint64(time.time())
|
||||
age_timestamp = int(max(peer.last_try_timestamp, peer.connected_timestamp))
|
||||
if age_timestamp == 0:
|
||||
age_timestamp = now - 1000
|
||||
replaced = dataclasses.replace(peer, connected=True, connected_timestamp=now, tls_version=tls_version)
|
||||
replaced = replace(peer, connected=True, connected_timestamp=now, tls_version=tls_version)
|
||||
reliability = await self.get_peer_reliability(peer.peer_id)
|
||||
if reliability is None:
|
||||
reliability = PeerReliability(peer.peer_id)
|
||||
reliability.update(True, now - age_timestamp)
|
||||
await self.add_peer(replaced, reliability)
|
||||
|
||||
async def update_best_timestamp(self, host: str, timestamp):
|
||||
async def update_best_timestamp(self, host: str, timestamp: uint64) -> None:
|
||||
if host not in self.host_to_records:
|
||||
return
|
||||
record = self.host_to_records[host]
|
||||
replaced = dataclasses.replace(record, best_timestamp=timestamp)
|
||||
replaced = replace(record, best_timestamp=timestamp)
|
||||
if host not in self.host_to_reliability:
|
||||
return
|
||||
reliability = self.host_to_reliability[host]
|
||||
await self.add_peer(replaced, reliability)
|
||||
|
||||
async def peer_connected_hostname(self, host: str, connected: bool = True, tls_version: str = "unknown"):
|
||||
async def peer_connected_hostname(self, host: str, connected: bool = True, tls_version: str = "unknown") -> None:
|
||||
if host not in self.host_to_records:
|
||||
return
|
||||
record = self.host_to_records[host]
|
||||
@ -195,7 +187,7 @@ class CrawlStore:
|
||||
else:
|
||||
await self.peer_failed_to_connect(record)
|
||||
|
||||
async def get_peers_to_crawl(self, min_batch_size, max_batch_size) -> List[PeerRecord]:
|
||||
async def get_peers_to_crawl(self, min_batch_size: int, max_batch_size: int) -> List[PeerRecord]:
|
||||
now = int(time.time())
|
||||
records = []
|
||||
records_v6 = []
|
||||
@ -270,20 +262,20 @@ class CrawlStore:
|
||||
def get_reliable_peers(self) -> int:
|
||||
return self.reliable_peers
|
||||
|
||||
async def load_to_db(self):
|
||||
log.error("Saving peers to DB...")
|
||||
async def load_to_db(self) -> None:
|
||||
log.info("Saving peers to DB...")
|
||||
for peer_id in list(self.host_to_reliability.keys()):
|
||||
if peer_id in self.host_to_reliability and peer_id in self.host_to_records:
|
||||
reliability = self.host_to_reliability[peer_id]
|
||||
record = self.host_to_records[peer_id]
|
||||
await self.add_peer(record, reliability, True)
|
||||
await self.crawl_db.commit()
|
||||
log.error(" - Done saving peers to DB")
|
||||
log.info(" - Done saving peers to DB")
|
||||
|
||||
async def unload_from_db(self):
|
||||
async def unload_from_db(self) -> None:
|
||||
self.host_to_records = {}
|
||||
self.host_to_reliability = {}
|
||||
log.error("Loading peer reliability records...")
|
||||
log.info("Loading peer reliability records...")
|
||||
cursor = await self.crawl_db.execute(
|
||||
"SELECT * from peer_reliability",
|
||||
)
|
||||
@ -312,9 +304,9 @@ class CrawlStore:
|
||||
row[18],
|
||||
row[19],
|
||||
)
|
||||
self.host_to_reliability[row[0]] = reliability
|
||||
log.error(" - Done loading peer reliability records...")
|
||||
log.error("Loading peer records...")
|
||||
self.host_to_reliability[reliability.peer_id] = reliability
|
||||
log.info(" - Done loading peer reliability records...")
|
||||
log.info("Loading peer records...")
|
||||
cursor = await self.crawl_db.execute(
|
||||
"SELECT * from peer_records",
|
||||
)
|
||||
@ -324,24 +316,23 @@ class CrawlStore:
|
||||
peer = PeerRecord(
|
||||
row[0], row[1], row[2], row[3], row[4], row[5], row[6], row[7], row[8], row[9], row[10], row[11]
|
||||
)
|
||||
self.host_to_records[row[0]] = peer
|
||||
log.error(" - Done loading peer records...")
|
||||
self.host_to_records[peer.peer_id] = peer
|
||||
log.info(" - Done loading peer records...")
|
||||
|
||||
# Crawler -> DNS.
|
||||
async def load_reliable_peers_to_db(self):
|
||||
async def load_reliable_peers_to_db(self) -> None:
|
||||
peers = []
|
||||
for peer_id in self.host_to_reliability:
|
||||
reliability = self.host_to_reliability[peer_id]
|
||||
for peer_id, reliability in self.host_to_reliability.items():
|
||||
if reliability.is_reliable():
|
||||
peers.append(peer_id)
|
||||
self.reliable_peers = len(peers)
|
||||
log.error("Deleting old good_peers from DB...")
|
||||
log.info("Deleting old good_peers from DB...")
|
||||
cursor = await self.crawl_db.execute(
|
||||
"DELETE from good_peers",
|
||||
)
|
||||
await cursor.close()
|
||||
log.error(" - Done deleting old good_peers...")
|
||||
log.error("Saving new good_peers to DB...")
|
||||
log.info(" - Done deleting old good_peers...")
|
||||
log.info("Saving new good_peers to DB...")
|
||||
for peer in peers:
|
||||
cursor = await self.crawl_db.execute(
|
||||
"INSERT OR REPLACE INTO good_peers VALUES(?)",
|
||||
@ -349,9 +340,9 @@ class CrawlStore:
|
||||
)
|
||||
await cursor.close()
|
||||
await self.crawl_db.commit()
|
||||
log.error(" - Done saving new good_peers to DB...")
|
||||
log.info(" - Done saving new good_peers to DB...")
|
||||
|
||||
def load_host_to_version(self):
|
||||
def load_host_to_version(self) -> tuple[dict[str, str], dict[str, uint64]]:
|
||||
versions = {}
|
||||
handshake = {}
|
||||
|
||||
@ -366,19 +357,30 @@ class CrawlStore:
|
||||
versions[host] = record.version
|
||||
handshake[host] = record.handshake_time
|
||||
|
||||
return (versions, handshake)
|
||||
return versions, handshake
|
||||
|
||||
def load_best_peer_reliability(self):
|
||||
def load_best_peer_reliability(self) -> dict[str, uint64]:
|
||||
best_timestamp = {}
|
||||
for host, record in self.host_to_records.items():
|
||||
if record.best_timestamp > time.time() - 5 * 24 * 3600:
|
||||
best_timestamp[host] = record.best_timestamp
|
||||
return best_timestamp
|
||||
|
||||
async def update_version(self, host, version, now):
|
||||
async def update_version(self, host: str, version: str, timestamp_now: uint64) -> None:
|
||||
record = self.host_to_records.get(host, None)
|
||||
reliability = self.host_to_reliability.get(host, None)
|
||||
if record is None or reliability is None:
|
||||
return
|
||||
record.update_version(version, now)
|
||||
record.update_version(version, timestamp_now)
|
||||
await self.add_peer(record, reliability)
|
||||
|
||||
async def get_good_peers(self) -> list[str]: # This is for the DNS server
|
||||
cursor = await self.crawl_db.execute(
|
||||
"SELECT * from good_peers",
|
||||
)
|
||||
rows = await cursor.fetchall()
|
||||
await cursor.close()
|
||||
result = [row[0] for row in rows]
|
||||
if len(result) > 0:
|
||||
random.shuffle(result) # mix up the peers
|
||||
return result
|
||||
|
@ -6,15 +6,16 @@ import logging
|
||||
import time
|
||||
import traceback
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Optional, Set, Tuple
|
||||
|
||||
import aiosqlite
|
||||
|
||||
from chia.consensus.constants import ConsensusConstants
|
||||
from chia.full_node.coin_store import CoinStore
|
||||
from chia.full_node.full_node_api import FullNodeAPI
|
||||
from chia.protocols import full_node_protocol
|
||||
from chia.protocols.full_node_protocol import RespondPeers
|
||||
from chia.rpc.rpc_server import StateChangedProtocol, default_get_connections
|
||||
from chia.seeder.crawl_store import CrawlStore
|
||||
from chia.seeder.peer_record import PeerRecord, PeerReliability
|
||||
@ -29,20 +30,28 @@ from chia.util.path import path_from_root
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Crawler:
|
||||
sync_store: Any
|
||||
coin_store: CoinStore
|
||||
connection: Optional[aiosqlite.Connection]
|
||||
config: Dict
|
||||
_server: Optional[ChiaServer]
|
||||
crawl_store: Optional[CrawlStore]
|
||||
log: logging.Logger
|
||||
constants: ConsensusConstants
|
||||
_shut_down: bool
|
||||
config: Dict[str, Any]
|
||||
root_path: Path
|
||||
peer_count: int
|
||||
with_peak: set
|
||||
minimum_version_count: int
|
||||
constants: ConsensusConstants
|
||||
print_status: bool = True
|
||||
state_changed_callback: Optional[StateChangedProtocol] = None
|
||||
_server: Optional[ChiaServer] = None
|
||||
crawl_task: Optional[asyncio.Task[None]] = None
|
||||
crawl_store: Optional[CrawlStore] = None
|
||||
log: logging.Logger = log
|
||||
_shut_down: bool = False
|
||||
peer_count: int = 0
|
||||
with_peak: Set[PeerInfo] = field(default_factory=set)
|
||||
seen_nodes: Set[str] = field(default_factory=set)
|
||||
minimum_version_count: int = 0
|
||||
peers_retrieved: List[RespondPeers] = field(default_factory=list)
|
||||
host_to_version: Dict[str, str] = field(default_factory=dict)
|
||||
versions: Dict[str, int] = field(default_factory=lambda: defaultdict(lambda: 0))
|
||||
version_cache: List[Tuple[str, str]] = field(default_factory=list)
|
||||
handshake_time: Dict[str, uint64] = field(default_factory=dict)
|
||||
best_timestamp_per_peer: Dict[str, uint64] = field(default_factory=dict)
|
||||
|
||||
@property
|
||||
def server(self) -> ChiaServer:
|
||||
@ -53,38 +62,16 @@ class Crawler:
|
||||
|
||||
return self._server
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Dict,
|
||||
root_path: Path,
|
||||
consensus_constants: ConsensusConstants,
|
||||
name: str = None,
|
||||
):
|
||||
self.initialized = False
|
||||
self.root_path = root_path
|
||||
self.config = config
|
||||
self.connection = None
|
||||
self._server = None
|
||||
self._shut_down = False # Set to true to close all infinite loops
|
||||
self.constants = consensus_constants
|
||||
self.state_changed_callback: Optional[StateChangedProtocol] = None
|
||||
self.crawl_store = None
|
||||
self.log = log
|
||||
self.peer_count = 0
|
||||
self.with_peak = set()
|
||||
self.peers_retrieved: List[Any] = []
|
||||
self.host_to_version: Dict[str, str] = {}
|
||||
self.version_cache: List[Tuple[str, str]] = []
|
||||
self.handshake_time: Dict[str, int] = {}
|
||||
self.best_timestamp_per_peer: Dict[str, int] = {}
|
||||
crawler_db_path: str = config.get("crawler_db_path", "crawler.db")
|
||||
self.db_path = path_from_root(root_path, crawler_db_path)
|
||||
def __post_init__(self) -> None:
|
||||
# get db path
|
||||
crawler_db_path: str = self.config.get("crawler_db_path", "crawler.db")
|
||||
self.db_path = path_from_root(self.root_path, crawler_db_path)
|
||||
self.db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
self.bootstrap_peers = config["bootstrap_peers"]
|
||||
self.minimum_height = config["minimum_height"]
|
||||
self.other_peers_port = config["other_peers_port"]
|
||||
self.versions: Dict[str, int] = defaultdict(lambda: 0)
|
||||
self.minimum_version_count = self.config.get("minimum_version_count", 100)
|
||||
# load data from config
|
||||
self.bootstrap_peers = self.config["bootstrap_peers"]
|
||||
self.minimum_height = self.config["minimum_height"]
|
||||
self.other_peers_port = self.config["other_peers_port"]
|
||||
self.minimum_version_count: int = self.config.get("minimum_version_count", 100)
|
||||
if self.minimum_version_count < 1:
|
||||
self.log.warning(
|
||||
f"Crawler configuration minimum_version_count expected to be greater than zero: "
|
||||
@ -97,11 +84,16 @@ class Crawler:
|
||||
def get_connections(self, request_node_type: Optional[NodeType]) -> List[Dict[str, Any]]:
|
||||
return default_get_connections(server=self.server, request_node_type=request_node_type)
|
||||
|
||||
async def create_client(self, peer_info, on_connect):
|
||||
async def create_client(
|
||||
self, peer_info: PeerInfo, on_connect: Callable[[WSChiaConnection], Awaitable[None]]
|
||||
) -> bool:
|
||||
return await self.server.start_client(peer_info, on_connect)
|
||||
|
||||
async def connect_task(self, peer):
|
||||
async def peer_action(peer: WSChiaConnection):
|
||||
async def connect_task(self, peer: PeerRecord) -> None:
|
||||
if self.crawl_store is None:
|
||||
raise ValueError("Not Connected to DB")
|
||||
|
||||
async def peer_action(peer: WSChiaConnection) -> None:
|
||||
peer_info = peer.get_peer_info()
|
||||
version = peer.get_version()
|
||||
if peer_info is not None and version is not None:
|
||||
@ -134,41 +126,34 @@ class Crawler:
|
||||
if not connected:
|
||||
await self.crawl_store.peer_failed_to_connect(peer)
|
||||
except Exception as e:
|
||||
self.log.info(f"Exception: {e}. Traceback: {traceback.format_exc()}.")
|
||||
self.log.warning(f"Exception: {e}. Traceback: {traceback.format_exc()}.")
|
||||
await self.crawl_store.peer_failed_to_connect(peer)
|
||||
|
||||
async def _start(self):
|
||||
async def _start(self) -> None:
|
||||
# We override the default peer_connect_timeout when running from the crawler
|
||||
crawler_peer_timeout = self.config.get("peer_connect_timeout", 2)
|
||||
self.server.config["peer_connect_timeout"] = crawler_peer_timeout
|
||||
|
||||
self.task = asyncio.create_task(self.crawl())
|
||||
|
||||
async def crawl(self):
|
||||
# Ensure the state_changed callback is set up before moving on
|
||||
# Sometimes, the daemon connection + state changed callback isn't up and ready
|
||||
# by the time we get to the first _state_changed call, so this just ensures it's there before moving on
|
||||
while self.state_changed_callback is None:
|
||||
self.log.info("Waiting for state changed callback...")
|
||||
await asyncio.sleep(0.1)
|
||||
# Connect to the DB
|
||||
self.crawl_store: CrawlStore = await CrawlStore.create(await aiosqlite.connect(self.db_path))
|
||||
# Bootstrap the initial peers
|
||||
await self.load_bootstrap_peers()
|
||||
self.crawl_task = asyncio.create_task(self.crawl())
|
||||
|
||||
async def load_bootstrap_peers(self) -> None:
|
||||
assert self.crawl_store is not None
|
||||
try:
|
||||
self.connection = await aiosqlite.connect(self.db_path)
|
||||
self.crawl_store = await CrawlStore.create(self.connection)
|
||||
self.log.info("Started")
|
||||
self.log.warning("Bootstrapping initial peers...")
|
||||
t_start = time.time()
|
||||
total_nodes = 0
|
||||
self.seen_nodes = set()
|
||||
tried_nodes = set()
|
||||
for peer in self.bootstrap_peers:
|
||||
new_peer = PeerRecord(
|
||||
peer,
|
||||
peer,
|
||||
self.other_peers_port,
|
||||
False,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
uint64(0),
|
||||
uint32(0),
|
||||
uint64(0),
|
||||
uint64(int(time.time())),
|
||||
uint64(0),
|
||||
"undefined",
|
||||
@ -180,14 +165,26 @@ class Crawler:
|
||||
|
||||
self.host_to_version, self.handshake_time = self.crawl_store.load_host_to_version()
|
||||
self.best_timestamp_per_peer = self.crawl_store.load_best_peer_reliability()
|
||||
self.versions = defaultdict(lambda: 0)
|
||||
for host, version in self.host_to_version.items():
|
||||
self.versions[version] += 1
|
||||
|
||||
self._state_changed("loaded_initial_peers")
|
||||
self.log.warning(f"Bootstrapped initial peers in {time.time() - t_start} seconds")
|
||||
except Exception as e:
|
||||
self.log.error(f"Error bootstrapping initial peers: {e}")
|
||||
|
||||
while True:
|
||||
self.with_peak = set()
|
||||
async def crawl(self) -> None:
|
||||
# Ensure the state_changed callback is set up before moving on
|
||||
# Sometimes, the daemon connection + state changed callback isn't up and ready
|
||||
# by the time we get to the first _state_changed call, so this just ensures it's there before moving on
|
||||
while self.state_changed_callback is None:
|
||||
self.log.info("Waiting for state changed callback...")
|
||||
await asyncio.sleep(0.1)
|
||||
assert self.crawl_store is not None
|
||||
t_start = time.time()
|
||||
total_nodes = 0
|
||||
tried_nodes = set()
|
||||
try:
|
||||
while not self._shut_down:
|
||||
peers_to_crawl = await self.crawl_store.get_peers_to_crawl(25000, 250000)
|
||||
tasks = set()
|
||||
for peer in peers_to_crawl:
|
||||
@ -200,7 +197,6 @@ class Crawler:
|
||||
if len(tasks) >= 250:
|
||||
await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
|
||||
tasks = set(filter(lambda t: not t.done(), tasks))
|
||||
|
||||
if len(tasks) > 0:
|
||||
await asyncio.wait(tasks, timeout=30)
|
||||
|
||||
@ -231,16 +227,15 @@ class Crawler:
|
||||
tls_version="unknown",
|
||||
)
|
||||
new_peer_reliability = PeerReliability(response_peer.host)
|
||||
if self.crawl_store is not None:
|
||||
self.crawl_store.maybe_add_peer(new_peer, new_peer_reliability)
|
||||
self.crawl_store.maybe_add_peer(new_peer, new_peer_reliability)
|
||||
await self.crawl_store.update_best_timestamp(
|
||||
response_peer.host,
|
||||
self.best_timestamp_per_peer[response_peer.host],
|
||||
)
|
||||
for host, version in self.version_cache:
|
||||
self.handshake_time[host] = int(time.time())
|
||||
self.handshake_time[host] = uint64(time.time())
|
||||
self.host_to_version[host] = version
|
||||
await self.crawl_store.update_version(host, version, int(time.time()))
|
||||
await self.crawl_store.update_version(host, version, uint64(time.time()))
|
||||
|
||||
to_remove = set()
|
||||
now = int(time.time())
|
||||
@ -264,94 +259,46 @@ class Crawler:
|
||||
self.versions = defaultdict(lambda: 0)
|
||||
for host, version in self.host_to_version.items():
|
||||
self.versions[version] += 1
|
||||
self.version_cache = []
|
||||
self.peers_retrieved = []
|
||||
|
||||
# clear caches
|
||||
self.version_cache: List[Tuple[str, str]] = []
|
||||
self.peers_retrieved = []
|
||||
self.server.banned_peers = {}
|
||||
self.with_peak = set()
|
||||
|
||||
if len(peers_to_crawl) == 0:
|
||||
continue
|
||||
|
||||
# Try up to 5 times to write to the DB in case there is a lock that causes a timeout
|
||||
for i in range(1, 5):
|
||||
try:
|
||||
await self.crawl_store.load_to_db()
|
||||
await self.crawl_store.load_reliable_peers_to_db()
|
||||
except Exception as e:
|
||||
self.log.error(f"Exception while saving to DB: {e}.")
|
||||
self.log.error("Waiting 5 seconds before retry...")
|
||||
await asyncio.sleep(5)
|
||||
continue
|
||||
break
|
||||
total_records = self.crawl_store.get_total_records()
|
||||
ipv6_count = self.crawl_store.get_ipv6_peers()
|
||||
self.log.error("***")
|
||||
self.log.error("Finished batch:")
|
||||
self.log.error(f"Total IPs stored in DB: {total_records}.")
|
||||
self.log.error(f"Total IPV6 addresses stored in DB: {ipv6_count}")
|
||||
self.log.error(f"Total connections attempted since crawler started: {total_nodes}.")
|
||||
self.log.error(f"Total unique nodes attempted since crawler started: {len(tried_nodes)}.")
|
||||
t_now = time.time()
|
||||
t_delta = int(t_now - t_start)
|
||||
if t_delta > 0:
|
||||
self.log.error(f"Avg connections per second: {total_nodes // t_delta}.")
|
||||
# Periodically print detailed stats.
|
||||
reliable_peers = self.crawl_store.get_reliable_peers()
|
||||
self.log.error(f"High quality reachable nodes, used by DNS introducer in replies: {reliable_peers}")
|
||||
banned_peers = self.crawl_store.get_banned_peers()
|
||||
ignored_peers = self.crawl_store.get_ignored_peers()
|
||||
available_peers = len(self.host_to_version)
|
||||
addresses_count = len(self.best_timestamp_per_peer)
|
||||
total_records = self.crawl_store.get_total_records()
|
||||
ipv6_addresses_count = 0
|
||||
for host in self.best_timestamp_per_peer.keys():
|
||||
try:
|
||||
ipaddress.IPv6Address(host)
|
||||
ipv6_addresses_count += 1
|
||||
except ipaddress.AddressValueError:
|
||||
continue
|
||||
self.log.error(
|
||||
"IPv4 addresses gossiped with timestamp in the last 5 days with respond_peers messages: "
|
||||
f"{addresses_count - ipv6_addresses_count}."
|
||||
)
|
||||
self.log.error(
|
||||
"IPv6 addresses gossiped with timestamp in the last 5 days with respond_peers messages: "
|
||||
f"{ipv6_addresses_count}."
|
||||
)
|
||||
ipv6_available_peers = 0
|
||||
for host in self.host_to_version.keys():
|
||||
try:
|
||||
ipaddress.IPv6Address(host)
|
||||
ipv6_available_peers += 1
|
||||
except ipaddress.AddressValueError:
|
||||
continue
|
||||
self.log.error(
|
||||
f"Total IPv4 nodes reachable in the last 5 days: {available_peers - ipv6_available_peers}."
|
||||
)
|
||||
self.log.error(f"Total IPv6 nodes reachable in the last 5 days: {ipv6_available_peers}.")
|
||||
self.log.error("Version distribution among reachable in the last 5 days (at least 100 nodes):")
|
||||
for version, count in sorted(self.versions.items(), key=lambda kv: kv[1], reverse=True):
|
||||
if count >= self.minimum_version_count:
|
||||
self.log.error(f"Version: {version} - Count: {count}")
|
||||
self.log.error(f"Banned addresses in the DB: {banned_peers}")
|
||||
self.log.error(f"Temporary ignored addresses in the DB: {ignored_peers}")
|
||||
self.log.error(
|
||||
"Peers to crawl from in the next batch (total IPs - ignored - banned): "
|
||||
f"{total_records - banned_peers - ignored_peers}"
|
||||
)
|
||||
self.log.error("***")
|
||||
|
||||
await self.save_to_db()
|
||||
await self.print_summary(t_start, total_nodes, tried_nodes)
|
||||
await asyncio.sleep(15) # 15 seconds between db updates
|
||||
self._state_changed("crawl_batch_completed")
|
||||
except Exception as e:
|
||||
self.log.error(f"Exception: {e}. Traceback: {traceback.format_exc()}.")
|
||||
|
||||
def set_server(self, server: ChiaServer):
|
||||
async def save_to_db(self) -> None:
|
||||
# Try up to 5 times to write to the DB in case there is a lock that causes a timeout
|
||||
if self.crawl_store is None:
|
||||
raise ValueError("Not Connected to DB")
|
||||
for i in range(1, 5):
|
||||
try:
|
||||
await self.crawl_store.load_to_db()
|
||||
await self.crawl_store.load_reliable_peers_to_db()
|
||||
return
|
||||
except Exception as e:
|
||||
self.log.error(f"Exception while saving to DB: {e}.")
|
||||
self.log.error("Waiting 5 seconds before retry...")
|
||||
await asyncio.sleep(5)
|
||||
continue
|
||||
|
||||
def set_server(self, server: ChiaServer) -> None:
|
||||
self._server = server
|
||||
|
||||
def _state_changed(self, change: str, change_data: Optional[Dict[str, Any]] = None):
|
||||
def _state_changed(self, change: str, change_data: Optional[Dict[str, Any]] = None) -> None:
|
||||
if self.state_changed_callback is not None:
|
||||
self.state_changed_callback(change, change_data)
|
||||
|
||||
async def new_peak(self, request: full_node_protocol.NewPeak, peer: WSChiaConnection):
|
||||
async def new_peak(self, request: full_node_protocol.NewPeak, peer: WSChiaConnection) -> None:
|
||||
try:
|
||||
peer_info = peer.get_peer_info()
|
||||
tls_version = peer.get_tls_version()
|
||||
@ -359,6 +306,11 @@ class Crawler:
|
||||
tls_version = "unknown"
|
||||
if peer_info is None:
|
||||
return
|
||||
# validate peer ip address:
|
||||
try:
|
||||
ipaddress.ip_address(peer_info.host)
|
||||
except ValueError:
|
||||
raise ValueError(f"Invalid peer ip address: {peer_info.host}")
|
||||
if request.height >= self.minimum_height:
|
||||
if self.crawl_store is not None:
|
||||
await self.crawl_store.peer_connected_hostname(peer_info.host, True, tls_version)
|
||||
@ -366,12 +318,79 @@ class Crawler:
|
||||
except Exception as e:
|
||||
self.log.error(f"Exception: {e}. Traceback: {traceback.format_exc()}.")
|
||||
|
||||
async def on_connect(self, connection: WSChiaConnection):
|
||||
async def on_connect(self, connection: WSChiaConnection) -> None:
|
||||
pass
|
||||
|
||||
def _close(self):
|
||||
def _close(self) -> None:
|
||||
self._shut_down = True
|
||||
|
||||
async def _await_closed(self):
|
||||
if self.connection is not None:
|
||||
await self.connection.close()
|
||||
async def _await_closed(self) -> None:
|
||||
if self.crawl_task is not None:
|
||||
try:
|
||||
await asyncio.wait_for(self.crawl_task, timeout=10) # wait 10 seconds before giving up
|
||||
except asyncio.TimeoutError:
|
||||
self.log.error("Crawl task did not exit in time, killing task.")
|
||||
self.crawl_task.cancel()
|
||||
if self.crawl_store is not None:
|
||||
self.log.info("Closing connection to DB.")
|
||||
await self.crawl_store.crawl_db.close()
|
||||
|
||||
async def print_summary(self, t_start: float, total_nodes: int, tried_nodes: Set[str]) -> None:
|
||||
assert self.crawl_store is not None # this is only ever called from the crawl task
|
||||
if not self.print_status:
|
||||
return
|
||||
total_records = self.crawl_store.get_total_records()
|
||||
ipv6_count = self.crawl_store.get_ipv6_peers()
|
||||
self.log.warning("***")
|
||||
self.log.warning("Finished batch:")
|
||||
self.log.warning(f"Total IPs stored in DB: {total_records}.")
|
||||
self.log.warning(f"Total IPV6 addresses stored in DB: {ipv6_count}")
|
||||
self.log.warning(f"Total connections attempted since crawler started: {total_nodes}.")
|
||||
self.log.warning(f"Total unique nodes attempted since crawler started: {len(tried_nodes)}.")
|
||||
t_now = time.time()
|
||||
t_delta = int(t_now - t_start)
|
||||
if t_delta > 0:
|
||||
self.log.warning(f"Avg connections per second: {total_nodes // t_delta}.")
|
||||
# Periodically print detailed stats.
|
||||
reliable_peers = self.crawl_store.get_reliable_peers()
|
||||
self.log.warning(f"High quality reachable nodes, used by DNS introducer in replies: {reliable_peers}")
|
||||
banned_peers = self.crawl_store.get_banned_peers()
|
||||
ignored_peers = self.crawl_store.get_ignored_peers()
|
||||
available_peers = len(self.host_to_version)
|
||||
addresses_count = len(self.best_timestamp_per_peer)
|
||||
total_records = self.crawl_store.get_total_records()
|
||||
ipv6_addresses_count = 0
|
||||
for host in self.best_timestamp_per_peer.keys():
|
||||
try:
|
||||
ipaddress.IPv6Address(host)
|
||||
ipv6_addresses_count += 1
|
||||
except ipaddress.AddressValueError:
|
||||
continue
|
||||
self.log.warning(
|
||||
"IPv4 addresses gossiped with timestamp in the last 5 days with respond_peers messages: "
|
||||
f"{addresses_count - ipv6_addresses_count}."
|
||||
)
|
||||
self.log.warning(
|
||||
"IPv6 addresses gossiped with timestamp in the last 5 days with respond_peers messages: "
|
||||
f"{ipv6_addresses_count}."
|
||||
)
|
||||
ipv6_available_peers = 0
|
||||
for host in self.host_to_version.keys():
|
||||
try:
|
||||
ipaddress.IPv6Address(host)
|
||||
ipv6_available_peers += 1
|
||||
except ipaddress.AddressValueError:
|
||||
continue
|
||||
self.log.warning(f"Total IPv4 nodes reachable in the last 5 days: {available_peers - ipv6_available_peers}.")
|
||||
self.log.warning(f"Total IPv6 nodes reachable in the last 5 days: {ipv6_available_peers}.")
|
||||
self.log.warning("Version distribution among reachable in the last 5 days (at least 100 nodes):")
|
||||
for version, count in sorted(self.versions.items(), key=lambda kv: kv[1], reverse=True):
|
||||
if count >= self.minimum_version_count:
|
||||
self.log.warning(f"Version: {version} - Count: {count}")
|
||||
self.log.warning(f"Banned addresses in the DB: {banned_peers}")
|
||||
self.log.warning(f"Temporary ignored addresses in the DB: {ignored_peers}")
|
||||
self.log.warning(
|
||||
"Peers to crawl from in the next batch (total IPs - ignored - banned): "
|
||||
f"{total_records - banned_peers - ignored_peers}"
|
||||
)
|
||||
self.log.warning("***")
|
||||
|
@ -1,294 +1,534 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import ipaddress
|
||||
import logging
|
||||
import random
|
||||
import signal
|
||||
import sys
|
||||
import traceback
|
||||
from asyncio import AbstractEventLoop
|
||||
from contextlib import asynccontextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from ipaddress import IPv4Address, IPv6Address, ip_address
|
||||
from multiprocessing import freeze_support
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any, AsyncIterator, Awaitable, Callable, Dict, List, Optional
|
||||
|
||||
import aiosqlite
|
||||
from dnslib import AAAA, CNAME, MX, NS, QTYPE, RR, SOA, A, DNSHeader, DNSRecord
|
||||
from dnslib import AAAA, EDNS0, NS, QTYPE, RCODE, RD, RR, SOA, A, DNSError, DNSHeader, DNSQuestion, DNSRecord
|
||||
|
||||
from chia.util.chia_logging import initialize_logging
|
||||
from chia.util.config import load_config
|
||||
from chia.seeder.crawl_store import CrawlStore
|
||||
from chia.util.chia_logging import initialize_service_logging
|
||||
from chia.util.config import load_config, load_config_cli
|
||||
from chia.util.default_root import DEFAULT_ROOT_PATH
|
||||
from chia.util.path import path_from_root
|
||||
|
||||
SERVICE_NAME = "seeder"
|
||||
log = logging.getLogger(__name__)
|
||||
DnsCallback = Callable[[DNSRecord], Awaitable[DNSRecord]]
|
||||
|
||||
|
||||
# DNS snippet taken from: https://gist.github.com/pklaus/b5a7876d4d2cf7271873
|
||||
|
||||
|
||||
class DomainName(str):
|
||||
def __getattr__(self, item):
|
||||
return DomainName(item + "." + self)
|
||||
def __getattr__(self, item: str) -> DomainName:
|
||||
return DomainName(item + "." + self) # DomainName.NS becomes DomainName("NS.DomainName")
|
||||
|
||||
|
||||
D = None
|
||||
ns = None
|
||||
IP = "127.0.0.1"
|
||||
TTL = None
|
||||
soa_record = None
|
||||
ns_records: List[Any] = []
|
||||
@dataclass(frozen=True)
|
||||
class PeerList:
|
||||
ipv4: List[IPv4Address]
|
||||
ipv6: List[IPv6Address]
|
||||
|
||||
@property
|
||||
def no_peers(self) -> bool:
|
||||
return not self.ipv4 and not self.ipv6
|
||||
|
||||
|
||||
class EchoServerProtocol(asyncio.DatagramProtocol):
|
||||
def __init__(self, callback):
|
||||
self.data_queue = asyncio.Queue()
|
||||
self.callback = callback
|
||||
asyncio.ensure_future(self.respond())
|
||||
@dataclass
|
||||
class UDPDNSServerProtocol(asyncio.DatagramProtocol):
|
||||
"""
|
||||
This is a really simple UDP Server, that converts all requests to DNSRecord objects and passes them to the callback.
|
||||
"""
|
||||
|
||||
def connection_made(self, transport):
|
||||
self.transport = transport
|
||||
callback: DnsCallback
|
||||
transport: Optional[asyncio.DatagramTransport] = field(init=False, default=None)
|
||||
data_queue: asyncio.Queue[tuple[DNSRecord, tuple[str, int]]] = field(default_factory=asyncio.Queue)
|
||||
queue_task: Optional[asyncio.Task[None]] = field(init=False, default=None)
|
||||
|
||||
def datagram_received(self, data, addr):
|
||||
asyncio.ensure_future(self.handler(data, addr))
|
||||
def start(self) -> None:
|
||||
self.queue_task = asyncio.create_task(self.respond()) # This starts the dns respond loop.
|
||||
|
||||
async def respond(self):
|
||||
while True:
|
||||
async def stop(self) -> None:
|
||||
if self.queue_task is not None:
|
||||
self.queue_task.cancel()
|
||||
try:
|
||||
resp, caller = await self.data_queue.get()
|
||||
self.transport.sendto(resp, caller)
|
||||
await self.queue_task
|
||||
except asyncio.CancelledError: # we dont care
|
||||
pass
|
||||
if self.transport is not None:
|
||||
self.transport.close()
|
||||
|
||||
def connection_made(self, transport: asyncio.BaseTransport) -> None:
|
||||
# we use the #ignore because transport is a subclass of BaseTransport, but we need the real type.
|
||||
self.transport = transport # type: ignore[assignment]
|
||||
|
||||
def datagram_received(self, data: bytes, addr: tuple[str, int]) -> None:
|
||||
log.debug(f"Received UDP DNS request from {addr}.")
|
||||
dns_request: Optional[DNSRecord] = parse_dns_request(data)
|
||||
if dns_request is None: # Invalid Request, we can just drop it and move on.
|
||||
return
|
||||
asyncio.create_task(self.handler(dns_request, addr))
|
||||
|
||||
async def respond(self) -> None:
|
||||
log.info("UDP DNS responder started.")
|
||||
while self.transport is None: # we wait for the transport to be set.
|
||||
await asyncio.sleep(0.1)
|
||||
while not self.transport.is_closing():
|
||||
try:
|
||||
edns_max_size = 0
|
||||
reply, caller = await self.data_queue.get()
|
||||
if len(reply.ar) > 0 and reply.ar[0].rtype == QTYPE.OPT:
|
||||
edns_max_size = reply.ar[0].edns_len
|
||||
|
||||
reply_packed = reply.pack()
|
||||
|
||||
if len(reply_packed) > max(512, edns_max_size): # 512 is the default max size for DNS:
|
||||
log.debug(f"DNS response to {caller} is too large, truncating.")
|
||||
reply_packed = reply.truncate().pack()
|
||||
|
||||
self.transport.sendto(reply_packed, caller)
|
||||
log.debug(f"Sent UDP DNS response to {caller}, of size {len(reply_packed)}.")
|
||||
except Exception as e:
|
||||
log.error(f"Exception: {e}. Traceback: {traceback.format_exc()}.")
|
||||
log.error(f"Exception while responding to UDP DNS request: {e}. Traceback: {traceback.format_exc()}.")
|
||||
log.info("UDP DNS responder stopped.")
|
||||
|
||||
async def handler(self, data, caller):
|
||||
async def handler(self, data: DNSRecord, caller: tuple[str, int]) -> None:
|
||||
r_data = await get_dns_reply(self.callback, data) # process the request, returning a DNSRecord response.
|
||||
await self.data_queue.put((r_data, caller))
|
||||
|
||||
|
||||
@dataclass
|
||||
class TCPDNSServerProtocol(asyncio.BufferedProtocol):
|
||||
"""
|
||||
This TCP server is a little more complicated, because we need to handle the length field, however it still
|
||||
converts all requests to DNSRecord objects and passes them to the callback, after receiving the full message.
|
||||
"""
|
||||
|
||||
callback: DnsCallback
|
||||
transport: Optional[asyncio.Transport] = field(init=False, default=None)
|
||||
peer_info: str = field(init=False, default="")
|
||||
expected_length: int = 0
|
||||
buffer: bytearray = field(init=False, default_factory=lambda: bytearray(2))
|
||||
futures: List[asyncio.Future[None]] = field(init=False, default_factory=list)
|
||||
|
||||
def connection_made(self, transport: asyncio.BaseTransport) -> None:
|
||||
"""
|
||||
This is called whenever we get a new connection.
|
||||
"""
|
||||
# we use the #ignore because transport is a subclass of BaseTransport, but we need the real type.
|
||||
self.transport = transport # type: ignore[assignment]
|
||||
peer_info = transport.get_extra_info("peername")
|
||||
self.peer_info = f"{peer_info[0]}:{peer_info[1]}"
|
||||
log.debug(f"TCP connection established with {self.peer_info}.")
|
||||
|
||||
def connection_lost(self, exc: Optional[Exception]) -> None:
|
||||
"""
|
||||
This is called whenever a connection is lost, or closed.
|
||||
"""
|
||||
if exc is not None:
|
||||
log.debug(f"TCP DNS connection lost with {self.peer_info}. Exception: {exc}.")
|
||||
else:
|
||||
log.debug(f"TCP DNS connection closed with {self.peer_info}.")
|
||||
# reset the state of the protocol.
|
||||
for future in self.futures:
|
||||
future.cancel()
|
||||
self.futures = []
|
||||
self.buffer = bytearray(2)
|
||||
self.expected_length = 0
|
||||
|
||||
def get_buffer(self, sizehint: int) -> memoryview:
|
||||
"""
|
||||
This is the first function called after connection_made, it returns a buffer that the tcp server will write to.
|
||||
Once a buffer is written to, buffer_updated is called.
|
||||
"""
|
||||
return memoryview(self.buffer)
|
||||
|
||||
def buffer_updated(self, nbytes: int) -> None:
|
||||
"""
|
||||
This is called whenever the buffer is written to, and it loops through the buffer, grouping them into messages
|
||||
and then dns records.
|
||||
"""
|
||||
while not len(self.buffer) == 0 and self.transport is not None:
|
||||
if not self.expected_length:
|
||||
# Length field received (This is the first part of the message)
|
||||
self.expected_length = int.from_bytes(self.buffer, byteorder="big")
|
||||
self.buffer = self.buffer[2:] # Remove the length field from the buffer.
|
||||
|
||||
if len(self.buffer) >= self.expected_length:
|
||||
# This is the rest of the message (after the length field)
|
||||
message = self.buffer[: self.expected_length]
|
||||
self.buffer = self.buffer[self.expected_length :] # Remove the message from the buffer
|
||||
self.expected_length = 0 # Reset the expected length
|
||||
|
||||
dns_request: Optional[DNSRecord] = parse_dns_request(message)
|
||||
if dns_request is None: # Invalid Request, so we disconnect and don't send anything back.
|
||||
self.transport.close()
|
||||
return
|
||||
self.futures.append(asyncio.create_task(self.handle_and_respond(dns_request)))
|
||||
|
||||
self.buffer = bytearray(2 if self.expected_length == 0 else self.expected_length) # Reset the buffer if empty.
|
||||
|
||||
def eof_received(self) -> Optional[bool]:
|
||||
"""
|
||||
This is called when the client closes the connection, False or None means we close the connection.
|
||||
True means we keep the connection open.
|
||||
"""
|
||||
if len(self.futures) > 0: # Successful requests
|
||||
if self.expected_length != 0: # Incomplete requests
|
||||
log.warning(
|
||||
f"Received incomplete TCP DNS request of length {self.expected_length} from {self.peer_info}, "
|
||||
f"closing connection after dns replies are sent."
|
||||
)
|
||||
asyncio.create_task(self.wait_for_futures())
|
||||
return True # Keep connection open, until the futures are done.
|
||||
log.info(f"Received early EOF from {self.peer_info}, closing connection.")
|
||||
return False
|
||||
|
||||
async def wait_for_futures(self) -> None:
|
||||
"""
|
||||
Waits for all the futures to complete, and then closes the connection.
|
||||
"""
|
||||
try:
|
||||
data = await self.callback(data)
|
||||
if data is None:
|
||||
return
|
||||
await self.data_queue.put((data, caller))
|
||||
await asyncio.wait_for(asyncio.gather(*self.futures), timeout=10)
|
||||
except asyncio.TimeoutError:
|
||||
log.warning(f"Timed out waiting for DNS replies to be sent to {self.peer_info}.")
|
||||
if self.transport is not None:
|
||||
self.transport.close()
|
||||
|
||||
async def handle_and_respond(self, data: DNSRecord) -> None:
|
||||
r_data = await get_dns_reply(self.callback, data) # process the request, returning a DNSRecord response.
|
||||
try:
|
||||
# If the client closed the connection, we don't want to send anything.
|
||||
if self.transport is not None and not self.transport.is_closing():
|
||||
self.transport.write(dns_response_to_tcp(r_data)) # send data back to the client
|
||||
log.debug(f"Sent DNS response for {data.q.qname}, to {self.peer_info}.")
|
||||
except Exception as e:
|
||||
log.error(f"Exception: {e}. Traceback: {traceback.format_exc()}.")
|
||||
log.error(f"Exception while responding to TCP DNS request: {e}. Traceback: {traceback.format_exc()}.")
|
||||
|
||||
|
||||
def dns_response_to_tcp(data: DNSRecord) -> bytes:
|
||||
"""
|
||||
Converts a DNSRecord response to a TCP DNS response, by adding a 2 byte length field to the start.
|
||||
"""
|
||||
dns_response = data.pack()
|
||||
dns_response_length = len(dns_response).to_bytes(2, byteorder="big")
|
||||
return bytes(dns_response_length + dns_response)
|
||||
|
||||
|
||||
def create_dns_reply(dns_request: DNSRecord) -> DNSRecord:
|
||||
"""
|
||||
Creates a DNS response with the correct header and section flags set.
|
||||
"""
|
||||
# QR means query response, AA means authoritative answer, RA means recursion available
|
||||
return DNSRecord(DNSHeader(id=dns_request.header.id, qr=1, aa=1, ra=0), q=dns_request.q)
|
||||
|
||||
|
||||
def parse_dns_request(data: bytes) -> Optional[DNSRecord]:
|
||||
"""
|
||||
Parses the DNS request, and returns a DNSRecord object, or None if the request is invalid.
|
||||
"""
|
||||
dns_request: Optional[DNSRecord] = None
|
||||
try:
|
||||
dns_request = DNSRecord.parse(data)
|
||||
except DNSError as e:
|
||||
log.warning(f"Received invalid DNS request: {e}. Traceback: {traceback.format_exc()}.")
|
||||
return dns_request
|
||||
|
||||
|
||||
async def get_dns_reply(callback: DnsCallback, dns_request: DNSRecord) -> DNSRecord:
|
||||
"""
|
||||
This function calls the callback, and returns SERVFAIL if the callback raises an exception.
|
||||
"""
|
||||
try:
|
||||
dns_reply = await callback(dns_request)
|
||||
except Exception as e:
|
||||
log.error(f"Exception during DNS record processing: {e}. Traceback: {traceback.format_exc()}.")
|
||||
# we return an empty response with an error code
|
||||
dns_reply = create_dns_reply(dns_request) # This is an empty response, with only the header set.
|
||||
dns_reply.header.rcode = RCODE.SERVFAIL
|
||||
return dns_reply
|
||||
|
||||
|
||||
@dataclass
|
||||
class DNSServer:
|
||||
reliable_peers_v4: List[str]
|
||||
reliable_peers_v6: List[str]
|
||||
lock: asyncio.Lock
|
||||
pointer: int
|
||||
crawl_db: aiosqlite.Connection
|
||||
config: Dict[str, Any]
|
||||
root_path: Path
|
||||
lock: asyncio.Lock = field(default_factory=asyncio.Lock)
|
||||
shutdown_event: asyncio.Event = field(default_factory=asyncio.Event)
|
||||
crawl_store: Optional[CrawlStore] = field(init=False, default=None)
|
||||
reliable_task: Optional[asyncio.Task[None]] = field(init=False, default=None)
|
||||
shutdown_task: Optional[asyncio.Task[None]] = field(init=False, default=None)
|
||||
udp_transport_ipv4: Optional[asyncio.DatagramTransport] = field(init=False, default=None)
|
||||
udp_protocol_ipv4: Optional[UDPDNSServerProtocol] = field(init=False, default=None)
|
||||
udp_transport_ipv6: Optional[asyncio.DatagramTransport] = field(init=False, default=None)
|
||||
udp_protocol_ipv6: Optional[UDPDNSServerProtocol] = field(init=False, default=None)
|
||||
# TODO: After 3.10 is dropped change to asyncio.Server
|
||||
tcp_server: Optional[asyncio.base_events.Server] = field(init=False, default=None)
|
||||
# these are all set in __post_init__
|
||||
dns_port: int = field(init=False)
|
||||
db_path: Path = field(init=False)
|
||||
domain: DomainName = field(init=False)
|
||||
ns1: DomainName = field(init=False)
|
||||
ns_records: List[RR] = field(init=False)
|
||||
ttl: int = field(init=False)
|
||||
soa_record: RR = field(init=False)
|
||||
reliable_peers_v4: List[IPv4Address] = field(default_factory=list)
|
||||
reliable_peers_v6: List[IPv6Address] = field(default_factory=list)
|
||||
pointer_v4: int = 0
|
||||
pointer_v6: int = 0
|
||||
|
||||
def __init__(self, config: Dict, root_path: Path):
|
||||
self.reliable_peers_v4 = []
|
||||
self.reliable_peers_v6 = []
|
||||
self.lock = asyncio.Lock()
|
||||
self.pointer_v4 = 0
|
||||
self.pointer_v6 = 0
|
||||
|
||||
crawler_db_path: str = config.get("crawler_db_path", "crawler.db")
|
||||
self.db_path = path_from_root(root_path, crawler_db_path)
|
||||
def __post_init__(self) -> None:
|
||||
"""
|
||||
We initialize all the variables set to field(init=False) here.
|
||||
"""
|
||||
# From Config
|
||||
self.dns_port: int = self.config.get("dns_port", 53)
|
||||
# DB Path
|
||||
crawler_db_path: str = self.config.get("crawler_db_path", "crawler.db")
|
||||
self.db_path: Path = path_from_root(self.root_path, crawler_db_path)
|
||||
self.db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
async def start(self):
|
||||
# self.crawl_db = await aiosqlite.connect(self.db_path)
|
||||
# Get a reference to the event loop as we plan to use
|
||||
# low-level APIs.
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
# One protocol instance will be created to serve all
|
||||
# client requests.
|
||||
self.transport, self.protocol = await loop.create_datagram_endpoint(
|
||||
lambda: EchoServerProtocol(self.dns_response), local_addr=("::0", 53)
|
||||
# DNS info
|
||||
self.domain: DomainName = DomainName(self.config["domain_name"])
|
||||
if not self.domain.endswith("."):
|
||||
self.domain = DomainName(self.domain + ".") # Make sure the domain ends with a period, as per RFC 1035.
|
||||
self.ns1: DomainName = DomainName(self.config["nameserver"])
|
||||
self.ns_records: List[NS] = [NS(self.ns1)]
|
||||
self.ttl: int = self.config["ttl"]
|
||||
self.soa_record: SOA = SOA(
|
||||
mname=self.ns1, # primary name server
|
||||
rname=self.config["soa"]["rname"], # email of the domain administrator
|
||||
times=(
|
||||
self.config["soa"]["serial_number"],
|
||||
self.config["soa"]["refresh"],
|
||||
self.config["soa"]["retry"],
|
||||
self.config["soa"]["expire"],
|
||||
self.config["soa"]["minimum"],
|
||||
),
|
||||
)
|
||||
|
||||
@asynccontextmanager
|
||||
async def run(self) -> AsyncIterator[None]:
|
||||
log.info("Starting DNS server.")
|
||||
# Get a reference to the event loop as we plan to use low-level APIs.
|
||||
loop = asyncio.get_running_loop()
|
||||
await self.setup_signal_handlers(loop)
|
||||
|
||||
# Set up the crawl store and the peer update task.
|
||||
self.crawl_store = await CrawlStore.create(await aiosqlite.connect(self.db_path, timeout=120))
|
||||
self.reliable_task = asyncio.create_task(self.periodically_get_reliable_peers())
|
||||
|
||||
async def periodically_get_reliable_peers(self):
|
||||
# One protocol instance will be created for each udp transport, so that we can accept ipv4 and ipv6
|
||||
self.udp_transport_ipv6, self.udp_protocol_ipv6 = await loop.create_datagram_endpoint(
|
||||
lambda: UDPDNSServerProtocol(self.dns_response), local_addr=("::0", self.dns_port)
|
||||
)
|
||||
self.udp_protocol_ipv6.start() # start ipv6 udp transmit task
|
||||
|
||||
# in case the port is 0 we need all protocols on the same port.
|
||||
self.dns_port = self.udp_transport_ipv6.get_extra_info("sockname")[1] # get the actual port we are listening to
|
||||
|
||||
if sys.platform.startswith("win32") or sys.platform.startswith("cygwin"):
|
||||
# Windows does not support dual stack sockets, so we need to create a new socket for ipv4.
|
||||
self.udp_transport_ipv4, self.udp_protocol_ipv4 = await loop.create_datagram_endpoint(
|
||||
lambda: UDPDNSServerProtocol(self.dns_response), local_addr=("0.0.0.0", self.dns_port)
|
||||
)
|
||||
self.udp_protocol_ipv4.start() # start ipv4 udp transmit task
|
||||
|
||||
# One tcp server will handle both ipv4 and ipv6 on both linux and windows.
|
||||
self.tcp_server = await loop.create_server(
|
||||
lambda: TCPDNSServerProtocol(self.dns_response), ["::0", "0.0.0.0"], self.dns_port
|
||||
)
|
||||
|
||||
log.info("DNS server started.")
|
||||
try:
|
||||
yield
|
||||
finally: # catches any errors and properly shuts down the server
|
||||
await self.stop()
|
||||
log.info("DNS server stopped.")
|
||||
|
||||
async def setup_signal_handlers(self, loop: AbstractEventLoop) -> None:
|
||||
try:
|
||||
loop.add_signal_handler(signal.SIGINT, self._accept_signal)
|
||||
loop.add_signal_handler(signal.SIGTERM, self._accept_signal)
|
||||
except NotImplementedError:
|
||||
log.warning("signal handlers unsupported on this platform")
|
||||
|
||||
def _accept_signal(self) -> None: # pragma: no cover
|
||||
if self.shutdown_task is None: # otherwise we are already shutting down, so we ignore the signal
|
||||
self.shutdown_task = asyncio.create_task(self.stop())
|
||||
|
||||
async def stop(self) -> None:
|
||||
log.info("Stopping DNS server...")
|
||||
if self.reliable_task is not None:
|
||||
self.reliable_task.cancel() # cancel the peer update task
|
||||
if self.crawl_store is not None:
|
||||
await self.crawl_store.crawl_db.close()
|
||||
if self.udp_protocol_ipv6 is not None:
|
||||
await self.udp_protocol_ipv6.stop() # stop responding to and accepting udp requests (ipv6) & ipv4 if linux.
|
||||
if self.udp_protocol_ipv4 is not None:
|
||||
await self.udp_protocol_ipv4.stop() # stop responding to and accepting udp requests (ipv4) if windows.
|
||||
if self.tcp_server is not None:
|
||||
self.tcp_server.close() # stop accepting new tcp requests (ipv4 and ipv6)
|
||||
await self.tcp_server.wait_closed() # wait for existing TCP requests to finish (ipv4 and ipv6)
|
||||
self.shutdown_event.set()
|
||||
|
||||
async def periodically_get_reliable_peers(self) -> None:
|
||||
sleep_interval = 0
|
||||
while True:
|
||||
while not self.shutdown_event.is_set() and self.crawl_store is not None:
|
||||
try:
|
||||
# TODO: double check this. It shouldn't take this long to connect.
|
||||
crawl_db = await aiosqlite.connect(self.db_path, timeout=600)
|
||||
cursor = await crawl_db.execute(
|
||||
"SELECT * from good_peers",
|
||||
)
|
||||
new_reliable_peers = []
|
||||
rows = await cursor.fetchall()
|
||||
await cursor.close()
|
||||
await crawl_db.close()
|
||||
for row in rows:
|
||||
new_reliable_peers.append(row[0])
|
||||
if len(new_reliable_peers) > 0:
|
||||
random.shuffle(new_reliable_peers)
|
||||
async with self.lock:
|
||||
self.reliable_peers_v4 = []
|
||||
self.reliable_peers_v6 = []
|
||||
for peer in new_reliable_peers:
|
||||
ipv4 = True
|
||||
try:
|
||||
_ = ipaddress.IPv4Address(peer)
|
||||
except ValueError:
|
||||
ipv4 = False
|
||||
if ipv4:
|
||||
self.reliable_peers_v4.append(peer)
|
||||
else:
|
||||
try:
|
||||
_ = ipaddress.IPv6Address(peer)
|
||||
except ValueError:
|
||||
continue
|
||||
self.reliable_peers_v6.append(peer)
|
||||
self.pointer_v4 = 0
|
||||
self.pointer_v6 = 0
|
||||
log.error(
|
||||
new_reliable_peers = await self.crawl_store.get_good_peers()
|
||||
except Exception as e:
|
||||
log.error(f"Error loading reliable peers from database: {e}. Traceback: {traceback.format_exc()}.")
|
||||
continue
|
||||
if len(new_reliable_peers) == 0:
|
||||
log.info("No reliable peers found in database, waiting for db to be populated.")
|
||||
await asyncio.sleep(2) # sleep for 2 seconds, because the db has not been populated yet.
|
||||
continue
|
||||
async with self.lock:
|
||||
self.reliable_peers_v4 = []
|
||||
self.reliable_peers_v6 = []
|
||||
self.pointer_v4 = 0
|
||||
self.pointer_v6 = 0
|
||||
for peer in new_reliable_peers:
|
||||
try:
|
||||
validated_peer = ip_address(peer)
|
||||
if validated_peer.version == 4:
|
||||
self.reliable_peers_v4.append(validated_peer)
|
||||
elif validated_peer.version == 6:
|
||||
self.reliable_peers_v6.append(validated_peer)
|
||||
except ValueError:
|
||||
log.error(f"Invalid peer: {peer}")
|
||||
continue
|
||||
log.info(
|
||||
f"Number of reliable peers discovered in dns server:"
|
||||
f" IPv4 count - {len(self.reliable_peers_v4)}"
|
||||
f" IPv6 count - {len(self.reliable_peers_v6)}"
|
||||
)
|
||||
except Exception as e:
|
||||
log.error(f"Exception: {e}. Traceback: {traceback.format_exc()}.")
|
||||
|
||||
sleep_interval = min(15, sleep_interval + 1)
|
||||
await asyncio.sleep(sleep_interval * 60)
|
||||
|
||||
async def get_peers_to_respond(self, ipv4_count, ipv6_count):
|
||||
peers = []
|
||||
async def get_peers_to_respond(self, ipv4_count: int, ipv6_count: int) -> PeerList:
|
||||
async with self.lock:
|
||||
# Append IPv4.
|
||||
ipv4_peers: List[IPv4Address] = []
|
||||
size = len(self.reliable_peers_v4)
|
||||
if ipv4_count > 0 and size <= ipv4_count:
|
||||
peers = self.reliable_peers_v4
|
||||
ipv4_peers = self.reliable_peers_v4
|
||||
elif ipv4_count > 0:
|
||||
peers = [self.reliable_peers_v4[i % size] for i in range(self.pointer_v4, self.pointer_v4 + ipv4_count)]
|
||||
self.pointer_v4 = (self.pointer_v4 + ipv4_count) % size
|
||||
ipv4_peers = [
|
||||
self.reliable_peers_v4[i % size] for i in range(self.pointer_v4, self.pointer_v4 + ipv4_count)
|
||||
]
|
||||
self.pointer_v4 = (self.pointer_v4 + ipv4_count) % size # mark where we left off
|
||||
# Append IPv6.
|
||||
ipv6_peers: List[IPv6Address] = []
|
||||
size = len(self.reliable_peers_v6)
|
||||
if ipv6_count > 0 and size <= ipv6_count:
|
||||
peers = peers + self.reliable_peers_v6
|
||||
ipv6_peers = self.reliable_peers_v6
|
||||
elif ipv6_count > 0:
|
||||
peers = peers + [
|
||||
ipv6_peers = [
|
||||
self.reliable_peers_v6[i % size] for i in range(self.pointer_v6, self.pointer_v6 + ipv6_count)
|
||||
]
|
||||
self.pointer_v6 = (self.pointer_v6 + ipv6_count) % size
|
||||
return peers
|
||||
self.pointer_v6 = (self.pointer_v6 + ipv6_count) % size # mark where we left off
|
||||
return PeerList(ipv4_peers, ipv6_peers)
|
||||
|
||||
async def dns_response(self, data):
|
||||
try:
|
||||
request = DNSRecord.parse(data)
|
||||
IPs = [MX(D.mail), soa_record] + ns_records
|
||||
ipv4_count = 0
|
||||
ipv6_count = 0
|
||||
if request.q.qtype == 1:
|
||||
ipv4_count = 32
|
||||
elif request.q.qtype == 28:
|
||||
ipv6_count = 32
|
||||
elif request.q.qtype == 255:
|
||||
ipv4_count = 16
|
||||
ipv6_count = 16
|
||||
else:
|
||||
ipv4_count = 32
|
||||
peers = await self.get_peers_to_respond(ipv4_count, ipv6_count)
|
||||
if len(peers) == 0:
|
||||
return None
|
||||
for peer in peers:
|
||||
ipv4 = True
|
||||
try:
|
||||
_ = ipaddress.IPv4Address(peer)
|
||||
except ValueError:
|
||||
ipv4 = False
|
||||
if ipv4:
|
||||
IPs.append(A(peer))
|
||||
else:
|
||||
try:
|
||||
_ = ipaddress.IPv6Address(peer)
|
||||
except ValueError:
|
||||
continue
|
||||
IPs.append(AAAA(peer))
|
||||
reply = DNSRecord(DNSHeader(id=request.header.id, qr=1, aa=len(IPs), ra=1), q=request.q)
|
||||
async def dns_response(self, request: DNSRecord) -> DNSRecord:
|
||||
"""
|
||||
This function is called when a DNS request is received, and it returns a DNS response.
|
||||
It does not catch any errors as it is called from within a try-except block.
|
||||
"""
|
||||
reply = create_dns_reply(request)
|
||||
dns_question: DNSQuestion = request.q # this is the question / request
|
||||
question_type: int = dns_question.qtype # the type of the record being requested
|
||||
qname = dns_question.qname # the name being queried / requested
|
||||
# ADD EDNS0 to response if supported
|
||||
if len(request.ar) > 0 and request.ar[0].rtype == QTYPE.OPT: # OPT Means EDNS
|
||||
udp_len = min(4096, request.ar[0].edns_len)
|
||||
edns_reply = EDNS0(udp_len=udp_len)
|
||||
reply.add_ar(edns_reply)
|
||||
# DNS labels are mixed case with DNS resolvers that implement the use of bit 0x20 to improve
|
||||
# transaction identity. See https://datatracker.ietf.org/doc/html/draft-vixie-dnsext-dns0x20-00
|
||||
qname_str = str(qname).lower()
|
||||
if qname_str != self.domain and not qname_str.endswith("." + self.domain):
|
||||
# we don't answer for other domains (we have the not recursive bit set)
|
||||
log.warning(f"Invalid request for {qname_str}, returning REFUSED.")
|
||||
reply.header.rcode = RCODE.REFUSED
|
||||
return reply
|
||||
|
||||
records = {
|
||||
D: IPs,
|
||||
D.ns1: [A(IP)], # MX and NS records must never point to a CNAME alias (RFC 2181 section 10.3)
|
||||
D.ns2: [A(IP)],
|
||||
D.mail: [A(IP)],
|
||||
D.andrei: [CNAME(D)],
|
||||
}
|
||||
ttl: int = self.ttl
|
||||
# we add these to the list as it will allow us to respond to ns and soa requests
|
||||
ips: List[RD] = [self.soa_record] + self.ns_records
|
||||
ipv4_count = 0
|
||||
ipv6_count = 0
|
||||
if question_type is QTYPE.A:
|
||||
ipv4_count = 32
|
||||
elif question_type is QTYPE.AAAA:
|
||||
ipv6_count = 32
|
||||
elif question_type is QTYPE.ANY:
|
||||
ipv4_count = 16
|
||||
ipv6_count = 16
|
||||
else:
|
||||
ipv4_count = 32
|
||||
peers: PeerList = await self.get_peers_to_respond(ipv4_count, ipv6_count)
|
||||
if peers.no_peers:
|
||||
log.error("No peers found, returning SOA and NS records only.")
|
||||
ttl = 60 # 1 minute as we should have some peers very soon
|
||||
# we always return the SOA and NS records, so we continue even if there are no peers
|
||||
ips.extend([A(str(peer)) for peer in peers.ipv4])
|
||||
ips.extend([AAAA(str(peer)) for peer in peers.ipv6])
|
||||
|
||||
qname = request.q.qname
|
||||
# DNS labels are mixed case with DNS resolvers that implement the use of bit 0x20 to improve
|
||||
# transaction identity. See https://datatracker.ietf.org/doc/html/draft-vixie-dnsext-dns0x20-00
|
||||
qn = str(qname).lower()
|
||||
qtype = request.q.qtype
|
||||
qt = QTYPE[qtype]
|
||||
if qn == D or qn.endswith("." + D):
|
||||
for name, rrs in records.items():
|
||||
if name == qn:
|
||||
for rdata in rrs:
|
||||
rqt = rdata.__class__.__name__
|
||||
if qt in ["*", rqt] or (qt == "ANY" and (rqt == "A" or rqt == "AAAA")):
|
||||
reply.add_answer(
|
||||
RR(rname=qname, rtype=getattr(QTYPE, rqt), rclass=1, ttl=TTL, rdata=rdata)
|
||||
)
|
||||
records: Dict[DomainName, List[RD]] = { # this is where we can add other records we want to serve
|
||||
self.domain: ips,
|
||||
}
|
||||
|
||||
for rdata in ns_records:
|
||||
reply.add_ar(RR(rname=D, rtype=QTYPE.NS, rclass=1, ttl=TTL, rdata=rdata))
|
||||
|
||||
reply.add_auth(RR(rname=D, rtype=QTYPE.SOA, rclass=1, ttl=TTL, rdata=soa_record))
|
||||
|
||||
return reply.pack()
|
||||
except Exception as e:
|
||||
log.error(f"Exception: {e}. Traceback: {traceback.format_exc()}.")
|
||||
valid_domain = False
|
||||
for domain_name, domain_responses in records.items():
|
||||
if domain_name == qname_str: # if the dns name is the same as the requested name
|
||||
valid_domain = True
|
||||
for response in domain_responses:
|
||||
rqt: int = getattr(QTYPE, response.__class__.__name__)
|
||||
if question_type == rqt or question_type == QTYPE.ANY:
|
||||
reply.add_answer(RR(rname=qname, rtype=rqt, rclass=1, ttl=ttl, rdata=response))
|
||||
if not valid_domain and len(reply.rr) == 0: # if we didn't find any records to return
|
||||
reply.header.rcode = RCODE.NXDOMAIN
|
||||
# always put nameservers and the SOA records
|
||||
for nameserver in self.ns_records:
|
||||
reply.add_auth(RR(rname=self.domain, rtype=QTYPE.NS, rclass=1, ttl=ttl, rdata=nameserver))
|
||||
reply.add_auth(RR(rname=self.domain, rtype=QTYPE.SOA, rclass=1, ttl=ttl, rdata=self.soa_record))
|
||||
return reply
|
||||
|
||||
|
||||
async def serve_dns(config: Dict, root_path: Path):
|
||||
dns_server = DNSServer(config, root_path)
|
||||
await dns_server.start()
|
||||
|
||||
# TODO: Make this cleaner?
|
||||
while True:
|
||||
await asyncio.sleep(3600)
|
||||
async def run_dns_server(dns_server: DNSServer) -> None: # pragma: no cover
|
||||
async with dns_server.run():
|
||||
await dns_server.shutdown_event.wait() # this is released on SIGINT or SIGTERM or any unhandled exception
|
||||
|
||||
|
||||
async def kill_processes():
|
||||
# TODO: implement.
|
||||
pass
|
||||
def create_dns_server_service(config: Dict[str, Any], root_path: Path) -> DNSServer:
|
||||
service_config = config[SERVICE_NAME]
|
||||
|
||||
return DNSServer(service_config, root_path)
|
||||
|
||||
|
||||
def signal_received():
|
||||
asyncio.create_task(kill_processes())
|
||||
|
||||
|
||||
async def async_main(config, root_path):
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
try:
|
||||
loop.add_signal_handler(signal.SIGINT, signal_received)
|
||||
loop.add_signal_handler(signal.SIGTERM, signal_received)
|
||||
except NotImplementedError:
|
||||
log.info("signal handlers unsupported")
|
||||
|
||||
await serve_dns(config, root_path)
|
||||
|
||||
|
||||
def main():
|
||||
def main() -> None: # pragma: no cover
|
||||
freeze_support()
|
||||
root_path = DEFAULT_ROOT_PATH
|
||||
config = load_config(root_path, "config.yaml", SERVICE_NAME)
|
||||
initialize_logging(SERVICE_NAME, config["logging"], root_path)
|
||||
global D
|
||||
global ns
|
||||
global TTL
|
||||
global soa_record
|
||||
global ns_records
|
||||
D = DomainName(config["domain_name"])
|
||||
ns = DomainName(config["nameserver"])
|
||||
TTL = config["ttl"]
|
||||
soa_record = SOA(
|
||||
mname=ns, # primary name server
|
||||
rname=config["soa"]["rname"], # email of the domain administrator
|
||||
times=(
|
||||
config["soa"]["serial_number"],
|
||||
config["soa"]["refresh"],
|
||||
config["soa"]["retry"],
|
||||
config["soa"]["expire"],
|
||||
config["soa"]["minimum"],
|
||||
),
|
||||
)
|
||||
ns_records = [NS(ns)]
|
||||
# TODO: refactor to avoid the double load
|
||||
config = load_config(DEFAULT_ROOT_PATH, "config.yaml")
|
||||
service_config = load_config_cli(DEFAULT_ROOT_PATH, "config.yaml", SERVICE_NAME)
|
||||
config[SERVICE_NAME] = service_config
|
||||
initialize_service_logging(service_name=SERVICE_NAME, config=config)
|
||||
|
||||
asyncio.run(async_main(config=config, root_path=root_path))
|
||||
dns_server = create_dns_server_service(config, root_path)
|
||||
asyncio.run(run_dns_server(dns_server))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -24,23 +24,19 @@ class PeerRecord(Streamable):
|
||||
handshake_time: uint64
|
||||
tls_version: str
|
||||
|
||||
def update_version(self, version, now):
|
||||
def update_version(self, version: str, now: uint64) -> None:
|
||||
if version != "undefined":
|
||||
object.__setattr__(self, "version", version)
|
||||
object.__setattr__(self, "handshake_time", uint64(now))
|
||||
object.__setattr__(self, "handshake_time", now)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PeerStat:
|
||||
weight: float
|
||||
count: float
|
||||
reliability: float
|
||||
|
||||
def __init__(self, weight, count, reliability):
|
||||
self.weight = weight
|
||||
self.count = count
|
||||
self.reliability = reliability
|
||||
|
||||
def update(self, is_reachable: bool, age: int, tau: int):
|
||||
def update(self, is_reachable: bool, age: int, tau: int) -> None:
|
||||
f = math.exp(-age / tau)
|
||||
self.reliability = self.reliability * f + (1.0 - f if is_reachable else 0.0)
|
||||
self.count = self.count * f + 1.0
|
||||
@ -81,7 +77,7 @@ class PeerReliability:
|
||||
stat_1m_reliability: float = 0.0,
|
||||
tries: int = 0,
|
||||
successes: int = 0,
|
||||
):
|
||||
) -> None:
|
||||
self.peer_id = peer_id
|
||||
self.ignore_till = ignore_till
|
||||
self.ban_till = ban_till
|
||||
@ -132,7 +128,7 @@ class PeerReliability:
|
||||
return 2 * 3600
|
||||
return 0
|
||||
|
||||
def update(self, is_reachable: bool, age: int):
|
||||
def update(self, is_reachable: bool, age: int) -> None:
|
||||
self.stat_2h.update(is_reachable, age, 2 * 3600)
|
||||
self.stat_8h.update(is_reachable, age, 8 * 3600)
|
||||
self.stat_1d.update(is_reachable, age, 24 * 3600)
|
||||
|
@ -4,7 +4,7 @@ import logging
|
||||
import pathlib
|
||||
import sys
|
||||
from multiprocessing import freeze_support
|
||||
from typing import Dict, Optional
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from chia.consensus.constants import ConsensusConstants
|
||||
from chia.consensus.default_constants import DEFAULT_CONSTANTS
|
||||
@ -26,24 +26,25 @@ log = logging.getLogger(__name__)
|
||||
|
||||
def create_full_node_crawler_service(
|
||||
root_path: pathlib.Path,
|
||||
config: Dict,
|
||||
config: Dict[str, Any],
|
||||
consensus_constants: ConsensusConstants,
|
||||
connect_to_daemon: bool = True,
|
||||
) -> Service[Crawler, CrawlerAPI]:
|
||||
service_config = config[SERVICE_NAME]
|
||||
crawler_config = service_config["crawler"]
|
||||
|
||||
crawler = Crawler(
|
||||
service_config,
|
||||
root_path=root_path,
|
||||
consensus_constants=consensus_constants,
|
||||
constants=consensus_constants,
|
||||
)
|
||||
api = CrawlerAPI(crawler)
|
||||
|
||||
network_id = service_config["selected_network"]
|
||||
|
||||
rpc_info: Optional[RpcInfo] = None
|
||||
if service_config.get("start_rpc_server", True):
|
||||
rpc_info = (CrawlerRpcApi, service_config.get("rpc_port", 8561))
|
||||
if crawler_config.get("start_rpc_server", True):
|
||||
rpc_info = (CrawlerRpcApi, crawler_config.get("rpc_port", 8561))
|
||||
|
||||
return Service(
|
||||
root_path=root_path,
|
||||
|
@ -24,6 +24,7 @@ from chia.introducer.introducer_api import IntroducerAPI
|
||||
from chia.protocols.shared_protocol import Capability, capabilities
|
||||
from chia.seeder.crawler import Crawler
|
||||
from chia.seeder.crawler_api import CrawlerAPI
|
||||
from chia.seeder.dns_server import DNSServer, create_dns_server_service
|
||||
from chia.seeder.start_crawler import create_full_node_crawler_service
|
||||
from chia.server.start_farmer import create_farmer_service
|
||||
from chia.server.start_full_node import create_full_node_service
|
||||
@ -32,15 +33,17 @@ from chia.server.start_introducer import create_introducer_service
|
||||
from chia.server.start_service import Service
|
||||
from chia.server.start_timelord import create_timelord_service
|
||||
from chia.server.start_wallet import create_wallet_service
|
||||
from chia.simulator.block_tools import BlockTools
|
||||
from chia.simulator.block_tools import BlockTools, test_constants
|
||||
from chia.simulator.keyring import TempKeyring
|
||||
from chia.simulator.ssl_certs import get_next_nodes_certs_and_keys, get_next_private_ca_cert_and_key
|
||||
from chia.simulator.start_simulator import create_full_node_simulator_service
|
||||
from chia.ssl.create_ssl import create_all_ssl
|
||||
from chia.timelord.timelord import Timelord
|
||||
from chia.timelord.timelord_api import TimelordAPI
|
||||
from chia.timelord.timelord_launcher import kill_processes, spawn_process
|
||||
from chia.types.peer_info import UnresolvedPeerInfo
|
||||
from chia.util.bech32m import encode_puzzle_hash
|
||||
from chia.util.config import config_path_for_filename, lock_and_load_config, save_config
|
||||
from chia.util.config import config_path_for_filename, load_config, lock_and_load_config, save_config
|
||||
from chia.util.ints import uint16
|
||||
from chia.util.keychain import bytes_to_mnemonic
|
||||
from chia.util.lock import Lockfile
|
||||
@ -175,24 +178,36 @@ async def setup_full_node(
|
||||
|
||||
|
||||
async def setup_crawler(
|
||||
bt: BlockTools,
|
||||
root_path_populated_with_config: Path, database_uri: str
|
||||
) -> AsyncGenerator[Service[Crawler, CrawlerAPI], None]:
|
||||
config = bt.config
|
||||
create_all_ssl(
|
||||
root_path=root_path_populated_with_config,
|
||||
private_ca_crt_and_key=get_next_private_ca_cert_and_key().collateral.cert_and_key,
|
||||
node_certs_and_keys=get_next_nodes_certs_and_keys().collateral.certs_and_keys,
|
||||
)
|
||||
config = load_config(root_path_populated_with_config, "config.yaml")
|
||||
service_config = config["seeder"]
|
||||
|
||||
service_config["selected_network"] = "testnet0"
|
||||
service_config["port"] = 0
|
||||
service_config["start_rpc_server"] = False
|
||||
service_config["crawler"]["start_rpc_server"] = False
|
||||
service_config["other_peers_port"] = 58444
|
||||
service_config["crawler_db_path"] = database_uri
|
||||
|
||||
overrides = service_config["network_overrides"]["constants"][service_config["selected_network"]]
|
||||
updated_constants = bt.constants.replace_str_to_bytes(**overrides)
|
||||
bt.change_config(config)
|
||||
updated_constants = test_constants.replace_str_to_bytes(**overrides)
|
||||
|
||||
service = create_full_node_crawler_service(
|
||||
bt.root_path,
|
||||
root_path_populated_with_config,
|
||||
config,
|
||||
updated_constants,
|
||||
connect_to_daemon=False,
|
||||
)
|
||||
await service.start()
|
||||
|
||||
if not service_config["crawler"]["start_rpc_server"]: # otherwise the loops don't work.
|
||||
service._node.state_changed_callback = lambda x, y: None
|
||||
|
||||
try:
|
||||
yield service
|
||||
finally:
|
||||
@ -200,6 +215,24 @@ async def setup_crawler(
|
||||
await service.wait_closed()
|
||||
|
||||
|
||||
async def setup_seeder(root_path_populated_with_config: Path, database_uri: str) -> AsyncGenerator[DNSServer, None]:
|
||||
config = load_config(root_path_populated_with_config, "config.yaml")
|
||||
service_config = config["seeder"]
|
||||
|
||||
service_config["selected_network"] = "testnet0"
|
||||
if service_config["domain_name"].endswith("."): # remove the trailing . so that we can test that logic.
|
||||
service_config["domain_name"] = service_config["domain_name"][:-1]
|
||||
service_config["dns_port"] = 0
|
||||
service_config["crawler_db_path"] = database_uri
|
||||
|
||||
service = create_dns_server_service(
|
||||
config,
|
||||
root_path_populated_with_config,
|
||||
)
|
||||
async with service.run():
|
||||
yield service
|
||||
|
||||
|
||||
# Note: convert these setup functions to fixtures, or push it one layer up,
|
||||
# keeping these usable independently?
|
||||
async def setup_wallet_node(
|
||||
|
@ -139,6 +139,8 @@ seeder:
|
||||
port: 8444
|
||||
# Most full nodes on the network run on this port. (i.e. 8444 for mainnet, 58444 for testnet).
|
||||
other_peers_port: 8444
|
||||
# What port to run the DNS server on, (this is useful if you are already using port 53 for DNS).
|
||||
dns_port: 53
|
||||
# This will override the default full_node.peer_connect_timeout for the crawler full node
|
||||
peer_connect_timeout: 2
|
||||
# Path to crawler DB. Defaults to $CHIA_ROOT/crawler.db
|
||||
@ -154,7 +156,7 @@ seeder:
|
||||
nameserver: "example.com."
|
||||
ttl: 300
|
||||
soa:
|
||||
rname: "hostmaster.example.com."
|
||||
rname: "hostmaster.example.com" # all @ symbols need to be replaced with . in dns records.
|
||||
serial_number: 1619105223
|
||||
refresh: 10800
|
||||
retry: 10800
|
||||
|
@ -16,11 +16,6 @@ chia.rpc.rpc_client
|
||||
chia.rpc.util
|
||||
chia.rpc.wallet_rpc_api
|
||||
chia.rpc.wallet_rpc_client
|
||||
chia.seeder.crawl_store
|
||||
chia.seeder.crawler
|
||||
chia.seeder.dns_server
|
||||
chia.seeder.peer_record
|
||||
chia.seeder.start_crawler
|
||||
chia.simulator.full_node_simulator
|
||||
chia.simulator.keyring
|
||||
chia.simulator.setup_services
|
||||
|
@ -4,6 +4,7 @@ from __future__ import annotations
|
||||
import datetime
|
||||
import multiprocessing
|
||||
import os
|
||||
import random
|
||||
import sysconfig
|
||||
import tempfile
|
||||
from enum import Enum
|
||||
@ -32,6 +33,7 @@ from chia.rpc.harvester_rpc_client import HarvesterRpcClient
|
||||
from chia.rpc.wallet_rpc_client import WalletRpcClient
|
||||
from chia.seeder.crawler import Crawler
|
||||
from chia.seeder.crawler_api import CrawlerAPI
|
||||
from chia.seeder.dns_server import DNSServer
|
||||
from chia.server.server import ChiaServer
|
||||
from chia.server.start_service import Service
|
||||
from chia.simulator.full_node_simulator import FullNodeSimulator
|
||||
@ -44,7 +46,7 @@ from chia.simulator.setup_nodes import (
|
||||
setup_simulators_and_wallets_service,
|
||||
setup_two_nodes,
|
||||
)
|
||||
from chia.simulator.setup_services import setup_crawler, setup_daemon, setup_introducer, setup_timelord
|
||||
from chia.simulator.setup_services import setup_crawler, setup_daemon, setup_introducer, setup_seeder, setup_timelord
|
||||
from chia.simulator.time_out_assert import time_out_assert
|
||||
from chia.simulator.wallet_tools import WalletTool
|
||||
from chia.types.peer_info import PeerInfo
|
||||
@ -872,11 +874,19 @@ async def timelord_service(bt):
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="function")
|
||||
async def crawler_service(bt: BlockTools) -> AsyncIterator[Service[Crawler, CrawlerAPI]]:
|
||||
async for service in setup_crawler(bt):
|
||||
async def crawler_service(
|
||||
root_path_populated_with_config: Path, database_uri: str
|
||||
) -> AsyncIterator[Service[Crawler, CrawlerAPI]]:
|
||||
async for service in setup_crawler(root_path_populated_with_config, database_uri):
|
||||
yield service
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="function")
|
||||
async def seeder_service(root_path_populated_with_config: Path, database_uri: str) -> AsyncIterator[DNSServer]:
|
||||
async for seeder in setup_seeder(root_path_populated_with_config, database_uri):
|
||||
yield seeder
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def tmp_chia_root(tmp_path):
|
||||
"""
|
||||
@ -980,3 +990,8 @@ async def harvester_farmer_environment(
|
||||
harvester_rpc_cl.close()
|
||||
await farmer_rpc_cl.await_closed()
|
||||
await harvester_rpc_cl.await_closed()
|
||||
|
||||
|
||||
@pytest.fixture(name="database_uri")
|
||||
def database_uri_fixture() -> str:
|
||||
return f"file:db_{random.randint(0, 99999999)}?mode=memory&cache=shared"
|
||||
|
@ -2,7 +2,6 @@ from __future__ import annotations
|
||||
|
||||
import os
|
||||
import pathlib
|
||||
import random
|
||||
import sys
|
||||
import time
|
||||
from typing import Any, AsyncIterable, Awaitable, Callable, Dict, Iterator
|
||||
@ -54,11 +53,6 @@ def create_example_fixture(request: SubRequest) -> Callable[[DataStore, bytes32]
|
||||
return request.param # type: ignore[no-any-return]
|
||||
|
||||
|
||||
@pytest.fixture(name="database_uri")
|
||||
def database_uri_fixture() -> str:
|
||||
return f"file:db_{random.randint(0, 99999999)}?mode=memory&cache=shared"
|
||||
|
||||
|
||||
@pytest.fixture(name="tree_id", scope="function")
|
||||
def tree_id_fixture() -> bytes32:
|
||||
base = b"a tree id"
|
||||
|
@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
from typing import cast
|
||||
|
||||
import pytest
|
||||
@ -11,13 +12,14 @@ from chia.protocols.protocol_message_types import ProtocolMessageTypes
|
||||
from chia.protocols.wallet_protocol import RequestChildren
|
||||
from chia.seeder.crawler import Crawler
|
||||
from chia.seeder.crawler_api import CrawlerAPI
|
||||
from chia.seeder.peer_record import PeerRecord, PeerReliability
|
||||
from chia.server.outbound_message import make_msg
|
||||
from chia.server.start_service import Service
|
||||
from chia.simulator.setup_nodes import SimulatorsAndWalletsServices
|
||||
from chia.simulator.time_out_assert import time_out_assert
|
||||
from chia.types.blockchain_format.sized_bytes import bytes32
|
||||
from chia.types.peer_info import PeerInfo
|
||||
from chia.util.ints import uint16, uint32, uint128
|
||||
from chia.util.ints import uint16, uint32, uint64, uint128
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -68,3 +70,43 @@ async def test_valid_message(
|
||||
)
|
||||
assert await connection.send_message(msg)
|
||||
await time_out_assert(10, peer_added)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_crawler_to_db(
|
||||
crawler_service: Service[Crawler, CrawlerAPI], one_node: SimulatorsAndWalletsServices
|
||||
) -> None:
|
||||
"""
|
||||
This is a lot more of an integration test, but it tests the whole process. We add a node to the crawler, then we
|
||||
save it to the db and validate.
|
||||
"""
|
||||
[full_node_service], _, _ = one_node
|
||||
full_node = full_node_service._node
|
||||
crawler = crawler_service._node
|
||||
crawl_store = crawler.crawl_store
|
||||
assert crawl_store is not None
|
||||
peer_address = "127.0.0.1"
|
||||
|
||||
# create peer records
|
||||
peer_record = PeerRecord(
|
||||
peer_address,
|
||||
peer_address,
|
||||
uint32(full_node.server.get_port()),
|
||||
False,
|
||||
uint64(0),
|
||||
uint32(0),
|
||||
uint64(0),
|
||||
uint64(int(time.time())),
|
||||
uint64(0),
|
||||
"undefined",
|
||||
uint64(0),
|
||||
tls_version="unknown",
|
||||
)
|
||||
peer_reliability = PeerReliability(peer_address, tries=1, successes=1)
|
||||
|
||||
# add peer to the db & mark it as connected
|
||||
await crawl_store.add_peer(peer_record, peer_reliability)
|
||||
assert peer_record == crawl_store.host_to_records[peer_address]
|
||||
|
||||
# validate the db data
|
||||
await time_out_assert(20, crawl_store.get_good_peers, [peer_address])
|
||||
|
@ -9,7 +9,7 @@ from chia.seeder.crawler import Crawler
|
||||
class TestCrawlerRpc:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_ips_after_timestamp(self, bt):
|
||||
crawler = Crawler(bt.config.get("seeder", {}), bt.root_path, consensus_constants=bt.constants)
|
||||
crawler = Crawler(bt.config.get("seeder", {}), bt.root_path, constants=bt.constants)
|
||||
crawler_rpc_api = CrawlerRpcApi(crawler)
|
||||
|
||||
# Should raise ValueError when `after` is not supplied
|
||||
|
275
tests/core/test_seeder.py
Normal file
275
tests/core/test_seeder.py
Normal file
@ -0,0 +1,275 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from ipaddress import IPv4Address, IPv6Address
|
||||
from socket import AF_INET6, SOCK_STREAM
|
||||
from typing import List, Tuple, cast
|
||||
|
||||
import dns
|
||||
import pytest
|
||||
|
||||
from chia.seeder.dns_server import DNSServer
|
||||
from chia.seeder.peer_record import PeerRecord, PeerReliability
|
||||
from chia.simulator.time_out_assert import time_out_assert
|
||||
from chia.util.ints import uint32, uint64
|
||||
|
||||
timeout = 0.5
|
||||
|
||||
|
||||
def generate_test_combs() -> List[Tuple[bool, str, dns.rdatatype.RdataType]]:
|
||||
"""
|
||||
Generates all the combinations of tests we want to run.
|
||||
"""
|
||||
output = []
|
||||
use_tcp = [True, False]
|
||||
target_address = ["::1", "127.0.0.1"]
|
||||
request_types = [dns.rdatatype.A, dns.rdatatype.AAAA, dns.rdatatype.ANY, dns.rdatatype.NS, dns.rdatatype.SOA]
|
||||
# (use_tcp, target-addr, request_type), so udp + tcp, ipv6 +4 for every request type we support.
|
||||
for addr in target_address:
|
||||
for tcp in use_tcp:
|
||||
for req in request_types:
|
||||
output.append((tcp, addr, req))
|
||||
return output
|
||||
|
||||
|
||||
all_test_combinations = generate_test_combs()
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class FakeDnsPacket:
|
||||
real_packet: dns.message.Message
|
||||
|
||||
def to_wire(self) -> bytes:
|
||||
return self.real_packet.to_wire()[23:]
|
||||
|
||||
def __getattr__(self, item: object) -> None:
|
||||
# This is definitely cheating, but it works
|
||||
return None
|
||||
|
||||
|
||||
async def make_dns_query(
|
||||
use_tcp: bool, target_address: str, port: int, dns_message: dns.message.Message, d_timeout: float = timeout
|
||||
) -> dns.message.Message:
|
||||
"""
|
||||
Makes a DNS query for the given domain name using the given protocol type.
|
||||
"""
|
||||
if use_tcp:
|
||||
return await dns.asyncquery.tcp(q=dns_message, where=target_address, timeout=d_timeout, port=port)
|
||||
return await dns.asyncquery.udp(q=dns_message, where=target_address, timeout=d_timeout, port=port)
|
||||
|
||||
|
||||
def get_addresses(num_subnets: int = 10) -> Tuple[List[IPv4Address], List[IPv6Address]]:
|
||||
ipv4 = []
|
||||
ipv6 = []
|
||||
# generate 2500 ipv4 and 2500 ipv6 peers, it's just a string so who cares
|
||||
for s in range(num_subnets):
|
||||
for i in range(1, 251): # im being lazy as we can only have 255 per subnet
|
||||
ipv4.append(IPv4Address(f"192.168.{s}.{i}"))
|
||||
ipv6.append(IPv6Address(f"2001:db8::{s}:{i}"))
|
||||
return ipv4, ipv6
|
||||
|
||||
|
||||
def assert_standard_results(
|
||||
std_query_answer: List[dns.rrset.RRset], request_type: dns.rdatatype.RdataType, num_ns: int
|
||||
) -> None:
|
||||
if request_type == dns.rdatatype.A:
|
||||
assert len(std_query_answer) == 1 # only 1 kind of answer
|
||||
a_answer = std_query_answer[0]
|
||||
assert a_answer.rdtype == dns.rdatatype.A
|
||||
assert len(a_answer) == 32 # 32 ipv4 addresses
|
||||
elif request_type == dns.rdatatype.AAAA:
|
||||
assert len(std_query_answer) == 1 # only 1 kind of answer
|
||||
aaaa_answer = std_query_answer[0]
|
||||
assert aaaa_answer.rdtype == dns.rdatatype.AAAA
|
||||
assert len(aaaa_answer) == 32 # 32 ipv6 addresses
|
||||
elif request_type == dns.rdatatype.ANY:
|
||||
assert len(std_query_answer) == 4 # 4 kinds of answers
|
||||
for answer in std_query_answer:
|
||||
if answer.rdtype == dns.rdatatype.A:
|
||||
assert len(answer) == 16
|
||||
elif answer.rdtype == dns.rdatatype.AAAA:
|
||||
assert len(answer) == 16
|
||||
elif answer.rdtype == dns.rdatatype.NS:
|
||||
assert len(answer) == num_ns
|
||||
else:
|
||||
assert len(answer) == 1
|
||||
elif request_type == dns.rdatatype.NS:
|
||||
assert len(std_query_answer) == 1 # only 1 kind of answer
|
||||
ns_answer = std_query_answer[0]
|
||||
assert ns_answer.rdtype == dns.rdatatype.NS
|
||||
assert len(ns_answer) == num_ns # ns records
|
||||
else:
|
||||
assert len(std_query_answer) == 1 # soa
|
||||
soa_answer = std_query_answer[0]
|
||||
assert soa_answer.rdtype == dns.rdatatype.SOA
|
||||
assert len(soa_answer) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("use_tcp, target_address, request_type", all_test_combinations)
|
||||
async def test_error_conditions(
|
||||
seeder_service: DNSServer, use_tcp: bool, target_address: str, request_type: dns.rdatatype.RdataType
|
||||
) -> None:
|
||||
"""
|
||||
We check having no peers, an invalid packet, an early EOF, and a packet then an EOF halfway through (tcp only).
|
||||
We also check for a dns record that does not exist, and a dns record outside the domain.
|
||||
"""
|
||||
port = seeder_service.dns_port
|
||||
domain = seeder_service.domain # default is: seeder.example.com
|
||||
num_ns = len(seeder_service.ns_records)
|
||||
|
||||
# No peers
|
||||
no_peers = dns.message.make_query(domain, request_type)
|
||||
no_peers_response = await make_dns_query(use_tcp, target_address, port, no_peers)
|
||||
assert no_peers_response.rcode() == dns.rcode.NOERROR
|
||||
|
||||
if request_type == dns.rdatatype.A or request_type == dns.rdatatype.AAAA:
|
||||
assert len(no_peers_response.answer) == 0 # no response, as expected
|
||||
elif request_type == dns.rdatatype.ANY: # ns + soa
|
||||
assert len(no_peers_response.answer) == 2
|
||||
for answer in no_peers_response.answer:
|
||||
if answer.rdtype == dns.rdatatype.NS:
|
||||
assert len(answer.items) == num_ns
|
||||
else:
|
||||
assert len(answer.items) == 1
|
||||
elif request_type == dns.rdatatype.NS:
|
||||
assert len(no_peers_response.answer) == 1 # ns
|
||||
assert no_peers_response.answer[0].rdtype == dns.rdatatype.NS
|
||||
assert len(no_peers_response.answer[0].items) == num_ns
|
||||
else:
|
||||
assert len(no_peers_response.answer) == 1 # soa
|
||||
assert no_peers_response.answer[0].rdtype == dns.rdatatype.SOA
|
||||
assert len(no_peers_response.answer[0].items) == 1
|
||||
# Authority Records
|
||||
assert len(no_peers_response.authority) == num_ns + 1 # ns + soa
|
||||
for record_list in no_peers_response.authority:
|
||||
if record_list.rdtype == dns.rdatatype.NS:
|
||||
assert len(record_list.items) == num_ns
|
||||
else:
|
||||
assert len(record_list.items) == 1 # soa
|
||||
|
||||
# Invalid packet (this is kinda a pain)
|
||||
invalid_packet = cast(dns.message.Message, FakeDnsPacket(dns.message.make_query(domain, request_type)))
|
||||
with pytest.raises(EOFError if use_tcp else dns.exception.Timeout): # UDP will time out, TCP will EOF
|
||||
await make_dns_query(use_tcp, target_address, port, invalid_packet)
|
||||
|
||||
# early EOF packet
|
||||
if use_tcp:
|
||||
backend = dns.asyncbackend.get_default_backend()
|
||||
tcp_socket = await backend.make_socket( # type: ignore[no-untyped-call]
|
||||
AF_INET6, SOCK_STREAM, 0, ("::", 0), ("::1", port), timeout=timeout
|
||||
)
|
||||
async with tcp_socket as socket:
|
||||
await socket.close()
|
||||
with pytest.raises(EOFError):
|
||||
await dns.asyncquery.receive_tcp(tcp_socket, timeout)
|
||||
|
||||
# Packet then length header then EOF
|
||||
if use_tcp:
|
||||
backend = dns.asyncbackend.get_default_backend()
|
||||
tcp_socket = await backend.make_socket( # type: ignore[no-untyped-call]
|
||||
AF_INET6, SOCK_STREAM, 0, ("::", 0), ("::1", port), timeout=timeout
|
||||
)
|
||||
async with tcp_socket as socket: # send packet then length then eof
|
||||
r = await dns.asyncquery.tcp(q=no_peers, where="::1", timeout=timeout, sock=socket)
|
||||
assert r.answer == no_peers_response.answer
|
||||
# send 120, as the first 2 bytes / the length of the packet, so that the server expects more.
|
||||
await socket.sendall(int(120).to_bytes(2, byteorder="big"), int(time.time() + timeout))
|
||||
await socket.close()
|
||||
with pytest.raises(EOFError):
|
||||
await dns.asyncquery.receive_tcp(tcp_socket, timeout)
|
||||
|
||||
# Record does not exist
|
||||
record_does_not_exist = dns.message.make_query("doesnotexist." + domain, request_type)
|
||||
record_does_not_exist_response = await make_dns_query(use_tcp, target_address, port, record_does_not_exist)
|
||||
assert record_does_not_exist_response.rcode() == dns.rcode.NXDOMAIN
|
||||
assert len(record_does_not_exist_response.answer) == 0
|
||||
assert len(record_does_not_exist_response.authority) == num_ns + 1 # ns + soa
|
||||
|
||||
# Record outside domain
|
||||
record_outside_domain = dns.message.make_query("chia.net", request_type)
|
||||
record_outside_domain_response = await make_dns_query(use_tcp, target_address, port, record_outside_domain)
|
||||
assert record_outside_domain_response.rcode() == dns.rcode.REFUSED
|
||||
assert len(record_outside_domain_response.answer) == 0
|
||||
assert len(record_outside_domain_response.authority) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("use_tcp, target_address, request_type", all_test_combinations)
|
||||
async def test_dns_queries(
|
||||
seeder_service: DNSServer, use_tcp: bool, target_address: str, request_type: dns.rdatatype.RdataType
|
||||
) -> None:
|
||||
"""
|
||||
We add 5000 peers directly, then try every kind of query many times over both the TCP and UDP protocols.
|
||||
"""
|
||||
port = seeder_service.dns_port
|
||||
domain = seeder_service.domain # default is: seeder.example.com
|
||||
num_ns = len(seeder_service.ns_records)
|
||||
|
||||
# add 5000 peers (2500 ipv4, 2500 ipv6)
|
||||
seeder_service.reliable_peers_v4, seeder_service.reliable_peers_v6 = get_addresses()
|
||||
|
||||
# now we query for each type of record a lot of times and make sure we get the right number of responses
|
||||
for i in range(150):
|
||||
query = dns.message.make_query(domain, request_type, use_edns=True) # we need to generate a new request id.
|
||||
std_query_response = await make_dns_query(use_tcp, target_address, port, query)
|
||||
assert std_query_response.rcode() == dns.rcode.NOERROR
|
||||
assert_standard_results(std_query_response.answer, request_type, num_ns)
|
||||
|
||||
# Assert Authority Records
|
||||
assert len(std_query_response.authority) == num_ns + 1 # ns + soa
|
||||
for record_list in std_query_response.authority:
|
||||
if record_list.rdtype == dns.rdatatype.NS:
|
||||
assert len(record_list.items) == num_ns
|
||||
else:
|
||||
assert len(record_list.items) == 1 # soa
|
||||
# Validate EDNS
|
||||
e_query = dns.message.make_query(domain, dns.rdatatype.ANY, use_edns=False)
|
||||
with pytest.raises(dns.query.BadResponse): # response is truncated without EDNS
|
||||
await make_dns_query(False, target_address, port, e_query)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_db_processing(seeder_service: DNSServer) -> None:
|
||||
"""
|
||||
We add 1000 peers through the db, then try every kind of query over both the TCP and UDP protocols.
|
||||
"""
|
||||
port = seeder_service.dns_port
|
||||
domain = seeder_service.domain # default is: seeder.example.com
|
||||
num_ns = len(seeder_service.ns_records)
|
||||
crawl_store = seeder_service.crawl_store
|
||||
assert crawl_store is not None
|
||||
|
||||
ipv4, ipv6 = get_addresses(2) # get 1000 addresses
|
||||
# add peers to the in memory part of the db.
|
||||
for peer in [str(peer) for pair in zip(ipv4, ipv6) for peer in pair]:
|
||||
new_peer = PeerRecord(
|
||||
peer,
|
||||
peer,
|
||||
uint32(58444),
|
||||
False,
|
||||
uint64(0),
|
||||
uint32(0),
|
||||
uint64(0),
|
||||
uint64(int(time.time())),
|
||||
uint64(0),
|
||||
"undefined",
|
||||
uint64(0),
|
||||
tls_version="unknown",
|
||||
)
|
||||
new_peer_reliability = PeerReliability(peer, tries=3, successes=3) # make sure the peer starts as reliable.
|
||||
crawl_store.maybe_add_peer(new_peer, new_peer_reliability) # we don't fully add it because we don't need to.
|
||||
|
||||
# Write these new peers to db.
|
||||
await crawl_store.load_reliable_peers_to_db()
|
||||
|
||||
# wait for the new db to be read.
|
||||
await time_out_assert(30, lambda: seeder_service.reliable_peers_v4 != [])
|
||||
|
||||
# now we check all the combinations once (not a stupid amount of times)
|
||||
for use_tcp, target_address, request_type in all_test_combinations:
|
||||
query = dns.message.make_query(domain, request_type, use_edns=True) # we need to generate a new request id.
|
||||
std_query_response = await make_dns_query(use_tcp, target_address, port, query)
|
||||
assert std_query_response.rcode() == dns.rcode.NOERROR
|
||||
assert_standard_results(std_query_response.answer, request_type, num_ns)
|
Loading…
Reference in New Issue
Block a user