use crate::{report_parse_error, require_login, simple_response, to_200, to_204, BracketID, BracketType, DbPool, Error, LoginError, UserID};
#[derive(Debug, Clone, Copy)]
enum UserIDOrMe {
Me,
UserID(UserID),
}
impl UserIDOrMe {
pub fn resolve(self, me: UserID) -> UserID {
match self {
UserIDOrMe::Me => me,
UserIDOrMe::UserID(id) => id,
}
}
}
impl std::str::FromStr for UserIDOrMe {
type Err = <UserID as std::str::FromStr>::Err;
fn from_str(src: &str) -> Result<Self, Self::Err> {
if src == "~me" {
Ok(UserIDOrMe::Me)
} else {
src.parse().map(UserIDOrMe::UserID)
}
}
}
async fn require_me(user: UserIDOrMe, me: UserID) -> Result<(), Error> {
match user {
UserIDOrMe::Me => Ok(()),
UserIDOrMe::UserID(user) => {
if user == me {
Ok(())
} else {
Err(Error::UserError(simple_response(hyper::StatusCode::FORBIDDEN, "You are not authorized to do this")))
}
}
}
}
fn is_allowed_username(username: &str) -> bool {
username
.chars()
.all(|chr| chr.is_ascii_alphabetic() || chr.is_ascii_digit())
}
pub fn route_users() -> crate::RouteNode<()> {
let route_brackets = crate::RouteNode::new()
.with_handler_async("GET", |(user,), ctx, req| async move {
route_users_brackets_list_fn(ctx.db_pool.clone(), req, user)
.await
.map(|v| to_200(&v))
});
let route_user =
crate::RouteNode::new().with_handler_async("PATCH", |(user,), ctx, req| async move {
route_users_edit_fn(ctx.db_pool.clone(), req, user)
.await
.map(to_204)
})
.with_handler_async("GET", |(user,), ctx, req| async move {
route_users_get_fn(ctx.db_pool.clone(), req, user)
.await
.map(|v| to_200(&v))
})
.with_child("brackets", route_brackets);
crate::RouteNode::new()
.with_handler_async("POST", |_, ctx, req| async move {
route_users_create_fn(ctx.db_pool.clone(), req)
.await
.map(|v| to_200(&v))
})
.with_child_parse::<UserIDOrMe, _>(route_user)
}
async fn route_users_create_fn(
db_pool: DbPool,
req: hyper::Request<hyper::Body>,
) -> Result<serde_json::Value, Error> {
#[derive(serde_derive::Deserialize, Debug, Default)]
struct UserCreateBody {
username: Option<String>,
password: Option<String>,
}
let body = hyper::body::to_bytes(req.into_body()).await?;
let body: UserCreateBody = if body.is_empty() {
Default::default()
} else {
serde_json::from_slice(&body).map_err(report_parse_error)?
};
let client = db_pool.get().await?;
let (user_id, username): (i32, _) = match if let Some(username) = body.username {
if !is_allowed_username(&username) {
return Err(Error::UserError(crate::simple_response(hyper::StatusCode::BAD_REQUEST, "Invalid characters in username")));
}
if let Some(password) = body.password {
let password_hash =
tokio::task::spawn_blocking(|| bcrypt::hash(password, bcrypt::DEFAULT_COST))
.await??;
let stmt = client
.prepare("INSERT INTO users (username, password_hash) VALUES ($1, $2) RETURNING id")
.await?;
Some((
client
.query_one(&stmt, &[&username, &password_hash])
.await?,
Some(username),
))
} else {
None
}
} else if body.password.is_some() {
None
} else {
let stmt = client
.prepare("INSERT INTO users DEFAULT VALUES RETURNING id")
.await?;
Some((client.query_one(&stmt, &[]).await?, None))
} {
Some((row, username)) => (row.get(0), username),
None => {
return Err(Error::UserError({
let mut res =
hyper::Response::new("username and password cannot be used separately".into());
*res.status_mut() = hyper::StatusCode::BAD_REQUEST;
res
}));
}
};
let token = uuid::Uuid::new_v4();
{
let stmt = client
.prepare("INSERT INTO logins (token, user_id) VALUES ($1, $2)")
.await?;
client.execute(&stmt, &[&token, &user_id]).await?;
}
Ok(serde_json::json!({
"token": token.to_string(),
"user": {
"id": user_id,
"username": username,
}
}))
}
async fn route_users_get_fn(
db_pool: DbPool,
req: hyper::Request<hyper::Body>,
user: UserIDOrMe,
) -> Result<serde_json::Value, Error> {
let user: i32 = match user {
UserIDOrMe::Me => {
let user = crate::get_login(&req, &db_pool)
.await
.map_err(LoginError::to_user)?;
match user {
None => {
Err(Error::UserError({
let mut res = hyper::Response::new("You are not logged in".into());
*res.status_mut() = hyper::StatusCode::UNAUTHORIZED;
res
}))
},
Some(user) => Ok(user),
}?
},
UserIDOrMe::UserID(user) => user,
};
let client = db_pool.get().await?;
let row = client.query_opt("SELECT username FROM users WHERE id=$1", &[&user]).await?;
match row {
None => Err(Error::UserError({
let mut res = hyper::Response::new("No such user".into());
*res.status_mut() = hyper::StatusCode::NOT_FOUND;
res
})),
Some(row) => {
let username: Option<String> = row.get(0);
Ok(serde_json::json!({
"username": username,
"id": user
}))
}
}
}
async fn route_users_edit_fn(
db_pool: DbPool,
req: hyper::Request<hyper::Body>,
user: UserIDOrMe,
) -> Result<(), Error> {
#[derive(serde_derive::Deserialize, Debug)]
struct UserEditBody {
username: Option<String>,
password: Option<String>,
email: Option<String>,
}
let login_user = require_login(&req, &db_pool).await?;
let user = user.resolve(login_user);
let body: UserEditBody = serde_json::from_slice(&hyper::body::to_bytes(req.into_body()).await?)
.map_err(report_parse_error)?;
if user != login_user {
return Err(Error::UserError({
let mut res =
hyper::Response::new("You don't have permission to modify this user".into());
*res.status_mut() = hyper::StatusCode::FORBIDDEN;
res
}));
}
let client = db_pool.get().await?;
let mut values = Vec::with_capacity(3);
let mut columns = Vec::with_capacity(3);
if let Some(username) = body.username {
if !is_allowed_username(&username) {
return Err(Error::UserError({
let mut res = hyper::Response::new("Invalid characters in username".into());
*res.status_mut() = hyper::StatusCode::BAD_REQUEST;
res
}));
}
values.push(username);
columns.push("username");
}
if let Some(password) = body.password {
let hash =
tokio::task::spawn_blocking(|| bcrypt::hash(password, bcrypt::DEFAULT_COST)).await??;
values.push(hash);
columns.push("password_hash");
}
if let Some(email) = body.email {
if !email.contains('@') {
return Err(Error::UserError({
let mut res = hyper::Response::new("Invalid email address".into());
*res.status_mut() = hyper::StatusCode::BAD_REQUEST;
res
}));
}
values.push(email);
columns.push("email");
}
let mut values: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> =
values.iter().map(|x| x as _).collect();
values.push(&user);
let sql: &str = &format!(
"UPDATE users SET {} WHERE id=${}",
columns
.into_iter()
.enumerate()
.map(|(idx, col)| format!("{} = ${}", col, idx + 1))
.collect::<Vec<_>>()
.join(", "),
values.len()
);
client.execute(sql, &values[..]).await?;
Ok(())
}
async fn route_users_brackets_list_fn(
db_pool: DbPool,
req: hyper::Request<hyper::Body>,
user: UserIDOrMe,
) -> Result<serde_json::Value, Error> {
let login = require_login(&req, &db_pool).await?;
require_me(user, login).await?;
let user = login;
let client = db_pool.get().await?;
let stmt = client.prepare("SELECT id, name, type FROM brackets, bracket_access WHERE bracket_access.bracket=brackets.id AND is_admin=TRUE AND user_id=$1").await?;
let rows = client.query(&stmt, &[&user]).await?;
Ok(serde_json::Value::Array(
rows.into_iter()
.map(|row| {
serde_json::json!({
"id": BracketID::from_internal(row.get(0)).to_external(),
"name": row.get::<_, String>(1),
"type": BracketType::from_internal(&row.get::<_, String>(2)),
})
})
.collect(),
))
}