From 3acb9721835123e88d9d393b82bca1276ef1c834 Mon Sep 17 00:00:00 2001 From: Sky Johnson Date: Mon, 30 Jun 2025 22:28:19 -0500 Subject: [PATCH] server update --- common.hpp | 1 - epoll_socket.hpp | 1 - main.cpp | 12 +-- server.hpp | 235 ++++++++++++++++++++++++++++++++++++++++------- 4 files changed, 207 insertions(+), 42 deletions(-) diff --git a/common.hpp b/common.hpp index 6409c69..0a1c14a 100644 --- a/common.hpp +++ b/common.hpp @@ -2,7 +2,6 @@ #include #include -#include #define likely(x) __builtin_expect(!!(x), 1) #define unlikely(x) __builtin_expect(!!(x), 0) diff --git a/epoll_socket.hpp b/epoll_socket.hpp index fe198e6..012dd58 100644 --- a/epoll_socket.hpp +++ b/epoll_socket.hpp @@ -10,7 +10,6 @@ #include #include #include -#include #include #include diff --git a/main.cpp b/main.cpp index e551e8f..c851b0e 100644 --- a/main.cpp +++ b/main.cpp @@ -4,7 +4,7 @@ Server* server = nullptr; -void signal_handler(int sig) { +void signal_handler(int) { if (server) { std::cout << "\nShutting down server...\n"; server->store.save(); @@ -20,18 +20,18 @@ int main() { Router router; // Root route - router.get("/", [](const Request& req, Response& res) { + router.get("/", [](const Request&, Response& res) { res.set_text("Hello, World! HTTP Server with Router\n"); res.set_cookie("test_cookie", "hey there!"); }); // API routes - router.get("/api/status", [](const Request& req, Response& res) { + router.get("/api/status", [](const Request&, Response& res) { res.set_json("{\"status\":\"running\",\"version\":\"1.0\"}"); }); // Users routes - router.get("/users", [](const Request& req, Response& res) { + router.get("/users", [](const Request&, Response& res) { res.set_json("{\"users\":[{\"id\":1,\"name\":\"Alice\"},{\"id\":2,\"name\":\"Bob\"}]}"); }); @@ -56,11 +56,11 @@ int main() { res.set_text(info); }); - router.get("/foo", [](const Request& req, Response& res) { + router.get("/foo", [](const Request&, Response& res) { res.set_text("Admin"); }); - router.get("/admin/counter", [](const Request& req, Response& res) { + router.get("/admin/counter", [](const Request&, Response& res) { auto current = server->store.get("visit_count"); int count = current ? current.value() : 0; server->store.set("visit_count", count + 1); diff --git a/server.hpp b/server.hpp index b8afaed..d01a19b 100644 --- a/server.hpp +++ b/server.hpp @@ -7,7 +7,6 @@ #include "static_file_handler.hpp" #include "kv_store.hpp" #include "session_store.hpp" -#include #include #include #include @@ -15,6 +14,8 @@ #include #include #include +#include +#include class Server { @@ -51,7 +52,7 @@ public: // Initialize all worker sockets (SO_REUSEPORT allows this) for (unsigned int i = 0; i < num_cores; ++i) { - auto worker = std::make_unique(port_, router_, static_handler_, sessions); + auto worker = std::make_unique(port_, router_, static_handler_, sessions, keep_alive_timeout_); if (!worker->socket.start()) return false; workers_.push_back(std::move(worker)); } @@ -113,9 +114,31 @@ public: static_handler_ = std::make_shared(static_dir, url_prefix); } + // Set keep-alive timeout (default: 60 seconds) + void set_keep_alive_timeout(int seconds) + { + keep_alive_timeout_ = seconds; + } + private: static constexpr int BUFFER_SIZE = 65536; + static constexpr int MAX_HEADER_SIZE = 64 * 1024; + static constexpr int MAX_BODY_SIZE = 10 * 1024 * 1024; // 10MB default + static constexpr int DEFAULT_KEEP_ALIVE_TIMEOUT = 60; // seconds + std::shared_ptr static_handler_; + int keep_alive_timeout_ = DEFAULT_KEEP_ALIVE_TIMEOUT; + + // Connection state for HTTP/1.1 keep-alive + struct ConnectionState + { + std::string buffer; // Accumulated request data + std::chrono::steady_clock::time_point last_activity; + bool reading_body = false; + size_t expected_body_length = 0; + size_t current_body_length = 0; + bool chunked_encoding = false; + }; // Worker handles requests in a dedicated thread struct Worker @@ -125,13 +148,14 @@ private: std::shared_ptr& static_handler; SessionStore& sessions; std::array buffer; - std::string request_accumulator; // For handling partial requests + std::unordered_map connections; // Per-connection state std::thread thread; + int& keep_alive_timeout; - Worker(uint16_t port, Router& r, std::shared_ptr& sh, SessionStore& s) - : socket(port), router(r), static_handler(sh), sessions(s) + Worker(uint16_t port, Router& r, std::shared_ptr& sh, SessionStore& s, int& timeout) + : socket(port), router(r), static_handler(sh), sessions(s), keep_alive_timeout(timeout) { - // Set up event handlers for the new epoll socket library + // Set up event handlers for the epoll socket library socket.on_connection([this](int fd) { handle_connection(fd); }); socket.on_data([this](int fd) { handle_data(fd); }); socket.on_disconnect([this](int fd) { handle_disconnect(fd); }); @@ -139,12 +163,25 @@ private: void handle_connection(int client_fd) { - // New client connected - no action needed with new library - // The library handles connection tracking automatically + // Initialize connection state for HTTP/1.1 keep-alive + connections[client_fd] = ConnectionState{ + .buffer = "", + .last_activity = std::chrono::steady_clock::now(), + .reading_body = false, + .expected_body_length = 0, + .current_body_length = 0, + .chunked_encoding = false + }; } void handle_data(int client_fd) { + auto conn_it = connections.find(client_fd); + if (conn_it == connections.end()) return; + + auto& conn_state = conn_it->second; + conn_state.last_activity = std::chrono::steady_clock::now(); + // Read all available data using edge-triggered epoll while (true) { ssize_t bytes_read = read(client_fd, buffer.data(), buffer.size()); @@ -164,37 +201,120 @@ private: } // Accumulate request data - request_accumulator.append(buffer.data(), bytes_read); + conn_state.buffer.append(buffer.data(), bytes_read); - // Check if we have a complete HTTP request - size_t header_end = request_accumulator.find("\r\n\r\n"); - if (header_end != std::string::npos) { - // Process the complete request - process_request(client_fd, request_accumulator); - // Clear accumulator for next request (HTTP keep-alive) - request_accumulator.clear(); + // Prevent memory exhaustion from malicious large requests + if (conn_state.buffer.size() > MAX_HEADER_SIZE && !conn_state.reading_body) { + send_error_response(client_fd, "Request Header Too Large", 431); return; } - // Prevent memory exhaustion from malicious large headers - if (request_accumulator.size() > 64 * 1024) { - send_error_response(client_fd, "Request Header Too Large", 431); - request_accumulator.clear(); + if (conn_state.buffer.size() > MAX_BODY_SIZE) { + send_error_response(client_fd, "Request Entity Too Large", 413); return; } } + + // Process all complete requests in the buffer (HTTP/1.1 pipelining) + process_requests(client_fd, conn_state); } void handle_disconnect(int client_fd) { - // Client disconnected - cleanup handled by epoll library - // Clear any partial request data for this connection - request_accumulator.clear(); + // Clean up connection state + connections.erase(client_fd); } - void process_request(int client_fd, std::string_view request_data) + void process_requests(int client_fd, ConnectionState& conn_state) { - // Parse the HTTP request + while (!conn_state.buffer.empty()) { + if (!conn_state.reading_body) { + // Look for complete headers + size_t header_end = conn_state.buffer.find("\r\n\r\n"); + if (header_end == std::string::npos) { + // Headers not complete yet + return; + } + + // Parse just the headers to get Content-Length + std::string_view headers_only(conn_state.buffer.data(), header_end + 4); + Request temp_req = Parser::parse(headers_only); + + if (!temp_req.valid) { + send_error_response(client_fd, "Bad Request", 400); + conn_state.buffer.clear(); + return; + } + + // Check for chunked encoding + auto transfer_encoding = temp_req.headers.find("transfer-encoding"); + if (transfer_encoding != temp_req.headers.end()) { + std::string_view encoding = transfer_encoding->second; + if (encoding.find("chunked") != std::string_view::npos) { + conn_state.chunked_encoding = true; + } + } + + // Set up body reading state + conn_state.expected_body_length = temp_req.content_length; + conn_state.current_body_length = 0; + conn_state.reading_body = (conn_state.expected_body_length > 0 || conn_state.chunked_encoding); + + if (!conn_state.reading_body) { + // No body expected, process the request immediately + std::string_view complete_request(conn_state.buffer.data(), header_end + 4); + process_single_request(client_fd, complete_request); + + // Remove processed request from buffer + conn_state.buffer.erase(0, header_end + 4); + continue; + } + } + + if (conn_state.reading_body) { + if (conn_state.chunked_encoding) { + // Basic chunked encoding support (simplified) + size_t chunk_end = conn_state.buffer.find("\r\n0\r\n\r\n"); + if (chunk_end == std::string::npos) { + // Chunked body not complete yet + return; + } + + // Process complete chunked request + std::string_view complete_request(conn_state.buffer.data(), chunk_end + 7); + process_single_request(client_fd, complete_request); + + // Remove processed request from buffer + conn_state.buffer.erase(0, chunk_end + 7); + conn_state.reading_body = false; + conn_state.chunked_encoding = false; + } else { + // Fixed Content-Length body + size_t headers_end = conn_state.buffer.find("\r\n\r\n"); + if (headers_end == std::string::npos) return; + + size_t total_expected = headers_end + 4 + conn_state.expected_body_length; + + if (conn_state.buffer.size() < total_expected) { + // Body not complete yet + return; + } + + // Process complete request with body + std::string_view complete_request(conn_state.buffer.data(), total_expected); + process_single_request(client_fd, complete_request); + + // Remove processed request from buffer + conn_state.buffer.erase(0, total_expected); + conn_state.reading_body = false; + } + } + } + } + + void process_single_request(int client_fd, std::string_view request_data) + { + // Parse the complete HTTP request Request req = Parser::parse(request_data); if (!req.valid) { @@ -204,6 +324,16 @@ private: Response response; + // Check Connection header for keep-alive preference + bool client_wants_keepalive = true; + auto connection_header = req.headers.find("connection"); + if (connection_header != req.headers.end()) { + std::string_view conn_value = connection_header->second; + if (conn_value.find("close") != std::string_view::npos) { + client_wants_keepalive = false; + } + } + // Try to handle with router first if (router.handle(req, response)) { // Handle session management @@ -212,20 +342,20 @@ private: sessions.create() : std::string(existing_id); set_session_cookie(response, session_id); - send_response(client_fd, response, req.version); + send_response(client_fd, response, req.version, client_wants_keepalive); return; } // Try static file handler if available if (static_handler && static_handler->handle(req, response)) { - send_response(client_fd, response, req.version); + send_response(client_fd, response, req.version, client_wants_keepalive); return; } // No handler found - return 404 response.status = 404; response.set_text("Not Found"); - send_response(client_fd, response, req.version); + send_response(client_fd, response, req.version, client_wants_keepalive); } void set_session_cookie(Response& response, const std::string& session_id) @@ -235,25 +365,62 @@ private: "; HttpOnly; Path=/; SameSite=Strict"); } - void send_response(int client_fd, const Response& response, string_view version) + void send_response(int client_fd, Response& response, string_view version, bool keep_alive) { - // Build HTTP response string + // Set appropriate Connection header based on HTTP version and client preference + if (version == "HTTP/1.0") { + keep_alive = false; // HTTP/1.0 defaults to close + } + + if (!keep_alive) { + response.headers["Connection"] = "close"; + } else { + response.headers["Connection"] = "keep-alive"; + response.headers["Keep-Alive"] = "timeout=" + std::to_string(keep_alive_timeout); + } + + // Build and send HTTP response std::string http_response = ResponseBuilder::build_response(response, version); - // Use the new epoll library's buffered send method if (!socket.send_data(client_fd, http_response.data(), http_response.size())) { // Send failed - connection probably closed - // Library will handle cleanup automatically + return; + } + + // Close connection if not keeping alive + if (!keep_alive) { + // The epoll library will handle the actual close via EPOLLHUP + shutdown(client_fd, SHUT_WR); } } void send_error_response(int client_fd, const std::string& message, int status) { - // Fast path for error responses + // Fast path for error responses - always close connection on errors std::string response = ResponseBuilder::build_error_response(status, message); - // Use the new epoll library's send method socket.send_data(client_fd, response.data(), response.size()); + // Force connection close on errors + shutdown(client_fd, SHUT_WR); + } + + // Periodic cleanup of idle connections (call from main loop if needed) + void cleanup_idle_connections() + { + auto now = std::chrono::steady_clock::now(); + + for (auto it = connections.begin(); it != connections.end();) { + auto age = std::chrono::duration_cast( + now - it->second.last_activity).count(); + + if (age > keep_alive_timeout) { + // Connection is idle, close it + shutdown(it->first, SHUT_WR); + it = connections.erase(it); + } else { + ++it; + } + } } };