1 1.1 christos """ 2 1.1 christos Copyright (C) Internet Systems Consortium, Inc. ("ISC") 3 1.1 christos 4 1.1 christos SPDX-License-Identifier: MPL-2.0 5 1.1 christos 6 1.1 christos This Source Code Form is subject to the terms of the Mozilla Public 7 1.1 christos License, v. 2.0. If a copy of the MPL was not distributed with this 8 1.1 christos file, you can obtain one at https://mozilla.org/MPL/2.0/. 9 1.1 christos 10 1.1 christos See the COPYRIGHT file distributed with this work for additional 11 1.1 christos information regarding copyright ownership. 12 1.1 christos """ 13 1.1 christos 14 1.1.1.5 christos from collections.abc import AsyncGenerator, Callable, Coroutine, Sequence 15 1.1 christos from dataclasses import dataclass, field 16 1.1.1.5 christos from typing import Any, cast 17 1.1 christos 18 1.1 christos import abc 19 1.1 christos import asyncio 20 1.1.1.4 christos import contextlib 21 1.1.1.4 christos import copy 22 1.1 christos import enum 23 1.1 christos import functools 24 1.1 christos import logging 25 1.1 christos import os 26 1.1 christos import pathlib 27 1.1 christos import re 28 1.1 christos import signal 29 1.1 christos import struct 30 1.1 christos import sys 31 1.1 christos 32 1.1.1.4 christos import dns.exception 33 1.1 christos import dns.flags 34 1.1 christos import dns.message 35 1.1 christos import dns.name 36 1.1 christos import dns.node 37 1.1 christos import dns.rcode 38 1.1.1.4 christos import dns.rdata 39 1.1 christos import dns.rdataclass 40 1.1.1.4 christos import dns.rdataset 41 1.1 christos import dns.rdatatype 42 1.1 christos import dns.rrset 43 1.1.1.4 christos import dns.tsig 44 1.1 christos import dns.zone 45 1.1 christos 46 1.1 christos _UdpHandler = Callable[ 47 1.1.1.5 christos [bytes, tuple[str, int], asyncio.DatagramTransport], Coroutine[Any, Any, None] 48 1.1 christos ] 49 1.1 christos 50 1.1 christos 51 1.1 christos _TcpHandler = Callable[ 52 1.1 christos [asyncio.StreamReader, asyncio.StreamWriter], Coroutine[Any, Any, None] 53 1.1 christos ] 54 1.1 christos 55 1.1 christos 56 1.1 christos class _AsyncUdpHandler(asyncio.DatagramProtocol): 57 1.1 christos """ 58 1.1 christos Protocol implementation for handling UDP traffic using asyncio. 59 1.1 christos """ 60 1.1 christos 61 1.1 christos def __init__( 62 1.1 christos self, 63 1.1 christos handler: _UdpHandler, 64 1.1 christos ) -> None: 65 1.1.1.5 christos self._transport: asyncio.DatagramTransport | None = None 66 1.1 christos self._handler: _UdpHandler = handler 67 1.1 christos 68 1.1 christos def connection_made(self, transport: asyncio.BaseTransport) -> None: 69 1.1 christos """ 70 1.1 christos Called by asyncio when a connection is made. 71 1.1 christos """ 72 1.1 christos self._transport = cast(asyncio.DatagramTransport, transport) 73 1.1 christos 74 1.1.1.5 christos def datagram_received(self, data: bytes, addr: tuple[str, int]) -> None: 75 1.1 christos """ 76 1.1 christos Called by asyncio when a datagram is received. 77 1.1 christos """ 78 1.1 christos assert self._transport 79 1.1 christos handler_coroutine = self._handler(data, addr, self._transport) 80 1.1 christos try: 81 1.1 christos # Python >= 3.7 82 1.1 christos asyncio.create_task(handler_coroutine) 83 1.1 christos except AttributeError: 84 1.1 christos # Python < 3.7 85 1.1 christos loop = asyncio.get_event_loop() 86 1.1 christos loop.create_task(handler_coroutine) 87 1.1 christos 88 1.1 christos 89 1.1 christos class AsyncServer: 90 1.1 christos """ 91 1.1 christos A generic asynchronous server which may handle UDP and/or TCP traffic. 92 1.1 christos 93 1.1 christos Once the server is executed as asyncio coroutine, it will keep running 94 1.1 christos until a SIGINT/SIGTERM signal is received. 95 1.1 christos """ 96 1.1 christos 97 1.1 christos def __init__( 98 1.1 christos self, 99 1.1.1.5 christos udp_handler: _UdpHandler | None, 100 1.1.1.5 christos tcp_handler: _TcpHandler | None, 101 1.1.1.5 christos pidfile: str | None = None, 102 1.1 christos ) -> None: 103 1.1 christos logging.basicConfig( 104 1.1 christos format="%(asctime)s %(levelname)8s %(message)s", 105 1.1 christos level=os.environ.get("ANS_LOG_LEVEL", "INFO").upper(), 106 1.1 christos ) 107 1.1 christos try: 108 1.1 christos ipv4_address = sys.argv[1] 109 1.1 christos except IndexError: 110 1.1 christos ipv4_address = self._get_ipv4_address_from_directory_name() 111 1.1 christos 112 1.1 christos last_ipv4_address_octet = ipv4_address.split(".")[-1] 113 1.1 christos ipv6_address = f"fd92:7065:b8e:ffff::{last_ipv4_address_octet}" 114 1.1 christos 115 1.1 christos try: 116 1.1 christos port = int(sys.argv[2]) 117 1.1 christos except IndexError: 118 1.1 christos port = int(os.environ.get("PORT", 5300)) 119 1.1 christos 120 1.1 christos logging.info("Setting up IPv4 listener at %s:%d", ipv4_address, port) 121 1.1 christos logging.info("Setting up IPv6 listener at [%s]:%d", ipv6_address, port) 122 1.1 christos 123 1.1.1.5 christos self._ip_addresses: tuple[str, str] = (ipv4_address, ipv6_address) 124 1.1 christos self._port: int = port 125 1.1.1.5 christos self._udp_handler: _UdpHandler | None = udp_handler 126 1.1.1.5 christos self._tcp_handler: _TcpHandler | None = tcp_handler 127 1.1.1.5 christos self._pidfile: str | None = pidfile 128 1.1.1.5 christos self._work_done: asyncio.Future | None = None 129 1.1.1.4 christos 130 1.1 christos def _get_ipv4_address_from_directory_name(self) -> str: 131 1.1 christos containing_directory = pathlib.Path().absolute().stem 132 1.1 christos match_result = re.match(r"ans(?P<index>\d+)", containing_directory) 133 1.1 christos if not match_result: 134 1.1 christos raise RuntimeError("Unable to auto-determine the IPv4 address to use") 135 1.1 christos 136 1.1 christos return f"10.53.0.{match_result.group('index')}" 137 1.1 christos 138 1.1 christos def run(self) -> None: 139 1.1 christos """ 140 1.1 christos Start the server in an asynchronous coroutine. 141 1.1 christos """ 142 1.1 christos coroutine = self._run 143 1.1 christos try: 144 1.1 christos # Python >= 3.7 145 1.1 christos asyncio.run(coroutine()) 146 1.1 christos except AttributeError: 147 1.1 christos # Python < 3.7 148 1.1 christos loop = asyncio.get_event_loop() 149 1.1 christos loop.run_until_complete(coroutine()) 150 1.1 christos 151 1.1 christos async def _run(self) -> None: 152 1.1.1.2 christos self._setup_exception_handler() 153 1.1 christos self._setup_signals() 154 1.1 christos assert self._work_done 155 1.1 christos await self._listen_udp() 156 1.1 christos await self._listen_tcp() 157 1.1 christos self._write_pidfile() 158 1.1 christos await self._work_done 159 1.1 christos self._cleanup_pidfile() 160 1.1 christos 161 1.1 christos def _get_asyncio_loop(self) -> asyncio.AbstractEventLoop: 162 1.1 christos try: 163 1.1 christos # Python >= 3.7 164 1.1 christos loop = asyncio.get_running_loop() 165 1.1 christos except AttributeError: 166 1.1 christos # Python < 3.7 167 1.1 christos loop = asyncio.get_event_loop() 168 1.1 christos return loop 169 1.1 christos 170 1.1.1.2 christos def _setup_exception_handler(self) -> None: 171 1.1 christos loop = self._get_asyncio_loop() 172 1.1 christos self._work_done = loop.create_future() 173 1.1.1.2 christos loop.set_exception_handler(self._handle_exception) 174 1.1.1.2 christos 175 1.1.1.2 christos def _handle_exception( 176 1.1.1.5 christos self, _: asyncio.AbstractEventLoop, context: dict[str, Any] 177 1.1.1.2 christos ) -> None: 178 1.1.1.2 christos assert self._work_done 179 1.1.1.2 christos exception = context.get("exception", RuntimeError(context["message"])) 180 1.1.1.4 christos try: 181 1.1.1.4 christos self._work_done.set_exception(exception) 182 1.1.1.4 christos except asyncio.InvalidStateError: 183 1.1.1.4 christos pass 184 1.1.1.2 christos 185 1.1.1.2 christos def _setup_signals(self) -> None: 186 1.1.1.2 christos loop = self._get_asyncio_loop() 187 1.1 christos loop.add_signal_handler(signal.SIGINT, functools.partial(self._signal_done)) 188 1.1 christos loop.add_signal_handler(signal.SIGTERM, functools.partial(self._signal_done)) 189 1.1 christos 190 1.1 christos def _signal_done(self) -> None: 191 1.1 christos assert self._work_done 192 1.1.1.4 christos try: 193 1.1.1.4 christos self._work_done.set_result(True) 194 1.1.1.4 christos except asyncio.InvalidStateError: 195 1.1.1.4 christos pass 196 1.1 christos 197 1.1 christos async def _listen_udp(self) -> None: 198 1.1 christos if not self._udp_handler: 199 1.1 christos return 200 1.1 christos loop = self._get_asyncio_loop() 201 1.1 christos for ip_address in self._ip_addresses: 202 1.1 christos await loop.create_datagram_endpoint( 203 1.1 christos lambda: _AsyncUdpHandler(cast(_UdpHandler, self._udp_handler)), 204 1.1 christos (ip_address, self._port), 205 1.1 christos ) 206 1.1 christos 207 1.1 christos async def _listen_tcp(self) -> None: 208 1.1 christos if not self._tcp_handler: 209 1.1 christos return 210 1.1 christos for ip_address in self._ip_addresses: 211 1.1 christos await asyncio.start_server( 212 1.1 christos self._tcp_handler, host=ip_address, port=self._port 213 1.1 christos ) 214 1.1 christos 215 1.1 christos def _write_pidfile(self) -> None: 216 1.1 christos if not self._pidfile: 217 1.1 christos return 218 1.1 christos logging.info("Writing PID to %s", self._pidfile) 219 1.1 christos with open(self._pidfile, "w", encoding="ascii") as pidfile: 220 1.1 christos print(f"{os.getpid()}", file=pidfile) 221 1.1 christos 222 1.1 christos def _cleanup_pidfile(self) -> None: 223 1.1 christos if not self._pidfile: 224 1.1 christos return 225 1.1 christos logging.info("Removing %s", self._pidfile) 226 1.1 christos os.unlink(self._pidfile) 227 1.1 christos 228 1.1 christos 229 1.1 christos class DnsProtocol(enum.Enum): 230 1.1 christos UDP = enum.auto() 231 1.1 christos TCP = enum.auto() 232 1.1 christos 233 1.1 christos 234 1.1.1.2 christos @dataclass(frozen=True) 235 1.1.1.2 christos class Peer: 236 1.1.1.2 christos """ 237 1.1.1.2 christos Pretty-printed connection endpoint. 238 1.1.1.2 christos """ 239 1.1.1.2 christos 240 1.1.1.2 christos host: str 241 1.1.1.2 christos port: int 242 1.1.1.2 christos 243 1.1.1.2 christos def __str__(self) -> str: 244 1.1.1.2 christos host = f"[{self.host}]" if ":" in self.host else self.host 245 1.1.1.2 christos return f"{host}:{self.port}" 246 1.1.1.2 christos 247 1.1.1.2 christos 248 1.1 christos @dataclass 249 1.1 christos class QueryContext: 250 1.1 christos """ 251 1.1 christos Context for the incoming query which may be used for preparing the response. 252 1.1 christos """ 253 1.1 christos 254 1.1 christos query: dns.message.Message 255 1.1 christos response: dns.message.Message 256 1.1.1.5 christos socket: Peer 257 1.1.1.2 christos peer: Peer 258 1.1 christos protocol: DnsProtocol 259 1.1.1.5 christos zone: dns.zone.Zone | None = field(default=None, init=False) 260 1.1.1.5 christos soa: dns.rrset.RRset | None = field(default=None, init=False) 261 1.1.1.5 christos node: dns.node.Node | None = field(default=None, init=False) 262 1.1.1.5 christos answer: dns.rdataset.Rdataset | None = field(default=None, init=False) 263 1.1.1.5 christos alias: dns.name.Name | None = field(default=None, init=False) 264 1.1.1.5 christos _initialized_response: dns.message.Message | None = field(default=None, init=False) 265 1.1.1.5 christos _initialized_response_with_zone_data: dns.message.Message | None = field( 266 1.1.1.4 christos default=None, init=False 267 1.1.1.4 christos ) 268 1.1 christos 269 1.1 christos @property 270 1.1 christos def qname(self) -> dns.name.Name: 271 1.1 christos return self.query.question[0].name 272 1.1 christos 273 1.1 christos @property 274 1.1.1.3 christos def current_qname(self) -> dns.name.Name: 275 1.1.1.3 christos return self.alias or self.qname 276 1.1.1.3 christos 277 1.1.1.3 christos @property 278 1.1.1.4 christos def qclass(self) -> dns.rdataclass.RdataClass: 279 1.1 christos return self.query.question[0].rdclass 280 1.1 christos 281 1.1 christos @property 282 1.1.1.4 christos def qtype(self) -> dns.rdatatype.RdataType: 283 1.1 christos return self.query.question[0].rdtype 284 1.1 christos 285 1.1.1.4 christos def prepare_new_response( 286 1.1.1.4 christos self, /, with_zone_data: bool = True 287 1.1.1.4 christos ) -> dns.message.Message: 288 1.1.1.4 christos if with_zone_data: 289 1.1.1.4 christos assert self._initialized_response_with_zone_data 290 1.1.1.4 christos self.response = copy.deepcopy(self._initialized_response_with_zone_data) 291 1.1.1.4 christos else: 292 1.1.1.4 christos assert self._initialized_response 293 1.1.1.4 christos self.response = copy.deepcopy(self._initialized_response) 294 1.1.1.4 christos return self.response 295 1.1.1.4 christos 296 1.1.1.4 christos def save_initialized_response(self, /, with_zone_data: bool) -> None: 297 1.1.1.4 christos if with_zone_data: 298 1.1.1.4 christos self._initialized_response_with_zone_data = copy.deepcopy(self.response) 299 1.1.1.4 christos else: 300 1.1.1.4 christos self._initialized_response = copy.deepcopy(self.response) 301 1.1.1.4 christos 302 1.1 christos 303 1.1 christos @dataclass 304 1.1 christos class ResponseAction(abc.ABC): 305 1.1 christos """ 306 1.1 christos Base class for actions that can be taken in response to a query. 307 1.1 christos """ 308 1.1 christos 309 1.1 christos @abc.abstractmethod 310 1.1.1.5 christos async def perform(self) -> dns.message.Message | bytes | None: 311 1.1 christos """ 312 1.1 christos This method is expected to carry out arbitrary actions (e.g. wait for a 313 1.1 christos specific amount of time, modify the answer, etc.) and then return the 314 1.1 christos DNS response to send (a dns.message.Message, a raw bytes object, or 315 1.1 christos None, which prevents any response from being sent). 316 1.1 christos """ 317 1.1 christos raise NotImplementedError 318 1.1 christos 319 1.1 christos 320 1.1 christos @dataclass 321 1.1 christos class DnsResponseSend(ResponseAction): 322 1.1 christos """ 323 1.1 christos Action which yields a dns.message.Message response. 324 1.1 christos 325 1.1 christos The response may be sent with a delay if requested. 326 1.1 christos 327 1.1 christos Depending on the value of the `authoritative` property, this class may set 328 1.1 christos the AA bit in the response (True), clear it (False), or not touch it at all 329 1.1 christos (None). 330 1.1 christos """ 331 1.1 christos 332 1.1 christos response: dns.message.Message 333 1.1.1.5 christos authoritative: bool | None = None 334 1.1 christos delay: float = 0.0 335 1.1.1.5 christos acknowledge_hand_rolled_response: bool = False 336 1.1 christos 337 1.1.1.5 christos async def perform(self) -> dns.message.Message | bytes | None: 338 1.1 christos """ 339 1.1 christos Yield a potentially delayed response that is a dns.message.Message. 340 1.1 christos """ 341 1.1 christos assert isinstance(self.response, dns.message.Message) 342 1.1.1.5 christos if not ( 343 1.1.1.5 christos _is_asyncserver_response(self.response) 344 1.1.1.5 christos or self.acknowledge_hand_rolled_response 345 1.1.1.5 christos ): 346 1.1.1.5 christos error = "The response you are trying to send was not created using " 347 1.1.1.5 christos error += "AsyncDnsServer's response preparation methods. " 348 1.1.1.5 christos error += "This will break features such as automatic AA flag " 349 1.1.1.5 christos error += "and RCODE handling. If you need a fresh copy of a " 350 1.1.1.5 christos error += "response, use `QueryContext.prepare_new_response` " 351 1.1.1.5 christos error += "instead of `dns.message.make_response`. " 352 1.1.1.5 christos error += "To acknowledge this and proceed anyway, set " 353 1.1.1.5 christos error += "`acknowledge_hand_rolled_response=True` in " 354 1.1.1.5 christos error += "DnsResponseSend's constructor." 355 1.1.1.5 christos raise RuntimeError(error) 356 1.1.1.5 christos 357 1.1 christos if self.authoritative is not None: 358 1.1 christos if self.authoritative: 359 1.1 christos self.response.flags |= dns.flags.AA 360 1.1 christos else: 361 1.1 christos self.response.flags &= ~dns.flags.AA 362 1.1 christos if self.delay > 0: 363 1.1 christos logging.info( 364 1.1 christos "Delaying response (ID=%d) by %d ms", 365 1.1 christos self.response.id, 366 1.1 christos self.delay * 1000, 367 1.1 christos ) 368 1.1 christos await asyncio.sleep(self.delay) 369 1.1 christos return self.response 370 1.1 christos 371 1.1 christos 372 1.1 christos @dataclass 373 1.1 christos class BytesResponseSend(ResponseAction): 374 1.1 christos """ 375 1.1 christos Action which yields a raw response that is a sequence of bytes. 376 1.1 christos 377 1.1 christos The response may be sent with a delay if requested. 378 1.1 christos """ 379 1.1 christos 380 1.1 christos response: bytes 381 1.1 christos delay: float = 0.0 382 1.1 christos 383 1.1.1.5 christos async def perform(self) -> dns.message.Message | bytes | None: 384 1.1 christos """ 385 1.1 christos Yield a potentially delayed response that is a sequence of bytes. 386 1.1 christos """ 387 1.1 christos assert isinstance(self.response, bytes) 388 1.1 christos if self.delay > 0: 389 1.1 christos logging.info("Delaying raw response by %d ms", self.delay * 1000) 390 1.1 christos await asyncio.sleep(self.delay) 391 1.1 christos return self.response 392 1.1 christos 393 1.1 christos 394 1.1 christos @dataclass 395 1.1 christos class ResponseDrop(ResponseAction): 396 1.1 christos """ 397 1.1 christos Action which does nothing - as if a packet was dropped. 398 1.1 christos """ 399 1.1 christos 400 1.1.1.5 christos async def perform(self) -> dns.message.Message | bytes | None: 401 1.1 christos return None 402 1.1 christos 403 1.1 christos 404 1.1.1.4 christos class _ConnectionTeardownRequested(Exception): 405 1.1.1.4 christos pass 406 1.1.1.4 christos 407 1.1.1.4 christos 408 1.1.1.4 christos @dataclass 409 1.1.1.5 christos class CloseConnection(ResponseAction): 410 1.1.1.4 christos """ 411 1.1.1.5 christos Action which makes the server close the connection (TCP only). 412 1.1.1.4 christos 413 1.1.1.4 christos The connection may be closed with a delay if requested. 414 1.1.1.4 christos """ 415 1.1.1.4 christos 416 1.1.1.4 christos delay: float = 0.0 417 1.1.1.4 christos 418 1.1.1.5 christos async def perform(self) -> dns.message.Message | bytes | None: 419 1.1.1.4 christos if self.delay > 0: 420 1.1.1.4 christos logging.info("Waiting %.1fs before closing TCP connection", self.delay) 421 1.1.1.4 christos await asyncio.sleep(self.delay) 422 1.1.1.4 christos raise _ConnectionTeardownRequested 423 1.1.1.4 christos 424 1.1.1.4 christos 425 1.1.1.4 christos class ConnectionHandler(abc.ABC): 426 1.1.1.4 christos """ 427 1.1.1.4 christos Base class for TCP connection handlers. 428 1.1.1.4 christos 429 1.1.1.4 christos An installed connection handler is called when a new TCP connection is 430 1.1.1.4 christos established. It may be used to perform arbitrary actions before 431 1.1.1.4 christos AsyncDnsServer processes DNS queries. 432 1.1.1.4 christos """ 433 1.1.1.4 christos 434 1.1.1.4 christos @abc.abstractmethod 435 1.1.1.4 christos async def handle( 436 1.1.1.4 christos self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter, peer: Peer 437 1.1.1.4 christos ) -> None: 438 1.1.1.4 christos """ 439 1.1.1.4 christos Handle the connection with the provided reader and writer. 440 1.1.1.4 christos """ 441 1.1.1.4 christos raise NotImplementedError 442 1.1.1.4 christos 443 1.1.1.4 christos 444 1.1.1.4 christos def block_reading(peer: Peer, writer_not_the_reader: asyncio.StreamWriter) -> None: 445 1.1.1.4 christos """ 446 1.1.1.4 christos Block reads for the reader associated with the provided writer. 447 1.1.1.4 christos 448 1.1.1.4 christos Yes, pass the writer, not the reader. See the comments below for details. 449 1.1.1.4 christos """ 450 1.1.1.4 christos 451 1.1.1.4 christos try: 452 1.1.1.4 christos # Python >= 3.7 453 1.1.1.4 christos loop = asyncio.get_running_loop() 454 1.1.1.4 christos except AttributeError: 455 1.1.1.4 christos # Python < 3.7 456 1.1.1.4 christos loop = asyncio.get_event_loop() 457 1.1.1.4 christos 458 1.1.1.4 christos logging.info("Blocking reads from %s", peer) 459 1.1.1.4 christos 460 1.1.1.4 christos # This is Micha's submission for the Ugliest Hack of the Year contest. 461 1.1.1.4 christos # (The alternative was implementing an asyncio transport from scratch.) 462 1.1.1.4 christos # 463 1.1.1.4 christos # In order to prevent the client socket from being read from, simply 464 1.1.1.4 christos # not calling `reader.read()` is not enough, because asyncio buffers 465 1.1.1.4 christos # incoming data itself on the transport level. However, `StreamReader` 466 1.1.1.4 christos # does not expose the underlying transport as a property. Therefore, 467 1.1.1.4 christos # cheat by extracting it from `StreamWriter` as it is the same 468 1.1.1.4 christos # bidirectional transport as for the read side (a `Transport`, which is 469 1.1.1.4 christos # a subclass of both `ReadTransport` and `WriteTransport`) and call 470 1.1.1.4 christos # `ReadTransport.pause_reading()` to remove the underlying socket from 471 1.1.1.4 christos # the set of descriptors monitored by the selector, thereby preventing 472 1.1.1.4 christos # any reads from happening on the client socket. However... 473 1.1.1.4 christos loop.call_soon(writer_not_the_reader.transport.pause_reading) # type: ignore 474 1.1.1.4 christos 475 1.1.1.4 christos # ...due to `AsyncDnsServer._handle_tcp()` being a coroutine, by the 476 1.1.1.4 christos # time it gets executed, asyncio transport code will already have added 477 1.1.1.4 christos # the client socket to the set of descriptors monitored by the 478 1.1.1.4 christos # selector. Therefore, if the client starts sending data immediately, 479 1.1.1.4 christos # a read from the socket will have already been scheduled by the time 480 1.1.1.4 christos # this handler gets executed. There is no way to prevent that from 481 1.1.1.4 christos # happening, so work around it by abusing the fact that the transport 482 1.1.1.4 christos # at hand is specifically an instance of `_SelectorSocketTransport` 483 1.1.1.4 christos # (from asyncio.selector_events) and set the size of its read buffer to 484 1.1.1.4 christos # just a single byte. This does give asyncio enough time to read that 485 1.1.1.4 christos # single byte from the client socket's buffer before that socket is 486 1.1.1.4 christos # removed from the set of monitored descriptors, but prevents the 487 1.1.1.4 christos # one-off read from emptying the client socket buffer _entirely_, which 488 1.1.1.4 christos # is enough to trigger sending an RST segment when the connection is 489 1.1.1.4 christos # closed shortly afterwards. 490 1.1.1.4 christos writer_not_the_reader.transport.max_size = 1 # type: ignore 491 1.1.1.4 christos 492 1.1.1.4 christos 493 1.1.1.4 christos @dataclass 494 1.1.1.4 christos class IgnoreAllConnections(ConnectionHandler): 495 1.1.1.4 christos """ 496 1.1.1.4 christos A connection handler that makes the server not read anything from the 497 1.1.1.4 christos client socket, effectively ignoring all incoming connections. 498 1.1.1.4 christos """ 499 1.1.1.4 christos 500 1.1.1.5 christos _connections: set[asyncio.StreamWriter] = field(default_factory=set) 501 1.1.1.4 christos 502 1.1.1.4 christos async def handle( 503 1.1.1.4 christos self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter, peer: Peer 504 1.1.1.4 christos ) -> None: 505 1.1.1.4 christos block_reading(peer, writer) 506 1.1.1.4 christos # Due to the way various asyncio-related objects (tasks, streams, 507 1.1.1.4 christos # transports, selectors) are referencing each other, pausing reads for 508 1.1.1.4 christos # a TCP transport (which in practice means removing the client socket 509 1.1.1.4 christos # from the set of descriptors monitored by a selector) can cause the 510 1.1.1.4 christos # client task (AsyncDnsServer._handle_tcp()) to be prematurely 511 1.1.1.4 christos # garbage-collected, causing asyncio code to raise a "Task was 512 1.1.1.4 christos # destroyed but it is pending!" exception. Prevent that from happening 513 1.1.1.4 christos # by keeping a reference to each incoming TCP connection to protect its 514 1.1.1.4 christos # related asyncio objects from getting garbage-collected. This 515 1.1.1.4 christos # prevents AsyncDnsServer from closing any of the ignored TCP 516 1.1.1.4 christos # connections indefinitely, which is obviously a pretty brain-dead idea 517 1.1.1.4 christos # for a production-grade DNS server, but AsyncDnsServer was never meant 518 1.1.1.4 christos # to be one and this hack reliably solves the problem at hand. 519 1.1.1.4 christos self._connections.add(writer) 520 1.1.1.4 christos 521 1.1.1.4 christos 522 1.1.1.4 christos @dataclass 523 1.1.1.4 christos class ConnectionReset(ConnectionHandler): 524 1.1.1.4 christos """ 525 1.1.1.4 christos A connection handler that makes the server close the connection without 526 1.1.1.4 christos reading anything from the client socket. 527 1.1.1.4 christos 528 1.1.1.4 christos The connection may be closed with a delay if requested. 529 1.1.1.4 christos 530 1.1.1.4 christos The sole purpose of this handler is to trigger a connection reset, i.e. to 531 1.1.1.4 christos make the server send an RST segment; this happens when the server closes a 532 1.1.1.4 christos client's socket while there is still unread data in that socket's buffer. 533 1.1.1.4 christos If closing the connection _after_ the query is read by the server is enough 534 1.1.1.5 christos for a given use case, the CloseConnection response handler should be used 535 1.1.1.5 christos instead. 536 1.1.1.4 christos """ 537 1.1.1.4 christos 538 1.1.1.4 christos delay: float = 0.0 539 1.1.1.4 christos 540 1.1.1.4 christos async def handle( 541 1.1.1.4 christos self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter, peer: Peer 542 1.1.1.4 christos ) -> None: 543 1.1.1.4 christos block_reading(peer, writer) 544 1.1.1.4 christos 545 1.1.1.4 christos if self.delay > 0: 546 1.1.1.4 christos logging.info( 547 1.1.1.4 christos "Waiting %.1fs before closing TCP connection from %s", self.delay, peer 548 1.1.1.4 christos ) 549 1.1.1.4 christos await asyncio.sleep(self.delay) 550 1.1.1.4 christos 551 1.1.1.4 christos raise _ConnectionTeardownRequested 552 1.1.1.4 christos 553 1.1.1.4 christos 554 1.1 christos class ResponseHandler(abc.ABC): 555 1.1 christos """ 556 1.1 christos Base class for generic response handlers. 557 1.1 christos 558 1.1 christos If a query passes the `match()` function logic, then it is handled by this 559 1.1 christos response handler and response(s) may be generated by the `get_responses()` 560 1.1 christos method. 561 1.1 christos """ 562 1.1 christos 563 1.1.1.2 christos # pylint: disable=unused-argument 564 1.1 christos def match(self, qctx: QueryContext) -> bool: 565 1.1 christos """ 566 1.1.1.2 christos Matching logic - the first handler whose `match()` method returns True 567 1.1.1.2 christos is used for handling the query. 568 1.1.1.2 christos 569 1.1.1.2 christos The default for each handler is to handle all queries. 570 1.1 christos """ 571 1.1 christos return True 572 1.1 christos 573 1.1 christos @abc.abstractmethod 574 1.1 christos async def get_responses( 575 1.1 christos self, qctx: QueryContext 576 1.1 christos ) -> AsyncGenerator[ResponseAction, None]: 577 1.1 christos """ 578 1.1 christos Custom handler which may produce response(s) to matching queries. 579 1.1 christos 580 1.1 christos The response prepared from zone data is passed to this method in 581 1.1 christos qctx.response. 582 1.1 christos """ 583 1.1 christos yield DnsResponseSend(qctx.response) 584 1.1 christos 585 1.1.1.2 christos def __str__(self) -> str: 586 1.1.1.2 christos return self.__class__.__name__ 587 1.1.1.2 christos 588 1.1.1.2 christos 589 1.1.1.2 christos class IgnoreAllQueries(ResponseHandler): 590 1.1.1.2 christos """ 591 1.1.1.2 christos Do not respond to any queries sent to the server. 592 1.1.1.2 christos """ 593 1.1.1.2 christos 594 1.1.1.2 christos async def get_responses( 595 1.1.1.2 christos self, qctx: QueryContext 596 1.1.1.2 christos ) -> AsyncGenerator[ResponseAction, None]: 597 1.1.1.2 christos yield ResponseDrop() 598 1.1.1.2 christos 599 1.1 christos 600 1.1.1.4 christos class QnameHandler(ResponseHandler): 601 1.1.1.4 christos """ 602 1.1.1.4 christos Base class used for deriving custom QNAME handlers. 603 1.1.1.4 christos 604 1.1.1.4 christos The derived class must specify a list of `qnames` that it wants to handle. 605 1.1.1.4 christos Queries for exactly these QNAMEs will then be passed to the 606 1.1.1.4 christos `get_response()` method in the derived class. 607 1.1.1.4 christos """ 608 1.1.1.4 christos 609 1.1.1.4 christos @property 610 1.1.1.4 christos @abc.abstractmethod 611 1.1.1.5 christos def qnames(self) -> list[str]: 612 1.1.1.4 christos """ 613 1.1.1.4 christos A list of QNAMEs handled by this class. 614 1.1.1.4 christos """ 615 1.1.1.4 christos raise NotImplementedError 616 1.1.1.4 christos 617 1.1.1.4 christos def __init__(self) -> None: 618 1.1.1.5 christos self._qnames: list[dns.name.Name] = [dns.name.from_text(d) for d in self.qnames] 619 1.1.1.4 christos 620 1.1.1.4 christos def __str__(self) -> str: 621 1.1.1.4 christos return f"{self.__class__.__name__}(QNAMEs: {', '.join(self.qnames)})" 622 1.1.1.4 christos 623 1.1.1.4 christos def match(self, qctx: QueryContext) -> bool: 624 1.1.1.4 christos """ 625 1.1.1.4 christos Handle queries whose QNAME matches any of the QNAMEs handled by this 626 1.1.1.4 christos class. 627 1.1.1.4 christos """ 628 1.1.1.4 christos return qctx.qname in self._qnames 629 1.1.1.4 christos 630 1.1.1.4 christos 631 1.1.1.5 christos class QnameQtypeHandler(QnameHandler): 632 1.1.1.5 christos """ 633 1.1.1.5 christos Handle queries for which both of the following conditions are true: 634 1.1.1.5 christos 635 1.1.1.5 christos - the query's QNAME is present in `self.qnames`, 636 1.1.1.5 christos - the query's QTYPE is present in `self.qtypes`. 637 1.1.1.5 christos """ 638 1.1.1.5 christos 639 1.1.1.5 christos @property 640 1.1.1.5 christos @abc.abstractmethod 641 1.1.1.5 christos def qtypes(self) -> list[dns.rdatatype.RdataType]: 642 1.1.1.5 christos """ 643 1.1.1.5 christos A list of QTYPEs handled by this class. 644 1.1.1.5 christos """ 645 1.1.1.5 christos raise NotImplementedError 646 1.1.1.5 christos 647 1.1.1.5 christos def __init__(self) -> None: 648 1.1.1.5 christos super().__init__() 649 1.1.1.5 christos self._qtypes: list[dns.rdatatype.RdataType] = self.qtypes 650 1.1.1.5 christos 651 1.1.1.5 christos def __str__(self) -> str: 652 1.1.1.5 christos return f"{self.__class__.__name__}(QNAMEs: {', '.join(self.qnames)}; QTYPEs: {', '.join(map(str, self.qtypes))})" 653 1.1.1.5 christos 654 1.1.1.5 christos def match(self, qctx: QueryContext) -> bool: 655 1.1.1.5 christos """ 656 1.1.1.5 christos Handle queries whose QNAME and QTYPE match any of the QNAMEs and 657 1.1.1.5 christos QTYPEs handled by this class. 658 1.1.1.5 christos """ 659 1.1.1.5 christos return qctx.qtype in self._qtypes and super().match(qctx) 660 1.1.1.5 christos 661 1.1.1.5 christos 662 1.1.1.5 christos class StaticResponseHandler(ResponseHandler): 663 1.1.1.5 christos """ 664 1.1.1.5 christos Base class used for deriving custom static response handlers. 665 1.1.1.5 christos 666 1.1.1.5 christos The derived class can specify the RRsets to be included in the answer, 667 1.1.1.5 christos authority, and additional sections of the response, whether to set the AA 668 1.1.1.5 christos bit in the response, and a delay before sending the response. 669 1.1.1.5 christos 670 1.1.1.5 christos The default implementation of `get_responses()` uses these properties to 671 1.1.1.5 christos prepare and yield a single response. 672 1.1.1.5 christos """ 673 1.1.1.5 christos 674 1.1.1.5 christos @property 675 1.1.1.5 christos def rcode(self) -> dns.rcode.Rcode | None: 676 1.1.1.5 christos """ 677 1.1.1.5 christos Optional RCODE to be set in the response. 678 1.1.1.5 christos """ 679 1.1.1.5 christos return None 680 1.1.1.5 christos 681 1.1.1.5 christos @property 682 1.1.1.5 christos def answer(self) -> Sequence[dns.rrset.RRset]: 683 1.1.1.5 christos """ 684 1.1.1.5 christos RRsets to be included in the answer section of the response. 685 1.1.1.5 christos """ 686 1.1.1.5 christos return [] 687 1.1.1.5 christos 688 1.1.1.5 christos @property 689 1.1.1.5 christos def authority(self) -> Sequence[dns.rrset.RRset]: 690 1.1.1.5 christos """ 691 1.1.1.5 christos RRsets to be included in the authority section of the response. 692 1.1.1.5 christos """ 693 1.1.1.5 christos return [] 694 1.1.1.5 christos 695 1.1.1.5 christos @property 696 1.1.1.5 christos def additional(self) -> Sequence[dns.rrset.RRset]: 697 1.1.1.5 christos """ 698 1.1.1.5 christos RRsets to be included in the additional section of the response. 699 1.1.1.5 christos """ 700 1.1.1.5 christos return [] 701 1.1.1.5 christos 702 1.1.1.5 christos @property 703 1.1.1.5 christos def authoritative(self) -> bool | None: 704 1.1.1.5 christos """ 705 1.1.1.5 christos Whether to set the AA bit in the response. 706 1.1.1.5 christos """ 707 1.1.1.5 christos return None 708 1.1.1.5 christos 709 1.1.1.5 christos @property 710 1.1.1.5 christos def delay(self) -> float: 711 1.1.1.5 christos """ 712 1.1.1.5 christos Delay before sending the response. 713 1.1.1.5 christos """ 714 1.1.1.5 christos return 0.0 715 1.1.1.5 christos 716 1.1.1.5 christos async def get_responses( 717 1.1.1.5 christos self, qctx: QueryContext 718 1.1.1.5 christos ) -> AsyncGenerator[DnsResponseSend, None]: 719 1.1.1.5 christos qctx.prepare_new_response(with_zone_data=False) 720 1.1.1.5 christos qctx.response.answer.extend(self.answer) 721 1.1.1.5 christos qctx.response.authority.extend(self.authority) 722 1.1.1.5 christos qctx.response.additional.extend(self.additional) 723 1.1.1.5 christos if self.rcode is not None: 724 1.1.1.5 christos qctx.response.set_rcode(self.rcode) 725 1.1.1.5 christos yield DnsResponseSend( 726 1.1.1.5 christos qctx.response, authoritative=self.authoritative, delay=self.delay 727 1.1.1.5 christos ) 728 1.1.1.5 christos 729 1.1.1.5 christos 730 1.1 christos class DomainHandler(ResponseHandler): 731 1.1 christos """ 732 1.1 christos Base class used for deriving custom domain handlers. 733 1.1 christos 734 1.1 christos The derived class must specify a list of `domains` that it wants to handle. 735 1.1 christos Queries for any of these domains (and their subdomains) will then be passed 736 1.1 christos to the `get_response()` method in the derived class. 737 1.1.1.5 christos 738 1.1.1.5 christos The most specific matching domain is stored in the `matched_domain` attribute. 739 1.1 christos """ 740 1.1 christos 741 1.1 christos @property 742 1.1 christos @abc.abstractmethod 743 1.1.1.5 christos def domains(self) -> list[str]: 744 1.1 christos """ 745 1.1 christos A list of domain names handled by this class. 746 1.1 christos """ 747 1.1 christos raise NotImplementedError 748 1.1 christos 749 1.1 christos def __init__(self) -> None: 750 1.1.1.5 christos self._domains: list[dns.name.Name] = sorted( 751 1.1.1.5 christos [dns.name.from_text(d) for d in self.domains], reverse=True 752 1.1.1.5 christos ) 753 1.1.1.5 christos self._matched_domain: dns.name.Name | None = None 754 1.1.1.5 christos 755 1.1.1.5 christos @property 756 1.1.1.5 christos def matched_domain(self) -> dns.name.Name: 757 1.1.1.5 christos assert self._matched_domain is not None 758 1.1.1.5 christos return self._matched_domain 759 1.1 christos 760 1.1 christos def __str__(self) -> str: 761 1.1 christos return f"{self.__class__.__name__}(domains: {', '.join(self.domains)})" 762 1.1 christos 763 1.1 christos def match(self, qctx: QueryContext) -> bool: 764 1.1 christos """ 765 1.1 christos Handle queries whose QNAME matches any of the domains handled by this 766 1.1 christos class. 767 1.1 christos """ 768 1.1.1.5 christos self._matched_domain = None 769 1.1 christos for domain in self._domains: 770 1.1 christos if qctx.qname.is_subdomain(domain): 771 1.1.1.5 christos self._matched_domain = domain 772 1.1 christos return True 773 1.1 christos return False 774 1.1 christos 775 1.1 christos 776 1.1.1.5 christos class ForwarderHandler(ResponseHandler): 777 1.1.1.5 christos """ 778 1.1.1.5 christos A handler forwarding all received queries to another DNS server with an 779 1.1.1.5 christos optional delay and then relaying the responses back to the original client. 780 1.1.1.5 christos 781 1.1.1.5 christos Queries are currently always forwarded via UDP. 782 1.1.1.5 christos """ 783 1.1.1.5 christos 784 1.1.1.5 christos @property 785 1.1.1.5 christos @abc.abstractmethod 786 1.1.1.5 christos def target(self) -> str: 787 1.1.1.5 christos """ 788 1.1.1.5 christos The address of the DNS server to forward queries to. 789 1.1.1.5 christos """ 790 1.1.1.5 christos raise NotImplementedError 791 1.1.1.5 christos 792 1.1.1.5 christos @property 793 1.1.1.5 christos def port(self) -> int: 794 1.1.1.5 christos """ 795 1.1.1.5 christos The port of the DNS server to forward queries to. 796 1.1.1.5 christos 797 1.1.1.5 christos The default value of 0 causes the same port as the one used by this 798 1.1.1.5 christos server for listening to be used. 799 1.1.1.5 christos """ 800 1.1.1.5 christos return 0 801 1.1.1.5 christos 802 1.1.1.5 christos @property 803 1.1.1.5 christos def delay(self) -> float: 804 1.1.1.5 christos """ 805 1.1.1.5 christos The number of seconds to wait before forwarding each query. 806 1.1.1.5 christos """ 807 1.1.1.5 christos return 0.0 808 1.1.1.5 christos 809 1.1.1.5 christos def __str__(self) -> str: 810 1.1.1.5 christos return f"{self.__class__.__name__}(target: {self.target}:{self.port})" 811 1.1.1.5 christos 812 1.1.1.5 christos class ForwarderProtocol(asyncio.DatagramProtocol): 813 1.1.1.5 christos def __init__(self, query: bytes, response: asyncio.Future) -> None: 814 1.1.1.5 christos self._query = query 815 1.1.1.5 christos self._response = response 816 1.1.1.5 christos 817 1.1.1.5 christos def connection_made(self, transport: asyncio.BaseTransport) -> None: 818 1.1.1.5 christos logging.debug("[OUT] %s", self._query.hex()) 819 1.1.1.5 christos cast(asyncio.DatagramTransport, transport).sendto(self._query) 820 1.1.1.5 christos 821 1.1.1.5 christos def datagram_received(self, data: bytes, _: tuple[str, int]) -> None: 822 1.1.1.5 christos logging.debug("[IN] %s", data.hex()) 823 1.1.1.5 christos self._response.set_result(data) 824 1.1.1.5 christos 825 1.1.1.5 christos async def get_responses( 826 1.1.1.5 christos self, qctx: QueryContext 827 1.1.1.5 christos ) -> AsyncGenerator[ResponseAction, None]: 828 1.1.1.5 christos loop = asyncio.get_running_loop() 829 1.1.1.5 christos response = loop.create_future() 830 1.1.1.5 christos forwarding_target = f"{self.target}:{self.port or qctx.socket.port}" 831 1.1.1.5 christos 832 1.1.1.5 christos if self.delay > 0: 833 1.1.1.5 christos logging.info( 834 1.1.1.5 christos "Waiting %.1fs before forwarding %s query from %s to %s over UDP", 835 1.1.1.5 christos self.delay, 836 1.1.1.5 christos qctx.protocol.name, 837 1.1.1.5 christos qctx.peer, 838 1.1.1.5 christos forwarding_target, 839 1.1.1.5 christos ) 840 1.1.1.5 christos await asyncio.sleep(self.delay) 841 1.1.1.5 christos 842 1.1.1.5 christos logging.info( 843 1.1.1.5 christos "Forwarding %s query from %s to %s over UDP", 844 1.1.1.5 christos qctx.protocol.name, 845 1.1.1.5 christos qctx.peer, 846 1.1.1.5 christos forwarding_target, 847 1.1.1.5 christos ) 848 1.1.1.5 christos 849 1.1.1.5 christos transport, _ = await loop.create_datagram_endpoint( 850 1.1.1.5 christos lambda: self.ForwarderProtocol(qctx.query.to_wire(), response), 851 1.1.1.5 christos local_addr=(qctx.socket.host, 0), 852 1.1.1.5 christos remote_addr=(self.target, self.port or qctx.socket.port), 853 1.1.1.5 christos ) 854 1.1.1.5 christos 855 1.1.1.5 christos try: 856 1.1.1.5 christos await response 857 1.1.1.5 christos finally: 858 1.1.1.5 christos transport.close() 859 1.1.1.5 christos 860 1.1.1.5 christos logging.info( 861 1.1.1.5 christos "Relaying UDP response from %s to %s over %s", 862 1.1.1.5 christos forwarding_target, 863 1.1.1.5 christos qctx.peer, 864 1.1.1.5 christos qctx.protocol.name, 865 1.1.1.5 christos ) 866 1.1.1.5 christos 867 1.1.1.5 christos try: 868 1.1.1.5 christos message = _DnsMessageWithTsigDisabled.from_wire(response.result()) 869 1.1.1.5 christos yield DnsResponseSend(message, acknowledge_hand_rolled_response=True) 870 1.1.1.5 christos except dns.exception.DNSException: 871 1.1.1.5 christos logging.warning( 872 1.1.1.5 christos "Failed to parse response from %s as a DNS message, relaying it as raw bytes", 873 1.1.1.5 christos forwarding_target, 874 1.1.1.5 christos ) 875 1.1.1.5 christos yield BytesResponseSend(response.result()) 876 1.1.1.5 christos 877 1.1.1.5 christos 878 1.1 christos @dataclass 879 1.1 christos class _ZoneTreeNode: 880 1.1 christos """ 881 1.1 christos A node representing a zone with one origin. 882 1.1 christos """ 883 1.1 christos 884 1.1.1.5 christos zone: dns.zone.Zone | None 885 1.1.1.5 christos children: list["_ZoneTreeNode"] = field(default_factory=list) 886 1.1 christos 887 1.1 christos 888 1.1 christos class _ZoneTree: 889 1.1 christos """ 890 1.1 christos Tree with independent zones. 891 1.1 christos 892 1.1 christos This zone tree is used as a backing structure for the DNS server. The 893 1.1 christos individual zones are independent to allow the (single) server to serve both 894 1.1 christos the parent zone and a child zone if needed. 895 1.1 christos """ 896 1.1 christos 897 1.1 christos def __init__(self) -> None: 898 1.1 christos self._root: _ZoneTreeNode = _ZoneTreeNode(None) 899 1.1 christos 900 1.1 christos def add(self, zone: dns.zone.Zone) -> None: 901 1.1 christos """ 902 1.1 christos Add a zone to the tree and rearrange sub-zones if necessary. 903 1.1 christos """ 904 1.1 christos assert zone.origin 905 1.1 christos best_match = self._find_best_match(zone.origin, self._root) 906 1.1 christos added_node = _ZoneTreeNode(zone) 907 1.1 christos self._move_children(best_match, added_node) 908 1.1 christos best_match.children.append(added_node) 909 1.1 christos 910 1.1 christos def _find_best_match( 911 1.1 christos self, name: dns.name.Name, start_node: _ZoneTreeNode 912 1.1 christos ) -> _ZoneTreeNode: 913 1.1 christos for child in start_node.children: 914 1.1 christos assert child.zone 915 1.1 christos assert child.zone.origin 916 1.1 christos if name.is_subdomain(child.zone.origin): 917 1.1 christos return self._find_best_match(name, child) 918 1.1 christos return start_node 919 1.1 christos 920 1.1 christos def _move_children(self, node_from: _ZoneTreeNode, node_to: _ZoneTreeNode) -> None: 921 1.1 christos assert node_to.zone 922 1.1 christos assert node_to.zone.origin 923 1.1 christos 924 1.1 christos children_to_move = [] 925 1.1 christos for child in node_from.children: 926 1.1 christos assert child.zone 927 1.1 christos assert child.zone.origin 928 1.1 christos if child.zone.origin.is_subdomain(node_to.zone.origin): 929 1.1 christos children_to_move.append(child) 930 1.1 christos 931 1.1 christos for child in children_to_move: 932 1.1 christos node_from.children.remove(child) 933 1.1 christos node_to.children.append(child) 934 1.1 christos 935 1.1.1.5 christos def find_best_zone(self, name: dns.name.Name) -> dns.zone.Zone | None: 936 1.1 christos """ 937 1.1 christos Return the closest matching zone (if any) for the domain name. 938 1.1 christos """ 939 1.1 christos node = self._find_best_match(name, self._root) 940 1.1 christos return node.zone if node != self._root else None 941 1.1 christos 942 1.1 christos 943 1.1.1.4 christos class _DnsMessageWithTsigDisabled(dns.message.Message): 944 1.1.1.4 christos """ 945 1.1.1.4 christos A wrapper for `dns.message.Message` that works around a dnspython bug 946 1.1.1.4 christos causing exceptions to be raised when `make_response()` or `to_wire()` are 947 1.1.1.4 christos called for a message created using `dns.message.from_wire(keyring=False)`. 948 1.1.1.4 christos 949 1.1.1.4 christos See https://github.com/rthalley/dnspython/issues/1205 for more details. 950 1.1.1.4 christos """ 951 1.1.1.4 christos 952 1.1.1.4 christos class _DisableTsigHandling(contextlib.ContextDecorator): 953 1.1.1.5 christos def __init__(self, message: dns.message.Message | None = None) -> None: 954 1.1.1.4 christos self.original_tsig_sign = dns.tsig.sign 955 1.1.1.4 christos self.original_tsig_validate = dns.tsig.validate 956 1.1.1.4 christos if message: 957 1.1.1.4 christos self.tsig = message.tsig 958 1.1.1.4 christos 959 1.1.1.4 christos def __enter__(self) -> None: 960 1.1.1.4 christos """ 961 1.1.1.4 christos Override the `dns.tsig.sign` and `dns.tsig.validate` functions to prevent them 962 1.1.1.4 christos from failing on messages initialized with `dns.message.from_wire(keyring=False)`. 963 1.1.1.4 christos """ 964 1.1.1.4 christos 965 1.1.1.5 christos def sign(*_: Any, **__: Any) -> tuple[dns.rdata.Rdata, None]: 966 1.1.1.4 christos assert self.tsig 967 1.1.1.4 christos return self.tsig[0], None 968 1.1.1.4 christos 969 1.1.1.4 christos def validate(*_: Any, **__: Any) -> None: 970 1.1.1.4 christos return None 971 1.1.1.4 christos 972 1.1.1.4 christos dns.tsig.sign = sign 973 1.1.1.4 christos dns.tsig.validate = validate 974 1.1.1.4 christos 975 1.1.1.4 christos def __exit__(self, *_: Any, **__: Any) -> None: 976 1.1.1.4 christos dns.tsig.sign = self.original_tsig_sign 977 1.1.1.4 christos dns.tsig.validate = self.original_tsig_validate 978 1.1.1.4 christos 979 1.1.1.4 christos @classmethod 980 1.1.1.4 christos def from_wire(cls, wire: bytes) -> "_DnsMessageWithTsigDisabled": 981 1.1.1.4 christos with cls._DisableTsigHandling(): 982 1.1.1.4 christos message = dns.message.from_wire(wire, keyring=False) 983 1.1.1.4 christos message.__class__ = _DnsMessageWithTsigDisabled 984 1.1.1.4 christos 985 1.1.1.4 christos return cast(_DnsMessageWithTsigDisabled, message) 986 1.1.1.4 christos 987 1.1.1.4 christos @property 988 1.1.1.4 christos def had_tsig(self) -> bool: 989 1.1.1.4 christos """ 990 1.1.1.4 christos Override the `had_tsig()` method to always return False, to prevent 991 1.1.1.4 christos `make_response()` from crashing. 992 1.1.1.4 christos """ 993 1.1.1.4 christos return False 994 1.1.1.4 christos 995 1.1.1.4 christos def to_wire(self, *args: Any, **kwargs: Any) -> bytes: 996 1.1.1.4 christos """ 997 1.1.1.4 christos Override the `to_wire()` method to prevent it from trying to sign 998 1.1.1.4 christos the message with TSIG. 999 1.1.1.4 christos """ 1000 1.1.1.4 christos with self._DisableTsigHandling(self): 1001 1.1.1.4 christos return super().to_wire(*args, **kwargs) 1002 1.1.1.4 christos 1003 1.1.1.4 christos 1004 1.1.1.4 christos class _NoKeyringType: 1005 1.1.1.4 christos pass 1006 1.1.1.4 christos 1007 1.1.1.4 christos 1008 1.1.1.5 christos _ASYNCSERVER_RESPONSE_MARKER = "__is_asyncserver_response__" 1009 1.1.1.5 christos 1010 1.1.1.5 christos 1011 1.1.1.5 christos def _make_asyncserver_response(query: dns.message.Message) -> dns.message.Message: 1012 1.1.1.5 christos response = dns.message.make_response(query) 1013 1.1.1.5 christos setattr(response, _ASYNCSERVER_RESPONSE_MARKER, True) 1014 1.1.1.5 christos return response 1015 1.1.1.5 christos 1016 1.1.1.5 christos 1017 1.1.1.5 christos def _is_asyncserver_response(message: dns.message.Message) -> bool: 1018 1.1.1.5 christos return getattr(message, _ASYNCSERVER_RESPONSE_MARKER, False) 1019 1.1.1.5 christos 1020 1.1.1.5 christos 1021 1.1 christos class AsyncDnsServer(AsyncServer): 1022 1.1 christos """ 1023 1.1 christos DNS server which responds to queries based on zone data and/or custom 1024 1.1 christos handlers. 1025 1.1 christos 1026 1.1 christos The server may use custom handlers which allow arbitrary query processing. 1027 1.1 christos These don't need to be standards-compliant and can be used for testing all 1028 1.1 christos sorts of scenarios, including delaying responses, synthesizing them based 1029 1.1 christos on query contents etc. 1030 1.1 christos 1031 1.1 christos The server also loads any zone files (*.db) found in its directory and 1032 1.1 christos serves them. Responses prepared using zone data can then be modified, 1033 1.1 christos replaced, or suppressed by query handlers. Query handlers can also generate 1034 1.1 christos response from scratch, without using zone data at all. 1035 1.1 christos """ 1036 1.1 christos 1037 1.1.1.4 christos def __init__( 1038 1.1.1.4 christos self, 1039 1.1.1.4 christos /, 1040 1.1.1.4 christos default_rcode: dns.rcode.Rcode = dns.rcode.REFUSED, 1041 1.1.1.4 christos default_aa: bool = False, 1042 1.1.1.5 christos keyring: ( 1043 1.1.1.5 christos dict[dns.name.Name, dns.tsig.Key] | None | _NoKeyringType 1044 1.1.1.5 christos ) = _NoKeyringType(), 1045 1.1.1.4 christos acknowledge_manual_dname_handling: bool = False, 1046 1.1.1.4 christos ) -> None: 1047 1.1 christos super().__init__(self._handle_udp, self._handle_tcp, "ans.pid") 1048 1.1 christos 1049 1.1 christos self._zone_tree: _ZoneTree = _ZoneTree() 1050 1.1.1.5 christos self._connection_handler: ConnectionHandler | None = None 1051 1.1.1.5 christos self._response_handlers: list[ResponseHandler] = [] 1052 1.1.1.4 christos self._default_rcode = default_rcode 1053 1.1.1.4 christos self._default_aa = default_aa 1054 1.1.1.4 christos self._keyring = keyring 1055 1.1.1.3 christos self._acknowledge_manual_dname_handling = acknowledge_manual_dname_handling 1056 1.1 christos 1057 1.1.1.3 christos self._load_zones() 1058 1.1 christos 1059 1.1.1.2 christos def install_response_handler( 1060 1.1.1.2 christos self, handler: ResponseHandler, prepend: bool = False 1061 1.1.1.2 christos ) -> None: 1062 1.1 christos """ 1063 1.1.1.2 christos Add a response handler that will be used to handle matching queries. 1064 1.1 christos 1065 1.1 christos Response handlers can modify, replace, or suppress the answers prepared 1066 1.1 christos from zone file contents. 1067 1.1.1.2 christos 1068 1.1.1.2 christos The provided handler is installed at the end of the response handler 1069 1.1.1.2 christos list unless `prepend` is set to True, in which case it is installed at 1070 1.1.1.2 christos the beginning of the response handler list. 1071 1.1 christos """ 1072 1.1 christos logging.info("Installing response handler: %s", handler) 1073 1.1.1.2 christos if prepend: 1074 1.1.1.2 christos self._response_handlers.insert(0, handler) 1075 1.1.1.2 christos else: 1076 1.1.1.2 christos self._response_handlers.append(handler) 1077 1.1.1.2 christos 1078 1.1.1.5 christos def install_response_handlers(self, *handlers: ResponseHandler) -> None: 1079 1.1.1.4 christos for handler in handlers: 1080 1.1.1.4 christos self.install_response_handler(handler) 1081 1.1.1.4 christos 1082 1.1.1.5 christos def replace_response_handlers(self, *new_handlers: ResponseHandler) -> None: 1083 1.1.1.5 christos """ 1084 1.1.1.5 christos Uninstall all currently installed handlers and install the provided ones. 1085 1.1.1.5 christos """ 1086 1.1.1.5 christos logging.info("Uninstalling response handlers: %s", str(self._response_handlers)) 1087 1.1.1.5 christos self._response_handlers.clear() 1088 1.1.1.5 christos self.install_response_handlers(*new_handlers) 1089 1.1.1.5 christos 1090 1.1.1.2 christos def uninstall_response_handler(self, handler: ResponseHandler) -> None: 1091 1.1.1.2 christos """ 1092 1.1.1.2 christos Remove the specified handler from the list of response handlers. 1093 1.1.1.2 christos """ 1094 1.1.1.2 christos logging.info("Uninstalling response handler: %s", handler) 1095 1.1.1.2 christos self._response_handlers.remove(handler) 1096 1.1 christos 1097 1.1.1.4 christos def install_connection_handler(self, handler: ConnectionHandler) -> None: 1098 1.1.1.4 christos """ 1099 1.1.1.4 christos Install a connection handler that will be called when a new TCP 1100 1.1.1.4 christos connection is established. 1101 1.1.1.4 christos """ 1102 1.1.1.4 christos if self._connection_handler: 1103 1.1.1.4 christos raise RuntimeError("Only one connection handler can be installed") 1104 1.1.1.4 christos self._connection_handler = handler 1105 1.1.1.4 christos 1106 1.1 christos def _load_zones(self) -> None: 1107 1.1 christos for entry in os.scandir(): 1108 1.1 christos entry_path = pathlib.Path(entry.path) 1109 1.1 christos if entry_path.suffix != ".db": 1110 1.1 christos continue 1111 1.1.1.3 christos zone = self._load_zone(entry_path) 1112 1.1 christos self._zone_tree.add(zone) 1113 1.1 christos 1114 1.1.1.3 christos def _load_zone(self, zone_file_path: pathlib.Path) -> dns.zone.Zone: 1115 1.1.1.3 christos logging.info("Loading zone file %s", zone_file_path) 1116 1.1.1.4 christos zone = self._load_zone_file(zone_file_path) 1117 1.1.1.3 christos self._abort_if_dname_found_unless_acknowledged(zone) 1118 1.1.1.3 christos return zone 1119 1.1.1.3 christos 1120 1.1.1.4 christos def _load_zone_file(self, zone_file_path: pathlib.Path) -> dns.zone.Zone: 1121 1.1.1.4 christos try: 1122 1.1.1.4 christos zone = self._load_zone_file_with_origin(zone_file_path) 1123 1.1.1.4 christos except dns.zone.UnknownOrigin: 1124 1.1.1.4 christos zone = self._load_zone_file_without_origin(zone_file_path) 1125 1.1.1.4 christos 1126 1.1.1.4 christos return zone 1127 1.1.1.4 christos 1128 1.1.1.4 christos def _load_zone_file_with_origin( 1129 1.1.1.4 christos self, zone_file_path: pathlib.Path 1130 1.1.1.4 christos ) -> dns.zone.Zone: 1131 1.1.1.4 christos zone = dns.zone.from_file(str(zone_file_path), origin=None, relativize=False) 1132 1.1.1.4 christos if zone.origin != dns.name.root: 1133 1.1.1.4 christos error = "only the root zone may use $ORIGIN in the zone file; " 1134 1.1.1.4 christos error += "for every other zone, its origin is determined by " 1135 1.1.1.4 christos error += "the name of the file it is loaded from" 1136 1.1.1.4 christos raise ValueError(error) 1137 1.1.1.4 christos return zone 1138 1.1.1.4 christos 1139 1.1.1.4 christos def _load_zone_file_without_origin( 1140 1.1.1.4 christos self, zone_file_path: pathlib.Path 1141 1.1.1.4 christos ) -> dns.zone.Zone: 1142 1.1.1.4 christos origin = dns.name.from_text(zone_file_path.stem) 1143 1.1.1.4 christos return dns.zone.from_file(str(zone_file_path), origin=origin, relativize=False) 1144 1.1.1.4 christos 1145 1.1.1.3 christos def _abort_if_dname_found_unless_acknowledged(self, zone: dns.zone.Zone) -> None: 1146 1.1.1.3 christos if self._acknowledge_manual_dname_handling: 1147 1.1.1.3 christos return 1148 1.1.1.3 christos 1149 1.1.1.3 christos error = f'DNAME records found in zone "{zone.origin}"; ' 1150 1.1.1.3 christos error += "this server does not handle DNAME in a standards-compliant way; " 1151 1.1.1.3 christos error += "pass `acknowledge_manual_dname_handling=True` to the " 1152 1.1.1.3 christos error += "AsyncDnsServer constructor to acknowledge this and load zone anyway" 1153 1.1.1.3 christos 1154 1.1.1.3 christos for node in zone.nodes.values(): 1155 1.1.1.3 christos for rdataset in node: 1156 1.1.1.3 christos if rdataset.rdtype == dns.rdatatype.DNAME: 1157 1.1.1.3 christos raise ValueError(error) 1158 1.1.1.3 christos 1159 1.1 christos async def _handle_udp( 1160 1.1.1.5 christos self, wire: bytes, addr: tuple[str, int], transport: asyncio.DatagramTransport 1161 1.1 christos ) -> None: 1162 1.1 christos logging.debug("Received UDP message: %s", wire.hex()) 1163 1.1.1.5 christos socket_info = transport.get_extra_info("sockname") 1164 1.1.1.5 christos socket = Peer(socket_info[0], socket_info[1]) 1165 1.1.1.2 christos peer = Peer(addr[0], addr[1]) 1166 1.1.1.5 christos responses = self._handle_query(wire, socket, peer, DnsProtocol.UDP) 1167 1.1 christos async for response in responses: 1168 1.1.1.3 christos logging.debug("Sending UDP message: %s", response.hex()) 1169 1.1.1.2 christos transport.sendto(response, addr) 1170 1.1 christos 1171 1.1 christos async def _handle_tcp( 1172 1.1 christos self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter 1173 1.1 christos ) -> None: 1174 1.1.1.2 christos peer_info = writer.get_extra_info("peername") 1175 1.1.1.2 christos peer = Peer(peer_info[0], peer_info[1]) 1176 1.1.1.2 christos logging.debug("Accepted TCP connection from %s", peer) 1177 1.1 christos 1178 1.1.1.4 christos try: 1179 1.1.1.4 christos if self._connection_handler: 1180 1.1.1.4 christos await self._connection_handler.handle(reader, writer, peer) 1181 1.1.1.4 christos while True: 1182 1.1.1.2 christos wire = await self._read_tcp_query(reader, peer) 1183 1.1.1.2 christos if not wire: 1184 1.1.1.2 christos break 1185 1.1.1.2 christos await self._send_tcp_response(writer, peer, wire) 1186 1.1.1.4 christos except _ConnectionTeardownRequested: 1187 1.1.1.4 christos pass 1188 1.1.1.4 christos except ConnectionResetError: 1189 1.1.1.4 christos logging.error("TCP connection from %s reset by peer", peer) 1190 1.1.1.4 christos return 1191 1.1 christos 1192 1.1.1.2 christos logging.debug("Closing TCP connection from %s", peer) 1193 1.1 christos writer.close() 1194 1.1.1.2 christos try: 1195 1.1.1.2 christos # Python >= 3.7 1196 1.1.1.2 christos await writer.wait_closed() 1197 1.1.1.2 christos except AttributeError: 1198 1.1.1.2 christos # Python < 3.7 1199 1.1.1.2 christos pass 1200 1.1.1.2 christos 1201 1.1.1.2 christos async def _read_tcp_query( 1202 1.1.1.2 christos self, reader: asyncio.StreamReader, peer: Peer 1203 1.1.1.5 christos ) -> bytes | None: 1204 1.1.1.2 christos wire_length = await self._read_tcp_query_wire_length(reader, peer) 1205 1.1.1.2 christos if not wire_length: 1206 1.1.1.2 christos return None 1207 1.1.1.2 christos 1208 1.1.1.2 christos return await self._read_tcp_query_wire(reader, peer, wire_length) 1209 1.1.1.2 christos 1210 1.1.1.2 christos async def _read_tcp_query_wire_length( 1211 1.1.1.2 christos self, reader: asyncio.StreamReader, peer: Peer 1212 1.1.1.5 christos ) -> int | None: 1213 1.1.1.2 christos logging.debug("Receiving TCP message length from %s...", peer) 1214 1.1.1.2 christos 1215 1.1.1.2 christos wire_length_bytes = await self._read_tcp_octets(reader, peer, 2) 1216 1.1.1.2 christos if not wire_length_bytes: 1217 1.1.1.2 christos return None 1218 1.1 christos 1219 1.1.1.2 christos (wire_length,) = struct.unpack("!H", wire_length_bytes) 1220 1.1.1.2 christos 1221 1.1.1.2 christos return wire_length 1222 1.1.1.2 christos 1223 1.1.1.2 christos async def _read_tcp_query_wire( 1224 1.1.1.2 christos self, reader: asyncio.StreamReader, peer: Peer, wire_length: int 1225 1.1.1.5 christos ) -> bytes | None: 1226 1.1.1.2 christos logging.debug("Receiving TCP message (%d octets) from %s...", wire_length, peer) 1227 1.1.1.2 christos 1228 1.1.1.2 christos wire = await self._read_tcp_octets(reader, peer, wire_length) 1229 1.1.1.2 christos if not wire: 1230 1.1.1.2 christos return None 1231 1.1.1.2 christos 1232 1.1.1.2 christos logging.debug("Received complete TCP message from %s: %s", peer, wire.hex()) 1233 1.1.1.2 christos 1234 1.1.1.2 christos return wire 1235 1.1.1.2 christos 1236 1.1.1.2 christos async def _read_tcp_octets( 1237 1.1.1.2 christos self, reader: asyncio.StreamReader, peer: Peer, expected: int 1238 1.1.1.5 christos ) -> bytes | None: 1239 1.1.1.2 christos buffer = b"" 1240 1.1.1.2 christos 1241 1.1.1.2 christos while len(buffer) < expected: 1242 1.1.1.2 christos chunk = await reader.read(expected - len(buffer)) 1243 1.1.1.2 christos if not chunk: 1244 1.1.1.2 christos if buffer: 1245 1.1.1.2 christos logging.debug( 1246 1.1.1.2 christos "Received short TCP message (%d octets) from %s: %s", 1247 1.1.1.2 christos len(buffer), 1248 1.1.1.2 christos peer, 1249 1.1.1.2 christos buffer.hex(), 1250 1.1.1.2 christos ) 1251 1.1.1.2 christos else: 1252 1.1.1.2 christos logging.debug("Received disconnect from %s", peer) 1253 1.1.1.2 christos return None 1254 1.1 christos 1255 1.1.1.2 christos logging.debug("Received %d TCP octets from %s", len(chunk), peer) 1256 1.1.1.2 christos buffer += chunk 1257 1.1.1.2 christos 1258 1.1.1.2 christos return buffer 1259 1.1.1.2 christos 1260 1.1.1.2 christos async def _send_tcp_response( 1261 1.1.1.2 christos self, writer: asyncio.StreamWriter, peer: Peer, wire: bytes 1262 1.1 christos ) -> None: 1263 1.1.1.5 christos socket_info = writer.get_extra_info("sockname") 1264 1.1.1.5 christos socket = Peer(socket_info[0], socket_info[1]) 1265 1.1.1.5 christos responses = self._handle_query(wire, socket, peer, DnsProtocol.TCP) 1266 1.1.1.2 christos async for response in responses: 1267 1.1.1.3 christos logging.debug("Sending TCP response: %s", response.hex()) 1268 1.1.1.2 christos writer.write(response) 1269 1.1.1.2 christos await writer.drain() 1270 1.1.1.2 christos 1271 1.1.1.5 christos def _log_query(self, qctx: QueryContext) -> None: 1272 1.1 christos logging.info( 1273 1.1.1.5 christos "Received %s/%s/%s (ID=%d) query from %s on %s (%s)", 1274 1.1 christos qctx.qname.to_text(omit_final_dot=True), 1275 1.1 christos dns.rdataclass.to_text(qctx.qclass), 1276 1.1 christos dns.rdatatype.to_text(qctx.qtype), 1277 1.1 christos qctx.query.id, 1278 1.1.1.5 christos qctx.peer, 1279 1.1.1.5 christos qctx.socket, 1280 1.1.1.5 christos qctx.protocol.name, 1281 1.1 christos ) 1282 1.1 christos logging.debug( 1283 1.1 christos "\n".join([f"[IN] {l}" for l in [""] + str(qctx.query).splitlines()]) 1284 1.1 christos ) 1285 1.1 christos 1286 1.1 christos def _log_response( 1287 1.1.1.5 christos self, qctx: QueryContext, response: dns.message.Message | bytes | None 1288 1.1 christos ) -> None: 1289 1.1 christos if not response: 1290 1.1 christos logging.info( 1291 1.1.1.5 christos "Not sending a response to query (ID=%d) from %s on %s (%s)", 1292 1.1 christos qctx.query.id, 1293 1.1.1.5 christos qctx.peer, 1294 1.1.1.5 christos qctx.socket, 1295 1.1.1.5 christos qctx.protocol.name, 1296 1.1 christos ) 1297 1.1 christos return 1298 1.1 christos 1299 1.1 christos if isinstance(response, dns.message.Message): 1300 1.1 christos try: 1301 1.1 christos qname = response.question[0].name.to_text(omit_final_dot=True) 1302 1.1 christos qclass = dns.rdataclass.to_text(response.question[0].rdclass) 1303 1.1 christos qtype = dns.rdatatype.to_text(response.question[0].rdtype) 1304 1.1 christos except IndexError: 1305 1.1 christos qname = "<empty>" 1306 1.1 christos qclass = "-" 1307 1.1 christos qtype = "-" 1308 1.1 christos 1309 1.1 christos logging.info( 1310 1.1.1.5 christos "Sending %s/%s/%s (ID=%d) response (%d/%d/%d/%d) to a query (ID=%d) from %s on %s (%s)", 1311 1.1 christos qname, 1312 1.1 christos qclass, 1313 1.1 christos qtype, 1314 1.1 christos response.id, 1315 1.1 christos len(response.question), 1316 1.1 christos len(response.answer), 1317 1.1 christos len(response.authority), 1318 1.1 christos len(response.additional), 1319 1.1 christos qctx.query.id, 1320 1.1.1.5 christos qctx.peer, 1321 1.1.1.5 christos qctx.socket, 1322 1.1.1.5 christos qctx.protocol.name, 1323 1.1 christos ) 1324 1.1 christos logging.debug( 1325 1.1 christos "\n".join([f"[OUT] {l}" for l in [""] + str(response).splitlines()]) 1326 1.1 christos ) 1327 1.1 christos return 1328 1.1 christos 1329 1.1 christos logging.info( 1330 1.1.1.5 christos "Sending response (%d bytes) to a query (ID=%d) from %s on %s (%s)", 1331 1.1 christos len(response), 1332 1.1 christos qctx.query.id, 1333 1.1.1.5 christos qctx.peer, 1334 1.1.1.5 christos qctx.socket, 1335 1.1.1.5 christos qctx.protocol.name, 1336 1.1 christos ) 1337 1.1 christos logging.debug("[OUT] %s", response.hex()) 1338 1.1 christos 1339 1.1 christos async def _handle_query( 1340 1.1.1.5 christos self, wire: bytes, socket: Peer, peer: Peer, protocol: DnsProtocol 1341 1.1 christos ) -> AsyncGenerator[bytes, None]: 1342 1.1 christos """ 1343 1.1 christos Yield wire data to send as a response over the established transport. 1344 1.1 christos """ 1345 1.1.1.2 christos try: 1346 1.1.1.4 christos query = self._parse_message(wire) 1347 1.1.1.2 christos except dns.exception.DNSException as exc: 1348 1.1.1.2 christos logging.error("Invalid query from %s (%s): %s", peer, wire.hex(), exc) 1349 1.1.1.2 christos return 1350 1.1.1.5 christos response_stub = _make_asyncserver_response(query) 1351 1.1.1.5 christos qctx = QueryContext(query, response_stub, socket, peer, protocol) 1352 1.1.1.5 christos self._log_query(qctx) 1353 1.1 christos responses = self._prepare_responses(qctx) 1354 1.1 christos async for response in responses: 1355 1.1.1.5 christos self._log_response(qctx, response) 1356 1.1 christos if response: 1357 1.1 christos if isinstance(response, dns.message.Message): 1358 1.1 christos response = response.to_wire(max_size=65535) 1359 1.1 christos if protocol == DnsProtocol.UDP: 1360 1.1 christos yield response 1361 1.1 christos else: 1362 1.1 christos response_length = struct.pack("!H", len(response)) 1363 1.1 christos yield response_length + response 1364 1.1 christos 1365 1.1.1.4 christos def _parse_message(self, wire: bytes) -> dns.message.Message: 1366 1.1.1.4 christos try: 1367 1.1.1.4 christos if isinstance(self._keyring, _NoKeyringType): 1368 1.1.1.4 christos keyring = None 1369 1.1.1.4 christos else: 1370 1.1.1.4 christos keyring = self._keyring 1371 1.1.1.4 christos return dns.message.from_wire(wire, keyring=keyring) 1372 1.1.1.4 christos except dns.message.UnknownTSIGKey as exc: 1373 1.1.1.4 christos if isinstance(self._keyring, _NoKeyringType): 1374 1.1.1.4 christos error = "TSIG-signed query received but no `keyring` was provided; " 1375 1.1.1.4 christos error += "either provide a keyring (in which case the server will " 1376 1.1.1.4 christos error += "ignore any TSIG-invalid queries), or set `keyring=None` " 1377 1.1.1.4 christos error += "explicitly to disable TSIG validation altogether. " 1378 1.1.1.4 christos error += "This requires some hacking around a dnspython bug, " 1379 1.1.1.4 christos error += "so there may be unexpected side effects." 1380 1.1.1.4 christos raise ValueError(error) from exc 1381 1.1.1.4 christos if self._keyring is None: 1382 1.1.1.4 christos return _DnsMessageWithTsigDisabled.from_wire(wire) 1383 1.1.1.4 christos raise 1384 1.1.1.4 christos 1385 1.1 christos async def _prepare_responses( 1386 1.1 christos self, qctx: QueryContext 1387 1.1.1.5 christos ) -> AsyncGenerator[dns.message.Message | bytes | None, None]: 1388 1.1 christos """ 1389 1.1 christos Yield response(s) either from response handlers or zone data. 1390 1.1 christos """ 1391 1.1.1.4 christos qctx.response.set_rcode(self._default_rcode) 1392 1.1.1.4 christos if self._default_aa: 1393 1.1.1.4 christos qctx.response.flags |= dns.flags.AA 1394 1.1.1.4 christos qctx.save_initialized_response(with_zone_data=False) 1395 1.1.1.4 christos 1396 1.1 christos self._prepare_response_from_zone_data(qctx) 1397 1.1.1.4 christos qctx.save_initialized_response(with_zone_data=True) 1398 1.1 christos 1399 1.1 christos response_handled = False 1400 1.1 christos async for action in self._run_response_handlers(qctx): 1401 1.1 christos yield await action.perform() 1402 1.1 christos response_handled = True 1403 1.1 christos 1404 1.1 christos if not response_handled: 1405 1.1.1.2 christos logging.debug("Responding based on zone data") 1406 1.1 christos yield qctx.response 1407 1.1 christos 1408 1.1 christos def _prepare_response_from_zone_data(self, qctx: QueryContext) -> None: 1409 1.1 christos """ 1410 1.1 christos Prepare a response to the query based on the available zone data. 1411 1.1 christos 1412 1.1 christos The functionality is split across smaller functions that modify the 1413 1.1 christos query context until a proper response is formed. 1414 1.1 christos """ 1415 1.1 christos if self._refused_response(qctx): 1416 1.1 christos return 1417 1.1 christos 1418 1.1 christos if self._delegation_response(qctx): 1419 1.1 christos return 1420 1.1 christos 1421 1.1 christos qctx.response.flags |= dns.flags.AA 1422 1.1 christos 1423 1.1 christos if self._ent_response(qctx): 1424 1.1 christos return 1425 1.1 christos 1426 1.1 christos if self._nxdomain_response(qctx): 1427 1.1 christos return 1428 1.1 christos 1429 1.1.1.3 christos if self._cname_response(qctx): 1430 1.1.1.3 christos return 1431 1.1.1.3 christos 1432 1.1 christos if self._nodata_response(qctx): 1433 1.1 christos return 1434 1.1 christos 1435 1.1 christos self._noerror_response(qctx) 1436 1.1 christos 1437 1.1 christos def _refused_response(self, qctx: QueryContext) -> bool: 1438 1.1.1.3 christos zone = self._zone_tree.find_best_zone(qctx.current_qname) 1439 1.1.1.3 christos if zone: 1440 1.1.1.3 christos qctx.zone = zone 1441 1.1 christos return False 1442 1.1 christos 1443 1.1.1.4 christos # RCODE is already set to self._default_rcode, i.e. REFUSED by default; 1444 1.1.1.4 christos # it should also not be changed when following a CNAME chain 1445 1.1 christos return True 1446 1.1 christos 1447 1.1 christos def _delegation_response(self, qctx: QueryContext) -> bool: 1448 1.1 christos assert qctx.zone 1449 1.1 christos 1450 1.1.1.3 christos name = qctx.current_qname 1451 1.1 christos delegation = None 1452 1.1 christos 1453 1.1 christos while name != qctx.zone.origin: 1454 1.1 christos node = qctx.zone.get_node(name) 1455 1.1 christos if node: 1456 1.1 christos delegation = node.get_rdataset(qctx.qclass, dns.rdatatype.NS) 1457 1.1 christos if delegation: 1458 1.1 christos break 1459 1.1 christos name = name.parent() 1460 1.1 christos 1461 1.1 christos if not delegation: 1462 1.1 christos return False 1463 1.1 christos 1464 1.1 christos delegation_rrset = dns.rrset.RRset(name, qctx.qclass, dns.rdatatype.NS) 1465 1.1 christos delegation_rrset.update(delegation) 1466 1.1 christos 1467 1.1 christos qctx.response.set_rcode(dns.rcode.NOERROR) 1468 1.1 christos qctx.response.authority.append(delegation_rrset) 1469 1.1 christos 1470 1.1 christos self._delegation_response_additional(qctx) 1471 1.1 christos 1472 1.1 christos return True 1473 1.1 christos 1474 1.1 christos def _delegation_response_additional(self, qctx: QueryContext) -> None: 1475 1.1 christos assert qctx.zone 1476 1.1 christos assert qctx.response.authority[0] 1477 1.1 christos 1478 1.1 christos for nameserver in qctx.response.authority[0]: 1479 1.1 christos if not nameserver.target.is_subdomain(qctx.response.authority[0].name): 1480 1.1 christos continue 1481 1.1 christos glue_a = qctx.zone.get_rrset(nameserver.target, dns.rdatatype.A) 1482 1.1 christos if glue_a: 1483 1.1 christos qctx.response.additional.append(glue_a) 1484 1.1 christos glue_aaaa = qctx.zone.get_rrset(nameserver.target, dns.rdatatype.AAAA) 1485 1.1 christos if glue_aaaa: 1486 1.1 christos qctx.response.additional.append(glue_aaaa) 1487 1.1 christos 1488 1.1 christos def _ent_response(self, qctx: QueryContext) -> bool: 1489 1.1 christos assert qctx.zone 1490 1.1 christos assert qctx.zone.origin 1491 1.1 christos 1492 1.1 christos qctx.soa = qctx.zone.find_rrset(qctx.zone.origin, dns.rdatatype.SOA) 1493 1.1 christos assert qctx.soa 1494 1.1 christos 1495 1.1.1.3 christos qctx.node = qctx.zone.get_node(qctx.current_qname) 1496 1.1 christos if qctx.node or not any( 1497 1.1.1.3 christos n for n in qctx.zone.nodes if n.is_subdomain(qctx.current_qname) 1498 1.1 christos ): 1499 1.1 christos return False 1500 1.1 christos 1501 1.1 christos qctx.response.set_rcode(dns.rcode.NOERROR) 1502 1.1 christos qctx.response.authority.append(qctx.soa) 1503 1.1 christos return True 1504 1.1 christos 1505 1.1 christos def _nxdomain_response(self, qctx: QueryContext) -> bool: 1506 1.1 christos assert qctx.soa 1507 1.1 christos 1508 1.1 christos if qctx.node: 1509 1.1 christos return False 1510 1.1 christos 1511 1.1 christos qctx.response.set_rcode(dns.rcode.NXDOMAIN) 1512 1.1 christos qctx.response.authority.append(qctx.soa) 1513 1.1 christos return True 1514 1.1 christos 1515 1.1.1.3 christos def _cname_response(self, qctx: QueryContext) -> bool: 1516 1.1.1.3 christos assert qctx.node 1517 1.1.1.3 christos 1518 1.1.1.3 christos cname = qctx.node.get_rdataset(qctx.qclass, dns.rdatatype.CNAME) 1519 1.1.1.3 christos if not cname: 1520 1.1.1.3 christos return False 1521 1.1.1.3 christos 1522 1.1.1.4 christos qctx.response.set_rcode(dns.rcode.NOERROR) 1523 1.1.1.3 christos cname_rrset = dns.rrset.RRset(qctx.current_qname, qctx.qclass, cname.rdtype) 1524 1.1.1.3 christos cname_rrset.update(cname) 1525 1.1.1.3 christos qctx.response.answer.append(cname_rrset) 1526 1.1.1.3 christos 1527 1.1.1.3 christos qctx.alias = cname[0].target 1528 1.1.1.3 christos self._prepare_response_from_zone_data(qctx) 1529 1.1.1.3 christos return True 1530 1.1.1.3 christos 1531 1.1 christos def _nodata_response(self, qctx: QueryContext) -> bool: 1532 1.1 christos assert qctx.node 1533 1.1 christos assert qctx.soa 1534 1.1 christos 1535 1.1 christos qctx.answer = qctx.node.get_rdataset(qctx.qclass, qctx.qtype) 1536 1.1 christos if qctx.answer: 1537 1.1 christos return False 1538 1.1 christos 1539 1.1 christos qctx.response.set_rcode(dns.rcode.NOERROR) 1540 1.1.1.3 christos if not qctx.response.answer: 1541 1.1.1.3 christos qctx.response.authority.append(qctx.soa) 1542 1.1 christos return True 1543 1.1 christos 1544 1.1 christos def _noerror_response(self, qctx: QueryContext) -> None: 1545 1.1 christos assert qctx.answer 1546 1.1 christos 1547 1.1.1.3 christos answer_rrset = dns.rrset.RRset(qctx.current_qname, qctx.qclass, qctx.qtype) 1548 1.1 christos answer_rrset.update(qctx.answer) 1549 1.1 christos 1550 1.1 christos qctx.response.set_rcode(dns.rcode.NOERROR) 1551 1.1 christos qctx.response.answer.append(answer_rrset) 1552 1.1 christos 1553 1.1 christos async def _run_response_handlers( 1554 1.1 christos self, qctx: QueryContext 1555 1.1 christos ) -> AsyncGenerator[ResponseAction, None]: 1556 1.1 christos """ 1557 1.1 christos Yield response(s) to the query from a matching query handler. 1558 1.1 christos """ 1559 1.1 christos for handler in self._response_handlers: 1560 1.1 christos if handler.match(qctx): 1561 1.1.1.2 christos logging.debug("Matched response handler: %s", handler) 1562 1.1 christos async for response in handler.get_responses(qctx): 1563 1.1 christos yield response 1564 1.1 christos return 1565 1.1.1.2 christos 1566 1.1.1.2 christos 1567 1.1.1.2 christos class ControllableAsyncDnsServer(AsyncDnsServer): 1568 1.1.1.2 christos """ 1569 1.1.1.2 christos An AsyncDnsServer whose behavior can be dynamically changed by sending TXT 1570 1.1.1.2 christos queries to a "magic" domain. 1571 1.1.1.2 christos """ 1572 1.1.1.2 christos 1573 1.1.1.2 christos _CONTROL_DOMAIN = "_control." 1574 1.1.1.2 christos 1575 1.1.1.4 christos @functools.cached_property 1576 1.1.1.4 christos def _control_domain(self) -> dns.name.Name: 1577 1.1.1.4 christos return dns.name.from_text(self._CONTROL_DOMAIN) 1578 1.1.1.4 christos 1579 1.1.1.4 christos @functools.cached_property 1580 1.1.1.5 christos def _commands(self) -> dict[dns.name.Name, "ControlCommand"]: 1581 1.1.1.4 christos return {} 1582 1.1.1.4 christos 1583 1.1.1.5 christos def install_control_commands(self, *commands: "ControlCommand") -> None: 1584 1.1.1.4 christos for command in commands: 1585 1.1.1.4 christos self.install_control_command(command) 1586 1.1.1.4 christos 1587 1.1.1.4 christos def install_control_command(self, command: "ControlCommand") -> None: 1588 1.1.1.4 christos command_subdomain = dns.name.Name([command.control_subdomain]) 1589 1.1.1.4 christos control_subdomain = command_subdomain.concatenate(self._control_domain) 1590 1.1.1.4 christos try: 1591 1.1.1.4 christos existing_command = self._commands[control_subdomain] 1592 1.1.1.4 christos except KeyError: 1593 1.1.1.4 christos self._commands[control_subdomain] = command 1594 1.1.1.4 christos else: 1595 1.1.1.4 christos raise RuntimeError( 1596 1.1.1.4 christos f"{control_subdomain} already handled by {existing_command}" 1597 1.1.1.4 christos ) 1598 1.1.1.2 christos 1599 1.1.1.2 christos async def _prepare_responses( 1600 1.1.1.2 christos self, qctx: QueryContext 1601 1.1.1.5 christos ) -> AsyncGenerator[dns.message.Message | bytes | None, None]: 1602 1.1.1.2 christos """ 1603 1.1.1.2 christos Detect and handle control queries, falling back to normal processing 1604 1.1.1.2 christos for non-control queries. 1605 1.1.1.2 christos """ 1606 1.1.1.2 christos control_response = self._handle_control_command(qctx) 1607 1.1.1.2 christos if control_response: 1608 1.1.1.2 christos yield await DnsResponseSend(response=control_response).perform() 1609 1.1.1.2 christos return 1610 1.1.1.2 christos 1611 1.1.1.2 christos async for response in super()._prepare_responses(qctx): 1612 1.1.1.2 christos yield response 1613 1.1.1.2 christos 1614 1.1.1.5 christos def _handle_control_command(self, qctx: QueryContext) -> dns.message.Message | None: 1615 1.1.1.2 christos """ 1616 1.1.1.2 christos Detect and handle control queries. 1617 1.1.1.2 christos 1618 1.1.1.2 christos A control query must be of type TXT; if it is not, a FORMERR response 1619 1.1.1.2 christos is sent back. 1620 1.1.1.2 christos 1621 1.1.1.2 christos The list of commands that the server should respond to is passed to its 1622 1.1.1.2 christos constructor. If the server is unable to handle the control query using 1623 1.1.1.2 christos any of the enabled commands, an NXDOMAIN response is sent. 1624 1.1.1.2 christos 1625 1.1.1.2 christos Otherwise, the relevant command's handler is expected to provide the 1626 1.1.1.2 christos response via qctx.response and/or return a string that is converted to 1627 1.1.1.2 christos a TXT RRset inserted into the ANSWER section of the response to the 1628 1.1.1.2 christos control query. The RCODE for a command-provided response defaults to 1629 1.1.1.2 christos NOERROR, but can be overridden by the command's handler. 1630 1.1.1.2 christos """ 1631 1.1.1.2 christos if not qctx.qname.is_subdomain(self._control_domain): 1632 1.1.1.2 christos return None 1633 1.1.1.2 christos 1634 1.1.1.2 christos if qctx.qtype != dns.rdatatype.TXT: 1635 1.1.1.2 christos logging.error("Non-TXT control query %s from %s", qctx.qname, qctx.peer) 1636 1.1.1.2 christos qctx.response.set_rcode(dns.rcode.FORMERR) 1637 1.1.1.2 christos return qctx.response 1638 1.1.1.2 christos 1639 1.1.1.2 christos control_subdomain = dns.name.Name(qctx.qname.labels[-3:]) 1640 1.1.1.2 christos try: 1641 1.1.1.2 christos command = self._commands[control_subdomain] 1642 1.1.1.2 christos except KeyError: 1643 1.1.1.2 christos logging.error("Unhandled control query %s from %s", qctx.qname, qctx.peer) 1644 1.1.1.2 christos qctx.response.set_rcode(dns.rcode.NXDOMAIN) 1645 1.1.1.2 christos return qctx.response 1646 1.1.1.2 christos 1647 1.1.1.2 christos logging.info("Received control query %s from %s", qctx.qname, qctx.peer) 1648 1.1.1.2 christos logging.debug("Handling control query %s using %s", qctx.qname, command) 1649 1.1.1.2 christos qctx.response.set_rcode(dns.rcode.NOERROR) 1650 1.1.1.2 christos qctx.response.flags |= dns.flags.AA 1651 1.1.1.2 christos 1652 1.1.1.2 christos command_qname = qctx.qname.relativize(control_subdomain) 1653 1.1.1.2 christos try: 1654 1.1.1.2 christos command_args = [l.decode("ascii") for l in command_qname.labels] 1655 1.1.1.2 christos except UnicodeDecodeError: 1656 1.1.1.2 christos logging.error("Non-ASCII control query %s from %s", qctx.qname, qctx.peer) 1657 1.1.1.2 christos qctx.response.set_rcode(dns.rcode.FORMERR) 1658 1.1.1.2 christos return qctx.response 1659 1.1.1.2 christos 1660 1.1.1.2 christos command_response = command.handle(command_args, self, qctx) 1661 1.1.1.2 christos if command_response: 1662 1.1.1.2 christos command_response_rrset = dns.rrset.from_text( 1663 1.1.1.2 christos qctx.qname, 0, qctx.qclass, dns.rdatatype.TXT, f'"{command_response}"' 1664 1.1.1.2 christos ) 1665 1.1.1.2 christos qctx.response.answer.append(command_response_rrset) 1666 1.1.1.2 christos 1667 1.1.1.2 christos return qctx.response 1668 1.1.1.2 christos 1669 1.1.1.2 christos 1670 1.1.1.2 christos class ControlCommand(abc.ABC): 1671 1.1.1.2 christos """ 1672 1.1.1.2 christos Base class for control commands. 1673 1.1.1.2 christos 1674 1.1.1.2 christos The derived class must define the control query subdomain that it handles 1675 1.1.1.2 christos and the callback that handles the control queries. 1676 1.1.1.2 christos """ 1677 1.1.1.2 christos 1678 1.1.1.2 christos @property 1679 1.1.1.2 christos @abc.abstractmethod 1680 1.1.1.2 christos def control_subdomain(self) -> str: 1681 1.1.1.2 christos """ 1682 1.1.1.2 christos The subdomain of the control domain handled by this command. Needs to 1683 1.1.1.2 christos be defined as a string by the derived class. 1684 1.1.1.2 christos """ 1685 1.1.1.2 christos raise NotImplementedError 1686 1.1.1.2 christos 1687 1.1.1.2 christos @abc.abstractmethod 1688 1.1.1.2 christos def handle( 1689 1.1.1.5 christos self, args: list[str], server: ControllableAsyncDnsServer, qctx: QueryContext 1690 1.1.1.5 christos ) -> str | None: 1691 1.1.1.2 christos """ 1692 1.1.1.2 christos This method is expected to carry out arbitrary actions in response to a 1693 1.1.1.2 christos control query. Note that it is invoked synchronously (it is not a 1694 1.1.1.2 christos coroutine). 1695 1.1.1.2 christos 1696 1.1.1.2 christos `args` is a list of arguments for the command extracted from the 1697 1.1.1.2 christos control query's QNAME; these arguments (and therefore the QNAME as 1698 1.1.1.2 christos well) must only contain ASCII characters. For example, if a command's 1699 1.1.1.2 christos subdomain is `my-command`, control query `foo.bar.my-command._control.` 1700 1.1.1.2 christos causes `args` to be set to `["foo", "bar"]` while control query 1701 1.1.1.2 christos `my-command._control.` causes `args` to be set to `[]`. 1702 1.1.1.2 christos 1703 1.1.1.2 christos `server` is the server instance that received the control query. This 1704 1.1.1.2 christos method can change the server's behavior by altering its response 1705 1.1.1.2 christos handler list using the appropriate methods. 1706 1.1.1.2 christos 1707 1.1.1.2 christos `qctx` is the query context for the control query. By operating on 1708 1.1.1.2 christos qctx.response, this method can prepare the DNS response sent to 1709 1.1.1.2 christos the client in response to the control query. Alternatively (or in 1710 1.1.1.2 christos addition to the above), it can also return a string; if it does, the 1711 1.1.1.2 christos returned string is converted to a TXT RRset that is inserted into the 1712 1.1.1.2 christos ANSWER section of the response to the control query. 1713 1.1.1.2 christos """ 1714 1.1.1.2 christos raise NotImplementedError 1715 1.1.1.2 christos 1716 1.1.1.2 christos def __str__(self) -> str: 1717 1.1.1.2 christos return self.__class__.__name__ 1718 1.1.1.2 christos 1719 1.1.1.2 christos 1720 1.1.1.2 christos class ToggleResponsesCommand(ControlCommand): 1721 1.1.1.2 christos """ 1722 1.1.1.2 christos Disable/enable sending responses from the server. 1723 1.1.1.2 christos """ 1724 1.1.1.2 christos 1725 1.1.1.2 christos control_subdomain = "send-responses" 1726 1.1.1.2 christos 1727 1.1.1.2 christos def __init__(self) -> None: 1728 1.1.1.5 christos self._current_handler: IgnoreAllQueries | None = None 1729 1.1.1.2 christos 1730 1.1.1.2 christos def handle( 1731 1.1.1.5 christos self, args: list[str], server: ControllableAsyncDnsServer, qctx: QueryContext 1732 1.1.1.5 christos ) -> str | None: 1733 1.1.1.2 christos if len(args) != 1: 1734 1.1.1.2 christos logging.error("Invalid %s query %s", self, qctx.qname) 1735 1.1.1.2 christos qctx.response.set_rcode(dns.rcode.SERVFAIL) 1736 1.1.1.2 christos return "invalid query; use exactly one of 'enable' or 'disable' in QNAME" 1737 1.1.1.2 christos 1738 1.1.1.2 christos mode = args[0] 1739 1.1.1.2 christos 1740 1.1.1.2 christos if mode == "disable": 1741 1.1.1.2 christos if self._current_handler: 1742 1.1.1.2 christos return "sending responses already disabled" 1743 1.1.1.2 christos self._current_handler = IgnoreAllQueries() 1744 1.1.1.2 christos server.install_response_handler(self._current_handler, prepend=True) 1745 1.1.1.2 christos return "sending responses disabled" 1746 1.1.1.2 christos 1747 1.1.1.2 christos if mode == "enable": 1748 1.1.1.2 christos if not self._current_handler: 1749 1.1.1.2 christos return "sending responses already enabled" 1750 1.1.1.2 christos server.uninstall_response_handler(self._current_handler) 1751 1.1.1.2 christos self._current_handler = None 1752 1.1.1.2 christos return "sending responses enabled" 1753 1.1.1.2 christos 1754 1.1.1.2 christos logging.error("Unrecognized response sending mode '%s'", mode) 1755 1.1.1.2 christos qctx.response.set_rcode(dns.rcode.SERVFAIL) 1756 1.1.1.2 christos return f"unrecognized response sending mode '{mode}'" 1757 1.1.1.5 christos 1758 1.1.1.5 christos 1759 1.1.1.5 christos class SwitchControlCommand(ControlCommand): 1760 1.1.1.5 christos """ 1761 1.1.1.5 christos Switch the server's response handlers based on the control query. 1762 1.1.1.5 christos 1763 1.1.1.5 christos A sequence of response handlers is associated with each key. When a 1764 1.1.1.5 christos control query is received, the server's response handlers are replaced 1765 1.1.1.5 christos with the sequence associated with the key extracted from the control 1766 1.1.1.5 christos query. 1767 1.1.1.5 christos """ 1768 1.1.1.5 christos 1769 1.1.1.5 christos control_subdomain = "switch" 1770 1.1.1.5 christos 1771 1.1.1.5 christos def __init__(self, handler_mapping: dict[str, Sequence[ResponseHandler]]): 1772 1.1.1.5 christos self._handler_mapping = handler_mapping 1773 1.1.1.5 christos 1774 1.1.1.5 christos def handle( 1775 1.1.1.5 christos self, args: list[str], server: ControllableAsyncDnsServer, qctx: QueryContext 1776 1.1.1.5 christos ) -> str | None: 1777 1.1.1.5 christos if len(args) != 1 or args[0] not in self._handler_mapping: 1778 1.1.1.5 christos logging.error("Invalid %s query %s", self, qctx.qname) 1779 1.1.1.5 christos qctx.response.set_rcode(dns.rcode.SERVFAIL) 1780 1.1.1.5 christos return f"invalid query; exactly one of {list(self._handler_mapping.keys())} is expected in QNAME" 1781 1.1.1.5 christos 1782 1.1.1.5 christos server.replace_response_handlers(*self._handler_mapping[args[0]]) 1783 1.1.1.5 christos return f"switched to handler set '{args[0]}'" 1784