~vpzom/shoved

a9688acd7ed6abd5b2175c0dc60567fdc088a7e1 — Colin Reeder 1 year, 19 days ago 5a8aef9
Somethat that really works
2 files changed, 115 insertions(+), 72 deletions(-)

A examples/register.rs
M src/main.rs
A examples/register.rs => examples/register.rs +35 -0
@@ 0,0 1,35 @@
use std::io::Write;

fn main() {
    let system = false;

    let exec = std::env::args().skip(1).next().expect("Missing executable");

    let client_socket_path = if system {
        std::path::PathBuf::from("/var/run/shoved.sock")
    } else {
        let mut path = dirs::runtime_dir().unwrap_or_else(|| std::path::PathBuf::from("/tmp"));
        path.push("shoved.sock");

        path
    };

    let mut stream = std::os::unix::net::UnixStream::connect(client_socket_path).unwrap();

    stream
        .write_all(
            &serde_json::to_vec(&jsonrpc_lite::JsonRpc::request_with_params(
                0,
                "register",
                serde_json::json!({
                    "exec": exec,
                }),
            ))
            .unwrap(),
        )
        .unwrap();

    stream.shutdown(std::net::Shutdown::Write).unwrap();

    std::io::copy(&mut stream, &mut std::io::stdout()).unwrap();
}

M src/main.rs => src/main.rs +80 -72
@@ 323,88 323,96 @@ async fn main() {
        res.unwrap()
    };

    futures_util::try_join!(
        handle_connection(uaid.clone(), existing_channel_ids.clone(), in_tx, out_rx),
        {
        async {
            let request_destination_map = request_destination_map.clone();
            while let Some(msg) = in_rx.recv().await {
                match msg {
                    ConnectionMessageIn::RegisterResponse { ref request_id, result: _ } => {
                        if let Some(destination) = request_destination_map.lock().unwrap().remove(&request_id) {
                            let _ = destination.send(msg); // if it's gone we don't care
                        }
                    }
                    ConnectionMessageIn::Notification { channel_id, data } => {
                        if let Err(err) = async {
                            let (uid, exec, auth_secret, private_key, public_key): (u32, String, Vec<u8>, Vec<u8>, Vec<u8>) = db.query_row("SELECT uid, exec, auth_secret, private_key, public_key FROM registration WHERE channel_id=?", (&channel_id,), |row| Ok((row.get(0)?, row.get(1)?, row.get(2)?, row.get(3)?, row.get(4)?)))?;
    let local_set = tokio::task::LocalSet::new();

    local_set.spawn_local(
        async move {
            futures_util::try_join!(
                handle_connection(uaid.clone(), existing_channel_ids.clone(), in_tx, out_rx),
                {
                async {
                    let request_destination_map = request_destination_map.clone();
                    while let Some(msg) = in_rx.recv().await {
                        match msg {
                            ConnectionMessageIn::RegisterResponse { ref request_id, result: _ } => {
                                if let Some(destination) = request_destination_map.lock().unwrap().remove(&request_id) {
                                    let _ = destination.send(msg); // if it's gone we don't care
                                }
                            }
                            ConnectionMessageIn::Notification { channel_id, data } => {
                                if let Err(err) = async {
                                    let (uid, exec, auth_secret, private_key, public_key): (u32, String, Vec<u8>, Vec<u8>, Vec<u8>) = db.query_row("SELECT uid, exec, auth_secret, private_key, public_key FROM registration WHERE channel_id=?", (&channel_id,), |row| Ok((row.get(0)?, row.get(1)?, row.get(2)?, row.get(3)?, row.get(4)?)))?;

                            let keypair = ece::EcKeyComponents::new(private_key, public_key);
                                    let keypair = ece::EcKeyComponents::new(private_key, public_key);

                            let data = BASE64_ENGINE.decode(&data)?;
                            let data = ece::decrypt(&keypair, &auth_secret, &data)?;
                                    let data = BASE64_ENGINE.decode(&data)?;
                                    let data = ece::decrypt(&keypair, &auth_secret, &data)?;

                            let child = std::process::Command::new(exec).uid(uid).stdin(std::process::Stdio::piped()).spawn()?;
                            child.stdin.unwrap().write_all(&data)?;
                                    let child = std::process::Command::new(exec).uid(uid).stdin(std::process::Stdio::piped()).spawn()?;
                                    child.stdin.unwrap().write_all(&data)?;

                            Result::<_, anyhow::Error>::Ok(())
                        }
                        .await {
                            eprintln!("Failed to handle notification: {:?}", err);
                        }
                    }
                    ConnectionMessageIn::SetUAID(new_uaid) => {
                        match &uaid {
                            None => {
                                db.execute("INSERT INTO state (uaid) VALUES (?)", (&new_uaid,))?;
                            }
                            Some(old_uaid) => {
                                if old_uaid != &new_uaid {
                                    db.execute("UPDATE state SET uaid=?", (&new_uaid,))?;
                                    Result::<_, anyhow::Error>::Ok(())
                                }
                                .await {
                                    eprintln!("Failed to handle notification: {:?}", err);
                                }
                            }
                            ConnectionMessageIn::SetUAID(new_uaid) => {
                                match &uaid {
                                    None => {
                                        db.execute("INSERT INTO state (uaid) VALUES (?)", (&new_uaid,))?;
                                    }
                                    Some(old_uaid) => {
                                        if old_uaid != &new_uaid {
                                            db.execute("UPDATE state SET uaid=?", (&new_uaid,))?;
                                        }
                                    }
                                }
                                uaid = Some(new_uaid);
                            },
                        }
                        uaid = Some(new_uaid);
                    },
                }
            }

            Ok(())
        }
        },
        async {
            loop {
                let (stream, _) = client_listener.accept().await?;

                let (sub_out_tx, mut sub_out_rx) = tokio::sync::mpsc::unbounded_channel();
                let (sub_in_tx, sub_in_rx) = tokio::sync::mpsc::unbounded_channel();

                let db = db.clone();
                tokio::task::spawn_local(async move {
                    if let Err(err) = handle_client(stream, db, sub_out_tx, sub_in_rx).await {
                        eprintln!("Error while handling client request: {:?}", err);
                    }
                });

                let request_destination_map = request_destination_map.clone();
                let out_tx = out_tx.clone();

                tokio::spawn(async move {
                    while let Some(msg) = sub_out_rx.recv().await {
                        match msg {
                            ConnectionMessageOut::Register { request_id } => {
                                request_destination_map.lock().unwrap().insert(request_id, sub_in_tx.clone());
                                out_tx.send(msg).unwrap(); // will handle this better once
                                                           // reconnections are implemented
                    Ok(())
                }
                },
                async {
                    loop {
                        let (stream, _) = client_listener.accept().await?;

                        let (sub_out_tx, mut sub_out_rx) = tokio::sync::mpsc::unbounded_channel();
                        let (sub_in_tx, sub_in_rx) = tokio::sync::mpsc::unbounded_channel();

                        let db = db.clone();
                        tokio::task::spawn_local(async move {
                            if let Err(err) = handle_client(stream, db, sub_out_tx, sub_in_rx).await {
                                eprintln!("Error while handling client request: {:?}", err);
                            }
                        }
                        });

                        let request_destination_map = request_destination_map.clone();
                        let out_tx = out_tx.clone();

                        tokio::spawn(async move {
                            while let Some(msg) = sub_out_rx.recv().await {
                                match msg {
                                    ConnectionMessageOut::Register { request_id } => {
                                        request_destination_map.lock().unwrap().insert(request_id, sub_in_tx.clone());
                                        out_tx.send(msg).unwrap(); // will handle this better once
                                                                   // reconnections are implemented
                                    }
                                }
                            }
                        });
                    }
                });
            }

            // need this to infer types
            #[allow(unreachable_code)]
            Ok(())
        },
    ).unwrap();
                    // need this to infer types
                    #[allow(unreachable_code)]
                    Ok(())
                },
            )
        }
    );

    local_set.await
}