~koehr/k0r

1536c38a976172917ccd4fba07c7cd035e68fae2 — koehr 4 months ago c2123a3
better error handling, multiple db return types
5 files changed, 186 insertions(+), 59 deletions(-)

M Cargo.lock
M Cargo.toml
M src/db.rs
M src/main.rs
M src/server.rs
M Cargo.lock => Cargo.lock +2 -0
@@ 984,6 984,7 @@ dependencies = [
 "actix-web",
 "exitcode",
 "failure",
 "failure_derive",
 "futures",
 "human-panic",
 "log",


@@ 1503,6 1504,7 @@ dependencies = [
 "libsqlite3-sys",
 "memchr",
 "smallvec",
 "uuid",
]

[[package]]

M Cargo.toml => Cargo.toml +1 -0
@@ 28,6 28,7 @@ log = { version = "0.4", features = ["max_level_debug", "release_max_level_info"
pretty_env_logger = "0.4"
uuid = { version = "0.8", features = ["v4"] }
failure = "0.1.8"
failure_derive = "0.1.1"
exitcode = "1.1.2"
human-panic = "1.0.3"
text_io = "0.1.8"

M src/db.rs => src/db.rs +152 -43
@@ 1,106 1,215 @@
use actix_web::{web, Error as AWError};
use failure::Error;
use failure_derive::Fail;
use futures::{Future, TryFutureExt};
use r2d2_sqlite::SqliteConnectionManager;
use rusqlite::NO_PARAMS;

use super::short_code::{random_uuid, ShortCode};

type Result<T = DBValue, E = Error> = std::result::Result<T, E>;

#[derive(Debug, Fail)]
pub enum DBError {
    #[fail(display = "The loaded database has not been initialized.")]
    InvalidSchema,

    #[fail(display = "Database error: {} ({})", msg, src)]
    SqliteError { msg: String, src: rusqlite::Error },
}

#[derive(Debug)]
pub enum DBValue {
    String(String),
    Number(i64),
    // Bool(bool),
    None,
}

pub type Pool = r2d2::Pool<SqliteConnectionManager>;
pub type Connection = r2d2::PooledConnection<SqliteConnectionManager>;

#[derive(serde::Deserialize)]
pub struct UrlPostData {
    pub url: String,
    pub title: Option<String>,
    pub description: Option<String>,
    pub key: String,
}

pub enum Queries {
    NeedsInit,
    CountUsers,
    InitDB,
    CreateUser(i64, bool),       // rate_limit, is_admin
    GetURL(String),              // short_code
    StoreNewURL(String, String), // api_key, url
    CreateUser(i64, bool),    // rate_limit, is_admin
    GetURL(String),           // short_code
    StoreNewURL(UrlPostData), // api_key, url, title?, description?
}

fn check_database_schema(conn: Connection) -> Result<String, Error> {
    match conn.query_row(
        "SELECT COUNT(name) FROM sqlite_master WHERE type='table' AND name IN ('Users', 'URLs')",
        &[0],
        |row| row.get(0),
    ) {
        Ok(2) => Ok(String::from("")),
        _ => Err(Error::from(rusqlite::Error::QueryReturnedNoRows)),
fn get_database_schema(conn: &Connection) -> Result<String> {
    let mut stmt = conn.prepare(
        "
        SELECT
          m.name as table_name,
          p.name as column_name
        FROM
          sqlite_master AS m
        JOIN
          pragma_table_info(m.name) AS p
        ORDER BY
          m.name,
          p.name;
    ",
    )?;
    let mut rows = stmt.query(NO_PARAMS)?;

    let mut tuples = Vec::new();
    while let Some(row) = rows.next()? {
        let table: String = row.get(0)?;
        let column: String = row.get(1)?;
        tuples.push(format!("{}|{}", table, column));
    }

    let schema = tuples.join("\n");
    Ok(schema)
}

fn check_database_schema(conn: Connection) -> Result {
    // TODO: is that really a good way to check the schema?
    let expected_schema = String::from(
        "URLs|created_at
URLs|description
URLs|title
URLs|url
URLs|user_id
URLs|visits
Users|api_key
Users|is_admin
Users|rate_limit
Users|rowid",
    );
    let schema = get_database_schema(&conn)?;

    if schema == expected_schema {
        debug!("Schema validated!");
        Ok(DBValue::None)
    } else {
        debug!("Schema not valid!");
        Err(Error::from(DBError::InvalidSchema))
    }
}

/// Initializes a new SQlite database with the default schema.
fn init_database(conn: Connection) -> Result<String, Error> {
    match conn.execute_batch(
fn init_database(conn: Connection) -> Result {
    conn.execute_batch(
        "
        BEGIN;
        CREATE TABLE IF NOT EXISTS Users(
          UserID   INTEGER PRIMARY KEY,
          APIKey    TEXT UNIQUE NOT NULL,
          RateLimit INTEGER DEFAULT 0,
          Admin     SMALLINT DEFAULT 0
          rowid     INTEGER NOT NULL,
          api_key    TEXT UNIQUE NOT NULL,
          rate_limit INTEGER DEFAULT 0,
          is_admin     SMALLINT DEFAULT 0,
          PRIMARY KEY(rowid)
        );
        CREATE UNIQUE INDEX IF NOT EXISTS idx_api_key ON Users(APIKey);
        CREATE UNIQUE INDEX IF NOT EXISTS idx_api_key ON Users(api_key);
        CREATE TABLE IF NOT EXISTS URLs(
          ID      INTEGER PRIMARY KEY,
          URL     TEXT NOT NULL,
          Visits  INTEGER DEFAULT 0,
          UserID INTEGER NOT NULL,
          FOREIGN KEY(UserID) REFERENCES Users(UserID)
          url         TEXT NOT NULL,
          visits      INTEGER DEFAULT 0,
          title       TEXT,
          description TEXT,
          created_at   DATETIME,
          user_id      INTEGER NOT NULL,
          FOREIGN KEY(user_id) REFERENCES Users(rowid)
        );
        COMMIT;",
    ) {
        Ok(_) => Ok(String::from("")), // db::query expects Result<String, Error>
        Err(err) => Err(Error::from(err)),
    }
    )
    .map(|_| DBValue::None)
    .map_err(|src| {
        let msg = "Failed to init DB!".to_owned();
        Error::from(DBError::SqliteError { msg, src })
    })
}

fn count_users(conn: Connection) -> Result {
    conn.query_row("SELECT COUNT(rowid) FROM USERS", NO_PARAMS, |row| {
        row.get(0)
    })
    .map(|v| DBValue::Number(v))
    .map_err(|src| {
        let msg = "Could not check users.".to_owned();
        Error::from(DBError::SqliteError { msg, src })
    })
}

/// Creates a user entry with random API key and returns the API key.
fn create_user(conn: Connection, rate_limit: i64, is_admin: bool) -> Result<String, Error> {
fn create_user(conn: Connection, rate_limit: i64, is_admin: bool) -> Result {
    let new_key = random_uuid();
    let is_admin = if is_admin { "1" } else { "0" };
    let _ = conn.execute(
    conn.execute(
        "INSERT INTO Users VALUES(NULL, ?1, ?2, ?3)",
        &[&new_key, &rate_limit.to_string(), &is_admin.to_string()],
    )?;
    Ok(new_key)
    )
    .map(|_| DBValue::String(new_key))
    .map_err(|src| {
        let msg = "Could not create user.".to_owned();
        Error::from(DBError::SqliteError { msg, src })
    })
}

/// Looks up an URL by translating the short_code to its ID
fn get_url(conn: Connection, short_code: &str) -> Result<String, Error> {
fn get_url(conn: Connection, short_code: &str) -> Result {
    let row_id = ShortCode::from_code(short_code)?.n;

    conn.query_row(
        "SELECT URL FROM URLs WHERE ID = ?",
        "SELECT url FROM URLs WHERE rowid = ?",
        &[row_id as i64],
        |row| row.get(0),
    )
    .map_err(Error::from)
    .map(|url| DBValue::String(url))
    .map_err(|src| {
        let msg = "Could not retrieve URL".to_owned();
        Error::from(DBError::SqliteError { msg, src })
    })
}

/// Stores a new URL if api_key is assigned to a valid user
fn store_url(conn: Connection, api_key: &str, url: &str) -> Result<String, Error> {
fn store_url(conn: Connection, data: &UrlPostData) -> Result {
    let user_id: i64 = conn.query_row(
        "SELECT UserID from Users WHERE APIKey = ?",
        &[api_key],
        "SELECT rowid FROM Users WHERE api_key = ?",
        &[&data.key],
        |row| row.get(0),
    )?;
    let _ = conn.execute(
        "INSERT INTO URLs VALUES(NULL, ?, 0, ?)",
        &[url, &(user_id.to_string())],
    let _ = conn.execute_named(
        "INSERT INTO URLs VALUES(:url, 0, :title, :description, DATETIME('now'), :user_id)",
        &[
            (":url", &data.url),
            (":title", data.title.as_ref().unwrap_or(&String::from(""))),
            (
                ":description",
                data.description.as_ref().unwrap_or(&String::from("")),
            ),
            (":user_id", &(user_id.to_string())),
        ],
    )?;
    // TODO: In case a plain [0-9a-z] string will be included into
    // IGNORED_SHORT_CODES, this function should work around such IDs as well.
    let short_code = ShortCode::new(conn.last_insert_rowid() as usize).code;
    Ok(short_code)
    Ok(DBValue::String(short_code))
}

/// translates Queries to function calls and returns the result
pub fn query(pool: &Pool, query: Queries) -> impl Future<Output = Result<String, AWError>> {
pub fn query(
    pool: &Pool,
    query: Queries,
) -> impl Future<Output = std::result::Result<DBValue, AWError>> {
    let pool = pool.clone();
    web::block(move || match query {
        Queries::NeedsInit => check_database_schema(pool.get()?),
        Queries::CountUsers => count_users(pool.get()?),
        Queries::InitDB => init_database(pool.get()?),
        Queries::CreateUser(rate_limit, is_admin) => create_user(pool.get()?, rate_limit, is_admin),
        Queries::GetURL(short_code) => get_url(pool.get()?, &short_code),
        Queries::StoreNewURL(api_key, url) => store_url(pool.get()?, &api_key, &url),
        Queries::StoreNewURL(url_data) => store_url(pool.get()?, &url_data),
    })
    .map_err(AWError::from)
}

M src/main.rs => src/main.rs +15 -2
@@ 12,6 12,8 @@ mod db;
mod server;
mod short_code;

use db::DBValue;

// This includes the template code generated by ructe
include!(concat!(env!("OUT_DIR"), "/templates.rs"));



@@ 57,9 59,20 @@ async fn init_db_pool(path_str: String) -> db::Pool {
    let db_pool = db::Pool::new(db_manager).unwrap();

    if (db::query(&db_pool, db::Queries::NeedsInit).await).is_err() {
        debug!("New database. Initializing schema and adding super user...");
        debug!("New database. Initializing schema...");
        let _ = db::query(&db_pool, db::Queries::InitDB).await;
        let _ = db::query(&db_pool, db::Queries::CreateUser(0, true)).await;
    }

    match db::query(&db_pool, db::Queries::CountUsers).await {
        Ok(DBValue::Number(0)) => {
            debug!("Adding super user...");
            if let Some(err) = (db::query(&db_pool, db::Queries::CreateUser(0, true)).await).err() {
                panic!("Failed to create super user! {}", err);
            }
        },
        Ok(DBValue::Number(_)) => { /* nothing to do */ },
        Ok(v) => debug!("Got unexpected value when counting users: {:#?}", v),
        Err(err) => panic!("Failed to create super user! {}", err),
    }

    db_pool

M src/server.rs => src/server.rs +16 -14
@@ 11,7 11,7 @@ use actix_web::{
use std::time::{Duration, SystemTime};
use super::templates::{self, statics::StaticFile};
use super::render;
use super::db;
use super::db::{self, DBValue};

const CONTENT_TYPE_HTML: &str = "content-type: text/html; charset=utf-8";
const CONTENT_TYPE_JSON: &str = "content-type: application/json; charset=utf-8";


@@ 23,7 23,7 @@ const  FAR: Duration = Duration::from_secs(180 * 24 * 60 * 60);
const IGNORED_SHORT_CODES: &[&str] = &["favicon.ico"]; // TODO: make db::store_url aware of this

type DB = web::Data<db::Pool>;
type JSON = web::Json<UrlPostData>;
type JSON = web::Json<db::UrlPostData>;

fn get_request_origin(req: &HttpRequest) -> String {
    req.connection_info().remote_addr().unwrap_or("unkown origin").to_string()


@@ 43,6 43,12 @@ fn build_404_json() -> HttpResponse {
        .body("{{\"status\": \"error\", \"message\": \"Not Found\"}}")
}

fn build_500_json() -> HttpResponse {
    HttpResponse::InternalServerError()
        .content_type(CONTENT_TYPE_JSON)
        .body("{{\"status\": \"error\", \"message\": \"Internal Server Error\"}}")
}

/// Index page handler
#[actix_web::get("/")]
async fn index() -> HttpResponse {


@@ 81,7 87,7 @@ async fn redirect(req: HttpRequest, db: DB) -> Result<HttpResponse, Error> {
        debug!("{} queried {}: IGNORED", get_request_origin(&req), short_code);
        Ok(build_404_json())

    } else if let Ok(url) = db::query(&db, db::Queries::GetURL(short_code.to_owned())).await {
    } else if let Ok(DBValue::String(url)) = db::query(&db, db::Queries::GetURL(short_code.to_owned())).await {
        let body = format!("Would redirect to <a href=\"{}\">{}</a>.", url, url);
        debug!("{} queried {}, got {}", get_request_origin(&req), short_code, url);
        Ok(HttpResponse::Ok().content_type(CONTENT_TYPE_HTML).body(body))


@@ 93,12 99,6 @@ async fn redirect(req: HttpRequest, db: DB) -> Result<HttpResponse, Error> {
}


#[derive(serde::Deserialize)]
struct UrlPostData {
    url: String,
    key: String,
}

#[actix_web::post("/")]
async fn add_url(_req: HttpRequest, data: JSON, db: DB) -> Result<HttpResponse, Error> {
    match Url::parse(&data.url) {


@@ 110,16 110,18 @@ async fn add_url(_req: HttpRequest, data: JSON, db: DB) -> Result<HttpResponse, 

            let query_result = db::query(
                &db,
                db::Queries::StoreNewURL(data.key.clone(), data.url.clone())
                db::Queries::StoreNewURL(data.into_inner())
            ).await;

            debug!("{} posted \"{}\" with key \"{}\", got {:?}", get_request_origin(&_req), &data.url, &data.key, query_result);

            match query_result {
                Ok(code) => Ok(HttpResponse::Created()
                Ok(DBValue::String(code)) => Ok(HttpResponse::Created()
                    .content_type(CONTENT_TYPE_JSON)
                    .body(format!("{{\"status\": \"ok\", \"message\": \"{}\"}}", code))),
                Err(_) => Ok(build_400_json("Invalid API key"))
                Err(err) => Ok(build_400_json(&format!("Invalid API key: {:?}", err))),
                _ => {
                    debug!("Got unexpected type back from StoreNewURL query: {:#?}", query_result);
                    Ok(build_500_json())
                }
            }
        },
        Err(_) => {