~whbboyd/russet

54373050cb865a858349e4260094132abaaf7977 — Will Boyd 4 months ago f55259e
Add some rate limiting
7 files changed, 92 insertions(+), 14 deletions(-)

M Cargo.lock
M Cargo.toml
M russet.toml.sample
M src/conf.rs
M src/http/mod.rs
M src/main.rs
M src/server.rs
M Cargo.lock => Cargo.lock +3 -1
@@ 1879,7 1879,7 @@ dependencies = [

[[package]]
name = "russet"
version = "0.9.2"
version = "0.10.0"
dependencies = [
 "argon2",
 "atom_syndication",


@@ 1900,6 1900,7 @@ dependencies = [
 "sqlx",
 "tokio",
 "toml",
 "tower",
 "tracing",
 "tracing-subscriber",
 "ulid",


@@ 2676,6 2677,7 @@ dependencies = [
 "pin-project",
 "pin-project-lite",
 "tokio",
 "tokio-util",
 "tower-layer",
 "tower-service",
 "tracing",

M Cargo.toml => Cargo.toml +5 -4
@@ 1,6 1,6 @@
[package]
name = "russet"
version = "0.9.2"
version = "0.10.0"
edition = "2021"
license = "AGPL-3.0"



@@ 17,6 17,7 @@ rss = "2.0"
axum = { version = "0.7", features = ["tracing"] }
axum-extra = { version = "0.9", features = ["cookie"] }
axum-macros = "0.4"
tower = { version = "0.4", features = ["limit"] }

# HTTP client
reqwest = "0.11"


@@ 28,10 29,10 @@ tokio = { version = "1.36", features = ["full"] }
sqlx = { version = "0.7", features = ["sqlite", "migrate", "runtime-tokio-native-tls"] }

# Configuration (general/config file/CLI)
merge = "0.1"
toml = "0.8"
clap = { version = "4.5", features = ["derive"] }
merge = "0.1"
rpassword = "7.3"
toml = "0.8"

# Assorted time conversions
chrono = "0.4"


@@ 49,6 50,6 @@ sailfish = "0.8"
tracing = { version = "0.1", features = ["log"] }
tracing-subscriber = { version = "0.3", features = ["env-filter"] }


# Serde is used for all sorts of stuff
serde = "1.0"


M russet.toml.sample => russet.toml.sample +18 -0
@@ 17,3 17,21 @@ listen = "127.0.0.1:9892"
# std::time::Duration. Check interval management will likely get overhauled
# soon.
feed_check_interval = { "secs" = 3_600, "nanos" = 0 }

# Settings for rate limiting. The defaults are intended to be conservative;
# you'll want to tune them appropriately to whatever hardware you're running
# Russet on.
[rate_limiting]

# Total number of concurrent connections the application will accept. Additional
# connections over this limit will block until currently-executing requests
# complete.
global_concurrent_limit = 1024

# Total number of concurrent login attempts the application will accept. Login
# requests are very expensive in terms of CPU due to the need to hash passwords,
# so these are limited separately. Under most circumstances, you probably want
# it set to fewer than the number of CPUs available to Russet.
login_concurrent_limit = 4

# TODO: per-client rate limiting

M src/conf.rs => src/conf.rs +32 -2
@@ 1,4 1,4 @@
use clap::{ Parser, Subcommand};
use clap::{ Args, Parser, Subcommand};
use merge::Merge;
use serde::Deserialize;
use std::num::ParseIntError;


@@ 10,6 10,7 @@ use std::time::Duration;
pub struct Config {
	/// Command
	#[command(subcommand)]
	#[serde(skip)] // Putting a command in a config file makes very little sense
	pub command: Option<Command>,

	/// Config file


@@ 27,8 28,9 @@ pub struct Config {

	/// Pepper for password hashing
	///
	/// (Not exposed on the CLI; let's not enoucrage putting secrets in
	/// (Not exposed on the CLI; let's not encourage putting secrets in
	/// commandlines or shell histories.)
	#[arg(hide = true)]
	pub pepper: Option<String>,

	/// Duration between feed checks, in seconds


@@ 41,6 43,9 @@ pub struct Config {
		)
	)]
	pub feed_check_interval: Option<Duration>,

	#[command(flatten)]
	pub rate_limiting: RateLimitingConfig,
}
impl Default for Config {
	fn default() -> Self {


@@ 51,6 56,7 @@ impl Default for Config {
			listen_address: Some("127.0.0.1:9892".to_string()),
			pepper: Some("IzvoEPMQIi82NSXTz7cZ".to_string()),
			feed_check_interval: Some(Duration::from_secs(3_600)),
			rate_limiting: RateLimitingConfig::default(),
		}
	}
}


@@ 62,6 68,7 @@ impl std::fmt::Debug for Config {
			.field("listen_address", &self.listen_address)
			.field("pepper", &"<redacted>")
			.field("feed_check_interval", &self.feed_check_interval.map(|duration| duration.as_secs()))
			.field("rate_limiting", &self.rate_limiting)
			.finish()
	}
}


@@ 103,3 110,26 @@ pub enum Command {
		url: String,
	},
}

#[derive(Args, Debug, Deserialize, Merge)]
pub struct RateLimitingConfig {
	/// Limit of concurrent connections application-wide.
	#[arg(short, long, value_name = "CONNECTIONS")]
	pub global_concurrent_limit: Option<u32>,

	/// Limit of concurrent connections to the login endpoint.
	///
	/// This has a separate limit because it is very expensive in terms of CPU.
	/// Typically it should be set to less than the number of processors
	/// available to Russet.
	#[arg(short = 'o', long, value_name = "CONNECTIONS")]
	pub login_concurrent_limit: Option<u32>,
}
impl Default for RateLimitingConfig {
	fn default() -> Self {
		RateLimitingConfig {
			global_concurrent_limit: Some(1024),
			login_concurrent_limit: Some(4),
		}
	}
}

M src/http/mod.rs => src/http/mod.rs +13 -3
@@ 1,7 1,7 @@
use axum::extract::{ Form, State };
use axum::response::{ Html, Redirect };
use axum::Router;
use axum::routing::{ any, get };
use axum::routing::{ any, get, post };
use crate::domain::model::{ Entry, Feed };
use crate::domain::RussetDomainService;
use crate::http::session::AuthenticatedUser;


@@ 12,6 12,8 @@ use sailfish::TemplateOnce;
use serde::Deserialize;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::Semaphore;
use tower::limit::GlobalConcurrencyLimitLayer;

mod entry;
mod feed;


@@ 20,16 22,24 @@ mod session;
mod static_routes;
mod subscribe;

pub fn russet_router<Persistence>() -> Router<AppState<Persistence>>
pub fn russet_router<Persistence>(
	global_concurrent_limit: u32,
	login_concurrent_limit: u32,
) -> Router<AppState<Persistence>>
where Persistence: RussetPersistenceLayer {
	let global_limit_semaphore = Arc::new(Semaphore::new(global_concurrent_limit.try_into().unwrap()));
	let login_limit_sempahore = Arc::new(Semaphore::new(login_concurrent_limit.try_into().unwrap()));
	Router::new()
		.route("/login", post(login::login_user))
		.layer(GlobalConcurrencyLimitLayer::with_semaphore(login_limit_sempahore))
		.route("/login", get(login::login_page))
		.route("/styles.css", get(static_routes::styles))
		.route("/login", get(login::login_page).post(login::login_user))
		.route("/", get(home))
		.route("/entry/:id", get(entry::mark_read_redirect))
		.route("/feed/:id", get(feed::feed_page).post(feed::unsubscribe))
		.route("/subscribe", get(subscribe::subscribe_page).post(subscribe::subscribe))
		.route("/*any", any(|| async { Redirect::to("/") }))
		.layer(GlobalConcurrencyLimitLayer::with_semaphore(global_limit_semaphore))
}

#[derive(Debug)]

M src/main.rs => src/main.rs +17 -2
@@ 67,7 67,16 @@ async fn main() -> Result<()> {
	let db_file = config.db_file.expect("No db_file");
	let listen_address = config.listen_address.expect("No listen_address");
	let pepper = config.pepper.expect("No pepper");
	let feed_check_interval = config.feed_check_interval.expect("No feed_check_interval");
	let feed_check_interval =
		config.feed_check_interval.expect("No feed_check_interval");
	let global_concurrent_limit = config
		.rate_limiting
		.global_concurrent_limit
		.expect("No global_concurrent_limit");
	let login_concurrent_limit = config
		.rate_limiting
		.login_concurrent_limit
		.expect("No login_concurrent_limit");

	let db = SqlDatabase::new(Path::new(&db_file)).await?;
	let readers: Vec<Box<dyn RussetFeedReader>> = vec![


@@ 82,7 91,13 @@ async fn main() -> Result<()> {
	)?);

	match command {
		Command::Run => start(domain_service, listen_address).await?,
		Command::Run => start(
				domain_service,
				listen_address,
				global_concurrent_limit,
				login_concurrent_limit
			)
			.await?,
		Command::AddUser { user_name, password } => {
			info!("Adding user {user_name}…");
			let plaintext_password = match password {

M src/server.rs => src/server.rs +4 -2
@@ 11,6 11,8 @@ const SESSION_CLEANUP_INTERVAL: Duration = Duration::from_secs(3_600);
pub async fn start<Persistence>(
	domain_service: Arc<RussetDomainService<Persistence>>,
	listen: String,
	global_concurrent_limit: u32,
	login_concurrent_limit: u32,
) -> Result<()>
where Persistence: RussetPersistenceLayer {
	info!("Starting {}…", crate::APP_NAME);


@@ 40,9 42,9 @@ where Persistence: RussetPersistenceLayer {
		}
	} );

	// Setup for Axum
	// Start the HTTP server
	let app_state = AppState { domain_service: domain_service.clone() };
	let routes = russet_router()
	let routes = russet_router(global_concurrent_limit, login_concurrent_limit)
		.with_state(app_state);
	let listener = tokio::net::TcpListener::bind(&listen).await?;
	let graceful_exit_signal = async {