~sircmpwn/hare-ev

ed023beb4b4db88e22f608aa001682ac18cad230 — Drew DeVault a month ago aa01cdd
ev::dns: fall back to TCP on response truncation

Signed-off-by: Drew DeVault <sir@cmpwn.com>
1 files changed, 170 insertions(+), 14 deletions(-)

M ev/dns/dns.ha
M ev/dns/dns.ha => ev/dns/dns.ha +170 -14
@@ 1,10 1,13 @@
use endian;
use errors;
use ev;
use io;
use net;
use net::dns;
use net::ip;
use net::udp;
use time;
use types;
use unix::resolvconf;

// TODO: Let users customize this?


@@ 17,20 20,36 @@ export type querycb = fn(
) void;

type qstate = struct {
	buf: [512]u8,
	// Event loop objects
	loop: *ev::loop,
	socket4: *ev::file,
	socket6: *ev::file,
	r4: ev::req,
	r6: ev::req,
	timer: *ev::file,

	// Request ID
	rid: u16,

	// Outgoing DNS request
	query: [512]u8,
	qlen: u16,

	// Response buffer
	rbuf: []u8,
	rbuf_valid: u16,

	// Length buffer
	zbuf: [2]u8,

	// Callback and user data
	cb: *querycb,
	user: nullable *opaque,
};

// Performs a DNS query against the provided set of DNS servers, or the list of
// servers from /etc/resolv.conf if none are specified. The user must free the
// message passed to the callback with [[net::dns::message_free]].
// servers from /etc/resolv.conf if none are specified. The DNS message passed
// to the callback is only valid for the duration of the callback.
export fn query(
	loop: *ev::loop,
	query: *dns::message,


@@ 51,15 70,17 @@ export fn query(
	const socket6 = ev::listen_udp(loop, ip::ANY_V6, 0)?;
	const timeout = ev::newtimer(loop, &timeoutcb, time::clock::MONOTONIC)?;
	let state = alloc(qstate {
		loop = loop,
		socket4 = socket4,
		socket6 = socket6,
		timer = timeout,
		rbuf = alloc([0...], 512),
		rid = query.header.id,
		cb = cb,
		user = user,
		...
	});
	const z = dns::encode(state.buf, query)?;
	state.qlen = dns::encode(state.query, query)?: u16;
	ev::setuser(socket4, state);
	ev::setuser(socket6, state);
	ev::setuser(timeout, state);


@@ 68,7 89,7 @@ export fn query(
	// Note: the initial set of requests is sent directly through net::udp
	// as it is assumed they can fit into the kernel's internal send buffer
	// and will finish without blocking
	const buf = state.buf[..z];
	const buf = state.query[..state.qlen];
	for (const server .. servers) {
		match (server) {
		case ip::addr4 =>


@@ 78,8 99,8 @@ export fn query(
		};
	};

	state.r4 = ev::recvfrom(socket4, &qrecvcb, state.buf);
	state.r6 = ev::recvfrom(socket6, &qrecvcb, state.buf);
	state.r4 = ev::recvfrom(socket4, &qrecvcb, state.rbuf);
	state.r6 = ev::recvfrom(socket6, &qrecvcb, state.rbuf);
	return ev::mkreq(&query_cancel, state);
};



@@ 94,14 115,20 @@ fn query_destroy(q: *qstate) void = {
	ev::close(q.socket4);
	ev::close(q.socket6);
	ev::close(q.timer);
	free(q.rbuf);
	free(q);
};

fn query_complete(q: *qstate, r: (*dns::message | dns::error)) void = {
	const cb = q.cb;
	const user = q.user;
	query_destroy(q);
	cb(user, r);
	match (r) {
	case let msg: *dns::message =>
		dns::message_free(msg);
	case => void;
	};
	query_destroy(q);
};

fn timeoutcb(file: *ev::file) void = {


@@ 111,6 138,18 @@ fn timeoutcb(file: *ev::file) void = {

fn qrecvcb(file: *ev::file, r: ((size, ip::addr, u16) | net::error)) void = {
	const q = ev::getuser(file): *qstate;
	match (qrecv(q, file, r)) {
	case void => void;
	case let r: (*dns::message | dns::error) =>
		query_complete(q, r);
	};
};

fn qrecv(
	q: *qstate,
	file: *ev::file,
	r: ((size, ip::addr, u16) | net::error),
) (*dns::message | dns::error | void) = {
	let req: *ev::req = if (file == q.socket4) &q.r4 else &q.r6;
	*req = ev::req { ... };



@@ 122,24 161,141 @@ fn qrecvcb(file: *ev::file, r: ((size, ip::addr, u16) | net::error)) void = {
		return;
	};

	const resp = match (dns::decode(q.buf[..z])) {
	const resp = match (dns::decode(q.rbuf[..z])) {
	case dns::format =>
		*req = ev::recvfrom(file, &qrecvcb, q.buf);
		*req = ev::recvfrom(file, &qrecvcb, q.rbuf);
		return;
	case let msg: *dns::message =>
		yield msg;
	};
	defer dns::message_free(resp);

	if (resp.header.id != q.rid || resp.header.op.qr != dns::qr::RESPONSE) {
		*req = ev::recvfrom(file, &qrecvcb, q.buf);
		*req = ev::recvfrom(file, &qrecvcb, q.rbuf);
		dns::message_free(resp);
		return;
	};

	if (!resp.header.op.tc) {
		query_complete(q, resp);
		return resp;
	};

	dns::message_free(resp);

	// Reponse truncated, retry over TCP
	//
	// Note that when we switch to TCP, we only use the r4 field for
	// in-flight requests (even if we're using IPv6), and likewise once the
	// TCP connection is estabilshed the UDP socket at socket4 is closed and
	// replaced with the TCP socket (regardless of domain).

	// Cancel in-flight UDP queries
	ev::cancel(&q.r4);
	ev::cancel(&q.r6);

	match (ev::connect_tcp(q.loop, &qconnected, addr, 53, q)) {
	case let req: ev::req =>
		q.r4 = req;
	case let err: net::error =>
		return err;
	case let err: errors::error =>
		return err: net::error;
	};
};

fn qconnected(result: (*ev::file | net::error), user: nullable *opaque) void = {
	const q = user: *qstate;
	q.r4 = ev::req { ... };
	const sock = match (result) {
	case let file: *ev::file =>
		yield file;
	case let err: net::error =>
		query_complete(q, err);
		return;
	};

	abort(); // TODO: retry over TCP
	ev::close(q.socket4);
	q.socket4 = sock;

	endian::beputu16(q.zbuf, q.qlen);

	q.r4 = ev::writev(sock,
		&qtcp_write_cb,
		io::mkvector(q.zbuf),
		io::mkvector(q.query[..q.qlen]));
};

fn qtcp_write_cb(file: *ev::file, result: (size | io::error)) void = {
	const q = ev::getuser(file): *qstate;
	q.r4 = ev::req { ... };
	match (result) {
	case let z: size =>
		// XXX: some (stupid) configurations may have a TCP buffer less
		// than 514 bytes, which we might want to handle, but generally
		// the request should make it to the TCP buffer in a single
		// writev call.
		assert(z: u16 == q.qlen + 2);
	case let err: io::error =>
		query_complete(q, err);
	};

	q.r4 = ev::read(file, &qtcp_readlength_cb, q.zbuf);
};

fn qtcp_readlength_cb(
	file: *ev::file,
	result: (size | io::EOF | io::error),
) void = {
	const q = ev::getuser(file): *qstate;
	match (result) {
	case let z: size =>
		if (z != 2) {
			query_complete(q, dns::format);
			return;
		};
	case let err: io::error =>
		query_complete(q, err);
		return;
	case io::EOF =>
		query_complete(q, dns::format);
		return;
	};

	const rlen = endian::begetu16(q.zbuf);
	q.rid = rlen;
	q.rbuf = alloc([0...], rlen);
	q.r4 = ev::read(file, &qtcp_readdata_cb, q.rbuf);
};

fn qtcp_readdata_cb(
	file: *ev::file,
	result: (size | io::EOF | io::error),
) void = {
	const q = ev::getuser(file): *qstate;
	q.r4 = ev::req { ... };
	match (result) {
	case let z: size =>
		const rlen = z: u16;
		if (q.rbuf_valid + rlen > q.rid) {
			query_complete(q, dns::format);
			return;
		};
		q.rbuf_valid += rlen;
	case io::EOF =>
		return;
	};

	if (q.rbuf_valid < q.rid) {
		// Read more data from the socket
		q.r4 = ev::read(file, &qtcp_readdata_cb, q.rbuf[q.rbuf_valid..]);
		return;
	};

	const resp = match (dns::decode(q.rbuf[..q.rbuf_valid])) {
	case dns::format =>
		query_complete(q, dns::format);
		return;
	case let msg: *dns::message =>
		yield msg;
	};
	query_complete(q, resp);
};