#![deny(warnings, rust_2018_idioms)]
use std::convert::TryFrom;
use std::net::IpAddr;
use anyhow::{bail, Context, Result};
use bytes::BytesMut;
use lazy_static::lazy_static;
use redis;
use scopeguard;
use tracing::{debug, level_enabled, trace, warn, Level};
use crate::client::DnsClient;
use crate::codec::{decoder::DNSMessageDecoder, encoder::DNSMessageEncoder, message};
use crate::fbs::dns_enums_conv;
use crate::fbs::dns_enums_generated::{ResourceClass, ResourceType, ResponseCode};
use crate::fbs::dns_message_generated::{Message, Question, OPT};
use crate::filter::{filter, reader};
pub struct ServerMembers<'a> {
client: Box<dyn DnsClient>,
/// Fallback to use if `client` returns a truncated response.
/// Only applicable to UDP->TCP fallback.
fallback_client: Option<Box<dyn DnsClient>>,
client_buffer: BytesMut,
// TODO(#3/#5/#6) Move Redis connection into UdpClient after figuring out a structure for TCP/DoH/DoT lookups.
// TODO(#12) Moving the Redis connection would affect how UdpResolver works, since it currently creates a new UdpClient every time.
redis_conn: Option<redis::Connection>,
fbb: flatbuffers::FlatBufferBuilder<'a>,
}
impl<'a> ServerMembers<'a> {
pub fn new(
client: Box<dyn DnsClient>,
fallback_client: Option<Box<dyn DnsClient>>,
redis_conn: Option<redis::Connection>,
) -> ServerMembers<'a> {
ServerMembers {
client,
fallback_client,
client_buffer: BytesMut::with_capacity(4096),
redis_conn,
fbb: flatbuffers::FlatBufferBuilder::new_with_capacity(1024),
}
}
}
lazy_static! {
/// Script for retrieving both the data and TTL for a TTLed value in a single query.
static ref GET_WITH_TTL: redis::Script = redis::Script::new("return {redis.call('get',KEYS[1]), redis.call('ttl',KEYS[1])}");
/// Encoder instance, currently doesn't have state
static ref ENCODER: DNSMessageEncoder = DNSMessageEncoder::new();
}
/// Information extracted from a DNS request, used internally for checking filters and sending a filtered response.
struct RequestInfo {
name: String,
resource_type: ResourceType,
/// Responses to upstream clients must have the matching request ID.
received_request_id: u16,
/// We use the upstream client's requested UDP size for our response.
requested_udp_size: u16,
}
/// Receives and handles a single query provided by packet_buffer,
/// and sends back a response again via packet_buffer.
/// Queries may be processed locally, or may be proxied from a configured external server
pub async fn handle_query<'a>(
m: &'_ mut ServerMembers<'a>,
packet_buffer: &mut BytesMut,
filter: &filter::Filter,
) -> Result<()> {
if let Some(request_info) = decode_request_check_local_response(
&mut m.client,
&mut m.redis_conn,
&mut m.fbb,
packet_buffer,
&mut m.client_buffer,
filter,
)
.await?
{
// Reset m.fbb when exiting this block, success or error.
let mut fbb = scopeguard::guard(&mut m.fbb, |fbb| {
fbb.reset();
});
if let Some(_response) = m.client.query(&mut m.client_buffer, &mut fbb).await? {
// Clear the request we received from packet_buffer before reusing it for the final output.
packet_buffer.clear();
write_server_response(packet_buffer, &request_info, &fbb, &mut m.redis_conn)?;
if level_enabled!(Level::TRACE) {
// Clear fbb before reusing. It will then be cleared again when the above scopeguard clears.
fbb.reset();
print_response(&mut fbb, packet_buffer, "remote server")?;
}
return Ok(());
}
if let Some(fallback_client) = &mut m.fallback_client {
// Send the request to the fallback client.
// We'd have already checked filters/caches so we skip those here.
// However we do need to reencode the request for the new client.
// Clear any partial data that may be in fbb (from e.g. decoding a truncated response from primary client) before redoing the request decode.
fbb.reset();
if !DNSMessageDecoder::new().decode(packet_buffer, &mut fbb)? {
bail!("Failed to parse incomplete request");
}
let request: Message<'_> = flatbuffers::get_root::<Message<'_>>(fbb.finished_data());
// Mark the client buffer as empty so that we don't append on top of a prior request
m.client_buffer.clear();
fallback_client.encode(&request, &mut m.client_buffer)?;
if let Some(_response) = fallback_client
.query(&mut m.client_buffer, &mut fbb)
.await?
{
// Clear the request we received from packet_buffer before reusing it for the final output.
packet_buffer.clear();
write_server_response(packet_buffer, &request_info, &fbb, &mut m.redis_conn)?;
if level_enabled!(Level::TRACE) {
// Clear fbb before reusing. It will then be cleared again when the above scopeguard clears.
fbb.reset();
print_response(&mut fbb, packet_buffer, "remote fallback server")?;
}
return Ok(());
} else {
bail!("Failed to parse truncated/incomplete response from fallback client")
}
} else {
// No fallback client available, implying that truncated messages shouldn't be happening.
bail!("Failed to parse truncated/incomplete response")
}
} else {
// We found a match in the filter, a filter response has been written to packet_buffer.
Ok(())
}
}
/// This function is a bit ugly on account of how flatbuffer builders work - and how we want to avoid having more buffers than necessary.
/// In the general case we only call this once, but if there's a UDP->TCP fallback then we will call this twice, since the request encoding is slightly different between the two clients. But this fallback should be rare in practice so the cost is low.
/// - Decodes the request payload, turning it into a Message
/// - Extracts some useful info from the Message and uses it to check the Filter
/// - If the filter doesn't match, encodes the request to be sent to the DnsClient in `client_buffer`
async fn decode_request_check_local_response(
client: &mut Box<dyn DnsClient>,
redis_conn: &mut Option<redis::Connection>,
raw_fbb: &mut flatbuffers::FlatBufferBuilder<'_>,
packet_buffer: &mut BytesMut,
client_buffer: &mut BytesMut,
filter: &filter::Filter,
) -> Result<Option<RequestInfo>> {
// Reset request_fbb when exiting this block, success or error.
let mut fbb = scopeguard::guard(raw_fbb, |fbb| {
fbb.reset();
});
// Decode the request message so that we can see what it's querying for
if !DNSMessageDecoder::new().decode(packet_buffer, &mut fbb)? {
bail!("Failed to parse incomplete request");
}
let request: Message<'_> = flatbuffers::get_root::<Message<'_>>(fbb.finished_data());
debug!("Incoming request: {}", request);
if let Some((question, request_info)) = get_question(&request)? {
if let Some((file_info, entry)) = filter.check(&request_info.name) {
// Filter had a match.
// Write response to packet_buffer to send back upstream.
packet_buffer.clear();
write_filter_response(
packet_buffer,
&request_info,
&question,
request.opt(),
&file_info.source_path,
entry,
)?;
if level_enabled!(Level::TRACE) {
// Clear fbb before reusing. It will then be cleared again when the above scopeguard clears.
fbb.reset();
print_response(&mut fbb, packet_buffer, "filter")?;
}
Ok(None)
} else if let Some(conn) = redis_conn {
// Redis cache enabled, attempt to read cached result from Redis
if query_redis_cache(conn, &request_info, packet_buffer)? {
// Found in Redis
if level_enabled!(Level::TRACE) {
// Clear fbb before reusing. It will then be cleared again when the above scopeguard clears.
fbb.reset();
print_response(&mut fbb, packet_buffer, "redis")?;
}
Ok(None)
} else {
// Mark the client buffer as empty so that we don't append on top of a prior request
client_buffer.clear();
// Filter and cache both missed: Send the encoded request to the client.
client.encode(&request, client_buffer)?;
Ok(Some(request_info))
}
} else {
// No filter match and redis cache not enabled, skip to not found
trace!(
"No filter entry found for {}, performing upstream query",
request_info.name
);
// Mark the client buffer as empty so that we don't append on top of a prior request
client_buffer.clear();
client.encode(&request, client_buffer)?;
Ok(Some(request_info))
}
} else {
bail!("Missing question in request");
}
}
/// Writes the response to `packet_buffer` based on the destination info in the retrieved filter entry.
fn write_filter_response(
packet_buffer: &mut BytesMut,
request_info: &RequestInfo,
question: &Question<'_>,
opt: Option<OPT<'_>>,
filter_source: &String,
entry: &reader::FileEntry,
) -> Result<()> {
if let (None, None) = (entry.dest_ipv4, entry.dest_ipv6) {
// Return black hole as indicated by filter
debug!(
"Got filter entry for {} from {} line {}: dest=NONE",
request_info.name, filter_source, entry.line_num
);
ENCODER.encode_local_response(
ResponseCode::RESPONSE_NXDOMAIN,
request_info.received_request_id,
&question,
opt,
None,
Some(request_info.requested_udp_size),
packet_buffer,
)
} else if request_info.resource_type == ResourceType::TYPE_A {
// Return IPv4/A result returned by filter
debug!(
"Got filter entry for {} from {} line {}: dest={:?}",
request_info.name, filter_source, entry.line_num, entry.dest_ipv4
);
ENCODER.encode_local_response(
ResponseCode::RESPONSE_NOERROR,
request_info.received_request_id,
&question,
opt,
entry.dest_ipv4.map(|ip| IpAddr::V4(ip)),
Some(request_info.requested_udp_size),
packet_buffer,
)
} else if request_info.resource_type == ResourceType::TYPE_AAAA {
// Return IPv6/AAAA result returned by filter
debug!(
"Got filter entry for {} from {} line {}: dest={:?}",
request_info.name, filter_source, entry.line_num, entry.dest_ipv6
);
ENCODER.encode_local_response(
ResponseCode::RESPONSE_NOERROR,
request_info.received_request_id,
&question,
opt,
entry.dest_ipv6.map(|ip| IpAddr::V6(ip)),
Some(request_info.requested_udp_size),
packet_buffer,
)
} else {
// Misc record type, for a domain that's got A and/or AAAA overrides in the filters: Record not found.
// It's a little ambiguous whether we should instead try going upstream if this happens,
// But if you had an upstream server with the right information, why would you be putting custom host entries locally?
// Therefore we explicitly do NOT support misc record types like MX and SRV for hostnames that have a filter entry.
ENCODER.encode_local_response(
ResponseCode::RESPONSE_NOERROR,
request_info.received_request_id,
&question,
opt,
None,
Some(request_info.requested_udp_size),
packet_buffer,
)
}
}
/// Writes the response returned by a remote DNS server to `packet_buffer`, and to `redis_conn` if enabled.
fn write_server_response<'a>(
packet_buffer: &mut BytesMut,
request_info: &RequestInfo,
fbb: &flatbuffers::FlatBufferBuilder<'a>,
mut redis_conn: &mut Option<redis::Connection>,
) -> Result<()> {
// Reencode the response to be returned, updating udp_size + message_id to match the original request
let response: Message<'_> = flatbuffers::get_root::<Message<'_>>(fbb.finished_data());
ENCODER.encode(
&response,
Some(request_info.requested_udp_size),
packet_buffer,
)?;
if let Some(conn) = &mut redis_conn {
// Redis cache enabled, store response
let response_min_ttl = message::get_min_ttl(&response).with_context(|| {
format!(
"Missing resources in {:?} response for {}",
request_info.resource_type, request_info.name
)
})?;
let store_result = redis::cmd("SETEX")
.arg(format!(
"kapiti_fbmsg__{:?}__{}",
request_info.resource_type, request_info.name
))
.arg(response_min_ttl as usize)
.arg(fbb.finished_data())
.query::<redis::Value>(conn);
// If cache storage fails, log a warning but don't kill the query.
match store_result {
Ok(_) => debug!(
"Stored {:?} result for {} to Redis cache with TTL={}s",
request_info.resource_type, request_info.name, response_min_ttl
),
Err(e) => warn!(
"Failed to store {:?} result for {} in Redis cache, continuing anyway: {:?}",
request_info.resource_type, request_info.name, e
),
}
};
// Set the message ID for the response so that it matches the original request.
// Keeping the message IDs independent reduces the likelihood of cache poisoning.
// TODO(#4) Once we support running a TCP server, this would need to support writing at the TCP offset (2).
message::update_message_id(request_info.received_request_id, packet_buffer, 0)
}
/// Queries Redis for a cached response.
/// If a match is found then `packet_buffer` is populated with the result and Ok(true) is returned, otherwise Ok(false) is returned.
fn query_redis_cache(
conn: &mut redis::Connection,
request_info: &RequestInfo,
packet_buffer: &mut BytesMut,
) -> Result<bool> {
let redis_key = format!(
"kapiti_fbmsg__{:?}__{}",
request_info.resource_type, request_info.name
);
let response_option: Option<Vec<redis::Value>> = GET_WITH_TTL
.key(redis_key)
.invoke(conn)
.with_context(|| format!("Reading cached response failed for {}", request_info.name))?;
if let Some(response_vec) = response_option {
return match (response_vec.get(0), response_vec.get(1)) {
(Some(redis::Value::Data(bytes)), Some(redis::Value::Int(raw_redis_ttl))) => {
// Just in case, maybe data wasn't garbage collected yet?
if raw_redis_ttl <= &0 {
return Ok(false);
}
let redis_ttl: u32 = u32::try_from(*raw_redis_ttl).with_context(|| {
format!(
"Invalid TTL value {} with cached {:?} result for {}",
raw_redis_ttl, request_info.resource_type, request_info.name
)
})?;
let orig_response: Message<'_> = flatbuffers::get_root::<Message<'_>>(bytes);
debug!(
"Response from Redis (ttl={:?}): {}",
redis_ttl, orig_response
);
// Update TTL in response: Search all records for a minimum TTL value, then subtract that from all record TTLs.
// For example, if the min TTL across all records is 30s and the Redis response TTL is 20s, then 10s has passed and all TTLs should have 10s subtracted.
let min_ttl = message::get_min_ttl(&orig_response).with_context(|| {
format!(
"Missing resources in cached {:?} response for {}",
request_info.resource_type, request_info.name
)
})?;
if min_ttl < redis_ttl {
// Shouldn't happen: Redis TTL should have started at min_ttl and gone down.
warn!(
"Redis had invalid TTL {} with {:?} result for {}: {:?}",
redis_ttl, request_info.resource_type, request_info.name, orig_response
);
return Ok(false);
}
packet_buffer.clear();
ENCODER.encode_cached_response(
&orig_response,
request_info.received_request_id,
min_ttl - redis_ttl,
Some(request_info.requested_udp_size),
packet_buffer,
)?;
Ok(true)
}
(Some(redis::Value::Nil), Some(redis::Value::Int(-2))) => {
trace!(
"Redis didn't have cached {:?} result for {}",
request_info.resource_type,
request_info.name
);
Ok(false)
}
(_other_msg, _other_ttl) => {
warn!(
"Unexpected data in Redis lookup response, bad connection?: {:?}",
response_vec
);
// Give up on cache and query direct
Ok(false)
}
};
} else {
// Shouldn't happen, but may as well handle it
trace!(
"Redis didn't have cache {:?} result for/{}",
request_info.resource_type,
request_info.name
);
Ok(false)
}
}
fn get_question<'a>(request: &Message<'a>) -> Result<Option<(Question<'a>, RequestInfo)>> {
if let Some(questions) = request.question() {
for i in 0..questions.len() {
let question = questions.get(i);
if question.resource_class() != ResourceClass::CLASS_INTERNET as u16 {
continue;
}
if let Some(name_str) = question.name() {
if let Some(resource_type) =
dns_enums_conv::resourcetype_int(question.resource_type() as usize)
{
// Remove trailing '.': Filters do not include trailing '.'
let mut name_string = name_str.to_string();
if !name_string.is_empty() {
name_string.pop();
}
let request_id = request
.header()
.with_context(|| "missing request header")?
.id();
return Ok(Some((
question,
RequestInfo {
name: name_string,
resource_type,
received_request_id: request_id,
// For the response UDP size, lets just return whatever the client sent...
requested_udp_size: request
.opt()
.map(|opt| opt.udp_size())
.unwrap_or(4096),
},
)));
}
}
}
}
Ok(None)
}
fn print_response<'a>(
fbb: &mut flatbuffers::FlatBufferBuilder<'a>,
packet_buffer: &mut BytesMut,
source: &str,
) -> Result<()> {
trace!(
"Raw response from {} ({}b): {:02X?}",
source,
packet_buffer.len(),
&packet_buffer[..]
);
if !DNSMessageDecoder::new().decode(packet_buffer, fbb)? {
// Shouldn't happen for our own local data, implies parser bug
bail!("Failed to re-parse response from {}", source);
}
let response: Message<'_> = flatbuffers::get_root::<Message<'_>>(fbb.finished_data());
debug!("Returning response from {}: {}", source, response);
Ok(())
}