#pragma once #include "epoll_socket.hpp" #include "router.hpp" #include "parser.hpp" #include "response.hpp" #include "static_file_handler.hpp" #include "kv_store.hpp" #include "session_store.hpp" #include #include #include #include #include #include #include #include #include class Server { public: KeyValueStore store; SessionStore sessions; explicit Server(uint16_t port, Router& router) : port_(port), router_(router) {} ~Server() { stop(); store.save(); sessions.save(); // Wait for all worker threads to finish for (auto& worker : workers_) { if (worker->thread.joinable()) { worker->thread.join(); } } } bool start() { // Load persistent data store.load(); sessions.load(); // Create one worker per CPU core for optimal performance unsigned int num_cores = std::thread::hardware_concurrency(); if (num_cores == 0) num_cores = 1; workers_.reserve(num_cores); // Initialize all worker sockets (SO_REUSEPORT allows this) for (unsigned int i = 0; i < num_cores; ++i) { auto worker = std::make_unique(port_, router_, static_handler_, sessions, keep_alive_timeout_); if (!worker->socket.start()) return false; workers_.push_back(std::move(worker)); } // Start worker threads for (auto& worker : workers_) { worker->thread = std::thread([&worker]() { worker->socket.run(); }); } return true; } void run() { // Wait for all workers to complete for (auto& worker : workers_) { if (worker->thread.joinable()) { worker->thread.join(); } } } void stop() { // Signal all workers to stop for (auto& worker : workers_) { worker->socket.stop(); } } // Utility function to extract path segments static std::string get_path_param(string_view path, size_t segment_index = 0) { size_t start = 0; size_t current_segment = 0; while (start < path.length()) { if (path[start] == '/') { start++; continue; } size_t end = path.find('/', start); if (end == std::string_view::npos) end = path.length(); if (current_segment == segment_index) { return std::string(path.substr(start, end - start)); } current_segment++; start = end; } return ""; } // Enable static file serving void serve_static(const std::string& static_dir, const std::string& url_prefix = "") { static_handler_ = std::make_shared(static_dir, url_prefix); } // Set keep-alive timeout (default: 60 seconds) void set_keep_alive_timeout(int seconds) { keep_alive_timeout_ = seconds; } private: static constexpr int BUFFER_SIZE = 65536; static constexpr int MAX_HEADER_SIZE = 64 * 1024; static constexpr int MAX_BODY_SIZE = 10 * 1024 * 1024; // 10MB default static constexpr int DEFAULT_KEEP_ALIVE_TIMEOUT = 60; // seconds std::shared_ptr static_handler_; int keep_alive_timeout_ = DEFAULT_KEEP_ALIVE_TIMEOUT; // Connection state for HTTP/1.1 keep-alive struct ConnectionState { std::string buffer; // Accumulated request data std::chrono::steady_clock::time_point last_activity; bool reading_body = false; size_t expected_body_length = 0; size_t current_body_length = 0; bool chunked_encoding = false; }; // Worker handles requests in a dedicated thread struct Worker { EpollSocket socket; Router& router; std::shared_ptr& static_handler; SessionStore& sessions; std::array buffer; std::unordered_map connections; // Per-connection state std::thread thread; int& keep_alive_timeout; Worker(uint16_t port, Router& r, std::shared_ptr& sh, SessionStore& s, int& timeout) : socket(port), router(r), static_handler(sh), sessions(s), keep_alive_timeout(timeout) { // Set up event handlers for the epoll socket library socket.on_connection([this](int fd) { handle_connection(fd); }); socket.on_data([this](int fd) { handle_data(fd); }); socket.on_disconnect([this](int fd) { handle_disconnect(fd); }); } void handle_connection(int client_fd) { // Initialize connection state for HTTP/1.1 keep-alive connections[client_fd] = ConnectionState{ .buffer = "", .last_activity = std::chrono::steady_clock::now(), .reading_body = false, .expected_body_length = 0, .current_body_length = 0, .chunked_encoding = false }; } void handle_data(int client_fd) { auto conn_it = connections.find(client_fd); if (conn_it == connections.end()) return; auto& conn_state = conn_it->second; conn_state.last_activity = std::chrono::steady_clock::now(); // Read all available data using edge-triggered epoll while (true) { ssize_t bytes_read = read(client_fd, buffer.data(), buffer.size()); if (bytes_read == -1) { if (errno == EAGAIN || errno == EWOULDBLOCK) { // No more data available break; } // Read error - connection will be closed by epoll library return; } if (bytes_read == 0) { // EOF - client closed connection return; } // Accumulate request data conn_state.buffer.append(buffer.data(), bytes_read); // Prevent memory exhaustion from malicious large requests if (conn_state.buffer.size() > MAX_HEADER_SIZE && !conn_state.reading_body) { send_error_response(client_fd, "Request Header Too Large", 431); return; } if (conn_state.buffer.size() > MAX_BODY_SIZE) { send_error_response(client_fd, "Request Entity Too Large", 413); return; } } // Process all complete requests in the buffer (HTTP/1.1 pipelining) process_requests(client_fd, conn_state); } void handle_disconnect(int client_fd) { // Clean up connection state connections.erase(client_fd); } void process_requests(int client_fd, ConnectionState& conn_state) { while (!conn_state.buffer.empty()) { if (!conn_state.reading_body) { // Look for complete headers size_t header_end = conn_state.buffer.find("\r\n\r\n"); if (header_end == std::string::npos) { // Headers not complete yet return; } // Parse just the headers to get Content-Length std::string_view headers_only(conn_state.buffer.data(), header_end + 4); Request temp_req = Parser::parse(headers_only); if (!temp_req.valid) { send_error_response(client_fd, "Bad Request", 400); conn_state.buffer.clear(); return; } // Check for chunked encoding auto transfer_encoding = temp_req.headers.find("transfer-encoding"); if (transfer_encoding != temp_req.headers.end()) { std::string_view encoding = transfer_encoding->second; if (encoding.find("chunked") != std::string_view::npos) { conn_state.chunked_encoding = true; } } // Set up body reading state conn_state.expected_body_length = temp_req.content_length; conn_state.current_body_length = 0; conn_state.reading_body = (conn_state.expected_body_length > 0 || conn_state.chunked_encoding); if (!conn_state.reading_body) { // No body expected, process the request immediately std::string_view complete_request(conn_state.buffer.data(), header_end + 4); process_single_request(client_fd, complete_request); // Remove processed request from buffer conn_state.buffer.erase(0, header_end + 4); continue; } } if (conn_state.reading_body) { if (conn_state.chunked_encoding) { // Basic chunked encoding support (simplified) size_t chunk_end = conn_state.buffer.find("\r\n0\r\n\r\n"); if (chunk_end == std::string::npos) { // Chunked body not complete yet return; } // Process complete chunked request std::string_view complete_request(conn_state.buffer.data(), chunk_end + 7); process_single_request(client_fd, complete_request); // Remove processed request from buffer conn_state.buffer.erase(0, chunk_end + 7); conn_state.reading_body = false; conn_state.chunked_encoding = false; } else { // Fixed Content-Length body size_t headers_end = conn_state.buffer.find("\r\n\r\n"); if (headers_end == std::string::npos) return; size_t total_expected = headers_end + 4 + conn_state.expected_body_length; if (conn_state.buffer.size() < total_expected) { // Body not complete yet return; } // Process complete request with body std::string_view complete_request(conn_state.buffer.data(), total_expected); process_single_request(client_fd, complete_request); // Remove processed request from buffer conn_state.buffer.erase(0, total_expected); conn_state.reading_body = false; } } } } void process_single_request(int client_fd, std::string_view request_data) { // Parse the complete HTTP request Request req = Parser::parse(request_data); if (!req.valid) { send_error_response(client_fd, "Bad Request", 400); return; } Response response; // Check Connection header for keep-alive preference bool client_wants_keepalive = true; auto connection_header = req.headers.find("connection"); if (connection_header != req.headers.end()) { std::string_view conn_value = connection_header->second; if (conn_value.find("close") != std::string_view::npos) { client_wants_keepalive = false; } } // Try to handle with router first if (router.handle(req, response)) { // Handle session management std::string_view existing_id = sessions.extract_session_id(req); std::string session_id = existing_id.empty() ? sessions.create() : std::string(existing_id); set_session_cookie(response, session_id); send_response(client_fd, response, req.version, client_wants_keepalive); return; } // Try static file handler if available if (static_handler && static_handler->handle(req, response)) { send_response(client_fd, response, req.version, client_wants_keepalive); return; } // No handler found - return 404 response.status = 404; response.set_text("Not Found"); send_response(client_fd, response, req.version, client_wants_keepalive); } void set_session_cookie(Response& response, const std::string& session_id) { // Set secure session cookie response.cookies.push_back("session_id=" + session_id + "; HttpOnly; Path=/; SameSite=Strict"); } void send_response(int client_fd, Response& response, string_view version, bool keep_alive) { // Set appropriate Connection header based on HTTP version and client preference if (version == "HTTP/1.0") { keep_alive = false; // HTTP/1.0 defaults to close } if (!keep_alive) { response.headers["Connection"] = "close"; } else { response.headers["Connection"] = "keep-alive"; response.headers["Keep-Alive"] = "timeout=" + std::to_string(keep_alive_timeout); } // Build and send HTTP response std::string http_response = ResponseBuilder::build_response(response, version); if (!socket.send_data(client_fd, http_response.data(), http_response.size())) { // Send failed - connection probably closed return; } // Close connection if not keeping alive if (!keep_alive) { // The epoll library will handle the actual close via EPOLLHUP shutdown(client_fd, SHUT_WR); } } void send_error_response(int client_fd, const std::string& message, int status) { // Fast path for error responses - always close connection on errors std::string response = ResponseBuilder::build_error_response(status, message); socket.send_data(client_fd, response.data(), response.size()); // Force connection close on errors shutdown(client_fd, SHUT_WR); } // Periodic cleanup of idle connections (call from main loop if needed) void cleanup_idle_connections() { auto now = std::chrono::steady_clock::now(); for (auto it = connections.begin(); it != connections.end();) { auto age = std::chrono::duration_cast( now - it->second.last_activity).count(); if (age > keep_alive_timeout) { // Connection is idle, close it shutdown(it->first, SHUT_WR); it = connections.erase(it); } else { ++it; } } } }; uint16_t port_; Router& router_; std::vector> workers_; };