~vpzom/shoved

fe7fd696abeb28bb527b4f9ad95ca6888357ea3e — Colin Reeder 1 year, 19 days ago 737c06e
Maybe make reconnection work

I haven't been able to test this since the connection seems to persist
even if I lose network
1 files changed, 109 insertions(+), 67 deletions(-)

M src/main.rs
M src/main.rs => src/main.rs +109 -67
@@ 55,7 55,6 @@ async fn handle_connection(
    let (conn, _) = tokio_tungstenite::connect_async(SERVER_URL).await?;

    println!("connected");

    let (sink, stream) = conn.split();

    let mut sink = sink.with(|msg: types::MozillaPushMessageA2S| {


@@ 71,6 70,7 @@ async fn handle_connection(
    let stream = stream
        .map_err(anyhow::Error::from)
        .try_filter_map(|msg| async {
            println!("{:?}", msg);
            match msg {
                tokio_tungstenite::tungstenite::protocol::Message::Text(msg) => {
                    let msg: types::MozillaPushMessageS2A = serde_json::from_str(&msg)?;


@@ 97,65 97,75 @@ async fn handle_connection(

            Ok(())
        },
        stream.try_for_each(|msg| {
            let channel_tx = channel_tx.clone();
            let register_map = &register_map;
            let sink_tx = sink_tx.clone();
            async move {
                println!("{:?}", msg);
        async {
            stream
                .try_for_each(|msg| {
                    let channel_tx = channel_tx.clone();
                    let register_map = &register_map;
                    let sink_tx = sink_tx.clone();
                    async move {
                        println!("{:?}", msg);

                match msg {
                    types::MozillaPushMessageS2A::Hello { uaid, status, .. } => {
                        if status != 200 {
                            anyhow::bail!("Unexpected status in Hello: {}", status);
                        }
                        match msg {
                            types::MozillaPushMessageS2A::Hello { uaid, status, .. } => {
                                if status != 200 {
                                    anyhow::bail!("Unexpected status in Hello: {}", status);
                                }

                        channel_tx.send(ConnectionMessageIn::SetUAID(uaid))?;
                    }
                    types::MozillaPushMessageS2A::Notification {
                        channel_id,
                        version,
                        data: Some(data),
                    } => {
                        channel_tx.send(ConnectionMessageIn::Notification {
                            channel_id: channel_id.clone(),
                            data: data.data,
                        })?;

                        sink_tx.send(types::MozillaPushMessageA2S::Ack {
                            updates: vec![types::AckUpdate {
                                channel_tx.send(ConnectionMessageIn::SetUAID(uaid))?;
                            }
                            types::MozillaPushMessageS2A::Notification {
                                channel_id,
                                version,
                            }],
                        })?;
                    }
                    types::MozillaPushMessageS2A::Register {
                        channel_id,
                        status,
                        push_endpoint,
                    } => {
                        if let Some(request_id) = register_map.lock().unwrap().remove(&channel_id) {
                            channel_tx.send(ConnectionMessageIn::RegisterResponse {
                                request_id,
                                result: match status {
                                    200 => Ok(RegisterResponse {
                                        endpoint: push_endpoint,
                                data: Some(data),
                            } => {
                                channel_tx.send(ConnectionMessageIn::Notification {
                                    channel_id: channel_id.clone(),
                                    data: data.data,
                                })?;

                                sink_tx.send(types::MozillaPushMessageA2S::Ack {
                                    updates: vec![types::AckUpdate {
                                        channel_id,
                                    }),
                                    status => Err(anyhow::anyhow!(
                                        "Failed to register. Got status {}",
                                        status
                                    )),
                                },
                            })?;
                                        version,
                                    }],
                                })?;
                            }
                            types::MozillaPushMessageS2A::Register {
                                channel_id,
                                status,
                                push_endpoint,
                            } => {
                                if let Some(request_id) =
                                    register_map.lock().unwrap().remove(&channel_id)
                                {
                                    channel_tx.send(ConnectionMessageIn::RegisterResponse {
                                        request_id,
                                        result: match status {
                                            200 => Ok(RegisterResponse {
                                                endpoint: push_endpoint,
                                                channel_id,
                                            }),
                                            status => Err(anyhow::anyhow!(
                                                "Failed to register. Got status {}",
                                                status
                                            )),
                                        },
                                    })?;
                                }
                            }
                            _ => {}
                        }

                        Ok(())
                    }
                    _ => {}
                }
                })
                .await?;

                Ok(())
            }
        }),
            println!("ended");

            Result::<(), _>::Err(anyhow::anyhow!("Stream ended"))
        },
        async {
            while let Some(msg) = channel_rx.recv().await {
                match msg {


@@ 321,24 331,25 @@ async fn main() {
        .unwrap();
    }

    // TODO abstract these to handle reconnections?
    let (in_tx, mut in_rx) = tokio::sync::mpsc::unbounded_channel();
    let (out_tx, out_rx) = tokio::sync::mpsc::unbounded_channel();
    let (nowhere_tx, _) = tokio::sync::mpsc::unbounded_channel();
    let (out_tx_send, out_tx_recv) = tokio::sync::watch::channel(nowhere_tx);

    let request_destination_map = Arc::new(Mutex::new(HashMap::<
        u64,
        tokio::sync::mpsc::UnboundedSender<ConnectionMessageIn>,
    >::new()));

    let mut uaid: Option<String> = db
        .query_row("SELECT uaid FROM state LIMIT 1", (), |row| row.get(0))
        .optional()
        .unwrap();
    let uaid: Mutex<Option<String>> = Mutex::new(
        db.query_row("SELECT uaid FROM state LIMIT 1", (), |row| row.get(0))
            .optional()
            .unwrap(),
    );

    let existing_channel_ids: Vec<String> = {
    let existing_channel_ids: Mutex<Vec<String>> = {
        let mut stmt = db.prepare("SELECT channel_id FROM registration").unwrap();
        let res: Result<_, _> = stmt.query_map((), |row| row.get(0)).unwrap().collect();
        res.unwrap()
        Mutex::new(res.unwrap())
    };

    let local_set = tokio::task::LocalSet::new();


@@ 346,13 357,30 @@ async fn main() {
    local_set.spawn_local(
        async move {
            futures_util::try_join!(
                handle_connection(uaid.clone(), existing_channel_ids.clone(), in_tx, out_rx),
                async {
                    loop {
                        let (out_tx, out_rx) = tokio::sync::mpsc::unbounded_channel();
                        out_tx_send.send(out_tx).unwrap();
                        let uaid_copy = uaid.lock().unwrap().clone();
                        let existing_channel_ids_copy = existing_channel_ids.lock().unwrap().clone();
                        let result = handle_connection(uaid_copy, existing_channel_ids_copy, in_tx.clone(), out_rx).await;
                        println!("connection ended with {:?}", result);
                    }

                    // for type inference
                    #[allow(unreachable_code)]
                    Result::<_, anyhow::Error>::Ok(())
                },
                {
                async {
                    let request_destination_map = request_destination_map.clone();
                    while let Some(msg) = in_rx.recv().await {
                        match msg {
                            ConnectionMessageIn::RegisterResponse { ref request_id, result: _ } => {
                            ConnectionMessageIn::RegisterResponse { ref request_id, ref result } => {
                                if let Ok(info) = result {
                                    existing_channel_ids.lock().unwrap().push(info.channel_id.clone());
                                }

                                if let Some(destination) = request_destination_map.lock().unwrap().remove(&request_id) {
                                    let _ = destination.send(msg); // if it's gone we don't care
                                }


@@ 376,7 404,8 @@ async fn main() {
                                }
                            }
                            ConnectionMessageIn::SetUAID(new_uaid) => {
                                match &uaid {
                                let mut lock = uaid.lock().unwrap();
                                match &*lock {
                                    None => {
                                        db.execute("INSERT INTO state (uaid) VALUES (?)", (&new_uaid,))?;
                                    }


@@ 386,7 415,7 @@ async fn main() {
                                        }
                                    }
                                }
                                uaid = Some(new_uaid);
                                *lock = Some(new_uaid);
                            },
                        }
                    }


@@ 409,15 438,28 @@ async fn main() {
                        });

                        let request_destination_map = request_destination_map.clone();
                        let out_tx = out_tx.clone();
                        let mut out_tx_recv = out_tx_recv.clone();

                        tokio::spawn(async move {
                            while let Some(msg) = sub_out_rx.recv().await {
                                match msg {
                                    ConnectionMessageOut::Register { request_id } => {
                                        request_destination_map.lock().unwrap().insert(request_id, sub_in_tx.clone());
                                        out_tx.send(msg).unwrap(); // will handle this better once
                                                                   // reconnections are implemented

                                        let mut msg = msg;
                                        loop {
                                            let result = out_tx_recv.borrow().send(msg);
                                            msg = match result {
                                                Ok(_) => {
                                                    // success!
                                                    break;
                                                }
                                                Err(tokio::sync::mpsc::error::SendError(msg)) => {
                                                    out_tx_recv.changed().await.unwrap();
                                                    msg
                                                }
                                            };
                                        }
                                    }
                                }
                            }