commit 78780555f00e8c3a9556182fa83295d55d65601a Author: Sky Johnson Date: Thu Jun 12 19:01:53 2025 -0500 first commit diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..5402af2 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,58 @@ +{ + "files.associations": { + "*.mllo.php": "blade", + "*.template": "blade", + "array": "cpp", + "atomic": "cpp", + "bit": "cpp", + "cctype": "cpp", + "clocale": "cpp", + "cmath": "cpp", + "compare": "cpp", + "concepts": "cpp", + "cstdarg": "cpp", + "cstddef": "cpp", + "cstdint": "cpp", + "cstdio": "cpp", + "cstdlib": "cpp", + "ctime": "cpp", + "cwchar": "cpp", + "cwctype": "cpp", + "deque": "cpp", + "string": "cpp", + "unordered_map": "cpp", + "vector": "cpp", + "exception": "cpp", + "algorithm": "cpp", + "functional": "cpp", + "iterator": "cpp", + "memory": "cpp", + "memory_resource": "cpp", + "numeric": "cpp", + "optional": "cpp", + "random": "cpp", + "ratio": "cpp", + "string_view": "cpp", + "system_error": "cpp", + "tuple": "cpp", + "type_traits": "cpp", + "utility": "cpp", + "initializer_list": "cpp", + "iosfwd": "cpp", + "iostream": "cpp", + "istream": "cpp", + "limits": "cpp", + "new": "cpp", + "numbers": "cpp", + "ostream": "cpp", + "semaphore": "cpp", + "span": "cpp", + "sstream": "cpp", + "stdexcept": "cpp", + "stop_token": "cpp", + "streambuf": "cpp", + "thread": "cpp", + "cinttypes": "cpp", + "typeinfo": "cpp" + } +} diff --git a/http_common.hpp b/http_common.hpp new file mode 100644 index 0000000..a9938c2 --- /dev/null +++ b/http_common.hpp @@ -0,0 +1,6 @@ +#pragma once +#include + +enum class HttpMethod : uint8_t { + GET, POST, PUT, DELETE, PATCH, HEAD, OPTIONS, UNKNOWN +}; diff --git a/http_parser.hpp b/http_parser.hpp new file mode 100644 index 0000000..d901b2f --- /dev/null +++ b/http_parser.hpp @@ -0,0 +1,150 @@ +#pragma once + +#include "http_common.hpp" +#include "router.hpp" +#include +#include +#include + +using std::string_view; + +struct HttpRequest { + HttpMethod method = HttpMethod::UNKNOWN; + string_view path; + string_view query; + string_view version; + string_view body; + std::unordered_map headers; + std::unordered_map params; // URL parameters + size_t content_length = 0; + + bool valid = false; +}; + +class HttpParser { +public: + static HttpRequest parse(string_view data) { + HttpRequest req; + const char* ptr = data.data(); + const char* end = ptr + data.size(); + + // Parse method + const char* method_end = find_char(ptr, end, ' '); + if (!method_end) return req; + + req.method = parse_method(string_view(ptr, method_end - ptr)); + ptr = method_end + 1; + + // Parse path and query + const char* path_end = find_char(ptr, end, ' '); + if (!path_end) return req; + + const char* query_start = find_char(ptr, path_end, '?'); + if (query_start) { + req.path = string_view(ptr, query_start - ptr); + req.query = string_view(query_start + 1, path_end - query_start - 1); + } else { + req.path = string_view(ptr, path_end - ptr); + } + ptr = path_end + 1; + + // Parse version + const char* version_end = find_char(ptr, end, '\r'); + if (!version_end || version_end + 1 >= end || *(version_end + 1) != '\n') return req; + + req.version = string_view(ptr, version_end - ptr); + ptr = version_end + 2; + + // Parse headers + while (ptr < end - 1) { + if (*ptr == '\r' && *(ptr + 1) == '\n') { + // End of headers + ptr += 2; + break; + } + + const char* header_end = find_char(ptr, end, '\r'); + if (!header_end || header_end + 1 >= end || *(header_end + 1) != '\n') break; + + const char* colon = find_char(ptr, header_end, ':'); + if (!colon) { + ptr = header_end + 2; + continue; + } + + string_view name(ptr, colon - ptr); + const char* value_start = colon + 1; + while (value_start < header_end && *value_start == ' ') value_start++; + + string_view value(value_start, header_end - value_start); + req.headers[name] = value; + + // Check for Content-Length + if (name.size() == 14 && strncasecmp(name.data(), "content-length", 14) == 0) { + req.content_length = parse_int(value); + } + + ptr = header_end + 2; + } + + // Body + if (ptr < end) { + req.body = string_view(ptr, end - ptr); + } + + req.valid = true; + return req; + } + +private: + static const char* find_char(const char* start, const char* end, char c) { + for (const char* p = start; p < end; ++p) { + if (*p == c) return p; + } + return nullptr; + } + + static HttpMethod parse_method(string_view method) { + switch (method.size()) { + case 3: + if (method == "GET") return HttpMethod::GET; + if (method == "PUT") return HttpMethod::PUT; + break; + case 4: + if (method == "POST") return HttpMethod::POST; + if (method == "HEAD") return HttpMethod::HEAD; + break; + case 5: + if (method == "PATCH") return HttpMethod::PATCH; + break; + case 6: + if (method == "DELETE") return HttpMethod::DELETE; + break; + case 7: + if (method == "OPTIONS") return HttpMethod::OPTIONS; + break; + } + return HttpMethod::UNKNOWN; + } + + static size_t parse_int(string_view str) { + size_t result = 0; + for (char c : str) { + if (c >= '0' && c <= '9') { + result = result * 10 + (c - '0'); + } else { + break; + } + } + return result; + } + + static int strncasecmp(const char* s1, const char* s2, size_t n) { + for (size_t i = 0; i < n; ++i) { + char c1 = s1[i] >= 'A' && s1[i] <= 'Z' ? s1[i] + 32 : s1[i]; + char c2 = s2[i] >= 'A' && s2[i] <= 'Z' ? s2[i] + 32 : s2[i]; + if (c1 != c2) return c1 - c2; + } + return 0; + } +}; diff --git a/main.cpp b/main.cpp new file mode 100644 index 0000000..9496b57 --- /dev/null +++ b/main.cpp @@ -0,0 +1,90 @@ +#include "server.hpp" +#include +#include + +HttpServer* server = nullptr; + +void signal_handler(int sig) { + if (server) { + std::cout << "\nShutting down server...\n"; + delete server; + exit(0); + } +} + +int main() { + signal(SIGINT, signal_handler); + signal(SIGTERM, signal_handler); + + Router router; + + // Root route + router.get("/", [](const HttpRequest& req, HttpResponse& res) { + res.set_text("Hello, World! HTTP Server with Router\n"); + }); + + // API routes + router.get("/api/status", [](const HttpRequest& req, HttpResponse& res) { + res.set_json("{\"status\":\"running\",\"version\":\"1.0\"}"); + }); + + // Users routes + router.get("/users", [](const HttpRequest& req, HttpResponse& res) { + res.set_json("{\"users\":[{\"id\":1,\"name\":\"Alice\"},{\"id\":2,\"name\":\"Bob\"}]}"); + }); + + router.get("/users/:id", [](const HttpRequest& req, HttpResponse& res) { + std::string id = req.params.at("id"); + std::string response = "{\"id\":" + id + ",\"name\":\"User" + id + "\"}"; + res.set_json(response); + }); + + router.post("/users", [](const HttpRequest& req, HttpResponse& res) { + std::string body_preview = std::string(req.body.substr(0, 100)); + res.set_json("{\"message\":\"User created\",\"received\":\"" + body_preview + "\"}"); + res.status = 201; + }); + + // Request info route + router.get("/request-info", [](const HttpRequest& req, HttpResponse& res) { + std::string info = "Method: " + std::to_string(static_cast(req.method)) + "\n"; + info += "Path: " + std::string(req.path) + "\n"; + info += "Query: " + std::string(req.query) + "\n"; + info += "Headers: " + std::to_string(req.headers.size()) + "\n"; + info += "Body length: " + std::to_string(req.body.size()) + "\n"; + res.set_text(info); + }); + + // Different HTTP methods + router.put("/users/:id", [](const HttpRequest& req, HttpResponse& res) { + std::string id = req.params.at("id"); + res.set_json("{\"message\":\"User " + id + " updated\"}"); + }); + + router.del("/users/:id", [](const HttpRequest& req, HttpResponse& res) { + std::string id = req.params.at("id"); + res.set_json("{\"message\":\"User " + id + " deleted\"}"); + }); + + server = new HttpServer(8080, router); + + if (!server->start()) { + std::cerr << "Failed to start server\n"; + return 1; + } + + std::cout << "Server running on http://localhost:8080\n"; + std::cout << "Test routes:\n"; + std::cout << " GET /\n"; + std::cout << " GET /api/status\n"; + std::cout << " GET /users\n"; + std::cout << " GET /users/123\n"; + std::cout << " POST /users\n"; + std::cout << " GET /request-info\n"; + std::cout << "Press Ctrl+C to stop\n"; + + server->run(); + + delete server; + return 0; +} diff --git a/router.hpp b/router.hpp new file mode 100644 index 0000000..4f6d920 --- /dev/null +++ b/router.hpp @@ -0,0 +1,200 @@ +#pragma once +#include "http_common.hpp" +#include "http_parser.hpp" +#include +#include +#include +#include +#include +#include + +using std::string_view; +struct HttpResponse { + int status = 200; + std::string body; + std::string content_type = "text/plain"; + std::unordered_map headers; + + void set_json(const std::string& json) { + body = json; + content_type = "application/json"; + } + + void set_text(const std::string& text) { + body = text; + content_type = "text/plain"; + } + + void set_html(const std::string& html) { + body = html; + content_type = "text/html"; + } +}; + +using Handler = std::function; + +struct TrieNode { + std::string prefix; + std::string param_name; // For parameter nodes + std::unordered_map handlers; + std::unordered_map> children; + bool is_param = false; + + TrieNode(string_view p = "") : prefix(p) {} +}; + +class Router { +private: + std::unique_ptr root = std::make_unique(); + + TrieNode* insert_path(string_view path) { + TrieNode* current = root.get(); + size_t i = 0; + + while (i < path.length()) { + if (path[i] == '/') { + i++; + continue; + } + + bool is_param = path[i] == ':'; + if (is_param) i++; // skip ':' + + size_t start = i; + while (i < path.length() && path[i] != '/') i++; + + std::string segment(path.substr(start, i - start)); + + if (is_param) { + // All parameters share the same node + auto it = current->children.find(':'); + if (it == current->children.end()) { + auto new_node = std::make_unique(); + new_node->is_param = true; + new_node->param_name = segment; + current->children[':'] = std::move(new_node); + } + current = current->children[':'].get(); + } else { + char first_char = segment[0]; + auto it = current->children.find(first_char); + if (it == current->children.end()) { + auto new_node = std::make_unique(segment); + current->children[first_char] = std::move(new_node); + current = current->children[first_char].get(); + } else { + current = it->second.get(); + if (current->prefix != segment) { + // Handle prefix mismatch - split node if needed + auto new_node = std::make_unique(segment); + current->children[first_char] = std::move(new_node); + current = current->children[first_char].get(); + } + } + } + } + return current; + } + + TrieNode* find_path(string_view path, std::unordered_map& params) const { + TrieNode* current = root.get(); + size_t i = 0; + + while (i < path.length() && current) { + if (path[i] == '/') { + i++; + continue; + } + + size_t start = i; + while (i < path.length() && path[i] != '/') i++; + + string_view segment = path.substr(start, i - start); + + // Try exact match first + auto it = current->children.find(segment[0]); + if (it != current->children.end() && it->second->prefix == segment) { + current = it->second.get(); + continue; + } + + // Try parameter match + auto param_it = current->children.find(':'); + if (param_it != current->children.end()) { + current = param_it->second.get(); + params[current->param_name] = std::string(segment); + continue; + } + + return nullptr; + } + return current; + } + +public: + void get(string_view path, Handler handler) { + insert_path(path)->handlers[HttpMethod::GET] = std::move(handler); + } + + void post(string_view path, Handler handler) { + insert_path(path)->handlers[HttpMethod::POST] = std::move(handler); + } + + void put(string_view path, Handler handler) { + insert_path(path)->handlers[HttpMethod::PUT] = std::move(handler); + } + + void del(string_view path, Handler handler) { + insert_path(path)->handlers[HttpMethod::DELETE] = std::move(handler); + } + + void patch(string_view path, Handler handler) { + insert_path(path)->handlers[HttpMethod::PATCH] = std::move(handler); + } + + bool route(HttpMethod method, string_view path, Handler& out_handler, std::unordered_map& params) const { + TrieNode* node = find_path(path, params); + if (!node) return false; + + auto it = node->handlers.find(method); + if (it == node->handlers.end()) return false; + + out_handler = it->second; + return true; + } + + bool handle(HttpRequest& request, HttpResponse& response) const { + Handler handler; + if (route(request.method, request.path, handler, request.params)) { + handler(request, response); + return true; + } + return false; + } +}; + +// Usage example: +/* +Router router; + +router.get("/users", [](const HttpRequest& req, HttpResponse& res) { + res.set_json("{\"users\": []}"); +}); + +router.get("/users/:id", [](const HttpRequest& req, HttpResponse& res) { + std::string id = req.params.at("id"); + res.set_json("{\"id\": \"" + id + "\"}"); +}); + +router.post("/users", [](const HttpRequest& req, HttpResponse& res) { + // Use req.body to get POST data + res.status = 201; + res.set_json("{\"created\": true}"); +}); + +router.get("/users/:id/posts/:postId", [](const HttpRequest& req, HttpResponse& res) { + std::string userId = req.params.at("id"); + std::string postId = req.params.at("postId"); + res.set_json("{\"userId\": \"" + userId + "\", \"postId\": \"" + postId + "\"}"); +}); +*/ diff --git a/server b/server new file mode 100755 index 0000000..e9f8251 Binary files /dev/null and b/server differ diff --git a/server.hpp b/server.hpp new file mode 100644 index 0000000..b930bce --- /dev/null +++ b/server.hpp @@ -0,0 +1,149 @@ +#pragma once + +#include "sockets/epoll.hpp" +#include "router.hpp" +#include "http_parser.hpp" +#include +#include +#include +#include +#include + +class HttpServer { +public: + explicit HttpServer(uint16_t port, Router& router) : socket_(port), router_(router) { + socket_.on_connection([this](int fd) { handle_connection(fd); }); + socket_.on_data([this](int fd) { handle_data(fd); }); + socket_.on_disconnect([this](int fd) { handle_disconnect(fd); }); + } + + bool start() { return socket_.start(); } + void run() { socket_.run(); } + void stop() { socket_.stop(); } + + // Utility function to extract path parameters + static std::string get_path_param(string_view path, size_t segment_index = 0) { + size_t start = 0; + size_t current_segment = 0; + + while (start < path.length()) { + if (path[start] == '/') { + start++; + continue; + } + + size_t end = path.find('/', start); + if (end == std::string_view::npos) end = path.length(); + + if (current_segment == segment_index) { + return std::string(path.substr(start, end - start)); + } + + current_segment++; + start = end; + } + return ""; + } + +private: + static constexpr int BUFFER_SIZE = 8192; + + EpollSocket socket_; + Router& router_; + std::array buffer_; + + void handle_connection(int client_fd) { + // Client connected + } + + void handle_data(int client_fd) { + while (true) { + ssize_t bytes_read = read(client_fd, buffer_.data(), buffer_.size()); + + if (bytes_read == -1) { + if (errno == EAGAIN || errno == EWOULDBLOCK) break; + return; + } + + if (bytes_read == 0) return; + + std::string request_data(buffer_.data(), bytes_read); + if (request_data.find("\r\n\r\n") != std::string::npos) { + process_request(client_fd, request_data); + } + } + } + + void handle_disconnect(int client_fd) { + // Client disconnected + } + + void process_request(int client_fd, std::string_view request_data) { + HttpRequest req = HttpParser::parse(request_data); + + if (!req.valid) { + send_error_response(client_fd, "Bad Request", 400); + return; + } + + HttpResponse response; + + // Try to route the request (router will populate req.params) + if (router_.handle(req, response)) { + send_http_response(client_fd, response); + } else { + response.status = 404; + response.set_text("Not Found"); + send_http_response(client_fd, response); + } + } + + void send_http_response(int client_fd, const HttpResponse& response) { + std::string http_response = "HTTP/1.1 " + std::to_string(response.status); + + switch (response.status) { + case 200: http_response += " OK"; break; + case 201: http_response += " Created"; break; + case 400: http_response += " Bad Request"; break; + case 404: http_response += " Not Found"; break; + case 500: http_response += " Internal Server Error"; break; + default: http_response += " Unknown"; break; + } + + http_response += "\r\n"; + http_response += "Content-Type: " + response.content_type + "\r\n"; + http_response += "Content-Length: " + std::to_string(response.body.size()) + "\r\n"; + + // Add custom headers + for (const auto& [key, value] : response.headers) { + http_response += key + ": " + value + "\r\n"; + } + + http_response += "Connection: keep-alive\r\n\r\n"; + http_response += response.body; + + send_raw_response(client_fd, http_response); + } + + void send_raw_response(int client_fd, const std::string& response) { + ssize_t total_sent = 0; + ssize_t response_len = response.size(); + + while (total_sent < response_len) { + ssize_t sent = write(client_fd, response.data() + total_sent, response_len - total_sent); + if (sent == -1) { + if (errno == EAGAIN || errno == EWOULDBLOCK) continue; + break; + } + total_sent += sent; + } + } + + void send_error_response(int client_fd, const std::string& message, int status) { + HttpResponse response; + response.status = status; + response.set_text(message); + response.headers["Connection"] = "close"; + send_http_response(client_fd, response); + } +}; diff --git a/sockets/epoll.hpp b/sockets/epoll.hpp new file mode 100644 index 0000000..989499b --- /dev/null +++ b/sockets/epoll.hpp @@ -0,0 +1,156 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +class EpollSocket { +public: + using ConnectionHandler = std::function; + using DataHandler = std::function; + using DisconnectHandler = std::function; + + explicit EpollSocket(uint16_t port = 8080) : port_(port) {} + + ~EpollSocket() { + if (server_fd_ != -1) close(server_fd_); + if (epoll_fd_ != -1) close(epoll_fd_); + } + + bool start() { + if (!create_server_socket()) return false; + if (!create_epoll()) return false; + if (!add_server_to_epoll()) return false; + return true; + } + + void run() { + std::array events; + + while (running_) { + int num_events = epoll_wait(epoll_fd_, events.data(), MAX_EVENTS, -1); + if (num_events == -1) break; + + for (int i = 0; i < num_events; ++i) { + if (events[i].data.fd == server_fd_) { + accept_connections(); + } else { + handle_client_event(events[i].data.fd); + } + } + } + } + + void stop() { running_ = false; } + + bool add_client(int client_fd) { + epoll_event event{}; + event.events = EPOLLIN | EPOLLET; + event.data.fd = client_fd; + + return epoll_ctl(epoll_fd_, EPOLL_CTL_ADD, client_fd, &event) != -1; + } + + void remove_client(int client_fd) { + epoll_ctl(epoll_fd_, EPOLL_CTL_DEL, client_fd, nullptr); + close(client_fd); + } + + // Event handlers + void on_connection(ConnectionHandler handler) { on_connection_ = std::move(handler); } + void on_data(DataHandler handler) { on_data_ = std::move(handler); } + void on_disconnect(DisconnectHandler handler) { on_disconnect_ = std::move(handler); } + +private: + static constexpr int MAX_EVENTS = 1024; + + uint16_t port_; + int server_fd_ = -1; + int epoll_fd_ = -1; + bool running_ = true; + + ConnectionHandler on_connection_; + DataHandler on_data_; + DisconnectHandler on_disconnect_; + + bool create_server_socket() { + server_fd_ = socket(AF_INET, SOCK_STREAM | SOCK_NONBLOCK | SOCK_CLOEXEC, 0); + if (server_fd_ == -1) return false; + + int opt = 1; + if (setsockopt(server_fd_, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt)) == -1) { + return false; + } + + // Enable port reuse for load balancing + if (setsockopt(server_fd_, SOL_SOCKET, SO_REUSEPORT, &opt, sizeof(opt)) == -1) { + return false; + } + + // Disable Nagle's algorithm for lower latency + if (setsockopt(server_fd_, IPPROTO_TCP, TCP_NODELAY, &opt, sizeof(opt)) == -1) { + return false; + } + + sockaddr_in addr{}; + addr.sin_family = AF_INET; + addr.sin_addr.s_addr = INADDR_ANY; + addr.sin_port = htons(port_); + + if (bind(server_fd_, reinterpret_cast(&addr), sizeof(addr)) == -1) { + return false; + } + + return listen(server_fd_, SOMAXCONN) != -1; + } + + bool create_epoll() { + epoll_fd_ = epoll_create1(EPOLL_CLOEXEC); + return epoll_fd_ != -1; + } + + bool add_server_to_epoll() { + epoll_event event{}; + event.events = EPOLLIN | EPOLLET; + event.data.fd = server_fd_; + return epoll_ctl(epoll_fd_, EPOLL_CTL_ADD, server_fd_, &event) != -1; + } + + inline void accept_connections() { + while (true) { + sockaddr_in client_addr{}; + socklen_t client_len = sizeof(client_addr); + + // Use accept4 to set non-blocking atomically + int client_fd = accept4(server_fd_, + reinterpret_cast(&client_addr), + &client_len, + SOCK_NONBLOCK | SOCK_CLOEXEC); + + if (client_fd == -1) { + if (errno == EAGAIN || errno == EWOULDBLOCK) break; + continue; + } + + // Set TCP_NODELAY for client connections + int opt = 1; + setsockopt(client_fd, IPPROTO_TCP, TCP_NODELAY, &opt, sizeof(opt)); + + if (add_client(client_fd)) { + if (on_connection_) on_connection_(client_fd); + } else { + close(client_fd); + } + } + } + + inline void handle_client_event(int client_fd) { + if (on_data_) on_data_(client_fd); + } +}; diff --git a/sockets/uring.hpp b/sockets/uring.hpp new file mode 100644 index 0000000..ca20983 --- /dev/null +++ b/sockets/uring.hpp @@ -0,0 +1,182 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +enum class OpType : uint8_t { + ACCEPT, + READ, + WRITE +}; + +struct IoData { + OpType type; + int fd; + char* buffer; + size_t size; +}; + +class IoUringSocket { +public: + using ConnectionHandler = std::function; + using DataHandler = std::function; + using DisconnectHandler = std::function; + + explicit IoUringSocket(uint16_t port = 8080, int queue_depth = 256) + : port_(port), queue_depth_(queue_depth) {} + + ~IoUringSocket() { + if (server_fd_ != -1) close(server_fd_); + io_uring_queue_exit(&ring_); + for (auto& [fd, buffer] : buffers_) { + delete[] buffer; + } + } + + bool start() { + if (!create_server_socket()) return false; + if (io_uring_queue_init(queue_depth_, &ring_, 0) < 0) return false; + submit_accept(); + return true; + } + + void run() { + while (running_) { + io_uring_cqe* cqe; + int ret = io_uring_wait_cqe(&ring_, &cqe); + if (ret < 0) break; + + handle_completion(cqe); + io_uring_cqe_seen(&ring_, cqe); + } + } + + void stop() { running_ = false; } + + // Event handlers + void on_connection(ConnectionHandler handler) { on_connection_ = std::move(handler); } + void on_data(DataHandler handler) { on_data_ = std::move(handler); } + void on_disconnect(DisconnectHandler handler) { on_disconnect_ = std::move(handler); } + +private: + static constexpr int BUFFER_SIZE = 8192; + + uint16_t port_; + int queue_depth_; + int server_fd_ = -1; + bool running_ = true; + io_uring ring_; + + std::unordered_map buffers_; + ConnectionHandler on_connection_; + DataHandler on_data_; + DisconnectHandler on_disconnect_; + + bool create_server_socket() { + server_fd_ = socket(AF_INET, SOCK_STREAM, 0); + if (server_fd_ == -1) return false; + + int opt = 1; + if (setsockopt(server_fd_, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt)) == -1) { + return false; + } + + sockaddr_in addr{}; + addr.sin_family = AF_INET; + addr.sin_addr.s_addr = INADDR_ANY; + addr.sin_port = htons(port_); + + if (bind(server_fd_, reinterpret_cast(&addr), sizeof(addr)) == -1) { + return false; + } + + return listen(server_fd_, SOMAXCONN) != -1; + } + + void submit_accept() { + auto* data = new IoData{OpType::ACCEPT, server_fd_, nullptr, 0}; + + io_uring_sqe* sqe = io_uring_get_sqe(&ring_); + io_uring_prep_accept(sqe, server_fd_, nullptr, nullptr, 0); + io_uring_sqe_set_data(sqe, data); + io_uring_submit(&ring_); + } + + void submit_read(int client_fd) { + if (buffers_.find(client_fd) == buffers_.end()) { + buffers_[client_fd] = new char[BUFFER_SIZE]; + } + + auto* data = new IoData{OpType::READ, client_fd, buffers_[client_fd], BUFFER_SIZE}; + + io_uring_sqe* sqe = io_uring_get_sqe(&ring_); + io_uring_prep_read(sqe, client_fd, data->buffer, data->size, 0); + io_uring_sqe_set_data(sqe, data); + io_uring_submit(&ring_); + } + + void handle_completion(io_uring_cqe* cqe) { + auto* data = static_cast(io_uring_cqe_get_data(cqe)); + if (!data) return; + + switch (data->type) { + case OpType::ACCEPT: + handle_accept(cqe->res); + submit_accept(); // Continue accepting + break; + + case OpType::READ: + handle_read(data->fd, cqe->res, data->buffer); + break; + + case OpType::WRITE: + // Write completed, nothing special needed + break; + } + + delete data; + } + + void handle_accept(int result) { + if (result < 0) return; + + int client_fd = result; + if (on_connection_) on_connection_(client_fd); + submit_read(client_fd); + } + + void handle_read(int client_fd, int result, char* buffer) { + if (result <= 0) { + if (on_disconnect_) on_disconnect_(client_fd); + cleanup_client(client_fd); + return; + } + + if (on_data_) on_data_(client_fd, buffer, result); + submit_read(client_fd); // Continue reading + } + + void cleanup_client(int client_fd) { + close(client_fd); + if (buffers_.find(client_fd) != buffers_.end()) { + delete[] buffers_[client_fd]; + buffers_.erase(client_fd); + } + } + +public: + void write_async(int client_fd, const char* data, size_t len) { + auto* io_data = new IoData{OpType::WRITE, client_fd, nullptr, len}; + + io_uring_sqe* sqe = io_uring_get_sqe(&ring_); + io_uring_prep_write(sqe, client_fd, data, len, 0); + io_uring_sqe_set_data(sqe, io_data); + io_uring_submit(&ring_); + } +};