diff --git a/CMakeLists.txt b/CMakeLists.txt index 3c4f475..6f4f4e7 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -78,7 +78,7 @@ add_executable(${TARGET_NAME} target_compile_definitions(${TARGET_NAME} PRIVATE APPLICATION_NAME="${TARGET_NAME}") -target_link_libraries(${TARGET_NAME} pthread sqlite3) +target_link_libraries(${TARGET_NAME} pthread sqlite3 sodium) # Optional: Print build type at configuration time message(STATUS "Configuring build type: ${CMAKE_BUILD_TYPE}") \ No newline at end of file diff --git a/src/database/database.cpp b/src/database/database.cpp index 05b2700..171c542 100644 --- a/src/database/database.cpp +++ b/src/database/database.cpp @@ -36,9 +36,29 @@ bool Database::exec(const std::string& sqlQuery){ return exec(sqlQuery.c_str()); } -map Database::getStrMap(const string& sql){ - sqlite3_stmt* stmt = prepareStmt(sql); +sqlite3_stmt* Database::bind(const std::string sql, const Database::QueryData& data) { + sqlite3_stmt* stmt = nullptr; + if (sqlite3_prepare_v2(m_db, sql.c_str(), -1, &stmt, nullptr) != SQLITE_OK) { + CROW_LOG_ERROR << "Failed to prepare statement: " << sqlite3_errmsg(m_db); + return nullptr; + } + for (int i = 0; i < data.size(); i++) { + std::visit([stmt, i](auto&& val) { + using T = std::decay_t; + if constexpr (std::is_same_v) + sqlite3_bind_int64(stmt, i + 1, val); + else if constexpr (std::is_same_v) + sqlite3_bind_double(stmt, i + 1, val); + else if constexpr (std::is_same_v) + sqlite3_bind_text(stmt, i + 1, val.c_str(), -1, SQLITE_TRANSIENT); + }, data[i]); + } + return stmt; +} + +map Database::getStrMap(const std::string sql, const Database::QueryData& data){ map map; + sqlite3_stmt* stmt = bind(sql, data); if (stmt == nullptr) return map; @@ -52,8 +72,8 @@ map Database::getStrMap(const string& sql){ return map; } -string Database::getStr(const string& sql){ - sqlite3_stmt* stmt = prepareStmt(sql); +string Database::getStr(const std::string sql, const Database::QueryData& data){ + sqlite3_stmt* stmt = bind(sql, data); string str; if (stmt == nullptr) return str; diff --git a/src/database/database.hpp b/src/database/database.hpp index a5fbca7..e7aa512 100644 --- a/src/database/database.hpp +++ b/src/database/database.hpp @@ -3,12 +3,16 @@ #include "sqlite3.h" #include +#include +#include #include #include #include class Database { + typedef std::vector> QueryData; + public: Database(); ~Database(); @@ -24,9 +28,11 @@ public: std::set getStrSet(const std::string& sql); - string getStr(const string& sql) + std::string getStr(const std::string sql, const QueryData& data); - std::map getStrMap(const std::string& sql); + sqlite3_stmt* bind(const std::string sql, const QueryData& data); + + std::map getStrMap(const std::string sql, const QueryData& data); private: sqlite3_stmt* prepareStmt(const std::string& sql); diff --git a/src/login/login.cpp b/src/login/login.cpp index 0325e3a..05069f9 100644 --- a/src/login/login.cpp +++ b/src/login/login.cpp @@ -13,7 +13,7 @@ struct Session { std::unordered_map sessions; -bool verifyHashWithPassword(std::string& const hash, std::string const& password) +bool verifyHashWithPassword(const std::string& hash, std::string const& password) { if (crypto_pwhash_str_verify(hash.c_str(), password.c_str(), password.size()) == 0) { return true; @@ -22,7 +22,7 @@ bool verifyHashWithPassword(std::string& const hash, std::string const& password } } -std::string hashPassword(std::string& const password) +std::string hashPassword(const std::string& password) { // Allocate storage for the hash char hash[crypto_pwhash_STRBYTES]; @@ -103,21 +103,19 @@ auto login_required = [](auto handler){ bool loginUser(const std::string& username, const std::string& password) { - auto sql = format("SELECT password_hash FROM users HERE username = '{}' LIMIT 1;", username); + auto sql = "SELECT id, password_hash FROM users WHERE username = '?' LIMIT 1;"; auto db = Database(); if (!db.open()) return false; - auto opt_str = db.getStr(sql.c_str()); - if (opt_str.has_value()) { - return verifyHashWithPassword(opt_str.value(), password); + auto str = db.getStr(sql, {username}); + if (!str.empty()) { + return verifyHashWithPassword(str, password); } else { return false; } } -} - bool initDB() { @@ -172,35 +170,27 @@ bool initLogin(crow::SimpleApp& app) CROW_ROUTE(app, "/login").methods("POST"_method) ([](const crow::request& req){ - auto cookie_it = req.get_header_value("Cookie").find("session_id="); - if (cookie_it == std::string::npos) - return crow::response(401, "No session"); - - // extract session_id - std::string session_id = req.get_header_value("Cookie").substr(cookie_it + 11, 32); - auto it = sessions.find(session_id); - if (it == sessions.end()) return crow::response(401, "Invalid session"); - - auto session = it->second; - // parse form auto body = crow::query_string(req.body); std::string csrf_token = body.get("csrf_token"); std::string username = body.get("username"); std::string password = body.get("password"); - if (csrf_token != session.csrf_token) return crow::response(403, "CSRF failed"); + if (csrf_token != csrf_token) return crow::response(403, "CSRF failed"); - bool ok = loginUser(); + bool ok = loginUser(username, password); if (!ok) return crow::response(401, "Invalid credentials"); + std::string session_id = generate_session_id(); // regenerate session, mark as logged in - sessions[session_id].user_id = generate_session_id(); // user ID + sessions[session_id].user_id; // user ID crow::response res; res.add_header("HX-Redirect", "/dashboard"); // htmx redirect return res; }); + + return true; } } \ No newline at end of file diff --git a/src/main.cpp b/src/main.cpp index ff0affd..fbfb3b6 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -108,7 +108,10 @@ int main() { } shadowrun::initApi(app); - login::initLogin(app); + if(!login::initLogin(app)) + { + CROW_LOG_ERROR << "Failed to init Login API"; + } app.loglevel(crow::LogLevel::INFO); app.port(httpPort).multithreaded().run(); diff --git a/src/shadowrun/ShadowrunDb.cpp b/src/shadowrun/ShadowrunDb.cpp index ea6b3f1..75e3e45 100644 --- a/src/shadowrun/ShadowrunDb.cpp +++ b/src/shadowrun/ShadowrunDb.cpp @@ -99,12 +99,12 @@ std::set getCharacters(){ } std::map getCharacterData(int64_t characterKey) { - auto sql = format("SELECT name, value FROM shadowrun_data WHERE character_id = {};", characterKey); + std::string sql = "SELECT name, value FROM shadowrun_data WHERE character_id = ?;"; auto db = Database(); if (!db.open()) return std::map(); - return db.getStrMap(sql); + return db.getStrMap(sql, {characterKey}); } } \ No newline at end of file