diff --git a/epoll_socket.hpp b/epoll_socket.hpp index 717f95e..fe198e6 100644 --- a/epoll_socket.hpp +++ b/epoll_socket.hpp @@ -9,8 +9,13 @@ #include #include #include +#include +#include +#include +#include -class EpollSocket { +class EpollSocket +{ public: using ConnectionHandler = std::function; using DataHandler = std::function; @@ -18,19 +23,21 @@ public: explicit EpollSocket(uint16_t port = 8080) : port_(port) {} - ~EpollSocket() { - if (server_fd_ != -1) close(server_fd_); - if (epoll_fd_ != -1) close(epoll_fd_); + ~EpollSocket() + { + shutdown(); } - bool start() { + bool start() + { if (!create_server_socket()) return false; if (!create_epoll()) return false; if (!add_server_to_epoll()) return false; return true; } - void run() { + void run() + { std::array events; while (running_) { @@ -41,10 +48,20 @@ public: } for (int i = 0; i < num_events; ++i) { - if (events[i].data.fd == server_fd_) { + const auto& event = events[i]; + int fd = event.data.fd; + + if (fd == server_fd_) { accept_connections(); - } else { - handle_client_event(events[i].data.fd); + } else if (event.events & (EPOLLHUP | EPOLLERR)) { + // Client disconnected or error occurred + handle_disconnect(fd); + } else if (event.events & EPOLLIN) { + // Data available to read + handle_client_read(fd); + } else if (event.events & EPOLLOUT) { + // Socket ready for writing + handle_client_write(fd); } } } @@ -52,17 +69,42 @@ public: void stop() { running_ = false; } - bool add_client(int client_fd) { - epoll_event event{}; - event.events = EPOLLIN | EPOLLET; - event.data.fd = client_fd; + // Send data to a client, returns true if queued successfully + bool send_data(int client_fd, const void* data, size_t len) + { + if (clients_.find(client_fd) == clients_.end()) { + return false; // Client not found + } - return epoll_ctl(epoll_fd_, EPOLL_CTL_ADD, client_fd, &event) != -1; - } + auto& client = clients_[client_fd]; - void remove_client(int client_fd) { - epoll_ctl(epoll_fd_, EPOLL_CTL_DEL, client_fd, nullptr); - close(client_fd); + // Try immediate send if no pending data + if (client.write_buffer.empty()) { + ssize_t sent = send(client_fd, data, len, MSG_NOSIGNAL); + if (sent == static_cast(len)) { + return true; // All data sent immediately + } + if (sent > 0) { + // Partial send, queue the rest + const char* remaining = static_cast(data) + sent; + client.write_buffer.insert(client.write_buffer.end(), + remaining, remaining + (len - sent)); + } else if (sent == -1 && errno != EAGAIN && errno != EWOULDBLOCK) { + return false; // Real error + } else { + // EAGAIN/EWOULDBLOCK, queue all data + const char* bytes = static_cast(data); + client.write_buffer.insert(client.write_buffer.end(), bytes, bytes + len); + } + } else { + // Already have pending data, just queue + const char* bytes = static_cast(data); + client.write_buffer.insert(client.write_buffer.end(), bytes, bytes + len); + } + + // Enable EPOLLOUT to get notified when we can write + enable_write_events(client_fd); + return true; } // Event handlers @@ -70,28 +112,43 @@ public: void on_data(DataHandler handler) { on_data_ = std::move(handler); } void on_disconnect(DisconnectHandler handler) { on_disconnect_ = std::move(handler); } + // Get number of active connections + size_t connection_count() const { return clients_.size(); } + private: static constexpr int MAX_EVENTS = 1024; + static constexpr size_t MAX_WRITE_BUFFER = 64 * 1024; // 64KB max buffer per client + + struct ClientInfo + { + std::vector write_buffer; + bool write_enabled = false; + }; uint16_t port_; int server_fd_ = -1; int epoll_fd_ = -1; bool running_ = true; + std::unordered_map clients_; + ConnectionHandler on_connection_; DataHandler on_data_; DisconnectHandler on_disconnect_; - bool create_server_socket() { + bool create_server_socket() + { + // Create socket with non-blocking and close-on-exec flags server_fd_ = socket(AF_INET, SOCK_STREAM | SOCK_NONBLOCK | SOCK_CLOEXEC, 0); if (server_fd_ == -1) return false; int opt = 1; + // Allow address reuse to avoid "Address already in use" errors if (setsockopt(server_fd_, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt)) == -1) { return false; } - // Enable port reuse for load balancing + // Enable port reuse for load balancing multiple processes if (setsockopt(server_fd_, SOL_SOCKET, SO_REUSEPORT, &opt, sizeof(opt)) == -1) { return false; } @@ -110,38 +167,105 @@ private: return false; } + // Use SOMAXCONN for maximum backlog return listen(server_fd_, SOMAXCONN) != -1; } - bool create_epoll() { + bool create_epoll() + { + // Create epoll instance with close-on-exec flag epoll_fd_ = epoll_create1(EPOLL_CLOEXEC); return epoll_fd_ != -1; } - bool add_server_to_epoll() { + bool add_server_to_epoll() + { epoll_event event{}; - event.events = EPOLLIN | EPOLLET; + event.events = EPOLLIN | EPOLLET; // Edge-triggered for better performance event.data.fd = server_fd_; return epoll_ctl(epoll_fd_, EPOLL_CTL_ADD, server_fd_, &event) != -1; } - inline void accept_connections() { - while (true) { + bool add_client(int client_fd) + { + epoll_event event{}; + event.events = EPOLLIN | EPOLLET; // Start with read events only + event.data.fd = client_fd; + + if (epoll_ctl(epoll_fd_, EPOLL_CTL_ADD, client_fd, &event) == -1) { + return false; + } + + // Track client connection + clients_[client_fd] = ClientInfo{}; + return true; + } + + void remove_client(int client_fd) + { + // Remove from epoll (ignore errors as fd might be closed) + epoll_ctl(epoll_fd_, EPOLL_CTL_DEL, client_fd, nullptr); + + // Remove from client tracking + clients_.erase(client_fd); + + // Close the socket + close(client_fd); + } + + void enable_write_events(int client_fd) + { + auto it = clients_.find(client_fd); + if (it == clients_.end() || it->second.write_enabled) { + return; // Client not found or write already enabled + } + + epoll_event event{}; + event.events = EPOLLIN | EPOLLOUT | EPOLLET; + event.data.fd = client_fd; + + if (epoll_ctl(epoll_fd_, EPOLL_CTL_MOD, client_fd, &event) != -1) { + it->second.write_enabled = true; + } + } + + void disable_write_events(int client_fd) + { + auto it = clients_.find(client_fd); + if (it == clients_.end() || !it->second.write_enabled) { + return; // Client not found or write already disabled + } + + epoll_event event{}; + event.events = EPOLLIN | EPOLLET; + event.data.fd = client_fd; + + if (epoll_ctl(epoll_fd_, EPOLL_CTL_MOD, client_fd, &event) != -1) { + it->second.write_enabled = false; + } + } + + inline void accept_connections() + { + // Accept all pending connections in edge-triggered mode + while (running_) { sockaddr_in client_addr{}; socklen_t client_len = sizeof(client_addr); - // Use accept4 to set non-blocking atomically + // Use accept4 to set non-blocking and close-on-exec atomically int client_fd = accept4(server_fd_, reinterpret_cast(&client_addr), &client_len, SOCK_NONBLOCK | SOCK_CLOEXEC); if (client_fd == -1) { - if (errno == EAGAIN || errno == EWOULDBLOCK) break; - continue; + if (errno == EAGAIN || errno == EWOULDBLOCK) { + break; // No more connections to accept + } + continue; // Skip this connection on other errors } - // Set TCP_NODELAY for client connections + // Set TCP_NODELAY for client connections (low latency) int opt = 1; setsockopt(client_fd, IPPROTO_TCP, TCP_NODELAY, &opt, sizeof(opt)); @@ -153,7 +277,105 @@ private: } } - inline void handle_client_event(int client_fd) { + inline void handle_client_read(int client_fd) + { + // Check if client is still tracked + if (clients_.find(client_fd) == clients_.end()) { + return; + } + + // Use MSG_PEEK to check connection status without consuming data + char peek_buffer[1]; + ssize_t peek_result = recv(client_fd, peek_buffer, 1, MSG_PEEK | MSG_DONTWAIT); + + if (peek_result == 0) { + // Clean disconnect (FIN received) + handle_disconnect(client_fd); + return; + } + + if (peek_result == -1) { + if (errno == EAGAIN || errno == EWOULDBLOCK) { + // False alarm, no data available + return; + } + // Connection error + handle_disconnect(client_fd); + return; + } + + // Data is available, notify handler if (on_data_) on_data_(client_fd); } -}; \ No newline at end of file + + inline void handle_client_write(int client_fd) + { + auto it = clients_.find(client_fd); + if (it == clients_.end()) { + return; // Client no longer exists + } + + auto& client = it->second; + + // Send as much buffered data as possible + while (!client.write_buffer.empty()) { + ssize_t sent = send(client_fd, client.write_buffer.data(), + client.write_buffer.size(), MSG_NOSIGNAL); + + if (sent == -1) { + if (errno == EAGAIN || errno == EWOULDBLOCK) { + break; // Socket buffer full, try again later + } + // Connection error + handle_disconnect(client_fd); + return; + } + + if (sent == 0) { + // Shouldn't happen with MSG_NOSIGNAL, but handle it + break; + } + + // Remove sent data from buffer + client.write_buffer.erase(client.write_buffer.begin(), + client.write_buffer.begin() + sent); + } + + // If buffer is empty, disable write events to avoid busy-waiting + if (client.write_buffer.empty()) { + disable_write_events(client_fd); + } + } + + inline void handle_disconnect(int client_fd) + { + // Notify application before removing client + if (on_disconnect_) on_disconnect_(client_fd); + + // Clean up client resources + remove_client(client_fd); + } + + void shutdown() + { + running_ = false; + + // Close all client connections gracefully + for (const auto& [fd, client] : clients_) { + close(fd); + } + clients_.clear(); + + // Close server socket + if (server_fd_ != -1) { + close(server_fd_); + server_fd_ = -1; + } + + // Close epoll instance + if (epoll_fd_ != -1) { + close(epoll_fd_); + epoll_fd_ = -1; + } + } +}; diff --git a/server.hpp b/server.hpp index 6c7e9a9..b8afaed 100644 --- a/server.hpp +++ b/server.hpp @@ -16,17 +16,20 @@ #include #include -class Server { +class Server +{ public: KeyValueStore store; SessionStore sessions; explicit Server(uint16_t port, Router& router) : port_(port), router_(router) {} - ~Server() { + ~Server() + { stop(); store.save(); sessions.save(); + // Wait for all worker threads to finish for (auto& worker : workers_) { if (worker->thread.joinable()) { worker->thread.join(); @@ -34,21 +37,26 @@ public: } } - bool start() { + bool start() + { + // Load persistent data store.load(); sessions.load(); + // Create one worker per CPU core for optimal performance unsigned int num_cores = std::thread::hardware_concurrency(); if (num_cores == 0) num_cores = 1; workers_.reserve(num_cores); + // Initialize all worker sockets (SO_REUSEPORT allows this) for (unsigned int i = 0; i < num_cores; ++i) { auto worker = std::make_unique(port_, router_, static_handler_, sessions); if (!worker->socket.start()) return false; workers_.push_back(std::move(worker)); } + // Start worker threads for (auto& worker : workers_) { worker->thread = std::thread([&worker]() { worker->socket.run(); }); } @@ -56,7 +64,9 @@ public: return true; } - void run() { + void run() + { + // Wait for all workers to complete for (auto& worker : workers_) { if (worker->thread.joinable()) { worker->thread.join(); @@ -64,13 +74,17 @@ public: } } - void stop() { + void stop() + { + // Signal all workers to stop for (auto& worker : workers_) { worker->socket.stop(); } } - static std::string get_path_param(string_view path, size_t segment_index = 0) { + // Utility function to extract path segments + static std::string get_path_param(string_view path, size_t segment_index = 0) + { size_t start = 0; size_t current_segment = 0; @@ -93,7 +107,9 @@ public: return ""; } - void serve_static(const std::string& static_dir, const std::string& url_prefix = "") { + // Enable static file serving + void serve_static(const std::string& static_dir, const std::string& url_prefix = "") + { static_handler_ = std::make_shared(static_dir, url_prefix); } @@ -101,103 +117,143 @@ private: static constexpr int BUFFER_SIZE = 65536; std::shared_ptr static_handler_; - struct Worker { + // Worker handles requests in a dedicated thread + struct Worker + { EpollSocket socket; Router& router; std::shared_ptr& static_handler; SessionStore& sessions; std::array buffer; + std::string request_accumulator; // For handling partial requests std::thread thread; Worker(uint16_t port, Router& r, std::shared_ptr& sh, SessionStore& s) - : socket(port), router(r), static_handler(sh), sessions(s) { + : socket(port), router(r), static_handler(sh), sessions(s) + { + // Set up event handlers for the new epoll socket library socket.on_connection([this](int fd) { handle_connection(fd); }); socket.on_data([this](int fd) { handle_data(fd); }); socket.on_disconnect([this](int fd) { handle_disconnect(fd); }); } - void handle_connection(int client_fd) { - // Client connected + void handle_connection(int client_fd) + { + // New client connected - no action needed with new library + // The library handles connection tracking automatically } - void handle_data(int client_fd) { + void handle_data(int client_fd) + { + // Read all available data using edge-triggered epoll while (true) { ssize_t bytes_read = read(client_fd, buffer.data(), buffer.size()); if (bytes_read == -1) { - if (errno == EAGAIN || errno == EWOULDBLOCK) break; + if (errno == EAGAIN || errno == EWOULDBLOCK) { + // No more data available + break; + } + // Read error - connection will be closed by epoll library return; } - if (bytes_read == 0) return; + if (bytes_read == 0) { + // EOF - client closed connection + return; + } - std::string request_data(buffer.data(), bytes_read); - if (request_data.find("\r\n\r\n") != std::string::npos) { - process_request(client_fd, request_data); + // Accumulate request data + request_accumulator.append(buffer.data(), bytes_read); + + // Check if we have a complete HTTP request + size_t header_end = request_accumulator.find("\r\n\r\n"); + if (header_end != std::string::npos) { + // Process the complete request + process_request(client_fd, request_accumulator); + // Clear accumulator for next request (HTTP keep-alive) + request_accumulator.clear(); + return; + } + + // Prevent memory exhaustion from malicious large headers + if (request_accumulator.size() > 64 * 1024) { + send_error_response(client_fd, "Request Header Too Large", 431); + request_accumulator.clear(); + return; } } } - void handle_disconnect(int client_fd) { - // Client disconnected + void handle_disconnect(int client_fd) + { + // Client disconnected - cleanup handled by epoll library + // Clear any partial request data for this connection + request_accumulator.clear(); } - void process_request(int client_fd, std::string_view request_data) { - Request req = Parser::parse(request_data); + void process_request(int client_fd, std::string_view request_data) + { + // Parse the HTTP request + Request req = Parser::parse(request_data); - if (!req.valid) { - send_error_response(client_fd, "Bad Request", 400, req.version); - return; - } + if (!req.valid) { + send_error_response(client_fd, "Bad Request", 400); + return; + } - Response response; + Response response; - if (router.handle(req, response)) { - std::string_view existing_id = sessions.extract_session_id(req); - std::string session_id = existing_id.empty() ? - sessions.create() : std::string(existing_id); - - set_session_cookie(response, session_id); + // Try to handle with router first + if (router.handle(req, response)) { + // Handle session management + std::string_view existing_id = sessions.extract_session_id(req); + std::string session_id = existing_id.empty() ? + sessions.create() : std::string(existing_id); + + set_session_cookie(response, session_id); + send_response(client_fd, response, req.version); + return; + } + + // Try static file handler if available + if (static_handler && static_handler->handle(req, response)) { + send_response(client_fd, response, req.version); + return; + } + + // No handler found - return 404 + response.status = 404; + response.set_text("Not Found"); send_response(client_fd, response, req.version); - return; } - if (static_handler && static_handler->handle(req, response)) { - send_response(client_fd, response, req.version); - return; + void set_session_cookie(Response& response, const std::string& session_id) + { + // Set secure session cookie + response.cookies.push_back("session_id=" + session_id + + "; HttpOnly; Path=/; SameSite=Strict"); } - response.status = 404; - response.set_text("Not Found"); - send_response(client_fd, response, req.version); - } - - void set_session_cookie(Response& response, const std::string& session_id) { - response.cookies.push_back("session_id=" + session_id + "; HttpOnly; Path=/; SameSite=Strict"); - } - - void send_response(int client_fd, const Response& response, string_view version) { + void send_response(int client_fd, const Response& response, string_view version) + { + // Build HTTP response string std::string http_response = ResponseBuilder::build_response(response, version); - send_raw_response(client_fd, http_response); - } - void send_raw_response(int client_fd, const std::string& response) { - ssize_t total_sent = 0; - ssize_t response_len = response.size(); - - while (total_sent < response_len) { - ssize_t sent = write(client_fd, response.data() + total_sent, response_len - total_sent); - if (sent == -1) { - if (errno == EAGAIN || errno == EWOULDBLOCK) continue; - break; - } - total_sent += sent; + // Use the new epoll library's buffered send method + if (!socket.send_data(client_fd, http_response.data(), http_response.size())) { + // Send failed - connection probably closed + // Library will handle cleanup automatically } } - void send_error_response(int client_fd, const std::string& message, int status, string_view version) { - std::string response = ResponseBuilder::build_error_response(status, message, version); - send_raw_response(client_fd, response); + void send_error_response(int client_fd, const std::string& message, int status) + { + // Fast path for error responses + std::string response = ResponseBuilder::build_error_response(status, message); + + // Use the new epoll library's send method + socket.send_data(client_fd, response.data(), response.size()); } };