use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
use std::sync::Arc;
use anyhow::{bail, Context, Result};
use async_lock::{Barrier, Mutex};
use bytes::BytesMut;
use lazy_static::lazy_static;
use tracing::{debug, warn};
use crate::cache;
use crate::client::DnsClient;
use crate::codec::{encoder::DNSMessageEncoder, message::RequestInfo};
use crate::specs::enums_generated::{OpCode, ResourceClass, ResourceType, ResponseCode};
use crate::specs::message::*;
lazy_static! {
/// Encoder instance, currently doesn't have state
static ref ENCODER: DNSMessageEncoder = DNSMessageEncoder::new();
}
/// A DNS resolver that queries from a set of upstream DNS sources.
pub struct Resolver {
cache_tx: async_channel::Sender<cache::task::CacheMsg>,
clients: Vec<Box<dyn DnsClient + Send>>,
client_buffer: BytesMut,
}
impl Resolver {
pub fn new(
cache_tx: async_channel::Sender<cache::task::CacheMsg>,
clients: Vec<Box<dyn DnsClient + Send>>,
) -> Self {
Resolver {
cache_tx,
clients,
client_buffer: BytesMut::with_capacity(4096),
}
}
/// A high-level query for getting an A/AAAA record for a given hostname.
pub async fn resolve_str(
&mut self,
host: &String,
port: u16,
get_ipv6: bool,
udp_size: u16,
) -> Result<SocketAddr> {
let resource_type = match get_ipv6 {
true => ResourceType::AAAA,
false => ResourceType::A,
};
// Ensure the host in the DNS message has the period:
let mut host_with_period = host.clone();
host_with_period.push('.');
let request_info = RequestInfo {
// The host used for cache/filter lookup meanwhile should not have the trailing period:
name: host.clone(),
resource_type,
received_request_id: 0,
requested_udp_size: udp_size,
};
let request = build_request(resource_type, host_with_period, udp_size);
let response = self.resolve(&request, &request_info)
.await
.with_context(|| format!("Failed to resolve host {:?}", host))?;
let results = extract_addresses(response, resource_type);
debug!(
"Resolved host={:?} type={:?}: {:?}",
host, resource_type, results
);
Ok(SocketAddr::new(
// If there are multiple IPs, just return the first one
*results.get(0).with_context(|| {
format!("No {:?} results for host: {:?}", resource_type, host)
})?,
port,
))
}
/// A low-level query that accepts and returns raw DNS query payloads.
pub async fn resolve_raw(
&mut self,
request: &Message,
request_info: &RequestInfo,
response_buffer: &mut BytesMut,
) -> Result<()> {
let response = self.resolve(request, request_info).await?;
debug!("Response to client: {}", response);
// Reencode the response to be returned, overriding udp_size to match the original request
ENCODER.encode(
&response,
Some(request_info.requested_udp_size),
response_buffer,
)
}
async fn resolve(&mut self, request: &Message, request_info: &RequestInfo) -> Result<Message> {
// Check if cache has cached result: Send request and wait for response via barrier+arc
let result_barrier = Arc::new(Barrier::new(2));
let result = Arc::new(Mutex::new(None));
self.cache_tx.send(cache::task::CacheMsg::Fetch(cache::task::CacheFetch {
request_info: (*request_info).clone(),
result_barrier: result_barrier.clone(),
result: result.clone(),
})).await.context("Failed to send cache fetch query")?;
// Wait on the barrier to complete
result_barrier.wait().await;
// Barrier has completed, get the stored result.
// Do a swap to get the result out without yet another copy.
match result.lock().await.replace(Ok(None)).expect("Missing fetch result following barrier") {
Ok(Some(cache_result)) => {
return Ok(cache_result)
},
Ok(None) => {
// cache miss - continue with upstream queries below
},
Err(e) => {
// cache fail - complain but continue with upstream queries
warn!("Cache lookup failed for request {:?}: {}", request_info, e)
},
}
// Cache didn't have anything, so query upstream clients.
for client in &mut self.clients {
// Mark the client buffer as empty so that we don't append on top of a prior request
self.client_buffer.clear();
if let Some(mut response) = client.query(request, &mut self.client_buffer).await? {
// Store fetched result to cache (no response needed)
self.cache_tx.send(cache::task::CacheMsg::Store(cache::task::CacheStore{
request_info: (*request_info).clone(),
response: response.clone(),
})).await.context("Failed to send cache store query")?;
// Set the message ID for the response so that it matches the original request.
// Keeping the message IDs independent reduces the likelihood of cache poisoning.
response.header.id = request_info.received_request_id;
return Ok(response);
}
}
bail!("All upstreams failed to return a response");
}
}
fn build_request(resource_type: ResourceType, domain: String, udp_size: u16) -> Message {
let mut question = Vec::new();
question.push(Question {
name: domain,
resource_type: IntEnum::Enum(resource_type),
resource_class: IntEnum::Enum(ResourceClass::INTERNET),
});
Message {
header: Header {
id: 0,
is_response: false,
op_code: IntEnum::Enum(OpCode::QUERY),
authoritative: false,
truncated: false,
recursion_desired: true,
recursion_available: false,
reserved_9: false,
authentic_data: true,
checking_disabled: false,
response_code: IntEnum::Enum(ResponseCode::NOERROR),
},
opt: Some(OPT {
option: Vec::new(),
udp_size,
response_code: 0,
version: 0,
dnssec_ok: true,
}),
question,
answer: Vec::new(),
authority: Vec::new(),
additional: Vec::new(),
}
}
fn extract_addresses(message: Message, resource_type: ResourceType) -> Vec<IpAddr> {
let mut results = Vec::with_capacity(message.answer.len());
for answer in &message.answer {
if answer.resource_type == IntEnum::Enum(resource_type) {
match extract_address(answer, resource_type) {
Some(addr) => results.push(addr),
None => continue,
}
}
}
results
}
fn extract_address(answer: &Resource, resource_type: ResourceType) -> Option<IpAddr> {
if resource_type == ResourceType::A {
if let ResourceData::A(a) = &answer.rdata {
Some(
Ipv4Addr::new(a.address1, a.address2, a.address3, a.address4)
.into(),
)
} else {
None
}
} else if resource_type == ResourceType::AAAA {
if let ResourceData::AAAA(aaaa) = &answer.rdata {
Some(
Ipv6Addr::new(
aaaa.address1,
aaaa.address2,
aaaa.address3,
aaaa.address4,
aaaa.address5,
aaaa.address6,
aaaa.address7,
aaaa.address8,
)
.into()
)
} else {
None
}
} else {
panic!(
"Unsupported resource type for address extraction: {:?}",
answer.resource_type
);
}
}