~vpzom/shoved

5fa7f7da94d5955dd9e3e60daddda696e1af8118 — Colin Reeder 4 months ago 1f1fa68 master v0.1.0
Switch to boot-time for timeouts
3 files changed, 102 insertions(+), 93 deletions(-)

M Cargo.lock
M Cargo.toml
M src/main.rs
M Cargo.lock => Cargo.lock +16 -6
@@ 781,7 781,7 @@ dependencies = [
 "serde",
 "serde_json",
 "tokio",
 "tokio-stream",
 "tokio-timerfd",
 "tokio-tungstenite",
 "uuid",
]


@@ 862,6 862,15 @@ dependencies = [
]

[[package]]
name = "timerfd"
version = "1.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3d3fd47d83ad0b5c7be2e8db0b9d712901ef6ce5afbcc6f676761004f5104ea2"
dependencies = [
 "rustix",
]

[[package]]
name = "tinyvec"
version = "1.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"


@@ 914,13 923,14 @@ dependencies = [
]

[[package]]
name = "tokio-stream"
version = "0.1.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "397c988d37662c7dda6d2208364a706264bf3d6138b11d436cbac0ad38832842"
name = "tokio-timerfd"
version = "0.2.0"
source = "git+https://github.com/polachok/tokio-timerfd?rev=4064e77#4064e777f145ce71e937318527460bfcf2da1c4f"
dependencies = [
 "futures-core",
 "pin-project-lite",
 "libc",
 "slab",
 "timerfd",
 "tokio",
]


M Cargo.toml => Cargo.toml +1 -1
@@ 23,6 23,6 @@ rusqlite = "0.29.0"
serde = { version = "1.0.163", features = ["derive"] }
serde_json = "1.0.96"
tokio = { version = "1.28.2", features = ["macros", "rt", "sync"] }
tokio-stream = "0.1.14"
tokio-timerfd = { git = "https://github.com/polachok/tokio-timerfd", rev = "4064e77", features = ["boottime"] }
tokio-tungstenite = { version = "0.19.0", features = ["native-tls"] }
uuid = { version = "1.3.3", features = ["v4"] }

M src/main.rs => src/main.rs +85 -86
@@ 1,5 1,5 @@
use base64::Engine;
use futures_util::{SinkExt, StreamExt, TryStreamExt};
use futures_util::{SinkExt, StreamExt};
use rusqlite::OptionalExtension;
use std::collections::HashMap;
use std::io::Write;


@@ 63,7 63,7 @@ async fn handle_connection(
    let (conn, _) = tokio_tungstenite::connect_async(SERVER_URL).await?;

    println!("connected");
    let (sink, stream) = conn.split();
    let (sink, mut stream) = conn.split();

    let mut sink = sink.with(|msg: types::MozillaPushMessageA2S| {
        futures_util::future::ready(


@@ 75,24 75,6 @@ async fn handle_connection(

    let (sink_tx, mut sink_rx) = tokio::sync::mpsc::unbounded_channel();

    // not importing this trait because it conflicts with the one from futures_util
    let stream =
        tokio_stream::StreamExt::timeout(stream.map_err(anyhow::Error::from), PING_TIMEOUT)
            .map(|item| match item {
                Ok(value) => value,
                Err(_) => Err(anyhow::anyhow!("Timeout")),
            })
            .try_filter_map(|msg| async {
                println!("{:?}", msg);
                match msg {
                    tokio_tungstenite::tungstenite::protocol::Message::Text(msg) => {
                        let msg: types::MozillaPushMessageS2A = serde_json::from_str(&msg)?;
                        Ok(Some(msg))
                    }
                    _ => Ok(None),
                }
            });

    sink.send(types::MozillaPushMessageA2S::Hello {
        uaid,
        channel_ids: existing_channel_ids,


@@ 111,80 93,95 @@ async fn handle_connection(
            Ok(())
        },
        async {
            stream
                .try_for_each(|msg| {
                    let channel_tx = channel_tx.clone();
                    let register_map = &register_map;
                    let sink_tx = sink_tx.clone();
                    let successfully_connected_tx = successfully_connected_tx.clone();
                    async move {
                        println!("{:?}", msg);
            loop {
                let msg = match futures_util::future::select(
                    stream.next(),
                    tokio_timerfd::sleep(PING_TIMEOUT),
                )
                .await
                {
                    futures_util::future::Either::Left((msg, _)) => msg,
                    futures_util::future::Either::Right(_) => anyhow::bail!("Timeout"),
                };

                        match msg {
                            types::MozillaPushMessageS2A::Hello { uaid, status, .. } => {
                                if status != 200 {
                                    anyhow::bail!("Unexpected status in Hello: {}", status);
                                }
                println!("{:?}", msg);

                                match successfully_connected_tx.try_send(()) {
                                    Ok(_) => {}
                                    Err(tokio::sync::mpsc::error::TrySendError::Full(_)) => {
                                        panic!("Received multiple Hello responses")
                                    }
                                    Err(tokio::sync::mpsc::error::TrySendError::Closed(_)) => {
                                        panic!("somehow lost controller")
                                    }
                                }
                let msg = match msg {
                    Some(Ok(tokio_tungstenite::tungstenite::protocol::Message::Text(msg))) => {
                        let msg: types::MozillaPushMessageS2A = serde_json::from_str(&msg)?;
                        msg
                    }
                    Some(Err(err)) => return Err(err.into()),
                    None => break,
                    _ => continue,
                };

                                channel_tx.send(ConnectionMessageIn::SetUAID(uaid))?;
                let channel_tx = channel_tx.clone();
                let register_map = &register_map;
                let sink_tx = sink_tx.clone();
                let successfully_connected_tx = successfully_connected_tx.clone();

                println!("{:?}", msg);

                match msg {
                    types::MozillaPushMessageS2A::Hello { uaid, status, .. } => {
                        if status != 200 {
                            anyhow::bail!("Unexpected status in Hello: {}", status);
                        }

                        match successfully_connected_tx.try_send(()) {
                            Ok(_) => {}
                            Err(tokio::sync::mpsc::error::TrySendError::Full(_)) => {
                                panic!("Received multiple Hello responses")
                            }
                            types::MozillaPushMessageS2A::Notification {
                            Err(tokio::sync::mpsc::error::TrySendError::Closed(_)) => {
                                panic!("somehow lost controller")
                            }
                        }

                        channel_tx.send(ConnectionMessageIn::SetUAID(uaid))?;
                    }
                    types::MozillaPushMessageS2A::Notification {
                        channel_id,
                        version,
                        data: Some(data),
                    } => {
                        channel_tx.send(ConnectionMessageIn::Notification {
                            channel_id: channel_id.clone(),
                            data: data.data,
                        })?;

                        sink_tx.send(types::MozillaPushMessageA2S::Ack {
                            updates: vec![types::AckUpdate {
                                channel_id,
                                version,
                                data: Some(data),
                            } => {
                                channel_tx.send(ConnectionMessageIn::Notification {
                                    channel_id: channel_id.clone(),
                                    data: data.data,
                                })?;

                                sink_tx.send(types::MozillaPushMessageA2S::Ack {
                                    updates: vec![types::AckUpdate {
                            }],
                        })?;
                    }
                    types::MozillaPushMessageS2A::Register {
                        channel_id,
                        status,
                        push_endpoint,
                    } => {
                        if let Some(request_id) = register_map.lock().unwrap().remove(&channel_id) {
                            channel_tx.send(ConnectionMessageIn::RegisterResponse {
                                request_id,
                                result: match status {
                                    200 => Ok(RegisterResponse {
                                        endpoint: push_endpoint,
                                        channel_id,
                                        version,
                                    }],
                                })?;
                            }
                            types::MozillaPushMessageS2A::Register {
                                channel_id,
                                status,
                                push_endpoint,
                            } => {
                                if let Some(request_id) =
                                    register_map.lock().unwrap().remove(&channel_id)
                                {
                                    channel_tx.send(ConnectionMessageIn::RegisterResponse {
                                        request_id,
                                        result: match status {
                                            200 => Ok(RegisterResponse {
                                                endpoint: push_endpoint,
                                                channel_id,
                                            }),
                                            status => Err(anyhow::anyhow!(
                                                "Failed to register. Got status {}",
                                                status
                                            )),
                                        },
                                    })?;
                                }
                            }
                            _ => {}
                                    }),
                                    status => Err(anyhow::anyhow!(
                                        "Failed to register. Got status {}",
                                        status
                                    )),
                                },
                            })?;
                        }

                        Ok(())
                    }
                })
                .await?;
                    _ => {}
                }
            }

            println!("ended");



@@ 402,7 399,9 @@ async fn main() {
                            backoff_time = START_BACKOFF_TIME;
                        } else {
                            println!("connection failed, waiting {} seconds", backoff_time.as_secs());
                            tokio::time::sleep(backoff_time).await;
                            tokio_timerfd::sleep(backoff_time).await?; // if this somehow fails,
                                                                       // maybe we should actually
                                                                       // just crash instead

                            backoff_time *= 2;
                            if backoff_time > MAX_BACKOFF_TIME {