use std::io::{self, Write}; use std::time::Duration; use anyhow::{bail, Context, Result}; use async_io::Timer; use async_trait::async_trait; use bytes::{BufMut, BytesMut}; use futures_lite::FutureExt; use hyper::header; use hyper::{Body, Client as HttpClient, Method, Uri}; use lazy_static::lazy_static; use rand::Rng; use scopeguard; use tracing::debug; use crate::client::DnsClient; use crate::codec::{decoder::DNSMessageDecoder, encoder::DNSMessageEncoder, message}; use crate::http::Fetcher; use crate::hyper_smol; use crate::resolver; use crate::specs::message::Message; static MAX_HTTP_BYTES: u16 = 65535; lazy_static! { /// Encoder instance, currently doesn't have state static ref ENCODER: DNSMessageEncoder = DNSMessageEncoder::new(); } pub struct Client { server_url: Uri, fetcher: Fetcher, client: HttpClient, timeout_ms: u64, response_buffer: BytesMut, } /// DNS Client that queries a server over HTTPS (DoH/RFC8484) impl Client { /// Constructs a new `Client` that will query the specified `dns_server`. pub fn new(server_url: Uri, resolver: resolver::Resolver, timeout_ms: u64) -> Result { Ok(Client { server_url, fetcher: Fetcher::new( MAX_HTTP_BYTES as usize, Some("application/dns-message".to_string()), ) // Note that hyper will reject requests with "request has unsupported HTTP version", // unless we ALSO set "http2_only(true)" in the Client builder. .use_http_2(), client: hyper_smol::client_kapiti(resolver, true, false, 4096), timeout_ms, response_buffer: BytesMut::with_capacity(MAX_HTTP_BYTES as usize), }) } } #[async_trait] impl DnsClient for Client { async fn query(&mut self, request: &Message, query_buffer: &mut BytesMut) -> Result> { // Just use our max for the "udp size" ENCODER.encode(request, Some(MAX_HTTP_BYTES), query_buffer)?; let request_id = rand::thread_rng().gen::(); message::update_message_id(request_id, query_buffer, 0)?; // Ensure that response_buffer size is reset when we're done with it, // regardless of success or error. The socket read is based on len, not capacity. let mut response_buffer = scopeguard::guard(&mut self.response_buffer, |buf| { buf.clear(); }); debug!( "Raw request to {} ({}b): {:02X?}", self.server_url, query_buffer.len(), &query_buffer[..] ); // Hyper apparently requires a static copy of the body content? No idea why, and don't care to find out. // Maybe we can clean this up by switching to something with a better API someday. let request_copy = hyper::body::Bytes::from(query_buffer.clone()); let request = self .fetcher .request_builder(&Method::POST, &self.server_url) .header(header::CONTENT_TYPE, "application/dns-message") .header(header::CONTENT_LENGTH, request_copy.len()) .body(Body::from(request_copy)) .context("Failed to build DoH request")?; // Copy for async call: let timeout_ms = self.timeout_ms; let mut response = self .client .request(request) .or(async { Timer::after(Duration::from_millis(timeout_ms)).await; // hyper keeps error types crate-private. Jump through hoops to produce an Ok response with an error. let response = hyper::Response::new(hyper::Body::empty()); let (mut parts, body) = response.into_parts(); parts.status = http::StatusCode::GATEWAY_TIMEOUT; Ok(hyper::Response::from_parts(parts, body)) }) .await .context("DoH query failed")?; if !response.status().is_success() { bail!( "HTTP POST to {} returned status: {}", self.server_url, response.status() ); } { // Write response payload into response_buffer let mut writer = BytesWriter::new(&mut response_buffer); self.fetcher .write_response(&self.server_url.to_string(), &mut writer, &mut response) .await?; } debug!( "Raw response from {} ({}b): {:02X?}", self.server_url, response_buffer.len(), &response_buffer[..] ); match DNSMessageDecoder::new().decode(&response_buffer[..]) { Ok(Some(response)) => { debug!("Response from {}: {}", self.server_url, response); if response.header.truncated { // Message claims to be truncated, shouldn't happen but let's bail anyway return Ok(None); } if response.header.id != request_id { bail!( "Returned transaction id {:?} doesn't match sent {:?}", response.header.id, request_id ); } Ok(Some(response)) } Ok(None) => { // Message was likely corrupted somehow, despite us receiving all the data in the payload debug!( "Unable to parse response from server={} to request={:02X?}: {:02X?}", self.server_url, &query_buffer[..], &response_buffer[..], ); Ok(None) } Err(e) => { // Other parse error Err(e).context(format!( "Failed to parse response from server={} to request={:02X?}: {:02X?}", self.server_url, &query_buffer[..], &response_buffer[..], )) } } } } /// Pass-through writer that counts the number of bytes that have been written. /// Used to consistently measure the decompressed size of a download. struct BytesWriter<'a> { inner: &'a mut BytesMut, } impl<'a> BytesWriter<'a> { fn new(inner: &'a mut BytesMut) -> BytesWriter<'a> { BytesWriter { inner } } } impl<'a> Write for BytesWriter<'a> { fn write(&mut self, buf: &[u8]) -> io::Result { if self.inner.remaining_mut() >= buf.len() { self.inner.put_slice(buf); Ok(buf.len()) } else { Err(io::Error::new( io::ErrorKind::InvalidInput, format!( "Unable to write {} bytes into buffer: {}/{} remaining", buf.len(), self.inner.remaining_mut(), self.inner.len() ), )) } } fn flush(&mut self) -> io::Result<()> { Ok(()) } }