Home | History | Annotate | Line # | Download | only in DSO
      1 /* dso.c
      2  *
      3  * Copyright (c) 2018-2024 Apple Inc. All rights reserved.
      4  *
      5  * Licensed under the Apache License, Version 2.0 (the "License");
      6  * you may not use this file except in compliance with the License.
      7  * You may obtain a copy of the License at
      8  *
      9  *     https://www.apache.org/licenses/LICENSE-2.0
     10  *
     11  * Unless required by applicable law or agreed to in writing, software
     12  * distributed under the License is distributed on an "AS IS" BASIS,
     13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     14  * See the License for the specific language governing permissions and
     15  * limitations under the License.
     16  *
     17  */
     18 
     19 //*************************************************************************************************************
     20 // Headers
     21 
     22 #include <stdio.h>
     23 #include <signal.h>
     24 #include <stdlib.h>
     25 #include <stdbool.h>
     26 #include <stdlib.h>
     27 #include <unistd.h>
     28 #include <string.h>
     29 #include <assert.h>
     30 
     31 #include <netdb.h>           // For gethostbyname()
     32 #include <sys/socket.h>      // For AF_INET, AF_INET6, etc.
     33 #include <net/if.h>          // For IF_NAMESIZE
     34 #include <netinet/in.h>      // For INADDR_NONE
     35 #include <netinet/tcp.h>     // For SOL_TCP, TCP_NOTSENT_LOWAT
     36 #include <arpa/inet.h>       // For inet_addr()
     37 #include <unistd.h>
     38 #include <errno.h>
     39 #include <fcntl.h>
     40 
     41 #include "DNSCommon.h"
     42 #include "mDNSEmbeddedAPI.h"
     43 #include "PlatformCommon.h"
     44 #include "dso.h"
     45 #include "DebugServices.h"   // For check_compile_time
     46 
     47 #ifdef STANDALONE
     48 #undef LogMsg
     49 #define LogMsg INFO
     50 
     51 #include "srp-log.h"
     52 extern uint16_t srp_random16(void);
     53 #define mDNSRandom(x) srp_random16()
     54 #define mDNSPlatformMemAllocateClear(length) mdns_calloc(1, length)
     55 #else // STANDALONE
     56 
     57 // This is only a temporary fix to let the code in this file print unredacted logs.
     58 
     59 #include "srp-log.h"
     60 #undef FAULT
     61 #undef INFO
     62         #define FAULT(fmt, ...)
     63         #define INFO(fmt, ...)
     64 
     65 #endif // STANDALONE
     66 
     67 #include "mdns_strict.h"
     68 
     69 //*************************************************************************************************************
     70 // Remaining work TODO
     71 
     72 // - Add keepalive/inactivity timeout support
     73 // - Notice if it takes a long time to get a response when establishing a session, and treat that
     74 //   as "DSO not supported."
     75 // - TLS support
     76 // - Actually use Network Framework
     77 
     78 
     79 //*************************************************************************************************************
     80 // Globals
     81 
     82 // List of dso connection states that are active. Added when dso_connect_state_create() is called, removed
     83 // when dso_state_cancel() is called. Removals are moved to dso_connections_needing_cleanup for cleanup during
     84 // the idle loop.
     85 // The list of connection states is not declared static so that the discovery proxy can access it as part of
     86 // the "start-dropping-push" test.
     87 dso_state_t *dso_connections;
     88 static dso_state_t *dso_connections_needing_cleanup; // DSO connections that have been shut down but aren't yet freed.
     89 
     90 dso_state_t *dso_find_by_serial(uint32_t serial)
     91 {
     92     dso_state_t *dsop;
     93 
     94     for (dsop = dso_connections; dsop; dsop = dsop->next) {
     95         if (dsop->serial == serial) {
     96             return dsop;
     97         }
     98     }
     99     return NULL;
    100 }
    101 
    102 // This function is called either when an error has occurred requiring the a DSO connection be
    103 // canceled, or else when a connection to a DSO endpoint has been cleanly closed and is ready to be
    104 // canceled for that reason.
    105 
    106 void dso_state_cancel(dso_state_t *dso)
    107 {
    108     dso_state_t **dsop = &dso_connections;
    109     bool status = true;
    110 
    111     // Find dso on the list of connections.
    112     while (*dsop != NULL && *dsop != dso) {
    113         dsop = &(*dsop)->next;
    114     }
    115 
    116     // If we get to the end of the list without finding dso, it means that it's already
    117     // been dropped.
    118     if (*dsop == NULL) {
    119         return;
    120     }
    121 
    122     // When the dso_state_t is canceled, its context may also need to be canceled/released/freed, so we give context a
    123     // callback to do the cleaning work with dso_life_cycle_cancel state.
    124     if (dso->context_callback != NULL) {
    125         status = dso->context_callback(dso_life_cycle_cancel, dso->context, dso);
    126     }
    127 
    128     // If the callback returns a status of true, then we want to free the dso object in the idle loop.
    129     if (status) {
    130         // Remove dso from the list of active dso objects.
    131         *dsop = dso->next;
    132 
    133         // Add it to the list of dso objects needing cleanup.
    134         dso->next = dso_connections_needing_cleanup;
    135         dso_connections_needing_cleanup = dso;
    136     }
    137 }
    138 
    139 void dso_cleanup(bool call_callbacks)
    140 {
    141     dso_state_t *dso, *dnext;
    142     dso_activity_t *ap, *anext;
    143 
    144     for (dso = dso_connections_needing_cleanup; dso; dso = dnext) {
    145         dnext = dso->next;
    146         // Finalize and then free any activities.
    147         for (ap = dso->activities; ap; ap = anext) {
    148             anext = ap->next;
    149             if (ap->finalize) {
    150                 ap->finalize(ap);
    151             }
    152             mdns_free(ap);
    153         }
    154         if (dso->transport != NULL && dso->transport_finalize != NULL) {
    155             dso->transport_finalize(dso->transport, "dso_idle");
    156             dso->transport = NULL;
    157         }
    158         LogMsg("[DSO%u] dso_state_t finalizing - "
    159                "dso: %p, remote name: %s, dso->context: %p", dso->serial, dso, dso->remote_name, dso->context);
    160         if (dso->cb && call_callbacks) {
    161             // Because dso->context is the DNSPushServer that uses the current dso_state_t *dso
    162             // (server->connection) and the server has been canceled by CancelDNSPushServer(), the
    163             // current dso is not used and cannot be recovered (or reconnected). The only thing we can do is to finalize
    164             // it.
    165             dso->cb(dso->context, NULL, dso, kDSOEventType_Finalize);
    166         } else {
    167             if (dso->additl != dso->additl_buf) {
    168                 mdns_free(dso->additl);
    169             }
    170             mdns_free(dso);
    171         }
    172         // Do not touch dso after this point, because it has been freed.
    173     }
    174     dso_connections_needing_cleanup = NULL;
    175 }
    176 
    177 int32_t dso_idle(void *context, int32_t now, int32_t next_timer_event)
    178 {
    179     dso_state_t *dso, *dnext;
    180 
    181     dso_cleanup(true);
    182 
    183     // Do keepalives.
    184     for (dso = dso_connections; dso; dso = dnext) {
    185         dnext = dso->next;
    186         if (dso->inactivity_due == 0) {
    187             if (dso->inactivity_timeout != 0) {
    188                 dso->inactivity_due = NonZeroTime(now + (event_time_t)MIN(dso->inactivity_timeout, INT32_MAX));
    189                 if (next_timer_event - dso->inactivity_due > 0) {
    190                     next_timer_event = dso->inactivity_due;
    191                 }
    192             }
    193         } else if (now - dso->inactivity_due > 0 && dso->cb != NULL) {
    194             dso->cb(dso->context, 0, dso, kDSOEventType_Inactive);
    195             // Should not touch the current dso_state_t after we deliver kDSOEventType_Inactive event, because it is
    196             // possible that the current dso_state_t has been canceled in the callback. Doing any operation to update
    197             // its status will not work as expected.
    198             continue;
    199         }
    200         if (dso->keepalive_due != 0 && dso->keepalive_due - now < 0 && dso->cb != NULL) {
    201             dso_keepalive_context_t kc;
    202             memset(&kc, 0, sizeof kc);
    203             dso->cb(dso->context, &kc, dso, kDSOEventType_Keepalive);
    204             dso->keepalive_due = NonZeroTime(now + (event_time_t)MIN(dso->keepalive_interval, INT32_MAX));
    205             if (next_timer_event - dso->keepalive_due > 0) {
    206                 next_timer_event = dso->keepalive_due;
    207             }
    208         }
    209     }
    210     return dso_transport_idle(context, now, next_timer_event);
    211 }
    212 
    213 void dso_set_event_context(dso_state_t *dso, void *context)
    214 {
    215     dso->context = context;
    216 }
    217 
    218 void dso_set_life_cycle_callback(dso_state_t *dso, dso_life_cycle_context_callback_t callback)
    219 {
    220     dso->context_callback = callback;
    221 }
    222 
    223 void dso_set_event_callback(dso_state_t *dso, dso_event_callback_t callback)
    224 {
    225     dso->cb = callback;
    226 }
    227 
    228 // Called when something happens that establishes a DSO session.
    229 static void dso_session_established(dso_state_t *dso)
    230 {
    231     LogMsg("[DSO%u] DSO session established - dso: %p, remote name: %s.", dso->serial, dso, dso->remote_name);
    232     dso->has_session = true;
    233     // Set up inactivity timer and keepalive timer...
    234 }
    235 
    236 // Create a dso_state_t structure
    237 dso_state_t *dso_state_create(bool is_server, int max_outstanding_queries, const char *remote_name,
    238                               dso_event_callback_t callback, void *const context,
    239                               const dso_life_cycle_context_callback_t context_callback,
    240                               dso_transport_t *transport)
    241 {
    242     dso_state_t *dso;
    243     size_t namelen = strlen(remote_name);
    244     size_t namespace = namelen + 1;
    245     const size_t outsize = (sizeof (dso_outstanding_query_state_t)) + (size_t)max_outstanding_queries * sizeof (dso_outstanding_query_t);
    246 
    247     if ((sizeof (*dso) + outsize + namespace) > UINT_MAX) {
    248         FAULT("Fatal: sizeof (*dso)[%zd], outsize[%zd], "
    249                   "namespace[%zd]", sizeof (*dso), outsize, namespace);
    250         dso = NULL;
    251         goto out;
    252     }
    253     // We allocate everything in a single hunk so that we can free it together as well.
    254     dso = (dso_state_t *) mDNSPlatformMemAllocateClear((uint32_t)((sizeof (*dso)) + outsize + namespace));
    255     if (dso == NULL) {
    256         goto out;
    257     }
    258     dso->outstanding_queries = (dso_outstanding_query_state_t *)(dso + 1);
    259     dso->outstanding_queries->max_outstanding_queries = max_outstanding_queries;
    260 
    261     dso->remote_name = ((char *)dso->outstanding_queries) + outsize;
    262     memcpy(dso->remote_name, remote_name, namelen);
    263     dso->remote_name[namelen] = 0;
    264 
    265     dso->cb = callback;
    266     if (context != NULL) {
    267         dso->context = context;
    268     }
    269     if (context_callback != NULL) {
    270         dso->context_callback = context_callback;
    271         // When dso_state_t is created, the context it holds may need to be reference counted, for example, to retain
    272         // the context. Here we give the context a callback with dso_life_cycle_create state.
    273         context_callback(dso_life_cycle_create, context, dso);
    274     }
    275     dso->transport = transport;
    276     dso->is_server = is_server;
    277 
    278     // Used to uniquely mark dso_state_t objects, incremented once for each dso_state_t created.
    279     // DSO_STATE_INVALID_SERIAL(0) is used to identify invalid dso_state_t.
    280     static uint32_t dso_state_serial = DSO_STATE_INVALID_SERIAL + 1;
    281     dso->serial = dso_state_serial++;
    282 
    283     // Set up additional additional pointer.
    284     dso->additl = dso->additl_buf;
    285     dso->max_additls = MAX_ADDITLS;
    286 
    287     dso->keepalive_interval = 3600 * MSEC_PER_SEC;
    288     dso->inactivity_timeout = 15 * MSEC_PER_SEC;
    289 
    290     dso->next = dso_connections;
    291     dso_connections = dso;
    292 
    293     LogMsg("[DSO%u] New dso_state_t created - dso: %p, remote name: %s, context: %p",
    294            dso->serial, dso, remote_name, context);
    295 out:
    296     return dso;
    297 }
    298 
    299 // Start building a TLV in an outgoing dso message.
    300 void dso_start_tlv(dso_message_t *state, int opcode)
    301 {
    302     // Make sure there's room for the length and the TLV opcode.
    303     if (state->cur + 4 >= state->max) {
    304         LogMsg("dso_start_tlv called when no space in output buffer!");
    305         assert(0);
    306     }
    307 
    308     // We need to not yet have a TLV.
    309     if (state->building_tlv) {
    310         LogMsg("dso_start_tlv called while already building a TLV!");
    311         assert(0);
    312     }
    313     state->building_tlv = true;
    314     state->tlv_len = 0;
    315 
    316     // Set up the TLV header.
    317     state->buf[state->cur] = (uint8_t)(opcode >> 8);
    318     state->buf[state->cur + 1] = opcode & 255;
    319     state->tlv_len_offset = state->cur + 2;
    320     state->cur += 4;
    321 }
    322 
    323 // Add some bytes to a TLV that's being built, but don't copy them--just remember the
    324 // pointer to the buffer.   This is used so that when we have a message to forward, we
    325 // don't copy it into the output buffer--we just use scatter/gather I/O.
    326 void dso_add_tlv_bytes_no_copy(dso_message_t *state, const uint8_t *bytes, size_t len)
    327 {
    328     if (!state->building_tlv) {
    329         LogMsg("add_tlv_bytes called when not building a TLV!");
    330         assert(0);
    331     }
    332     if (state->no_copy_bytes_len) {
    333         LogMsg("add_tlv_bytesNoCopy called twice on the same DSO message.");
    334         assert(0);
    335     }
    336     state->no_copy_bytes_len = len;
    337     state->no_copy_bytes = bytes;
    338     state->no_copy_bytes_offset = state->cur;
    339     state->tlv_len += len;
    340 }
    341 
    342 // Add some bytes to a TLV that's being built.
    343 void dso_add_tlv_bytes(dso_message_t *state, const uint8_t *bytes, size_t len)
    344 {
    345     if (!state->building_tlv) {
    346         LogMsg("add_tlv_bytes called when not building a TLV!");
    347         assert(0);
    348     }
    349     if (state->cur + len > state->max) {
    350         LogMsg("add_tlv_bytes called with no room in output buffer.");
    351         assert(0);
    352     }
    353     memcpy(&state->buf[state->cur], bytes, len);
    354     state->cur += len;
    355     state->tlv_len += len;
    356 }
    357 
    358 // Add a single byte to a TLV that's being built.
    359 void dso_add_tlv_byte(dso_message_t *state, uint8_t byte)
    360 {
    361     if (!state->building_tlv) {
    362         LogMsg("dso_add_tlv_byte called when not building a TLV!");
    363         assert(0);
    364     }
    365     if (state->cur + 1 > state->max) {
    366         LogMsg("dso_add_tlv_byte called with no room in output buffer.");
    367         assert(0);
    368     }
    369     state->buf[state->cur++] = byte;
    370     state->tlv_len++;
    371 }
    372 
    373 // Add an uint16_t to a TLV that's being built.
    374 void dso_add_tlv_u16(dso_message_t *state, uint16_t u16)
    375 {
    376     if (!state->building_tlv) {
    377         LogMsg("dso_add_tlv_u16 called when not building a TLV!");
    378         assert(0);
    379     }
    380     if ((state->cur + sizeof u16) > state->max) {
    381         LogMsg("dso_add_tlv_u16 called with no room in output buffer.");
    382         assert(0);
    383     }
    384     state->buf[state->cur++] = u16 >> 8;
    385     state->buf[state->cur++] = u16 & 255;
    386     state->tlv_len += 2;
    387 }
    388 
    389 // Add an uint32_t to a TLV that's being built.
    390 void dso_add_tlv_u32(dso_message_t *state, uint32_t u32)
    391 {
    392     if (!state->building_tlv) {
    393         LogMsg("dso_add_tlv_u32 called when not building a TLV!");
    394         assert(0);
    395     }
    396     if ((state->cur + sizeof u32) > state->max) {
    397         LogMsg("dso_add_tlv_u32 called with no room in output buffer.");
    398         assert(0);
    399     }
    400     state->buf[state->cur++] = u32 >> 24;
    401     state->buf[state->cur++] = (u32 >> 16) & 255;
    402     state->buf[state->cur++] = (u32 >> 8) & 255;
    403     state->buf[state->cur++] = u32 & 255;
    404     state->tlv_len += 4;
    405 }
    406 
    407 // Finish building a TLV.
    408 void dso_finish_tlv(dso_message_t *state)
    409 {
    410     if (!state->building_tlv) {
    411         LogMsg("dso_finish_tlv called when not building a TLV!");
    412         assert(0);
    413     }
    414 
    415     // A TLV can't be longer than this.
    416     if (state->tlv_len > 65535) {
    417         LogMsg("dso_finish_tlv was given more than 65535 bytes of TLV payload!");
    418         assert(0);
    419     }
    420     state->buf[state->tlv_len_offset] = (uint8_t)(state->tlv_len >> 8);
    421     state->buf[state->tlv_len_offset + 1] = state->tlv_len & 255;
    422     state->tlv_len = 0;
    423     state->building_tlv = false;
    424 }
    425 
    426 dso_activity_t *NULLABLE dso_find_activity(dso_state_t *const NONNULL dso, const char *const NULLABLE name,
    427                                   const char *const NONNULL activity_type, void *const NULLABLE context)
    428 {
    429     dso_activity_t *activity;
    430 
    431     // If we haven't been given something to search for, don't search.
    432     if (name == NULL && context == NULL) {
    433         FAULT("[DSO%u] Cannot search for activity with name and context both equal to NULL - "
    434               "activity_type: " PUB_S_SRP ".", dso->serial, activity_type);
    435         activity = NULL;
    436         goto exit;
    437     }
    438 
    439     for (activity = dso->activities; activity != NULL; activity = activity->next) {
    440         if (activity->activity_type != activity_type) {
    441             continue;
    442         }
    443 
    444         if (name != NULL) {
    445             // If name is specified, always use the name to search for the corresponding activity, even if context is
    446             // also specified.
    447             if (activity->name == NULL) {
    448                 continue;
    449             }
    450             if (strcmp(name, activity->name) != 0) {
    451                 continue;
    452             }
    453             // If the name matches, the corresponding context should also match if the context is not NULL.
    454             if (context != NULL && activity->context != context) {
    455                 FAULT("[DSO%u] The activity specified by the name does not have the expected context - "
    456                     "name: " PRI_S_SRP ", activity_type: " PUB_S_SRP ", context: %p.", dso->serial, name, activity_type,
    457                     context);
    458             }
    459         } else {
    460             // name == NULL && context != NULL
    461             // If name is not specified, use context to search for the activity.
    462             if (context != activity->context) {
    463                 continue;
    464             }
    465         }
    466 
    467         break;
    468     }
    469 
    470 exit:
    471     return activity;
    472 }
    473 
    474 // Make an activity structure to hang off the DSO.
    475 dso_activity_t *dso_add_activity(dso_state_t *dso, const char *name, const char *activity_type,
    476                                  void *context, void (*finalize)(dso_activity_t *))
    477 {
    478     size_t namelen = name ? strlen(name) + 1 : 0;
    479     size_t len;
    480     dso_activity_t *activity;
    481     void *ap;
    482 
    483     // Shouldn't add an activity that's already been added.
    484     activity = dso_find_activity(dso, name, activity_type, context);
    485     if (activity != NULL) {
    486         FAULT("[DSO%u] Trying to add a duplicate activity - activity name: " PRI_S_SRP ", activity type: " PUB_S_SRP
    487             ", activity context: %p.", dso->serial, name, activity_type, context);
    488         return NULL;
    489     }
    490 
    491     len = namelen + sizeof *activity;
    492     ap = mDNSPlatformMemAllocateClear((mDNSu32)len);
    493     if (ap == NULL) {
    494         return NULL;
    495     }
    496     activity = (dso_activity_t *)ap;
    497     ap = (char *)ap + sizeof *activity;
    498 
    499     // Activities can be identified either by name or by context
    500     if (namelen) {
    501         activity->name = ap;
    502         memcpy(activity->name, name, namelen);
    503     } else {
    504         activity->name = NULL;
    505     }
    506     activity->context = context;
    507 
    508     // Activity type is expected to be a string constant; all activities of the same type must
    509     // reference the same constant, not different constants with the same contents.
    510     activity->activity_type = activity_type;
    511     activity->finalize = finalize;
    512 
    513     INFO("[DSO%u] Adding a DSO activity - activity name: " PRI_S_SRP ", activity type: " PUB_S_SRP
    514         ", activity context: %p.", dso->serial, activity->name, activity->activity_type, activity->context);
    515 
    516     // Retain this activity on the list.
    517     activity->next = dso->activities;
    518     dso->activities = activity;
    519 
    520     return activity;
    521 }
    522 
    523 void dso_drop_activity(dso_state_t *dso, dso_activity_t *activity)
    524 {
    525     dso_activity_t **app = &dso->activities;
    526     bool matched = false;
    527 
    528     // Remove this activity from the list.
    529     while (*app) {
    530         if (*app == activity) {
    531             *app = activity->next;
    532             matched = true;
    533             break;
    534         } else {
    535             app = &((*app)->next);
    536         }
    537     }
    538 
    539     // If an activity that's not on the DSO list is passed here, it's an internal consistency
    540     // error that probably indicates something is corrupted.
    541     if (!matched) {
    542         FAULT("[DSO%u] Trying to remove an activity that is not in the list - "
    543             "activity name: " PRI_S_SRP ", activity type: " PUB_S_SRP ", activity context: %p.",
    544             dso->serial, activity->name, activity->activity_type, activity->context);
    545     }
    546     INFO("[DSO%u] Removing a DSO activity - activity name: " PRI_S_SRP ", activity type: " PUB_S_SRP
    547         ", activity context: %p.", dso->serial, activity->name, activity->activity_type, activity->context);
    548 
    549     if (activity->finalize != NULL) {
    550         activity->finalize(activity);
    551     }
    552     mdns_free(activity);
    553 }
    554 
    555 uint32_t dso_ignore_further_responses(dso_state_t *dso, const void *const context)
    556 {
    557     dso_outstanding_query_state_t *midState = dso->outstanding_queries;
    558     int i;
    559     uint32_t disassociated_count = 0;
    560     for (i = 0; i < midState->max_outstanding_queries; i++) {
    561         // The query is still be outstanding, and we want to know it when it comes back, but we forget the context,
    562         // which presumably is a reference to something that's going away.
    563         if (midState->queries[i].context == context) {
    564             midState->queries[i].context = NULL;
    565             INFO("[DSO%u] Disassociate the outstanding dso query with the context - query id: 0x%x, context: %p.",
    566                  dso->serial, midState->queries[i].id, context);
    567             disassociated_count++;
    568         }
    569     }
    570 
    571     return disassociated_count;
    572 }
    573 
    574 void dso_update_outstanding_query_context(dso_state_t *const dso, const void *const old_context,
    575     void *const new_context)
    576 {
    577     dso_outstanding_query_state_t *const states = dso->outstanding_queries;
    578     for (int i = 0; i < states->max_outstanding_queries; i++) {
    579         if (states->queries[i].context == old_context) {
    580             states->queries[i].context = new_context;
    581         }
    582     }
    583 }
    584 
    585 uint32_t dso_connections_reset_outstanding_query_context(const void *const context)
    586 {
    587     uint32_t reset_count = 0;
    588 
    589     if (context == NULL) {
    590         goto exit;
    591     }
    592 
    593     for (dso_state_t *dso_state = dso_connections; dso_state; dso_state = dso_state->next) {
    594         reset_count += dso_ignore_further_responses(dso_state, context);
    595     }
    596 
    597 exit:
    598     return reset_count;
    599 }
    600 
    601 bool dso_make_message(dso_message_t *state, uint8_t *outbuf, size_t outbuf_size, dso_state_t *dso,
    602                       bool unidirectional, bool response, uint16_t xid, int rcode, void *callback_state)
    603 {
    604     DNSMessageHeader *msg_header;
    605     dso_outstanding_query_state_t *midState = dso->outstanding_queries;
    606 
    607     memset(state, 0, sizeof *state);
    608     state->buf = outbuf;
    609     state->max = outbuf_size;
    610 
    611     // We need space for the TCP message length plus the DNS header.
    612     if (state->max < sizeof *msg_header) {
    613         LogMsg("dso_make_message: called without enough buffer space to store a DNS header!");
    614         assert(0);
    615     }
    616 
    617     // This buffer should be 16-bit aligned.
    618     msg_header = (DNSMessageHeader *)state->buf;
    619 
    620     // The DNS header for a DSO message is mostly zeroes
    621     memset(msg_header, 0, sizeof *msg_header);
    622     msg_header->flags.b[0] = (response ? kDNSFlag0_QR_Response : kDNSFlag0_QR_Query) | kDNSFlag0_OP_DSO;
    623 
    624     // Servers can't send DSO messages until there's a DSO session.
    625     if (dso->is_server && !dso->has_session) {
    626         LogMsg("dso_make_message: FATAL: server attempting to make a DSO message with no session!");
    627         assert(0);
    628     }
    629 
    630     // Response-requiring messages need to have a message ID. Replies take the message ID from the message to which
    631     // they are a reply, and also need an rcode.
    632     if (response) {
    633         msg_header->flags.b[1] = (uint8_t)rcode;
    634         msg_header->id.NotAnInteger = xid;
    635     } else if (!unidirectional) {
    636         bool msg_id_ok = true;
    637         uint16_t message_id;
    638         int looping = 0;
    639         int i, avail = -1;
    640 
    641         // If we don't have room for another outstanding message, the caller should try
    642         // again later.
    643         if (midState->outstanding_query_count == midState->max_outstanding_queries) {
    644             return false;
    645         }
    646         // Generate a random message ID.   This doesn't really need to be cryptographically sound
    647         // (right?) because we're encrypting the whole data stream in TLS.
    648         do {
    649             // This would be a surprising fluke, but let's not get killed by it.
    650             if (looping++ > 1000) {
    651                 return false;
    652             }
    653             message_id = (uint16_t)mDNSRandom(UINT16_MAX);
    654             msg_id_ok = true;
    655             if (message_id == 0) {
    656                 msg_id_ok = false;
    657             } else {
    658                 for (i = 0; i < midState->max_outstanding_queries; i++) {
    659                     if (midState->queries[i].id == 0 && avail == -1) {
    660                         avail = i;
    661                     } else if (midState->queries[i].id == message_id) {
    662                         msg_id_ok = false;
    663                     }
    664                 }
    665             }
    666         } while (!msg_id_ok);
    667         if (avail == -1) {
    668             LogMsg("dso_make_message: FATAL: no slots available even though there's supposedly space.");
    669             return false;
    670         }
    671         midState->queries[avail].id = message_id;
    672         midState->queries[avail].context = callback_state;
    673         LogMsg("dso_make_message: added query xid %x into slot %x, context %p", message_id, avail, callback_state);
    674         midState->outstanding_query_count++;
    675         msg_header->id.NotAnInteger = message_id;
    676         state->outstanding_query_number = avail;
    677     } else {
    678         // Clients aren't allowed to send unidirectional messages until there's a session.
    679         if (!dso->has_session) {
    680             LogMsg("dso_make_message: FATAL: client making a DSO unidirectional message with no session!");
    681             assert(0);
    682         }
    683         state->outstanding_query_number = -1;
    684     }
    685 
    686     state->cur = sizeof *msg_header;
    687     return true;
    688 }
    689 
    690 size_t dso_message_length(dso_message_t *state)
    691 {
    692     return state->cur + state->no_copy_bytes_len;
    693 }
    694 
    695 void dso_retry_delay(dso_state_t *dso, const DNSMessageHeader *header)
    696 {
    697     dso_disconnect_context_t context;
    698     if (dso->cb) {
    699         memset(&context, 0, sizeof context);
    700         if (dso->primary.length != 4) {
    701             LogMsg("Invalid DSO Retry Delay length %d from %s", dso->primary.length, dso->remote_name);
    702             dso_send_formerr(dso, header);
    703             return;
    704         }
    705         memcpy(&context, dso->primary.payload, dso->primary.length);
    706         context.reconnect_delay = ntohl(context.reconnect_delay);
    707         dso->cb(dso->context, &context, dso, kDSOEventType_RetryDelay);
    708     }
    709 }
    710 
    711 void dso_keepalive(dso_state_t *dso, const DNSMessageHeader *header, bool response)
    712 {
    713     dso_keepalive_context_t context;
    714     memset(&context, 0, sizeof context);
    715     if (dso->primary.length != 8) {
    716         LogMsg("Invalid DSO Keepalive length %d from %s", dso->primary.length, dso->remote_name);
    717         if (dso->is_server) {
    718             dso_send_formerr(dso, header);
    719         }
    720         return;
    721     }
    722     if (dso->is_server && response) {
    723         LogMsg("Dropping Keepalive Response received by DSO server");
    724         return;
    725     }
    726 
    727     memcpy(&context, dso->primary.payload, dso->primary.length);
    728     context.inactivity_timeout = ntohl(context.inactivity_timeout);
    729     context.keepalive_interval = ntohl(context.keepalive_interval);
    730     context.xid = header->id.NotAnInteger;
    731     context.send_response = true;
    732     if (context.inactivity_timeout > FutureTime || context.keepalive_interval > FutureTime) {
    733         LogMsg("[DSO%u] inactivity_timeoutl[%u] keepalive_interva[%u] is unreasonably large.",
    734                dso->serial, context.inactivity_timeout, context.keepalive_interval);
    735         if (dso->is_server) {
    736             dso_send_formerr(dso, header);
    737         }
    738         return;
    739     }
    740     if (dso->is_server) {
    741         if (dso->cb) {
    742             if (dso->keepalive_interval < context.keepalive_interval) {
    743                 context.keepalive_interval = dso->keepalive_interval;
    744             }
    745             if (dso->inactivity_timeout < context.inactivity_timeout) {
    746                 context.inactivity_timeout = dso->inactivity_timeout;
    747             }
    748             dso->cb(dso->context, &context, dso, kDSOEventType_KeepaliveRcvd);
    749         }
    750         if (context.send_response) {
    751             dso_send_simple_response(dso, kDNSFlag1_RC_NoErr, header, "No Error");
    752         }
    753     } else {
    754         if (dso->keepalive_interval > context.keepalive_interval) {
    755             dso->keepalive_interval = context.keepalive_interval;
    756         }
    757         if (dso->inactivity_timeout > context.inactivity_timeout) {
    758             dso->inactivity_timeout = context.inactivity_timeout;
    759         }
    760         if (dso->cb) {
    761             dso->cb(dso->context, &context, dso, kDSOEventType_KeepaliveRcvd);
    762         }
    763         // Client does not send response.
    764     }
    765 }
    766 
    767 // We received a DSO message; validate it, parse it and, if implemented, dispatch it.
    768 void dso_message_received(dso_state_t *dso, const uint8_t *message, size_t message_length, void *context)
    769 {
    770     int i;
    771     size_t offset;
    772     const DNSMessageHeader *header = (const DNSMessageHeader *)message;
    773     int response = (header->flags.b[0] & kDNSFlag0_QR_Mask) == kDNSFlag0_QR_Response;
    774     dso_query_receive_context_t qcontext;
    775 
    776     if (message_length < 12) {
    777         LogMsg("dso_message_received: response too short: %ld bytes", (long)message_length);
    778         dso_state_cancel(dso);
    779         goto out;
    780     }
    781 
    782     // See if we have sent a message for which a response is expected.
    783     if (response) {
    784         bool expected = false;
    785 
    786         // A zero ID on a response is not permitted.
    787         if (header->id.NotAnInteger == 0) {
    788             LogMsg("dso_message_received: response with id==0 received from %s", dso->remote_name);
    789             dso_state_cancel(dso);
    790             goto out;
    791         }
    792         // It's possible for a DSO response to contain no TLVs, but if that's the case, the length
    793         // should always be twelve.
    794         if (message_length < 16 && message_length != 12) {
    795             LogMsg("dso_message_received: response with bogus length==%ld received from %s", (long)message_length, dso->remote_name);
    796             dso_state_cancel(dso);
    797             goto out;
    798         }
    799         for (i = 0; i < dso->outstanding_queries->max_outstanding_queries; i++) {
    800             if (dso->outstanding_queries->queries[i].id == header->id.NotAnInteger) {
    801                 qcontext.query_context = dso->outstanding_queries->queries[i].context;
    802                 qcontext.rcode = header->flags.b[1] & kDNSFlag1_RC_Mask;
    803                 qcontext.message_context = context;
    804 
    805                 // If we are a client, and we just got an acknowledgment, a session has been established.
    806                 if (!dso->is_server && !dso->has_session && (header->flags.b[1] & kDNSFlag1_RC_Mask) == kDNSFlag1_RC_NoErr) {
    807                     dso_session_established(dso);
    808                 }
    809                 dso->outstanding_queries->queries[i].id = 0;
    810                 dso->outstanding_queries->queries[i].context = 0;
    811                 dso->outstanding_queries->outstanding_query_count--;
    812                 if (dso->outstanding_queries->outstanding_query_count < 0) {
    813                     LogMsg("dso_message_receive: programming error: outstanding_query_count went negative.");
    814                     assert(0);
    815                 }
    816                 // If there were no TLVs, we don't need to parse them.
    817                 expected = true;
    818                 if (message_length == 12) {
    819                     dso->primary.opcode = 0;
    820                     dso->primary.length = 0;
    821                     dso->num_additls = 0;
    822                 }
    823                 break;
    824             }
    825         }
    826 
    827         // This is fatal because we've received a response to a message we didn't send, so
    828         // it's not just that we don't understand what was sent.
    829         if (!expected) {
    830             LogMsg("dso_message_received: fatal: %s sent %ld byte message, QR=1, xid=%02x%02x", dso->remote_name,
    831                    (long)message_length, header->id.b[0], header->id.b[1]);
    832             dso_state_cancel(dso);
    833             goto out;
    834         }
    835     }
    836 
    837     // Make sure that the DNS header is okay (QDCOUNT, ANCOUNT, NSCOUNT and ARCOUNT are all zero)
    838     for (i = 0; i < 4; i++) {
    839         if (message[4 + i * 2] != 0 || message[4 + i * 2 + 1] != 0) {
    840             LogMsg("dso_message_received: fatal: %s sent %ld byte DSO message, %s is nonzero",
    841                    dso->remote_name, (long)message_length,
    842                    (i == 0 ? "QDCOUNT" : (i == 1 ? "ANCOUNT" : ( i == 2 ? "NSCOUNT" : "ARCOUNT"))));
    843             dso_state_cancel(dso);
    844             goto out;
    845         }
    846     }
    847 
    848     // Check that there is space for there to be a primary TLV
    849     if (message_length < 16 && message_length != 12) {
    850         LogMsg("dso_message_received: fatal: %s sent short (%ld byte) DSO message",
    851                dso->remote_name, (long)message_length);
    852 
    853         // Short messages are a fatal error. XXX check DSO document
    854         dso_state_cancel(dso);
    855         goto out;
    856     }
    857 
    858     // If we are a server, and we don't have a session, and this is a message, then we have now established a session.
    859     if (!dso->has_session && dso->is_server && !response) {
    860         dso_session_established(dso);
    861     }
    862 
    863     // If a DSO session isn't yet established, make sure the message is a request (if is_server) or a
    864     // response (if not).
    865     if (!dso->has_session && ((dso->is_server && response) || (!dso->is_server && !response))) {
    866         LogMsg("dso_message_received: received a %s with no established session from %s",
    867                response ? "response" : "request", dso->remote_name);
    868         dso_state_cancel(dso);
    869     }
    870 
    871     // Get the primary TLV and count how many TLVs there are in total
    872     for (int k = 0; k < 2; k++) {
    873         unsigned num_additls = 0;
    874         offset = 12;
    875         while (offset < message_length) {
    876             // Get the TLV opcode
    877             const uint16_t opcode = (uint16_t)(((uint16_t)message[offset]) << 8) + message[offset + 1];
    878             // And the length
    879             const uint16_t length = (uint16_t)(((uint16_t)message[offset + 2]) << 8) + message[offset + 3];
    880 
    881             // Is there room for the contents of this TLV?
    882             if (length + offset > message_length) {
    883                 LogMsg("dso_message_received: fatal: %s: TLV (%d %ld) extends past end (%ld)",
    884                        dso->remote_name, opcode, (long)length, (long)message_length);
    885 
    886                 // Short messages are a fatal error. XXX check DSO document
    887                 dso_state_cancel(dso);
    888                 goto out;
    889             }
    890 
    891             if (k == 0) {
    892                 num_additls++;
    893             } else {
    894                 // Is this the primary TLV?
    895                 if (offset == 12) {
    896                     dso->primary.opcode = opcode;
    897                     dso->primary.length = length;
    898                     dso->primary.payload = &message[offset + 4];
    899                     dso->num_additls = 0;
    900                 } else {
    901                     if (dso->num_additls < dso->max_additls) {
    902                         dso->additl[dso->num_additls].opcode = opcode;
    903                         dso->additl[dso->num_additls].length = length;
    904                         dso->additl[dso->num_additls].payload = &message[offset + 4];
    905                         dso->num_additls++;
    906                     } else {
    907                         // XXX MAX_ADDITLS should be enough for all possible additional TLVs, so this
    908                         // XXX should never happen; if it does, maybe it's a fatal error.
    909                         LogMsg("dso_message_received: %s: ignoring additional TLV (%d %ld) in excess of %d",
    910                                dso->remote_name, opcode, (long)length, dso->max_additls);
    911                     }
    912                 }
    913             }
    914             offset += 4 + length;
    915         }
    916         if (k == 0) {
    917             if (num_additls > dso->max_additls) {
    918                 if (dso->additl != dso->additl_buf) {
    919                     mdns_free(dso->additl);
    920                 }
    921                 dso->additl = mdns_calloc(num_additls, sizeof(*dso->additl));
    922                 if (dso->additl == NULL) {
    923                     dso->additl = dso->additl_buf;
    924                     dso->max_additls = MAX_ADDITLS;
    925                 } else {
    926                     dso->max_additls = num_additls;
    927                 }
    928             }
    929         }
    930     }
    931 
    932     // Call the callback with the message or response
    933     if (dso->cb) {
    934         if (message_length != 12 && dso->primary.opcode == kDSOType_Keepalive) {
    935             dso_keepalive(dso, header, response);
    936         } else if (message_length != 12 && dso->primary.opcode == kDSOType_RetryDelay) {
    937             dso_retry_delay(dso, header);
    938         } else {
    939             if (response) {
    940                 dso->cb(dso->context, &qcontext, dso, kDSOEventType_DSOResponse);
    941             } else {
    942                 dso->cb(dso->context, context, dso, kDSOEventType_DSOMessage);
    943             }
    944         }
    945     }
    946 out:
    947     ;
    948 }
    949 
    950 // This code is currently assuming that we won't get a DNS message, but that's not true.   Fix.
    951 void dns_message_received(dso_state_t *dso, const uint8_t *message, size_t message_length, void *context)
    952 {
    953     const DNSMessageHeader *header;
    954     int opcode, response;
    955 
    956     // We can safely assume that the header is 16-bit aligned.
    957     header = (const DNSMessageHeader *)message;
    958     opcode = header->flags.b[0] & kDNSFlag0_OP_Mask;
    959     response = (header->flags.b[0] & kDNSFlag0_QR_Mask) == kDNSFlag0_QR_Response;
    960 
    961     // Validate the length of the DNS message.
    962     if (message_length < 12) {
    963         LogMsg("dns_message_received: fatal: %s sent short (%ld byte) message",
    964                dso->remote_name, (long)message_length);
    965 
    966         // Short messages are a fatal error.
    967         dso_state_cancel(dso);
    968         return;
    969     }
    970 
    971     // This is not correct for the general case.
    972     if (opcode != kDNSFlag0_OP_DSO) {
    973         LogMsg("dns_message_received: %s sent %ld byte %s, QTYPE=%d",
    974                dso->remote_name, (long)message_length, (response ? "response" : "request"), opcode);
    975         if (dso->cb) {
    976             dso->cb(dso->context, context, dso,
    977                     response ? kDSOEventType_DNSMessage : kDSOEventType_DNSResponse);
    978         }
    979     } else {
    980         dso_message_received(dso, message, message_length, context);
    981     }
    982 }
    983 
    984 const char *dso_event_type_to_string(const dso_event_type_t dso_event_type)
    985 {
    986 #define CASE_TO_STR(s) case kDSOEventType_ ## s: return (#s)
    987     switch(dso_event_type)
    988     {
    989         CASE_TO_STR(DNSMessage);
    990         CASE_TO_STR(DNSResponse);
    991         CASE_TO_STR(DSOMessage);
    992         CASE_TO_STR(Finalize);
    993         CASE_TO_STR(DSOResponse);
    994         CASE_TO_STR(Connected);
    995         CASE_TO_STR(ConnectFailed);
    996         CASE_TO_STR(Disconnected);
    997         CASE_TO_STR(ShouldReconnect);
    998         CASE_TO_STR(Inactive);
    999         CASE_TO_STR(Keepalive);
   1000         CASE_TO_STR(KeepaliveRcvd);
   1001         CASE_TO_STR(RetryDelay);
   1002         MDNS_COVERED_SWITCH_DEFAULT:
   1003             break;
   1004     }
   1005 #undef CASE_TO_STR
   1006     LogMsg("Invalid dso_event_type - dso_event_type: %d.", dso_event_type);
   1007     return "<INVALID dso_event_type>";
   1008 }
   1009 
   1010 // Local Variables:
   1011 // mode: C
   1012 // tab-width: 4
   1013 // c-file-style: "bsd"
   1014 // c-basic-offset: 4
   1015 // fill-column: 108
   1016 // indent-tabs-mode: nil
   1017 // End:
   1018