diff --git a/src/database/database.cpp b/src/database/database.cpp index 171c542..d29da42 100644 --- a/src/database/database.cpp +++ b/src/database/database.cpp @@ -1,6 +1,7 @@ #include "crow.h" #include "utils.hpp" #include "database.hpp" +#include using namespace std; @@ -72,18 +73,6 @@ map Database::getStrMap(const std::string sql, const Database::Q return map; } -string Database::getStr(const std::string sql, const Database::QueryData& data){ - sqlite3_stmt* stmt = bind(sql, data); - string str; - if (stmt == nullptr) - return str; - - str = reinterpret_cast(sqlite3_column_text(stmt, 0)); - - sqlite3_finalize(stmt); - return str; -} - set Database::getStrSet(const string& sql){ sqlite3_stmt* stmt = prepareStmt(sql); set vec; @@ -99,20 +88,6 @@ set Database::getStrSet(const string& sql){ return vec; } -std::optional Database::getInt(const char* sql) { - sqlite3_stmt* stmt = prepareStmt(sql); - if (stmt == nullptr) - return {}; - - std::optional id; - if (sqlite3_step(stmt) == SQLITE_ROW) { - id = sqlite3_column_int64(stmt, 0); - } - - sqlite3_finalize(stmt); - return id; -} - std::optional Database::insert(const char* sql) { sqlite3_stmt* stmt = prepareStmt(sql); if (stmt == nullptr) diff --git a/src/database/database.hpp b/src/database/database.hpp index e7aa512..7a5ea23 100644 --- a/src/database/database.hpp +++ b/src/database/database.hpp @@ -21,14 +21,27 @@ public: bool exec(const char* sqlQuery); bool exec(const std::string& sqlQuery); - /// returns true if the sql statment returns at least one row - std::optional getInt(const char* sql); - std::optional insert(const char* sql); std::set getStrSet(const std::string& sql); - std::string getStr(const std::string sql, const QueryData& data); + template + std::optional get(const std::string sql, const QueryData& data){ + sqlite3_stmt* stmt = bind(sql, data); + T ret; + if ((stmt == nullptr) || (sqlite3_step(stmt) != SQLITE_ROW)) + return {}; + + if constexpr (std::is_same_v) { + ret = sqlite3_column_int64(stmt, 0); + } + else if constexpr (std::is_same_v){ + ret = reinterpret_cast(sqlite3_column_text(stmt, 0)); + } + + sqlite3_finalize(stmt); + return ret; + } sqlite3_stmt* bind(const std::string sql, const QueryData& data); diff --git a/src/login/login.cpp b/src/login/login.cpp index 05069f9..ef2f525 100644 --- a/src/login/login.cpp +++ b/src/login/login.cpp @@ -41,21 +41,6 @@ std::string hashPassword(const std::string& password) return hash; } -std::string generate_session_id(size_t bytes = 32) { - std::vector buf(bytes); - randombytes_buf(buf.data(), buf.size()); - - // Convert to hex - std::string hex; - hex.reserve(bytes * 2); - static const char *hexmap = "0123456789abcdef"; - for (unsigned char b : buf) { - hex.push_back(hexmap[b >> 4]); - hex.push_back(hexmap[b & 0xF]); - } - return hex; -} - std::string get_session_id(const crow::request& req) { auto cookie_header = req.get_header_value("Cookie"); std::string prefix = "session_id="; @@ -103,15 +88,15 @@ auto login_required = [](auto handler){ bool loginUser(const std::string& username, const std::string& password) { - auto sql = "SELECT id, password_hash FROM users WHERE username = '?' LIMIT 1;"; + auto sql = "SELECT id password_hash FROM users WHERE username = '?' LIMIT 1;"; auto db = Database(); if (!db.open()) return false; - auto str = db.getStr(sql, {username}); - if (!str.empty()) { - return verifyHashWithPassword(str, password); + auto opt_str = db.get(sql, {username}); + if (opt_str.has_value()) { + return verifyHashWithPassword(opt_str.value(), password); } else { return false; } @@ -170,21 +155,31 @@ bool initLogin(crow::SimpleApp& app) CROW_ROUTE(app, "/login").methods("POST"_method) ([](const crow::request& req){ + auto cookie_it = req.get_header_value("Cookie").find("session_id="); + if (cookie_it == std::string::npos) + return crow::response(401, "No session"); + + // extract session_id + std::string session_id = req.get_header_value("Cookie").substr(cookie_it + 11, 32); + auto it = sessions.find(session_id); + if (it == sessions.end()) return crow::response(401, "Invalid session"); + + auto session = it->second; + // parse form auto body = crow::query_string(req.body); std::string csrf_token = body.get("csrf_token"); std::string username = body.get("username"); std::string password = body.get("password"); - if (csrf_token != csrf_token) return crow::response(403, "CSRF failed"); + if (csrf_token != session.csrf_token) return crow::response(403, "CSRF failed"); bool ok = loginUser(username, password); if (!ok) return crow::response(401, "Invalid credentials"); - std::string session_id = generate_session_id(); // regenerate session, mark as logged in - sessions[session_id].user_id; // user ID + sessions[session_id].user_id = "123"; // user ID crow::response res; res.add_header("HX-Redirect", "/dashboard"); // htmx redirect diff --git a/src/shadowrun/ShadowrunDb.cpp b/src/shadowrun/ShadowrunDb.cpp index 75e3e45..18938a8 100644 --- a/src/shadowrun/ShadowrunDb.cpp +++ b/src/shadowrun/ShadowrunDb.cpp @@ -41,13 +41,13 @@ bool initDb() { } int64_t getKeyOfCharacter(const string& name){ - auto sql = format("SELECT id FROM shadowrun_characters WHERE name = '{}' LIMIT 1;", name); + std::string sql = "SELECT id FROM shadowrun_characters WHERE name = ? LIMIT 1;"; auto db = Database(); if (!db.open()) return -1; - auto opt_int = db.getInt(sql.c_str()); + auto opt_int = db.get(sql, {name}); if (opt_int.has_value()) { return opt_int.value(); } else {