~nickbp/originz

ref: fbaed2a25114cf06aaa5d509c0d19d28ac4faa6d originz/src/client/udp.rs -rw-r--r-- 8.3 KiB
fbaed2a2Nick Parker Implement benchmark test for UDP client/UDP upstream (#10) 1 year, 10 months ago
                                                                                
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
#![deny(warnings, rust_2018_idioms)]

use std::net::{SocketAddr, ToSocketAddrs};
use std::time::Duration;

use anyhow::{bail, Context, Result};
use async_trait::async_trait;
use bytes::BytesMut;
use lazy_static::lazy_static;
use rand::Rng;
use scopeguard;
use tokio::net::UdpSocket;
use tokio::time;
use tracing::{debug, trace, warn};

use crate::client::DnsClient;
use crate::codec::{decoder::DNSMessageDecoder, encoder::DNSMessageEncoder, message};
use crate::fbs::dns_message_generated::Message;

pub struct Client {
    dns_server: SocketAddr,
    last_udp_size: u16,
    response_buffer: Vec<u8>,
    timeout_ms: u64,
}

/// DNS Client that queries a server over UDP with fallback to TCP.
/// The timeout logic uses base2 exponential retry (1s, 2s, 4s, ...)
/// If a UDP response comes back truncated, the client will automatically retry the request over TCP.
impl Client {
    /// Constructs a new `Client` that will query the specified `dns_server`.
    /// timeout_ms: Approximate timeout for requests, actual timeout will be the next base2-1 amount.
    ///             For example 10000 -> 15000 (1s + 2s + 4s + 8s+)
    pub fn new(dns_server: SocketAddr, timeout_ms: u64) -> Self {
        Client {
            dns_server,
            last_udp_size: 4096,
            response_buffer: vec![0; 4096],
            timeout_ms,
        }
    }
}

lazy_static! {
    /// Encoder instance, currently doesn't have state
    static ref ENCODER: DNSMessageEncoder = DNSMessageEncoder::new();
}

#[async_trait]
impl DnsClient for Client {
    fn encode(&self, request: &Message<'_>, request_buffer: &mut BytesMut) -> Result<()> {
        // Ensure the request to our server has the correct UDP size in the request.
        ENCODER.encode(request, Some(self.last_udp_size), request_buffer)
    }

    async fn query<'response>(
        &mut self,
        request_buffer: &mut BytesMut,
        response_fbb: &'response mut flatbuffers::FlatBufferBuilder<'_>,
    ) -> Result<Option<Message<'response>>> {
        let mut request_id: u16 = 0;
        {
            // Ensure that response_buffer is reset to size=4096 when we're done with it,
            // regardless of success or error. The socket read is based on len, not capacity.
            let mut response_buffer = scopeguard::guard(&mut self.response_buffer, |buf| {
                buf.resize(4096, 0);
            });
            let response_size: usize;
            {
                let mut response_buffer_slice = response_buffer.as_mut();
                response_size = send_recv_exponential_backoff(
                    &self.dns_server,
                    request_buffer,
                    &mut request_id,
                    &mut response_buffer_slice,
                    self.timeout_ms,
                )
                .await?;
                // Shorten to actual size received
                response_buffer.truncate(response_size);
            }

            trace!(
                "Raw response from {:?} ({}b): {:02X?}",
                self.dns_server,
                response_buffer.len(),
                &response_buffer[..]
            );

            match DNSMessageDecoder::new().decode(&response_buffer, response_fbb) {
                Ok(true) => {
                    // Decoded into response_fbb
                }
                Ok(false) => {
                    // Message was likely truncated, upstream can fall back to TCP
                    debug!(
                        "Unable to parse response from server={} to request={:02X?}: {:02X?}",
                        self.dns_server,
                        &request_buffer[..],
                        &response_buffer[..],
                    );
                    return Ok(None);
                }
                Err(e) => {
                    // Other parse error
                    return Err(e).context(format!(
                        "Failed to parse response from server={} to request={:02X?}: {:02X?}",
                        self.dns_server,
                        &request_buffer[..],
                        &response_buffer[..],
                    ));
                }
            }
        }

        let response: Message<'_> =
            flatbuffers::get_root::<Message<'_>>(response_fbb.finished_data());
        debug!("Response from {:?}: {}", self.dns_server, response);

        let response_id: u16;
        match response.header() {
            Some(header) => {
                if header.truncated() {
                    // Message claims to be truncated
                    return Ok(None);
                }
                response_id = header.id();
            }
            None => {
                bail!("Missing header in response");
            }
        }
        if response_id != request_id {
            bail!(
                "Returned transaction id {:?} doesn't match sent {:?}",
                response_id,
                request_id
            );
        }

        // After passing validation, update udp_size for the next request to this server.
        if let Some(opt) = response.opt() {
            trace!(
                "Using udp_size={} for server={}",
                opt.udp_size(),
                self.dns_server
            );
            self.last_udp_size = opt.udp_size();
        }

        Ok(Some(response))
    }
}

async fn send_recv_exponential_backoff(
    dest: &SocketAddr,
    request_buffer: &mut BytesMut,
    request_id: &mut u16,
    mut response_buffer: &mut [u8],
    total_timeout_ms: u64,
) -> Result<usize> {
    // Start at 1s, then 2s, then 4s, ...
    let mut remaining_timeout_ms = total_timeout_ms;
    let mut timeout_ms = 1000;
    loop {
        // NOTE: This assumes that port 0 results in a random port each time.
        // In particular we DONT want it to just increment by 1 or something each time.
        // Apparently this is OS-specific but Linux at least should do what we want.
        let client_addr = "0.0.0.0:0".to_socket_addrs()?.next().unwrap();
        let mut conn = UdpSocket::bind(client_addr).await?;

        // We regenerate the request ID on every retry. We're changing the client port each time,
        // so scrambling the request ID shouldn't result in "old" mismatched responses anyway.
        // This reduces the likelihood of someone trying to poison our cache by sending a request
        // and then flooding us with responses that match that request's message id.
        *request_id = rand::thread_rng().gen::<u16>();
        message::update_message_id(*request_id, request_buffer, 0)?;

        trace!(
            "Raw request to {:?} ({}b): {:02X?}",
            &dest,
            request_buffer.len(),
            &request_buffer[..]
        );
        // (Re)send request. Shouldn't time out but just in case...
        let _sendsize = time::timeout(
            Duration::from_millis(1000),
            conn.send_to(request_buffer.as_ref(), dest),
        )
        .await?;

        match time::timeout(
            Duration::from_millis(timeout_ms),
            conn.recv_from(&mut response_buffer),
        )
        .await
        {
            // Got a response from somewhere
            Ok(recv) => {
                let (recvsize, recvdest) = recv
                    // A different error occurred, give up
                    .with_context(|| format!("Failed to receive DNS response from {}", dest))?;
                // Before returning, check that the response is from who we're waiting for
                if *dest == recvdest {
                    return Ok(recvsize);
                }
                // If it doesn't match, resend and resume waiting, unless this was the last retry
                warn!(
                    "Response origin {:?} doesn't match request target {:?}",
                    recvdest, dest
                );
            }
            // Timeout occurred, try again (or exit loop)
            Err(_e) => {
                debug!("UDP request to {} timed out after {}ms", dest, timeout_ms);
            }
        }

        timeout_ms *= 2;
        if remaining_timeout_ms == 0 {
            // No retries left, give up
            bail!("Timed out waiting for response from {:?}", dest);
        } else if remaining_timeout_ms <= timeout_ms {
            // Last retry
            remaining_timeout_ms = 0;
        } else {
            // More retries left after this one
            remaining_timeout_ms -= timeout_ms;
        }
    }
}