~nickbp/kapiti

ref: 9d460d0b7ab39ac1e39576b34fa6217049c65edb kapiti/src/hyper_smol.rs -rw-r--r-- 16.0 KiB
9d460d0bNick Parker Update dependencies to latest along with corresponding code changes 2 months ago
                                                                                
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
use std::net::{Shutdown, SocketAddr, TcpStream, ToSocketAddrs};
use std::pin::Pin;
use std::str::FromStr;
use std::sync::Arc;
use std::task::{Context, Poll};

use anyhow::{anyhow, bail, Context as _, Error, Result};
use async_lock::{Barrier, Mutex};
use async_rustls::{TlsStream, webpki::DNSNameRef};
use http::Uri;
use hyper::{Body, Client};
use smol::{io, prelude::*, Async, Task};
use tokio::io::ReadBuf;
use tracing::{trace, warn};

use crate::resolver;

/// Returns a new Hyper HTTP client:
/// - Using smol for the connection and async runtime
/// - Using the provided Kapiti Resolver for resolving any hostnames
pub fn client_kapiti(
    mut resolver: resolver::Resolver,
    http2_only: bool,
    get_ipv6: bool,
    udp_size: u16,
) -> Client<SmolConnector> {
    let (resolver_tx, resolver_rx): (
        async_channel::Sender<ResolverQuery>,
        async_channel::Receiver<ResolverQuery>,
    ) = async_channel::bounded(32);

    // Create a separate task that will perform lookups on behalf of the SmolConnector.
    // This is mainly to get around locking + async issues across the spawned tasks for each query.
    // The only other alternative would be to create a whole new Resolver with each query.
    // Also, we don't worry about tracking the handle for this task because it should expire automatically when resolver_tx is dropped.
    let resolver_task = Arc::new(smol::spawn(async move {
        trace!("Internal resolver waiting for requests");
        // Returns Err when channel is closed and has no more messages
        while let Ok(msg) = resolver_rx.recv().await {
            trace!("Internal resolver: {}", msg.host);
            // If the "host" appears to already be an IP, return it as-is rather than trying to resolve it.
            // This effectively mirrors the behavior of the system resolver via client_system().
            // This is not in the user query path, and should only come up if the admin e.g. gives an upstream as "https://127.0.0.1"
            let endpoint_str = format!("{}:{}", msg.host, msg.port);
            if let Ok(lookup_result) = SocketAddr::from_str(endpoint_str.as_str()) {
                trace!(
                    "Internal resolver IP shortcut: {} = {:?}",
                    endpoint_str,
                    lookup_result
                );
                // Store the result, then notify the barrier
                msg.result.lock().await.replace(Ok(lookup_result));
                msg.result_barrier.wait().await;
            } else {
                // It doesn't look like a socket address, so do the resolve.
                let lookup_result = resolver
                    .resolve_str(&msg.host, msg.port, get_ipv6, udp_size)
                    .await;
                if let Err(e) = &lookup_result {
                    warn!("Internal resolver lookup failed: {:?}", e);
                } else {
                    trace!("Internal resolver: {} = {:?}", endpoint_str, lookup_result);
                }
                // Store the result, then notify the barrier
                msg.result.lock().await.replace(lookup_result);
                msg.result_barrier.wait().await;
            }
        }
        trace!("Internal resolver exiting");
    }));

    Client::builder()
        .executor(SmolExecutor)
        .http2_only(http2_only)
        .build::<_, Body>(SmolConnector {
            resolver_task,
            resolver_tx,
        })
}

/// Returns a new Hyper HTTP client:
/// - Using smol for the connection and async runtime
/// - Using the system resolver for resolving any hostnames
pub fn client_system(http2_only: bool) -> Client<SmolConnector> {
    let (resolver_tx, resolver_rx): (
        async_channel::Sender<ResolverQuery>,
        async_channel::Receiver<ResolverQuery>,
    ) = async_channel::bounded(32);

    // Create a separate task that will perform lookups on behalf of the SmolConnector.
    // This isn't strictly needed for the system resolver, but keeps things in line with client_kapiti().
    let resolver_task = Arc::new(smol::spawn(async move {
        trace!("System resolver waiting for requests");
        // Returns Err when channel is closed and has no more messages
        while let Ok(msg) = resolver_rx.recv().await {
            trace!("System resolver: {}", msg.host);
            let host = msg.host.clone();
            let port = msg.port;
            let lookup_result =
                match smol::unblock(move || (host.as_str(), port).to_socket_addrs()).await {
                    Ok(mut socket_addrs) => match socket_addrs.next() {
                        Some(socket_addr) => Ok(socket_addr),
                        None => Err(anyhow!("No results for hostname {}", msg.host)),
                    },
                    Err(e) => {
                        Err(e).with_context(|| format!("Failed to query for hostname {}", msg.host))
                    }
                };

            trace!("System resolver: {} = {:?}", msg.host, lookup_result);
            // Store the result, then notify the barrier
            msg.result.lock().await.replace(lookup_result);
            msg.result_barrier.wait().await;
        }
        trace!("System resolver exiting");
    }));

    Client::builder()
        .executor(SmolExecutor)
        .http2_only(http2_only)
        .build::<_, Body>(SmolConnector {
            resolver_task,
            resolver_tx,
        })
}

/// Spawns futures.
#[derive(Clone)]
struct SmolExecutor;

impl<F: Future + Send + 'static> hyper::rt::Executor<F> for SmolExecutor {
    fn execute(&self, fut: F) {
        smol::spawn(async { drop(fut.await) }).detach();
    }
}

/// The request for a host to be resolved, along with an output for returning the response.
#[derive(Debug)]
struct ResolverQuery {
    /// The hostname to look up
    host: String,
    /// The port to include in the resolved result
    port: u16,
    /// Barrier to wait for the result to appear. The requestor should wait on this before accessing result.
    result_barrier: Arc<Barrier>,
    /// Where the result should go. Should be an error if the hostname could not be resolved (e.g. not found).
    result: Arc<Mutex<Option<Result<SocketAddr>>>>,
}

/// Connects to URLs.
#[derive(Clone)]
pub struct SmolConnector {
    /// Handle to keep the resolver task from dying prematurely
    resolver_task: Arc<Task<()>>,
    /// Channel for sending requests to the resolver task
    resolver_tx: async_channel::Sender<ResolverQuery>,
}

impl hyper::service::Service<Uri> for SmolConnector {
    type Response = SmolStream;
    type Error = Error;
    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;

    fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
        Poll::Ready(Ok(()))
    }

    fn call(&mut self, uri: Uri) -> Self::Future {
        // Get copy for async move:
        let resolver_tx_copy = self.resolver_tx.clone();

        Box::pin(async move {
            let host = uri
                .host()
                .with_context(|| format!("Cannot parse host: {:?}", uri))?;
            // Release when both the requestor and requestee have called wait
            let result_barrier = Arc::new(Barrier::new(2));
            // Where the requestee will store the result before calling result_barrier.wait
            let result = Arc::new(Mutex::new(None));
            match uri.scheme_str() {
                Some("http") => {
                    // Send lookup, with place to give us back the result:
                    trace!("HTTP lookup: {}", host);
                    let query = ResolverQuery {
                        host: host.to_string(),
                        port: uri.port_u16().unwrap_or(80),
                        result_barrier: result_barrier.clone(),
                        result: result.clone(),
                    };
                    resolver_tx_copy
                        .send(query)
                        .await
                        .context("Failed to send HTTP resolver query")?;
                    trace!("HTTP lookup sent");
                    // Wait on the barrier to complete
                    result_barrier.wait().await;
                    // Barrier has completed, get the stored result.
                    // Jump through weird reference hoops to get the SocketAddr value out of the mutex.
                    match result
                        .lock()
                        .await
                        .as_ref()
                        .expect("Missing resolve result following barrier")
                    {
                        Ok(socket_addr) => {
                            let stream = Async::<TcpStream>::connect(socket_addr.clone()).await?;
                            Ok(SmolStream::Plain(stream))
                        }
                        // e is a reference, and anyhow doesn't like using with_context with it. So just give up and create a new error.
                        Err(e) => Err(anyhow!("Failed to resolve host {:?}: {}", uri, e)),
                    }
                }
                Some("https") => {
                    // Send lookup, with place to give us back the result:
                    trace!("HTTPS lookup: {}", host);
                    let query = ResolverQuery {
                        host: host.to_string(),
                        port: uri.port_u16().unwrap_or(443),
                        result_barrier: result_barrier.clone(),
                        result: result.clone(),
                    };
                    resolver_tx_copy
                        .send(query)
                        .await
                        .context("Failed to send HTTPS resolver query")?;
                    trace!("HTTPS lookup sent");
                    // Wait on the barrier to complete
                    result_barrier.wait().await;
                    // Barrier has completed, get the stored result.
                    // Jump through weird reference hoops to get the SocketAddr value out of the mutex.
                    match result
                        .lock()
                        .await
                        .as_ref()
                        .expect("Missing resolve result following barrier")
                    {
                        Ok(socket_addr) => {
                            let stream = Async::<TcpStream>::connect(socket_addr.clone()).await?;
                            let mut client_config = rustls::ClientConfig::new();
                            // Required for http2/DoH, otherwise we get 'http2 error: protocol error: frame with invalid size'
                            // In particular, we're stuck with rustls for now because rust-native-tls doesn't support configuring this.
                            client_config.alpn_protocols = vec![b"h2".to_vec()];
                            match rustls_native_certs::load_native_certs() {
                                Ok(certs) => {
                                    client_config.root_store = certs;
                                }
                                Err((Some(certs), e)) => {
                                    warn!(
                                        "Some TLS certificates failed to load, trying to continue without them: {:?}",
                                        e
                                    );
                                    client_config.root_store = certs;
                                }
                                Err((None, e)) => {
                                    return Err(e).context("Failed to load native TLS cert store");
                                }
                            }
                            // Disabled for now: Was previously using ct-logs package, but bizarre API mismatches broke it.
                            client_config.ct_logs = None;
                            let connector =
                                async_rustls::TlsConnector::from(Arc::new(client_config));
                            if let Ok(dns_name) = webpki::DnsNameRef::try_from_ascii_str(host) {
                                // Convert new webpki::DnsNameRef (webpki 0.22) to old webpki::DNSNameRef (webpki 0.21) used by async-rustls
                                // See also https://github.com/smol-rs/async-rustls/blob/master/Cargo.toml
                                let dns_name_old: DNSNameRef = DNSNameRef::try_from_ascii(dns_name.as_ref())
                                    .expect("Converting new DnsNameRef to old DNSNameRef failed");
                                let stream = connector.connect(dns_name_old, stream).await?;
                                Ok(SmolStream::Tls(async_rustls::TlsStream::Client(stream)))
                            } else {
                                // Uh-oh, looks like we're trying to connect to an IP. Explain the issue with the underlying library and how to work around it.
                                bail!(
                                    "Unable to parse TLS endpoint: {}
rustls/webpki still don't support IP endpoints. See also: https://github.com/briansmith/webpki/issues/54 and https://github.com/ctz/rustls/issues/184
Try using a hostname instead, e.g. 9.9.9.9 => dns.quad9.net, 8.8.8.8 => dns.google, or 1.1.1.1 => cloudflare-dns.com
If this hostname is a DoH or DoT upstream, you will also need to include at least one IP-based 'regular' UDP/TCP upstream as a fallback so that the DoH or DoT hostname can itself be resolved.",
                                    host
                                );
                            }
                        }
                        // e is a reference, and anyhow doesn't like using with_context with it. So just give up and create a new error.
                        Err(e) => Err(anyhow!("Failed to resolve host {:?}: {}", uri, e)),
                    }
                }
                scheme => bail!("Unsupported scheme: {:?}", scheme),
            }
        })
    }
}

/// A TCP or TCP+TLS connection.
pub enum SmolStream {
    /// A plain TCP connection.
    Plain(Async<TcpStream>),

    /// A TCP connection secured by TLS.
    Tls(TlsStream<Async<TcpStream>>),
}

impl hyper::client::connect::Connection for SmolStream {
    fn connected(&self) -> hyper::client::connect::Connected {
        hyper::client::connect::Connected::new()
    }
}

impl tokio::io::AsyncRead for SmolStream {
    fn poll_read(
        mut self: Pin<&mut Self>,
        cx: &mut Context<'_>,
        buf: &mut ReadBuf<'_>,
    ) -> Poll<io::Result<()>> {
        // TODO what's this actually doing?
        // - invoked by hyper(?)
        // - invoking underlying std TcpStream(?)
        // TODO looks like we can accept the size from the return value and use that to update the offset in buf?
        match &mut *self {
            SmolStream::Plain(s) => {
                Pin::new(s)
                    .poll_read(cx, buf.initialize_unfilled())
                    .map_ok(|size| {
                        buf.advance(size);
                        ()
                    })
            }
            SmolStream::Tls(s) => {
                Pin::new(s)
                    .poll_read(cx, buf.initialize_unfilled())
                    .map_ok(|size| {
                        buf.advance(size);
                        ()
                    })
            }
        }
    }
}

impl tokio::io::AsyncWrite for SmolStream {
    fn poll_write(
        mut self: Pin<&mut Self>,
        cx: &mut Context<'_>,
        buf: &[u8],
    ) -> Poll<io::Result<usize>> {
        match &mut *self {
            SmolStream::Plain(s) => Pin::new(s).poll_write(cx, buf),
            SmolStream::Tls(s) => Pin::new(s).poll_write(cx, buf),
        }
    }

    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
        match &mut *self {
            SmolStream::Plain(s) => Pin::new(s).poll_flush(cx),
            SmolStream::Tls(s) => Pin::new(s).poll_flush(cx),
        }
    }

    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
        match &mut *self {
            SmolStream::Plain(s) => {
                s.get_ref().shutdown(Shutdown::Write)?;
                Poll::Ready(Ok(()))
            }
            SmolStream::Tls(s) => Pin::new(s).poll_close(cx),
        }
    }
}