~vpzom/lotide

a2b6f6705df89e82c1160890c8e900914d622860 — Colin Reeder 13 days ago da0ca3a
Allow using X-Forwarded-For to change ratelimit key
1 files changed, 37 insertions(+), 5 deletions(-)

M src/main.rs
M src/main.rs => src/main.rs +37 -5
@@ 853,6 853,12 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
        Err(other) => Err(other).expect("Failed to parse APUB_PROXY_REWRITES"),
    };

    let allow_forwarded = match std::env::var("ALLOW_FORWARDED") {
        Ok(value) => value.parse().expect("Failed to parse ALLOW_FORWARDED"),
        Err(std::env::VarError::NotPresent) => false,
        Err(other) => Err(other).expect("Failed to parse ALLOW_FORWARDED"),
    };

    let db_pool = deadpool_postgres::Pool::new(
        deadpool_postgres::Manager::new(
            std::env::var("DATABASE_URL")


@@ 943,7 949,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {

    let server = hyper::Server::bind(&(std::net::Ipv6Addr::UNSPECIFIED, port).into()).serve(
        hyper::service::make_service_fn(|sock: &hyper::server::conn::AddrStream| {
            let addr = sock.remote_addr().ip();
            let addr_direct = sock.remote_addr().ip();
            let routes = routes.clone();
            let context = context.clone();
            async move {


@@ 951,12 957,38 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
                    let routes = routes.clone();
                    let context = context.clone();
                    async move {
                        let addr = if allow_forwarded {
                            if let Some(value) = req
                                .headers()
                                .get(hyper::header::HeaderName::from_static("x-forwarded-for"))
                            {
                                match value
                                    .to_str()
                                    .map_err(|_| ())
                                    .and_then(|value| value.split(", ").next().ok_or(()))
                                    .and_then(|value| value.parse().map_err(|_| ()))
                                {
                                    Err(_) => {
                                        return Ok(simple_response(
                                            hyper::StatusCode::BAD_REQUEST,
                                            "Invalid X-Forwarded-For value",
                                        ));
                                    }
                                    Ok(value) => value,
                                }
                            } else {
                                addr_direct
                            }
                        } else {
                            addr_direct
                        };

                        let ratelimit_ok = context.api_ratelimit.try_call(addr).await;
                        let result = if !ratelimit_ok {
                            common_response_builder()
                                .status(hyper::StatusCode::TOO_MANY_REQUESTS)
                                .body("Ratelimit exceeded.".into())
                                .map_err(Into::into)
                            Ok(simple_response(
                                hyper::StatusCode::TOO_MANY_REQUESTS,
                                "Ratelimit exceeded.",
                            ))
                        } else if req.method() == hyper::Method::OPTIONS
                            && req.uri().path().starts_with("/api")
                        {