@@ 43,12 43,19 @@ static ERR_500_HTML: &'static [u8] = include_bytes!("../web/500.html");
static CONFIG: OnceCell<Config> = OnceCell::new();
#[derive(Deserialize, Debug)]
+struct UserConfig {
+ token: String,
+ max_file_size: Option<i64>,
+}
+
+#[derive(Deserialize, Debug)]
struct Config {
data: String,
expiration_days: i64,
listen: String,
max_file_size: i64,
url: String,
+ users: Vec<UserConfig>,
}
impl std::default::Default for Config {
@@ 59,6 66,7 @@ impl std::default::Default for Config {
listen: "127.0.0.1:8080".into(),
max_file_size: 10000000,
url: "http://localhost:8080".into(),
+ users: Vec::new(),
}
}
}
@@ 301,6 309,59 @@ async fn on_get(config: &Config, req: Request<Body>) -> Result<Response<Body>, B
}
async fn on_put(config: &Config, req: Request<Body>) -> Result<Response<Body>, Infallible> {
+ let headers = req.headers();
+
+ let mut max_file_size = config.max_file_size;
+
+ if let Some(auth) = headers.get(header::AUTHORIZATION) {
+ let mut parts = match auth.to_str() {
+ Err(err) => {
+ error!("Authorization header exists but can't be stringified: {}", err);
+ return Ok(Response::builder()
+ .status(400).body(Body::from(ERR_400_HTML)).unwrap());
+ },
+ Ok(s) => s.split(" "),
+ };
+
+ let scheme = parts.nth(0);
+ let mut tok = "".to_owned();
+ for part in parts {
+ if tok != "" {
+ tok += " ";
+ tok += part;
+ } else {
+ tok += part;
+ }
+ }
+
+ if scheme.is_none() || scheme.unwrap().to_ascii_lowercase() != "bearer" || tok == "" {
+ error!("Invalid authorization header");
+ return Ok(Response::builder()
+ .status(400).body(Body::from(ERR_400_HTML)).unwrap());
+ }
+
+ // TODO(perf): Consider a O(1) or O(log n) lookup
+ let mut user_config: Option<&UserConfig> = None;
+ for user in &config.users {
+ if user.token == tok {
+ user_config = Some(&user);
+ break;
+ }
+ }
+
+ if user_config.is_none() {
+ error!("Invalid authorization token");
+ return Ok(Response::builder()
+ .status(400).body(Body::from(ERR_400_HTML)).unwrap());
+ }
+
+ let user_config = user_config.unwrap();
+
+ if let Some(user_max_file_size) = user_config.max_file_size {
+ max_file_size = user_max_file_size;
+ }
+ }
+
let (parts, mut reqbody) = req.into_parts();
let pathstr = parts.uri.path();
let path = Path::new(pathstr);
@@ 320,9 381,7 @@ async fn on_put(config: &Config, req: Request<Body>) -> Result<Response<Body>, I
Err(err) => {
error!("Failed to create random file: {}", err);
return Ok(Response::builder()
- .status(500)
- .body(Body::from(ERR_500_HTML))
- .unwrap());
+ .status(500).body(Body::from(ERR_500_HTML)).unwrap());
},
Ok((file, name)) => (file, name),
};
@@ 354,8 413,8 @@ async fn on_put(config: &Config, req: Request<Body>) -> Result<Response<Body>, I
first = false;
}
- if bodysize + chunk.len() > config.max_file_size as usize {
- error!("Uploaded file exceeds max size ({} bytes)", config.max_file_size);
+ if bodysize + chunk.len() > max_file_size as usize {
+ error!("Uploaded file exceeds max size ({} bytes)", max_file_size);
delete_file(&config.data, &name);
return respond_413();
}
@@ 408,7 467,7 @@ async fn main() {
debug!("Logger initialized");
if let Ok(confstr) = fs::read_to_string("config.toml") {
- CONFIG.set(toml::from_str(&confstr).unwrap()).unwrap();
+ CONFIG.set(toml::from_str(&confstr).expect("Failed to parse config")).unwrap();
} else {
CONFIG.set(Config::default()).unwrap();
}