1 /* dso.c 2 * 3 * Copyright (c) 2018-2024 Apple Inc. All rights reserved. 4 * 5 * Licensed under the Apache License, Version 2.0 (the "License"); 6 * you may not use this file except in compliance with the License. 7 * You may obtain a copy of the License at 8 * 9 * https://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 * 17 */ 18 19 //************************************************************************************************************* 20 // Headers 21 22 #include <stdio.h> 23 #include <signal.h> 24 #include <stdlib.h> 25 #include <stdbool.h> 26 #include <stdlib.h> 27 #include <unistd.h> 28 #include <string.h> 29 #include <assert.h> 30 31 #include <netdb.h> // For gethostbyname() 32 #include <sys/socket.h> // For AF_INET, AF_INET6, etc. 33 #include <net/if.h> // For IF_NAMESIZE 34 #include <netinet/in.h> // For INADDR_NONE 35 #include <netinet/tcp.h> // For SOL_TCP, TCP_NOTSENT_LOWAT 36 #include <arpa/inet.h> // For inet_addr() 37 #include <unistd.h> 38 #include <errno.h> 39 #include <fcntl.h> 40 41 #include "DNSCommon.h" 42 #include "mDNSEmbeddedAPI.h" 43 #include "PlatformCommon.h" 44 #include "dso.h" 45 #include "DebugServices.h" // For check_compile_time 46 47 #ifdef STANDALONE 48 #undef LogMsg 49 #define LogMsg INFO 50 51 #include "srp-log.h" 52 extern uint16_t srp_random16(void); 53 #define mDNSRandom(x) srp_random16() 54 #define mDNSPlatformMemAllocateClear(length) mdns_calloc(1, length) 55 #else // STANDALONE 56 57 // This is only a temporary fix to let the code in this file print unredacted logs. 58 59 #include "srp-log.h" 60 #undef FAULT 61 #undef INFO 62 #define FAULT(fmt, ...) 63 #define INFO(fmt, ...) 64 65 #endif // STANDALONE 66 67 #include "mdns_strict.h" 68 69 //************************************************************************************************************* 70 // Remaining work TODO 71 72 // - Add keepalive/inactivity timeout support 73 // - Notice if it takes a long time to get a response when establishing a session, and treat that 74 // as "DSO not supported." 75 // - TLS support 76 // - Actually use Network Framework 77 78 79 //************************************************************************************************************* 80 // Globals 81 82 // List of dso connection states that are active. Added when dso_connect_state_create() is called, removed 83 // when dso_state_cancel() is called. Removals are moved to dso_connections_needing_cleanup for cleanup during 84 // the idle loop. 85 // The list of connection states is not declared static so that the discovery proxy can access it as part of 86 // the "start-dropping-push" test. 87 dso_state_t *dso_connections; 88 static dso_state_t *dso_connections_needing_cleanup; // DSO connections that have been shut down but aren't yet freed. 89 90 dso_state_t *dso_find_by_serial(uint32_t serial) 91 { 92 dso_state_t *dsop; 93 94 for (dsop = dso_connections; dsop; dsop = dsop->next) { 95 if (dsop->serial == serial) { 96 return dsop; 97 } 98 } 99 return NULL; 100 } 101 102 // This function is called either when an error has occurred requiring the a DSO connection be 103 // canceled, or else when a connection to a DSO endpoint has been cleanly closed and is ready to be 104 // canceled for that reason. 105 106 void dso_state_cancel(dso_state_t *dso) 107 { 108 dso_state_t **dsop = &dso_connections; 109 bool status = true; 110 111 // Find dso on the list of connections. 112 while (*dsop != NULL && *dsop != dso) { 113 dsop = &(*dsop)->next; 114 } 115 116 // If we get to the end of the list without finding dso, it means that it's already 117 // been dropped. 118 if (*dsop == NULL) { 119 return; 120 } 121 122 // When the dso_state_t is canceled, its context may also need to be canceled/released/freed, so we give context a 123 // callback to do the cleaning work with dso_life_cycle_cancel state. 124 if (dso->context_callback != NULL) { 125 status = dso->context_callback(dso_life_cycle_cancel, dso->context, dso); 126 } 127 128 // If the callback returns a status of true, then we want to free the dso object in the idle loop. 129 if (status) { 130 // Remove dso from the list of active dso objects. 131 *dsop = dso->next; 132 133 // Add it to the list of dso objects needing cleanup. 134 dso->next = dso_connections_needing_cleanup; 135 dso_connections_needing_cleanup = dso; 136 } 137 } 138 139 void dso_cleanup(bool call_callbacks) 140 { 141 dso_state_t *dso, *dnext; 142 dso_activity_t *ap, *anext; 143 144 for (dso = dso_connections_needing_cleanup; dso; dso = dnext) { 145 dnext = dso->next; 146 // Finalize and then free any activities. 147 for (ap = dso->activities; ap; ap = anext) { 148 anext = ap->next; 149 if (ap->finalize) { 150 ap->finalize(ap); 151 } 152 mdns_free(ap); 153 } 154 if (dso->transport != NULL && dso->transport_finalize != NULL) { 155 dso->transport_finalize(dso->transport, "dso_idle"); 156 dso->transport = NULL; 157 } 158 LogMsg("[DSO%u] dso_state_t finalizing - " 159 "dso: %p, remote name: %s, dso->context: %p", dso->serial, dso, dso->remote_name, dso->context); 160 if (dso->cb && call_callbacks) { 161 // Because dso->context is the DNSPushServer that uses the current dso_state_t *dso 162 // (server->connection) and the server has been canceled by CancelDNSPushServer(), the 163 // current dso is not used and cannot be recovered (or reconnected). The only thing we can do is to finalize 164 // it. 165 dso->cb(dso->context, NULL, dso, kDSOEventType_Finalize); 166 } else { 167 if (dso->additl != dso->additl_buf) { 168 mdns_free(dso->additl); 169 } 170 mdns_free(dso); 171 } 172 // Do not touch dso after this point, because it has been freed. 173 } 174 dso_connections_needing_cleanup = NULL; 175 } 176 177 int32_t dso_idle(void *context, int32_t now, int32_t next_timer_event) 178 { 179 dso_state_t *dso, *dnext; 180 181 dso_cleanup(true); 182 183 // Do keepalives. 184 for (dso = dso_connections; dso; dso = dnext) { 185 dnext = dso->next; 186 if (dso->inactivity_due == 0) { 187 if (dso->inactivity_timeout != 0) { 188 dso->inactivity_due = NonZeroTime(now + (event_time_t)MIN(dso->inactivity_timeout, INT32_MAX)); 189 if (next_timer_event - dso->inactivity_due > 0) { 190 next_timer_event = dso->inactivity_due; 191 } 192 } 193 } else if (now - dso->inactivity_due > 0 && dso->cb != NULL) { 194 dso->cb(dso->context, 0, dso, kDSOEventType_Inactive); 195 // Should not touch the current dso_state_t after we deliver kDSOEventType_Inactive event, because it is 196 // possible that the current dso_state_t has been canceled in the callback. Doing any operation to update 197 // its status will not work as expected. 198 continue; 199 } 200 if (dso->keepalive_due != 0 && dso->keepalive_due - now < 0 && dso->cb != NULL) { 201 dso_keepalive_context_t kc; 202 memset(&kc, 0, sizeof kc); 203 dso->cb(dso->context, &kc, dso, kDSOEventType_Keepalive); 204 dso->keepalive_due = NonZeroTime(now + (event_time_t)MIN(dso->keepalive_interval, INT32_MAX)); 205 if (next_timer_event - dso->keepalive_due > 0) { 206 next_timer_event = dso->keepalive_due; 207 } 208 } 209 } 210 return dso_transport_idle(context, now, next_timer_event); 211 } 212 213 void dso_set_event_context(dso_state_t *dso, void *context) 214 { 215 dso->context = context; 216 } 217 218 void dso_set_life_cycle_callback(dso_state_t *dso, dso_life_cycle_context_callback_t callback) 219 { 220 dso->context_callback = callback; 221 } 222 223 void dso_set_event_callback(dso_state_t *dso, dso_event_callback_t callback) 224 { 225 dso->cb = callback; 226 } 227 228 // Called when something happens that establishes a DSO session. 229 static void dso_session_established(dso_state_t *dso) 230 { 231 LogMsg("[DSO%u] DSO session established - dso: %p, remote name: %s.", dso->serial, dso, dso->remote_name); 232 dso->has_session = true; 233 // Set up inactivity timer and keepalive timer... 234 } 235 236 // Create a dso_state_t structure 237 dso_state_t *dso_state_create(bool is_server, int max_outstanding_queries, const char *remote_name, 238 dso_event_callback_t callback, void *const context, 239 const dso_life_cycle_context_callback_t context_callback, 240 dso_transport_t *transport) 241 { 242 dso_state_t *dso; 243 size_t namelen = strlen(remote_name); 244 size_t namespace = namelen + 1; 245 const size_t outsize = (sizeof (dso_outstanding_query_state_t)) + (size_t)max_outstanding_queries * sizeof (dso_outstanding_query_t); 246 247 if ((sizeof (*dso) + outsize + namespace) > UINT_MAX) { 248 FAULT("Fatal: sizeof (*dso)[%zd], outsize[%zd], " 249 "namespace[%zd]", sizeof (*dso), outsize, namespace); 250 dso = NULL; 251 goto out; 252 } 253 // We allocate everything in a single hunk so that we can free it together as well. 254 dso = (dso_state_t *) mDNSPlatformMemAllocateClear((uint32_t)((sizeof (*dso)) + outsize + namespace)); 255 if (dso == NULL) { 256 goto out; 257 } 258 dso->outstanding_queries = (dso_outstanding_query_state_t *)(dso + 1); 259 dso->outstanding_queries->max_outstanding_queries = max_outstanding_queries; 260 261 dso->remote_name = ((char *)dso->outstanding_queries) + outsize; 262 memcpy(dso->remote_name, remote_name, namelen); 263 dso->remote_name[namelen] = 0; 264 265 dso->cb = callback; 266 if (context != NULL) { 267 dso->context = context; 268 } 269 if (context_callback != NULL) { 270 dso->context_callback = context_callback; 271 // When dso_state_t is created, the context it holds may need to be reference counted, for example, to retain 272 // the context. Here we give the context a callback with dso_life_cycle_create state. 273 context_callback(dso_life_cycle_create, context, dso); 274 } 275 dso->transport = transport; 276 dso->is_server = is_server; 277 278 // Used to uniquely mark dso_state_t objects, incremented once for each dso_state_t created. 279 // DSO_STATE_INVALID_SERIAL(0) is used to identify invalid dso_state_t. 280 static uint32_t dso_state_serial = DSO_STATE_INVALID_SERIAL + 1; 281 dso->serial = dso_state_serial++; 282 283 // Set up additional additional pointer. 284 dso->additl = dso->additl_buf; 285 dso->max_additls = MAX_ADDITLS; 286 287 dso->keepalive_interval = 3600 * MSEC_PER_SEC; 288 dso->inactivity_timeout = 15 * MSEC_PER_SEC; 289 290 dso->next = dso_connections; 291 dso_connections = dso; 292 293 LogMsg("[DSO%u] New dso_state_t created - dso: %p, remote name: %s, context: %p", 294 dso->serial, dso, remote_name, context); 295 out: 296 return dso; 297 } 298 299 // Start building a TLV in an outgoing dso message. 300 void dso_start_tlv(dso_message_t *state, int opcode) 301 { 302 // Make sure there's room for the length and the TLV opcode. 303 if (state->cur + 4 >= state->max) { 304 LogMsg("dso_start_tlv called when no space in output buffer!"); 305 assert(0); 306 } 307 308 // We need to not yet have a TLV. 309 if (state->building_tlv) { 310 LogMsg("dso_start_tlv called while already building a TLV!"); 311 assert(0); 312 } 313 state->building_tlv = true; 314 state->tlv_len = 0; 315 316 // Set up the TLV header. 317 state->buf[state->cur] = (uint8_t)(opcode >> 8); 318 state->buf[state->cur + 1] = opcode & 255; 319 state->tlv_len_offset = state->cur + 2; 320 state->cur += 4; 321 } 322 323 // Add some bytes to a TLV that's being built, but don't copy them--just remember the 324 // pointer to the buffer. This is used so that when we have a message to forward, we 325 // don't copy it into the output buffer--we just use scatter/gather I/O. 326 void dso_add_tlv_bytes_no_copy(dso_message_t *state, const uint8_t *bytes, size_t len) 327 { 328 if (!state->building_tlv) { 329 LogMsg("add_tlv_bytes called when not building a TLV!"); 330 assert(0); 331 } 332 if (state->no_copy_bytes_len) { 333 LogMsg("add_tlv_bytesNoCopy called twice on the same DSO message."); 334 assert(0); 335 } 336 state->no_copy_bytes_len = len; 337 state->no_copy_bytes = bytes; 338 state->no_copy_bytes_offset = state->cur; 339 state->tlv_len += len; 340 } 341 342 // Add some bytes to a TLV that's being built. 343 void dso_add_tlv_bytes(dso_message_t *state, const uint8_t *bytes, size_t len) 344 { 345 if (!state->building_tlv) { 346 LogMsg("add_tlv_bytes called when not building a TLV!"); 347 assert(0); 348 } 349 if (state->cur + len > state->max) { 350 LogMsg("add_tlv_bytes called with no room in output buffer."); 351 assert(0); 352 } 353 memcpy(&state->buf[state->cur], bytes, len); 354 state->cur += len; 355 state->tlv_len += len; 356 } 357 358 // Add a single byte to a TLV that's being built. 359 void dso_add_tlv_byte(dso_message_t *state, uint8_t byte) 360 { 361 if (!state->building_tlv) { 362 LogMsg("dso_add_tlv_byte called when not building a TLV!"); 363 assert(0); 364 } 365 if (state->cur + 1 > state->max) { 366 LogMsg("dso_add_tlv_byte called with no room in output buffer."); 367 assert(0); 368 } 369 state->buf[state->cur++] = byte; 370 state->tlv_len++; 371 } 372 373 // Add an uint16_t to a TLV that's being built. 374 void dso_add_tlv_u16(dso_message_t *state, uint16_t u16) 375 { 376 if (!state->building_tlv) { 377 LogMsg("dso_add_tlv_u16 called when not building a TLV!"); 378 assert(0); 379 } 380 if ((state->cur + sizeof u16) > state->max) { 381 LogMsg("dso_add_tlv_u16 called with no room in output buffer."); 382 assert(0); 383 } 384 state->buf[state->cur++] = u16 >> 8; 385 state->buf[state->cur++] = u16 & 255; 386 state->tlv_len += 2; 387 } 388 389 // Add an uint32_t to a TLV that's being built. 390 void dso_add_tlv_u32(dso_message_t *state, uint32_t u32) 391 { 392 if (!state->building_tlv) { 393 LogMsg("dso_add_tlv_u32 called when not building a TLV!"); 394 assert(0); 395 } 396 if ((state->cur + sizeof u32) > state->max) { 397 LogMsg("dso_add_tlv_u32 called with no room in output buffer."); 398 assert(0); 399 } 400 state->buf[state->cur++] = u32 >> 24; 401 state->buf[state->cur++] = (u32 >> 16) & 255; 402 state->buf[state->cur++] = (u32 >> 8) & 255; 403 state->buf[state->cur++] = u32 & 255; 404 state->tlv_len += 4; 405 } 406 407 // Finish building a TLV. 408 void dso_finish_tlv(dso_message_t *state) 409 { 410 if (!state->building_tlv) { 411 LogMsg("dso_finish_tlv called when not building a TLV!"); 412 assert(0); 413 } 414 415 // A TLV can't be longer than this. 416 if (state->tlv_len > 65535) { 417 LogMsg("dso_finish_tlv was given more than 65535 bytes of TLV payload!"); 418 assert(0); 419 } 420 state->buf[state->tlv_len_offset] = (uint8_t)(state->tlv_len >> 8); 421 state->buf[state->tlv_len_offset + 1] = state->tlv_len & 255; 422 state->tlv_len = 0; 423 state->building_tlv = false; 424 } 425 426 dso_activity_t *NULLABLE dso_find_activity(dso_state_t *const NONNULL dso, const char *const NULLABLE name, 427 const char *const NONNULL activity_type, void *const NULLABLE context) 428 { 429 dso_activity_t *activity; 430 431 // If we haven't been given something to search for, don't search. 432 if (name == NULL && context == NULL) { 433 FAULT("[DSO%u] Cannot search for activity with name and context both equal to NULL - " 434 "activity_type: " PUB_S_SRP ".", dso->serial, activity_type); 435 activity = NULL; 436 goto exit; 437 } 438 439 for (activity = dso->activities; activity != NULL; activity = activity->next) { 440 if (activity->activity_type != activity_type) { 441 continue; 442 } 443 444 if (name != NULL) { 445 // If name is specified, always use the name to search for the corresponding activity, even if context is 446 // also specified. 447 if (activity->name == NULL) { 448 continue; 449 } 450 if (strcmp(name, activity->name) != 0) { 451 continue; 452 } 453 // If the name matches, the corresponding context should also match if the context is not NULL. 454 if (context != NULL && activity->context != context) { 455 FAULT("[DSO%u] The activity specified by the name does not have the expected context - " 456 "name: " PRI_S_SRP ", activity_type: " PUB_S_SRP ", context: %p.", dso->serial, name, activity_type, 457 context); 458 } 459 } else { 460 // name == NULL && context != NULL 461 // If name is not specified, use context to search for the activity. 462 if (context != activity->context) { 463 continue; 464 } 465 } 466 467 break; 468 } 469 470 exit: 471 return activity; 472 } 473 474 // Make an activity structure to hang off the DSO. 475 dso_activity_t *dso_add_activity(dso_state_t *dso, const char *name, const char *activity_type, 476 void *context, void (*finalize)(dso_activity_t *)) 477 { 478 size_t namelen = name ? strlen(name) + 1 : 0; 479 size_t len; 480 dso_activity_t *activity; 481 void *ap; 482 483 // Shouldn't add an activity that's already been added. 484 activity = dso_find_activity(dso, name, activity_type, context); 485 if (activity != NULL) { 486 FAULT("[DSO%u] Trying to add a duplicate activity - activity name: " PRI_S_SRP ", activity type: " PUB_S_SRP 487 ", activity context: %p.", dso->serial, name, activity_type, context); 488 return NULL; 489 } 490 491 len = namelen + sizeof *activity; 492 ap = mDNSPlatformMemAllocateClear((mDNSu32)len); 493 if (ap == NULL) { 494 return NULL; 495 } 496 activity = (dso_activity_t *)ap; 497 ap = (char *)ap + sizeof *activity; 498 499 // Activities can be identified either by name or by context 500 if (namelen) { 501 activity->name = ap; 502 memcpy(activity->name, name, namelen); 503 } else { 504 activity->name = NULL; 505 } 506 activity->context = context; 507 508 // Activity type is expected to be a string constant; all activities of the same type must 509 // reference the same constant, not different constants with the same contents. 510 activity->activity_type = activity_type; 511 activity->finalize = finalize; 512 513 INFO("[DSO%u] Adding a DSO activity - activity name: " PRI_S_SRP ", activity type: " PUB_S_SRP 514 ", activity context: %p.", dso->serial, activity->name, activity->activity_type, activity->context); 515 516 // Retain this activity on the list. 517 activity->next = dso->activities; 518 dso->activities = activity; 519 520 return activity; 521 } 522 523 void dso_drop_activity(dso_state_t *dso, dso_activity_t *activity) 524 { 525 dso_activity_t **app = &dso->activities; 526 bool matched = false; 527 528 // Remove this activity from the list. 529 while (*app) { 530 if (*app == activity) { 531 *app = activity->next; 532 matched = true; 533 break; 534 } else { 535 app = &((*app)->next); 536 } 537 } 538 539 // If an activity that's not on the DSO list is passed here, it's an internal consistency 540 // error that probably indicates something is corrupted. 541 if (!matched) { 542 FAULT("[DSO%u] Trying to remove an activity that is not in the list - " 543 "activity name: " PRI_S_SRP ", activity type: " PUB_S_SRP ", activity context: %p.", 544 dso->serial, activity->name, activity->activity_type, activity->context); 545 } 546 INFO("[DSO%u] Removing a DSO activity - activity name: " PRI_S_SRP ", activity type: " PUB_S_SRP 547 ", activity context: %p.", dso->serial, activity->name, activity->activity_type, activity->context); 548 549 if (activity->finalize != NULL) { 550 activity->finalize(activity); 551 } 552 mdns_free(activity); 553 } 554 555 uint32_t dso_ignore_further_responses(dso_state_t *dso, const void *const context) 556 { 557 dso_outstanding_query_state_t *midState = dso->outstanding_queries; 558 int i; 559 uint32_t disassociated_count = 0; 560 for (i = 0; i < midState->max_outstanding_queries; i++) { 561 // The query is still be outstanding, and we want to know it when it comes back, but we forget the context, 562 // which presumably is a reference to something that's going away. 563 if (midState->queries[i].context == context) { 564 midState->queries[i].context = NULL; 565 INFO("[DSO%u] Disassociate the outstanding dso query with the context - query id: 0x%x, context: %p.", 566 dso->serial, midState->queries[i].id, context); 567 disassociated_count++; 568 } 569 } 570 571 return disassociated_count; 572 } 573 574 void dso_update_outstanding_query_context(dso_state_t *const dso, const void *const old_context, 575 void *const new_context) 576 { 577 dso_outstanding_query_state_t *const states = dso->outstanding_queries; 578 for (int i = 0; i < states->max_outstanding_queries; i++) { 579 if (states->queries[i].context == old_context) { 580 states->queries[i].context = new_context; 581 } 582 } 583 } 584 585 uint32_t dso_connections_reset_outstanding_query_context(const void *const context) 586 { 587 uint32_t reset_count = 0; 588 589 if (context == NULL) { 590 goto exit; 591 } 592 593 for (dso_state_t *dso_state = dso_connections; dso_state; dso_state = dso_state->next) { 594 reset_count += dso_ignore_further_responses(dso_state, context); 595 } 596 597 exit: 598 return reset_count; 599 } 600 601 bool dso_make_message(dso_message_t *state, uint8_t *outbuf, size_t outbuf_size, dso_state_t *dso, 602 bool unidirectional, bool response, uint16_t xid, int rcode, void *callback_state) 603 { 604 DNSMessageHeader *msg_header; 605 dso_outstanding_query_state_t *midState = dso->outstanding_queries; 606 607 memset(state, 0, sizeof *state); 608 state->buf = outbuf; 609 state->max = outbuf_size; 610 611 // We need space for the TCP message length plus the DNS header. 612 if (state->max < sizeof *msg_header) { 613 LogMsg("dso_make_message: called without enough buffer space to store a DNS header!"); 614 assert(0); 615 } 616 617 // This buffer should be 16-bit aligned. 618 msg_header = (DNSMessageHeader *)state->buf; 619 620 // The DNS header for a DSO message is mostly zeroes 621 memset(msg_header, 0, sizeof *msg_header); 622 msg_header->flags.b[0] = (response ? kDNSFlag0_QR_Response : kDNSFlag0_QR_Query) | kDNSFlag0_OP_DSO; 623 624 // Servers can't send DSO messages until there's a DSO session. 625 if (dso->is_server && !dso->has_session) { 626 LogMsg("dso_make_message: FATAL: server attempting to make a DSO message with no session!"); 627 assert(0); 628 } 629 630 // Response-requiring messages need to have a message ID. Replies take the message ID from the message to which 631 // they are a reply, and also need an rcode. 632 if (response) { 633 msg_header->flags.b[1] = (uint8_t)rcode; 634 msg_header->id.NotAnInteger = xid; 635 } else if (!unidirectional) { 636 bool msg_id_ok = true; 637 uint16_t message_id; 638 int looping = 0; 639 int i, avail = -1; 640 641 // If we don't have room for another outstanding message, the caller should try 642 // again later. 643 if (midState->outstanding_query_count == midState->max_outstanding_queries) { 644 return false; 645 } 646 // Generate a random message ID. This doesn't really need to be cryptographically sound 647 // (right?) because we're encrypting the whole data stream in TLS. 648 do { 649 // This would be a surprising fluke, but let's not get killed by it. 650 if (looping++ > 1000) { 651 return false; 652 } 653 message_id = (uint16_t)mDNSRandom(UINT16_MAX); 654 msg_id_ok = true; 655 if (message_id == 0) { 656 msg_id_ok = false; 657 } else { 658 for (i = 0; i < midState->max_outstanding_queries; i++) { 659 if (midState->queries[i].id == 0 && avail == -1) { 660 avail = i; 661 } else if (midState->queries[i].id == message_id) { 662 msg_id_ok = false; 663 } 664 } 665 } 666 } while (!msg_id_ok); 667 if (avail == -1) { 668 LogMsg("dso_make_message: FATAL: no slots available even though there's supposedly space."); 669 return false; 670 } 671 midState->queries[avail].id = message_id; 672 midState->queries[avail].context = callback_state; 673 LogMsg("dso_make_message: added query xid %x into slot %x, context %p", message_id, avail, callback_state); 674 midState->outstanding_query_count++; 675 msg_header->id.NotAnInteger = message_id; 676 state->outstanding_query_number = avail; 677 } else { 678 // Clients aren't allowed to send unidirectional messages until there's a session. 679 if (!dso->has_session) { 680 LogMsg("dso_make_message: FATAL: client making a DSO unidirectional message with no session!"); 681 assert(0); 682 } 683 state->outstanding_query_number = -1; 684 } 685 686 state->cur = sizeof *msg_header; 687 return true; 688 } 689 690 size_t dso_message_length(dso_message_t *state) 691 { 692 return state->cur + state->no_copy_bytes_len; 693 } 694 695 void dso_retry_delay(dso_state_t *dso, const DNSMessageHeader *header) 696 { 697 dso_disconnect_context_t context; 698 if (dso->cb) { 699 memset(&context, 0, sizeof context); 700 if (dso->primary.length != 4) { 701 LogMsg("Invalid DSO Retry Delay length %d from %s", dso->primary.length, dso->remote_name); 702 dso_send_formerr(dso, header); 703 return; 704 } 705 memcpy(&context, dso->primary.payload, dso->primary.length); 706 context.reconnect_delay = ntohl(context.reconnect_delay); 707 dso->cb(dso->context, &context, dso, kDSOEventType_RetryDelay); 708 } 709 } 710 711 void dso_keepalive(dso_state_t *dso, const DNSMessageHeader *header, bool response) 712 { 713 dso_keepalive_context_t context; 714 memset(&context, 0, sizeof context); 715 if (dso->primary.length != 8) { 716 LogMsg("Invalid DSO Keepalive length %d from %s", dso->primary.length, dso->remote_name); 717 if (dso->is_server) { 718 dso_send_formerr(dso, header); 719 } 720 return; 721 } 722 if (dso->is_server && response) { 723 LogMsg("Dropping Keepalive Response received by DSO server"); 724 return; 725 } 726 727 memcpy(&context, dso->primary.payload, dso->primary.length); 728 context.inactivity_timeout = ntohl(context.inactivity_timeout); 729 context.keepalive_interval = ntohl(context.keepalive_interval); 730 context.xid = header->id.NotAnInteger; 731 context.send_response = true; 732 if (context.inactivity_timeout > FutureTime || context.keepalive_interval > FutureTime) { 733 LogMsg("[DSO%u] inactivity_timeoutl[%u] keepalive_interva[%u] is unreasonably large.", 734 dso->serial, context.inactivity_timeout, context.keepalive_interval); 735 if (dso->is_server) { 736 dso_send_formerr(dso, header); 737 } 738 return; 739 } 740 if (dso->is_server) { 741 if (dso->cb) { 742 if (dso->keepalive_interval < context.keepalive_interval) { 743 context.keepalive_interval = dso->keepalive_interval; 744 } 745 if (dso->inactivity_timeout < context.inactivity_timeout) { 746 context.inactivity_timeout = dso->inactivity_timeout; 747 } 748 dso->cb(dso->context, &context, dso, kDSOEventType_KeepaliveRcvd); 749 } 750 if (context.send_response) { 751 dso_send_simple_response(dso, kDNSFlag1_RC_NoErr, header, "No Error"); 752 } 753 } else { 754 if (dso->keepalive_interval > context.keepalive_interval) { 755 dso->keepalive_interval = context.keepalive_interval; 756 } 757 if (dso->inactivity_timeout > context.inactivity_timeout) { 758 dso->inactivity_timeout = context.inactivity_timeout; 759 } 760 if (dso->cb) { 761 dso->cb(dso->context, &context, dso, kDSOEventType_KeepaliveRcvd); 762 } 763 // Client does not send response. 764 } 765 } 766 767 // We received a DSO message; validate it, parse it and, if implemented, dispatch it. 768 void dso_message_received(dso_state_t *dso, const uint8_t *message, size_t message_length, void *context) 769 { 770 int i; 771 size_t offset; 772 const DNSMessageHeader *header = (const DNSMessageHeader *)message; 773 int response = (header->flags.b[0] & kDNSFlag0_QR_Mask) == kDNSFlag0_QR_Response; 774 dso_query_receive_context_t qcontext; 775 776 if (message_length < 12) { 777 LogMsg("dso_message_received: response too short: %ld bytes", (long)message_length); 778 dso_state_cancel(dso); 779 goto out; 780 } 781 782 // See if we have sent a message for which a response is expected. 783 if (response) { 784 bool expected = false; 785 786 // A zero ID on a response is not permitted. 787 if (header->id.NotAnInteger == 0) { 788 LogMsg("dso_message_received: response with id==0 received from %s", dso->remote_name); 789 dso_state_cancel(dso); 790 goto out; 791 } 792 // It's possible for a DSO response to contain no TLVs, but if that's the case, the length 793 // should always be twelve. 794 if (message_length < 16 && message_length != 12) { 795 LogMsg("dso_message_received: response with bogus length==%ld received from %s", (long)message_length, dso->remote_name); 796 dso_state_cancel(dso); 797 goto out; 798 } 799 for (i = 0; i < dso->outstanding_queries->max_outstanding_queries; i++) { 800 if (dso->outstanding_queries->queries[i].id == header->id.NotAnInteger) { 801 qcontext.query_context = dso->outstanding_queries->queries[i].context; 802 qcontext.rcode = header->flags.b[1] & kDNSFlag1_RC_Mask; 803 qcontext.message_context = context; 804 805 // If we are a client, and we just got an acknowledgment, a session has been established. 806 if (!dso->is_server && !dso->has_session && (header->flags.b[1] & kDNSFlag1_RC_Mask) == kDNSFlag1_RC_NoErr) { 807 dso_session_established(dso); 808 } 809 dso->outstanding_queries->queries[i].id = 0; 810 dso->outstanding_queries->queries[i].context = 0; 811 dso->outstanding_queries->outstanding_query_count--; 812 if (dso->outstanding_queries->outstanding_query_count < 0) { 813 LogMsg("dso_message_receive: programming error: outstanding_query_count went negative."); 814 assert(0); 815 } 816 // If there were no TLVs, we don't need to parse them. 817 expected = true; 818 if (message_length == 12) { 819 dso->primary.opcode = 0; 820 dso->primary.length = 0; 821 dso->num_additls = 0; 822 } 823 break; 824 } 825 } 826 827 // This is fatal because we've received a response to a message we didn't send, so 828 // it's not just that we don't understand what was sent. 829 if (!expected) { 830 LogMsg("dso_message_received: fatal: %s sent %ld byte message, QR=1, xid=%02x%02x", dso->remote_name, 831 (long)message_length, header->id.b[0], header->id.b[1]); 832 dso_state_cancel(dso); 833 goto out; 834 } 835 } 836 837 // Make sure that the DNS header is okay (QDCOUNT, ANCOUNT, NSCOUNT and ARCOUNT are all zero) 838 for (i = 0; i < 4; i++) { 839 if (message[4 + i * 2] != 0 || message[4 + i * 2 + 1] != 0) { 840 LogMsg("dso_message_received: fatal: %s sent %ld byte DSO message, %s is nonzero", 841 dso->remote_name, (long)message_length, 842 (i == 0 ? "QDCOUNT" : (i == 1 ? "ANCOUNT" : ( i == 2 ? "NSCOUNT" : "ARCOUNT")))); 843 dso_state_cancel(dso); 844 goto out; 845 } 846 } 847 848 // Check that there is space for there to be a primary TLV 849 if (message_length < 16 && message_length != 12) { 850 LogMsg("dso_message_received: fatal: %s sent short (%ld byte) DSO message", 851 dso->remote_name, (long)message_length); 852 853 // Short messages are a fatal error. XXX check DSO document 854 dso_state_cancel(dso); 855 goto out; 856 } 857 858 // If we are a server, and we don't have a session, and this is a message, then we have now established a session. 859 if (!dso->has_session && dso->is_server && !response) { 860 dso_session_established(dso); 861 } 862 863 // If a DSO session isn't yet established, make sure the message is a request (if is_server) or a 864 // response (if not). 865 if (!dso->has_session && ((dso->is_server && response) || (!dso->is_server && !response))) { 866 LogMsg("dso_message_received: received a %s with no established session from %s", 867 response ? "response" : "request", dso->remote_name); 868 dso_state_cancel(dso); 869 } 870 871 // Get the primary TLV and count how many TLVs there are in total 872 for (int k = 0; k < 2; k++) { 873 unsigned num_additls = 0; 874 offset = 12; 875 while (offset < message_length) { 876 // Get the TLV opcode 877 const uint16_t opcode = (uint16_t)(((uint16_t)message[offset]) << 8) + message[offset + 1]; 878 // And the length 879 const uint16_t length = (uint16_t)(((uint16_t)message[offset + 2]) << 8) + message[offset + 3]; 880 881 // Is there room for the contents of this TLV? 882 if (length + offset > message_length) { 883 LogMsg("dso_message_received: fatal: %s: TLV (%d %ld) extends past end (%ld)", 884 dso->remote_name, opcode, (long)length, (long)message_length); 885 886 // Short messages are a fatal error. XXX check DSO document 887 dso_state_cancel(dso); 888 goto out; 889 } 890 891 if (k == 0) { 892 num_additls++; 893 } else { 894 // Is this the primary TLV? 895 if (offset == 12) { 896 dso->primary.opcode = opcode; 897 dso->primary.length = length; 898 dso->primary.payload = &message[offset + 4]; 899 dso->num_additls = 0; 900 } else { 901 if (dso->num_additls < dso->max_additls) { 902 dso->additl[dso->num_additls].opcode = opcode; 903 dso->additl[dso->num_additls].length = length; 904 dso->additl[dso->num_additls].payload = &message[offset + 4]; 905 dso->num_additls++; 906 } else { 907 // XXX MAX_ADDITLS should be enough for all possible additional TLVs, so this 908 // XXX should never happen; if it does, maybe it's a fatal error. 909 LogMsg("dso_message_received: %s: ignoring additional TLV (%d %ld) in excess of %d", 910 dso->remote_name, opcode, (long)length, dso->max_additls); 911 } 912 } 913 } 914 offset += 4 + length; 915 } 916 if (k == 0) { 917 if (num_additls > dso->max_additls) { 918 if (dso->additl != dso->additl_buf) { 919 mdns_free(dso->additl); 920 } 921 dso->additl = mdns_calloc(num_additls, sizeof(*dso->additl)); 922 if (dso->additl == NULL) { 923 dso->additl = dso->additl_buf; 924 dso->max_additls = MAX_ADDITLS; 925 } else { 926 dso->max_additls = num_additls; 927 } 928 } 929 } 930 } 931 932 // Call the callback with the message or response 933 if (dso->cb) { 934 if (message_length != 12 && dso->primary.opcode == kDSOType_Keepalive) { 935 dso_keepalive(dso, header, response); 936 } else if (message_length != 12 && dso->primary.opcode == kDSOType_RetryDelay) { 937 dso_retry_delay(dso, header); 938 } else { 939 if (response) { 940 dso->cb(dso->context, &qcontext, dso, kDSOEventType_DSOResponse); 941 } else { 942 dso->cb(dso->context, context, dso, kDSOEventType_DSOMessage); 943 } 944 } 945 } 946 out: 947 ; 948 } 949 950 // This code is currently assuming that we won't get a DNS message, but that's not true. Fix. 951 void dns_message_received(dso_state_t *dso, const uint8_t *message, size_t message_length, void *context) 952 { 953 const DNSMessageHeader *header; 954 int opcode, response; 955 956 // We can safely assume that the header is 16-bit aligned. 957 header = (const DNSMessageHeader *)message; 958 opcode = header->flags.b[0] & kDNSFlag0_OP_Mask; 959 response = (header->flags.b[0] & kDNSFlag0_QR_Mask) == kDNSFlag0_QR_Response; 960 961 // Validate the length of the DNS message. 962 if (message_length < 12) { 963 LogMsg("dns_message_received: fatal: %s sent short (%ld byte) message", 964 dso->remote_name, (long)message_length); 965 966 // Short messages are a fatal error. 967 dso_state_cancel(dso); 968 return; 969 } 970 971 // This is not correct for the general case. 972 if (opcode != kDNSFlag0_OP_DSO) { 973 LogMsg("dns_message_received: %s sent %ld byte %s, QTYPE=%d", 974 dso->remote_name, (long)message_length, (response ? "response" : "request"), opcode); 975 if (dso->cb) { 976 dso->cb(dso->context, context, dso, 977 response ? kDSOEventType_DNSMessage : kDSOEventType_DNSResponse); 978 } 979 } else { 980 dso_message_received(dso, message, message_length, context); 981 } 982 } 983 984 const char *dso_event_type_to_string(const dso_event_type_t dso_event_type) 985 { 986 #define CASE_TO_STR(s) case kDSOEventType_ ## s: return (#s) 987 switch(dso_event_type) 988 { 989 CASE_TO_STR(DNSMessage); 990 CASE_TO_STR(DNSResponse); 991 CASE_TO_STR(DSOMessage); 992 CASE_TO_STR(Finalize); 993 CASE_TO_STR(DSOResponse); 994 CASE_TO_STR(Connected); 995 CASE_TO_STR(ConnectFailed); 996 CASE_TO_STR(Disconnected); 997 CASE_TO_STR(ShouldReconnect); 998 CASE_TO_STR(Inactive); 999 CASE_TO_STR(Keepalive); 1000 CASE_TO_STR(KeepaliveRcvd); 1001 CASE_TO_STR(RetryDelay); 1002 MDNS_COVERED_SWITCH_DEFAULT: 1003 break; 1004 } 1005 #undef CASE_TO_STR 1006 LogMsg("Invalid dso_event_type - dso_event_type: %d.", dso_event_type); 1007 return "<INVALID dso_event_type>"; 1008 } 1009 1010 // Local Variables: 1011 // mode: C 1012 // tab-width: 4 1013 // c-file-style: "bsd" 1014 // c-basic-offset: 4 1015 // fill-column: 108 1016 // indent-tabs-mode: nil 1017 // End: 1018