use crate::shell::{self, Connection, ConnMsg};
use smolboi::protocol::packet::Packet;
use std::net::TcpStream;
use async_channel::Sender;
use async_dup::Arc;
use byteorder::{ReadBytesExt, WriteBytesExt, NetworkEndian};
use futures_util::{select, AsyncReadExt, AsyncWriteExt, StreamExt};
use smol::Async;
use thiserror::Error;
#[derive(Debug, Error)]
enum ClientError {
#[error("an unspecified error has occurred.")]
Unspecified,
#[error("the packet was well-formed, but not expected at this time.")]
InvalidState,
}
enum ClientState {
Connecting,
Connected,
}
struct Client {
conn: Connection,
outbox: Sender<ConnMsg>,
sock: Arc<Async<TcpStream>>,
state: ClientState,
}
impl Client {
pub fn with(conn: Connection, sock: &Arc<Async<TcpStream>>, outbox: Sender<ConnMsg>) -> Self {
Self {
conn: conn,
sock: sock.clone(),
outbox: outbox,
state: ClientState::Connecting,
}
}
/// Updates client state based on incoming packet.
/// The state change will be relayed to the parent task which created this client
/// by way of the `outbox` parameter supplied to the `with` constructor.
pub async fn handle_net_pkt(&mut self, pkt: Packet) -> anyhow::Result<()> {
match self.state {
ClientState::Connecting => match pkt {
Packet::RegistrationAccepted { name } => {
// !!! use the name the server accepted !!!
//
// this avoids race where user tries to change nick before
// the server finishes sending the accept/reject packet for
// the outstanding request.
self.conn.name = name;
self.state = ClientState::Connected;
let log_line = ConnMsg::LogLine { msg: "registered ... ok.".to_string() };
self.outbox.send(log_line).await?;
Ok(())
},
Packet::RegistrationRejected => {
let msg = ConnMsg::RejectTemp {
reason: "[server]: registration rejected. try another username.".to_string()
};
Ok(self.outbox.send(msg).await?)
},
Packet::Notice { body } => {
let msg = ConnMsg::LogLine { msg: format!("wtf: {}", body) };
Ok(self.outbox.send(msg).await?)
},
_ => Err(ClientError::InvalidState.into()),
},
ClientState::Connected => match pkt {
Packet::MessageRoom { room, from, body } => {
let msg = ConnMsg::BufLine {
id: room,
msg: format!("{}: {}", from, body),
};
Ok(self.outbox.send(msg).await?)
},
Packet::Notice { body } => {
let msg = ConnMsg::LogLine { msg: body };
Ok(self.outbox.send(msg).await?)
},
_ => Err(ClientError::InvalidState.into()),
},
}
}
pub async fn handle_shell_msg(&mut self, msg: ConnMsg) -> anyhow::Result<()> {
match self.state {
ClientState::Connecting => match msg {
ConnMsg::ChangeNick { name } => {
Ok(write_packet(&mut self.sock, Packet::Register { name: name }).await?)
},
_ => Err(ClientError::Unspecified.into()),
},
ClientState::Connected => match msg {
ConnMsg::BufLine { id, msg } => {
let packet = Packet::MessageRoom {
room: id,
from: self.conn.name.clone(),
body: msg,
};
Ok(write_packet(&mut self.sock, packet).await?)
},
ConnMsg::JoinRoom { room } => {
Ok(write_packet(&mut self.sock, Packet::Join { room: room }).await?)
},
ConnMsg::PartRoom { room } => {
Ok(write_packet(&mut self.sock, Packet::Part { room: room }).await?)
},
_ => Err(ClientError::Unspecified.into()),
},
}
}
}
async fn write_packet<W>(wtr: &mut W, packet: Packet) -> anyhow::Result<()>
where W: AsyncWriteExt + Unpin {
let packet_buf = packet.encode()?;
let packet_len = packet_buf.len();
assert!(packet_len <= u16::MAX as usize);
let mut packet_sz = [0u8; 2];
(&mut packet_sz[..]).write_u16::<NetworkEndian>(packet_len as u16)?;
wtr.write(&packet_sz).await?;
wtr.write(&packet_buf).await?;
Ok(())
}
async fn decode_packet<R>(stream: &mut R) -> anyhow::Result<Packet>
where R: AsyncReadExt + Unpin {
let mut packet_header_buf = [0u8; 2];
stream.read_exact(&mut packet_header_buf).await?;
let packet_sz = (&packet_header_buf[..]).read_u16::<NetworkEndian>()?;
assert!(packet_sz <= u16::MAX);
let mut packet_body_buf = vec![0u8; packet_sz as usize];
stream.read_exact(&mut packet_body_buf).await?;
Ok(Packet::decode(&packet_body_buf)?)
}
/// Starts a connection to the server specified in the background.
/// This returns control to the caller once the task has been sent to the
/// executor, at which point the caller can begin polling the associated
/// task mailbox.
///
/// Returns an error in the event the TCP socket could not be opeend.
pub async fn start_task(
conn: Connection,
shell_mx: shell::ConnectionMailbox,
) -> anyhow::Result<()> {
let sock = Arc::new(Async::<TcpStream>::connect(conn.addr).await?);
smol::Task::spawn(async {
let task_outbox = shell_mx.outbox.clone();
if let Err(msg) = net_worker_task(conn, sock, shell_mx).await {
let error_log_ln = format!("BUG: net worker exited unexpectedly {:?}", msg);
let _ = task_outbox.send(ConnMsg::LogLine { msg: error_log_ln }).await;
}
}).detach();
Ok(())
}
async fn net_worker_task(
conn: Connection,
mut sock: Arc<Async<TcpStream>>,
shell_mx: shell::ConnectionMailbox
) -> anyhow::Result<()> {
// start packet decoding thread
let (packets_tx, packets_rx) = async_channel::unbounded();
let (quit_tx, quit_rx) = async_channel::unbounded::<void::Void>();
// copy these fields so they don't move into task
let mut packet_decoding_sock = sock.clone();
let packet_shell_outbox = shell_mx.outbox.clone();
let packet_conn_id = conn.id;
smol::Task::spawn(async move {
loop {
match decode_packet(&mut packet_decoding_sock).await {
Ok(packet) => {
if let Err(_) = packets_tx.send(packet).await {
panic!("BUG: sent packet but connection event loop is gone?");
}
},
Err(msg) => {
let log_line = ConnMsg::LogLine {
msg: format!("FATAL: conn({}) received malformed packet: {:?}.", packet_conn_id, msg),
};
let _ = packet_shell_outbox.send(log_line).await;
break;
},
}
}
drop(quit_tx);
}).detach();
// send out registration packet
let packet = Packet::Register {
name: conn.name.to_string(),
};
let packet_buf = packet.encode()?;
if packet_buf.len() > u16::MAX as usize {
panic!("BUG: user crafted >64KiB packet, this is not legal.");
}
let mut packet_sz = [0u8; 2];
(&mut packet_sz[..]).write_u16::<NetworkEndian>(packet_buf.len() as u16)?;
sock.write(&packet_sz).await?;
sock.write(&packet_buf).await?;
let mut packets_rx = packets_rx.fuse();
let mut task_rx = shell_mx.inbox.fuse();
let mut quit_rx = quit_rx.fuse();
let mut client = Client::with(conn, &sock, shell_mx.outbox.clone());
loop {
select! {
packet = packets_rx.next() => match packet {
Some(packet) => {
if let Err(msg) = client.handle_net_pkt(packet).await {
let error_log_ln = format!("client error: {:?}", msg);
let _ = shell_mx.outbox.send(ConnMsg::LogLine { msg: error_log_ln }).await;
}
},
None => { break },
},
task_msg = task_rx.next() => match task_msg {
Some(task_msg) => {
if let Err(msg) = client.handle_shell_msg(task_msg).await {
let error_log_ln = format!("client error: {:?}", msg);
let _ = shell_mx.outbox.send(ConnMsg::LogLine { msg: error_log_ln }).await;
}
},
None => { break },
},
quit = quit_rx.next() => if quit.is_none() { break },
}
}
let log_line = ConnMsg::LogLine { msg: "warning! net worker hung up!".to_string() };
Ok(shell_mx.outbox.send(log_line).await?)
}