~koehr/k0r

449c87d8baa655b925710bedde251312d78fe775 — koehr 8 months ago f58a9d8
initialize database if not existing
7 files changed, 128 insertions(+), 34 deletions(-)

M Cargo.lock
M Cargo.toml
M src/db.rs
M src/main.rs
M src/server.rs
M src/short_code.rs
A test.db
M Cargo.lock => Cargo.lock +8 -0
@@ 996,7 996,9 @@ dependencies = [
 "rusqlite",
 "serde",
 "serde_json",
 "text_io",
 "url",
 "uuid",
]

[[package]]


@@ 1744,6 1746,12 @@ dependencies = [
]

[[package]]
name = "text_io"
version = "0.1.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6cb170b4f47dc48835fbc56259c12d8963e542b05a24be2e3a1f5a6c320fd2d4"

[[package]]
name = "thiserror"
version = "1.0.23"
source = "registry+https://github.com/rust-lang/crates.io-index"

M Cargo.toml => Cargo.toml +2 -0
@@ 26,9 26,11 @@ rusqlite = "0.24"
futures = "0.3"
log = { version = "0.4", features = ["max_level_debug", "release_max_level_info"] }
pretty_env_logger = "0.4"
uuid = { version = "0.8", features = ["v4"] }
failure = "0.1.8"
exitcode = "1.1.2"
human-panic = "1.0.3"
text_io = "0.1.8"

[build-dependencies]
ructe = { version = "0.13", features = ["mime03"] }

M src/db.rs => src/db.rs +60 -3
@@ 3,16 3,68 @@ use failure::Error;
use futures::{Future, TryFutureExt};
use r2d2_sqlite::SqliteConnectionManager;

use super::short_code::ShortCode;
use super::short_code::{random_uuid, ShortCode};

pub type Pool = r2d2::Pool<SqliteConnectionManager>;
pub type Connection = r2d2::PooledConnection<SqliteConnectionManager>;

pub enum Queries {
    GetURL(String),
    StoreNewURL(String, String),
    NeedsInit,
    InitDB,
    CreateUser(i64, bool),       // rate_limit, is_admin
    GetURL(String),              // short_code
    StoreNewURL(String, String), // api_key, url
}

fn check_database_schema(conn: Connection) -> Result<String, Error> {
    match conn.query_row(
        "SELECT COUNT(name) FROM sqlite_master WHERE type='table' AND name IN ('Users', 'URLs')",
        &[0],
        |row| row.get(0),
    ) {
        Ok(2) => Ok(String::from("")),
        _ => Err(Error::from(rusqlite::Error::QueryReturnedNoRows)),
    }
}

/// Initializes a new SQlite database with the default schema.
fn init_database(conn: Connection) -> Result<String, Error> {
    match conn.execute_batch(
        "
        BEGIN;
        CREATE TABLE IF NOT EXISTS Users(
          UserID   INTEGER PRIMARY KEY,
          APIKey    TEXT UNIQUE NOT NULL,
          RateLimit INTEGER DEFAULT 0,
          Admin     SMALLINT DEFAULT 0
        );
        CREATE UNIQUE INDEX IF NOT EXISTS idx_api_key ON Users(APIKey);
        CREATE TABLE IF NOT EXISTS URLs(
          ID      INTEGER PRIMARY KEY,
          URL     TEXT NOT NULL,
          Visits  INTEGER DEFAULT 0,
          UserID INTEGER NOT NULL,
          FOREIGN KEY(UserID) REFERENCES Users(UserID)
        );
        COMMIT;",
    ) {
        Ok(_) => Ok(String::from("")), // db::query expects Result<String, Error>
        Err(err) => Err(Error::from(err)),
    }
}

/// Creates a user entry with random API key and returns the API key.
fn create_user(conn: Connection, rate_limit: i64, is_admin: bool) -> Result<String, Error> {
    let new_key = random_uuid();
    let is_admin = if is_admin { "1" } else { "0" };
    let _ = conn.execute(
        "INSERT INTO Users VALUES(NULL, ?1, ?2, ?3)",
        &[&new_key, &rate_limit.to_string(), &is_admin.to_string()],
    )?;
    Ok(new_key)
}

/// Looks up an URL by translating the short_code to its ID
fn get_url(conn: Connection, short_code: &str) -> Result<String, Error> {
    let row_id = ShortCode::from_code(short_code)?.n;
    conn.query_row(


@@ 23,6 75,7 @@ fn get_url(conn: Connection, short_code: &str) -> Result<String, Error> {
    .map_err(Error::from)
}

/// Stores a new URL if api_key is assigned to a valid user
fn store_url(conn: Connection, api_key: &str, url: &str) -> Result<String, Error> {
    let user_id: i64 = conn.query_row(
        "SELECT UserID from Users WHERE APIKey = ?",


@@ 39,9 92,13 @@ fn store_url(conn: Connection, api_key: &str, url: &str) -> Result<String, Error
    Ok(short_code)
}

/// translates Queries to function calls and returns the result
pub fn query(pool: &Pool, query: Queries) -> impl Future<Output = Result<String, AWError>> {
    let pool = pool.clone();
    web::block(move || match query {
        Queries::NeedsInit => check_database_schema(pool.get()?),
        Queries::InitDB => init_database(pool.get()?),
        Queries::CreateUser(rate_limit, is_admin) => create_user(pool.get()?, rate_limit, is_admin),
        Queries::GetURL(short_code) => get_url(pool.get()?, &short_code),
        Queries::StoreNewURL(api_key, url) => store_url(pool.get()?, &api_key, &url),
    })

M src/main.rs => src/main.rs +10 -2
@@ 3,6 3,7 @@ extern crate log;
extern crate pretty_env_logger;
use human_panic::setup_panic;
use std::path::PathBuf;
use text_io::read;

mod actix_ructe;
mod db;


@@ 34,8 35,15 @@ fn main() {
        }

        if !path.is_file() {
            error!("DB path not found and cannot be created: {:?}", path);
            std::process::exit(exitcode::CANTCREAT);
            println!(
                "Database file {} not found. Create it? [y/N]",
                path.to_str().unwrap()
            );
            let input: String = read!("{}\n");
            if input != "y" && input != "Y" {
                error!("DB path not found and cannot be created: {:?}", path);
                std::process::exit(exitcode::CANTCREAT);
            }
        }

        debug!("Starting server...");

M src/server.rs => src/server.rs +37 -29
@@ 30,6 30,20 @@ fn get_request_origin(req: &HttpRequest) -> String {
    req.connection_info().remote_addr().unwrap_or("unkown origin").to_string()
}

fn build_400_json(msg: &str) -> HttpResponse {
    let body = format!("{{\"status\": \"error\", \"message\": \"{}\"}}", msg);

    HttpResponse::BadRequest()
        .content_type(CONTENT_TYPE_JSON)
        .body(body)
}

fn build_404_json() -> HttpResponse {
    HttpResponse::BadRequest()
        .content_type(CONTENT_TYPE_JSON)
        .body("{{\"status\": \"error\", \"message\": \"Not Found\"}}")
}

/// Index page handler
#[actix_web::get("/")]
async fn index() -> HttpResponse {


@@ 62,15 76,11 @@ fn static_file(path: web::Path<String>) -> HttpResponse {
/// with a redirect or, if not found, a JSON error
#[actix_web::get("/{short_code}")]
async fn redirect(req: HttpRequest, db: DB) -> Result<HttpResponse, Error> {
    let respond_with_not_found = HttpResponse::NotFound()
        .content_type(CONTENT_TYPE_JSON)
        .body("{{\"status\": \"error\", \"message\": \"URL not found\"}}");

    let short_code = req.match_info().get("short_code").unwrap_or("0");

    if IGNORED_SHORT_CODES.contains(&short_code) {
        debug!("{} queried {}: IGNORED", get_request_origin(&req), short_code);
        Ok(respond_with_not_found)
        Ok(build_404_json())

    } else if let Ok(url) = db::query(&db, db::Queries::GetURL(short_code.to_owned())).await {
        let body = format!("Would redirect to <a href=\"{}\">{}</a>.", url, url);


@@ 79,59 89,57 @@ async fn redirect(req: HttpRequest, db: DB) -> Result<HttpResponse, Error> {

    } else {
        debug!("{} queried {}, got Not Found", get_request_origin(&req), short_code);
        Ok(respond_with_not_found)
        Ok(build_404_json())
    }
}


#[derive(serde::Deserialize)]
struct UrlPostData {
    url: String,
    key: String,
}

fn build_400(msg: &str) -> HttpResponse {
    let body = format!("{{\"status\": \"error\", \"message\": \"{}\"}}", msg);

    HttpResponse::BadRequest()
        .content_type(CONTENT_TYPE_JSON)
        .body(body)
}

#[actix_web::post("/")]
async fn add_url(_req: HttpRequest, data: JSON, db: DB) -> Result<HttpResponse, Error> {
    match Url::parse(&data.url) {
        Ok(parsed_url) => {
            if !parsed_url.has_authority() {
                debug!("{} posted \"{}\", got Invalid, no authority.", get_request_origin(&_req), &data.url);
                return Ok(build_400("Invalid URL, cannot be path only or data URL"));
                return Ok(build_400_json("Invalid URL, cannot be path only or data URL"));
            }
            match db::query(&db, db::Queries::StoreNewURL(data.key.clone(), data.url.clone())).await {
                Ok(code) => {
                    debug!("{} posted \"{}\" with key \"{}\", got {}", get_request_origin(&_req), &data.url, &data.key, code);
                    Ok(HttpResponse::Created()
                        .content_type(CONTENT_TYPE_JSON)
                        .body(format!("{{\"status\": \"ok\", \"message\": \"{}\"}}", code)))
                }
                Err(err) => {
                    debug!("{} posted \"{}\" with key \"{}\", got {}", get_request_origin(&_req), &data.url, &data.key, err);
                    Ok(build_400("Invalid API key"))
                }

            let query_result = db::query(
                &db,
                db::Queries::StoreNewURL(data.key.clone(), data.url.clone())
            ).await;

            debug!("{} posted \"{}\" with key \"{}\", got {:?}", get_request_origin(&_req), &data.url, &data.key, query_result);

            match query_result {
                Ok(code) => Ok(HttpResponse::Created()
                    .content_type(CONTENT_TYPE_JSON)
                    .body(format!("{{\"status\": \"ok\", \"message\": \"{}\"}}", code))),
                Err(_) => Ok(build_400_json("Invalid API key"))
            }
        },
        Err(_) => {
            debug!("{} posted \"{}\", got Invalid, Parser Error.", get_request_origin(&_req), &data.url);
            Ok(build_400("Invalid URL"))
            Ok(build_400_json("Invalid URL"))
        },
    }
}

#[actix_web::main]
pub async fn start(db_path: PathBuf) -> std::io::Result<()> {
    debug!("Canonical database path is {:?}", db_path.canonicalize());

    let db_manager = SqliteConnectionManager::file(db_path);
    let db_pool = db::Pool::new(db_manager).unwrap();

    if (db::query(&db_pool, db::Queries::NeedsInit).await).is_err() {
        let _ = db::query(&db_pool, db::Queries::InitDB).await;
        let _ = db::query(&db_pool, db::Queries::CreateUser(0, true)).await;
    }

    println!("Server is listening on 127.0.0.1:8080");

    actix_web::HttpServer::new(move || {

M src/short_code.rs => src/short_code.rs +11 -0
@@ 1,5 1,6 @@
use radix_fmt::radix_36;
use std::num::ParseIntError;
use uuid::Uuid;

pub struct ShortCode {
    pub code: String,


@@ 18,3 19,13 @@ impl ShortCode {
        Ok(ShortCode { code, n })
    }
}

/// Creates a new random UUID and encodes it as lower case hyphenated string
// see https://docs.rs/uuid/0.8.2/uuid/adapter/struct.Hyphenated.html
// in case you wonder about that Uuid::encode_buffer()
pub fn random_uuid() -> String {
    let uuid = Uuid::new_v4();
    uuid.to_hyphenated()
        .encode_lower(&mut Uuid::encode_buffer())
        .to_owned()
}

A test.db => test.db +0 -0