Home | History | Annotate | Line # | Download | only in isctest
      1      1.1  christos """
      2      1.1  christos Copyright (C) Internet Systems Consortium, Inc. ("ISC")
      3      1.1  christos 
      4      1.1  christos SPDX-License-Identifier: MPL-2.0
      5      1.1  christos 
      6      1.1  christos This Source Code Form is subject to the terms of the Mozilla Public
      7      1.1  christos License, v. 2.0.  If a copy of the MPL was not distributed with this
      8      1.1  christos file, you can obtain one at https://mozilla.org/MPL/2.0/.
      9      1.1  christos 
     10      1.1  christos See the COPYRIGHT file distributed with this work for additional
     11      1.1  christos information regarding copyright ownership.
     12      1.1  christos """
     13      1.1  christos 
     14  1.1.1.5  christos from collections.abc import AsyncGenerator, Callable, Coroutine, Sequence
     15      1.1  christos from dataclasses import dataclass, field
     16  1.1.1.5  christos from typing import Any, cast
     17      1.1  christos 
     18      1.1  christos import abc
     19      1.1  christos import asyncio
     20  1.1.1.4  christos import contextlib
     21  1.1.1.4  christos import copy
     22      1.1  christos import enum
     23      1.1  christos import functools
     24      1.1  christos import logging
     25      1.1  christos import os
     26      1.1  christos import pathlib
     27      1.1  christos import re
     28      1.1  christos import signal
     29      1.1  christos import struct
     30      1.1  christos import sys
     31      1.1  christos 
     32  1.1.1.4  christos import dns.exception
     33      1.1  christos import dns.flags
     34      1.1  christos import dns.message
     35      1.1  christos import dns.name
     36      1.1  christos import dns.node
     37      1.1  christos import dns.rcode
     38  1.1.1.4  christos import dns.rdata
     39      1.1  christos import dns.rdataclass
     40  1.1.1.4  christos import dns.rdataset
     41      1.1  christos import dns.rdatatype
     42      1.1  christos import dns.rrset
     43  1.1.1.4  christos import dns.tsig
     44      1.1  christos import dns.zone
     45      1.1  christos 
     46      1.1  christos _UdpHandler = Callable[
     47  1.1.1.5  christos     [bytes, tuple[str, int], asyncio.DatagramTransport], Coroutine[Any, Any, None]
     48      1.1  christos ]
     49      1.1  christos 
     50      1.1  christos 
     51      1.1  christos _TcpHandler = Callable[
     52      1.1  christos     [asyncio.StreamReader, asyncio.StreamWriter], Coroutine[Any, Any, None]
     53      1.1  christos ]
     54      1.1  christos 
     55      1.1  christos 
     56      1.1  christos class _AsyncUdpHandler(asyncio.DatagramProtocol):
     57      1.1  christos     """
     58      1.1  christos     Protocol implementation for handling UDP traffic using asyncio.
     59      1.1  christos     """
     60      1.1  christos 
     61      1.1  christos     def __init__(
     62      1.1  christos         self,
     63      1.1  christos         handler: _UdpHandler,
     64      1.1  christos     ) -> None:
     65  1.1.1.5  christos         self._transport: asyncio.DatagramTransport | None = None
     66      1.1  christos         self._handler: _UdpHandler = handler
     67      1.1  christos 
     68      1.1  christos     def connection_made(self, transport: asyncio.BaseTransport) -> None:
     69      1.1  christos         """
     70      1.1  christos         Called by asyncio when a connection is made.
     71      1.1  christos         """
     72      1.1  christos         self._transport = cast(asyncio.DatagramTransport, transport)
     73      1.1  christos 
     74  1.1.1.5  christos     def datagram_received(self, data: bytes, addr: tuple[str, int]) -> None:
     75      1.1  christos         """
     76      1.1  christos         Called by asyncio when a datagram is received.
     77      1.1  christos         """
     78      1.1  christos         assert self._transport
     79      1.1  christos         handler_coroutine = self._handler(data, addr, self._transport)
     80      1.1  christos         try:
     81      1.1  christos             # Python >= 3.7
     82      1.1  christos             asyncio.create_task(handler_coroutine)
     83      1.1  christos         except AttributeError:
     84      1.1  christos             # Python < 3.7
     85      1.1  christos             loop = asyncio.get_event_loop()
     86      1.1  christos             loop.create_task(handler_coroutine)
     87      1.1  christos 
     88      1.1  christos 
     89      1.1  christos class AsyncServer:
     90      1.1  christos     """
     91      1.1  christos     A generic asynchronous server which may handle UDP and/or TCP traffic.
     92      1.1  christos 
     93      1.1  christos     Once the server is executed as asyncio coroutine, it will keep running
     94      1.1  christos     until a SIGINT/SIGTERM signal is received.
     95      1.1  christos     """
     96      1.1  christos 
     97      1.1  christos     def __init__(
     98      1.1  christos         self,
     99  1.1.1.5  christos         udp_handler: _UdpHandler | None,
    100  1.1.1.5  christos         tcp_handler: _TcpHandler | None,
    101  1.1.1.5  christos         pidfile: str | None = None,
    102      1.1  christos     ) -> None:
    103      1.1  christos         logging.basicConfig(
    104      1.1  christos             format="%(asctime)s %(levelname)8s  %(message)s",
    105      1.1  christos             level=os.environ.get("ANS_LOG_LEVEL", "INFO").upper(),
    106      1.1  christos         )
    107      1.1  christos         try:
    108      1.1  christos             ipv4_address = sys.argv[1]
    109      1.1  christos         except IndexError:
    110      1.1  christos             ipv4_address = self._get_ipv4_address_from_directory_name()
    111      1.1  christos 
    112      1.1  christos         last_ipv4_address_octet = ipv4_address.split(".")[-1]
    113      1.1  christos         ipv6_address = f"fd92:7065:b8e:ffff::{last_ipv4_address_octet}"
    114      1.1  christos 
    115      1.1  christos         try:
    116      1.1  christos             port = int(sys.argv[2])
    117      1.1  christos         except IndexError:
    118      1.1  christos             port = int(os.environ.get("PORT", 5300))
    119      1.1  christos 
    120      1.1  christos         logging.info("Setting up IPv4 listener at %s:%d", ipv4_address, port)
    121      1.1  christos         logging.info("Setting up IPv6 listener at [%s]:%d", ipv6_address, port)
    122      1.1  christos 
    123  1.1.1.5  christos         self._ip_addresses: tuple[str, str] = (ipv4_address, ipv6_address)
    124      1.1  christos         self._port: int = port
    125  1.1.1.5  christos         self._udp_handler: _UdpHandler | None = udp_handler
    126  1.1.1.5  christos         self._tcp_handler: _TcpHandler | None = tcp_handler
    127  1.1.1.5  christos         self._pidfile: str | None = pidfile
    128  1.1.1.5  christos         self._work_done: asyncio.Future | None = None
    129  1.1.1.4  christos 
    130      1.1  christos     def _get_ipv4_address_from_directory_name(self) -> str:
    131      1.1  christos         containing_directory = pathlib.Path().absolute().stem
    132      1.1  christos         match_result = re.match(r"ans(?P<index>\d+)", containing_directory)
    133      1.1  christos         if not match_result:
    134      1.1  christos             raise RuntimeError("Unable to auto-determine the IPv4 address to use")
    135      1.1  christos 
    136      1.1  christos         return f"10.53.0.{match_result.group('index')}"
    137      1.1  christos 
    138      1.1  christos     def run(self) -> None:
    139      1.1  christos         """
    140      1.1  christos         Start the server in an asynchronous coroutine.
    141      1.1  christos         """
    142      1.1  christos         coroutine = self._run
    143      1.1  christos         try:
    144      1.1  christos             # Python >= 3.7
    145      1.1  christos             asyncio.run(coroutine())
    146      1.1  christos         except AttributeError:
    147      1.1  christos             # Python < 3.7
    148      1.1  christos             loop = asyncio.get_event_loop()
    149      1.1  christos             loop.run_until_complete(coroutine())
    150      1.1  christos 
    151      1.1  christos     async def _run(self) -> None:
    152  1.1.1.2  christos         self._setup_exception_handler()
    153      1.1  christos         self._setup_signals()
    154      1.1  christos         assert self._work_done
    155      1.1  christos         await self._listen_udp()
    156      1.1  christos         await self._listen_tcp()
    157      1.1  christos         self._write_pidfile()
    158      1.1  christos         await self._work_done
    159      1.1  christos         self._cleanup_pidfile()
    160      1.1  christos 
    161      1.1  christos     def _get_asyncio_loop(self) -> asyncio.AbstractEventLoop:
    162      1.1  christos         try:
    163      1.1  christos             # Python >= 3.7
    164      1.1  christos             loop = asyncio.get_running_loop()
    165      1.1  christos         except AttributeError:
    166      1.1  christos             # Python < 3.7
    167      1.1  christos             loop = asyncio.get_event_loop()
    168      1.1  christos         return loop
    169      1.1  christos 
    170  1.1.1.2  christos     def _setup_exception_handler(self) -> None:
    171      1.1  christos         loop = self._get_asyncio_loop()
    172      1.1  christos         self._work_done = loop.create_future()
    173  1.1.1.2  christos         loop.set_exception_handler(self._handle_exception)
    174  1.1.1.2  christos 
    175  1.1.1.2  christos     def _handle_exception(
    176  1.1.1.5  christos         self, _: asyncio.AbstractEventLoop, context: dict[str, Any]
    177  1.1.1.2  christos     ) -> None:
    178  1.1.1.2  christos         assert self._work_done
    179  1.1.1.2  christos         exception = context.get("exception", RuntimeError(context["message"]))
    180  1.1.1.4  christos         try:
    181  1.1.1.4  christos             self._work_done.set_exception(exception)
    182  1.1.1.4  christos         except asyncio.InvalidStateError:
    183  1.1.1.4  christos             pass
    184  1.1.1.2  christos 
    185  1.1.1.2  christos     def _setup_signals(self) -> None:
    186  1.1.1.2  christos         loop = self._get_asyncio_loop()
    187      1.1  christos         loop.add_signal_handler(signal.SIGINT, functools.partial(self._signal_done))
    188      1.1  christos         loop.add_signal_handler(signal.SIGTERM, functools.partial(self._signal_done))
    189      1.1  christos 
    190      1.1  christos     def _signal_done(self) -> None:
    191      1.1  christos         assert self._work_done
    192  1.1.1.4  christos         try:
    193  1.1.1.4  christos             self._work_done.set_result(True)
    194  1.1.1.4  christos         except asyncio.InvalidStateError:
    195  1.1.1.4  christos             pass
    196      1.1  christos 
    197      1.1  christos     async def _listen_udp(self) -> None:
    198      1.1  christos         if not self._udp_handler:
    199      1.1  christos             return
    200      1.1  christos         loop = self._get_asyncio_loop()
    201      1.1  christos         for ip_address in self._ip_addresses:
    202      1.1  christos             await loop.create_datagram_endpoint(
    203      1.1  christos                 lambda: _AsyncUdpHandler(cast(_UdpHandler, self._udp_handler)),
    204      1.1  christos                 (ip_address, self._port),
    205      1.1  christos             )
    206      1.1  christos 
    207      1.1  christos     async def _listen_tcp(self) -> None:
    208      1.1  christos         if not self._tcp_handler:
    209      1.1  christos             return
    210      1.1  christos         for ip_address in self._ip_addresses:
    211      1.1  christos             await asyncio.start_server(
    212      1.1  christos                 self._tcp_handler, host=ip_address, port=self._port
    213      1.1  christos             )
    214      1.1  christos 
    215      1.1  christos     def _write_pidfile(self) -> None:
    216      1.1  christos         if not self._pidfile:
    217      1.1  christos             return
    218      1.1  christos         logging.info("Writing PID to %s", self._pidfile)
    219      1.1  christos         with open(self._pidfile, "w", encoding="ascii") as pidfile:
    220      1.1  christos             print(f"{os.getpid()}", file=pidfile)
    221      1.1  christos 
    222      1.1  christos     def _cleanup_pidfile(self) -> None:
    223      1.1  christos         if not self._pidfile:
    224      1.1  christos             return
    225      1.1  christos         logging.info("Removing %s", self._pidfile)
    226      1.1  christos         os.unlink(self._pidfile)
    227      1.1  christos 
    228      1.1  christos 
    229      1.1  christos class DnsProtocol(enum.Enum):
    230      1.1  christos     UDP = enum.auto()
    231      1.1  christos     TCP = enum.auto()
    232      1.1  christos 
    233      1.1  christos 
    234  1.1.1.2  christos @dataclass(frozen=True)
    235  1.1.1.2  christos class Peer:
    236  1.1.1.2  christos     """
    237  1.1.1.2  christos     Pretty-printed connection endpoint.
    238  1.1.1.2  christos     """
    239  1.1.1.2  christos 
    240  1.1.1.2  christos     host: str
    241  1.1.1.2  christos     port: int
    242  1.1.1.2  christos 
    243  1.1.1.2  christos     def __str__(self) -> str:
    244  1.1.1.2  christos         host = f"[{self.host}]" if ":" in self.host else self.host
    245  1.1.1.2  christos         return f"{host}:{self.port}"
    246  1.1.1.2  christos 
    247  1.1.1.2  christos 
    248      1.1  christos @dataclass
    249      1.1  christos class QueryContext:
    250      1.1  christos     """
    251      1.1  christos     Context for the incoming query which may be used for preparing the response.
    252      1.1  christos     """
    253      1.1  christos 
    254      1.1  christos     query: dns.message.Message
    255      1.1  christos     response: dns.message.Message
    256  1.1.1.5  christos     socket: Peer
    257  1.1.1.2  christos     peer: Peer
    258      1.1  christos     protocol: DnsProtocol
    259  1.1.1.5  christos     zone: dns.zone.Zone | None = field(default=None, init=False)
    260  1.1.1.5  christos     soa: dns.rrset.RRset | None = field(default=None, init=False)
    261  1.1.1.5  christos     node: dns.node.Node | None = field(default=None, init=False)
    262  1.1.1.5  christos     answer: dns.rdataset.Rdataset | None = field(default=None, init=False)
    263  1.1.1.5  christos     alias: dns.name.Name | None = field(default=None, init=False)
    264  1.1.1.5  christos     _initialized_response: dns.message.Message | None = field(default=None, init=False)
    265  1.1.1.5  christos     _initialized_response_with_zone_data: dns.message.Message | None = field(
    266  1.1.1.4  christos         default=None, init=False
    267  1.1.1.4  christos     )
    268      1.1  christos 
    269      1.1  christos     @property
    270      1.1  christos     def qname(self) -> dns.name.Name:
    271      1.1  christos         return self.query.question[0].name
    272      1.1  christos 
    273      1.1  christos     @property
    274  1.1.1.3  christos     def current_qname(self) -> dns.name.Name:
    275  1.1.1.3  christos         return self.alias or self.qname
    276  1.1.1.3  christos 
    277  1.1.1.3  christos     @property
    278  1.1.1.4  christos     def qclass(self) -> dns.rdataclass.RdataClass:
    279      1.1  christos         return self.query.question[0].rdclass
    280      1.1  christos 
    281      1.1  christos     @property
    282  1.1.1.4  christos     def qtype(self) -> dns.rdatatype.RdataType:
    283      1.1  christos         return self.query.question[0].rdtype
    284      1.1  christos 
    285  1.1.1.4  christos     def prepare_new_response(
    286  1.1.1.4  christos         self, /, with_zone_data: bool = True
    287  1.1.1.4  christos     ) -> dns.message.Message:
    288  1.1.1.4  christos         if with_zone_data:
    289  1.1.1.4  christos             assert self._initialized_response_with_zone_data
    290  1.1.1.4  christos             self.response = copy.deepcopy(self._initialized_response_with_zone_data)
    291  1.1.1.4  christos         else:
    292  1.1.1.4  christos             assert self._initialized_response
    293  1.1.1.4  christos             self.response = copy.deepcopy(self._initialized_response)
    294  1.1.1.4  christos         return self.response
    295  1.1.1.4  christos 
    296  1.1.1.4  christos     def save_initialized_response(self, /, with_zone_data: bool) -> None:
    297  1.1.1.4  christos         if with_zone_data:
    298  1.1.1.4  christos             self._initialized_response_with_zone_data = copy.deepcopy(self.response)
    299  1.1.1.4  christos         else:
    300  1.1.1.4  christos             self._initialized_response = copy.deepcopy(self.response)
    301  1.1.1.4  christos 
    302      1.1  christos 
    303      1.1  christos @dataclass
    304      1.1  christos class ResponseAction(abc.ABC):
    305      1.1  christos     """
    306      1.1  christos     Base class for actions that can be taken in response to a query.
    307      1.1  christos     """
    308      1.1  christos 
    309      1.1  christos     @abc.abstractmethod
    310  1.1.1.5  christos     async def perform(self) -> dns.message.Message | bytes | None:
    311      1.1  christos         """
    312      1.1  christos         This method is expected to carry out arbitrary actions (e.g. wait for a
    313      1.1  christos         specific amount of time, modify the answer, etc.) and then return the
    314      1.1  christos         DNS response to send (a dns.message.Message, a raw bytes object, or
    315      1.1  christos         None, which prevents any response from being sent).
    316      1.1  christos         """
    317      1.1  christos         raise NotImplementedError
    318      1.1  christos 
    319      1.1  christos 
    320      1.1  christos @dataclass
    321      1.1  christos class DnsResponseSend(ResponseAction):
    322      1.1  christos     """
    323      1.1  christos     Action which yields a dns.message.Message response.
    324      1.1  christos 
    325      1.1  christos     The response may be sent with a delay if requested.
    326      1.1  christos 
    327      1.1  christos     Depending on the value of the `authoritative` property, this class may set
    328      1.1  christos     the AA bit in the response (True), clear it (False), or not touch it at all
    329      1.1  christos     (None).
    330      1.1  christos     """
    331      1.1  christos 
    332      1.1  christos     response: dns.message.Message
    333  1.1.1.5  christos     authoritative: bool | None = None
    334      1.1  christos     delay: float = 0.0
    335  1.1.1.5  christos     acknowledge_hand_rolled_response: bool = False
    336      1.1  christos 
    337  1.1.1.5  christos     async def perform(self) -> dns.message.Message | bytes | None:
    338      1.1  christos         """
    339      1.1  christos         Yield a potentially delayed response that is a dns.message.Message.
    340      1.1  christos         """
    341      1.1  christos         assert isinstance(self.response, dns.message.Message)
    342  1.1.1.5  christos         if not (
    343  1.1.1.5  christos             _is_asyncserver_response(self.response)
    344  1.1.1.5  christos             or self.acknowledge_hand_rolled_response
    345  1.1.1.5  christos         ):
    346  1.1.1.5  christos             error = "The response you are trying to send was not created using "
    347  1.1.1.5  christos             error += "AsyncDnsServer's response preparation methods. "
    348  1.1.1.5  christos             error += "This will break features such as automatic AA flag "
    349  1.1.1.5  christos             error += "and RCODE handling. If you need a fresh copy of a "
    350  1.1.1.5  christos             error += "response, use `QueryContext.prepare_new_response` "
    351  1.1.1.5  christos             error += "instead of `dns.message.make_response`. "
    352  1.1.1.5  christos             error += "To acknowledge this and proceed anyway, set "
    353  1.1.1.5  christos             error += "`acknowledge_hand_rolled_response=True` in "
    354  1.1.1.5  christos             error += "DnsResponseSend's constructor."
    355  1.1.1.5  christos             raise RuntimeError(error)
    356  1.1.1.5  christos 
    357      1.1  christos         if self.authoritative is not None:
    358      1.1  christos             if self.authoritative:
    359      1.1  christos                 self.response.flags |= dns.flags.AA
    360      1.1  christos             else:
    361      1.1  christos                 self.response.flags &= ~dns.flags.AA
    362      1.1  christos         if self.delay > 0:
    363      1.1  christos             logging.info(
    364      1.1  christos                 "Delaying response (ID=%d) by %d ms",
    365      1.1  christos                 self.response.id,
    366      1.1  christos                 self.delay * 1000,
    367      1.1  christos             )
    368      1.1  christos             await asyncio.sleep(self.delay)
    369      1.1  christos         return self.response
    370      1.1  christos 
    371      1.1  christos 
    372      1.1  christos @dataclass
    373      1.1  christos class BytesResponseSend(ResponseAction):
    374      1.1  christos     """
    375      1.1  christos     Action which yields a raw response that is a sequence of bytes.
    376      1.1  christos 
    377      1.1  christos     The response may be sent with a delay if requested.
    378      1.1  christos     """
    379      1.1  christos 
    380      1.1  christos     response: bytes
    381      1.1  christos     delay: float = 0.0
    382      1.1  christos 
    383  1.1.1.5  christos     async def perform(self) -> dns.message.Message | bytes | None:
    384      1.1  christos         """
    385      1.1  christos         Yield a potentially delayed response that is a sequence of bytes.
    386      1.1  christos         """
    387      1.1  christos         assert isinstance(self.response, bytes)
    388      1.1  christos         if self.delay > 0:
    389      1.1  christos             logging.info("Delaying raw response by %d ms", self.delay * 1000)
    390      1.1  christos             await asyncio.sleep(self.delay)
    391      1.1  christos         return self.response
    392      1.1  christos 
    393      1.1  christos 
    394      1.1  christos @dataclass
    395      1.1  christos class ResponseDrop(ResponseAction):
    396      1.1  christos     """
    397      1.1  christos     Action which does nothing - as if a packet was dropped.
    398      1.1  christos     """
    399      1.1  christos 
    400  1.1.1.5  christos     async def perform(self) -> dns.message.Message | bytes | None:
    401      1.1  christos         return None
    402      1.1  christos 
    403      1.1  christos 
    404  1.1.1.4  christos class _ConnectionTeardownRequested(Exception):
    405  1.1.1.4  christos     pass
    406  1.1.1.4  christos 
    407  1.1.1.4  christos 
    408  1.1.1.4  christos @dataclass
    409  1.1.1.5  christos class CloseConnection(ResponseAction):
    410  1.1.1.4  christos     """
    411  1.1.1.5  christos     Action which makes the server close the connection (TCP only).
    412  1.1.1.4  christos 
    413  1.1.1.4  christos     The connection may be closed with a delay if requested.
    414  1.1.1.4  christos     """
    415  1.1.1.4  christos 
    416  1.1.1.4  christos     delay: float = 0.0
    417  1.1.1.4  christos 
    418  1.1.1.5  christos     async def perform(self) -> dns.message.Message | bytes | None:
    419  1.1.1.4  christos         if self.delay > 0:
    420  1.1.1.4  christos             logging.info("Waiting %.1fs before closing TCP connection", self.delay)
    421  1.1.1.4  christos             await asyncio.sleep(self.delay)
    422  1.1.1.4  christos         raise _ConnectionTeardownRequested
    423  1.1.1.4  christos 
    424  1.1.1.4  christos 
    425  1.1.1.4  christos class ConnectionHandler(abc.ABC):
    426  1.1.1.4  christos     """
    427  1.1.1.4  christos     Base class for TCP connection handlers.
    428  1.1.1.4  christos 
    429  1.1.1.4  christos     An installed connection handler is called when a new TCP connection is
    430  1.1.1.4  christos     established.  It may be used to perform arbitrary actions before
    431  1.1.1.4  christos     AsyncDnsServer processes DNS queries.
    432  1.1.1.4  christos     """
    433  1.1.1.4  christos 
    434  1.1.1.4  christos     @abc.abstractmethod
    435  1.1.1.4  christos     async def handle(
    436  1.1.1.4  christos         self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter, peer: Peer
    437  1.1.1.4  christos     ) -> None:
    438  1.1.1.4  christos         """
    439  1.1.1.4  christos         Handle the connection with the provided reader and writer.
    440  1.1.1.4  christos         """
    441  1.1.1.4  christos         raise NotImplementedError
    442  1.1.1.4  christos 
    443  1.1.1.4  christos 
    444  1.1.1.4  christos def block_reading(peer: Peer, writer_not_the_reader: asyncio.StreamWriter) -> None:
    445  1.1.1.4  christos     """
    446  1.1.1.4  christos     Block reads for the reader associated with the provided writer.
    447  1.1.1.4  christos 
    448  1.1.1.4  christos     Yes, pass the writer, not the reader. See the comments below for details.
    449  1.1.1.4  christos     """
    450  1.1.1.4  christos 
    451  1.1.1.4  christos     try:
    452  1.1.1.4  christos         # Python >= 3.7
    453  1.1.1.4  christos         loop = asyncio.get_running_loop()
    454  1.1.1.4  christos     except AttributeError:
    455  1.1.1.4  christos         # Python < 3.7
    456  1.1.1.4  christos         loop = asyncio.get_event_loop()
    457  1.1.1.4  christos 
    458  1.1.1.4  christos     logging.info("Blocking reads from %s", peer)
    459  1.1.1.4  christos 
    460  1.1.1.4  christos     # This is Micha's submission for the Ugliest Hack of the Year contest.
    461  1.1.1.4  christos     # (The alternative was implementing an asyncio transport from scratch.)
    462  1.1.1.4  christos     #
    463  1.1.1.4  christos     # In order to prevent the client socket from being read from, simply
    464  1.1.1.4  christos     # not calling `reader.read()` is not enough, because asyncio buffers
    465  1.1.1.4  christos     # incoming data itself on the transport level.  However, `StreamReader`
    466  1.1.1.4  christos     # does not expose the underlying transport as a property.  Therefore,
    467  1.1.1.4  christos     # cheat by extracting it from `StreamWriter` as it is the same
    468  1.1.1.4  christos     # bidirectional transport as for the read side (a `Transport`, which is
    469  1.1.1.4  christos     # a subclass of both `ReadTransport` and `WriteTransport`) and call
    470  1.1.1.4  christos     # `ReadTransport.pause_reading()` to remove the underlying socket from
    471  1.1.1.4  christos     # the set of descriptors monitored by the selector, thereby preventing
    472  1.1.1.4  christos     # any reads from happening on the client socket.  However...
    473  1.1.1.4  christos     loop.call_soon(writer_not_the_reader.transport.pause_reading)  # type: ignore
    474  1.1.1.4  christos 
    475  1.1.1.4  christos     # ...due to `AsyncDnsServer._handle_tcp()` being a coroutine, by the
    476  1.1.1.4  christos     # time it gets executed, asyncio transport code will already have added
    477  1.1.1.4  christos     # the client socket to the set of descriptors monitored by the
    478  1.1.1.4  christos     # selector.  Therefore, if the client starts sending data immediately,
    479  1.1.1.4  christos     # a read from the socket will have already been scheduled by the time
    480  1.1.1.4  christos     # this handler gets executed.  There is no way to prevent that from
    481  1.1.1.4  christos     # happening, so work around it by abusing the fact that the transport
    482  1.1.1.4  christos     # at hand is specifically an instance of `_SelectorSocketTransport`
    483  1.1.1.4  christos     # (from asyncio.selector_events) and set the size of its read buffer to
    484  1.1.1.4  christos     # just a single byte.  This does give asyncio enough time to read that
    485  1.1.1.4  christos     # single byte from the client socket's buffer before that socket is
    486  1.1.1.4  christos     # removed from the set of monitored descriptors, but prevents the
    487  1.1.1.4  christos     # one-off read from emptying the client socket buffer _entirely_, which
    488  1.1.1.4  christos     # is enough to trigger sending an RST segment when the connection is
    489  1.1.1.4  christos     # closed shortly afterwards.
    490  1.1.1.4  christos     writer_not_the_reader.transport.max_size = 1  # type: ignore
    491  1.1.1.4  christos 
    492  1.1.1.4  christos 
    493  1.1.1.4  christos @dataclass
    494  1.1.1.4  christos class IgnoreAllConnections(ConnectionHandler):
    495  1.1.1.4  christos     """
    496  1.1.1.4  christos     A connection handler that makes the server not read anything from the
    497  1.1.1.4  christos     client socket, effectively ignoring all incoming connections.
    498  1.1.1.4  christos     """
    499  1.1.1.4  christos 
    500  1.1.1.5  christos     _connections: set[asyncio.StreamWriter] = field(default_factory=set)
    501  1.1.1.4  christos 
    502  1.1.1.4  christos     async def handle(
    503  1.1.1.4  christos         self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter, peer: Peer
    504  1.1.1.4  christos     ) -> None:
    505  1.1.1.4  christos         block_reading(peer, writer)
    506  1.1.1.4  christos         # Due to the way various asyncio-related objects (tasks, streams,
    507  1.1.1.4  christos         # transports, selectors) are referencing each other, pausing reads for
    508  1.1.1.4  christos         # a TCP transport (which in practice means removing the client socket
    509  1.1.1.4  christos         # from the set of descriptors monitored by a selector) can cause the
    510  1.1.1.4  christos         # client task (AsyncDnsServer._handle_tcp()) to be prematurely
    511  1.1.1.4  christos         # garbage-collected, causing asyncio code to raise a "Task was
    512  1.1.1.4  christos         # destroyed but it is pending!" exception.  Prevent that from happening
    513  1.1.1.4  christos         # by keeping a reference to each incoming TCP connection to protect its
    514  1.1.1.4  christos         # related asyncio objects from getting garbage-collected.  This
    515  1.1.1.4  christos         # prevents AsyncDnsServer from closing any of the ignored TCP
    516  1.1.1.4  christos         # connections indefinitely, which is obviously a pretty brain-dead idea
    517  1.1.1.4  christos         # for a production-grade DNS server, but AsyncDnsServer was never meant
    518  1.1.1.4  christos         # to be one and this hack reliably solves the problem at hand.
    519  1.1.1.4  christos         self._connections.add(writer)
    520  1.1.1.4  christos 
    521  1.1.1.4  christos 
    522  1.1.1.4  christos @dataclass
    523  1.1.1.4  christos class ConnectionReset(ConnectionHandler):
    524  1.1.1.4  christos     """
    525  1.1.1.4  christos     A connection handler that makes the server close the connection without
    526  1.1.1.4  christos     reading anything from the client socket.
    527  1.1.1.4  christos 
    528  1.1.1.4  christos     The connection may be closed with a delay if requested.
    529  1.1.1.4  christos 
    530  1.1.1.4  christos     The sole purpose of this handler is to trigger a connection reset, i.e. to
    531  1.1.1.4  christos     make the server send an RST segment; this happens when the server closes a
    532  1.1.1.4  christos     client's socket while there is still unread data in that socket's buffer.
    533  1.1.1.4  christos     If closing the connection _after_ the query is read by the server is enough
    534  1.1.1.5  christos     for a given use case, the CloseConnection response handler should be used
    535  1.1.1.5  christos     instead.
    536  1.1.1.4  christos     """
    537  1.1.1.4  christos 
    538  1.1.1.4  christos     delay: float = 0.0
    539  1.1.1.4  christos 
    540  1.1.1.4  christos     async def handle(
    541  1.1.1.4  christos         self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter, peer: Peer
    542  1.1.1.4  christos     ) -> None:
    543  1.1.1.4  christos         block_reading(peer, writer)
    544  1.1.1.4  christos 
    545  1.1.1.4  christos         if self.delay > 0:
    546  1.1.1.4  christos             logging.info(
    547  1.1.1.4  christos                 "Waiting %.1fs before closing TCP connection from %s", self.delay, peer
    548  1.1.1.4  christos             )
    549  1.1.1.4  christos             await asyncio.sleep(self.delay)
    550  1.1.1.4  christos 
    551  1.1.1.4  christos         raise _ConnectionTeardownRequested
    552  1.1.1.4  christos 
    553  1.1.1.4  christos 
    554      1.1  christos class ResponseHandler(abc.ABC):
    555      1.1  christos     """
    556      1.1  christos     Base class for generic response handlers.
    557      1.1  christos 
    558      1.1  christos     If a query passes the `match()` function logic, then it is handled by this
    559      1.1  christos     response handler and response(s) may be generated by the `get_responses()`
    560      1.1  christos     method.
    561      1.1  christos     """
    562      1.1  christos 
    563  1.1.1.2  christos     # pylint: disable=unused-argument
    564      1.1  christos     def match(self, qctx: QueryContext) -> bool:
    565      1.1  christos         """
    566  1.1.1.2  christos         Matching logic - the first handler whose `match()` method returns True
    567  1.1.1.2  christos         is used for handling the query.
    568  1.1.1.2  christos 
    569  1.1.1.2  christos         The default for each handler is to handle all queries.
    570      1.1  christos         """
    571      1.1  christos         return True
    572      1.1  christos 
    573      1.1  christos     @abc.abstractmethod
    574      1.1  christos     async def get_responses(
    575      1.1  christos         self, qctx: QueryContext
    576      1.1  christos     ) -> AsyncGenerator[ResponseAction, None]:
    577      1.1  christos         """
    578      1.1  christos         Custom handler which may produce response(s) to matching queries.
    579      1.1  christos 
    580      1.1  christos         The response prepared from zone data is passed to this method in
    581      1.1  christos         qctx.response.
    582      1.1  christos         """
    583      1.1  christos         yield DnsResponseSend(qctx.response)
    584      1.1  christos 
    585  1.1.1.2  christos     def __str__(self) -> str:
    586  1.1.1.2  christos         return self.__class__.__name__
    587  1.1.1.2  christos 
    588  1.1.1.2  christos 
    589  1.1.1.2  christos class IgnoreAllQueries(ResponseHandler):
    590  1.1.1.2  christos     """
    591  1.1.1.2  christos     Do not respond to any queries sent to the server.
    592  1.1.1.2  christos     """
    593  1.1.1.2  christos 
    594  1.1.1.2  christos     async def get_responses(
    595  1.1.1.2  christos         self, qctx: QueryContext
    596  1.1.1.2  christos     ) -> AsyncGenerator[ResponseAction, None]:
    597  1.1.1.2  christos         yield ResponseDrop()
    598  1.1.1.2  christos 
    599      1.1  christos 
    600  1.1.1.4  christos class QnameHandler(ResponseHandler):
    601  1.1.1.4  christos     """
    602  1.1.1.4  christos     Base class used for deriving custom QNAME handlers.
    603  1.1.1.4  christos 
    604  1.1.1.4  christos     The derived class must specify a list of `qnames` that it wants to handle.
    605  1.1.1.4  christos     Queries for exactly these QNAMEs will then be passed to the
    606  1.1.1.4  christos     `get_response()` method in the derived class.
    607  1.1.1.4  christos     """
    608  1.1.1.4  christos 
    609  1.1.1.4  christos     @property
    610  1.1.1.4  christos     @abc.abstractmethod
    611  1.1.1.5  christos     def qnames(self) -> list[str]:
    612  1.1.1.4  christos         """
    613  1.1.1.4  christos         A list of QNAMEs handled by this class.
    614  1.1.1.4  christos         """
    615  1.1.1.4  christos         raise NotImplementedError
    616  1.1.1.4  christos 
    617  1.1.1.4  christos     def __init__(self) -> None:
    618  1.1.1.5  christos         self._qnames: list[dns.name.Name] = [dns.name.from_text(d) for d in self.qnames]
    619  1.1.1.4  christos 
    620  1.1.1.4  christos     def __str__(self) -> str:
    621  1.1.1.4  christos         return f"{self.__class__.__name__}(QNAMEs: {', '.join(self.qnames)})"
    622  1.1.1.4  christos 
    623  1.1.1.4  christos     def match(self, qctx: QueryContext) -> bool:
    624  1.1.1.4  christos         """
    625  1.1.1.4  christos         Handle queries whose QNAME matches any of the QNAMEs handled by this
    626  1.1.1.4  christos         class.
    627  1.1.1.4  christos         """
    628  1.1.1.4  christos         return qctx.qname in self._qnames
    629  1.1.1.4  christos 
    630  1.1.1.4  christos 
    631  1.1.1.5  christos class QnameQtypeHandler(QnameHandler):
    632  1.1.1.5  christos     """
    633  1.1.1.5  christos     Handle queries for which both of the following conditions are true:
    634  1.1.1.5  christos 
    635  1.1.1.5  christos     - the query's QNAME is present in `self.qnames`,
    636  1.1.1.5  christos     - the query's QTYPE is present in `self.qtypes`.
    637  1.1.1.5  christos     """
    638  1.1.1.5  christos 
    639  1.1.1.5  christos     @property
    640  1.1.1.5  christos     @abc.abstractmethod
    641  1.1.1.5  christos     def qtypes(self) -> list[dns.rdatatype.RdataType]:
    642  1.1.1.5  christos         """
    643  1.1.1.5  christos         A list of QTYPEs handled by this class.
    644  1.1.1.5  christos         """
    645  1.1.1.5  christos         raise NotImplementedError
    646  1.1.1.5  christos 
    647  1.1.1.5  christos     def __init__(self) -> None:
    648  1.1.1.5  christos         super().__init__()
    649  1.1.1.5  christos         self._qtypes: list[dns.rdatatype.RdataType] = self.qtypes
    650  1.1.1.5  christos 
    651  1.1.1.5  christos     def __str__(self) -> str:
    652  1.1.1.5  christos         return f"{self.__class__.__name__}(QNAMEs: {', '.join(self.qnames)}; QTYPEs: {', '.join(map(str, self.qtypes))})"
    653  1.1.1.5  christos 
    654  1.1.1.5  christos     def match(self, qctx: QueryContext) -> bool:
    655  1.1.1.5  christos         """
    656  1.1.1.5  christos         Handle queries whose QNAME and QTYPE match any of the QNAMEs and
    657  1.1.1.5  christos         QTYPEs handled by this class.
    658  1.1.1.5  christos         """
    659  1.1.1.5  christos         return qctx.qtype in self._qtypes and super().match(qctx)
    660  1.1.1.5  christos 
    661  1.1.1.5  christos 
    662  1.1.1.5  christos class StaticResponseHandler(ResponseHandler):
    663  1.1.1.5  christos     """
    664  1.1.1.5  christos     Base class used for deriving custom static response handlers.
    665  1.1.1.5  christos 
    666  1.1.1.5  christos     The derived class can specify the RRsets to be included in the answer,
    667  1.1.1.5  christos     authority, and additional sections of the response, whether to set the AA
    668  1.1.1.5  christos     bit in the response, and a delay before sending the response.
    669  1.1.1.5  christos 
    670  1.1.1.5  christos     The default implementation of `get_responses()` uses these properties to
    671  1.1.1.5  christos     prepare and yield a single response.
    672  1.1.1.5  christos     """
    673  1.1.1.5  christos 
    674  1.1.1.5  christos     @property
    675  1.1.1.5  christos     def rcode(self) -> dns.rcode.Rcode | None:
    676  1.1.1.5  christos         """
    677  1.1.1.5  christos         Optional RCODE to be set in the response.
    678  1.1.1.5  christos         """
    679  1.1.1.5  christos         return None
    680  1.1.1.5  christos 
    681  1.1.1.5  christos     @property
    682  1.1.1.5  christos     def answer(self) -> Sequence[dns.rrset.RRset]:
    683  1.1.1.5  christos         """
    684  1.1.1.5  christos         RRsets to be included in the answer section of the response.
    685  1.1.1.5  christos         """
    686  1.1.1.5  christos         return []
    687  1.1.1.5  christos 
    688  1.1.1.5  christos     @property
    689  1.1.1.5  christos     def authority(self) -> Sequence[dns.rrset.RRset]:
    690  1.1.1.5  christos         """
    691  1.1.1.5  christos         RRsets to be included in the authority section of the response.
    692  1.1.1.5  christos         """
    693  1.1.1.5  christos         return []
    694  1.1.1.5  christos 
    695  1.1.1.5  christos     @property
    696  1.1.1.5  christos     def additional(self) -> Sequence[dns.rrset.RRset]:
    697  1.1.1.5  christos         """
    698  1.1.1.5  christos         RRsets to be included in the additional section of the response.
    699  1.1.1.5  christos         """
    700  1.1.1.5  christos         return []
    701  1.1.1.5  christos 
    702  1.1.1.5  christos     @property
    703  1.1.1.5  christos     def authoritative(self) -> bool | None:
    704  1.1.1.5  christos         """
    705  1.1.1.5  christos         Whether to set the AA bit in the response.
    706  1.1.1.5  christos         """
    707  1.1.1.5  christos         return None
    708  1.1.1.5  christos 
    709  1.1.1.5  christos     @property
    710  1.1.1.5  christos     def delay(self) -> float:
    711  1.1.1.5  christos         """
    712  1.1.1.5  christos         Delay before sending the response.
    713  1.1.1.5  christos         """
    714  1.1.1.5  christos         return 0.0
    715  1.1.1.5  christos 
    716  1.1.1.5  christos     async def get_responses(
    717  1.1.1.5  christos         self, qctx: QueryContext
    718  1.1.1.5  christos     ) -> AsyncGenerator[DnsResponseSend, None]:
    719  1.1.1.5  christos         qctx.prepare_new_response(with_zone_data=False)
    720  1.1.1.5  christos         qctx.response.answer.extend(self.answer)
    721  1.1.1.5  christos         qctx.response.authority.extend(self.authority)
    722  1.1.1.5  christos         qctx.response.additional.extend(self.additional)
    723  1.1.1.5  christos         if self.rcode is not None:
    724  1.1.1.5  christos             qctx.response.set_rcode(self.rcode)
    725  1.1.1.5  christos         yield DnsResponseSend(
    726  1.1.1.5  christos             qctx.response, authoritative=self.authoritative, delay=self.delay
    727  1.1.1.5  christos         )
    728  1.1.1.5  christos 
    729  1.1.1.5  christos 
    730      1.1  christos class DomainHandler(ResponseHandler):
    731      1.1  christos     """
    732      1.1  christos     Base class used for deriving custom domain handlers.
    733      1.1  christos 
    734      1.1  christos     The derived class must specify a list of `domains` that it wants to handle.
    735      1.1  christos     Queries for any of these domains (and their subdomains) will then be passed
    736      1.1  christos     to the `get_response()` method in the derived class.
    737  1.1.1.5  christos 
    738  1.1.1.5  christos     The most specific matching domain is stored in the `matched_domain` attribute.
    739      1.1  christos     """
    740      1.1  christos 
    741      1.1  christos     @property
    742      1.1  christos     @abc.abstractmethod
    743  1.1.1.5  christos     def domains(self) -> list[str]:
    744      1.1  christos         """
    745      1.1  christos         A list of domain names handled by this class.
    746      1.1  christos         """
    747      1.1  christos         raise NotImplementedError
    748      1.1  christos 
    749      1.1  christos     def __init__(self) -> None:
    750  1.1.1.5  christos         self._domains: list[dns.name.Name] = sorted(
    751  1.1.1.5  christos             [dns.name.from_text(d) for d in self.domains], reverse=True
    752  1.1.1.5  christos         )
    753  1.1.1.5  christos         self._matched_domain: dns.name.Name | None = None
    754  1.1.1.5  christos 
    755  1.1.1.5  christos     @property
    756  1.1.1.5  christos     def matched_domain(self) -> dns.name.Name:
    757  1.1.1.5  christos         assert self._matched_domain is not None
    758  1.1.1.5  christos         return self._matched_domain
    759      1.1  christos 
    760      1.1  christos     def __str__(self) -> str:
    761      1.1  christos         return f"{self.__class__.__name__}(domains: {', '.join(self.domains)})"
    762      1.1  christos 
    763      1.1  christos     def match(self, qctx: QueryContext) -> bool:
    764      1.1  christos         """
    765      1.1  christos         Handle queries whose QNAME matches any of the domains handled by this
    766      1.1  christos         class.
    767      1.1  christos         """
    768  1.1.1.5  christos         self._matched_domain = None
    769      1.1  christos         for domain in self._domains:
    770      1.1  christos             if qctx.qname.is_subdomain(domain):
    771  1.1.1.5  christos                 self._matched_domain = domain
    772      1.1  christos                 return True
    773      1.1  christos         return False
    774      1.1  christos 
    775      1.1  christos 
    776  1.1.1.5  christos class ForwarderHandler(ResponseHandler):
    777  1.1.1.5  christos     """
    778  1.1.1.5  christos     A handler forwarding all received queries to another DNS server with an
    779  1.1.1.5  christos     optional delay and then relaying the responses back to the original client.
    780  1.1.1.5  christos 
    781  1.1.1.5  christos     Queries are currently always forwarded via UDP.
    782  1.1.1.5  christos     """
    783  1.1.1.5  christos 
    784  1.1.1.5  christos     @property
    785  1.1.1.5  christos     @abc.abstractmethod
    786  1.1.1.5  christos     def target(self) -> str:
    787  1.1.1.5  christos         """
    788  1.1.1.5  christos         The address of the DNS server to forward queries to.
    789  1.1.1.5  christos         """
    790  1.1.1.5  christos         raise NotImplementedError
    791  1.1.1.5  christos 
    792  1.1.1.5  christos     @property
    793  1.1.1.5  christos     def port(self) -> int:
    794  1.1.1.5  christos         """
    795  1.1.1.5  christos         The port of the DNS server to forward queries to.
    796  1.1.1.5  christos 
    797  1.1.1.5  christos         The default value of 0 causes the same port as the one used by this
    798  1.1.1.5  christos         server for listening to be used.
    799  1.1.1.5  christos         """
    800  1.1.1.5  christos         return 0
    801  1.1.1.5  christos 
    802  1.1.1.5  christos     @property
    803  1.1.1.5  christos     def delay(self) -> float:
    804  1.1.1.5  christos         """
    805  1.1.1.5  christos         The number of seconds to wait before forwarding each query.
    806  1.1.1.5  christos         """
    807  1.1.1.5  christos         return 0.0
    808  1.1.1.5  christos 
    809  1.1.1.5  christos     def __str__(self) -> str:
    810  1.1.1.5  christos         return f"{self.__class__.__name__}(target: {self.target}:{self.port})"
    811  1.1.1.5  christos 
    812  1.1.1.5  christos     class ForwarderProtocol(asyncio.DatagramProtocol):
    813  1.1.1.5  christos         def __init__(self, query: bytes, response: asyncio.Future) -> None:
    814  1.1.1.5  christos             self._query = query
    815  1.1.1.5  christos             self._response = response
    816  1.1.1.5  christos 
    817  1.1.1.5  christos         def connection_made(self, transport: asyncio.BaseTransport) -> None:
    818  1.1.1.5  christos             logging.debug("[OUT] %s", self._query.hex())
    819  1.1.1.5  christos             cast(asyncio.DatagramTransport, transport).sendto(self._query)
    820  1.1.1.5  christos 
    821  1.1.1.5  christos         def datagram_received(self, data: bytes, _: tuple[str, int]) -> None:
    822  1.1.1.5  christos             logging.debug("[IN] %s", data.hex())
    823  1.1.1.5  christos             self._response.set_result(data)
    824  1.1.1.5  christos 
    825  1.1.1.5  christos     async def get_responses(
    826  1.1.1.5  christos         self, qctx: QueryContext
    827  1.1.1.5  christos     ) -> AsyncGenerator[ResponseAction, None]:
    828  1.1.1.5  christos         loop = asyncio.get_running_loop()
    829  1.1.1.5  christos         response = loop.create_future()
    830  1.1.1.5  christos         forwarding_target = f"{self.target}:{self.port or qctx.socket.port}"
    831  1.1.1.5  christos 
    832  1.1.1.5  christos         if self.delay > 0:
    833  1.1.1.5  christos             logging.info(
    834  1.1.1.5  christos                 "Waiting %.1fs before forwarding %s query from %s to %s over UDP",
    835  1.1.1.5  christos                 self.delay,
    836  1.1.1.5  christos                 qctx.protocol.name,
    837  1.1.1.5  christos                 qctx.peer,
    838  1.1.1.5  christos                 forwarding_target,
    839  1.1.1.5  christos             )
    840  1.1.1.5  christos             await asyncio.sleep(self.delay)
    841  1.1.1.5  christos 
    842  1.1.1.5  christos         logging.info(
    843  1.1.1.5  christos             "Forwarding %s query from %s to %s over UDP",
    844  1.1.1.5  christos             qctx.protocol.name,
    845  1.1.1.5  christos             qctx.peer,
    846  1.1.1.5  christos             forwarding_target,
    847  1.1.1.5  christos         )
    848  1.1.1.5  christos 
    849  1.1.1.5  christos         transport, _ = await loop.create_datagram_endpoint(
    850  1.1.1.5  christos             lambda: self.ForwarderProtocol(qctx.query.to_wire(), response),
    851  1.1.1.5  christos             local_addr=(qctx.socket.host, 0),
    852  1.1.1.5  christos             remote_addr=(self.target, self.port or qctx.socket.port),
    853  1.1.1.5  christos         )
    854  1.1.1.5  christos 
    855  1.1.1.5  christos         try:
    856  1.1.1.5  christos             await response
    857  1.1.1.5  christos         finally:
    858  1.1.1.5  christos             transport.close()
    859  1.1.1.5  christos 
    860  1.1.1.5  christos         logging.info(
    861  1.1.1.5  christos             "Relaying UDP response from %s to %s over %s",
    862  1.1.1.5  christos             forwarding_target,
    863  1.1.1.5  christos             qctx.peer,
    864  1.1.1.5  christos             qctx.protocol.name,
    865  1.1.1.5  christos         )
    866  1.1.1.5  christos 
    867  1.1.1.5  christos         try:
    868  1.1.1.5  christos             message = _DnsMessageWithTsigDisabled.from_wire(response.result())
    869  1.1.1.5  christos             yield DnsResponseSend(message, acknowledge_hand_rolled_response=True)
    870  1.1.1.5  christos         except dns.exception.DNSException:
    871  1.1.1.5  christos             logging.warning(
    872  1.1.1.5  christos                 "Failed to parse response from %s as a DNS message, relaying it as raw bytes",
    873  1.1.1.5  christos                 forwarding_target,
    874  1.1.1.5  christos             )
    875  1.1.1.5  christos             yield BytesResponseSend(response.result())
    876  1.1.1.5  christos 
    877  1.1.1.5  christos 
    878      1.1  christos @dataclass
    879      1.1  christos class _ZoneTreeNode:
    880      1.1  christos     """
    881      1.1  christos     A node representing a zone with one origin.
    882      1.1  christos     """
    883      1.1  christos 
    884  1.1.1.5  christos     zone: dns.zone.Zone | None
    885  1.1.1.5  christos     children: list["_ZoneTreeNode"] = field(default_factory=list)
    886      1.1  christos 
    887      1.1  christos 
    888      1.1  christos class _ZoneTree:
    889      1.1  christos     """
    890      1.1  christos     Tree with independent zones.
    891      1.1  christos 
    892      1.1  christos     This zone tree is used as a backing structure for the DNS server. The
    893      1.1  christos     individual zones are independent to allow the (single) server to serve both
    894      1.1  christos     the parent zone and a child zone if needed.
    895      1.1  christos     """
    896      1.1  christos 
    897      1.1  christos     def __init__(self) -> None:
    898      1.1  christos         self._root: _ZoneTreeNode = _ZoneTreeNode(None)
    899      1.1  christos 
    900      1.1  christos     def add(self, zone: dns.zone.Zone) -> None:
    901      1.1  christos         """
    902      1.1  christos         Add a zone to the tree and rearrange sub-zones if necessary.
    903      1.1  christos         """
    904      1.1  christos         assert zone.origin
    905      1.1  christos         best_match = self._find_best_match(zone.origin, self._root)
    906      1.1  christos         added_node = _ZoneTreeNode(zone)
    907      1.1  christos         self._move_children(best_match, added_node)
    908      1.1  christos         best_match.children.append(added_node)
    909      1.1  christos 
    910      1.1  christos     def _find_best_match(
    911      1.1  christos         self, name: dns.name.Name, start_node: _ZoneTreeNode
    912      1.1  christos     ) -> _ZoneTreeNode:
    913      1.1  christos         for child in start_node.children:
    914      1.1  christos             assert child.zone
    915      1.1  christos             assert child.zone.origin
    916      1.1  christos             if name.is_subdomain(child.zone.origin):
    917      1.1  christos                 return self._find_best_match(name, child)
    918      1.1  christos         return start_node
    919      1.1  christos 
    920      1.1  christos     def _move_children(self, node_from: _ZoneTreeNode, node_to: _ZoneTreeNode) -> None:
    921      1.1  christos         assert node_to.zone
    922      1.1  christos         assert node_to.zone.origin
    923      1.1  christos 
    924      1.1  christos         children_to_move = []
    925      1.1  christos         for child in node_from.children:
    926      1.1  christos             assert child.zone
    927      1.1  christos             assert child.zone.origin
    928      1.1  christos             if child.zone.origin.is_subdomain(node_to.zone.origin):
    929      1.1  christos                 children_to_move.append(child)
    930      1.1  christos 
    931      1.1  christos         for child in children_to_move:
    932      1.1  christos             node_from.children.remove(child)
    933      1.1  christos             node_to.children.append(child)
    934      1.1  christos 
    935  1.1.1.5  christos     def find_best_zone(self, name: dns.name.Name) -> dns.zone.Zone | None:
    936      1.1  christos         """
    937      1.1  christos         Return the closest matching zone (if any) for the domain name.
    938      1.1  christos         """
    939      1.1  christos         node = self._find_best_match(name, self._root)
    940      1.1  christos         return node.zone if node != self._root else None
    941      1.1  christos 
    942      1.1  christos 
    943  1.1.1.4  christos class _DnsMessageWithTsigDisabled(dns.message.Message):
    944  1.1.1.4  christos     """
    945  1.1.1.4  christos     A wrapper for `dns.message.Message` that works around a dnspython bug
    946  1.1.1.4  christos     causing exceptions to be raised when `make_response()` or `to_wire()` are
    947  1.1.1.4  christos     called for a message created using `dns.message.from_wire(keyring=False)`.
    948  1.1.1.4  christos 
    949  1.1.1.4  christos     See https://github.com/rthalley/dnspython/issues/1205 for more details.
    950  1.1.1.4  christos     """
    951  1.1.1.4  christos 
    952  1.1.1.4  christos     class _DisableTsigHandling(contextlib.ContextDecorator):
    953  1.1.1.5  christos         def __init__(self, message: dns.message.Message | None = None) -> None:
    954  1.1.1.4  christos             self.original_tsig_sign = dns.tsig.sign
    955  1.1.1.4  christos             self.original_tsig_validate = dns.tsig.validate
    956  1.1.1.4  christos             if message:
    957  1.1.1.4  christos                 self.tsig = message.tsig
    958  1.1.1.4  christos 
    959  1.1.1.4  christos         def __enter__(self) -> None:
    960  1.1.1.4  christos             """
    961  1.1.1.4  christos             Override the `dns.tsig.sign` and `dns.tsig.validate` functions to prevent them
    962  1.1.1.4  christos             from failing on messages initialized with `dns.message.from_wire(keyring=False)`.
    963  1.1.1.4  christos             """
    964  1.1.1.4  christos 
    965  1.1.1.5  christos             def sign(*_: Any, **__: Any) -> tuple[dns.rdata.Rdata, None]:
    966  1.1.1.4  christos                 assert self.tsig
    967  1.1.1.4  christos                 return self.tsig[0], None
    968  1.1.1.4  christos 
    969  1.1.1.4  christos             def validate(*_: Any, **__: Any) -> None:
    970  1.1.1.4  christos                 return None
    971  1.1.1.4  christos 
    972  1.1.1.4  christos             dns.tsig.sign = sign
    973  1.1.1.4  christos             dns.tsig.validate = validate
    974  1.1.1.4  christos 
    975  1.1.1.4  christos         def __exit__(self, *_: Any, **__: Any) -> None:
    976  1.1.1.4  christos             dns.tsig.sign = self.original_tsig_sign
    977  1.1.1.4  christos             dns.tsig.validate = self.original_tsig_validate
    978  1.1.1.4  christos 
    979  1.1.1.4  christos     @classmethod
    980  1.1.1.4  christos     def from_wire(cls, wire: bytes) -> "_DnsMessageWithTsigDisabled":
    981  1.1.1.4  christos         with cls._DisableTsigHandling():
    982  1.1.1.4  christos             message = dns.message.from_wire(wire, keyring=False)
    983  1.1.1.4  christos             message.__class__ = _DnsMessageWithTsigDisabled
    984  1.1.1.4  christos 
    985  1.1.1.4  christos         return cast(_DnsMessageWithTsigDisabled, message)
    986  1.1.1.4  christos 
    987  1.1.1.4  christos     @property
    988  1.1.1.4  christos     def had_tsig(self) -> bool:
    989  1.1.1.4  christos         """
    990  1.1.1.4  christos         Override the `had_tsig()` method to always return False, to prevent
    991  1.1.1.4  christos         `make_response()` from crashing.
    992  1.1.1.4  christos         """
    993  1.1.1.4  christos         return False
    994  1.1.1.4  christos 
    995  1.1.1.4  christos     def to_wire(self, *args: Any, **kwargs: Any) -> bytes:
    996  1.1.1.4  christos         """
    997  1.1.1.4  christos         Override the `to_wire()` method to prevent it from trying to sign
    998  1.1.1.4  christos         the message with TSIG.
    999  1.1.1.4  christos         """
   1000  1.1.1.4  christos         with self._DisableTsigHandling(self):
   1001  1.1.1.4  christos             return super().to_wire(*args, **kwargs)
   1002  1.1.1.4  christos 
   1003  1.1.1.4  christos 
   1004  1.1.1.4  christos class _NoKeyringType:
   1005  1.1.1.4  christos     pass
   1006  1.1.1.4  christos 
   1007  1.1.1.4  christos 
   1008  1.1.1.5  christos _ASYNCSERVER_RESPONSE_MARKER = "__is_asyncserver_response__"
   1009  1.1.1.5  christos 
   1010  1.1.1.5  christos 
   1011  1.1.1.5  christos def _make_asyncserver_response(query: dns.message.Message) -> dns.message.Message:
   1012  1.1.1.5  christos     response = dns.message.make_response(query)
   1013  1.1.1.5  christos     setattr(response, _ASYNCSERVER_RESPONSE_MARKER, True)
   1014  1.1.1.5  christos     return response
   1015  1.1.1.5  christos 
   1016  1.1.1.5  christos 
   1017  1.1.1.5  christos def _is_asyncserver_response(message: dns.message.Message) -> bool:
   1018  1.1.1.5  christos     return getattr(message, _ASYNCSERVER_RESPONSE_MARKER, False)
   1019  1.1.1.5  christos 
   1020  1.1.1.5  christos 
   1021      1.1  christos class AsyncDnsServer(AsyncServer):
   1022      1.1  christos     """
   1023      1.1  christos     DNS server which responds to queries based on zone data and/or custom
   1024      1.1  christos     handlers.
   1025      1.1  christos 
   1026      1.1  christos     The server may use custom handlers which allow arbitrary query processing.
   1027      1.1  christos     These don't need to be standards-compliant and can be used for testing all
   1028      1.1  christos     sorts of scenarios, including delaying responses, synthesizing them based
   1029      1.1  christos     on query contents etc.
   1030      1.1  christos 
   1031      1.1  christos     The server also loads any zone files (*.db) found in its directory and
   1032      1.1  christos     serves them. Responses prepared using zone data can then be modified,
   1033      1.1  christos     replaced, or suppressed by query handlers. Query handlers can also generate
   1034      1.1  christos     response from scratch, without using zone data at all.
   1035      1.1  christos     """
   1036      1.1  christos 
   1037  1.1.1.4  christos     def __init__(
   1038  1.1.1.4  christos         self,
   1039  1.1.1.4  christos         /,
   1040  1.1.1.4  christos         default_rcode: dns.rcode.Rcode = dns.rcode.REFUSED,
   1041  1.1.1.4  christos         default_aa: bool = False,
   1042  1.1.1.5  christos         keyring: (
   1043  1.1.1.5  christos             dict[dns.name.Name, dns.tsig.Key] | None | _NoKeyringType
   1044  1.1.1.5  christos         ) = _NoKeyringType(),
   1045  1.1.1.4  christos         acknowledge_manual_dname_handling: bool = False,
   1046  1.1.1.4  christos     ) -> None:
   1047      1.1  christos         super().__init__(self._handle_udp, self._handle_tcp, "ans.pid")
   1048      1.1  christos 
   1049      1.1  christos         self._zone_tree: _ZoneTree = _ZoneTree()
   1050  1.1.1.5  christos         self._connection_handler: ConnectionHandler | None = None
   1051  1.1.1.5  christos         self._response_handlers: list[ResponseHandler] = []
   1052  1.1.1.4  christos         self._default_rcode = default_rcode
   1053  1.1.1.4  christos         self._default_aa = default_aa
   1054  1.1.1.4  christos         self._keyring = keyring
   1055  1.1.1.3  christos         self._acknowledge_manual_dname_handling = acknowledge_manual_dname_handling
   1056      1.1  christos 
   1057  1.1.1.3  christos         self._load_zones()
   1058      1.1  christos 
   1059  1.1.1.2  christos     def install_response_handler(
   1060  1.1.1.2  christos         self, handler: ResponseHandler, prepend: bool = False
   1061  1.1.1.2  christos     ) -> None:
   1062      1.1  christos         """
   1063  1.1.1.2  christos         Add a response handler that will be used to handle matching queries.
   1064      1.1  christos 
   1065      1.1  christos         Response handlers can modify, replace, or suppress the answers prepared
   1066      1.1  christos         from zone file contents.
   1067  1.1.1.2  christos 
   1068  1.1.1.2  christos         The provided handler is installed at the end of the response handler
   1069  1.1.1.2  christos         list unless `prepend` is set to True, in which case it is installed at
   1070  1.1.1.2  christos         the beginning of the response handler list.
   1071      1.1  christos         """
   1072      1.1  christos         logging.info("Installing response handler: %s", handler)
   1073  1.1.1.2  christos         if prepend:
   1074  1.1.1.2  christos             self._response_handlers.insert(0, handler)
   1075  1.1.1.2  christos         else:
   1076  1.1.1.2  christos             self._response_handlers.append(handler)
   1077  1.1.1.2  christos 
   1078  1.1.1.5  christos     def install_response_handlers(self, *handlers: ResponseHandler) -> None:
   1079  1.1.1.4  christos         for handler in handlers:
   1080  1.1.1.4  christos             self.install_response_handler(handler)
   1081  1.1.1.4  christos 
   1082  1.1.1.5  christos     def replace_response_handlers(self, *new_handlers: ResponseHandler) -> None:
   1083  1.1.1.5  christos         """
   1084  1.1.1.5  christos         Uninstall all currently installed handlers and install the provided ones.
   1085  1.1.1.5  christos         """
   1086  1.1.1.5  christos         logging.info("Uninstalling response handlers: %s", str(self._response_handlers))
   1087  1.1.1.5  christos         self._response_handlers.clear()
   1088  1.1.1.5  christos         self.install_response_handlers(*new_handlers)
   1089  1.1.1.5  christos 
   1090  1.1.1.2  christos     def uninstall_response_handler(self, handler: ResponseHandler) -> None:
   1091  1.1.1.2  christos         """
   1092  1.1.1.2  christos         Remove the specified handler from the list of response handlers.
   1093  1.1.1.2  christos         """
   1094  1.1.1.2  christos         logging.info("Uninstalling response handler: %s", handler)
   1095  1.1.1.2  christos         self._response_handlers.remove(handler)
   1096      1.1  christos 
   1097  1.1.1.4  christos     def install_connection_handler(self, handler: ConnectionHandler) -> None:
   1098  1.1.1.4  christos         """
   1099  1.1.1.4  christos         Install a connection handler that will be called when a new TCP
   1100  1.1.1.4  christos         connection is established.
   1101  1.1.1.4  christos         """
   1102  1.1.1.4  christos         if self._connection_handler:
   1103  1.1.1.4  christos             raise RuntimeError("Only one connection handler can be installed")
   1104  1.1.1.4  christos         self._connection_handler = handler
   1105  1.1.1.4  christos 
   1106      1.1  christos     def _load_zones(self) -> None:
   1107      1.1  christos         for entry in os.scandir():
   1108      1.1  christos             entry_path = pathlib.Path(entry.path)
   1109      1.1  christos             if entry_path.suffix != ".db":
   1110      1.1  christos                 continue
   1111  1.1.1.3  christos             zone = self._load_zone(entry_path)
   1112      1.1  christos             self._zone_tree.add(zone)
   1113      1.1  christos 
   1114  1.1.1.3  christos     def _load_zone(self, zone_file_path: pathlib.Path) -> dns.zone.Zone:
   1115  1.1.1.3  christos         logging.info("Loading zone file %s", zone_file_path)
   1116  1.1.1.4  christos         zone = self._load_zone_file(zone_file_path)
   1117  1.1.1.3  christos         self._abort_if_dname_found_unless_acknowledged(zone)
   1118  1.1.1.3  christos         return zone
   1119  1.1.1.3  christos 
   1120  1.1.1.4  christos     def _load_zone_file(self, zone_file_path: pathlib.Path) -> dns.zone.Zone:
   1121  1.1.1.4  christos         try:
   1122  1.1.1.4  christos             zone = self._load_zone_file_with_origin(zone_file_path)
   1123  1.1.1.4  christos         except dns.zone.UnknownOrigin:
   1124  1.1.1.4  christos             zone = self._load_zone_file_without_origin(zone_file_path)
   1125  1.1.1.4  christos 
   1126  1.1.1.4  christos         return zone
   1127  1.1.1.4  christos 
   1128  1.1.1.4  christos     def _load_zone_file_with_origin(
   1129  1.1.1.4  christos         self, zone_file_path: pathlib.Path
   1130  1.1.1.4  christos     ) -> dns.zone.Zone:
   1131  1.1.1.4  christos         zone = dns.zone.from_file(str(zone_file_path), origin=None, relativize=False)
   1132  1.1.1.4  christos         if zone.origin != dns.name.root:
   1133  1.1.1.4  christos             error = "only the root zone may use $ORIGIN in the zone file; "
   1134  1.1.1.4  christos             error += "for every other zone, its origin is determined by "
   1135  1.1.1.4  christos             error += "the name of the file it is loaded from"
   1136  1.1.1.4  christos             raise ValueError(error)
   1137  1.1.1.4  christos         return zone
   1138  1.1.1.4  christos 
   1139  1.1.1.4  christos     def _load_zone_file_without_origin(
   1140  1.1.1.4  christos         self, zone_file_path: pathlib.Path
   1141  1.1.1.4  christos     ) -> dns.zone.Zone:
   1142  1.1.1.4  christos         origin = dns.name.from_text(zone_file_path.stem)
   1143  1.1.1.4  christos         return dns.zone.from_file(str(zone_file_path), origin=origin, relativize=False)
   1144  1.1.1.4  christos 
   1145  1.1.1.3  christos     def _abort_if_dname_found_unless_acknowledged(self, zone: dns.zone.Zone) -> None:
   1146  1.1.1.3  christos         if self._acknowledge_manual_dname_handling:
   1147  1.1.1.3  christos             return
   1148  1.1.1.3  christos 
   1149  1.1.1.3  christos         error = f'DNAME records found in zone "{zone.origin}"; '
   1150  1.1.1.3  christos         error += "this server does not handle DNAME in a standards-compliant way; "
   1151  1.1.1.3  christos         error += "pass `acknowledge_manual_dname_handling=True` to the "
   1152  1.1.1.3  christos         error += "AsyncDnsServer constructor to acknowledge this and load zone anyway"
   1153  1.1.1.3  christos 
   1154  1.1.1.3  christos         for node in zone.nodes.values():
   1155  1.1.1.3  christos             for rdataset in node:
   1156  1.1.1.3  christos                 if rdataset.rdtype == dns.rdatatype.DNAME:
   1157  1.1.1.3  christos                     raise ValueError(error)
   1158  1.1.1.3  christos 
   1159      1.1  christos     async def _handle_udp(
   1160  1.1.1.5  christos         self, wire: bytes, addr: tuple[str, int], transport: asyncio.DatagramTransport
   1161      1.1  christos     ) -> None:
   1162      1.1  christos         logging.debug("Received UDP message: %s", wire.hex())
   1163  1.1.1.5  christos         socket_info = transport.get_extra_info("sockname")
   1164  1.1.1.5  christos         socket = Peer(socket_info[0], socket_info[1])
   1165  1.1.1.2  christos         peer = Peer(addr[0], addr[1])
   1166  1.1.1.5  christos         responses = self._handle_query(wire, socket, peer, DnsProtocol.UDP)
   1167      1.1  christos         async for response in responses:
   1168  1.1.1.3  christos             logging.debug("Sending UDP message: %s", response.hex())
   1169  1.1.1.2  christos             transport.sendto(response, addr)
   1170      1.1  christos 
   1171      1.1  christos     async def _handle_tcp(
   1172      1.1  christos         self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter
   1173      1.1  christos     ) -> None:
   1174  1.1.1.2  christos         peer_info = writer.get_extra_info("peername")
   1175  1.1.1.2  christos         peer = Peer(peer_info[0], peer_info[1])
   1176  1.1.1.2  christos         logging.debug("Accepted TCP connection from %s", peer)
   1177      1.1  christos 
   1178  1.1.1.4  christos         try:
   1179  1.1.1.4  christos             if self._connection_handler:
   1180  1.1.1.4  christos                 await self._connection_handler.handle(reader, writer, peer)
   1181  1.1.1.4  christos             while True:
   1182  1.1.1.2  christos                 wire = await self._read_tcp_query(reader, peer)
   1183  1.1.1.2  christos                 if not wire:
   1184  1.1.1.2  christos                     break
   1185  1.1.1.2  christos                 await self._send_tcp_response(writer, peer, wire)
   1186  1.1.1.4  christos         except _ConnectionTeardownRequested:
   1187  1.1.1.4  christos             pass
   1188  1.1.1.4  christos         except ConnectionResetError:
   1189  1.1.1.4  christos             logging.error("TCP connection from %s reset by peer", peer)
   1190  1.1.1.4  christos             return
   1191      1.1  christos 
   1192  1.1.1.2  christos         logging.debug("Closing TCP connection from %s", peer)
   1193      1.1  christos         writer.close()
   1194  1.1.1.2  christos         try:
   1195  1.1.1.2  christos             # Python >= 3.7
   1196  1.1.1.2  christos             await writer.wait_closed()
   1197  1.1.1.2  christos         except AttributeError:
   1198  1.1.1.2  christos             # Python < 3.7
   1199  1.1.1.2  christos             pass
   1200  1.1.1.2  christos 
   1201  1.1.1.2  christos     async def _read_tcp_query(
   1202  1.1.1.2  christos         self, reader: asyncio.StreamReader, peer: Peer
   1203  1.1.1.5  christos     ) -> bytes | None:
   1204  1.1.1.2  christos         wire_length = await self._read_tcp_query_wire_length(reader, peer)
   1205  1.1.1.2  christos         if not wire_length:
   1206  1.1.1.2  christos             return None
   1207  1.1.1.2  christos 
   1208  1.1.1.2  christos         return await self._read_tcp_query_wire(reader, peer, wire_length)
   1209  1.1.1.2  christos 
   1210  1.1.1.2  christos     async def _read_tcp_query_wire_length(
   1211  1.1.1.2  christos         self, reader: asyncio.StreamReader, peer: Peer
   1212  1.1.1.5  christos     ) -> int | None:
   1213  1.1.1.2  christos         logging.debug("Receiving TCP message length from %s...", peer)
   1214  1.1.1.2  christos 
   1215  1.1.1.2  christos         wire_length_bytes = await self._read_tcp_octets(reader, peer, 2)
   1216  1.1.1.2  christos         if not wire_length_bytes:
   1217  1.1.1.2  christos             return None
   1218      1.1  christos 
   1219  1.1.1.2  christos         (wire_length,) = struct.unpack("!H", wire_length_bytes)
   1220  1.1.1.2  christos 
   1221  1.1.1.2  christos         return wire_length
   1222  1.1.1.2  christos 
   1223  1.1.1.2  christos     async def _read_tcp_query_wire(
   1224  1.1.1.2  christos         self, reader: asyncio.StreamReader, peer: Peer, wire_length: int
   1225  1.1.1.5  christos     ) -> bytes | None:
   1226  1.1.1.2  christos         logging.debug("Receiving TCP message (%d octets) from %s...", wire_length, peer)
   1227  1.1.1.2  christos 
   1228  1.1.1.2  christos         wire = await self._read_tcp_octets(reader, peer, wire_length)
   1229  1.1.1.2  christos         if not wire:
   1230  1.1.1.2  christos             return None
   1231  1.1.1.2  christos 
   1232  1.1.1.2  christos         logging.debug("Received complete TCP message from %s: %s", peer, wire.hex())
   1233  1.1.1.2  christos 
   1234  1.1.1.2  christos         return wire
   1235  1.1.1.2  christos 
   1236  1.1.1.2  christos     async def _read_tcp_octets(
   1237  1.1.1.2  christos         self, reader: asyncio.StreamReader, peer: Peer, expected: int
   1238  1.1.1.5  christos     ) -> bytes | None:
   1239  1.1.1.2  christos         buffer = b""
   1240  1.1.1.2  christos 
   1241  1.1.1.2  christos         while len(buffer) < expected:
   1242  1.1.1.2  christos             chunk = await reader.read(expected - len(buffer))
   1243  1.1.1.2  christos             if not chunk:
   1244  1.1.1.2  christos                 if buffer:
   1245  1.1.1.2  christos                     logging.debug(
   1246  1.1.1.2  christos                         "Received short TCP message (%d octets) from %s: %s",
   1247  1.1.1.2  christos                         len(buffer),
   1248  1.1.1.2  christos                         peer,
   1249  1.1.1.2  christos                         buffer.hex(),
   1250  1.1.1.2  christos                     )
   1251  1.1.1.2  christos                 else:
   1252  1.1.1.2  christos                     logging.debug("Received disconnect from %s", peer)
   1253  1.1.1.2  christos                 return None
   1254      1.1  christos 
   1255  1.1.1.2  christos             logging.debug("Received %d TCP octets from %s", len(chunk), peer)
   1256  1.1.1.2  christos             buffer += chunk
   1257  1.1.1.2  christos 
   1258  1.1.1.2  christos         return buffer
   1259  1.1.1.2  christos 
   1260  1.1.1.2  christos     async def _send_tcp_response(
   1261  1.1.1.2  christos         self, writer: asyncio.StreamWriter, peer: Peer, wire: bytes
   1262      1.1  christos     ) -> None:
   1263  1.1.1.5  christos         socket_info = writer.get_extra_info("sockname")
   1264  1.1.1.5  christos         socket = Peer(socket_info[0], socket_info[1])
   1265  1.1.1.5  christos         responses = self._handle_query(wire, socket, peer, DnsProtocol.TCP)
   1266  1.1.1.2  christos         async for response in responses:
   1267  1.1.1.3  christos             logging.debug("Sending TCP response: %s", response.hex())
   1268  1.1.1.2  christos             writer.write(response)
   1269  1.1.1.2  christos             await writer.drain()
   1270  1.1.1.2  christos 
   1271  1.1.1.5  christos     def _log_query(self, qctx: QueryContext) -> None:
   1272      1.1  christos         logging.info(
   1273  1.1.1.5  christos             "Received %s/%s/%s (ID=%d) query from %s on %s (%s)",
   1274      1.1  christos             qctx.qname.to_text(omit_final_dot=True),
   1275      1.1  christos             dns.rdataclass.to_text(qctx.qclass),
   1276      1.1  christos             dns.rdatatype.to_text(qctx.qtype),
   1277      1.1  christos             qctx.query.id,
   1278  1.1.1.5  christos             qctx.peer,
   1279  1.1.1.5  christos             qctx.socket,
   1280  1.1.1.5  christos             qctx.protocol.name,
   1281      1.1  christos         )
   1282      1.1  christos         logging.debug(
   1283      1.1  christos             "\n".join([f"[IN] {l}" for l in [""] + str(qctx.query).splitlines()])
   1284      1.1  christos         )
   1285      1.1  christos 
   1286      1.1  christos     def _log_response(
   1287  1.1.1.5  christos         self, qctx: QueryContext, response: dns.message.Message | bytes | None
   1288      1.1  christos     ) -> None:
   1289      1.1  christos         if not response:
   1290      1.1  christos             logging.info(
   1291  1.1.1.5  christos                 "Not sending a response to query (ID=%d) from %s on %s (%s)",
   1292      1.1  christos                 qctx.query.id,
   1293  1.1.1.5  christos                 qctx.peer,
   1294  1.1.1.5  christos                 qctx.socket,
   1295  1.1.1.5  christos                 qctx.protocol.name,
   1296      1.1  christos             )
   1297      1.1  christos             return
   1298      1.1  christos 
   1299      1.1  christos         if isinstance(response, dns.message.Message):
   1300      1.1  christos             try:
   1301      1.1  christos                 qname = response.question[0].name.to_text(omit_final_dot=True)
   1302      1.1  christos                 qclass = dns.rdataclass.to_text(response.question[0].rdclass)
   1303      1.1  christos                 qtype = dns.rdatatype.to_text(response.question[0].rdtype)
   1304      1.1  christos             except IndexError:
   1305      1.1  christos                 qname = "<empty>"
   1306      1.1  christos                 qclass = "-"
   1307      1.1  christos                 qtype = "-"
   1308      1.1  christos 
   1309      1.1  christos             logging.info(
   1310  1.1.1.5  christos                 "Sending %s/%s/%s (ID=%d) response (%d/%d/%d/%d) to a query (ID=%d) from %s on %s (%s)",
   1311      1.1  christos                 qname,
   1312      1.1  christos                 qclass,
   1313      1.1  christos                 qtype,
   1314      1.1  christos                 response.id,
   1315      1.1  christos                 len(response.question),
   1316      1.1  christos                 len(response.answer),
   1317      1.1  christos                 len(response.authority),
   1318      1.1  christos                 len(response.additional),
   1319      1.1  christos                 qctx.query.id,
   1320  1.1.1.5  christos                 qctx.peer,
   1321  1.1.1.5  christos                 qctx.socket,
   1322  1.1.1.5  christos                 qctx.protocol.name,
   1323      1.1  christos             )
   1324      1.1  christos             logging.debug(
   1325      1.1  christos                 "\n".join([f"[OUT] {l}" for l in [""] + str(response).splitlines()])
   1326      1.1  christos             )
   1327      1.1  christos             return
   1328      1.1  christos 
   1329      1.1  christos         logging.info(
   1330  1.1.1.5  christos             "Sending response (%d bytes) to a query (ID=%d) from %s on %s (%s)",
   1331      1.1  christos             len(response),
   1332      1.1  christos             qctx.query.id,
   1333  1.1.1.5  christos             qctx.peer,
   1334  1.1.1.5  christos             qctx.socket,
   1335  1.1.1.5  christos             qctx.protocol.name,
   1336      1.1  christos         )
   1337      1.1  christos         logging.debug("[OUT] %s", response.hex())
   1338      1.1  christos 
   1339      1.1  christos     async def _handle_query(
   1340  1.1.1.5  christos         self, wire: bytes, socket: Peer, peer: Peer, protocol: DnsProtocol
   1341      1.1  christos     ) -> AsyncGenerator[bytes, None]:
   1342      1.1  christos         """
   1343      1.1  christos         Yield wire data to send as a response over the established transport.
   1344      1.1  christos         """
   1345  1.1.1.2  christos         try:
   1346  1.1.1.4  christos             query = self._parse_message(wire)
   1347  1.1.1.2  christos         except dns.exception.DNSException as exc:
   1348  1.1.1.2  christos             logging.error("Invalid query from %s (%s): %s", peer, wire.hex(), exc)
   1349  1.1.1.2  christos             return
   1350  1.1.1.5  christos         response_stub = _make_asyncserver_response(query)
   1351  1.1.1.5  christos         qctx = QueryContext(query, response_stub, socket, peer, protocol)
   1352  1.1.1.5  christos         self._log_query(qctx)
   1353      1.1  christos         responses = self._prepare_responses(qctx)
   1354      1.1  christos         async for response in responses:
   1355  1.1.1.5  christos             self._log_response(qctx, response)
   1356      1.1  christos             if response:
   1357      1.1  christos                 if isinstance(response, dns.message.Message):
   1358      1.1  christos                     response = response.to_wire(max_size=65535)
   1359      1.1  christos                 if protocol == DnsProtocol.UDP:
   1360      1.1  christos                     yield response
   1361      1.1  christos                 else:
   1362      1.1  christos                     response_length = struct.pack("!H", len(response))
   1363      1.1  christos                     yield response_length + response
   1364      1.1  christos 
   1365  1.1.1.4  christos     def _parse_message(self, wire: bytes) -> dns.message.Message:
   1366  1.1.1.4  christos         try:
   1367  1.1.1.4  christos             if isinstance(self._keyring, _NoKeyringType):
   1368  1.1.1.4  christos                 keyring = None
   1369  1.1.1.4  christos             else:
   1370  1.1.1.4  christos                 keyring = self._keyring
   1371  1.1.1.4  christos             return dns.message.from_wire(wire, keyring=keyring)
   1372  1.1.1.4  christos         except dns.message.UnknownTSIGKey as exc:
   1373  1.1.1.4  christos             if isinstance(self._keyring, _NoKeyringType):
   1374  1.1.1.4  christos                 error = "TSIG-signed query received but no `keyring` was provided; "
   1375  1.1.1.4  christos                 error += "either provide a keyring (in which case the server will "
   1376  1.1.1.4  christos                 error += "ignore any TSIG-invalid queries), or set `keyring=None` "
   1377  1.1.1.4  christos                 error += "explicitly to disable TSIG validation altogether. "
   1378  1.1.1.4  christos                 error += "This requires some hacking around a dnspython bug, "
   1379  1.1.1.4  christos                 error += "so there may be unexpected side effects."
   1380  1.1.1.4  christos                 raise ValueError(error) from exc
   1381  1.1.1.4  christos             if self._keyring is None:
   1382  1.1.1.4  christos                 return _DnsMessageWithTsigDisabled.from_wire(wire)
   1383  1.1.1.4  christos             raise
   1384  1.1.1.4  christos 
   1385      1.1  christos     async def _prepare_responses(
   1386      1.1  christos         self, qctx: QueryContext
   1387  1.1.1.5  christos     ) -> AsyncGenerator[dns.message.Message | bytes | None, None]:
   1388      1.1  christos         """
   1389      1.1  christos         Yield response(s) either from response handlers or zone data.
   1390      1.1  christos         """
   1391  1.1.1.4  christos         qctx.response.set_rcode(self._default_rcode)
   1392  1.1.1.4  christos         if self._default_aa:
   1393  1.1.1.4  christos             qctx.response.flags |= dns.flags.AA
   1394  1.1.1.4  christos         qctx.save_initialized_response(with_zone_data=False)
   1395  1.1.1.4  christos 
   1396      1.1  christos         self._prepare_response_from_zone_data(qctx)
   1397  1.1.1.4  christos         qctx.save_initialized_response(with_zone_data=True)
   1398      1.1  christos 
   1399      1.1  christos         response_handled = False
   1400      1.1  christos         async for action in self._run_response_handlers(qctx):
   1401      1.1  christos             yield await action.perform()
   1402      1.1  christos             response_handled = True
   1403      1.1  christos 
   1404      1.1  christos         if not response_handled:
   1405  1.1.1.2  christos             logging.debug("Responding based on zone data")
   1406      1.1  christos             yield qctx.response
   1407      1.1  christos 
   1408      1.1  christos     def _prepare_response_from_zone_data(self, qctx: QueryContext) -> None:
   1409      1.1  christos         """
   1410      1.1  christos         Prepare a response to the query based on the available zone data.
   1411      1.1  christos 
   1412      1.1  christos         The functionality is split across smaller functions that modify the
   1413      1.1  christos         query context until a proper response is formed.
   1414      1.1  christos         """
   1415      1.1  christos         if self._refused_response(qctx):
   1416      1.1  christos             return
   1417      1.1  christos 
   1418      1.1  christos         if self._delegation_response(qctx):
   1419      1.1  christos             return
   1420      1.1  christos 
   1421      1.1  christos         qctx.response.flags |= dns.flags.AA
   1422      1.1  christos 
   1423      1.1  christos         if self._ent_response(qctx):
   1424      1.1  christos             return
   1425      1.1  christos 
   1426      1.1  christos         if self._nxdomain_response(qctx):
   1427      1.1  christos             return
   1428      1.1  christos 
   1429  1.1.1.3  christos         if self._cname_response(qctx):
   1430  1.1.1.3  christos             return
   1431  1.1.1.3  christos 
   1432      1.1  christos         if self._nodata_response(qctx):
   1433      1.1  christos             return
   1434      1.1  christos 
   1435      1.1  christos         self._noerror_response(qctx)
   1436      1.1  christos 
   1437      1.1  christos     def _refused_response(self, qctx: QueryContext) -> bool:
   1438  1.1.1.3  christos         zone = self._zone_tree.find_best_zone(qctx.current_qname)
   1439  1.1.1.3  christos         if zone:
   1440  1.1.1.3  christos             qctx.zone = zone
   1441      1.1  christos             return False
   1442      1.1  christos 
   1443  1.1.1.4  christos         # RCODE is already set to self._default_rcode, i.e. REFUSED by default;
   1444  1.1.1.4  christos         # it should also not be changed when following a CNAME chain
   1445      1.1  christos         return True
   1446      1.1  christos 
   1447      1.1  christos     def _delegation_response(self, qctx: QueryContext) -> bool:
   1448      1.1  christos         assert qctx.zone
   1449      1.1  christos 
   1450  1.1.1.3  christos         name = qctx.current_qname
   1451      1.1  christos         delegation = None
   1452      1.1  christos 
   1453      1.1  christos         while name != qctx.zone.origin:
   1454      1.1  christos             node = qctx.zone.get_node(name)
   1455      1.1  christos             if node:
   1456      1.1  christos                 delegation = node.get_rdataset(qctx.qclass, dns.rdatatype.NS)
   1457      1.1  christos                 if delegation:
   1458      1.1  christos                     break
   1459      1.1  christos             name = name.parent()
   1460      1.1  christos 
   1461      1.1  christos         if not delegation:
   1462      1.1  christos             return False
   1463      1.1  christos 
   1464      1.1  christos         delegation_rrset = dns.rrset.RRset(name, qctx.qclass, dns.rdatatype.NS)
   1465      1.1  christos         delegation_rrset.update(delegation)
   1466      1.1  christos 
   1467      1.1  christos         qctx.response.set_rcode(dns.rcode.NOERROR)
   1468      1.1  christos         qctx.response.authority.append(delegation_rrset)
   1469      1.1  christos 
   1470      1.1  christos         self._delegation_response_additional(qctx)
   1471      1.1  christos 
   1472      1.1  christos         return True
   1473      1.1  christos 
   1474      1.1  christos     def _delegation_response_additional(self, qctx: QueryContext) -> None:
   1475      1.1  christos         assert qctx.zone
   1476      1.1  christos         assert qctx.response.authority[0]
   1477      1.1  christos 
   1478      1.1  christos         for nameserver in qctx.response.authority[0]:
   1479      1.1  christos             if not nameserver.target.is_subdomain(qctx.response.authority[0].name):
   1480      1.1  christos                 continue
   1481      1.1  christos             glue_a = qctx.zone.get_rrset(nameserver.target, dns.rdatatype.A)
   1482      1.1  christos             if glue_a:
   1483      1.1  christos                 qctx.response.additional.append(glue_a)
   1484      1.1  christos             glue_aaaa = qctx.zone.get_rrset(nameserver.target, dns.rdatatype.AAAA)
   1485      1.1  christos             if glue_aaaa:
   1486      1.1  christos                 qctx.response.additional.append(glue_aaaa)
   1487      1.1  christos 
   1488      1.1  christos     def _ent_response(self, qctx: QueryContext) -> bool:
   1489      1.1  christos         assert qctx.zone
   1490      1.1  christos         assert qctx.zone.origin
   1491      1.1  christos 
   1492      1.1  christos         qctx.soa = qctx.zone.find_rrset(qctx.zone.origin, dns.rdatatype.SOA)
   1493      1.1  christos         assert qctx.soa
   1494      1.1  christos 
   1495  1.1.1.3  christos         qctx.node = qctx.zone.get_node(qctx.current_qname)
   1496      1.1  christos         if qctx.node or not any(
   1497  1.1.1.3  christos             n for n in qctx.zone.nodes if n.is_subdomain(qctx.current_qname)
   1498      1.1  christos         ):
   1499      1.1  christos             return False
   1500      1.1  christos 
   1501      1.1  christos         qctx.response.set_rcode(dns.rcode.NOERROR)
   1502      1.1  christos         qctx.response.authority.append(qctx.soa)
   1503      1.1  christos         return True
   1504      1.1  christos 
   1505      1.1  christos     def _nxdomain_response(self, qctx: QueryContext) -> bool:
   1506      1.1  christos         assert qctx.soa
   1507      1.1  christos 
   1508      1.1  christos         if qctx.node:
   1509      1.1  christos             return False
   1510      1.1  christos 
   1511      1.1  christos         qctx.response.set_rcode(dns.rcode.NXDOMAIN)
   1512      1.1  christos         qctx.response.authority.append(qctx.soa)
   1513      1.1  christos         return True
   1514      1.1  christos 
   1515  1.1.1.3  christos     def _cname_response(self, qctx: QueryContext) -> bool:
   1516  1.1.1.3  christos         assert qctx.node
   1517  1.1.1.3  christos 
   1518  1.1.1.3  christos         cname = qctx.node.get_rdataset(qctx.qclass, dns.rdatatype.CNAME)
   1519  1.1.1.3  christos         if not cname:
   1520  1.1.1.3  christos             return False
   1521  1.1.1.3  christos 
   1522  1.1.1.4  christos         qctx.response.set_rcode(dns.rcode.NOERROR)
   1523  1.1.1.3  christos         cname_rrset = dns.rrset.RRset(qctx.current_qname, qctx.qclass, cname.rdtype)
   1524  1.1.1.3  christos         cname_rrset.update(cname)
   1525  1.1.1.3  christos         qctx.response.answer.append(cname_rrset)
   1526  1.1.1.3  christos 
   1527  1.1.1.3  christos         qctx.alias = cname[0].target
   1528  1.1.1.3  christos         self._prepare_response_from_zone_data(qctx)
   1529  1.1.1.3  christos         return True
   1530  1.1.1.3  christos 
   1531      1.1  christos     def _nodata_response(self, qctx: QueryContext) -> bool:
   1532      1.1  christos         assert qctx.node
   1533      1.1  christos         assert qctx.soa
   1534      1.1  christos 
   1535      1.1  christos         qctx.answer = qctx.node.get_rdataset(qctx.qclass, qctx.qtype)
   1536      1.1  christos         if qctx.answer:
   1537      1.1  christos             return False
   1538      1.1  christos 
   1539      1.1  christos         qctx.response.set_rcode(dns.rcode.NOERROR)
   1540  1.1.1.3  christos         if not qctx.response.answer:
   1541  1.1.1.3  christos             qctx.response.authority.append(qctx.soa)
   1542      1.1  christos         return True
   1543      1.1  christos 
   1544      1.1  christos     def _noerror_response(self, qctx: QueryContext) -> None:
   1545      1.1  christos         assert qctx.answer
   1546      1.1  christos 
   1547  1.1.1.3  christos         answer_rrset = dns.rrset.RRset(qctx.current_qname, qctx.qclass, qctx.qtype)
   1548      1.1  christos         answer_rrset.update(qctx.answer)
   1549      1.1  christos 
   1550      1.1  christos         qctx.response.set_rcode(dns.rcode.NOERROR)
   1551      1.1  christos         qctx.response.answer.append(answer_rrset)
   1552      1.1  christos 
   1553      1.1  christos     async def _run_response_handlers(
   1554      1.1  christos         self, qctx: QueryContext
   1555      1.1  christos     ) -> AsyncGenerator[ResponseAction, None]:
   1556      1.1  christos         """
   1557      1.1  christos         Yield response(s) to the query from a matching query handler.
   1558      1.1  christos         """
   1559      1.1  christos         for handler in self._response_handlers:
   1560      1.1  christos             if handler.match(qctx):
   1561  1.1.1.2  christos                 logging.debug("Matched response handler: %s", handler)
   1562      1.1  christos                 async for response in handler.get_responses(qctx):
   1563      1.1  christos                     yield response
   1564      1.1  christos                 return
   1565  1.1.1.2  christos 
   1566  1.1.1.2  christos 
   1567  1.1.1.2  christos class ControllableAsyncDnsServer(AsyncDnsServer):
   1568  1.1.1.2  christos     """
   1569  1.1.1.2  christos     An AsyncDnsServer whose behavior can be dynamically changed by sending TXT
   1570  1.1.1.2  christos     queries to a "magic" domain.
   1571  1.1.1.2  christos     """
   1572  1.1.1.2  christos 
   1573  1.1.1.2  christos     _CONTROL_DOMAIN = "_control."
   1574  1.1.1.2  christos 
   1575  1.1.1.4  christos     @functools.cached_property
   1576  1.1.1.4  christos     def _control_domain(self) -> dns.name.Name:
   1577  1.1.1.4  christos         return dns.name.from_text(self._CONTROL_DOMAIN)
   1578  1.1.1.4  christos 
   1579  1.1.1.4  christos     @functools.cached_property
   1580  1.1.1.5  christos     def _commands(self) -> dict[dns.name.Name, "ControlCommand"]:
   1581  1.1.1.4  christos         return {}
   1582  1.1.1.4  christos 
   1583  1.1.1.5  christos     def install_control_commands(self, *commands: "ControlCommand") -> None:
   1584  1.1.1.4  christos         for command in commands:
   1585  1.1.1.4  christos             self.install_control_command(command)
   1586  1.1.1.4  christos 
   1587  1.1.1.4  christos     def install_control_command(self, command: "ControlCommand") -> None:
   1588  1.1.1.4  christos         command_subdomain = dns.name.Name([command.control_subdomain])
   1589  1.1.1.4  christos         control_subdomain = command_subdomain.concatenate(self._control_domain)
   1590  1.1.1.4  christos         try:
   1591  1.1.1.4  christos             existing_command = self._commands[control_subdomain]
   1592  1.1.1.4  christos         except KeyError:
   1593  1.1.1.4  christos             self._commands[control_subdomain] = command
   1594  1.1.1.4  christos         else:
   1595  1.1.1.4  christos             raise RuntimeError(
   1596  1.1.1.4  christos                 f"{control_subdomain} already handled by {existing_command}"
   1597  1.1.1.4  christos             )
   1598  1.1.1.2  christos 
   1599  1.1.1.2  christos     async def _prepare_responses(
   1600  1.1.1.2  christos         self, qctx: QueryContext
   1601  1.1.1.5  christos     ) -> AsyncGenerator[dns.message.Message | bytes | None, None]:
   1602  1.1.1.2  christos         """
   1603  1.1.1.2  christos         Detect and handle control queries, falling back to normal processing
   1604  1.1.1.2  christos         for non-control queries.
   1605  1.1.1.2  christos         """
   1606  1.1.1.2  christos         control_response = self._handle_control_command(qctx)
   1607  1.1.1.2  christos         if control_response:
   1608  1.1.1.2  christos             yield await DnsResponseSend(response=control_response).perform()
   1609  1.1.1.2  christos             return
   1610  1.1.1.2  christos 
   1611  1.1.1.2  christos         async for response in super()._prepare_responses(qctx):
   1612  1.1.1.2  christos             yield response
   1613  1.1.1.2  christos 
   1614  1.1.1.5  christos     def _handle_control_command(self, qctx: QueryContext) -> dns.message.Message | None:
   1615  1.1.1.2  christos         """
   1616  1.1.1.2  christos         Detect and handle control queries.
   1617  1.1.1.2  christos 
   1618  1.1.1.2  christos         A control query must be of type TXT; if it is not, a FORMERR response
   1619  1.1.1.2  christos         is sent back.
   1620  1.1.1.2  christos 
   1621  1.1.1.2  christos         The list of commands that the server should respond to is passed to its
   1622  1.1.1.2  christos         constructor.  If the server is unable to handle the control query using
   1623  1.1.1.2  christos         any of the enabled commands, an NXDOMAIN response is sent.
   1624  1.1.1.2  christos 
   1625  1.1.1.2  christos         Otherwise, the relevant command's handler is expected to provide the
   1626  1.1.1.2  christos         response via qctx.response and/or return a string that is converted to
   1627  1.1.1.2  christos         a TXT RRset inserted into the ANSWER section of the response to the
   1628  1.1.1.2  christos         control query.  The RCODE for a command-provided response defaults to
   1629  1.1.1.2  christos         NOERROR, but can be overridden by the command's handler.
   1630  1.1.1.2  christos         """
   1631  1.1.1.2  christos         if not qctx.qname.is_subdomain(self._control_domain):
   1632  1.1.1.2  christos             return None
   1633  1.1.1.2  christos 
   1634  1.1.1.2  christos         if qctx.qtype != dns.rdatatype.TXT:
   1635  1.1.1.2  christos             logging.error("Non-TXT control query %s from %s", qctx.qname, qctx.peer)
   1636  1.1.1.2  christos             qctx.response.set_rcode(dns.rcode.FORMERR)
   1637  1.1.1.2  christos             return qctx.response
   1638  1.1.1.2  christos 
   1639  1.1.1.2  christos         control_subdomain = dns.name.Name(qctx.qname.labels[-3:])
   1640  1.1.1.2  christos         try:
   1641  1.1.1.2  christos             command = self._commands[control_subdomain]
   1642  1.1.1.2  christos         except KeyError:
   1643  1.1.1.2  christos             logging.error("Unhandled control query %s from %s", qctx.qname, qctx.peer)
   1644  1.1.1.2  christos             qctx.response.set_rcode(dns.rcode.NXDOMAIN)
   1645  1.1.1.2  christos             return qctx.response
   1646  1.1.1.2  christos 
   1647  1.1.1.2  christos         logging.info("Received control query %s from %s", qctx.qname, qctx.peer)
   1648  1.1.1.2  christos         logging.debug("Handling control query %s using %s", qctx.qname, command)
   1649  1.1.1.2  christos         qctx.response.set_rcode(dns.rcode.NOERROR)
   1650  1.1.1.2  christos         qctx.response.flags |= dns.flags.AA
   1651  1.1.1.2  christos 
   1652  1.1.1.2  christos         command_qname = qctx.qname.relativize(control_subdomain)
   1653  1.1.1.2  christos         try:
   1654  1.1.1.2  christos             command_args = [l.decode("ascii") for l in command_qname.labels]
   1655  1.1.1.2  christos         except UnicodeDecodeError:
   1656  1.1.1.2  christos             logging.error("Non-ASCII control query %s from %s", qctx.qname, qctx.peer)
   1657  1.1.1.2  christos             qctx.response.set_rcode(dns.rcode.FORMERR)
   1658  1.1.1.2  christos             return qctx.response
   1659  1.1.1.2  christos 
   1660  1.1.1.2  christos         command_response = command.handle(command_args, self, qctx)
   1661  1.1.1.2  christos         if command_response:
   1662  1.1.1.2  christos             command_response_rrset = dns.rrset.from_text(
   1663  1.1.1.2  christos                 qctx.qname, 0, qctx.qclass, dns.rdatatype.TXT, f'"{command_response}"'
   1664  1.1.1.2  christos             )
   1665  1.1.1.2  christos             qctx.response.answer.append(command_response_rrset)
   1666  1.1.1.2  christos 
   1667  1.1.1.2  christos         return qctx.response
   1668  1.1.1.2  christos 
   1669  1.1.1.2  christos 
   1670  1.1.1.2  christos class ControlCommand(abc.ABC):
   1671  1.1.1.2  christos     """
   1672  1.1.1.2  christos     Base class for control commands.
   1673  1.1.1.2  christos 
   1674  1.1.1.2  christos     The derived class must define the control query subdomain that it handles
   1675  1.1.1.2  christos     and the callback that handles the control queries.
   1676  1.1.1.2  christos     """
   1677  1.1.1.2  christos 
   1678  1.1.1.2  christos     @property
   1679  1.1.1.2  christos     @abc.abstractmethod
   1680  1.1.1.2  christos     def control_subdomain(self) -> str:
   1681  1.1.1.2  christos         """
   1682  1.1.1.2  christos         The subdomain of the control domain handled by this command.  Needs to
   1683  1.1.1.2  christos         be defined as a string by the derived class.
   1684  1.1.1.2  christos         """
   1685  1.1.1.2  christos         raise NotImplementedError
   1686  1.1.1.2  christos 
   1687  1.1.1.2  christos     @abc.abstractmethod
   1688  1.1.1.2  christos     def handle(
   1689  1.1.1.5  christos         self, args: list[str], server: ControllableAsyncDnsServer, qctx: QueryContext
   1690  1.1.1.5  christos     ) -> str | None:
   1691  1.1.1.2  christos         """
   1692  1.1.1.2  christos         This method is expected to carry out arbitrary actions in response to a
   1693  1.1.1.2  christos         control query.  Note that it is invoked synchronously (it is not a
   1694  1.1.1.2  christos         coroutine).
   1695  1.1.1.2  christos 
   1696  1.1.1.2  christos         `args` is a list of arguments for the command extracted from the
   1697  1.1.1.2  christos         control query's QNAME; these arguments (and therefore the QNAME as
   1698  1.1.1.2  christos         well) must only contain ASCII characters.  For example, if a command's
   1699  1.1.1.2  christos         subdomain is `my-command`, control query `foo.bar.my-command._control.`
   1700  1.1.1.2  christos         causes `args` to be set to `["foo", "bar"]` while control query
   1701  1.1.1.2  christos         `my-command._control.` causes `args` to be set to `[]`.
   1702  1.1.1.2  christos 
   1703  1.1.1.2  christos         `server` is the server instance that received the control query.  This
   1704  1.1.1.2  christos         method can change the server's behavior by altering its response
   1705  1.1.1.2  christos         handler list using the appropriate methods.
   1706  1.1.1.2  christos 
   1707  1.1.1.2  christos         `qctx` is the query context for the control query.  By operating on
   1708  1.1.1.2  christos         qctx.response, this method can prepare the DNS response sent to
   1709  1.1.1.2  christos         the client in response to the control query.  Alternatively (or in
   1710  1.1.1.2  christos         addition to the above), it can also return a string; if it does, the
   1711  1.1.1.2  christos         returned string is converted to a TXT RRset that is inserted into the
   1712  1.1.1.2  christos         ANSWER section of the response to the control query.
   1713  1.1.1.2  christos         """
   1714  1.1.1.2  christos         raise NotImplementedError
   1715  1.1.1.2  christos 
   1716  1.1.1.2  christos     def __str__(self) -> str:
   1717  1.1.1.2  christos         return self.__class__.__name__
   1718  1.1.1.2  christos 
   1719  1.1.1.2  christos 
   1720  1.1.1.2  christos class ToggleResponsesCommand(ControlCommand):
   1721  1.1.1.2  christos     """
   1722  1.1.1.2  christos     Disable/enable sending responses from the server.
   1723  1.1.1.2  christos     """
   1724  1.1.1.2  christos 
   1725  1.1.1.2  christos     control_subdomain = "send-responses"
   1726  1.1.1.2  christos 
   1727  1.1.1.2  christos     def __init__(self) -> None:
   1728  1.1.1.5  christos         self._current_handler: IgnoreAllQueries | None = None
   1729  1.1.1.2  christos 
   1730  1.1.1.2  christos     def handle(
   1731  1.1.1.5  christos         self, args: list[str], server: ControllableAsyncDnsServer, qctx: QueryContext
   1732  1.1.1.5  christos     ) -> str | None:
   1733  1.1.1.2  christos         if len(args) != 1:
   1734  1.1.1.2  christos             logging.error("Invalid %s query %s", self, qctx.qname)
   1735  1.1.1.2  christos             qctx.response.set_rcode(dns.rcode.SERVFAIL)
   1736  1.1.1.2  christos             return "invalid query; use exactly one of 'enable' or 'disable' in QNAME"
   1737  1.1.1.2  christos 
   1738  1.1.1.2  christos         mode = args[0]
   1739  1.1.1.2  christos 
   1740  1.1.1.2  christos         if mode == "disable":
   1741  1.1.1.2  christos             if self._current_handler:
   1742  1.1.1.2  christos                 return "sending responses already disabled"
   1743  1.1.1.2  christos             self._current_handler = IgnoreAllQueries()
   1744  1.1.1.2  christos             server.install_response_handler(self._current_handler, prepend=True)
   1745  1.1.1.2  christos             return "sending responses disabled"
   1746  1.1.1.2  christos 
   1747  1.1.1.2  christos         if mode == "enable":
   1748  1.1.1.2  christos             if not self._current_handler:
   1749  1.1.1.2  christos                 return "sending responses already enabled"
   1750  1.1.1.2  christos             server.uninstall_response_handler(self._current_handler)
   1751  1.1.1.2  christos             self._current_handler = None
   1752  1.1.1.2  christos             return "sending responses enabled"
   1753  1.1.1.2  christos 
   1754  1.1.1.2  christos         logging.error("Unrecognized response sending mode '%s'", mode)
   1755  1.1.1.2  christos         qctx.response.set_rcode(dns.rcode.SERVFAIL)
   1756  1.1.1.2  christos         return f"unrecognized response sending mode '{mode}'"
   1757  1.1.1.5  christos 
   1758  1.1.1.5  christos 
   1759  1.1.1.5  christos class SwitchControlCommand(ControlCommand):
   1760  1.1.1.5  christos     """
   1761  1.1.1.5  christos     Switch the server's response handlers based on the control query.
   1762  1.1.1.5  christos 
   1763  1.1.1.5  christos     A sequence of response handlers is associated with each key.  When a
   1764  1.1.1.5  christos     control query is received, the server's response handlers are replaced
   1765  1.1.1.5  christos     with the sequence associated with the key extracted from the control
   1766  1.1.1.5  christos     query.
   1767  1.1.1.5  christos     """
   1768  1.1.1.5  christos 
   1769  1.1.1.5  christos     control_subdomain = "switch"
   1770  1.1.1.5  christos 
   1771  1.1.1.5  christos     def __init__(self, handler_mapping: dict[str, Sequence[ResponseHandler]]):
   1772  1.1.1.5  christos         self._handler_mapping = handler_mapping
   1773  1.1.1.5  christos 
   1774  1.1.1.5  christos     def handle(
   1775  1.1.1.5  christos         self, args: list[str], server: ControllableAsyncDnsServer, qctx: QueryContext
   1776  1.1.1.5  christos     ) -> str | None:
   1777  1.1.1.5  christos         if len(args) != 1 or args[0] not in self._handler_mapping:
   1778  1.1.1.5  christos             logging.error("Invalid %s query %s", self, qctx.qname)
   1779  1.1.1.5  christos             qctx.response.set_rcode(dns.rcode.SERVFAIL)
   1780  1.1.1.5  christos             return f"invalid query; exactly one of {list(self._handler_mapping.keys())} is expected in QNAME"
   1781  1.1.1.5  christos 
   1782  1.1.1.5  christos         server.replace_response_handlers(*self._handler_mapping[args[0]])
   1783  1.1.1.5  christos         return f"switched to handler set '{args[0]}'"
   1784