~hime/protochat

ref: 98d70f9aa232809379a9267ca9acf3a85b8748b0 protochat/linetest/src/shell/net.rs -rw-r--r-- 8.6 KiB
98d70f9a — drbawb (deps) bump crossterm to 0.17.7 2 years ago
                                                                                
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
use crate::shell::{self, Connection, ConnMsg};
use smolboi::protocol::packet::Packet;
use std::net::TcpStream;

use async_channel::Sender;
use async_dup::Arc;
use byteorder::{ReadBytesExt, WriteBytesExt, NetworkEndian};
use futures_util::{select, AsyncReadExt, AsyncWriteExt, StreamExt};
use smol::Async;
use thiserror::Error;

#[derive(Debug, Error)]
enum ClientError {
    #[error("an unspecified error has occurred.")]
    Unspecified,

    #[error("the packet was well-formed, but not expected at this time.")]
    InvalidState,
}

enum ClientState {
    Connecting,
    Connected,
}

struct Client {
    conn: Connection,
    outbox: Sender<ConnMsg>,
    sock: Arc<Async<TcpStream>>,
    state: ClientState,
}

impl Client {
    pub fn with(conn: Connection, sock: &Arc<Async<TcpStream>>, outbox: Sender<ConnMsg>) -> Self {
        Self {
            conn: conn,
            sock: sock.clone(),
            outbox: outbox,
            state: ClientState::Connecting,
        }
    }

    /// Updates client state based on incoming packet.
    /// The state change will be relayed to the parent task which created this client
    /// by way of the `outbox` parameter supplied to the `with` constructor.
    pub async fn handle_net_pkt(&mut self, pkt: Packet) -> anyhow::Result<()> {
        match self.state {
            ClientState::Connecting => match pkt {
                Packet::RegistrationAccepted { name } => {
                    // !!! use the name the server accepted !!!
                    //
                    // this avoids race where user tries to change nick before
                    // the server finishes sending the accept/reject packet for
                    // the outstanding request.
                    self.conn.name = name;
                    self.state = ClientState::Connected;
                    
                    let log_line = ConnMsg::LogLine { msg: "registered ... ok.".to_string() };
                    self.outbox.send(log_line).await?;
                    Ok(())
                },

                Packet::RegistrationRejected => {
                    let msg = ConnMsg::RejectTemp { 
                        reason: "[server]: registration rejected. try another username.".to_string() 
                    };

                    Ok(self.outbox.send(msg).await?)
                },

                Packet::Notice { body } => {
                    let msg = ConnMsg::LogLine { msg: format!("wtf: {}", body) };
                    Ok(self.outbox.send(msg).await?)
                },

                _ => Err(ClientError::InvalidState.into()),
            },

            ClientState::Connected => match pkt {
                Packet::MessageRoom { room, from, body } => {
                    let msg = ConnMsg::BufLine {
                        id: room,
                        msg: format!("{}: {}", from, body),
                    };

                    Ok(self.outbox.send(msg).await?)
                },

                Packet::Notice { body } => {
                    let msg = ConnMsg::LogLine { msg: body };
                    Ok(self.outbox.send(msg).await?)
                },

                _ => Err(ClientError::InvalidState.into()),
            },
        }
    }

    pub async fn handle_shell_msg(&mut self, msg: ConnMsg) -> anyhow::Result<()> {
        match self.state {
            ClientState::Connecting => match msg {
                ConnMsg::ChangeNick { name } => {
                    Ok(write_packet(&mut self.sock, Packet::Register { name: name }).await?)
                },

                _ => Err(ClientError::Unspecified.into()),
            },

            ClientState::Connected => match msg {
                ConnMsg::BufLine { id, msg } => {
                    let packet = Packet::MessageRoom { 
                        room: id,
                        from: self.conn.name.clone(),
                        body: msg,
                    };

                    Ok(write_packet(&mut self.sock, packet).await?)
                },

                ConnMsg::JoinRoom { room } => {
                    Ok(write_packet(&mut self.sock, Packet::Join {  room: room }).await?)
                },

                ConnMsg::PartRoom { room } => {
                    Ok(write_packet(&mut self.sock, Packet::Part { room: room }).await?)
                },

                _ => Err(ClientError::Unspecified.into()),
            },
        }
    }
}

async fn write_packet<W>(wtr: &mut W, packet: Packet) -> anyhow::Result<()>
    where W: AsyncWriteExt + Unpin {
        let packet_buf = packet.encode()?;
        let packet_len = packet_buf.len();
        assert!(packet_len <= u16::MAX as usize);

        let mut packet_sz = [0u8; 2];
        (&mut packet_sz[..]).write_u16::<NetworkEndian>(packet_len as u16)?;
        
        wtr.write(&packet_sz).await?;
        wtr.write(&packet_buf).await?;
        Ok(())
    }

async fn decode_packet<R>(stream: &mut R) -> anyhow::Result<Packet> 
where R: AsyncReadExt + Unpin {
    let mut packet_header_buf = [0u8; 2];
    stream.read_exact(&mut packet_header_buf).await?;

    let packet_sz = (&packet_header_buf[..]).read_u16::<NetworkEndian>()?;

    assert!(packet_sz <= u16::MAX);
    let mut packet_body_buf = vec![0u8; packet_sz as usize];
    stream.read_exact(&mut packet_body_buf).await?;

    Ok(Packet::decode(&packet_body_buf)?)
}

/// Starts a connection to the server specified in the background.
/// This returns control to the caller once the task has been sent to the
/// executor, at which point the caller can begin polling the associated
/// task mailbox.
/// 
/// Returns an error in the event the TCP socket could not be opeend.
pub async fn start_task(
    conn: Connection,
    shell_mx: shell::ConnectionMailbox,
) -> anyhow::Result<()> {
    let sock = Arc::new(Async::<TcpStream>::connect(conn.addr).await?);

    smol::Task::spawn(async {
        let task_outbox = shell_mx.outbox.clone();
        if let Err(msg) = net_worker_task(conn, sock, shell_mx).await {
            let error_log_ln = format!("BUG: net worker exited unexpectedly {:?}", msg);
            let _ = task_outbox.send(ConnMsg::LogLine { msg: error_log_ln }).await;
        }
    }).detach();

    Ok(())
}

async fn net_worker_task(
    conn: Connection,
    mut sock: Arc<Async<TcpStream>>,
    shell_mx: shell::ConnectionMailbox
) -> anyhow::Result<()> {
    // start packet decoding thread
    let (packets_tx, packets_rx) = async_channel::unbounded();
        
    let mut packet_decoding_sock = sock.clone();
    smol::Task::spawn(async move {
        loop {
            match decode_packet(&mut packet_decoding_sock).await {
                Ok(packet) => { 
                    if let Err(_) = packets_tx.send(packet).await {
                        panic!("BUG: sent packet but connection event loop is gone?");
                    }
                },

                Err(msg) => {
                    //warn!("malformed packet from client: {:?}", msg);
                    continue; // TODO: kill the client?
                },   
            }
        }
    }).detach();

    //send registration packet
    let packet = Packet::Register {
        name: conn.name.to_string(),
    };

    let packet_buf = packet.encode()?;
    if packet_buf.len() > u16::MAX as usize {
        panic!("BUG: user crafted >64KiB packet, this is not legal.");
    }

    let mut packet_sz = [0u8; 2];
    (&mut packet_sz[..]).write_u16::<NetworkEndian>(packet_buf.len() as u16)?;

    sock.write(&packet_sz).await?;
    sock.write(&packet_buf).await?;


    let mut packets_rx = packets_rx.fuse();
    let mut task_rx = shell_mx.inbox.fuse();

    let mut client = Client::with(conn, &sock, shell_mx.outbox.clone());

    loop {
        select! {
            packet = packets_rx.next() => match packet {
                Some(packet) => {
                    if let Err(msg) = client.handle_net_pkt(packet).await {
                        let error_log_ln = format!("client error: {:?}", msg);
                        let _ = shell_mx.outbox.send(ConnMsg::LogLine { msg: error_log_ln }).await;
                    }
                },

                None => { break },
            },

            task_msg = task_rx.next() => match task_msg {
                Some(task_msg) => {
                    if let Err(msg) = client.handle_shell_msg(task_msg).await {
                        let error_log_ln = format!("client error: {:?}", msg);
                        let _ = shell_mx.outbox.send(ConnMsg::LogLine { msg: error_log_ln }).await;
                    }
                },

                None => { break },
            },
        }
    }

    let log_line = ConnMsg::LogLine { msg: "warning! net worker hung up!".to_string() };
    Ok(shell_mx.outbox.send(log_line).await?)
}