From 38b24a07d745f2440cc7c610915b1cb8b57c784e Mon Sep 17 00:00:00 2001 From: Aaron Drew Date: Tue, 5 May 2026 16:49:47 +1000 Subject: [PATCH 1/2] Extract DoH proxy core; unify UDP/TCP listeners behind one interface The per-request lifecycle (allocate state, hand to libcurl, route response back to the originating listener, free) lived as static functions in main.c glued together via app_state_t and an is_tcp tag in request_t. The two listener implementations had no shared interface, so dispatch branched on is_tcp at every step. - dns_listener.h: transport-agnostic interface (respond/stop/destroy function pointers + transport tag for metrics). - dns_listener_udp.{c,h}, dns_listener_tcp.{c,h}: adapters. - doh_proxy.{c,h}: owns the request lifecycle, curl resolve list, and bootstrap-await/on-ready handshake. Listeners and dns_poller call into the proxy, not main.c. - dns_common.h: shared DNS protocol constants. main.c shrinks to assembly. The is_tcp boolean is gone. --- src/dns_common.h | 13 ++ src/dns_listener.h | 46 ++++ src/{dns_server_tcp.c => dns_listener_tcp.c} | 93 ++++---- src/dns_listener_tcp.h | 21 ++ src/{dns_server.c => dns_listener_udp.c} | 77 +++++-- src/dns_listener_udp.h | 18 ++ src/dns_server.h | 43 ---- src/dns_server_tcp.h | 19 -- src/doh_proxy.c | 228 +++++++++++++++++++ src/doh_proxy.h | 49 ++++ src/main.c | 219 +++--------------- 11 files changed, 519 insertions(+), 307 deletions(-) create mode 100644 src/dns_common.h create mode 100644 src/dns_listener.h rename src/{dns_server_tcp.c => dns_listener_tcp.c} (87%) create mode 100644 src/dns_listener_tcp.h rename src/{dns_server.c => dns_listener_udp.c} (82%) create mode 100644 src/dns_listener_udp.h delete mode 100644 src/dns_server.h delete mode 100755 src/dns_server_tcp.h create mode 100644 src/doh_proxy.c create mode 100644 src/doh_proxy.h diff --git a/src/dns_common.h b/src/dns_common.h new file mode 100644 index 0000000..e249ab0 --- /dev/null +++ b/src/dns_common.h @@ -0,0 +1,13 @@ +#ifndef _DNS_COMMON_H_ +#define _DNS_COMMON_H_ + +// Constants from the DNS wire format, shared by both UDP and TCP listeners +// and by the DoH proxy core. + +enum { + DNS_HEADER_LENGTH = 12, // RFC1035 4.1.1 header size + DNS_SIZE_LIMIT = 512, // RFC1035 4.2.1 traditional UDP payload limit + DNS_REQUEST_BUFFER_SIZE = 4096 // EDNS default before DNS Flag Day 2020 +}; + +#endif // _DNS_COMMON_H_ diff --git a/src/dns_listener.h b/src/dns_listener.h new file mode 100644 index 0000000..7ec12f2 --- /dev/null +++ b/src/dns_listener.h @@ -0,0 +1,46 @@ +#ifndef _DNS_LISTENER_H_ +#define _DNS_LISTENER_H_ + +#include +#include + +// A DNS listener accepts requests on some transport (UDP, TCP, ...) and routes +// responses back to the originating peer. The proxy core never branches on +// transport — all listener-specific behaviour (datagram-vs-stream framing, +// EDNS-aware truncation, per-client state) lives behind the function pointers +// on dns_listener. + +typedef struct dns_listener dns_listener_t; + +// Transport classification, exposed for callers that need to bin metrics or +// log per-transport — not used for dispatch (dispatch is via the function +// pointers on dns_listener itself). +typedef enum { + DNS_TRANSPORT_UDP, + DNS_TRANSPORT_TCP, +} dns_transport_t; + +// Invoked once per fully-received DNS request. `dns_req` is heap-allocated +// and ownership transfers to the callee. `listener` is a back-pointer the +// callee uses later to deliver the matching response. +typedef void (*dns_request_fn)(void *ctx, dns_listener_t *listener, + struct sockaddr *raddr, + char *dns_req, size_t dns_req_len); + +struct dns_listener { + // Send `dns_resp` to `raddr`. UDP listeners may EDNS-truncate the response + // in place using `dns_req`; TCP listeners ignore the request bytes. + void (*respond)(dns_listener_t *self, struct sockaddr *raddr, + const char *dns_req, size_t dns_req_len, + char *dns_resp, size_t dns_resp_len); + // Stop accepting new requests. Existing per-client state (TCP) is retained + // so any in-flight DoH responses can still be delivered during graceful + // drain. + void (*stop)(dns_listener_t *self); + // Free the listener and any owned resources. + void (*destroy)(dns_listener_t *self); + + dns_transport_t transport; +}; + +#endif // _DNS_LISTENER_H_ diff --git a/src/dns_server_tcp.c b/src/dns_listener_tcp.c similarity index 87% rename from src/dns_server_tcp.c rename to src/dns_listener_tcp.c index 0346eb9..c7e5e23 100644 --- a/src/dns_server_tcp.c +++ b/src/dns_listener_tcp.c @@ -1,11 +1,16 @@ //NOLINTNEXTLINE(bugprone-reserved-identifier,cert-dcl37-c,cert-dcl51-cpp) #define _GNU_SOURCE // needed for having accept4() +#include #include #include +#include +#include +#include #include -#include "dns_server_tcp.h" +#include "dns_common.h" +#include "dns_listener_tcp.h" #include "logging.h" // Platform compatibility @@ -33,8 +38,10 @@ enum { TCP_DNS_MAX_PAYLOAD = UINT16_MAX - sizeof(uint16_t), // Max after 2-byte length prefix }; +typedef struct dns_listener_tcp_s dns_listener_tcp_t; + struct tcp_client_s { - struct dns_server_tcp_s * d; + dns_listener_tcp_t * d; uint64_t id; int sock; @@ -52,10 +59,12 @@ struct tcp_client_s { struct tcp_client_s * next; } __attribute__((packed)) __attribute__((aligned(128))); -struct dns_server_tcp_s { +struct dns_listener_tcp_s { + dns_listener_t base; + struct ev_loop *loop; - dns_req_received_cb cb; + dns_request_fn cb; void *cb_data; int sock; @@ -70,7 +79,7 @@ struct dns_server_tcp_s { static void remove_client(struct tcp_client_s * client) { - dns_server_tcp_t *d = client->d; + dns_listener_tcp_t *d = client->d; DLOG_CLIENT("Removing client, socket %d", client->sock); @@ -114,7 +123,7 @@ static int get_dns_request(struct tcp_client_s *client, return 0; // Partial request } // copy whole request - *dns_req = (char *)malloc(*req_size); // To free buffer after https request is complete. + *dns_req = (char *)malloc(*req_size); // freed when DoH request completes if (*dns_req == NULL) { FLOG_CLIENT("Out of mem"); } @@ -128,7 +137,7 @@ static int get_dns_request(struct tcp_client_s *client, static void read_cb(struct ev_loop __attribute__((unused)) *loop, ev_io *w, int __attribute__((unused)) revents) { struct tcp_client_s *client = (struct tcp_client_s *)w->data; - dns_server_tcp_t *d = client->d; + dns_listener_tcp_t *d = client->d; // Receive data char buf[DNS_REQUEST_BUFFER_SIZE]; // if there would be more data, callback will be called again @@ -191,7 +200,7 @@ static void read_cb(struct ev_loop __attribute__((unused)) *loop, return; } - d->cb(d, 1, d->cb_data, (struct sockaddr*)&client->raddr, dns_req, req_size); + d->cb(d->cb_data, &d->base, (struct sockaddr*)&client->raddr, dns_req, req_size); request_received = 1; } @@ -209,7 +218,7 @@ static void timer_cb(struct ev_loop __attribute__((unused)) *loop, static void accept_cb(struct ev_loop __attribute__((unused)) *loop, ev_io *w, int __attribute__((unused)) revents) { - dns_server_tcp_t *d = (dns_server_tcp_t *)w->data; + dns_listener_tcp_t *d = (dns_listener_tcp_t *)w->data; struct sockaddr_storage client_addr; socklen_t client_addr_len = sizeof(client_addr); @@ -258,7 +267,7 @@ static void accept_cb(struct ev_loop __attribute__((unused)) *loop, DLOG_CLIENT("Accepted client %u of %u, socket %d", d->client_count, d->client_limit, client->sock); } -// Creates and bind a listening non-blocking TCP socket for incoming requests. +// Creates and binds a listening non-blocking TCP socket for incoming requests. static int get_tcp_listen_sock(struct addrinfo *listen_addrinfo) { int sock = socket(listen_addrinfo->ai_family, SOCK_STREAM, 0); if (sock < 0) { @@ -308,33 +317,12 @@ static int get_tcp_listen_sock(struct addrinfo *listen_addrinfo) { return sock; } -dns_server_tcp_t * dns_server_tcp_create( - struct ev_loop *loop, struct addrinfo *listen_addrinfo, - dns_req_received_cb cb, void *data, uint16_t tcp_client_limit) { - dns_server_tcp_t * d = (dns_server_tcp_t *) malloc(sizeof(dns_server_tcp_t)); - if (d == NULL) { - FLOG("Out of mem"); - } - d->loop = loop; - d->cb = cb; - d->cb_data = data; - d->sock = get_tcp_listen_sock(listen_addrinfo); - d->addrlen = listen_addrinfo->ai_addrlen; - d->client_id = 0; - d->client_count = 0; - d->client_limit = tcp_client_limit; - d->clients = NULL; - - ev_io_init(&d->accept_watcher, accept_cb, d->sock, EV_READ); - d->accept_watcher.data = d; - ev_io_start(d->loop, &d->accept_watcher); - - return d; -} +static void tcp_respond(dns_listener_t *self, struct sockaddr *raddr, + const char __attribute__((unused)) *dns_req, + size_t __attribute__((unused)) dns_req_len, + char *resp, size_t resp_len) { + dns_listener_tcp_t *d = (dns_listener_tcp_t *)self; -void dns_server_tcp_respond(dns_server_tcp_t *d, - struct sockaddr *raddr, char *resp, size_t resp_len) -{ // Limit response size to prevent overflow when accounting for the 2-byte // length prefix. The total on-wire size would be resp_len + sizeof(uint16_t). if (resp_len < DNS_HEADER_LENGTH || resp_len > TCP_DNS_MAX_PAYLOAD) { @@ -402,13 +390,42 @@ void dns_server_tcp_respond(dns_server_tcp_t *d, ev_timer_again(d->loop, &client->timer_watcher); } -void dns_server_tcp_stop(dns_server_tcp_t *d) { +static void tcp_stop(dns_listener_t *self) { + dns_listener_tcp_t *d = (dns_listener_tcp_t *)self; while (d->clients) { remove_client(d->clients); //NOLINT(clang-analyzer-unix.Malloc) false use after free detection } ev_io_stop(d->loop, &d->accept_watcher); } -void dns_server_tcp_cleanup(dns_server_tcp_t *d) { +static void tcp_destroy(dns_listener_t *self) { + dns_listener_tcp_t *d = (dns_listener_tcp_t *)self; close(d->sock); + free(d); +} + +dns_listener_t * dns_tcp_listener_create(struct ev_loop *loop, + struct addrinfo *listen_addrinfo, + uint16_t client_limit, + dns_request_fn cb, void *ctx) { + dns_listener_tcp_t * d = (dns_listener_tcp_t *)calloc(1, sizeof(dns_listener_tcp_t)); + if (d == NULL) { + FLOG("Out of mem"); + } + d->base.respond = tcp_respond; + d->base.stop = tcp_stop; + d->base.destroy = tcp_destroy; + d->base.transport = DNS_TRANSPORT_TCP; + d->loop = loop; + d->cb = cb; + d->cb_data = ctx; + d->sock = get_tcp_listen_sock(listen_addrinfo); + d->addrlen = listen_addrinfo->ai_addrlen; + d->client_limit = client_limit; + + ev_io_init(&d->accept_watcher, accept_cb, d->sock, EV_READ); + d->accept_watcher.data = d; + ev_io_start(d->loop, &d->accept_watcher); + + return &d->base; } diff --git a/src/dns_listener_tcp.h b/src/dns_listener_tcp.h new file mode 100644 index 0000000..7a78bf0 --- /dev/null +++ b/src/dns_listener_tcp.h @@ -0,0 +1,21 @@ +#ifndef _DNS_LISTENER_TCP_H_ +#define _DNS_LISTENER_TCP_H_ + +#include +#include +#include + +#include "dns_listener.h" + +// Create a TCP DNS listener bound to `listen_addrinfo`. The returned listener +// implements the dns_listener_t interface; callers should treat it as such +// and use dns_listener_stop / dns_listener_destroy for lifecycle. +// +// `client_limit` caps the number of concurrent TCP clients. `cb` is invoked +// once per fully-received DNS request from any client. +dns_listener_t * dns_tcp_listener_create(struct ev_loop *loop, + struct addrinfo *listen_addrinfo, + uint16_t client_limit, + dns_request_fn cb, void *ctx); + +#endif // _DNS_LISTENER_TCP_H_ diff --git a/src/dns_server.c b/src/dns_listener_udp.c similarity index 82% rename from src/dns_server.c rename to src/dns_listener_udp.c index 3421995..71875d8 100644 --- a/src/dns_server.c +++ b/src/dns_listener_udp.c @@ -1,14 +1,29 @@ #include +#include #include #include +#include #include +#include #include -#include "dns_server.h" +#include "dns_common.h" +#include "dns_listener_udp.h" #include "logging.h" +typedef struct dns_listener_udp_s { + dns_listener_t base; -// Creates and bind a listening UDP socket for incoming requests. + struct ev_loop *loop; + int sock; + socklen_t addrlen; + ev_io watcher; + + dns_request_fn cb; + void *cb_data; +} dns_listener_udp_t; + +// Creates and binds a listening UDP socket for incoming requests. static int get_listen_sock(struct addrinfo *listen_addrinfo) { int sock = socket(listen_addrinfo->ai_family, SOCK_DGRAM, 0); if (sock < 0) { @@ -41,7 +56,7 @@ static int get_listen_sock(struct addrinfo *listen_addrinfo) { static void watcher_cb(struct ev_loop __attribute__((unused)) *loop, ev_io *w, int __attribute__((unused)) revents) { - dns_server_t *d = (dns_server_t *)w->data; + dns_listener_udp_t *d = (dns_listener_udp_t *)w->data; char tmp_buf[DNS_REQUEST_BUFFER_SIZE]; struct sockaddr_storage tmp_raddr; @@ -57,32 +72,18 @@ static void watcher_cb(struct ev_loop __attribute__((unused)) *loop, len, DNS_REQUEST_BUFFER_SIZE); return; } - if (len < DNS_HEADER_LENGTH) { WLOG("Malformed request received, too short: %d", len); return; } - char *dns_req = (char *)malloc((size_t)len); // To free buffer after https request is complete. + char *dns_req = (char *)malloc((size_t)len); // freed when DoH request completes if (dns_req == NULL) { FLOG("Out of mem"); } memcpy(dns_req, tmp_buf, (size_t)len); - d->cb(d, 0, d->cb_data, (struct sockaddr*)&tmp_raddr, dns_req, (size_t)len); -} - -void dns_server_init(dns_server_t *d, struct ev_loop *loop, - struct addrinfo *listen_addrinfo, - dns_req_received_cb cb, void *data) { - d->loop = loop; - d->sock = get_listen_sock(listen_addrinfo); - d->addrlen = listen_addrinfo->ai_addrlen; - d->cb = cb; - d->cb_data = data; - ev_io_init(&d->watcher, watcher_cb, d->sock, EV_READ); - d->watcher.data = d; - ev_io_start(d->loop, &d->watcher); + d->cb(d->cb_data, &d->base, (struct sockaddr*)&tmp_raddr, dns_req, (size_t)len); } static uint16_t get_edns_udp_size(const char *dns_req, const size_t dns_req_len) { @@ -195,8 +196,11 @@ static void truncate_dns_response(char *buf, size_t *buflen, const uint16_t size } } -void dns_server_respond(dns_server_t *d, struct sockaddr *raddr, - const char *dns_req, const size_t dns_req_len, char *dns_resp, size_t dns_resp_len) { +static void udp_respond(dns_listener_t *self, struct sockaddr *raddr, + const char *dns_req, size_t dns_req_len, + char *dns_resp, size_t dns_resp_len) { + dns_listener_udp_t *d = (dns_listener_udp_t *)self; + if (dns_resp_len < DNS_HEADER_LENGTH) { WLOG("Malformed response received, invalid length: %u", dns_resp_len); return; @@ -213,15 +217,40 @@ void dns_server_respond(dns_server_t *d, struct sockaddr *raddr, } ssize_t len = sendto(d->sock, dns_resp, dns_resp_len, 0, raddr, d->addrlen); - if(len == -1) { + if (len == -1) { DLOG("sendto failed: %s", strerror(errno)); } } -void dns_server_stop(dns_server_t *d) { +static void udp_stop(dns_listener_t *self) { + dns_listener_udp_t *d = (dns_listener_udp_t *)self; ev_io_stop(d->loop, &d->watcher); } -void dns_server_cleanup(dns_server_t *d) { +static void udp_destroy(dns_listener_t *self) { + dns_listener_udp_t *d = (dns_listener_udp_t *)self; close(d->sock); + free(d); +} + +dns_listener_t * dns_udp_listener_create(struct ev_loop *loop, + struct addrinfo *listen_addrinfo, + dns_request_fn cb, void *ctx) { + dns_listener_udp_t *d = (dns_listener_udp_t *)calloc(1, sizeof(dns_listener_udp_t)); + if (d == NULL) { + FLOG("Out of mem"); + } + d->base.respond = udp_respond; + d->base.stop = udp_stop; + d->base.destroy = udp_destroy; + d->base.transport = DNS_TRANSPORT_UDP; + d->loop = loop; + d->sock = get_listen_sock(listen_addrinfo); + d->addrlen = listen_addrinfo->ai_addrlen; + d->cb = cb; + d->cb_data = ctx; + ev_io_init(&d->watcher, watcher_cb, d->sock, EV_READ); + d->watcher.data = d; + ev_io_start(d->loop, &d->watcher); + return &d->base; } diff --git a/src/dns_listener_udp.h b/src/dns_listener_udp.h new file mode 100644 index 0000000..2fb34d4 --- /dev/null +++ b/src/dns_listener_udp.h @@ -0,0 +1,18 @@ +#ifndef _DNS_LISTENER_UDP_H_ +#define _DNS_LISTENER_UDP_H_ + +#include +#include + +#include "dns_listener.h" + +// Create a UDP DNS listener bound to `listen_addrinfo`. The returned listener +// implements the dns_listener_t interface; callers should treat it as such +// and use dns_listener_stop / dns_listener_destroy for lifecycle. +// +// `cb` is invoked once per inbound DNS request. +dns_listener_t * dns_udp_listener_create(struct ev_loop *loop, + struct addrinfo *listen_addrinfo, + dns_request_fn cb, void *ctx); + +#endif // _DNS_LISTENER_UDP_H_ diff --git a/src/dns_server.h b/src/dns_server.h deleted file mode 100644 index 0d87165..0000000 --- a/src/dns_server.h +++ /dev/null @@ -1,43 +0,0 @@ -#ifndef _DNS_SERVER_H_ -#define _DNS_SERVER_H_ - -#include -#include -#include -#include -#include -#include - -enum { - DNS_HEADER_LENGTH = 12, // RFC1035 4.1.1 header size - DNS_SIZE_LIMIT = 512, - DNS_REQUEST_BUFFER_SIZE = 4096 // EDNS default before DNS Flag Day 2020 -}; - -struct dns_server_s; - -typedef void (*dns_req_received_cb)(void *dns_server, uint8_t is_tcp, void *data, - struct sockaddr* addr, char *dns_req, size_t dns_req_len); - -typedef struct dns_server_s { - struct ev_loop *loop; - void *cb_data; - dns_req_received_cb cb; - int sock; - socklen_t addrlen; - ev_io watcher; -} dns_server_t; - -void dns_server_init(dns_server_t *d, struct ev_loop *loop, - struct addrinfo *listen_addrinfo, - dns_req_received_cb cb, void *data); - -// Sends a DNS response 'buf' of length 'blen' to 'raddr'. -void dns_server_respond(dns_server_t *d, struct sockaddr *raddr, - const char *dns_req, const size_t dns_req_len, char *dns_resp, size_t dns_resp_len); - -void dns_server_stop(dns_server_t *d); - -void dns_server_cleanup(dns_server_t *d); - -#endif // _DNS_SERVER_H_ diff --git a/src/dns_server_tcp.h b/src/dns_server_tcp.h deleted file mode 100755 index 3fb32a9..0000000 --- a/src/dns_server_tcp.h +++ /dev/null @@ -1,19 +0,0 @@ -#ifndef _DNS_SERVER_TCP_H_ -#define _DNS_SERVER_TCP_H_ - -#include "dns_server.h" - -typedef struct dns_server_tcp_s dns_server_tcp_t; - -dns_server_tcp_t * dns_server_tcp_create( - struct ev_loop *loop, struct addrinfo *listen_addrinfo, - dns_req_received_cb cb, void *data, uint16_t tcp_client_limit); - -void dns_server_tcp_respond(dns_server_tcp_t *d, - struct sockaddr *raddr, char *resp, size_t resp_len); - -void dns_server_tcp_stop(dns_server_tcp_t *d); - -void dns_server_tcp_cleanup(dns_server_tcp_t *d); - -#endif // _DNS_SERVER_H_ diff --git a/src/doh_proxy.c b/src/doh_proxy.c new file mode 100644 index 0000000..f8a8fef --- /dev/null +++ b/src/doh_proxy.c @@ -0,0 +1,228 @@ +#include +#include +#include + +#include "dns_common.h" +#include "doh_proxy.h" +#include "logging.h" + +struct doh_proxy { + struct ev_loop *loop; + https_client_t *client; + const char *resolver_url; + stat_t *stat; + + // CURLOPT_RESOLVE entries (one slist node, "host:443:ip1,ip2,..."). NULL + // until the first successful resolver update arrives. + struct curl_slist *resolv; + + // True until the first successful resolver update completes. While set, + // inbound DNS requests are dropped (we'd otherwise leak through libcurl's + // fallback resolver and risk a recursion through our own listener). + uint8_t awaiting_bootstrap; + + doh_proxy_ready_fn on_ready; + void *on_ready_ctx; + uint8_t on_ready_fired; +}; + +// Per-request transient state. Lives from doh_proxy_handle_request to +// https_resp_cb, when the response (or failure) returns from libcurl. +// NOLINTNEXTLINE(altera-struct-pack-align) +typedef struct { + doh_proxy_t *proxy; + dns_listener_t *listener; + uint16_t tx_id; + ev_tstamp start_tstamp; + struct sockaddr_storage raddr; + char *dns_req; + size_t dns_req_len; +} doh_request_t; + +doh_proxy_t * doh_proxy_create(struct ev_loop *loop, + https_client_t *client, + const char *resolver_url, + stat_t *stat) { + doh_proxy_t *p = (doh_proxy_t *)calloc(1, sizeof(doh_proxy_t)); + if (p == NULL) { + FLOG("Out of mem"); + } + p->loop = loop; + p->client = client; + p->resolver_url = resolver_url; + p->stat = stat; + return p; +} + +void doh_proxy_await_bootstrap(doh_proxy_t *p) { + p->awaiting_bootstrap = 1; +} + +void doh_proxy_set_on_ready(doh_proxy_t *p, doh_proxy_ready_fn cb, void *cb_ctx) { + p->on_ready = cb; + p->on_ready_ctx = cb_ctx; +} + +static void fire_on_ready(doh_proxy_t *p) { + if (p->on_ready_fired || !p->on_ready) { + return; + } + p->on_ready_fired = 1; + p->on_ready(p->on_ready_ctx); +} + +// Returns 1 if `addr_list` is a (possibly equal, possibly proper) subset of +// `full_list`, where both are comma-separated IP literals. Used to decide +// whether a fresh poll result actually changed anything; if every IP in the +// new list is already in the old list, we skip the curl reset. +static int addr_list_reduced(const char* full_list, const char* list) { + const char *pos = list; + const char *end = list + strlen(list); + while (pos < end) { + char current[50]; + const char *comma = strchr(pos, ','); + size_t ip_len = (size_t)(comma ? comma - pos : end - pos); + if (ip_len >= sizeof(current)) { + DLOG("IP address too long: %zu bytes", ip_len); + return 1; + } + strncpy(current, pos, ip_len); + current[ip_len] = '\0'; + + const char *match_begin = strstr(full_list, current); + if (!match_begin || + !(match_begin == full_list || *(match_begin - 1) == ',') || + !(*(match_begin + ip_len) == ',' || *(match_begin + ip_len) == '\0')) { + DLOG("IP address missing: %s", current); + return 1; + } + + pos += ip_len + 1; + } + return 0; +} + +void doh_proxy_handle_resolver_update(const char *hostname, void *ctx, + const char *addr_list) { + doh_proxy_t *p = (doh_proxy_t *)ctx; + + if (addr_list == NULL) { + WLOG("DNS poll for '%s' returned no usable addresses, will retry.", hostname); + return; + } + + char buf[255 + (sizeof(":443:") - 1) + POLLER_ADDR_LIST_SIZE]; + memset(buf, 0, sizeof(buf)); + if (strlen(hostname) > 254) { FLOG("Hostname too long."); } + int ip_start = snprintf(buf, sizeof(buf) - 1, "%s:443:", hostname); + if (ip_start < 0) { + abort(); // must be impossible + } + (void)snprintf(buf + ip_start, sizeof(buf) - 1 - (uint32_t)ip_start, "%s", addr_list); + + if (p->resolv && p->resolv->data) { + char *old_addr_list = strstr(p->resolv->data, ":443:"); + if (old_addr_list) { + old_addr_list += sizeof(":443:") - 1; + if (!addr_list_reduced(addr_list, old_addr_list)) { + DLOG("DNS server IP address unchanged (%s).", buf + ip_start); + free((void*)addr_list); + return; + } + } + } + free((void*)addr_list); + DLOG("Received new DNS server IP '%s'", buf + ip_start); + curl_slist_free_all(p->resolv); + p->resolv = curl_slist_append(NULL, buf); + + // Reset libcurl: in-flight connections were aimed at the old IP, and curl + // gets confused if we leave them around with a different CURLOPT_RESOLVE. + https_client_reset(p->client); + + if (p->awaiting_bootstrap) { + p->awaiting_bootstrap = 0; + fire_on_ready(p); + } +} + +static void doh_response_cb(void *data, char *buf, size_t buflen) { + doh_request_t *req = (doh_request_t *)data; + if (req == NULL) { + FLOG("Request data is NULL (buflen: %zu)", buflen); + return; + } + doh_proxy_t *p = req->proxy; + DLOG("Received response for id: %hX, len: %zu", req->tx_id, buflen); + + if (buf != NULL) { // NULL on timeout / DNS failure / similar. + if (buflen < DNS_HEADER_LENGTH) { + WLOG("%04hX: Malformed response received, too short: %u", req->tx_id, buflen); + } else { + const uint16_t response_id = ntohs(*((uint16_t*)buf)); + if (req->tx_id != response_id) { + WLOG("DNS request and response IDs are not matching: %hX != %hX", + req->tx_id, response_id); + } else { + req->listener->respond(req->listener, (struct sockaddr*)&req->raddr, + req->dns_req, req->dns_req_len, buf, buflen); + if (p->stat) { + stat_request_end(p->stat, buflen, + ev_now(p->stat->loop) - req->start_tstamp, + req->listener->transport == DNS_TRANSPORT_TCP); + } + } + } + } + + free((void*)req->dns_req); + free(req); +} + +void doh_proxy_handle_request(void *ctx, dns_listener_t *listener, + struct sockaddr *raddr, + char *dns_req, size_t dns_req_len) { + doh_proxy_t *p = (doh_proxy_t *)ctx; + + uint16_t tx_id = ntohs(*((uint16_t*)dns_req)); + DLOG("Received request for id: %hX, len: %zu", tx_id, dns_req_len); + + if (p->awaiting_bootstrap) { + WLOG("%04hX: Query received before bootstrapping is completed, discarding.", tx_id); + free(dns_req); + return; + } + + doh_request_t *req = (doh_request_t *)calloc(1, sizeof(doh_request_t)); + if (req == NULL) { + FLOG("%04hX: Out of mem", tx_id); + } + req->proxy = p; + req->listener = listener; + req->tx_id = tx_id; + req->dns_req = dns_req; + req->dns_req_len = dns_req_len; + // raddr length depends on family; sockaddr_storage holds either. Copy what + // the address actually has, not more. + socklen_t raddr_len = (raddr->sa_family == AF_INET6) + ? sizeof(struct sockaddr_in6) + : sizeof(struct sockaddr_in); + memcpy(&req->raddr, raddr, raddr_len); + + if (p->stat) { + req->start_tstamp = ev_now(p->stat->loop); + stat_request_begin(p->stat, dns_req_len, + listener->transport == DNS_TRANSPORT_TCP); + } + + https_client_fetch(p->client, p->resolver_url, req->dns_req, dns_req_len, + p->resolv, req->tx_id, doh_response_cb, req); +} + +void doh_proxy_destroy(doh_proxy_t *p) { + if (p == NULL) { + return; + } + curl_slist_free_all(p->resolv); + free(p); +} diff --git a/src/doh_proxy.h b/src/doh_proxy.h new file mode 100644 index 0000000..b7e3a69 --- /dev/null +++ b/src/doh_proxy.h @@ -0,0 +1,49 @@ +#ifndef _DOH_PROXY_H_ +#define _DOH_PROXY_H_ + +#include + +#include "dns_listener.h" +#include "dns_poller.h" +#include "https_client.h" +#include "stat.h" + +// The DoH proxy core. Owns the per-request lifecycle (allocate state on +// inbound, hand to the HTTPS client, route the response back to the +// originating listener, free) and the curl resolve list driven by the +// bootstrap DNS poller. +typedef struct doh_proxy doh_proxy_t; + +// Optional callback invoked the first time the proxy is ready to serve +// requests (after bootstrap completes, if bootstrap was required). Called +// at most once. +typedef void (*doh_proxy_ready_fn)(void *ctx); + +doh_proxy_t * doh_proxy_create(struct ev_loop *loop, + https_client_t *client, + const char *resolver_url, + stat_t *stat); + +// Mark the proxy as awaiting bootstrap. Until the first successful resolver +// update, inbound DNS requests will be dropped (libcurl would otherwise fall +// back to gethostbyname() and may deadlock if our resolver depends on us). +void doh_proxy_await_bootstrap(doh_proxy_t *p); + +// Set a one-shot "ready" notifier (e.g. systemd_notify_ready). Fires when +// bootstrap completes, or never if await_bootstrap was never called. +void doh_proxy_set_on_ready(doh_proxy_t *p, doh_proxy_ready_fn cb, void *cb_ctx); + +// dns_request_fn — pass to dns_*_listener_create as the request callback. +// `ctx` must be a doh_proxy_t *. +void doh_proxy_handle_request(void *ctx, dns_listener_t *listener, + struct sockaddr *raddr, + char *dns_req, size_t dns_req_len); + +// dns_poller_cb — pass to dns_poller_init as the resolver-update callback. +// Takes ownership of `addr_list` (will free it). +void doh_proxy_handle_resolver_update(const char *hostname, void *ctx, + const char *addr_list); + +void doh_proxy_destroy(doh_proxy_t *p); + +#endif // _DOH_PROXY_H_ diff --git a/src/main.c b/src/main.c index 0ba7f9c..b94a9c3 100644 --- a/src/main.c +++ b/src/main.c @@ -1,6 +1,7 @@ // Simple UDP-to-HTTPS DNS Proxy // (C) 2016 Aaron Drew +#include #include #include #include @@ -13,37 +14,16 @@ #include #endif +#include "dns_listener.h" +#include "dns_listener_tcp.h" +#include "dns_listener_udp.h" #include "dns_poller.h" -#include "dns_server.h" -#include "dns_server_tcp.h" +#include "doh_proxy.h" #include "https_client.h" #include "logging.h" #include "options.h" #include "stat.h" -// Holds app state required for dns_server_cb. -// NOLINTNEXTLINE(altera-struct-pack-align) -typedef struct { - https_client_t *https_client; - struct curl_slist *resolv; - const char *resolver_url; - stat_t *stat; - uint8_t using_dns_poller; - socklen_t addrlen; -} app_state_t; - -// NOLINTNEXTLINE(altera-struct-pack-align) -typedef struct { - void *dns_server; - uint8_t is_tcp; - char* dns_req; - size_t dns_req_len; - stat_t *stat; - ev_tstamp start_tstamp; - uint16_t tx_id; - struct sockaddr_storage raddr; -} request_t; - static int is_ipv4_address(char *str) { struct in6_addr addr; return inet_pton(AF_INET, str, &addr) == 1; @@ -89,76 +69,7 @@ static void sigpipe_cb(struct ev_loop __attribute__((__unused__)) *loop, ELOG("Received SIGPIPE. Ignoring."); } -static void https_resp_cb(void *data, char *buf, size_t buflen) { - request_t *req = (request_t *)data; - if (req == NULL) { - FLOG("Request data is NULL (buflen: %zu)", buflen); - return; - } - DLOG("Received response for id: %hX, len: %zu", req->tx_id, buflen); - if (buf != NULL) { // May be NULL for timeout, DNS failure, or something similar. - if (buflen < DNS_HEADER_LENGTH) { - WLOG("%04hX: Malformed response received, too short: %u", req->tx_id, buflen); - } else { - const uint16_t response_id = ntohs(*((uint16_t*)buf)); - if (req->tx_id != response_id) { - WLOG("DNS request and response IDs are not matching: %hX != %hX", - req->tx_id, response_id); - } else { - if (req->is_tcp) { - dns_server_tcp_respond((dns_server_tcp_t *)req->dns_server, (struct sockaddr*)&req->raddr, buf, buflen); - } else { - dns_server_respond((dns_server_t *)req->dns_server, (struct sockaddr*)&req->raddr, - req->dns_req, req->dns_req_len, buf, buflen); - } - if (req->stat) { - stat_request_end(req->stat, buflen, ev_now(req->stat->loop) - req->start_tstamp, req->is_tcp); - } - } - } - } - free((void*)req->dns_req); - free(req); -} - -static void dns_server_cb(void *dns_server, uint8_t is_tcp, void *data, - struct sockaddr* tmp_remote_addr, - char *dns_req, size_t dns_req_len) { - app_state_t *app = (app_state_t *)data; - - uint16_t tx_id = ntohs(*((uint16_t*)dns_req)); - DLOG("Received request for id: %hX, len: %d", tx_id, dns_req_len); - - // If we're not yet bootstrapped, don't answer. libcurl will fall back to - // gethostbyname() which can cause a DNS loop due to the nameserver listed - // in resolv.conf being or depending on https_dns_proxy itself. - if(app->using_dns_poller && (app->resolv == NULL || app->resolv->data == NULL)) { - WLOG("%04hX: Query received before bootstrapping is completed, discarding.", tx_id); - free(dns_req); - return; - } - - request_t *req = (request_t *)calloc(1, sizeof(request_t)); - if (req == NULL) { - FLOG("%04hX: Out of mem", tx_id); - } - req->tx_id = tx_id; - memcpy(&req->raddr, tmp_remote_addr, app->addrlen); - req->dns_server = dns_server; - req->is_tcp = is_tcp; - req->dns_req = dns_req; // To free buffer after https request is complete. - req->dns_req_len = dns_req_len; - req->stat = app->stat; - - if (req->stat) { - req->start_tstamp = ev_now(app->stat->loop); - stat_request_begin(app->stat, dns_req_len, is_tcp); - } - https_client_fetch(app->https_client, app->resolver_url, - req->dns_req, dns_req_len, app->resolv, req->tx_id, https_resp_cb, req); -} - -static void systemd_notify_ready(void) { +static void systemd_notify_ready(void __attribute__((__unused__)) *unused) { #if HAS_LIBSYSTEMD == 1 static uint8_t called_once = 0; if (called_once != 0) { @@ -179,67 +90,6 @@ static void systemd_notify_ready(void) { #endif } -static int addr_list_reduced(const char* full_list, const char* list) { - const char *pos = list; - const char *end = list + strlen(list); - while (pos < end) { - char current[50]; - const char *comma = strchr(pos, ','); - size_t ip_len = (size_t)(comma ? comma - pos : end - pos); - if (ip_len >= sizeof(current)) { - DLOG("IP address too long: %zu bytes", ip_len); - return 1; - } - strncpy(current, pos, ip_len); - current[ip_len] = '\0'; - - const char *match_begin = strstr(full_list, current); - if (!match_begin || - !(match_begin == full_list || *(match_begin - 1) == ',') || - !(*(match_begin + ip_len) == ',' || *(match_begin + ip_len) == '\0')) { - DLOG("IP address missing: %s", current); - return 1; - } - - pos += ip_len + 1; - } - return 0; -} - -static void dns_poll_cb(const char* hostname, void *data, - const char* addr_list) { - app_state_t *app = (app_state_t *)data; - char buf[255 + (sizeof(":443:") - 1) + POLLER_ADDR_LIST_SIZE]; - memset(buf, 0, sizeof(buf)); - if (strlen(hostname) > 254) { FLOG("Hostname too long."); } - int ip_start = snprintf(buf, sizeof(buf) - 1, "%s:443:", hostname); - if (ip_start < 0) { - abort(); // must be impossible - } - (void)snprintf(buf + ip_start, sizeof(buf) - 1 - (uint32_t)ip_start, "%s", addr_list); - if (app->resolv == NULL) { - systemd_notify_ready(); - } - if (app->resolv && app->resolv->data) { - char * old_addr_list = strstr(app->resolv->data, ":443:"); - if (old_addr_list) { - old_addr_list += sizeof(":443:") - 1; - if (!addr_list_reduced(addr_list, old_addr_list)) { - DLOG("DNS server IP address unchanged (%s).", buf + ip_start); - free((void*)addr_list); - return; - } - } - } - free((void*)addr_list); - DLOG("Received new DNS server IP '%s'", buf + ip_start); - curl_slist_free_all(app->resolv); - app->resolv = curl_slist_append(NULL, buf); - // Resets curl or it gets in a mess due to IP of streaming connection not - // matching that of configured DNS. - https_client_reset(app->https_client); -} - static int proxy_supports_name_resolution(const char *proxy) { size_t i = 0; @@ -355,9 +205,13 @@ int main(int argc, char *argv[]) { stat_t stat; stat_init(&stat, loop, opt.stats_interval); + stat_t *stat_ptr = (opt.stats_interval ? &stat : NULL); https_client_t https_client; - https_client_init(&https_client, &opt, (opt.stats_interval ? &stat : NULL), loop); + https_client_init(&https_client, &opt, stat_ptr, loop); + + doh_proxy_t *proxy = doh_proxy_create(loop, &https_client, + opt.resolver_url, stat_ptr); struct addrinfo *listen_addrinfo = get_listen_address(opt.listen_addr); @@ -367,20 +221,15 @@ int main(int argc, char *argv[]) { ((struct sockaddr_in6*) listen_addrinfo->ai_addr)->sin6_port = htons((uint16_t)opt.listen_port); } - app_state_t app; - app.https_client = &https_client; - app.resolv = NULL; - app.resolver_url = opt.resolver_url; - app.using_dns_poller = 0; - app.stat = (opt.stats_interval ? &stat : NULL); - app.addrlen = listen_addrinfo->ai_addrlen; - - dns_server_t dns_server; - dns_server_init(&dns_server, loop, listen_addrinfo, dns_server_cb, &app); + dns_listener_t *udp_listener = + dns_udp_listener_create(loop, listen_addrinfo, + doh_proxy_handle_request, proxy); - dns_server_tcp_t * dns_server_tcp = NULL; + dns_listener_t *tcp_listener = NULL; if (opt.tcp_client_limit > 0) { - dns_server_tcp = dns_server_tcp_create(loop, listen_addrinfo, dns_server_cb, &app, (uint16_t)opt.tcp_client_limit); + tcp_listener = dns_tcp_listener_create(loop, listen_addrinfo, + (uint16_t)opt.tcp_client_limit, + doh_proxy_handle_request, proxy); } freeaddrinfo(listen_addrinfo); @@ -418,39 +267,42 @@ int main(int argc, char *argv[]) { logging_events_init(loop); dns_poller_t dns_poller; + uint8_t using_dns_poller = 0; char hostname[255] = {0}; // Domain names shouldn't exceed 253 chars. if (!proxy_supports_name_resolution(opt.curl_proxy)) { if (hostname_from_url(opt.resolver_url, hostname, sizeof(hostname))) { - app.using_dns_poller = 1; + using_dns_poller = 1; + doh_proxy_await_bootstrap(proxy); + doh_proxy_set_on_ready(proxy, systemd_notify_ready, NULL); dns_poller_init(&dns_poller, loop, opt.bootstrap_dns, opt.bootstrap_dns_polling_interval, opt.source_addr, hostname, opt.ipv4 ? AF_INET : AF_UNSPEC, - dns_poll_cb, &app); + doh_proxy_handle_resolver_update, proxy); ILOG("DNS polling initialized for '%s'", hostname); } else { ILOG("Resolver prefix '%s' doesn't appear to contain a " "hostname. DNS polling disabled.", opt.resolver_url); - - systemd_notify_ready(); + systemd_notify_ready(NULL); } + } else { + systemd_notify_ready(NULL); } ev_run(loop, 0); DLOG("loop breaked"); - if (app.using_dns_poller) { + if (using_dns_poller) { dns_poller_cleanup(&dns_poller); } - curl_slist_free_all(app.resolv); logging_events_cleanup(loop); ev_signal_stop(loop, &sigterm); ev_signal_stop(loop, &sigint); ev_signal_stop(loop, &sigpipe); - dns_server_stop(&dns_server); - if (dns_server_tcp != NULL) { - dns_server_tcp_stop(dns_server_tcp); + udp_listener->stop(udp_listener); + if (tcp_listener != NULL) { + tcp_listener->stop(tcp_listener); } stat_stop(&stat); @@ -458,13 +310,14 @@ int main(int argc, char *argv[]) { ev_run(loop, 0); DLOG("loop finished all events"); - dns_server_cleanup(&dns_server); - if (dns_server_tcp != NULL) { - dns_server_tcp_cleanup(dns_server_tcp); - free(dns_server_tcp); - dns_server_tcp = NULL; + udp_listener->destroy(udp_listener); + if (tcp_listener != NULL) { + tcp_listener->destroy(tcp_listener); } + // The CURLOPT_RESOLVE list owned by the proxy must outlive in-flight curl + // easy handles, which is why https_client_cleanup runs first. https_client_cleanup(&https_client); + doh_proxy_destroy(proxy); stat_cleanup(&stat); ev_loop_destroy(loop); From 346509a11a6c0fe131f3fa54944584eabe2e335d Mon Sep 17 00:00:00 2001 From: Aaron Drew Date: Tue, 5 May 2026 16:52:14 +1000 Subject: [PATCH 2/2] Extract DNS response truncation into its own module MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ~110 lines of EDNS0/c-ares juggling (parse OPT, drop additional+authority, drop answers until under the limit, set TC) was buried as static functions in the UDP listener — only reachable by booting a real UDP server. Move to dns_truncate.{c,h} with one entry point: dns_truncate_for_udp(). The UDP listener's respond path collapses to validate-length, truncate-for-fit, sendto. The module is now unit-testable with synthetic byte buffers. --- src/dns_listener_udp.c | 123 +----------------------------------- src/dns_truncate.c | 139 +++++++++++++++++++++++++++++++++++++++++ src/dns_truncate.h | 20 ++++++ 3 files changed, 161 insertions(+), 121 deletions(-) create mode 100644 src/dns_truncate.c create mode 100644 src/dns_truncate.h diff --git a/src/dns_listener_udp.c b/src/dns_listener_udp.c index 71875d8..2cf6689 100644 --- a/src/dns_listener_udp.c +++ b/src/dns_listener_udp.c @@ -1,4 +1,3 @@ -#include #include #include #include @@ -9,6 +8,7 @@ #include "dns_common.h" #include "dns_listener_udp.h" +#include "dns_truncate.h" #include "logging.h" typedef struct dns_listener_udp_s { @@ -86,116 +86,6 @@ static void watcher_cb(struct ev_loop __attribute__((unused)) *loop, d->cb(d->cb_data, &d->base, (struct sockaddr*)&tmp_raddr, dns_req, (size_t)len); } -static uint16_t get_edns_udp_size(const char *dns_req, const size_t dns_req_len) { - ares_dns_record_t *dnsrec = NULL; - ares_status_t parse_status = ares_dns_parse((const unsigned char *)dns_req, dns_req_len, 0, &dnsrec); - if (parse_status != ARES_SUCCESS) { - WLOG("Failed to parse DNS request: %s", ares_strerror((int)parse_status)); - return DNS_SIZE_LIMIT; - } - const uint16_t tx_id = ares_dns_record_get_id(dnsrec); - uint16_t udp_size = 0; - const size_t record_count = ares_dns_record_rr_cnt(dnsrec, ARES_SECTION_ADDITIONAL); - for (size_t i = 0; i < record_count; ++i) { - const ares_dns_rr_t *rr = ares_dns_record_rr_get(dnsrec, ARES_SECTION_ADDITIONAL, i); - if (ares_dns_rr_get_type(rr) == ARES_REC_TYPE_OPT) { - udp_size = ares_dns_rr_get_u16(rr, ARES_RR_OPT_UDP_SIZE); - if (udp_size > 0) { - DLOG("%04hX: Found EDNS0 UDP buffer size: %u", tx_id, udp_size); - } - break; - } - } - ares_dns_record_destroy(dnsrec); - if (udp_size < DNS_SIZE_LIMIT) { - DLOG("%04hX: EDNS0 UDP buffer size %u overruled to %d", tx_id, udp_size, DNS_SIZE_LIMIT); - return DNS_SIZE_LIMIT; // RFC6891 4.3 "Values lower than 512 MUST be treated as equal to 512." - } - return udp_size; -} - -static void truncate_dns_response(char *buf, size_t *buflen, const uint16_t size_limit) { - const size_t old_size = *buflen; - buf[2] |= 0x02; // anyway: set truncation flag - - ares_dns_record_t *dnsrec = NULL; - ares_status_t status = ares_dns_parse((const unsigned char *)buf, *buflen, 0, &dnsrec); - if (status != ARES_SUCCESS) { - WLOG("Failed to parse DNS response: %s", ares_strerror((int)status)); - return; - } - const uint16_t tx_id = ares_dns_record_get_id(dnsrec); - - // NOTE: according to current c-ares implementation, removing first or last elements are the fastest! - - // remove every additional and authority record - while (ares_dns_record_rr_cnt(dnsrec, ARES_SECTION_ADDITIONAL) > 0) { - status = ares_dns_record_rr_del(dnsrec, ARES_SECTION_ADDITIONAL, 0); - if (status != ARES_SUCCESS) { - WLOG("%04hX: Could not remove additional record: %s", tx_id, ares_strerror((int)status)); - } - } - while (ares_dns_record_rr_cnt(dnsrec, ARES_SECTION_AUTHORITY) > 0) { - status = ares_dns_record_rr_del(dnsrec, ARES_SECTION_AUTHORITY, 0); - if (status != ARES_SUCCESS) { - WLOG("%04hX: Could not remove authority record: %s", tx_id, ares_strerror((int)status)); - } - } - - // rough estimate to reach size limit - size_t answers = ares_dns_record_rr_cnt(dnsrec, ARES_SECTION_ANSWER); - size_t answers_to_keep = ((size_limit - DNS_HEADER_LENGTH) * answers) / old_size; - answers_to_keep = answers_to_keep > 0 ? answers_to_keep : 1; // try to keep 1 answer - - // remove answer records until fit size limit or running out of answers - unsigned char *new_resp = NULL; - size_t new_resp_len = 0; - for (uint8_t g = 0; g < UINT8_MAX; ++g) { // endless loop guard - status = ares_dns_write(dnsrec, &new_resp, &new_resp_len); - if (status != ARES_SUCCESS) { - WLOG("%04hX: Failed to create truncated DNS response: %s", tx_id, ares_strerror((int)status)); - new_resp = NULL; // just to be sure - break; - } - if (new_resp_len < size_limit || answers == 0) { - break; - } - if (new_resp_len >= old_size) { - WLOG("%04hX: Truncated DNS response size larger or equal to original: %u >= %u", - tx_id, new_resp_len, old_size); // impossible? - } - ares_free_string(new_resp); - new_resp = NULL; - - DLOG("%04hX: DNS response size truncated from %u to %u but to keep %u limit reducing answers from %u to %u", - tx_id, old_size, new_resp_len, size_limit, answers, answers_to_keep); - - while (answers > answers_to_keep) { - status = ares_dns_record_rr_del(dnsrec, ARES_SECTION_ANSWER, answers - 1); - if (status != ARES_SUCCESS) { - WLOG("%04hX: Could not remove answer record: %s", tx_id, ares_strerror((int)status)); - break; - } - --answers; - } - answers = ares_dns_record_rr_cnt(dnsrec, ARES_SECTION_ANSWER); // update to be sure! - answers_to_keep /= 2; - } - ares_dns_record_destroy(dnsrec); - - if (new_resp != NULL) { - if (new_resp_len < old_size) { - memcpy(buf, new_resp, new_resp_len); - *buflen = new_resp_len; - buf[2] |= 0x02; // set truncation flag - ILOG("%04hX: DNS response size truncated from %u to %u to keep %u limit", - tx_id, old_size, new_resp_len, size_limit); - } - ares_free_string(new_resp); - new_resp = NULL; - } -} - static void udp_respond(dns_listener_t *self, struct sockaddr *raddr, const char *dns_req, size_t dns_req_len, char *dns_resp, size_t dns_resp_len) { @@ -205,16 +95,7 @@ static void udp_respond(dns_listener_t *self, struct sockaddr *raddr, WLOG("Malformed response received, invalid length: %u", dns_resp_len); return; } - if (dns_resp_len > DNS_SIZE_LIMIT) { - const uint16_t udp_size = get_edns_udp_size(dns_req, dns_req_len); - if (dns_resp_len > udp_size) { - truncate_dns_response(dns_resp, &dns_resp_len, udp_size); - } else { - uint16_t tx_id = ntohs(*((uint16_t*)dns_req)); - DLOG("%04hX: DNS response size %u larger than %d but EDNS0 UDP buffer size %u allows it", - tx_id, dns_resp_len, DNS_SIZE_LIMIT, udp_size); - } - } + dns_truncate_for_udp(dns_req, dns_req_len, dns_resp, &dns_resp_len); ssize_t len = sendto(d->sock, dns_resp, dns_resp_len, 0, raddr, d->addrlen); if (len == -1) { diff --git a/src/dns_truncate.c b/src/dns_truncate.c new file mode 100644 index 0000000..7c439e9 --- /dev/null +++ b/src/dns_truncate.c @@ -0,0 +1,139 @@ +#include +#include +#include +#include +#include + +#include "dns_common.h" +#include "dns_truncate.h" +#include "logging.h" + +// Returns the size limit the request peer is willing to accept. Reads the +// EDNS0 OPT record from the request's additional section. Falls back to the +// RFC1035 4.2.1 default of 512 if the request can't be parsed or the OPT +// advertises a smaller size (RFC6891 4.3 mandates that values below 512 +// MUST be treated as 512). +static uint16_t get_edns_udp_size(const char *dns_req, const size_t dns_req_len) { + ares_dns_record_t *dnsrec = NULL; + ares_status_t parse_status = ares_dns_parse((const unsigned char *)dns_req, dns_req_len, 0, &dnsrec); + if (parse_status != ARES_SUCCESS) { + WLOG("Failed to parse DNS request: %s", ares_strerror((int)parse_status)); + return DNS_SIZE_LIMIT; + } + const uint16_t tx_id = ares_dns_record_get_id(dnsrec); + uint16_t udp_size = 0; + const size_t record_count = ares_dns_record_rr_cnt(dnsrec, ARES_SECTION_ADDITIONAL); + for (size_t i = 0; i < record_count; ++i) { + const ares_dns_rr_t *rr = ares_dns_record_rr_get(dnsrec, ARES_SECTION_ADDITIONAL, i); + if (ares_dns_rr_get_type(rr) == ARES_REC_TYPE_OPT) { + udp_size = ares_dns_rr_get_u16(rr, ARES_RR_OPT_UDP_SIZE); + if (udp_size > 0) { + DLOG("%04hX: Found EDNS0 UDP buffer size: %u", tx_id, udp_size); + } + break; + } + } + ares_dns_record_destroy(dnsrec); + if (udp_size < DNS_SIZE_LIMIT) { + DLOG("%04hX: EDNS0 UDP buffer size %u overruled to %d", tx_id, udp_size, DNS_SIZE_LIMIT); + return DNS_SIZE_LIMIT; + } + return udp_size; +} + +static void truncate_to_size_limit(char *buf, size_t *buflen, const uint16_t size_limit) { + const size_t old_size = *buflen; + buf[2] |= 0x02; // anyway: set truncation flag + + ares_dns_record_t *dnsrec = NULL; + ares_status_t status = ares_dns_parse((const unsigned char *)buf, *buflen, 0, &dnsrec); + if (status != ARES_SUCCESS) { + WLOG("Failed to parse DNS response: %s", ares_strerror((int)status)); + return; + } + const uint16_t tx_id = ares_dns_record_get_id(dnsrec); + + // NOTE: according to current c-ares implementation, removing first or last elements are the fastest! + + // remove every additional and authority record + while (ares_dns_record_rr_cnt(dnsrec, ARES_SECTION_ADDITIONAL) > 0) { + status = ares_dns_record_rr_del(dnsrec, ARES_SECTION_ADDITIONAL, 0); + if (status != ARES_SUCCESS) { + WLOG("%04hX: Could not remove additional record: %s", tx_id, ares_strerror((int)status)); + } + } + while (ares_dns_record_rr_cnt(dnsrec, ARES_SECTION_AUTHORITY) > 0) { + status = ares_dns_record_rr_del(dnsrec, ARES_SECTION_AUTHORITY, 0); + if (status != ARES_SUCCESS) { + WLOG("%04hX: Could not remove authority record: %s", tx_id, ares_strerror((int)status)); + } + } + + // rough estimate to reach size limit + size_t answers = ares_dns_record_rr_cnt(dnsrec, ARES_SECTION_ANSWER); + size_t answers_to_keep = ((size_limit - DNS_HEADER_LENGTH) * answers) / old_size; + answers_to_keep = answers_to_keep > 0 ? answers_to_keep : 1; // try to keep 1 answer + + // remove answer records until fit size limit or running out of answers + unsigned char *new_resp = NULL; + size_t new_resp_len = 0; + for (uint8_t g = 0; g < UINT8_MAX; ++g) { // endless loop guard + status = ares_dns_write(dnsrec, &new_resp, &new_resp_len); + if (status != ARES_SUCCESS) { + WLOG("%04hX: Failed to create truncated DNS response: %s", tx_id, ares_strerror((int)status)); + new_resp = NULL; // just to be sure + break; + } + if (new_resp_len < size_limit || answers == 0) { + break; + } + if (new_resp_len >= old_size) { + WLOG("%04hX: Truncated DNS response size larger or equal to original: %u >= %u", + tx_id, new_resp_len, old_size); // impossible? + } + ares_free_string(new_resp); + new_resp = NULL; + + DLOG("%04hX: DNS response size truncated from %u to %u but to keep %u limit reducing answers from %u to %u", + tx_id, old_size, new_resp_len, size_limit, answers, answers_to_keep); + + while (answers > answers_to_keep) { + status = ares_dns_record_rr_del(dnsrec, ARES_SECTION_ANSWER, answers - 1); + if (status != ARES_SUCCESS) { + WLOG("%04hX: Could not remove answer record: %s", tx_id, ares_strerror((int)status)); + break; + } + --answers; + } + answers = ares_dns_record_rr_cnt(dnsrec, ARES_SECTION_ANSWER); // update to be sure! + answers_to_keep /= 2; + } + ares_dns_record_destroy(dnsrec); + + if (new_resp != NULL) { + if (new_resp_len < old_size) { + memcpy(buf, new_resp, new_resp_len); + *buflen = new_resp_len; + buf[2] |= 0x02; // set truncation flag + ILOG("%04hX: DNS response size truncated from %u to %u to keep %u limit", + tx_id, old_size, new_resp_len, size_limit); + } + ares_free_string(new_resp); + new_resp = NULL; + } +} + +void dns_truncate_for_udp(const char *dns_req, size_t dns_req_len, + char *resp, size_t *resp_len) { + if (*resp_len <= DNS_SIZE_LIMIT) { + return; // always fits + } + const uint16_t udp_size = get_edns_udp_size(dns_req, dns_req_len); + if (*resp_len <= udp_size) { + uint16_t tx_id = ntohs(*((uint16_t*)dns_req)); + DLOG("%04hX: DNS response size %zu larger than %d but EDNS0 UDP buffer size %u allows it", + tx_id, *resp_len, DNS_SIZE_LIMIT, udp_size); + return; + } + truncate_to_size_limit(resp, resp_len, udp_size); +} diff --git a/src/dns_truncate.h b/src/dns_truncate.h new file mode 100644 index 0000000..d474553 --- /dev/null +++ b/src/dns_truncate.h @@ -0,0 +1,20 @@ +#ifndef _DNS_TRUNCATE_H_ +#define _DNS_TRUNCATE_H_ + +#include + +// Fit a DNS response into the size limit advertised by the request. +// +// If `resp` exceeds the request's EDNS0 UDP buffer size (or RFC1035 4.2.1's +// 512-byte default when no EDNS0 OPT record is present), shrink it in place +// by dropping additional and authority records, then answer records, until +// it fits. The TC flag is set on truncation. A response that already fits +// is left untouched. +// +// Mutates `resp` and `*resp_len`. Caller retains ownership of both buffers. +// +// DNS-over-TCP has no per-message size cap and never needs this. +void dns_truncate_for_udp(const char *dns_req, size_t dns_req_len, + char *resp, size_t *resp_len); + +#endif // _DNS_TRUNCATE_H_