#![deny(warnings)]
use std::convert::TryFrom;
use std::io::{self, Write};
use std::time::Duration;
use anyhow::{bail, Context, Result};
use async_trait::async_trait;
use bytes::{Buf, BufMut, BytesMut};
use hyper::client::HttpConnector;
use hyper::header;
use hyper::{Body, Client as HttpClient, Method, Uri};
use hyper_rustls::HttpsConnector;
use lazy_static::lazy_static;
use rand::Rng;
use rustls::ClientConfig;
use scopeguard;
use tracing::{debug, trace, warn};
use crate::client::{hyper::Resolver, DnsClient};
use crate::codec::{decoder::DNSMessageDecoder, encoder::DNSMessageEncoder, message};
use crate::fbs::dns_message_generated::Message;
use crate::http::Fetcher;
static MAX_HTTP_BYTES: u16 = 65535;
lazy_static! {
/// Encoder instance, currently doesn't have state
static ref ENCODER: DNSMessageEncoder = DNSMessageEncoder::new();
}
pub struct Client {
query_url: Uri,
fetcher: Fetcher,
client: HttpClient<HttpsConnector<HttpConnector<Resolver>>, Body>,
response_buffer: BytesMut,
}
/// DNS Client that queries a server over HTTPS (DoH/RFC8484)
impl Client {
/// Constructs a new `Client` that will query the specified `dns_server`.
pub fn new(resolver: Resolver, server_url: String, timeout_ms: u64) -> Result<Self> {
Ok(Client {
query_url: Uri::try_from(&server_url)
.with_context(|| format!("Failed to create URI for DoH server: {}", server_url))?,
fetcher: Fetcher::new(
MAX_HTTP_BYTES as usize,
Some("application/dns-message".to_string()),
)
// Note that hyper will reject requests with "request has unsupported HTTP version",
// unless we ALSO set "http2_only(true)" in the Client builder.
.use_http_2(),
client: build_fetch_client(resolver, timeout_ms)?,
response_buffer: BytesMut::with_capacity(MAX_HTTP_BYTES as usize),
})
}
}
/// Build HTTP connector which queries our configured source DNS server.
/// Uses the provided `resolver` to resolve the HTTPS endpoint itself.
fn build_fetch_client(
resolver: Resolver,
timeout_ms: u64,
) -> Result<HttpClient<HttpsConnector<HttpConnector<Resolver>>, Body>> {
let mut http_connector = HttpConnector::<_>::new_with_resolver(resolver);
http_connector.set_connect_timeout(Some(Duration::from_millis(timeout_ms)));
http_connector.set_happy_eyeballs_timeout(Some(Duration::from_millis(timeout_ms)));
http_connector.set_keepalive(Some(Duration::from_secs(90)));
// Required or else we get errors when trying to pass through https urls, see also HttpsConnector::new_():
http_connector.enforce_http(false);
// Build HTTPS connector that wraps HTTP connector. Allows HTTPS but doesn't require it.
let mut https_config = ClientConfig::new();
https_config.alpn_protocols = vec![b"h2".to_vec()];
match rustls_native_certs::load_native_certs() {
Ok(certs) => {
https_config.root_store = certs;
}
Err((Some(certs), e)) => {
warn!(
"Some TLS certificates failed to load, trying to continue without them: {:?}",
e
);
https_config.root_store = certs;
}
Err((None, e)) => {
return Err(e).with_context(|| "Failed to load native TLS cert store");
}
}
https_config.ct_logs = Some(&ct_logs::LOGS);
Ok(HttpClient::builder()
// Required to avoid "request has unsupported HTTP version" errors when sending HTTP/2 requests
.http2_only(true)
.build::<HttpsConnector<_>, Body>(HttpsConnector::from((http_connector, https_config))))
}
#[async_trait]
impl DnsClient for Client {
fn encode(&self, request: &Message<'_>, request_buffer: &mut BytesMut) -> Result<()> {
// Just use our max for the "udp size"
ENCODER.encode(request, Some(MAX_HTTP_BYTES), request_buffer)
}
async fn query<'response>(
&mut self,
request_buffer: &mut BytesMut,
response_fbb: &'response mut flatbuffers::FlatBufferBuilder<'_>,
) -> Result<Option<Message<'response>>> {
let request_id = rand::thread_rng().gen::<u16>();
message::update_message_id(request_id, request_buffer, 0)?;
{
// Ensure that response_buffer size is reset when we're done with it,
// regardless of success or error. The socket read is based on len, not capacity.
let mut response_buffer = scopeguard::guard(&mut self.response_buffer, |buf| {
buf.clear();
});
trace!(
"Raw request to {} ({}b): {:02X?}",
self.query_url,
request_buffer.len(),
&request_buffer[..]
);
let request = self
.fetcher
.request_builder(&Method::POST, &self.query_url)
.header(header::CONTENT_TYPE, "application/dns-message")
.header(header::CONTENT_LENGTH, request_buffer.len())
.body(Body::from(request_buffer.to_bytes()))
.with_context(|| "Failed to build DoH request")?;
let mut response = self
.client
.request(request)
.await
.with_context(|| "DoH query failed")?;
if !response.status().is_success() {
bail!(
"HTTP POST to {} returned status: {}",
self.query_url,
response.status()
);
}
{
// Write response payload into response_buffer
let mut writer = BytesWriter::new(&mut response_buffer);
self.fetcher
.write_response(&self.query_url.to_string(), &mut writer, &mut response)
.await?;
}
trace!(
"Raw response from {} ({}b): {:02X?}",
self.query_url,
response_buffer.len(),
&response_buffer[..]
);
match DNSMessageDecoder::new().decode(&response_buffer[..], response_fbb) {
Ok(true) => {
// Decoded into response_fbb
}
Ok(false) => {
// Message was likely truncated, despite us receiving all the data in the payload
debug!(
"Unable to parse response from server={} to request={:02X?}: {:02X?}",
self.query_url,
&request_buffer[..],
&response_buffer[..],
);
return Ok(None);
}
Err(e) => {
// Other parse error
return Err(e).context(format!(
"Failed to parse response from server={} to request={:02X?}: {:02X?}",
self.query_url,
&request_buffer[..],
&response_buffer[..],
));
}
}
}
let response: Message<'_> =
flatbuffers::get_root::<Message<'_>>(response_fbb.finished_data());
debug!("Response from {}: {}", self.query_url, response);
let response_id: u16;
match response.header() {
Some(header) => {
if header.truncated() {
// Message claims to be truncated
return Ok(None);
}
response_id = header.id();
}
None => {
bail!("Missing header in response");
}
}
if response_id != request_id {
bail!(
"Returned transaction id {:?} doesn't match sent {:?}",
response_id,
request_id
);
}
Ok(Some(response))
}
}
/// Pass-through writer that counts the number of bytes that have been written.
/// Used to consistently measure the decompressed size of a download.
struct BytesWriter<'a> {
inner: &'a mut BytesMut,
}
impl<'a> BytesWriter<'a> {
fn new(inner: &'a mut BytesMut) -> BytesWriter<'a> {
BytesWriter { inner }
}
}
impl<'a> Write for BytesWriter<'a> {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
if self.inner.remaining_mut() >= buf.len() {
self.inner.put_slice(buf);
Ok(buf.len())
} else {
Err(io::Error::new(
io::ErrorKind::InvalidInput,
format!(
"Unable to write {} bytes into buffer: {}/{} remaining",
buf.len(),
self.inner.remaining_mut(),
self.inner.len()
),
))
}
}
fn flush(&mut self) -> io::Result<()> {
Ok(())
}
}