~nickbp/kapiti

ref: 9d460d0b7ab39ac1e39576b34fa6217049c65edb kapiti/src/resolver.rs -rw-r--r-- 8.0 KiB
9d460d0bNick Parker Update dependencies to latest along with corresponding code changes 6 months ago
                                                                                
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
use std::sync::Arc;

use anyhow::{bail, Context, Result};
use async_lock::{Barrier, Mutex};
use bytes::BytesMut;
use lazy_static::lazy_static;
use tracing::{debug, warn};

use crate::cache;
use crate::client::DnsClient;
use crate::codec::{encoder::DNSMessageEncoder, message::RequestInfo};
use crate::specs::enums_generated::{OpCode, ResourceClass, ResourceType, ResponseCode};
use crate::specs::message::*;

lazy_static! {
    /// Encoder instance, currently doesn't have state
    static ref ENCODER: DNSMessageEncoder = DNSMessageEncoder::new();
}

/// A DNS resolver that queries from a set of upstream DNS sources.
pub struct Resolver {
    cache_tx: async_channel::Sender<cache::task::CacheMsg>,
    clients: Vec<Box<dyn DnsClient + Send>>,
    client_buffer: BytesMut,
}

impl Resolver {
    pub fn new(
        cache_tx: async_channel::Sender<cache::task::CacheMsg>,
        clients: Vec<Box<dyn DnsClient + Send>>,
    ) -> Self {
        Resolver {
            cache_tx,
            clients,
            client_buffer: BytesMut::with_capacity(4096),
        }
    }

    /// A high-level query for getting an A/AAAA record for a given hostname.
    pub async fn resolve_str(
        &mut self,
        host: &String,
        port: u16,
        get_ipv6: bool,
        udp_size: u16,
    ) -> Result<SocketAddr> {
        let resource_type = match get_ipv6 {
            true => ResourceType::AAAA,
            false => ResourceType::A,
        };

        // Ensure the host in the DNS message has the period:
        let mut host_with_period = host.clone();
        host_with_period.push('.');

        let request_info = RequestInfo {
            // The host used for cache/filter lookup meanwhile should not have the trailing period:
            name: host.clone(),
            resource_type,
            received_request_id: 0,
            requested_udp_size: udp_size,
        };

        let request = build_request(resource_type, host_with_period, udp_size);
        let response = self.resolve(&request, &request_info)
            .await
            .with_context(|| format!("Failed to resolve host {:?}", host))?;

        let results = extract_addresses(response, resource_type);
        debug!(
            "Resolved host={:?} type={:?}: {:?}",
            host, resource_type, results
        );
        Ok(SocketAddr::new(
            // If there are multiple IPs, just return the first one
            *results.get(0).with_context(|| {
                format!("No {:?} results for host: {:?}", resource_type, host)
            })?,
            port,
        ))
    }

    /// A low-level query that accepts and returns raw DNS query payloads.
    pub async fn resolve_raw(
        &mut self,
        request: &Message,
        request_info: &RequestInfo,
        response_buffer: &mut BytesMut,
    ) -> Result<()> {
        let response = self.resolve(request, request_info).await?;
        debug!("Response to client: {}", response);

        // Reencode the response to be returned, overriding udp_size to match the original request
        ENCODER.encode(
            &response,
            Some(request_info.requested_udp_size),
            response_buffer,
        )
    }

    async fn resolve(&mut self, request: &Message, request_info: &RequestInfo) -> Result<Message> {
        // Check if cache has cached result: Send request and wait for response via barrier+arc
        let result_barrier = Arc::new(Barrier::new(2));
        let result = Arc::new(Mutex::new(None));
        self.cache_tx.send(cache::task::CacheMsg::Fetch(cache::task::CacheFetch {
            request_info: (*request_info).clone(),
            result_barrier: result_barrier.clone(),
            result: result.clone(),
        })).await.context("Failed to send cache fetch query")?;

        // Wait on the barrier to complete
        result_barrier.wait().await;
        // Barrier has completed, get the stored result.
        // Do a swap to get the result out without yet another copy.
        match result.lock().await.replace(Ok(None)).expect("Missing fetch result following barrier") {
            Ok(Some(cache_result)) => {
                return Ok(cache_result)
            },
            Ok(None) => {
                // cache miss - continue with upstream queries below
            },
            Err(e) => {
                // cache fail - complain but continue with upstream queries
                warn!("Cache lookup failed for request {:?}: {}", request_info, e)
            },
        }

        // Cache didn't have anything, so query upstream clients.
        for client in &mut self.clients {
            // Mark the client buffer as empty so that we don't append on top of a prior request
            self.client_buffer.clear();
            if let Some(mut response) = client.query(request, &mut self.client_buffer).await? {
                // Store fetched result to cache (no response needed)
                self.cache_tx.send(cache::task::CacheMsg::Store(cache::task::CacheStore{
                    request_info: (*request_info).clone(),
                    response: response.clone(),
                })).await.context("Failed to send cache store query")?;

                // Set the message ID for the response so that it matches the original request.
                // Keeping the message IDs independent reduces the likelihood of cache poisoning.
                response.header.id = request_info.received_request_id;

                return Ok(response);
            }
        }
        bail!("All upstreams failed to return a response");
    }
}

fn build_request(resource_type: ResourceType, domain: String, udp_size: u16) -> Message {
    let mut question = Vec::new();
    question.push(Question {
        name: domain,
        resource_type: IntEnum::Enum(resource_type),
        resource_class: IntEnum::Enum(ResourceClass::INTERNET),
    });
    Message {
        header: Header {
            id: 0,
            is_response: false,
            op_code: IntEnum::Enum(OpCode::QUERY),
            authoritative: false,
            truncated: false,
            recursion_desired: true,
            recursion_available: false,
            reserved_9: false,
            authentic_data: true,
            checking_disabled: false,
            response_code: IntEnum::Enum(ResponseCode::NOERROR),
        },
        opt: Some(OPT {
            option: Vec::new(),
            udp_size,
            response_code: 0,
            version: 0,
            dnssec_ok: true,
        }),
        question,
        answer: Vec::new(),
        authority: Vec::new(),
        additional: Vec::new(),
    }
}

fn extract_addresses(message: Message, resource_type: ResourceType) -> Vec<IpAddr> {
    let mut results = Vec::with_capacity(message.answer.len());
    for answer in &message.answer {
        if answer.resource_type == IntEnum::Enum(resource_type) {
            match extract_address(answer, resource_type) {
                Some(addr) => results.push(addr),
                None => continue,
            }
        }
    }
    results
}

fn extract_address(answer: &Resource, resource_type: ResourceType) -> Option<IpAddr> {
    if resource_type == ResourceType::A {
        if let ResourceData::A(a) = &answer.rdata {
            Some(
                Ipv4Addr::new(a.address1, a.address2, a.address3, a.address4)
                    .into(),
            )
        } else {
            None
        }
    } else if resource_type == ResourceType::AAAA {
        if let ResourceData::AAAA(aaaa) = &answer.rdata {
            Some(
                Ipv6Addr::new(
                    aaaa.address1,
                    aaaa.address2,
                    aaaa.address3,
                    aaaa.address4,
                    aaaa.address5,
                    aaaa.address6,
                    aaaa.address7,
                    aaaa.address8,
                )
                    .into()
            )
        } else {
            None
        }
    } else {
        panic!(
            "Unsupported resource type for address extraction: {:?}",
            answer.resource_type
        );
    }
}