~nickbp/originz

ref: bench-baseline originz/src/server.rs -rw-r--r-- 20.3 KiB
d94f181cNick Parker Backport current benchmark to older code 1 year, 8 months ago
                                                                                
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
#![deny(warnings)]

use std::convert::TryFrom;
use std::net::IpAddr;

use anyhow::{bail, Context, Result};
use bytes::BytesMut;
use lazy_static::lazy_static;
use redis;
use scopeguard;
use tracing::{debug, level_enabled, trace, warn, Level};

use crate::client::DnsClient;
use crate::codec::{decoder::DNSMessageDecoder, encoder::DNSMessageEncoder, message};
use crate::fbs::dns_enums_conv;
use crate::fbs::dns_enums_generated::{ResourceClass, ResourceType, ResponseCode};
use crate::fbs::dns_message_generated::{Message, Question, OPT};
use crate::filter::{filter, reader};

pub struct ServerMembers<'a> {
    client: Box<dyn DnsClient>,
    /// Fallback to use if `client` returns a truncated response.
    /// Only applicable to UDP->TCP fallback.
    fallback_client: Option<Box<dyn DnsClient>>,
    client_buffer: BytesMut,
    // TODO(#3/#5/#6) Move Redis connection into UdpClient after figuring out a structure for TCP/DoH/DoT lookups.
    // TODO(#12) Moving the Redis connection would affect how UdpResolver works, since it currently creates a new UdpClient every time.
    redis_conn: Option<redis::Connection>,
    fbb: flatbuffers::FlatBufferBuilder<'a>,
}

impl<'a> ServerMembers<'a> {
    pub fn new(
        client: Box<dyn DnsClient>,
        fallback_client: Option<Box<dyn DnsClient>>,
        redis_conn: Option<redis::Connection>,
    ) -> ServerMembers<'a> {
        ServerMembers {
            client,
            fallback_client,
            client_buffer: BytesMut::with_capacity(4096),
            redis_conn,
            fbb: flatbuffers::FlatBufferBuilder::new_with_capacity(1024),
        }
    }
}

lazy_static! {
    /// Script for retrieving both the data and TTL for a TTLed value in a single query.
    static ref GET_WITH_TTL: redis::Script = redis::Script::new("return {redis.call('get',KEYS[1]), redis.call('ttl',KEYS[1])}");

    /// Encoder instance, currently doesn't have state
    static ref ENCODER: DNSMessageEncoder = DNSMessageEncoder::new();
}

/// Information extracted from a DNS request, used internally for checking filters and sending a filtered response.
struct RequestInfo {
    name: String,
    resource_type: ResourceType,
    /// Responses to upstream clients must have the matching request ID.
    received_request_id: u16,
    /// We use the upstream client's requested UDP size for our response.
    requested_udp_size: u16,
}

/// Receives and handles a single query provided by packet_buffer,
/// and sends back a response again via packet_buffer.
/// Queries may be processed locally, or may be proxied from a configured external server
pub async fn handle_query<'a>(
    m: &'_ mut ServerMembers<'a>,
    packet_buffer: &mut BytesMut,
    filter: &filter::Filter,
) -> Result<()> {
    if let Some(request_info) = decode_request_check_local_response(
        &mut m.client,
        &mut m.redis_conn,
        &mut m.fbb,
        packet_buffer,
        &mut m.client_buffer,
        filter,
    )
    .await?
    {
        // Reset m.fbb when exiting this block, success or error.
        let mut fbb = scopeguard::guard(&mut m.fbb, |fbb| {
            fbb.reset();
        });

        if let Some(_response) = m.client.query(&mut m.client_buffer, &mut fbb).await? {
            // Clear the request we received from packet_buffer before reusing it for the final output.
            packet_buffer.clear();

            write_server_response(packet_buffer, &request_info, &fbb, &mut m.redis_conn)?;
            if level_enabled!(Level::TRACE) {
                // Clear fbb before reusing. It will then be cleared again when the above scopeguard clears.
                fbb.reset();
                print_response(&mut fbb, packet_buffer, "remote server")?;
            }
            return Ok(());
        }

        if let Some(fallback_client) = &mut m.fallback_client {
            // Send the request to the fallback client.
            // We'd have already checked filters/caches so we skip those here.
            // However we do need to reencode the request for the new client.

            // Clear any partial data that may be in fbb (from e.g. decoding a truncated response from primary client) before redoing the request decode.
            fbb.reset();
            if !DNSMessageDecoder::new().decode(packet_buffer, &mut fbb)? {
                bail!("Failed to parse incomplete request");
            }
            let request: Message<'_> = flatbuffers::get_root::<Message<'_>>(fbb.finished_data());
            // Mark the client buffer as empty so that we don't append on top of a prior request
            m.client_buffer.clear();
            fallback_client.encode(&request, &mut m.client_buffer)?;

            if let Some(_response) = fallback_client
                .query(&mut m.client_buffer, &mut fbb)
                .await?
            {
                // Clear the request we received from packet_buffer before reusing it for the final output.
                packet_buffer.clear();

                write_server_response(packet_buffer, &request_info, &fbb, &mut m.redis_conn)?;
                if level_enabled!(Level::TRACE) {
                    // Clear fbb before reusing. It will then be cleared again when the above scopeguard clears.
                    fbb.reset();
                    print_response(&mut fbb, packet_buffer, "remote fallback server")?;
                }
                return Ok(());
            } else {
                bail!("Failed to parse truncated/incomplete response from fallback client")
            }
        } else {
            // No fallback client available, implying that truncated messages shouldn't be happening.
            bail!("Failed to parse truncated/incomplete response")
        }
    } else {
        // We found a match in the filter, a filter response has been written to packet_buffer.
        Ok(())
    }
}

/// This function is a bit ugly on account of how flatbuffer builders work - and how we want to avoid having more buffers than necessary.
/// In the general case we only call this once, but if there's a UDP->TCP fallback then we will call this twice, since the request encoding is slightly different between the two clients. But this fallback should be rare in practice so the cost is low.
/// - Decodes the request payload, turning it into a Message
/// - Extracts some useful info from the Message and uses it to check the Filter
/// - If the filter doesn't match, encodes the request to be sent to the DnsClient in `client_buffer`
async fn decode_request_check_local_response(
    client: &mut Box<dyn DnsClient>,
    redis_conn: &mut Option<redis::Connection>,
    raw_fbb: &mut flatbuffers::FlatBufferBuilder<'_>,
    packet_buffer: &mut BytesMut,
    client_buffer: &mut BytesMut,
    filter: &filter::Filter,
) -> Result<Option<RequestInfo>> {
    // Reset request_fbb when exiting this block, success or error.
    let mut fbb = scopeguard::guard(raw_fbb, |fbb| {
        fbb.reset();
    });

    // Decode the request message so that we can see what it's querying for
    if !DNSMessageDecoder::new().decode(packet_buffer, &mut fbb)? {
        bail!("Failed to parse incomplete request");
    }

    let request: Message<'_> = flatbuffers::get_root::<Message<'_>>(fbb.finished_data());
    debug!("Incoming request: {}", request);
    if let Some((question, request_info)) = get_question(&request)? {
        if let Some((file_info, entry)) = filter.check(&request_info.name) {
            // Filter had a match.
            // Write response to packet_buffer to send back upstream.
            packet_buffer.clear();
            write_filter_response(
                packet_buffer,
                &request_info,
                &question,
                request.opt(),
                &file_info.source_path,
                entry,
            )?;
            if level_enabled!(Level::TRACE) {
                // Clear fbb before reusing. It will then be cleared again when the above scopeguard clears.
                fbb.reset();
                print_response(&mut fbb, packet_buffer, "filter")?;
            }
            Ok(None)
        } else if let Some(conn) = redis_conn {
            // Redis cache enabled, attempt to read cached result from Redis
            if query_redis_cache(conn, &request_info, packet_buffer)? {
                // Found in Redis
                if level_enabled!(Level::TRACE) {
                    // Clear fbb before reusing. It will then be cleared again when the above scopeguard clears.
                    fbb.reset();
                    print_response(&mut fbb, packet_buffer, "redis")?;
                }
                Ok(None)
            } else {
                // Mark the client buffer as empty so that we don't append on top of a prior request
                client_buffer.clear();
                // Filter and cache both missed: Send the encoded request to the client.
                client.encode(&request, client_buffer)?;
                Ok(Some(request_info))
            }
        } else {
            // No filter match and redis cache not enabled, skip to not found
            trace!(
                "No filter entry found for {}, performing upstream query",
                request_info.name
            );
            // Mark the client buffer as empty so that we don't append on top of a prior request
            client_buffer.clear();
            client.encode(&request, client_buffer)?;
            Ok(Some(request_info))
        }
    } else {
        bail!("Missing question in request");
    }
}

/// Writes the response to `packet_buffer` based on the destination info in the retrieved filter entry.
fn write_filter_response(
    packet_buffer: &mut BytesMut,
    request_info: &RequestInfo,
    question: &Question<'_>,
    opt: Option<OPT<'_>>,
    filter_source: &String,
    entry: &reader::FileEntry,
) -> Result<()> {
    if let (None, None) = (entry.dest_ipv4, entry.dest_ipv6) {
        // Return black hole as indicated by filter
        debug!(
            "Got filter entry for {} from {} line {}: dest=NONE",
            request_info.name, filter_source, entry.line_num
        );
        ENCODER.encode_local_response(
            ResponseCode::RESPONSE_NXDOMAIN,
            request_info.received_request_id,
            &question,
            opt,
            None,
            Some(request_info.requested_udp_size),
            packet_buffer,
        )
    } else if request_info.resource_type == ResourceType::TYPE_A {
        // Return IPv4/A result returned by filter
        debug!(
            "Got filter entry for {} from {} line {}: dest={:?}",
            request_info.name, filter_source, entry.line_num, entry.dest_ipv4
        );
        ENCODER.encode_local_response(
            ResponseCode::RESPONSE_NOERROR,
            request_info.received_request_id,
            &question,
            opt,
            entry.dest_ipv4.map(|ip| IpAddr::V4(ip)),
            Some(request_info.requested_udp_size),
            packet_buffer,
        )
    } else if request_info.resource_type == ResourceType::TYPE_AAAA {
        // Return IPv6/AAAA result returned by filter
        debug!(
            "Got filter entry for {} from {} line {}: dest={:?}",
            request_info.name, filter_source, entry.line_num, entry.dest_ipv6
        );
        ENCODER.encode_local_response(
            ResponseCode::RESPONSE_NOERROR,
            request_info.received_request_id,
            &question,
            opt,
            entry.dest_ipv6.map(|ip| IpAddr::V6(ip)),
            Some(request_info.requested_udp_size),
            packet_buffer,
        )
    } else {
        // Misc record type, for a domain that's got A and/or AAAA overrides in the filters: Record not found.
        // It's a little ambiguous whether we should instead try going upstream if this happens,
        // But if you had an upstream server with the right information, why would you be putting custom host entries locally?
        // Therefore we explicitly do NOT support misc record types like MX and SRV for hostnames that have a filter entry.
        ENCODER.encode_local_response(
            ResponseCode::RESPONSE_NOERROR,
            request_info.received_request_id,
            &question,
            opt,
            None,
            Some(request_info.requested_udp_size),
            packet_buffer,
        )
    }
}

/// Writes the response returned by a remote DNS server to `packet_buffer`, and to `redis_conn` if enabled.
fn write_server_response<'a>(
    packet_buffer: &mut BytesMut,
    request_info: &RequestInfo,
    fbb: &flatbuffers::FlatBufferBuilder<'a>,
    mut redis_conn: &mut Option<redis::Connection>,
) -> Result<()> {
    // Reencode the response to be returned, updating udp_size + message_id to match the original request
    let response: Message<'_> = flatbuffers::get_root::<Message<'_>>(fbb.finished_data());
    ENCODER.encode(
        &response,
        Some(request_info.requested_udp_size),
        packet_buffer,
    )?;

    if let Some(conn) = &mut redis_conn {
        // Redis cache enabled, store response
        let response_min_ttl = message::get_min_ttl(&response).with_context(|| {
            format!(
                "Missing resources in {:?} response for {}",
                request_info.resource_type, request_info.name
            )
        })?;
        let store_result = redis::cmd("SETEX")
            .arg(format!(
                "kapiti_fbmsg__{:?}__{}",
                request_info.resource_type, request_info.name
            ))
            .arg(response_min_ttl as usize)
            .arg(fbb.finished_data())
            .query::<redis::Value>(conn);
        // If cache storage fails, log a warning but don't kill the query.
        match store_result {
            Ok(_) => debug!(
                "Stored {:?} result for {} to Redis cache with TTL={}s",
                request_info.resource_type, request_info.name, response_min_ttl
            ),
            Err(e) => warn!(
                "Failed to store {:?} result for {} in Redis cache, continuing anyway: {:?}",
                request_info.resource_type, request_info.name, e
            ),
        }
    };

    // Set the message ID for the response so that it matches the original request.
    // Keeping the message IDs independent reduces the likelihood of cache poisoning.
    // TODO(#4) Once we support running a TCP server, this would need to support writing at the TCP offset (2).
    message::update_message_id(request_info.received_request_id, packet_buffer, 0)
}

/// Queries Redis for a cached response.
/// If a match is found then `packet_buffer` is populated with the result and Ok(true) is returned, otherwise Ok(false) is returned.
fn query_redis_cache(
    conn: &mut redis::Connection,
    request_info: &RequestInfo,
    packet_buffer: &mut BytesMut,
) -> Result<bool> {
    let redis_key = format!(
        "kapiti_fbmsg__{:?}__{}",
        request_info.resource_type, request_info.name
    );
    let response_option: Option<Vec<redis::Value>> = GET_WITH_TTL
        .key(redis_key)
        .invoke(conn)
        .with_context(|| format!("Reading cached response failed for {}", request_info.name))?;

    if let Some(response_vec) = response_option {
        return match (response_vec.get(0), response_vec.get(1)) {
            (Some(redis::Value::Data(bytes)), Some(redis::Value::Int(raw_redis_ttl))) => {
                // Just in case, maybe data wasn't garbage collected yet?
                if raw_redis_ttl <= &0 {
                    return Ok(false);
                }
                let redis_ttl: u32 = u32::try_from(*raw_redis_ttl).with_context(|| {
                    format!(
                        "Invalid TTL value {} with cached {:?} result for {}",
                        raw_redis_ttl, request_info.resource_type, request_info.name
                    )
                })?;
                let orig_response: Message<'_> = flatbuffers::get_root::<Message<'_>>(bytes);
                debug!(
                    "Response from Redis (ttl={:?}): {}",
                    redis_ttl, orig_response
                );

                // Update TTL in response: Search all records for a minimum TTL value, then subtract that from all record TTLs.
                // For example, if the min TTL across all records is 30s and the Redis response TTL is 20s, then 10s has passed and all TTLs should have 10s subtracted.
                let min_ttl = message::get_min_ttl(&orig_response).with_context(|| {
                    format!(
                        "Missing resources in cached {:?} response for {}",
                        request_info.resource_type, request_info.name
                    )
                })?;

                if min_ttl < redis_ttl {
                    // Shouldn't happen: Redis TTL should have started at min_ttl and gone down.
                    warn!(
                        "Redis had invalid TTL {} with {:?} result for {}: {:?}",
                        redis_ttl, request_info.resource_type, request_info.name, orig_response
                    );
                    return Ok(false);
                }
                packet_buffer.clear();
                ENCODER.encode_cached_response(
                    &orig_response,
                    request_info.received_request_id,
                    min_ttl - redis_ttl,
                    Some(request_info.requested_udp_size),
                    packet_buffer,
                )?;
                Ok(true)
            }
            (Some(redis::Value::Nil), Some(redis::Value::Int(-2))) => {
                trace!(
                    "Redis didn't have cached {:?} result for {}",
                    request_info.resource_type,
                    request_info.name
                );
                Ok(false)
            }
            (_other_msg, _other_ttl) => {
                warn!(
                    "Unexpected data in Redis lookup response, bad connection?: {:?}",
                    response_vec
                );
                // Give up on cache and query direct
                Ok(false)
            }
        };
    } else {
        // Shouldn't happen, but may as well handle it
        trace!(
            "Redis didn't have cache {:?} result for/{}",
            request_info.resource_type,
            request_info.name
        );
        Ok(false)
    }
}

fn get_question<'a>(request: &Message<'a>) -> Result<Option<(Question<'a>, RequestInfo)>> {
    if let Some(questions) = request.question() {
        for i in 0..questions.len() {
            let question = questions.get(i);
            if question.resource_class() != ResourceClass::CLASS_INTERNET as u16 {
                continue;
            }

            if let Some(name_str) = question.name() {
                if let Some(resource_type) =
                    dns_enums_conv::resourcetype_int(question.resource_type() as usize)
                {
                    // Remove trailing '.': Filters do not include trailing '.'
                    let mut name_string = name_str.to_string();
                    if !name_string.is_empty() {
                        name_string.pop();
                    }

                    let request_id = request
                        .header()
                        .with_context(|| "missing request header")?
                        .id();
                    return Ok(Some((
                        question,
                        RequestInfo {
                            name: name_string,
                            resource_type,
                            received_request_id: request_id,
                            // For the response UDP size, lets just return whatever the client sent...
                            requested_udp_size: request
                                .opt()
                                .map(|opt| opt.udp_size())
                                .unwrap_or(4096),
                        },
                    )));
                }
            }
        }
    }
    Ok(None)
}

fn print_response<'a>(
    fbb: &mut flatbuffers::FlatBufferBuilder<'a>,
    packet_buffer: &mut BytesMut,
    source: &str,
) -> Result<()> {
    trace!(
        "Raw response from {} ({}b): {:02X?}",
        source,
        packet_buffer.len(),
        &packet_buffer[..]
    );

    if !DNSMessageDecoder::new().decode(packet_buffer, fbb)? {
        // Shouldn't happen for our own local data, implies parser bug
        bail!("Failed to re-parse response from {}", source);
    }

    let response: Message<'_> = flatbuffers::get_root::<Message<'_>>(fbb.finished_data());
    debug!("Returning response from {}: {}", source, response);

    Ok(())
}