cpp_server/server.hpp
2025-06-30 22:28:19 -05:00

431 lines
12 KiB
C++

#pragma once
#include "epoll_socket.hpp"
#include "router.hpp"
#include "parser.hpp"
#include "response.hpp"
#include "static_file_handler.hpp"
#include "kv_store.hpp"
#include "session_store.hpp"
#include <string.h>
#include <string_view>
#include <array>
#include <string>
#include <thread>
#include <vector>
#include <memory>
#include <chrono>
#include <unordered_map>
class Server
{
public:
KeyValueStore store;
SessionStore sessions;
explicit Server(uint16_t port, Router& router) : port_(port), router_(router) {}
~Server()
{
stop();
store.save();
sessions.save();
// Wait for all worker threads to finish
for (auto& worker : workers_) {
if (worker->thread.joinable()) {
worker->thread.join();
}
}
}
bool start()
{
// Load persistent data
store.load();
sessions.load();
// Create one worker per CPU core for optimal performance
unsigned int num_cores = std::thread::hardware_concurrency();
if (num_cores == 0) num_cores = 1;
workers_.reserve(num_cores);
// Initialize all worker sockets (SO_REUSEPORT allows this)
for (unsigned int i = 0; i < num_cores; ++i) {
auto worker = std::make_unique<Worker>(port_, router_, static_handler_, sessions, keep_alive_timeout_);
if (!worker->socket.start()) return false;
workers_.push_back(std::move(worker));
}
// Start worker threads
for (auto& worker : workers_) {
worker->thread = std::thread([&worker]() { worker->socket.run(); });
}
return true;
}
void run()
{
// Wait for all workers to complete
for (auto& worker : workers_) {
if (worker->thread.joinable()) {
worker->thread.join();
}
}
}
void stop()
{
// Signal all workers to stop
for (auto& worker : workers_) {
worker->socket.stop();
}
}
// Utility function to extract path segments
static std::string get_path_param(string_view path, size_t segment_index = 0)
{
size_t start = 0;
size_t current_segment = 0;
while (start < path.length()) {
if (path[start] == '/') {
start++;
continue;
}
size_t end = path.find('/', start);
if (end == std::string_view::npos) end = path.length();
if (current_segment == segment_index) {
return std::string(path.substr(start, end - start));
}
current_segment++;
start = end;
}
return "";
}
// Enable static file serving
void serve_static(const std::string& static_dir, const std::string& url_prefix = "")
{
static_handler_ = std::make_shared<StaticFileHandler>(static_dir, url_prefix);
}
// Set keep-alive timeout (default: 60 seconds)
void set_keep_alive_timeout(int seconds)
{
keep_alive_timeout_ = seconds;
}
private:
static constexpr int BUFFER_SIZE = 65536;
static constexpr int MAX_HEADER_SIZE = 64 * 1024;
static constexpr int MAX_BODY_SIZE = 10 * 1024 * 1024; // 10MB default
static constexpr int DEFAULT_KEEP_ALIVE_TIMEOUT = 60; // seconds
std::shared_ptr<StaticFileHandler> static_handler_;
int keep_alive_timeout_ = DEFAULT_KEEP_ALIVE_TIMEOUT;
// Connection state for HTTP/1.1 keep-alive
struct ConnectionState
{
std::string buffer; // Accumulated request data
std::chrono::steady_clock::time_point last_activity;
bool reading_body = false;
size_t expected_body_length = 0;
size_t current_body_length = 0;
bool chunked_encoding = false;
};
// Worker handles requests in a dedicated thread
struct Worker
{
EpollSocket socket;
Router& router;
std::shared_ptr<StaticFileHandler>& static_handler;
SessionStore& sessions;
std::array<char, BUFFER_SIZE> buffer;
std::unordered_map<int, ConnectionState> connections; // Per-connection state
std::thread thread;
int& keep_alive_timeout;
Worker(uint16_t port, Router& r, std::shared_ptr<StaticFileHandler>& sh, SessionStore& s, int& timeout)
: socket(port), router(r), static_handler(sh), sessions(s), keep_alive_timeout(timeout)
{
// Set up event handlers for the epoll socket library
socket.on_connection([this](int fd) { handle_connection(fd); });
socket.on_data([this](int fd) { handle_data(fd); });
socket.on_disconnect([this](int fd) { handle_disconnect(fd); });
}
void handle_connection(int client_fd)
{
// Initialize connection state for HTTP/1.1 keep-alive
connections[client_fd] = ConnectionState{
.buffer = "",
.last_activity = std::chrono::steady_clock::now(),
.reading_body = false,
.expected_body_length = 0,
.current_body_length = 0,
.chunked_encoding = false
};
}
void handle_data(int client_fd)
{
auto conn_it = connections.find(client_fd);
if (conn_it == connections.end()) return;
auto& conn_state = conn_it->second;
conn_state.last_activity = std::chrono::steady_clock::now();
// Read all available data using edge-triggered epoll
while (true) {
ssize_t bytes_read = read(client_fd, buffer.data(), buffer.size());
if (bytes_read == -1) {
if (errno == EAGAIN || errno == EWOULDBLOCK) {
// No more data available
break;
}
// Read error - connection will be closed by epoll library
return;
}
if (bytes_read == 0) {
// EOF - client closed connection
return;
}
// Accumulate request data
conn_state.buffer.append(buffer.data(), bytes_read);
// Prevent memory exhaustion from malicious large requests
if (conn_state.buffer.size() > MAX_HEADER_SIZE && !conn_state.reading_body) {
send_error_response(client_fd, "Request Header Too Large", 431);
return;
}
if (conn_state.buffer.size() > MAX_BODY_SIZE) {
send_error_response(client_fd, "Request Entity Too Large", 413);
return;
}
}
// Process all complete requests in the buffer (HTTP/1.1 pipelining)
process_requests(client_fd, conn_state);
}
void handle_disconnect(int client_fd)
{
// Clean up connection state
connections.erase(client_fd);
}
void process_requests(int client_fd, ConnectionState& conn_state)
{
while (!conn_state.buffer.empty()) {
if (!conn_state.reading_body) {
// Look for complete headers
size_t header_end = conn_state.buffer.find("\r\n\r\n");
if (header_end == std::string::npos) {
// Headers not complete yet
return;
}
// Parse just the headers to get Content-Length
std::string_view headers_only(conn_state.buffer.data(), header_end + 4);
Request temp_req = Parser::parse(headers_only);
if (!temp_req.valid) {
send_error_response(client_fd, "Bad Request", 400);
conn_state.buffer.clear();
return;
}
// Check for chunked encoding
auto transfer_encoding = temp_req.headers.find("transfer-encoding");
if (transfer_encoding != temp_req.headers.end()) {
std::string_view encoding = transfer_encoding->second;
if (encoding.find("chunked") != std::string_view::npos) {
conn_state.chunked_encoding = true;
}
}
// Set up body reading state
conn_state.expected_body_length = temp_req.content_length;
conn_state.current_body_length = 0;
conn_state.reading_body = (conn_state.expected_body_length > 0 || conn_state.chunked_encoding);
if (!conn_state.reading_body) {
// No body expected, process the request immediately
std::string_view complete_request(conn_state.buffer.data(), header_end + 4);
process_single_request(client_fd, complete_request);
// Remove processed request from buffer
conn_state.buffer.erase(0, header_end + 4);
continue;
}
}
if (conn_state.reading_body) {
if (conn_state.chunked_encoding) {
// Basic chunked encoding support (simplified)
size_t chunk_end = conn_state.buffer.find("\r\n0\r\n\r\n");
if (chunk_end == std::string::npos) {
// Chunked body not complete yet
return;
}
// Process complete chunked request
std::string_view complete_request(conn_state.buffer.data(), chunk_end + 7);
process_single_request(client_fd, complete_request);
// Remove processed request from buffer
conn_state.buffer.erase(0, chunk_end + 7);
conn_state.reading_body = false;
conn_state.chunked_encoding = false;
} else {
// Fixed Content-Length body
size_t headers_end = conn_state.buffer.find("\r\n\r\n");
if (headers_end == std::string::npos) return;
size_t total_expected = headers_end + 4 + conn_state.expected_body_length;
if (conn_state.buffer.size() < total_expected) {
// Body not complete yet
return;
}
// Process complete request with body
std::string_view complete_request(conn_state.buffer.data(), total_expected);
process_single_request(client_fd, complete_request);
// Remove processed request from buffer
conn_state.buffer.erase(0, total_expected);
conn_state.reading_body = false;
}
}
}
}
void process_single_request(int client_fd, std::string_view request_data)
{
// Parse the complete HTTP request
Request req = Parser::parse(request_data);
if (!req.valid) {
send_error_response(client_fd, "Bad Request", 400);
return;
}
Response response;
// Check Connection header for keep-alive preference
bool client_wants_keepalive = true;
auto connection_header = req.headers.find("connection");
if (connection_header != req.headers.end()) {
std::string_view conn_value = connection_header->second;
if (conn_value.find("close") != std::string_view::npos) {
client_wants_keepalive = false;
}
}
// Try to handle with router first
if (router.handle(req, response)) {
// Handle session management
std::string_view existing_id = sessions.extract_session_id(req);
std::string session_id = existing_id.empty() ?
sessions.create() : std::string(existing_id);
set_session_cookie(response, session_id);
send_response(client_fd, response, req.version, client_wants_keepalive);
return;
}
// Try static file handler if available
if (static_handler && static_handler->handle(req, response)) {
send_response(client_fd, response, req.version, client_wants_keepalive);
return;
}
// No handler found - return 404
response.status = 404;
response.set_text("Not Found");
send_response(client_fd, response, req.version, client_wants_keepalive);
}
void set_session_cookie(Response& response, const std::string& session_id)
{
// Set secure session cookie
response.cookies.push_back("session_id=" + session_id +
"; HttpOnly; Path=/; SameSite=Strict");
}
void send_response(int client_fd, Response& response, string_view version, bool keep_alive)
{
// Set appropriate Connection header based on HTTP version and client preference
if (version == "HTTP/1.0") {
keep_alive = false; // HTTP/1.0 defaults to close
}
if (!keep_alive) {
response.headers["Connection"] = "close";
} else {
response.headers["Connection"] = "keep-alive";
response.headers["Keep-Alive"] = "timeout=" + std::to_string(keep_alive_timeout);
}
// Build and send HTTP response
std::string http_response = ResponseBuilder::build_response(response, version);
if (!socket.send_data(client_fd, http_response.data(), http_response.size())) {
// Send failed - connection probably closed
return;
}
// Close connection if not keeping alive
if (!keep_alive) {
// The epoll library will handle the actual close via EPOLLHUP
shutdown(client_fd, SHUT_WR);
}
}
void send_error_response(int client_fd, const std::string& message, int status)
{
// Fast path for error responses - always close connection on errors
std::string response = ResponseBuilder::build_error_response(status, message);
socket.send_data(client_fd, response.data(), response.size());
// Force connection close on errors
shutdown(client_fd, SHUT_WR);
}
// Periodic cleanup of idle connections (call from main loop if needed)
void cleanup_idle_connections()
{
auto now = std::chrono::steady_clock::now();
for (auto it = connections.begin(); it != connections.end();) {
auto age = std::chrono::duration_cast<std::chrono::seconds>(
now - it->second.last_activity).count();
if (age > keep_alive_timeout) {
// Connection is idle, close it
shutdown(it->first, SHUT_WR);
it = connections.erase(it);
} else {
++it;
}
}
}
};
uint16_t port_;
Router& router_;
std::vector<std::unique_ptr<Worker>> workers_;
};