summaryrefslogblamecommitdiffstats
path: root/dns.c
blob: 8c6ec32f9de795f30ed135e86625e55e3d03dab3 (plain) (tree)











































































































                                                                                                                 

                            

                                                                       



                            






                                                   
















                                                                           































































































                                                                                                    

                       




                                
               


















                                                                                         
#define _GNU_SOURCE
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <errno.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <arpa/inet.h>
#include <netinet/in.h>
#include <signal.h>
#include <unistd.h>
#include "domain2name.c" // this is from http://git.sijanec.eu/sijanec/dnsfind
enum type {
	A = 1,
	Ns,
	Md,
	Mf,
	Cname,
	Soa,
	Mb,
	M,
	Mr,
	Null,
	Wks,
	Ptr,
	Hinfo,
	Minfo,
	Mx,
	Txt
};
enum class {
	In = 1,
	Cs,
	Ch,
	He,
	Any = 255
};
#define RESPONSE (1 << 15)
#define QUESTION (0 << 15)
#define QUERY (0 << 11)
#define IQUERY (1 << 11)
#define STATUS (1 << 12)
#define AA (1 << 10)
#define TC (1 << 9)
#define RD (1 << 8)
#define RA (1 << 7)
#define SUCCESS (0 << 0)
#define FORMAT_ERROR (1 << 0)
#define SERVFAIL (1 << 1)
#define NXDOMAIN (FORMAT_ERROR | SERVFAIL)
#define NI (1 << 2)
#define FORBIDDEN (NXDOMAIN | FORMAT_ERROR)
struct header {
	uint16_t xid			__attribute__((packed));
	uint16_t flags			__attribute__((packed));
	uint16_t qdcount		__attribute__((packed));
	uint16_t ancount		__attribute__((packed));
	uint16_t nscount		__attribute__((packed));
	uint16_t arcount		__attribute__((packed));
	char data[];			/* ignored for char[] */
} __attribute__((packed));
struct rr { /* name is omitted, first byte of struct is first byte of type */
	uint16_t type			__attribute__((packed));
	uint16_t class			__attribute__((packed));
	uint32_t ttl			__attribute__((packed));
	uint16_t len			__attribute__((packed));
	char data[];			/* ignored for char[] */
} __attribute__((packed));
struct question {
	uint16_t type			__attribute__((packed));
	uint16_t class			__attribute__((packed));
} __attribute__((packed));
enum dns_loglevel {
	DNS_DEBUG,
	DNS_INFO,
	DNS_WARN,
	DNS_ERROR
};
typedef void (* dns_log_handler) (void * const, const enum dns_loglevel, const char * const, const char * const);
struct dns {
	int fd;
	char * domain;	// this is as it appears in the packet
	dns_log_handler log_handler;
	void * log_userdata;
	struct sockaddr_in sockaddr;
};
typedef struct dns dns;
static void dns_default_log_handler (void * const u __attribute__((unused)),
		const enum dns_loglevel l, const char * a, const char * m) {
	char * n = "unspec";
	switch (l) {
		case DNS_DEBUG:
			n = "DEBUG";
			break;
		case DNS_INFO:
			n = "INFO";
			break;
		case DNS_WARN:
			n = "WARN";
			break;
		case DNS_ERROR:
			n = "ERROR";
			break;
	}
	fprintf(stderr, "%s %s %s\n", n, a, m);
}
struct dns * dns_init (void) {
	struct dns * dns = calloc(1, sizeof(struct dns));
	if (!dns)
		return NULL;
	dns->fd = -1;
	dns->domain = strdup(" call dns_set_domain to set the domain");
	if (!dns->domain) {
		free(dns);
		return NULL;
	}
	dns->domain[0] = strlen(dns->domain)-1;
	dns->log_handler = dns_default_log_handler;
	dns->sockaddr.sin_family = AF_INET;
	dns->sockaddr.sin_port = htons(53);
	dns->sockaddr.sin_addr.s_addr = INADDR_ANY;
	return dns;
}
static void dns_set_domain (struct dns * dns, const char * domain) {
	char buf[64];
	int required = domain2name_len(domain, strlen(domain));
	if (required <= 0) {
		sprintf(buf, "domain2name_len failed with %d", required);
		dns->log_handler(dns->log_userdata, DNS_ERROR, "dns", buf);
		return;
	}
	char * new = malloc(required);
	if (!new) {
		sprintf(buf, "malloc of size %d failed", required);
		dns->log_handler(dns->log_userdata, DNS_ERROR, "dns", buf);
		return;
	}
	free(dns->domain);
	dns->domain = new;
	domain2name(dns->domain, domain, strlen(domain));
}
static void dns_set_log_handler (struct dns * dns, dns_log_handler log_handler) {
	if (!log_handler)
		log_handler = dns_default_log_handler;
	dns->log_handler = log_handler;
}
static void dns_set_log_userdata (struct dns * dns, void * log_userdata) {
	dns->log_userdata = log_userdata;
}
static void dns_set_port (struct dns * dns, int port) {
	dns->sockaddr.sin_port = htons(port);
}
static void dns_set_ip (struct dns * dns, const char * ip) {
	if (!ip) {
		dns->sockaddr.sin_addr.s_addr = INADDR_ANY;
		return;
	}
	inet_aton(ip, &dns->sockaddr.sin_addr);
}
static void dns_run_once (struct dns * dns) {
#define BUFLEN 4096
	char buf[BUFLEN];
	if (dns->fd == -1) {
		if ((dns->fd = socket(AF_INET, SOCK_DGRAM, IPPROTO_UDP)) == -1) {
			sprintf(buf, "socket failed with %s", strerror(errno));
			dns->log_handler(dns->log_userdata, DNS_ERROR, "dns", buf);
			return;
		}
		int ž = 1;
		if (setsockopt(dns->fd, SOL_SOCKET, SO_BROADCAST, &ž, sizeof(ž)) == -1) {
			sprintf(buf, "setsockopt failed with %s", strerror(errno));
			dns->log_handler(dns->log_userdata, DNS_ERROR, "dns", buf);
			close(dns->fd);
			dns->fd = -1;
			return;
		}
		if (bind(dns->fd, (struct sockaddr *) &dns->sockaddr, sizeof(struct sockaddr))) {
			sprintf(buf, "bind failed with %s", strerror(errno));
			dns->log_handler(dns->log_userdata, DNS_ERROR, "dns", buf);
			close(dns->fd);
			dns->fd = -1;
			return;
		}
	}
	struct sockaddr_in sender;
	socklen_t sendl = sizeof sender;
	int size = recvfrom(dns->fd, buf, 65535, MSG_DONTWAIT, (struct sockaddr *) &sender, &sendl);
	if (size == -1) {
		if (errno != EWOULDBLOCK) {
			sprintf(buf, "recvfrom failed with %s", strerror(errno));
			dns->log_handler(dns->log_userdata, DNS_ERROR, "dns", buf);
			close(dns->fd);
			dns->fd = -1;
			return;
		}
		return;
	}
	struct header * header = (struct header *) buf;	// time for some overwriting
	header->flags = htons(RESPONSE | QUERY | AA | SUCCESS);
	header->ancount = htons(1);	// we keep question number intact, as we send all qs back
	header->nscount = 0;
	header->arcount = 0;
	buf[BUFLEN-1] = '\0'; // strstr and strlen on untrusted data!
	char * ouranswerisat = memmem(buf, size, "in-addr", strlen("in-addr"));
	if (!ouranswerisat) {
		dns->log_handler(dns->log_userdata, DNS_INFO, "dns",
				"received request without 'in-addr' string ... weird.");
		return;
	}
	ouranswerisat += 17;
	struct rr * rr = (struct rr *) (ouranswerisat + strlen(header->data) + 1);
	if (ouranswerisat + strlen(header->data) + 1 + sizeof(struct rr)
			+ strlen(dns->domain) + 1 >= buf + BUFLEN) {
		dns->log_handler(dns->log_userdata, DNS_WARN, "dns", "sent packet would be to big");
		return;
	}
	strcpy(ouranswerisat, header->data);
	rr->type = htons(Ptr);
	rr->class = htons(In);
	rr->ttl = 0;
	rr->len = htons(strlen(dns->domain)+1);
	strcpy(rr->data, dns->domain);
	int len = rr->data - buf + strlen(dns->domain) + 1;
	if (sendto(dns->fd, buf, len, 0, (struct sockaddr *) &sender, sizeof(sender)) == -1) {
		sprintf(buf, "sendto failed with %s", strerror(errno));
		dns->log_handler(dns->log_userdata, DNS_ERROR, "dns", buf);
		close(dns->fd);
		dns->fd = -1;
		return;
	}
	char dst[INET_ADDRSTRLEN];
	const char * resp = inet_ntop(AF_INET, &sender.sin_addr, dst, INET_ADDRSTRLEN);
	sprintf(buf, "successfully sent DNS reply to %s", resp ? resp : "[inet_ntop failed]");
	dns->log_handler(dns->log_userdata, DNS_ERROR, "dns", buf);
}
static void dns_free (struct dns * dns) {
	if (!dns)
		return;
	if (dns->fd != -1)
		close (dns->fd);
	free(dns->domain);
	free(dns);
}
#if IX_DNS_MAIN
int shouldexit = 0;
void handler (int signal __attribute__((unused))) {
	shouldexit++;
}
int main (int argc, char ** argv) {
	if (argc != 1+1) {
		fprintf(stderr, "usage: %s respond.to.ptr.with.this.domain.\n", argv[0]);
		return 1;
	}
	signal(SIGINT, handler);
	signal(SIGTERM, handler);
	dns * dns = dns_init();
	dns_set_domain(dns, argv[1]);
	while (!shouldexit) {
		dns_run_once(dns);
	}
	dns_free(dns);
}
#endif