#![deny(warnings, rust_2018_idioms)]
use std::net::{SocketAddr, ToSocketAddrs};
use std::time::Duration;
use anyhow::{bail, Context, Result};
use async_trait::async_trait;
use bytes::BytesMut;
use lazy_static::lazy_static;
use rand::Rng;
use scopeguard;
use tokio::net::UdpSocket;
use tokio::time;
use tracing::{debug, trace, warn};
use crate::client::DnsClient;
use crate::codec::{decoder::DNSMessageDecoder, encoder::DNSMessageEncoder, message};
use crate::fbs::dns_message_generated::Message;
pub struct Client {
dns_server: SocketAddr,
last_udp_size: u16,
response_buffer: Vec<u8>,
timeout_ms: u64,
}
/// DNS Client that queries a server over UDP with fallback to TCP.
/// The timeout logic uses base2 exponential retry (1s, 2s, 4s, ...)
/// If a UDP response comes back truncated, the client will automatically retry the request over TCP.
impl Client {
/// Constructs a new `Client` that will query the specified `dns_server`.
/// timeout_ms: Approximate timeout for requests, actual timeout will be the next base2-1 amount.
/// For example 10000 -> 15000 (1s + 2s + 4s + 8s+)
pub fn new(dns_server: SocketAddr, timeout_ms: u64) -> Self {
Client {
dns_server,
last_udp_size: 4096,
response_buffer: vec![0; 4096],
timeout_ms,
}
}
}
lazy_static! {
/// Encoder instance, currently doesn't have state
static ref ENCODER: DNSMessageEncoder = DNSMessageEncoder::new();
}
#[async_trait]
impl DnsClient for Client {
fn encode(&self, request: &Message<'_>, request_buffer: &mut BytesMut) -> Result<()> {
// Ensure the request to our server has the correct UDP size in the request.
ENCODER.encode(request, Some(self.last_udp_size), request_buffer)
}
async fn query<'response>(
&mut self,
request_buffer: &mut BytesMut,
response_fbb: &'response mut flatbuffers::FlatBufferBuilder<'_>,
) -> Result<Option<Message<'response>>> {
let mut request_id: u16 = 0;
{
// Ensure that response_buffer is reset to size=4096 when we're done with it,
// regardless of success or error. The socket read is based on len, not capacity.
let mut response_buffer = scopeguard::guard(&mut self.response_buffer, |buf| {
buf.resize(4096, 0);
});
let response_size: usize;
{
let mut response_buffer_slice = response_buffer.as_mut();
response_size = send_recv_exponential_backoff(
&self.dns_server,
request_buffer,
&mut request_id,
&mut response_buffer_slice,
self.timeout_ms,
)
.await?;
// Shorten to actual size received
response_buffer.truncate(response_size);
}
trace!(
"Raw response from {:?} ({}b): {:02X?}",
self.dns_server,
response_buffer.len(),
&response_buffer[..]
);
match DNSMessageDecoder::new().decode(&response_buffer, response_fbb) {
Ok(true) => {
// Decoded into response_fbb
}
Ok(false) => {
// Message was likely truncated, upstream can fall back to TCP
debug!(
"Unable to parse response from server={} to request={:02X?}: {:02X?}",
self.dns_server,
&request_buffer[..],
&response_buffer[..],
);
return Ok(None);
}
Err(e) => {
// Other parse error
return Err(e).context(format!(
"Failed to parse response from server={} to request={:02X?}: {:02X?}",
self.dns_server,
&request_buffer[..],
&response_buffer[..],
));
}
}
}
let response: Message<'_> =
flatbuffers::get_root::<Message<'_>>(response_fbb.finished_data());
debug!("Response from {:?}: {}", self.dns_server, response);
let response_id: u16;
match response.header() {
Some(header) => {
if header.truncated() {
// Message claims to be truncated
return Ok(None);
}
response_id = header.id();
}
None => {
bail!("Missing header in response");
}
}
if response_id != request_id {
bail!(
"Returned transaction id {:?} doesn't match sent {:?}",
response_id,
request_id
);
}
// After passing validation, update udp_size for the next request to this server.
if let Some(opt) = response.opt() {
trace!(
"Using udp_size={} for server={}",
opt.udp_size(),
self.dns_server
);
self.last_udp_size = opt.udp_size();
}
Ok(Some(response))
}
}
async fn send_recv_exponential_backoff(
dest: &SocketAddr,
request_buffer: &mut BytesMut,
request_id: &mut u16,
mut response_buffer: &mut [u8],
total_timeout_ms: u64,
) -> Result<usize> {
// Start at 1s, then 2s, then 4s, ...
let mut remaining_timeout_ms = total_timeout_ms;
let mut timeout_ms = 1000;
loop {
// NOTE: This assumes that port 0 results in a random port each time.
// In particular we DONT want it to just increment by 1 or something each time.
// Apparently this is OS-specific but Linux at least should do what we want.
let client_addr = "0.0.0.0:0".to_socket_addrs()?.next().unwrap();
let mut conn = UdpSocket::bind(client_addr).await?;
// We regenerate the request ID on every retry. We're changing the client port each time,
// so scrambling the request ID shouldn't result in "old" mismatched responses anyway.
// This reduces the likelihood of someone trying to poison our cache by sending a request
// and then flooding us with responses that match that request's message id.
*request_id = rand::thread_rng().gen::<u16>();
message::update_message_id(*request_id, request_buffer, 0)?;
trace!(
"Raw request to {:?} ({}b): {:02X?}",
&dest,
request_buffer.len(),
&request_buffer[..]
);
// (Re)send request. Shouldn't time out but just in case...
let _sendsize = time::timeout(
Duration::from_millis(1000),
conn.send_to(request_buffer.as_ref(), dest),
)
.await?;
match time::timeout(
Duration::from_millis(timeout_ms),
conn.recv_from(&mut response_buffer),
)
.await
{
// Got a response from somewhere
Ok(recv) => {
let (recvsize, recvdest) = recv
// A different error occurred, give up
.with_context(|| format!("Failed to receive DNS response from {}", dest))?;
// Before returning, check that the response is from who we're waiting for
if *dest == recvdest {
return Ok(recvsize);
}
// If it doesn't match, resend and resume waiting, unless this was the last retry
warn!(
"Response origin {:?} doesn't match request target {:?}",
recvdest, dest
);
}
// Timeout occurred, try again (or exit loop)
Err(_e) => {
debug!("UDP request to {} timed out after {}ms", dest, timeout_ms);
}
}
timeout_ms *= 2;
if remaining_timeout_ms == 0 {
// No retries left, give up
bail!("Timed out waiting for response from {:?}", dest);
} else if remaining_timeout_ms <= timeout_ms {
// Last retry
remaining_timeout_ms = 0;
} else {
// More retries left after this one
remaining_timeout_ms -= timeout_ms;
}
}
}