M Cargo.lock => Cargo.lock +55 -3
@@ 26,6 26,15 @@ dependencies = [
]
[[package]]
+name = "ahash"
+version = "0.3.8"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "e8fd72866655d1904d6b0997d0b07ba561047d070fbe29de039031c641b61217"
+dependencies = [
+ "const-random",
+]
+
+[[package]]
name = "aho-corasick"
version = "0.7.13"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ 84,7 93,7 @@ dependencies = [
"base64",
"blowfish",
"byteorder",
- "getrandom",
+ "getrandom 0.1.14",
]
[[package]]
@@ 220,6 229,26 @@ dependencies = [
]
[[package]]
+name = "const-random"
+version = "0.1.11"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "02dc82c12dc2ee6e1ded861cf7d582b46f66f796d1b6c93fa28b911ead95da02"
+dependencies = [
+ "const-random-macro",
+ "proc-macro-hack",
+]
+
+[[package]]
+name = "const-random-macro"
+version = "0.1.11"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "fc757bbb9544aa296c2ae00c679e81f886b37e28e59097defe0cf524306f6685"
+dependencies = [
+ "getrandom 0.2.0",
+ "proc-macro-hack",
+]
+
+[[package]]
name = "core-foundation"
version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ 274,6 303,17 @@ dependencies = [
]
[[package]]
+name = "dashmap"
+version = "3.11.10"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "0f260e2fc850179ef410018660006951c1b55b79e8087e87111a2c388994b9b5"
+dependencies = [
+ "ahash",
+ "cfg-if",
+ "num_cpus",
+]
+
+[[package]]
name = "deadpool"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ 597,6 637,17 @@ dependencies = [
]
[[package]]
+name = "getrandom"
+version = "0.2.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "ee8025cf36f917e6a52cce185b7c7177689b838b7ec138364e50cc2277a56cf4"
+dependencies = [
+ "cfg-if",
+ "libc",
+ "wasi",
+]
+
+[[package]]
name = "h2"
version = "0.2.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ 949,6 1000,7 @@ dependencies = [
"bumpalo",
"bytes",
"chrono",
+ "dashmap",
"deadpool-postgres",
"either",
"fast_chemail",
@@ 1406,7 1458,7 @@ version = "0.7.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6a6b1679d49b24bbfe0c803429aa1874472f50d9b363131f0e89fc356b544d03"
dependencies = [
- "getrandom",
+ "getrandom 0.1.14",
"libc",
"rand_chacha",
"rand_core",
@@ 1429,7 1481,7 @@ version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "90bde5296fc891b0cef12a6d03ddccc162ce7b2aff54160af9338f8d40df6d19"
dependencies = [
- "getrandom",
+ "getrandom 0.1.14",
]
[[package]]
M Cargo.toml => Cargo.toml +1 -0
@@ 47,6 47,7 @@ rand = "0.7.3"
bs58 = "0.3.1"
bumpalo = "3.4.0"
tokio-util = "0.3.1"
+dashmap = "3.11.10"
[dev-dependencies]
rand = "0.7.3"
M src/main.rs => src/main.rs +13 -3
@@ 8,6 8,7 @@ use std::sync::Arc;
use trout::hyper::RoutingFailureExtHyper;
mod apub_util;
+mod ratelimit;
mod routes;
mod tasks;
mod worker;
@@ 139,6 140,7 @@ pub struct BaseContext {
pub http_client: HttpClient,
pub apub_proxy_rewrites: bool,
pub media_location: Option<std::path::PathBuf>,
+ pub api_ratelimit: ratelimit::RatelimitBucket<std::net::IpAddr>,
pub local_hostname: String,
}
@@ 929,6 931,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
host_url_apub,
http_client: hyper::Client::builder().build(hyper_tls::HttpsConnector::new()),
apub_proxy_rewrites,
+ api_ratelimit: ratelimit::RatelimitBucket::new(1),
});
let worker_trigger = worker::start_worker(base_context.clone());
@@ 939,15 942,22 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
});
let server = hyper::Server::bind(&(std::net::Ipv6Addr::UNSPECIFIED, port).into()).serve(
- hyper::service::make_service_fn(|_| {
+ hyper::service::make_service_fn(|sock: &hyper::server::conn::AddrStream| {
+ let addr = sock.remote_addr().ip();
let routes = routes.clone();
let context = context.clone();
- async {
+ async move {
Ok::<_, hyper::Error>(hyper::service::service_fn(move |req| {
let routes = routes.clone();
let context = context.clone();
async move {
- let result = if req.method() == hyper::Method::OPTIONS
+ let ratelimit_ok = context.api_ratelimit.try_call(addr).await;
+ let result = if !ratelimit_ok {
+ common_response_builder()
+ .status(hyper::StatusCode::TOO_MANY_REQUESTS)
+ .body("Ratelimit exceeded.".into())
+ .map_err(Into::into)
+ } else if req.method() == hyper::Method::OPTIONS
&& req.uri().path().starts_with("/api")
{
hyper::Response::builder()
A src/ratelimit.rs => src/ratelimit.rs +86 -0
@@ 0,0 1,86 @@
+use std::sync::atomic::AtomicU16;
+
+pub struct RatelimitBucket<K> {
+ cap: u16,
+ inner: tokio::sync::RwLock<Inner<K>>,
+}
+
+impl<K: Eq + std::hash::Hash + std::fmt::Debug> RatelimitBucket<K> {
+ pub fn new(cap: u16) -> Self {
+ Self {
+ cap,
+ inner: tokio::sync::RwLock::new(Inner {
+ divider_time: std::time::Instant::now(),
+ last_minute: None,
+ current_minute: dashmap::DashMap::new(),
+ }),
+ }
+ }
+
+ pub async fn try_call(&self, key: K) -> bool {
+ let now = std::time::Instant::now();
+ let inner = self.inner.read().await;
+ let seconds_into = now.duration_since(inner.divider_time).as_secs();
+ if seconds_into >= 60 {
+ println!("new minute");
+ std::mem::drop(inner);
+ let mut inner = self.inner.write().await;
+
+ let seconds_into_new = now.duration_since(inner.divider_time).as_secs();
+
+ // check again
+ if seconds_into_new >= 120 {
+ // more than two minutes elapsed, reset
+ inner.last_minute = None;
+ inner.current_minute = dashmap::DashMap::new();
+ inner.divider_time = now;
+
+ self.try_for_current(0, &inner, key).await
+ } else if seconds_into_new >= 60 {
+ let mut tmp = dashmap::DashMap::new();
+ std::mem::swap(&mut tmp, &mut inner.current_minute);
+ inner.last_minute = Some(tmp.into_read_only());
+ inner.divider_time += std::time::Duration::new(60, 0);
+
+ self.try_for_current(seconds_into_new - 60, &inner, key)
+ .await
+ } else {
+ self.try_for_current(seconds_into_new, &inner, key).await
+ }
+ } else {
+ self.try_for_current(seconds_into, &inner, key).await
+ }
+ }
+
+ async fn try_for_current(&self, seconds_into: u64, inner: &Inner<K>, key: K) -> bool {
+ println!("key={:?}", key);
+ let prev_count = if let Some(last_minute) = &inner.last_minute {
+ if let Some(prev_count) = last_minute.get(&key) {
+ (u64::from(prev_count.load(std::sync::atomic::Ordering::Relaxed))
+ * (60 - seconds_into)
+ / 60) as u16
+ } else {
+ 0
+ }
+ } else {
+ 0
+ };
+
+ let count = prev_count
+ + inner
+ .current_minute
+ .entry(key)
+ .or_default()
+ .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
+
+ println!("count={:?}", count);
+
+ count < self.cap
+ }
+}
+
+struct Inner<K> {
+ divider_time: std::time::Instant,
+ last_minute: Option<dashmap::ReadOnlyView<K, AtomicU16>>,
+ current_minute: dashmap::DashMap<K, AtomicU16>,
+}