use std::convert::TryFrom;
use std::net::SocketAddr;
use std::time::Duration;
use anyhow::Result;
use async_net::{TcpListener, TcpStream};
use async_io::Timer;
use bytes::BytesMut;
use byteorder::{BigEndian, ByteOrder};
use futures_lite::{AsyncReadExt, AsyncWriteExt, FutureExt};
use tracing::{self, trace, warn};
use crate::{lookup, runner};
/// Listens for TCP requests until an accept fails or the channel sender is closed.
pub async fn listen_tcp(
tcp_listener: &mut TcpListener,
server_query_tx: &mut async_channel::Sender<runner::RequestMsg>,
) -> Result<()> {
loop {
let (tcp_stream, request_source) = tcp_listener.accept().await?;
trace!("Queueing raw TCP stream from {:?}", request_source);
let msg = runner::RequestMsg {
src: request_source,
data: runner::RequestData::Tcp(tcp_stream),
};
server_query_tx.send(msg).await?;
}
}
pub async fn handle_tcp_request(
server: &mut lookup::Lookup,
timeout: &Duration,
request_source: SocketAddr,
mut tcp_stream: TcpStream,
tcp_buf: &mut BytesMut,
) {
// Read first two bytes to get expected request size
let mut message_size_bytes: [u8; 2] = [0, 0];
if let Err(ioerr) = tcp_stream
.read_exact(&mut message_size_bytes)
.or(async {
Timer::after(*timeout).await;
return Err(std::io::Error::new(
std::io::ErrorKind::TimedOut,
"TCP header read timed out",
));
})
.await
{
warn!(
"Reading TCP header from client={} failed ({:?})",
request_source, ioerr
);
return;
};
let request_size = BigEndian::read_u16(&message_size_bytes);
tcp_buf.resize(
usize::try_from(request_size).expect("couldn't convert u16 to usize"),
0,
);
// Read the request itself into our correctly-sized buffer.
if let Err(ioerr) = tcp_stream
.read_exact(tcp_buf)
.or(async {
Timer::after(*timeout).await;
return Err(std::io::Error::new(
std::io::ErrorKind::TimedOut,
"TCP header read timed out",
));
})
.await
{
warn!(
"Reading TCP request from client={} failed ({:?})",
request_source, ioerr
);
return;
}
if let Err(ioerr) = server.handle_query(tcp_buf).await {
warn!(
"Failed to handle TCP request from client={:?}: {:02X?} ({:?})",
request_source,
&tcp_buf[..],
ioerr
);
return;
}
// Send the response back to the client, prefaced by the u16 payload size
trace!("Raw response to {:?} ({}+2b): {:02X?}", request_source, tcp_buf.len(), &tcp_buf[..]);
BigEndian::write_u16(&mut message_size_bytes, tcp_buf.len() as u16);
if let Err(ioerr) = tcp_stream
.write_all(&message_size_bytes)
.or(async {
Timer::after(*timeout).await;
return Err(std::io::Error::new(
std::io::ErrorKind::TimedOut,
"TCP header write timed out",
));
})
.await
{
warn!(
"Writing TCP header to client={} failed ({:?})",
request_source, ioerr
);
return;
}
if let Err(ioerr) = tcp_stream
.write_all(tcp_buf)
.or(async {
Timer::after(*timeout).await;
return Err(std::io::Error::new(
std::io::ErrorKind::TimedOut,
"TCP response write timed out",
));
})
.await
{
warn!(
"Writing TCP response to client={} failed ({:?})",
request_source, ioerr
);
}
}