use std::net::IpAddr;
use std::sync::{Arc, Mutex};
use anyhow::{bail, Result};
use bytes::BytesMut;
use lazy_static::lazy_static;
use redis;
use tracing::{debug, level_enabled, trace, Level};
use crate::codec::{decoder::DNSMessageDecoder, encoder::DNSMessageEncoder, message::RequestInfo};
use crate::specs::enums_generated;
use crate::specs::message::{IntEnum, Message, Question, OPT};
use crate::filter::{filter::Filter, reader};
use crate::resolver::Resolver;
lazy_static! {
/// Script for retrieving both the data and TTL for a TTLed value in a single query.
static ref GET_WITH_TTL: redis::Script = redis::Script::new("return {redis.call('get',KEYS[1]), redis.call('ttl',KEYS[1])}");
/// Encoder instance, currently doesn't have state
static ref ENCODER: DNSMessageEncoder = DNSMessageEncoder::new();
/// "File name" name use in filtered info responses about hardcoded targets
static ref HARDCODED_SOURCE_NAME: String = "hardcoded".to_string();
}
pub struct Lookup {
resolver: Resolver,
filter: Arc<Mutex<Filter>>,
}
impl Lookup {
pub fn new(resolver: Resolver, filter: Arc<Mutex<Filter>>) -> Lookup {
Lookup { resolver, filter }
}
/// Receives and handles a single query provided by packet_buffer,
/// and passes back the response again via packet_buffer.
/// Queries may be processed locally, or may be proxied from a configured external server
pub async fn handle_query(self: &mut Lookup, packet_buffer: &mut BytesMut) -> Result<()> {
// Decode the request message so that we can see what it's querying for
if let Some(request) = DNSMessageDecoder::new().decode(packet_buffer)? {
debug!("Incoming request: {}", request);
if let Some((question, request_info)) = get_question(&request)? {
let filter_result: Option<(String, reader::FilterEntry)>;
match self.filter.lock() {
Err(e) => bail!("Failed to lock query filter: {:?}", e),
Ok(filter_locked) => {
// Hold the lock as briefly as possible, these clones should be cheap
filter_result =
filter_locked
.check(&request_info.name)
.map(|(file_info, file_entry)| {
if let Some(f) = file_info {
(f.source_path.clone(), (*file_entry).clone())
} else {
(HARDCODED_SOURCE_NAME.clone(), (*file_entry).clone())
}
});
}
}
if let Some((file_source_path, entry)) = filter_result {
// Filter had a match.
// Write response to packet_buffer to send back upstream.
packet_buffer.clear();
write_filter_response(
packet_buffer,
&request_info,
&question,
&request.opt,
&file_source_path,
&entry,
)?;
if level_enabled!(Level::TRACE) {
if let Some(response) = DNSMessageDecoder::new().decode(packet_buffer)? {
trace!("Returning response from filter: {}", response);
} else {
// Shouldn't happen for our own local data, implies parser bug
bail!(
"Failed to re-parse response from filter ({}b): {:02X?}",
packet_buffer.len(),
&packet_buffer[..]
);
}
}
return Ok(());
}
// Filter and cache (if any) both missed: Send the encoded request to the resolver client(s)
trace!(
"No filter nor cache entry found for {}, performing upstream query",
request_info.name
);
// Reuse packet_buffer for response
packet_buffer.clear();
self.resolver.resolve_raw(&request, &request_info, packet_buffer).await
} else {
bail!("Missing question in request");
}
} else {
bail!("Failed to parse incomplete request");
}
}
}
/// Writes the response to `packet_buffer` based on the destination info in the retrieved filter entry.
fn write_filter_response(
packet_buffer: &mut BytesMut,
request_info: &RequestInfo,
question: &Question,
opt: &Option<OPT>,
filter_source: &String,
entry: &reader::FilterEntry,
) -> Result<()> {
let filter_info = match entry.line_num {
Some(line_num) => format!("{}:{}", filter_source, line_num),
None => filter_source.to_string(),
};
if let (None, None) = (entry.dest_ipv4, entry.dest_ipv6) {
// Return blocked domain
debug!(
"Got block entry for {} from {} line {:?}: dest=NXDOMAIN",
request_info.name, filter_source, entry.line_num
);
ENCODER.encode_local_response(
enums_generated::ResponseCode::NXDOMAIN,
request_info.received_request_id,
question,
opt,
&filter_info,
None,
Some(request_info.requested_udp_size),
packet_buffer,
)
} else if request_info.resource_type == enums_generated::ResourceType::A {
// Return configured IPv4/A override
debug!(
"Got override {:?} entry for {} from {} line {:?}: dest={:?}",
request_info.resource_type,
request_info.name,
filter_source,
entry.line_num,
entry.dest_ipv4
);
ENCODER.encode_local_response(
enums_generated::ResponseCode::NOERROR,
request_info.received_request_id,
question,
opt,
&filter_info,
entry.dest_ipv4.map(|ip| IpAddr::V4(ip)),
Some(request_info.requested_udp_size),
packet_buffer,
)
} else if request_info.resource_type == enums_generated::ResourceType::AAAA {
// Return configured IPv6/AAAA override
debug!(
"Got override {:?} entry for {} from {} line {:?}: dest={:?}",
request_info.resource_type,
request_info.name,
filter_source,
entry.line_num,
entry.dest_ipv6
);
ENCODER.encode_local_response(
enums_generated::ResponseCode::NOERROR,
request_info.received_request_id,
question,
opt,
&filter_info,
entry.dest_ipv6.map(|ip| IpAddr::V6(ip)),
Some(request_info.requested_udp_size),
packet_buffer,
)
} else {
// Misc record type, for a domain that's got A and/or AAAA overrides: Record not found.
// It's a little ambiguous whether we should instead try going upstream if this happens.
// But if you had an upstream server with the right information, why would you be putting custom host entries locally?
// Therefore we explicitly do NOT support misc record types like MX and SRV for hostnames that have an override entry.
debug!(
"Got override {:?} entry for {} from {} line {:?}: dest=NONE",
request_info.resource_type, request_info.name, filter_source, entry.line_num
);
ENCODER.encode_local_response(
enums_generated::ResponseCode::NOERROR,
request_info.received_request_id,
question,
opt,
&filter_info,
None,
Some(request_info.requested_udp_size),
packet_buffer,
)
}
}
pub fn get_question<'a>(request: &'a Message) -> Result<Option<(&'a Question, RequestInfo)>> {
for question in &request.question {
if question.resource_class != IntEnum::Enum(enums_generated::ResourceClass::INTERNET) {
continue;
}
if let IntEnum::Enum(resource_type) = question.resource_type {
// Remove trailing '.': Filters do not include trailing '.'
let mut name = question.name.to_string();
if !name.is_empty() {
name.pop();
}
let request_id = request.header.id;
return Ok(Some((
question,
RequestInfo {
name,
resource_type,
received_request_id: request_id,
// For the response UDP size, lets just return whatever the client sent...
requested_udp_size: request
.opt
.as_ref()
.map(|opt| opt.udp_size)
.unwrap_or(4096),
},
)));
}
}
Ok(None)
}