#![deny(warnings, rust_2018_idioms)]
use std::future::Future;
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
use std::pin::Pin;
use std::task::{self, Poll};
use std::vec;
use anyhow::{Context, Result};
use bytes::BytesMut;
use hyper::client::connect::dns::Name;
use tokio::task::JoinHandle;
use tower_service::Service;
use tracing::{debug, trace};
use crate::client::{self, DnsClient};
use crate::fbs::dns_enums_generated::{OpCode, ResourceClass, ResourceType};
use crate::fbs::dns_message_generated::*;
pub struct ReturnedAddrs {
addrs: vec::IntoIter<IpAddr>,
}
impl Iterator for ReturnedAddrs {
type Item = IpAddr;
fn next(&mut self) -> Option<Self::Item> {
self.addrs.next()
}
}
pub struct ResolveFuture {
hosts: JoinHandle<Result<Vec<IpAddr>>>,
}
impl Future for ResolveFuture {
type Output = Result<ReturnedAddrs>;
fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
Pin::new(&mut self.hosts).poll(cx).map(|res| match res {
Ok(Ok(addrs)) => Ok(ReturnedAddrs {
addrs: addrs.into_iter(),
}),
Ok(Err(e)) => Err(e),
Err(e) => Err(e).context("Waiting for DNS lookup failed"),
})
}
}
/// A hyper-compatible DNS resolver object that queries an upstream UDP server directly.
/// This may be passed to a hyper HTTP client, ensuring that it doesn't use the default system resolver.
/// In particular, this avoids a possible infinite loop if that system resolver routes requests back to Kapiti.
#[derive(Clone, Copy)]
pub struct Resolver {
// TODO(#3) we can't just pass a DnsClient here due to the Copy requirement. Time for crossbeam.
dns_server: SocketAddr,
get_ipv6: bool,
udp_size: u16,
}
impl Resolver {
pub fn new(dns_server: SocketAddr, get_ipv6: bool, udp_size: u16) -> Self {
Resolver {
dns_server,
get_ipv6,
udp_size,
}
}
}
impl Service<Name> for Resolver {
type Response = ReturnedAddrs;
type Error = anyhow::Error;
type Future = ResolveFuture;
fn poll_ready(&mut self, _cx: &mut task::Context<'_>) -> Poll<Result<()>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, name: Name) -> Self::Future {
// Name only exposes str, convert back to string
let mut name_string = name.as_str().to_string();
// Also add the trailing '.' that we want
name_string.push('.');
let resource_type = match self.get_ipv6 {
true => (ResourceType::TYPE_AAAA),
false => (ResourceType::TYPE_A),
};
trace!("Resolving host={:?} type={:?}", name_string, resource_type);
// Get copies for async move:
let dns_server_copy = self.dns_server.clone();
let udp_size_copy = self.udp_size;
let query = tokio::task::spawn(async move {
// TODO(#11) OPTIMIZE:
// - reuse client across resolves, and use configured client type instead of always trying UDP
// - reuse fb_builder,request_buffer across resolves
// - for now we rebuild each time to avoid lifetime issues with the async callback
// - not urgent since we only use this resolver for updating filters anyway
let mut client = client::udp::Client::new(dns_server_copy, 10000);
let mut fb_builder = flatbuffers::FlatBufferBuilder::new_with_capacity(1024);
let mut request_buffer = BytesMut::new();
build_request(&mut fb_builder, resource_type, &name_string, udp_size_copy);
{
let request: Message<'_> =
flatbuffers::get_root::<Message<'_>>(fb_builder.finished_data());
client
.encode(&request, &mut request_buffer)
.with_context(|| "failed to construct DNS request")?;
}
// Request has been encoded, so we can reuse this buffer for the response
fb_builder.reset();
let response = client
.query(&mut request_buffer, &mut fb_builder)
.await
.with_context(|| format!("Failed to resolve host {:?}", name_string))?
.with_context(|| {
format!(
"Failed to parse truncated/incomplete response to {} lookup",
name_string
)
})?;
let results = extract_addresses(response, resource_type);
debug!(
"Resolved host={:?} type={:?}: {:?}",
name_string, resource_type, results
);
Ok(results)
});
ResolveFuture { hosts: query }
}
}
fn build_request(
fb_builder: &mut flatbuffers::FlatBufferBuilder,
resource_type: ResourceType,
domain: &String,
udp_size: u16,
) {
let header = Header::create(
fb_builder,
&HeaderArgs {
id: 0,
is_response: false,
op_code: OpCode::OP_QUERY as u8,
authoritative: false,
truncated: false,
recursion_desired: true,
recursion_available: false,
reserved_9: false,
authentic_data: true,
checking_disabled: false,
response_code: 0,
},
);
let opt = OPT::create(
fb_builder,
&OPTArgs {
option: None,
udp_size,
response_code: 0,
version: 0,
dnssec_ok: true,
},
);
let question_name = fb_builder.create_string(domain.as_str());
let question = Question::create(
fb_builder,
&QuestionArgs {
name: Some(question_name),
resource_type: resource_type as u16,
resource_class: ResourceClass::CLASS_INTERNET as u16,
},
);
let message_question = fb_builder.create_vector(&[question]);
let message_offset = Message::create(
fb_builder,
&MessageArgs {
header: Some(header),
opt: Some(opt),
question: Some(message_question),
answer: None,
authority: None,
additional: None,
},
);
fb_builder.finish_minimal(message_offset);
}
fn extract_addresses(message: Message<'_>, resource_type: ResourceType) -> Vec<IpAddr> {
let resource_type_u16 = resource_type as u16;
match message.answer() {
Some(answers) => {
let mut results = Vec::with_capacity(answers.len());
for i in 0..answers.len() {
let answer = answers.get(i);
if answer.resource_type() == resource_type_u16 {
match extract_address(answer, resource_type) {
Some(addr) => results.push(addr),
None => continue,
}
}
}
results
}
None => Vec::new(),
}
}
fn extract_address(answer: Resource<'_>, resource_type: ResourceType) -> Option<IpAddr> {
match answer.rdata() {
Some(rdata) => {
if resource_type == ResourceType::TYPE_A {
return match rdata.a() {
Some(a) => Some(
Ipv4Addr::new(a.address1(), a.address2(), a.address3(), a.address4())
.into(),
),
None => None,
};
} else if resource_type == ResourceType::TYPE_AAAA {
return match rdata.aaaa() {
Some(aaaa) => Some(
Ipv6Addr::new(
aaaa.address1(),
aaaa.address2(),
aaaa.address3(),
aaaa.address4(),
aaaa.address5(),
aaaa.address6(),
aaaa.address7(),
aaaa.address8(),
)
.into(),
),
None => return None,
};
} else {
panic!(
"Unsupported resource type for address extraction: {:?}",
answer.resource_type()
);
}
}
None => None,
}
}