#![deny(warnings, rust_2018_idioms)]
use std::convert::TryFrom;
use std::net::SocketAddr;
use std::time::Duration;
use anyhow::{bail, Context, Result};
use async_trait::async_trait;
use bytes::{BufMut, BytesMut};
use lazy_static::lazy_static;
use rand::Rng;
use scopeguard;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream;
use tokio::time;
use tracing::{debug, trace};
use crate::client::DnsClient;
use crate::codec::{decoder::DNSMessageDecoder, encoder::DNSMessageEncoder, message};
use crate::fbs::dns_message_generated::Message;
/// TCP size header is 16 bits, so max theoretical size is 64k
static MAX_TCP_BYTES: u16 = 65535;
lazy_static! {
/// Encoder instance, currently doesn't have state
static ref ENCODER: DNSMessageEncoder = DNSMessageEncoder::new();
}
pub struct Client {
dns_server: SocketAddr,
conn: Option<TcpStream>,
response_buffer: BytesMut,
timeout_ms: u64,
}
/// DNS Client that queries a server over TCP
impl Client {
/// Constructs a new `Client` that will query the specified `dns_server`.
pub fn new(dns_server: SocketAddr, timeout_ms: u64) -> Self {
Client {
dns_server,
conn: None,
response_buffer: BytesMut::with_capacity(MAX_TCP_BYTES as usize),
timeout_ms,
}
}
async fn connect(&mut self) -> Result<()> {
let conn = match time::timeout(
Duration::from_millis(self.timeout_ms),
TcpStream::connect(self.dns_server),
)
.await
{
Ok(conn) => {
conn.with_context(|| format!("Failed to connect to {}", self.dns_server))?
}
Err(_) => {
self.conn = None;
bail!("Timed out waiting for connection to {:?}", self.dns_server);
}
};
self.conn = Some(conn);
Ok(())
}
}
#[async_trait]
impl DnsClient for Client {
fn encode(&self, request: &Message<'_>, request_buffer: &mut BytesMut) -> Result<()> {
// Reserve 2 bytes for the TCP-specific length prefix
request_buffer.reserve(2);
request_buffer.put_u16(0);
// Just use our max for the "udp size"
ENCODER.encode(request, Some(MAX_TCP_BYTES), request_buffer)?;
// Insert the resulting encoded size of the message into those leading two bytes that we'd reserved
let message_len = u16::try_from(request_buffer.len() - 2).with_context(|| {
format!(
"Encoded request size {} exceeds {} limit: {:?}",
request_buffer.len() - 2,
MAX_TCP_BYTES,
request
)
})?;
request_buffer[0] = ((message_len & 0xFF00) >> 8) as u8;
request_buffer[1] = (message_len & 0xFF) as u8;
Ok(())
}
async fn query<'response>(
&mut self,
request_buffer: &mut BytesMut,
response_fbb: &'response mut flatbuffers::FlatBufferBuilder<'_>,
) -> Result<Option<Message<'response>>> {
if self.conn.is_none() {
self.connect().await?;
}
let request_id = rand::thread_rng().gen::<u16>();
// For TCP, the size header means that the message actually starts at byte 2
message::update_message_id(request_id, request_buffer, 2)?;
{
// Ensure that response_buffer size is reset when we're done with it,
// regardless of success or error. The socket read is based on len, not capacity.
let mut response_buffer = scopeguard::guard(&mut self.response_buffer, |buf| {
buf.resize(MAX_TCP_BYTES as usize, 0);
});
trace!(
"Raw request to {:?} ({}b): {:02X?}",
self.dns_server,
request_buffer.len(),
&request_buffer[..]
);
match time::timeout(
Duration::from_millis(self.timeout_ms),
self.conn
.as_mut()
.expect("missing connection")
.write_all(request_buffer.as_ref()),
)
.await
{
Ok(Ok(_sent_size)) => {
// No-op
}
Ok(Err(e)) => {
// Mark connection as dead, reconnect again on next query
self.conn = None;
return Err(e).with_context(|| "Error when sending request");
}
Err(_e) => {
bail!(
"Send to {} timed out after {}ms",
self.dns_server,
self.timeout_ms
);
}
}
// Read first two bytes to get expected response size
let response_size = match time::timeout(
Duration::from_millis(self.timeout_ms),
self.conn.as_mut().expect("missing connection").read_u16(),
)
.await
{
Ok(Ok(response_size)) => response_size,
Ok(Err(e)) => {
// Mark connection as dead, reconnect again on next query
self.conn = None;
return Err(e).with_context(|| "Error when reading header");
}
Err(_e) => {
bail!(
"Recv from {} timed out reading header after {}ms",
self.dns_server,
self.timeout_ms
);
}
};
// Read remaining bytes to get response
response_buffer.resize(
usize::try_from(response_size).with_context(|| "couldn't convert u16 to usize")?,
0,
);
match time::timeout(
Duration::from_millis(self.timeout_ms),
self.conn
.as_mut()
.expect("missing connection")
.read_exact(&mut response_buffer),
)
.await
{
Ok(Ok(recv_size)) => {
// Shorten to actual size received
response_buffer.truncate(recv_size);
}
Ok(Err(e)) => {
// Mark connection as dead, reconnect again on next query
self.conn = None;
return Err(e).with_context(|| "Error when reading payload");
}
Err(_e) => {
bail!(
"Recv from {} timed out reading payload after {}ms",
self.dns_server,
self.timeout_ms
);
}
}
trace!(
"Raw response from {:?} ({}b): {:02X?}",
self.dns_server,
response_buffer.len(),
&response_buffer[..]
);
match DNSMessageDecoder::new().decode(&response_buffer[..], response_fbb) {
Ok(true) => {
// Decoded into response_fbb
}
Ok(false) => {
// Message was likely truncated, despite us receiving all the data in the payload
debug!(
"Unable to parse response from server={} to request={:02X?}: {:02X?}",
self.dns_server,
&request_buffer[..],
&response_buffer[..],
);
return Ok(None);
}
Err(e) => {
// Other parse error
return Err(e).context(format!(
"Failed to parse response from server={} to request={:02X?}: {:02X?}",
self.dns_server,
&request_buffer[..],
&response_buffer[..],
));
}
}
}
let response: Message<'_> =
flatbuffers::get_root::<Message<'_>>(response_fbb.finished_data());
debug!("Response from {:?}: {}", self.dns_server, response);
let response_id: u16;
match response.header() {
Some(header) => {
if header.truncated() {
// Message claims to be truncated
return Ok(None);
}
response_id = header.id();
}
None => {
bail!("Missing header in response");
}
}
if response_id != request_id {
bail!(
"Returned transaction id {:?} doesn't match sent {:?}",
response_id,
request_id
);
}
Ok(Some(response))
}
}