~vpzom/shoved

5a8aef9688117e92e5d12e9840f9844f788a639d — Colin Reeder 1 year, 16 days ago a0bb2cb
Something that might work
4 files changed, 464 insertions(+), 37 deletions(-)

M Cargo.lock
M Cargo.toml
M src/main.rs
M src/types.rs
M Cargo.lock => Cargo.lock +143 -5
@@ 3,6 3,17 @@
version = 3

[[package]]
name = "ahash"
version = "0.8.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2c99f64d1e06488f620f932677e24bc6e2897582980441ae90a671415bd7ec2f"
dependencies = [
 "cfg-if",
 "once_cell",
 "version_check",
]

[[package]]
name = "anyhow"
version = "1.0.71"
source = "registry+https://github.com/rust-lang/crates.io-index"


@@ 33,6 44,12 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a"

[[package]]
name = "bitflags"
version = "2.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6776fc96284a0bb647b615056fc496d1fe1644a7ab01829818a6d91cae888b84"

[[package]]
name = "block-buffer"
version = "0.10.4"
source = "registry+https://github.com/rust-lang/crates.io-index"


@@ 118,6 135,27 @@ dependencies = [
]

[[package]]
name = "dirs"
version = "5.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "44c45a9d03d6676652bcb5e724c7e988de1acad23a711b5217ab9cbecbec2225"
dependencies = [
 "dirs-sys",
]

[[package]]
name = "dirs-sys"
version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "520f05a5cbd335fae5a99ff7a6ab8627577660ee5cfd6a94a6a929b52ff0321c"
dependencies = [
 "libc",
 "option-ext",
 "redox_users",
 "windows-sys 0.48.0",
]

[[package]]
name = "ece"
version = "2.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"


@@ 157,6 195,18 @@ dependencies = [
]

[[package]]
name = "fallible-iterator"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4443176a9f2c162692bd3d352d745ef9413eec5782a80d8fd6f8a1ac692a07f7"

[[package]]
name = "fallible-streaming-iterator"
version = "0.1.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7360491ce676a36bf9bb3c56c1aa791658183a54d2744120f27285738d90465a"

[[package]]
name = "fastrand"
version = "1.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"


@@ 261,6 311,24 @@ dependencies = [
]

[[package]]
name = "hashbrown"
version = "0.13.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "43a3c133739dddd0d2990f9a4bdf8eb4b21ef50e4851ca85ab661199821d510e"
dependencies = [
 "ahash",
]

[[package]]
name = "hashlink"
version = "0.8.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0761a1b9491c4f2e3d66aa0f62d0fba0af9a0e2852e4d48ea506632a4b56e6aa"
dependencies = [
 "hashbrown",
]

[[package]]
name = "hermit-abi"
version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"


@@ 344,6 412,17 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "453ad9f582a441959e5f0d088b02ce04cfe8d51a8eaf077f12ac6d3e94164ca6"

[[package]]
name = "jsonrpc-lite"
version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bb4128aba82294c14af2998831c4df3c843940e92b5cfc41bac1229d1e63b88c"
dependencies = [
 "serde",
 "serde_derive",
 "serde_json",
]

[[package]]
name = "lazy_static"
version = "1.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"


@@ 356,6 435,16 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2b00cc1c228a6782d0f076e7b232802e0c5689d41bb5df366f2a6b6621cfdfe1"

[[package]]
name = "libsqlite3-sys"
version = "0.26.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "afc22eff61b133b115c6e8c74e818c628d6d5e7a502afea6f64dee076dd94326"
dependencies = [
 "pkg-config",
 "vcpkg",
]

[[package]]
name = "linux-raw-sys"
version = "0.3.8"
source = "registry+https://github.com/rust-lang/crates.io-index"


@@ 408,7 497,7 @@ version = "0.10.52"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "01b8574602df80f7b85fdfc5392fa884a4e3b3f4f35402c070ab34c3d3f78d56"
dependencies = [
 "bitflags",
 "bitflags 1.3.2",
 "cfg-if",
 "foreign-types",
 "libc",


@@ 447,6 536,12 @@ dependencies = [
]

[[package]]
name = "option-ext"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "04744f49eae99ab78e0d5c0b603ab218f515ea8cfe5a456d7629ad883a3b6e7d"

[[package]]
name = "percent-encoding"
version = "2.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"


@@ 526,11 621,45 @@ dependencies = [

[[package]]
name = "redox_syscall"
version = "0.2.16"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fb5a58c1855b4b6819d59012155603f0b22ad30cad752600aadfcb695265519a"
dependencies = [
 "bitflags 1.3.2",
]

[[package]]
name = "redox_syscall"
version = "0.3.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "567664f262709473930a4bf9e51bf2ebf3348f2e748ccc50dea20646858f8f29"
dependencies = [
 "bitflags",
 "bitflags 1.3.2",
]

[[package]]
name = "redox_users"
version = "0.4.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b033d837a7cf162d7993aded9304e30a83213c648b6e389db233191f891e5c2b"
dependencies = [
 "getrandom",
 "redox_syscall 0.2.16",
 "thiserror",
]

[[package]]
name = "rusqlite"
version = "0.29.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "549b9d036d571d42e6e85d1c1425e2ac83491075078ca9a15be021c56b1641f2"
dependencies = [
 "bitflags 2.3.1",
 "fallible-iterator",
 "fallible-streaming-iterator",
 "hashlink",
 "libsqlite3-sys",
 "smallvec",
]

[[package]]


@@ 539,7 668,7 @@ version = "0.37.19"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "acf8729d8542766f1b2cf77eb034d52f40d375bb8b615d0b147089946e16613d"
dependencies = [
 "bitflags",
 "bitflags 1.3.2",
 "errno",
 "io-lifetimes",
 "libc",


@@ 568,7 697,7 @@ version = "2.9.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1fc758eb7bffce5b308734e9b0c1468893cae9ff70ebf13e7090be8dcbcc83a8"
dependencies = [
 "bitflags",
 "bitflags 1.3.2",
 "core-foundation",
 "core-foundation-sys",
 "libc",


@@ 644,8 773,11 @@ version = "0.1.0"
dependencies = [
 "anyhow",
 "base64 0.21.2",
 "dirs",
 "ece",
 "futures-util",
 "jsonrpc-lite",
 "rusqlite",
 "serde",
 "serde_json",
 "tokio",


@@ 663,6 795,12 @@ dependencies = [
]

[[package]]
name = "smallvec"
version = "1.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a507befe795404456341dfab10cef66ead4c041f62b8b11bbb92bffe5d0953e0"

[[package]]
name = "socket2"
version = "0.4.9"
source = "registry+https://github.com/rust-lang/crates.io-index"


@@ 697,7 835,7 @@ checksum = "b9fbec84f381d5795b08656e4912bec604d162bff9291d6189a78f4c8ab87998"
dependencies = [
 "cfg-if",
 "fastrand",
 "redox_syscall",
 "redox_syscall 0.3.5",
 "rustix",
 "windows-sys 0.45.0",
]

M Cargo.toml => Cargo.toml +3 -0
@@ 8,8 8,11 @@ edition = "2021"
[dependencies]
anyhow = "1.0.71"
base64 = "0.21.2"
dirs = "5.0.1"
ece = "2.2.0"
futures-util = { version = "0.3.28", features = ["sink"] }
jsonrpc-lite = "0.6.0"
rusqlite = "0.29.0"
serde = { version = "1.0.163", features = ["derive"] }
serde_json = "1.0.96"
tokio = { version = "1.28.2", features = ["macros", "rt", "sync"] }

M src/main.rs => src/main.rs +311 -30
@@ 1,5 1,11 @@
use base64::Engine;
use futures_util::{SinkExt, StreamExt, TryStreamExt};
use rusqlite::OptionalExtension;
use std::collections::HashMap;
use std::io::Write;
use std::os::unix::process::CommandExt;
use std::sync::{Arc, Mutex};
use tokio::io::{AsyncReadExt, AsyncWriteExt};

mod types;



@@ 12,10 18,37 @@ enum ConnectionMessageOut {

#[derive(Debug)]
enum ConnectionMessageIn {
    Notification { channel_id: String, data: String },
    SetUAID(String),
    RegisterResponse {
        request_id: u64,
        result: Result<RegisterResponse, anyhow::Error>,
    },
    Notification {
        channel_id: String,
        data: String,
    },
}

#[derive(Debug)]
struct RegisterResponse {
    endpoint: String,
    channel_id: String,
}

const REQUEST_ID_COUNTER: std::sync::atomic::AtomicU64 = std::sync::atomic::AtomicU64::new(0);

fn next_request_id() -> u64 {
    REQUEST_ID_COUNTER.fetch_add(1, std::sync::atomic::Ordering::Relaxed)
}

const BASE64_ENGINE: base64::engine::GeneralPurpose = base64::engine::GeneralPurpose::new(
    &base64::alphabet::URL_SAFE,
    base64::engine::general_purpose::NO_PAD,
);

async fn handle_connection(
    uaid: Option<String>,
    existing_channel_ids: Vec<String>,
    channel_tx: tokio::sync::mpsc::UnboundedSender<ConnectionMessageIn>,
    mut channel_rx: tokio::sync::mpsc::UnboundedReceiver<ConnectionMessageOut>,
) -> Result<(), anyhow::Error> {


@@ 23,7 56,7 @@ async fn handle_connection(

    println!("connected");

    let (sink, mut stream) = conn.split();
    let (sink, stream) = conn.split();

    let mut sink = sink.with(|msg: types::MozillaPushMessageA2S| {
        futures_util::future::ready(


@@ 33,7 66,9 @@ async fn handle_connection(
        )
    });

    let mut stream = stream
    let (sink_tx, mut sink_rx) = tokio::sync::mpsc::unbounded_channel();

    let stream = stream
        .map_err(anyhow::Error::from)
        .try_filter_map(|msg| async {
            match msg {


@@ 46,28 81,74 @@ async fn handle_connection(
        });

    sink.send(types::MozillaPushMessageA2S::Hello {
        uaid: None,
        channel_ids: vec![],
        uaid,
        channel_ids: existing_channel_ids,
        use_webpush: Some(true),
    })
    .await?;

    let register_map = Mutex::new(HashMap::<String, _>::new());

    futures_util::try_join!(
        async {
            while let Some(msg) = sink_rx.recv().await {
                sink.send(msg).await?;
            }

            Ok(())
        },
        stream.try_for_each(|msg| {
            let channel_tx = channel_tx.clone();
            let register_map = &register_map;
            let sink_tx = sink_tx.clone();
            async move {
                println!("{:?}", msg);

                match msg {
                    types::MozillaPushMessageS2A::Hello { uaid, status, .. } => {
                        if status != 200 {
                            anyhow::bail!("Unexpected status in Hello: {}", status);
                        }

                        channel_tx.send(ConnectionMessageIn::SetUAID(uaid))?;
                    }
                    types::MozillaPushMessageS2A::Notification {
                        channel_id,
                        version,
                        data: Some(data),
                        ..
                    } => {
                        channel_tx.send(ConnectionMessageIn::Notification {
                            channel_id,
                            channel_id: channel_id.clone(),
                            data: data.data,
                        })?;

                        sink_tx.send(types::MozillaPushMessageA2S::Ack {
                            updates: vec![types::AckUpdate {
                                channel_id,
                                version,
                            }],
                        })?;
                    }
                    types::MozillaPushMessageS2A::Register {
                        channel_id,
                        status,
                        push_endpoint,
                    } => {
                        if let Some(request_id) = register_map.lock().unwrap().remove(&channel_id) {
                            channel_tx.send(ConnectionMessageIn::RegisterResponse {
                                request_id,
                                result: match status {
                                    200 => Ok(RegisterResponse {
                                        endpoint: push_endpoint,
                                        channel_id,
                                    }),
                                    status => Err(anyhow::anyhow!(
                                        "Failed to register. Got status {}",
                                        status
                                    )),
                                },
                            })?;
                        }
                    }
                    _ => {}
                }


@@ 79,10 160,14 @@ async fn handle_connection(
            while let Some(msg) = channel_rx.recv().await {
                match msg {
                    ConnectionMessageOut::Register { request_id } => {
                        sink.send(types::MozillaPushMessageA2S::Register {
                            channel_id: uuid::Uuid::new_v4().to_string(),
                        })
                        .await?;
                        let channel_id = uuid::Uuid::new_v4().to_string();

                        register_map
                            .lock()
                            .unwrap()
                            .insert(channel_id.clone(), request_id);

                        sink_tx.send(types::MozillaPushMessageA2S::Register { channel_id })?;
                    }
                }
            }


@@ 94,36 179,232 @@ async fn handle_connection(
    Ok(())
}

async fn handle_client(
    mut stream: tokio::net::UnixStream,
    db: Arc<rusqlite::Connection>,
    out_tx: tokio::sync::mpsc::UnboundedSender<ConnectionMessageOut>,
    mut in_rx: tokio::sync::mpsc::UnboundedReceiver<ConnectionMessageIn>,
) -> Result<(), anyhow::Error> {
    // TODO actually return errors to client

    let uid = stream.peer_cred()?.uid();

    let mut body = Vec::new();
    stream.read_to_end(&mut body).await?;

    let msg: jsonrpc_lite::JsonRpc = serde_json::from_slice(&body)?;
    match msg {
        jsonrpc_lite::JsonRpc::Request(_) => {
            // guaranteed to be Some, I don't really like this API
            let id = msg.get_id().unwrap();
            let method = msg.get_method().unwrap();

            match method {
                "register" => {
                    let params = match msg.get_params() {
                        None => anyhow::bail!("Missing params"),
                        Some(params) => match params {
                            jsonrpc_lite::Params::Map(params) => serde_json::Value::Object(params),
                            _ => anyhow::bail!("Invalid params type"),
                        },
                    };
                    let params: types::ClientRequestRegisterParams =
                        serde_json::from_value(params)?;

                    let (keypair, auth_secret) = ece::generate_keypair_and_auth_secret().unwrap();

                    out_tx.send(ConnectionMessageOut::Register {
                        request_id: next_request_id(),
                    })?;

                    let msg = in_rx.recv().await;

                    let registration = if let Some(ConnectionMessageIn::RegisterResponse {
                        request_id: _,
                        result,
                    }) = msg
                    {
                        result?
                    } else {
                        anyhow::bail!("Unexpected or missing message");
                    };
                    let out_params = serde_json::json!({
                        "endpoint": registration.endpoint,
                        "keys": {
                            "auth": BASE64_ENGINE.encode(auth_secret),
                            "p256dh": BASE64_ENGINE.encode(keypair.pub_as_raw().unwrap()),
                        },
                    });

                    let keypair_components = keypair.raw_components()?;

                    db.execute(
                        "INSERT INTO registration (channel_id, uid, exec, auth_secret, private_key, public_key) VALUES (?, ?, ?, ?, ?, ?)",
                        (&registration.channel_id, uid, params.exec, auth_secret, keypair_components.private_key(), keypair_components.public_key()),
                    )?;

                    let out = jsonrpc_lite::JsonRpc::success(id, &out_params);
                    let bytes = serde_json::to_vec(&out)?;

                    stream.write_all(&bytes).await?;
                }
                _ => anyhow::bail!("Unknown method"),
            }
        }
        _ => {
            // probably shouldn't happen?
        }
    }

    Ok(())
}

#[tokio::main(flavor = "current_thread")]
async fn main() {
    let system = false;

    let db_path = {
        let db_dir_path = if system {
            std::path::PathBuf::from("/var/db")
        } else {
            dirs::data_local_dir().unwrap()
        };

        std::fs::create_dir_all(&db_dir_path).unwrap();

        db_dir_path.join("shoved.db")
    };

    let db = {
        // I know this isn't ideal, but I don't care
        let db_exists = db_path.exists();

        let db = rusqlite::Connection::open(db_path).unwrap();

        if !db_exists {
            db.execute("CREATE TABLE state (uaid TEXT NOT NULL)", ())
                .unwrap();
            db.execute("CREATE TABLE registration (channel_id TEXT PRIMARY KEY, uid INTEGER NOT NULL, exec TEXT NOT NULL, auth_secret BLOB NOT NULL, private_key BLOB NOT NULL, public_key BLOB NOT NULL)", ()).unwrap();
        }

        Arc::new(db)
    };

    let client_socket_path = if system {
        std::path::PathBuf::from("/var/run/shoved.sock")
    } else {
        let mut path = dirs::runtime_dir().unwrap_or_else(|| std::path::PathBuf::from("/tmp"));
        path.push("shoved.sock");

        path
    };

    std::fs::remove_file(&client_socket_path).unwrap();

    let client_listener = tokio::net::UnixListener::bind(client_socket_path).unwrap();

    // TODO abstract these to handle reconnections?
    let (in_tx, mut in_rx) = tokio::sync::mpsc::unbounded_channel();
    let (out_tx, out_rx) = tokio::sync::mpsc::unbounded_channel();
    let handle = tokio::spawn(handle_connection(in_tx, out_rx));

    let (keypair, auth_secret) = ece::generate_keypair_and_auth_secret().unwrap();
    let request_destination_map = Arc::new(Mutex::new(HashMap::<
        u64,
        tokio::sync::mpsc::UnboundedSender<ConnectionMessageIn>,
    >::new()));

    let b64e = base64::engine::GeneralPurpose::new(
        &base64::alphabet::URL_SAFE,
        base64::engine::general_purpose::NO_PAD,
    );
    let mut uaid: Option<String> = db
        .query_row("SELECT uaid FROM state LIMIT 1", (), |row| row.get(0))
        .optional()
        .unwrap();

    println!("{}", b64e.encode(keypair.pub_as_raw().unwrap()));
    println!("{}", b64e.encode(auth_secret));
    let existing_channel_ids: Vec<String> = {
        let mut stmt = db.prepare("SELECT channel_id FROM registration").unwrap();
        let res: Result<_, _> = stmt.query_map((), |row| row.get(0)).unwrap().collect();
        res.unwrap()
    };

    out_tx
        .send(ConnectionMessageOut::Register { request_id: 0 })
        .unwrap();
    futures_util::try_join!(
        handle_connection(uaid.clone(), existing_channel_ids.clone(), in_tx, out_rx),
        {
        async {
            let request_destination_map = request_destination_map.clone();
            while let Some(msg) = in_rx.recv().await {
                match msg {
                    ConnectionMessageIn::RegisterResponse { ref request_id, result: _ } => {
                        if let Some(destination) = request_destination_map.lock().unwrap().remove(&request_id) {
                            let _ = destination.send(msg); // if it's gone we don't care
                        }
                    }
                    ConnectionMessageIn::Notification { channel_id, data } => {
                        if let Err(err) = async {
                            let (uid, exec, auth_secret, private_key, public_key): (u32, String, Vec<u8>, Vec<u8>, Vec<u8>) = db.query_row("SELECT uid, exec, auth_secret, private_key, public_key FROM registration WHERE channel_id=?", (&channel_id,), |row| Ok((row.get(0)?, row.get(1)?, row.get(2)?, row.get(3)?, row.get(4)?)))?;

    tokio::spawn(async move {
        while let Some(msg) = in_rx.recv().await {
            match msg {
                ConnectionMessageIn::Notification { channel_id, data } => {
                    let data = ece::decrypt(&keypair.raw_components().unwrap(), &auth_secret, &b64e.decode(&data).unwrap());
                    println!("{:?}", data);
                            let keypair = ece::EcKeyComponents::new(private_key, public_key);

                            let data = BASE64_ENGINE.decode(&data)?;
                            let data = ece::decrypt(&keypair, &auth_secret, &data)?;

                            let child = std::process::Command::new(exec).uid(uid).stdin(std::process::Stdio::piped()).spawn()?;
                            child.stdin.unwrap().write_all(&data)?;

                            Result::<_, anyhow::Error>::Ok(())
                        }
                        .await {
                            eprintln!("Failed to handle notification: {:?}", err);
                        }
                    }
                    ConnectionMessageIn::SetUAID(new_uaid) => {
                        match &uaid {
                            None => {
                                db.execute("INSERT INTO state (uaid) VALUES (?)", (&new_uaid,))?;
                            }
                            Some(old_uaid) => {
                                if old_uaid != &new_uaid {
                                    db.execute("UPDATE state SET uaid=?", (&new_uaid,))?;
                                }
                            }
                        }
                        uaid = Some(new_uaid);
                    },
                }
            }

            Ok(())
        }
    });
        },
        async {
            loop {
                let (stream, _) = client_listener.accept().await?;

                let (sub_out_tx, mut sub_out_rx) = tokio::sync::mpsc::unbounded_channel();
                let (sub_in_tx, sub_in_rx) = tokio::sync::mpsc::unbounded_channel();

                let db = db.clone();
                tokio::task::spawn_local(async move {
                    if let Err(err) = handle_client(stream, db, sub_out_tx, sub_in_rx).await {
                        eprintln!("Error while handling client request: {:?}", err);
                    }
                });

                let request_destination_map = request_destination_map.clone();
                let out_tx = out_tx.clone();

    handle.await.unwrap().unwrap();
                tokio::spawn(async move {
                    while let Some(msg) = sub_out_rx.recv().await {
                        match msg {
                            ConnectionMessageOut::Register { request_id } => {
                                request_destination_map.lock().unwrap().insert(request_id, sub_in_tx.clone());
                                out_tx.send(msg).unwrap(); // will handle this better once
                                                           // reconnections are implemented
                            }
                        }
                    }
                });
            }

            // need this to infer types
            #[allow(unreachable_code)]
            Ok(())
        },
    ).unwrap();
}

M src/types.rs => src/types.rs +7 -2
@@ 29,8 29,8 @@ pub enum MozillaPushMessageA2S {
#[derive(Serialize, Debug)]
pub struct AckUpdate {
    #[serde(rename = "channelID")]
    channel_id: String,
    version: String,
    pub channel_id: String,
    pub version: String,
}

#[derive(Deserialize, Debug)]


@@ 68,3 68,8 @@ pub struct NotificationMessageData {
    pub headers: HashMap<String, String>,
    pub data: String,
}

#[derive(Deserialize, Debug)]
pub struct ClientRequestRegisterParams {
    pub exec: String,
}