~nickbp/kapiti

ref: 4697278f2ab267f47865a0385e0167c8ff401ead kapiti/src/client/https.rs -rw-r--r-- 6.9 KiB
4697278fNick Parker Give some files better names (http=>fetcher, then downloader=>updater to avoid confusion with fetcher) 4 months ago
                                                                                
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
use std::io::{self, Write};
use std::time::Duration;

use anyhow::{bail, Context, Result};
use async_io::Timer;
use async_trait::async_trait;
use bytes::{BufMut, BytesMut};
use futures_lite::FutureExt;
use hyper::header;
use hyper::{Body, Client as HttpClient, Method, Uri};
use lazy_static::lazy_static;
use rand::Rng;
use scopeguard;
use tracing::debug;

use crate::client::DnsClient;
use crate::codec::{decoder::DNSMessageDecoder, encoder::DNSMessageEncoder, message};
use crate::fetcher::Fetcher;
use crate::hyper_smol;
use crate::resolver;
use crate::specs::message::Message;

static MAX_HTTP_BYTES: u16 = 65535;

lazy_static! {
    /// Encoder instance, currently doesn't have state
    static ref ENCODER: DNSMessageEncoder = DNSMessageEncoder::new();
}

pub struct Client {
    server_url: Uri,
    fetcher: Fetcher,
    client: HttpClient<hyper_smol::SmolConnector>,
    timeout_ms: u64,
    response_buffer: BytesMut,
}

/// DNS Client that queries a server over HTTPS (DoH/RFC8484)
impl Client {
    /// Constructs a new `Client` that will query the specified `dns_server`.
    pub fn new(server_url: Uri, resolver: resolver::Resolver, timeout_ms: u64) -> Result<Self> {
        Ok(Client {
            server_url,
            fetcher: Fetcher::new(
                MAX_HTTP_BYTES as usize,
                Some("application/dns-message".to_string()),
            )
            // Note that hyper will reject requests with "request has unsupported HTTP version",
            // unless we ALSO set "http2_only(true)" in the Client builder.
            .use_http_2(),
            client: hyper_smol::client_kapiti(resolver, true, false, 4096),
            timeout_ms,
            response_buffer: BytesMut::with_capacity(MAX_HTTP_BYTES as usize),
        })
    }
}

#[async_trait]
impl DnsClient for Client {
    async fn query(&mut self, request: &Message, query_buffer: &mut BytesMut) -> Result<Option<Message>> {
        // Just use our max for the "udp size"
        ENCODER.encode(request, Some(MAX_HTTP_BYTES), query_buffer)?;

        let request_id = rand::thread_rng().gen::<u16>();
        message::update_message_id(request_id, query_buffer, 0)?;

        // Ensure that response_buffer size is reset when we're done with it,
        // regardless of success or error. The socket read is based on len, not capacity.
        let mut response_buffer = scopeguard::guard(&mut self.response_buffer, |buf| {
            buf.clear();
        });

        debug!(
            "Raw request to {} ({}b): {:02X?}",
            self.server_url,
            query_buffer.len(),
            &query_buffer[..]
        );

        // Hyper apparently requires a static copy of the body content? No idea why, and don't care to find out.
        // Maybe we can clean this up by switching to something with a better API someday.
        let request_copy = hyper::body::Bytes::from(query_buffer.clone());
        let request = self
            .fetcher
            .request_builder(&Method::POST, &self.server_url)
            .header(header::CONTENT_TYPE, "application/dns-message")
            .header(header::CONTENT_LENGTH, request_copy.len())
            .body(Body::from(request_copy))
            .context("Failed to build DoH request")?;

        // Copy for async call:
        let timeout_ms = self.timeout_ms;
        let mut response = self
            .client
            .request(request)
            .or(async {
                Timer::after(Duration::from_millis(timeout_ms)).await;
                // hyper keeps error types crate-private. Jump through hoops to produce an Ok response with an error.
                let response = hyper::Response::new(hyper::Body::empty());
                let (mut parts, body) = response.into_parts();
                parts.status = http::StatusCode::GATEWAY_TIMEOUT;
                Ok(hyper::Response::from_parts(parts, body))
            })
            .await
            .context("DoH query failed")?;

        if !response.status().is_success() {
            bail!(
                "HTTP POST to {} returned status: {}",
                self.server_url,
                response.status()
            );
        }

        {
            // Write response payload into response_buffer
            let mut writer = BytesWriter::new(&mut response_buffer);
            self.fetcher
                .write_response(&self.server_url.to_string(), &mut writer, &mut response)
                .await?;
        }

        debug!(
            "Raw response from {} ({}b): {:02X?}",
            self.server_url,
            response_buffer.len(),
            &response_buffer[..]
        );

        match DNSMessageDecoder::new().decode(&response_buffer[..]) {
            Ok(Some(response)) => {
                debug!("Response from {}: {}", self.server_url, response);

                if response.header.truncated {
                    // Message claims to be truncated, shouldn't happen but let's bail anyway
                    return Ok(None);
                }
                if response.header.id != request_id {
                    bail!(
                        "Returned transaction id {:?} doesn't match sent {:?}",
                        response.header.id,
                        request_id
                    );
                }

                Ok(Some(response))
            }
            Ok(None) => {
                // Message was likely corrupted somehow, despite us receiving all the data in the payload
                debug!(
                    "Unable to parse response from server={} to request={:02X?}: {:02X?}",
                    self.server_url,
                    &query_buffer[..],
                    &response_buffer[..],
                );
                Ok(None)
            }
            Err(e) => {
                // Other parse error
                Err(e).context(format!(
                    "Failed to parse response from server={} to request={:02X?}: {:02X?}",
                    self.server_url,
                    &query_buffer[..],
                    &response_buffer[..],
                ))
            }
        }
    }
}

/// Pass-through writer that counts the number of bytes that have been written.
/// Used to consistently measure the decompressed size of a download.
struct BytesWriter<'a> {
    inner: &'a mut BytesMut,
}

impl<'a> BytesWriter<'a> {
    fn new(inner: &'a mut BytesMut) -> BytesWriter<'a> {
        BytesWriter { inner }
    }
}

impl<'a> Write for BytesWriter<'a> {
    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
        if self.inner.remaining_mut() >= buf.len() {
            self.inner.put_slice(buf);
            Ok(buf.len())
        } else {
            Err(io::Error::new(
                io::ErrorKind::InvalidInput,
                format!(
                    "Unable to write {} bytes into buffer: {}/{} remaining",
                    buf.len(),
                    self.inner.remaining_mut(),
                    self.inner.len()
                ),
            ))
        }
    }

    fn flush(&mut self) -> io::Result<()> {
        Ok(())
    }
}