~nickbp/kapiti

ref: 025082b1e3b359a5343d5977bcdd37a3f57c3e0c kapiti/src/server.rs -rw-r--r-- 9.8 KiB
025082b1Nick Parker Basic initial structure for OPT COOKIE support (#21) 1 year, 10 months ago
                                                                                
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
#![deny(warnings, rust_2018_idioms)]

use std::io;
use std::net::{IpAddr, SocketAddr};
use std::result::Result;

use bytes::BytesMut;
use log::{debug, log_enabled, trace, Level};
use scopeguard;

use crate::client::client;
use crate::codec::{decoder::DNSMessageDecoder, encoder::DNSMessageEncoder, message, optcookie};
use crate::fbs::dns_enums_conv;
use crate::fbs::dns_enums_generated::{ResourceClass, ResourceType, ResponseCode};
use crate::fbs::dns_message_generated::{Message, Question};
use crate::filter::filter;

macro_rules! io_err {
    ($($arg:tt)*) => (Err(io::Error::new(io::ErrorKind::InvalidData, format!($($arg)+))))
}

pub struct ServerMembers<'a> {
    client: client::DnsClient,
    fbb: flatbuffers::FlatBufferBuilder<'a>,
}

impl <'a> ServerMembers<'a> {
    pub fn new(query_addr: SocketAddr) -> ServerMembers<'a> {
        ServerMembers {
            client: client::DnsClient::new(query_addr),
            fbb: flatbuffers::FlatBufferBuilder::new_with_capacity(1024),
        }
    }
}

/// Receives and handles a single query provided by packet_buffer,
/// and sends back a response again via packet_buffer.
/// Queries may be processed locally, or may be proxied from a configured external server
pub async fn handle_query<'a>(
    m: &'_ mut ServerMembers<'a>,
    packet_buffer: &mut BytesMut,
    filter: &filter::Filter,
) -> Result<(), io::Error> {
    let received_request_id: u16;
    let requested_udp_size: u16;
    let requested_cookie: Option<optcookie::COOKIE>;
    {
        // Reset request_fbb when exiting this block, success or error.
        let mut fbb = scopeguard::guard(&mut m.fbb, |fbb| {
            fbb.reset();
        });

        // Decode the request message so that we can see what it's querying for
        if !DNSMessageDecoder::new().decode(packet_buffer, &mut fbb)? {
            return io_err!("Failed to parse incomplete request");
        }
        packet_buffer.clear();

        let request: Message<'_> =
            flatbuffers::get_root::<Message<'_>>(fbb.finished_data());
        debug!("Incoming request: {}", request);
        received_request_id = request.header().expect("missing request header").id();

        // For the response UDP size, lets just return whatever the client sent...
        let requested_cookie: Option<optcookie::COOKIE>;
        if let Some(opt) = request.opt() {
            requested_udp_size = opt.udp_size();
            requested_cookie = optcookie::read_opt_entry_cookie(opt)?;
        } else {
            requested_udp_size = 4096;
            requested_cookie = None;
        }

        if check_local_response(&request, received_request_id, packet_buffer, filter, requested_udp_size)? {
            // Wrote local response to packet_buffer, exit early
            if log_enabled!(Level::Debug) {
                // Clear fbb before reusing. It will then be cleared again when the above scopeguard clears.
                fbb.reset();
                print_response(&mut fbb, packet_buffer)?;
            }
            return Ok(());
        }

        // Reencode the request message to be forwarded to the server.
        DNSMessageEncoder::new().encode(
            request,
            Some(m.client.last_udp_size()),
            packet_buffer,
        )?;
    }

    // Reset request_fbb when exiting this block, success or error.
    let mut fbb = scopeguard::guard(&mut m.fbb, |fbb| {
        fbb.reset();
    });

    match m.client.query(packet_buffer, &mut fbb).await? {
        Some(response) => {
            // Clear the request we sent from packet_buffer before reusing it for the final output.
            packet_buffer.clear();

            // Reencode the response to be be sent to our client, updating udp_size + message_id.
            DNSMessageEncoder::new().encode(
                response,
                Some(requested_udp_size),
                packet_buffer,
            )?;

            // Set the message ID for the response so that it matches the original request.
            // Keeping the message IDs independent reduces the likelihood of cache poisoning.
            message::update_message_id(received_request_id, packet_buffer)?;

            if log_enabled!(Level::Debug) {
                // Clear fbb before reusing. It will then be cleared again when the above scopeguard clears.
                fbb.reset();
                print_response(&mut fbb, packet_buffer)?;
            }
            return Ok(());
        },
        None => {
            // TODO(#3) handle truncated/fragmented response with fallback to TCP
            io_err!("Failed to parse truncated/incomplete response")
        },
    }
}

/// Checks if the request should be served locally according to the provided `filter`.
/// If it's served locally, then `packet_buffer` is filled with the response payload and Ok(true) is returned.
/// If it's not served locally then Ok(false) is returned without any changes to `packet_buffer`.
fn check_local_response(
    request: &Message<'_>,
    received_request_id: u16,
    packet_buffer: &mut BytesMut,
    filter: &filter::Filter,
    response_udp_size: u16,
) -> Result<bool, io::Error> {
    if let Some((question, name, resource_type)) = get_question(request) {
        if let Some((file_info, entry)) = filter.check(&name) {
            if let (None, None) = (entry.dest_ipv4, entry.dest_ipv6) {
                // Return black hole as indicated by filter
                debug!(
                    "Got filter entry for {} from {} line {}: dest=NONE",
                    name, file_info.source_path, entry.line_num
                );
                DNSMessageEncoder::new().encode_local_response(
                    ResponseCode::RESPONSE_NXDOMAIN,
                    received_request_id,
                    &question,
                    request.opt(),
                    None,
                    Some(response_udp_size),
                    packet_buffer,
                )?;
            } else if resource_type == ResourceType::TYPE_A {
                // Return IPv4/A result returned by filter
                debug!(
                    "Got filter entry for {} from {} line {}: dest={:?}",
                    name, file_info.source_path, entry.line_num, entry.dest_ipv4
                );
                DNSMessageEncoder::new().encode_local_response(
                    ResponseCode::RESPONSE_NOERROR,
                    received_request_id,
                    &question,
                    request.opt(),
                    entry.dest_ipv4.map(|ip| IpAddr::V4(ip)),
                    Some(response_udp_size),
                    packet_buffer,
                )?;
            } else if resource_type == ResourceType::TYPE_AAAA {
                // Return IPv6/AAAA result returned by filter
                debug!(
                    "Got filter entry for {} from {} line {}: dest={:?}",
                    name, file_info.source_path, entry.line_num, entry.dest_ipv6
                );
                DNSMessageEncoder::new().encode_local_response(
                    ResponseCode::RESPONSE_NOERROR,
                    received_request_id,
                    &question,
                    request.opt(),
                    entry.dest_ipv6.map(|ip| IpAddr::V6(ip)),
                    Some(response_udp_size),
                    packet_buffer,
                )?;
            } else {
                // Misc record type, for a domain that's got A and/or AAAA overrides in the filters: Record not found.
                // It's a little ambiguous whether we should instead try going upstream if this happens,
                // But if you had an upstream server with the right information, why would you be putting custom host entries locally?
                // Therefore we explicitly do NOT support misc record types like MX and SRV for hostnames that have a filter entry.
                DNSMessageEncoder::new().encode_local_response(
                    ResponseCode::RESPONSE_NOERROR,
                    received_request_id,
                    &question,
                    request.opt(),
                    None,
                    Some(response_udp_size),
                    packet_buffer,
                )?;
            }
            // Upstream will send back packet_buffer to the sender
            return Ok(true);
        } else {
            trace!(
                "No filter entry found for {}, performing upstream query",
                name
            );
        }
    }
    Ok(false)
}

fn get_question<'a>(request: &Message<'a>) -> Option<(Question<'a>, String, ResourceType)> {
    if let Some(questions) = request.question() {
        for i in 0..questions.len() {
            let question = questions.get(i);
            if question.resource_class() != ResourceClass::CLASS_INTERNET as u16 {
                continue;
            }

            if let Some(resource_type) =
                dns_enums_conv::resourcetype_int(question.resource_type() as usize)
            {
                if let Some(name_str) = question.name() {
                    // Remove trailing '.': Filters do not include trailing '.'
                    let mut name_string = name_str.to_string();
                    if !name_string.is_empty() {
                        name_string.pop();
                    }

                    return Some((question, name_string, resource_type));
                }
            }
        }
    }
    None
}

fn print_response<'a>(
    fbb: &mut flatbuffers::FlatBufferBuilder<'a>,
    packet_buffer: &mut BytesMut,
) -> Result<(), io::Error> {
    trace!(
        "Raw response ({}b): {:02X?}",
        packet_buffer.len(),
        &packet_buffer[..]
    );

    if !DNSMessageDecoder::new().decode(packet_buffer, fbb)? {
        // Shouldn't happen for our own local data, implies parser bug
        return io_err!("Failed to re-parse response");
    }

    let response: Message<'_> =
        flatbuffers::get_root::<Message<'_>>(fbb.finished_data());
    debug!("Returning response: {}", response);

    Ok(())
}