~nickbp/kapiti

ref: 9d460d0b7ab39ac1e39576b34fa6217049c65edb kapiti/src/listen_tcp.rs -rw-r--r-- 3.7 KiB
9d460d0bNick Parker Update dependencies to latest along with corresponding code changes 6 months ago
                                                                                
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
use std::convert::TryFrom;
use std::net::SocketAddr;
use std::time::Duration;

use anyhow::Result;
use async_net::{TcpListener, TcpStream};
use async_io::Timer;
use bytes::BytesMut;
use byteorder::{BigEndian, ByteOrder};
use futures_lite::{AsyncReadExt, AsyncWriteExt, FutureExt};
use tracing::{self, trace, warn};

use crate::{lookup, runner};

/// Listens for TCP requests until an accept fails or the channel sender is closed.
pub async fn listen_tcp(
    tcp_listener: &mut TcpListener,
    server_query_tx: &mut async_channel::Sender<runner::RequestMsg>,
) -> Result<()> {
    loop {
        let (tcp_stream, request_source) = tcp_listener.accept().await?;
        trace!("Queueing raw TCP stream from {:?}", request_source);
        let msg = runner::RequestMsg {
            src: request_source,
            data: runner::RequestData::Tcp(tcp_stream),
        };
        server_query_tx.send(msg).await?;
    }
}

pub async fn handle_tcp_request(
    server: &mut lookup::Lookup,
    timeout: &Duration,
    request_source: SocketAddr,
    mut tcp_stream: TcpStream,
    tcp_buf: &mut BytesMut,
) {
    // Read first two bytes to get expected request size
    let mut message_size_bytes: [u8; 2] = [0, 0];
    if let Err(ioerr) = tcp_stream
        .read_exact(&mut message_size_bytes)
        .or(async {
            Timer::after(*timeout).await;
            return Err(std::io::Error::new(
                std::io::ErrorKind::TimedOut,
                "TCP header read timed out",
            ));
        })
        .await
    {
        warn!(
            "Reading TCP header from client={} failed ({:?})",
            request_source, ioerr
        );
        return;
    };
    let request_size = BigEndian::read_u16(&message_size_bytes);
    tcp_buf.resize(
        usize::try_from(request_size).expect("couldn't convert u16 to usize"),
        0,
    );

    // Read the request itself into our correctly-sized buffer.
    if let Err(ioerr) = tcp_stream
        .read_exact(tcp_buf)
        .or(async {
            Timer::after(*timeout).await;
            return Err(std::io::Error::new(
                std::io::ErrorKind::TimedOut,
                "TCP header read timed out",
            ));
        })
        .await
    {
        warn!(
            "Reading TCP request from client={} failed ({:?})",
            request_source, ioerr
        );
        return;
    }

    if let Err(ioerr) = server.handle_query(tcp_buf).await {
        warn!(
            "Failed to handle TCP request from client={:?}: {:02X?} ({:?})",
            request_source,
            &tcp_buf[..],
            ioerr
        );
        return;
    }

    // Send the response back to the client, prefaced by the u16 payload size
    trace!("Raw response to {:?} ({}+2b): {:02X?}", request_source, tcp_buf.len(), &tcp_buf[..]);
    BigEndian::write_u16(&mut message_size_bytes, tcp_buf.len() as u16);
    if let Err(ioerr) = tcp_stream
        .write_all(&message_size_bytes)
        .or(async {
            Timer::after(*timeout).await;
            return Err(std::io::Error::new(
                std::io::ErrorKind::TimedOut,
                "TCP header write timed out",
            ));
        })
        .await
    {
        warn!(
            "Writing TCP header to client={} failed ({:?})",
            request_source, ioerr
        );
        return;
    }
    if let Err(ioerr) = tcp_stream
        .write_all(tcp_buf)
        .or(async {
            Timer::after(*timeout).await;
            return Err(std::io::Error::new(
                std::io::ErrorKind::TimedOut,
                "TCP response write timed out",
            ));
        })
        .await
    {
        warn!(
            "Writing TCP response to client={} failed ({:?})",
            request_source, ioerr
        );
    }
}