Home | History | Annotate | Line # | Download | only in ans2
      1 """
      2 Copyright (C) Internet Systems Consortium, Inc. ("ISC")
      3 
      4 SPDX-License-Identifier: MPL-2.0
      5 
      6 This Source Code Form is subject to the terms of the Mozilla Public
      7 License, v. 2.0.  If a copy of the MPL was not distributed with this
      8 file, you can obtain one at https://mozilla.org/MPL/2.0/.
      9 
     10 See the COPYRIGHT file distributed with this work for additional
     11 information regarding copyright ownership.
     12 """
     13 
     14 from collections.abc import AsyncGenerator, Collection, Iterable
     15 
     16 import abc
     17 
     18 import dns.rcode
     19 import dns.rdataclass
     20 import dns.rdatatype
     21 import dns.rrset
     22 
     23 from isctest.asyncserver import (
     24     ControllableAsyncDnsServer,
     25     DnsResponseSend,
     26     QueryContext,
     27     ResponseHandler,
     28     SwitchControlCommand,
     29 )
     30 
     31 
     32 def rrset(owner: str, rdtype: dns.rdatatype.RdataType, rdata: str) -> dns.rrset.RRset:
     33     return dns.rrset.from_text(
     34         owner,
     35         300,
     36         dns.rdataclass.IN,
     37         rdtype,
     38         rdata,
     39     )
     40 
     41 
     42 def soa(serial: int, *, owner: str = "nil.") -> dns.rrset.RRset:
     43     return rrset(
     44         owner,
     45         dns.rdatatype.SOA,
     46         f"ns.nil. root.nil. {serial} 300 300 604800 300",
     47     )
     48 
     49 
     50 def ns() -> dns.rrset.RRset:
     51     return rrset(
     52         "nil.",
     53         dns.rdatatype.NS,
     54         "ns.nil.",
     55     )
     56 
     57 
     58 def a(address: str, *, owner: str) -> dns.rrset.RRset:
     59     return rrset(
     60         owner,
     61         dns.rdatatype.A,
     62         address,
     63     )
     64 
     65 
     66 def txt(data: str, *, owner: str = "nil.") -> dns.rrset.RRset:
     67     return rrset(
     68         owner,
     69         dns.rdatatype.TXT,
     70         f'"{data}"',
     71     )
     72 
     73 
     74 class SoaHandler(ResponseHandler):
     75     def __init__(self, serial: int):
     76         self._serial = serial
     77 
     78     def match(self, qctx: QueryContext) -> bool:
     79         return qctx.qtype == dns.rdatatype.SOA
     80 
     81     async def get_responses(
     82         self, qctx: QueryContext
     83     ) -> AsyncGenerator[DnsResponseSend, None]:
     84         qctx.response.answer.append(soa(self._serial))
     85         yield DnsResponseSend(qctx.response)
     86 
     87 
     88 class AxfrHandler(ResponseHandler):
     89     @property
     90     @abc.abstractmethod
     91     def answers(self) -> Iterable[Collection[dns.rrset.RRset]]:
     92         """
     93         Answer sections of response packets sent in response to
     94         AXFR queries.
     95         """
     96         raise NotImplementedError
     97 
     98     def match(self, qctx: QueryContext) -> bool:
     99         return qctx.qtype == dns.rdatatype.AXFR
    100 
    101     async def get_responses(
    102         self, qctx: QueryContext
    103     ) -> AsyncGenerator[DnsResponseSend, None]:
    104         for answer in self.answers:
    105             response = qctx.prepare_new_response()
    106             for rrset_ in answer:
    107                 response.answer.append(rrset_)
    108             yield DnsResponseSend(response)
    109 
    110 
    111 class IxfrHandler(ResponseHandler):
    112     @property
    113     @abc.abstractmethod
    114     def answer(self) -> Collection[dns.rrset.RRset]:
    115         """
    116         Answer section of a response packet sent in response to
    117         IXFR queries.
    118         """
    119         raise NotImplementedError
    120 
    121     def match(self, qctx: QueryContext) -> bool:
    122         return qctx.qtype == dns.rdatatype.IXFR
    123 
    124     async def get_responses(
    125         self, qctx: QueryContext
    126     ) -> AsyncGenerator[DnsResponseSend, None]:
    127         for rrset_ in self.answer:
    128             qctx.response.answer.append(rrset_)
    129         yield DnsResponseSend(qctx.response)
    130 
    131 
    132 class InitialAfxrHandler(AxfrHandler):
    133     answers = (
    134         (soa(1),),
    135         (
    136             ns(),
    137             txt("initial AXFR"),
    138             a("10.0.0.61", owner="a.nil."),
    139             a("10.0.0.62", owner="b.nil."),
    140         ),
    141         (soa(1),),
    142     )
    143 
    144 
    145 class SuccessfulIfxrHandler(IxfrHandler):
    146     answer = (
    147         soa(3),
    148         soa(1),
    149         a("10.0.0.61", owner="a.nil."),
    150         txt("initial AXFR"),
    151         soa(2),
    152         txt("successful IXFR"),
    153         a("10.0.1.61", owner="a.nil."),
    154         soa(2),
    155         soa(3),
    156         soa(3),
    157     )
    158 
    159 
    160 class NotExactIxfrHandler(IxfrHandler):
    161     answer = (
    162         soa(4),
    163         soa(3),
    164         txt("delete-nonexistent-txt-record"),
    165         soa(4),
    166         txt("this-txt-record-would-be-added"),
    167         soa(4),
    168     )
    169 
    170 
    171 class FallbackNotExactAxfrHandler(AxfrHandler):
    172     answers = (
    173         (soa(3),),
    174         (
    175             ns(),
    176             txt("fallback AXFR"),
    177         ),
    178         (soa(3),),
    179     )
    180 
    181 
    182 class TooManyRecordsIxfrHandler(IxfrHandler):
    183     answer = (
    184         soa(4),
    185         soa(3),
    186         soa(4),
    187         txt("text 1"),
    188         txt("text 2"),
    189         txt("text 3"),
    190         txt("text 4"),
    191         txt("text 5"),
    192         txt("text 6: causing too many records"),
    193         soa(4),
    194     )
    195 
    196 
    197 class FallbackTooManyRecordsAxfrHandler(AxfrHandler):
    198     answers = (
    199         (
    200             soa(3),
    201             ns(),
    202             txt("fallback AXFR on too many records"),
    203         ),
    204         (soa(3),),
    205     )
    206 
    207 
    208 class BadSoaOwnerIxfrHandler(IxfrHandler):
    209     answer = (
    210         soa(4),
    211         soa(3),
    212         soa(4, owner="bad-owner."),
    213         txt("serial 4, malformed IXFR", owner="test.nil."),
    214         soa(4),
    215     )
    216 
    217 
    218 class FallbackBadSoaOwnerAxfrHandler(AxfrHandler):
    219     answers = (
    220         (soa(4),),
    221         (
    222             ns(),
    223             txt("serial 4, fallback AXFR", owner="test.nil."),
    224         ),
    225         (soa(4),),
    226     )
    227 
    228 
    229 def main() -> None:
    230     server = ControllableAsyncDnsServer(
    231         default_aa=True, default_rcode=dns.rcode.NOERROR
    232     )
    233     switch_command = SwitchControlCommand(
    234         {
    235             "initial_axfr": (
    236                 SoaHandler(1),
    237                 InitialAfxrHandler(),
    238             ),
    239             "successful_ixfr": (
    240                 SoaHandler(3),
    241                 SuccessfulIfxrHandler(),
    242             ),
    243             "not_exact": (
    244                 SoaHandler(4),
    245                 NotExactIxfrHandler(),
    246                 FallbackNotExactAxfrHandler(),
    247             ),
    248             "too_many_records": (
    249                 SoaHandler(4),
    250                 TooManyRecordsIxfrHandler(),
    251                 FallbackTooManyRecordsAxfrHandler(),
    252             ),
    253             "bad_soa_owner": (
    254                 SoaHandler(4),
    255                 BadSoaOwnerIxfrHandler(),
    256                 FallbackBadSoaOwnerAxfrHandler(),
    257             ),
    258         }
    259     )
    260     server.install_control_command(switch_command)
    261     server.run()
    262 
    263 
    264 if __name__ == "__main__":
    265     main()
    266