first commit

This commit is contained in:
Sky Johnson 2025-06-12 19:01:53 -05:00
commit 78780555f0
9 changed files with 991 additions and 0 deletions

58
.vscode/settings.json vendored Normal file
View File

@ -0,0 +1,58 @@
{
"files.associations": {
"*.mllo.php": "blade",
"*.template": "blade",
"array": "cpp",
"atomic": "cpp",
"bit": "cpp",
"cctype": "cpp",
"clocale": "cpp",
"cmath": "cpp",
"compare": "cpp",
"concepts": "cpp",
"cstdarg": "cpp",
"cstddef": "cpp",
"cstdint": "cpp",
"cstdio": "cpp",
"cstdlib": "cpp",
"ctime": "cpp",
"cwchar": "cpp",
"cwctype": "cpp",
"deque": "cpp",
"string": "cpp",
"unordered_map": "cpp",
"vector": "cpp",
"exception": "cpp",
"algorithm": "cpp",
"functional": "cpp",
"iterator": "cpp",
"memory": "cpp",
"memory_resource": "cpp",
"numeric": "cpp",
"optional": "cpp",
"random": "cpp",
"ratio": "cpp",
"string_view": "cpp",
"system_error": "cpp",
"tuple": "cpp",
"type_traits": "cpp",
"utility": "cpp",
"initializer_list": "cpp",
"iosfwd": "cpp",
"iostream": "cpp",
"istream": "cpp",
"limits": "cpp",
"new": "cpp",
"numbers": "cpp",
"ostream": "cpp",
"semaphore": "cpp",
"span": "cpp",
"sstream": "cpp",
"stdexcept": "cpp",
"stop_token": "cpp",
"streambuf": "cpp",
"thread": "cpp",
"cinttypes": "cpp",
"typeinfo": "cpp"
}
}

6
http_common.hpp Normal file
View File

@ -0,0 +1,6 @@
#pragma once
#include <cstdint>
enum class HttpMethod : uint8_t {
GET, POST, PUT, DELETE, PATCH, HEAD, OPTIONS, UNKNOWN
};

150
http_parser.hpp Normal file
View File

@ -0,0 +1,150 @@
#pragma once
#include "http_common.hpp"
#include "router.hpp"
#include <string_view>
#include <unordered_map>
#include <string>
using std::string_view;
struct HttpRequest {
HttpMethod method = HttpMethod::UNKNOWN;
string_view path;
string_view query;
string_view version;
string_view body;
std::unordered_map<string_view, string_view> headers;
std::unordered_map<std::string, std::string> params; // URL parameters
size_t content_length = 0;
bool valid = false;
};
class HttpParser {
public:
static HttpRequest parse(string_view data) {
HttpRequest req;
const char* ptr = data.data();
const char* end = ptr + data.size();
// Parse method
const char* method_end = find_char(ptr, end, ' ');
if (!method_end) return req;
req.method = parse_method(string_view(ptr, method_end - ptr));
ptr = method_end + 1;
// Parse path and query
const char* path_end = find_char(ptr, end, ' ');
if (!path_end) return req;
const char* query_start = find_char(ptr, path_end, '?');
if (query_start) {
req.path = string_view(ptr, query_start - ptr);
req.query = string_view(query_start + 1, path_end - query_start - 1);
} else {
req.path = string_view(ptr, path_end - ptr);
}
ptr = path_end + 1;
// Parse version
const char* version_end = find_char(ptr, end, '\r');
if (!version_end || version_end + 1 >= end || *(version_end + 1) != '\n') return req;
req.version = string_view(ptr, version_end - ptr);
ptr = version_end + 2;
// Parse headers
while (ptr < end - 1) {
if (*ptr == '\r' && *(ptr + 1) == '\n') {
// End of headers
ptr += 2;
break;
}
const char* header_end = find_char(ptr, end, '\r');
if (!header_end || header_end + 1 >= end || *(header_end + 1) != '\n') break;
const char* colon = find_char(ptr, header_end, ':');
if (!colon) {
ptr = header_end + 2;
continue;
}
string_view name(ptr, colon - ptr);
const char* value_start = colon + 1;
while (value_start < header_end && *value_start == ' ') value_start++;
string_view value(value_start, header_end - value_start);
req.headers[name] = value;
// Check for Content-Length
if (name.size() == 14 && strncasecmp(name.data(), "content-length", 14) == 0) {
req.content_length = parse_int(value);
}
ptr = header_end + 2;
}
// Body
if (ptr < end) {
req.body = string_view(ptr, end - ptr);
}
req.valid = true;
return req;
}
private:
static const char* find_char(const char* start, const char* end, char c) {
for (const char* p = start; p < end; ++p) {
if (*p == c) return p;
}
return nullptr;
}
static HttpMethod parse_method(string_view method) {
switch (method.size()) {
case 3:
if (method == "GET") return HttpMethod::GET;
if (method == "PUT") return HttpMethod::PUT;
break;
case 4:
if (method == "POST") return HttpMethod::POST;
if (method == "HEAD") return HttpMethod::HEAD;
break;
case 5:
if (method == "PATCH") return HttpMethod::PATCH;
break;
case 6:
if (method == "DELETE") return HttpMethod::DELETE;
break;
case 7:
if (method == "OPTIONS") return HttpMethod::OPTIONS;
break;
}
return HttpMethod::UNKNOWN;
}
static size_t parse_int(string_view str) {
size_t result = 0;
for (char c : str) {
if (c >= '0' && c <= '9') {
result = result * 10 + (c - '0');
} else {
break;
}
}
return result;
}
static int strncasecmp(const char* s1, const char* s2, size_t n) {
for (size_t i = 0; i < n; ++i) {
char c1 = s1[i] >= 'A' && s1[i] <= 'Z' ? s1[i] + 32 : s1[i];
char c2 = s2[i] >= 'A' && s2[i] <= 'Z' ? s2[i] + 32 : s2[i];
if (c1 != c2) return c1 - c2;
}
return 0;
}
};

90
main.cpp Normal file
View File

@ -0,0 +1,90 @@
#include "server.hpp"
#include <iostream>
#include <signal.h>
HttpServer* server = nullptr;
void signal_handler(int sig) {
if (server) {
std::cout << "\nShutting down server...\n";
delete server;
exit(0);
}
}
int main() {
signal(SIGINT, signal_handler);
signal(SIGTERM, signal_handler);
Router router;
// Root route
router.get("/", [](const HttpRequest& req, HttpResponse& res) {
res.set_text("Hello, World! HTTP Server with Router\n");
});
// API routes
router.get("/api/status", [](const HttpRequest& req, HttpResponse& res) {
res.set_json("{\"status\":\"running\",\"version\":\"1.0\"}");
});
// Users routes
router.get("/users", [](const HttpRequest& req, HttpResponse& res) {
res.set_json("{\"users\":[{\"id\":1,\"name\":\"Alice\"},{\"id\":2,\"name\":\"Bob\"}]}");
});
router.get("/users/:id", [](const HttpRequest& req, HttpResponse& res) {
std::string id = req.params.at("id");
std::string response = "{\"id\":" + id + ",\"name\":\"User" + id + "\"}";
res.set_json(response);
});
router.post("/users", [](const HttpRequest& req, HttpResponse& res) {
std::string body_preview = std::string(req.body.substr(0, 100));
res.set_json("{\"message\":\"User created\",\"received\":\"" + body_preview + "\"}");
res.status = 201;
});
// Request info route
router.get("/request-info", [](const HttpRequest& req, HttpResponse& res) {
std::string info = "Method: " + std::to_string(static_cast<int>(req.method)) + "\n";
info += "Path: " + std::string(req.path) + "\n";
info += "Query: " + std::string(req.query) + "\n";
info += "Headers: " + std::to_string(req.headers.size()) + "\n";
info += "Body length: " + std::to_string(req.body.size()) + "\n";
res.set_text(info);
});
// Different HTTP methods
router.put("/users/:id", [](const HttpRequest& req, HttpResponse& res) {
std::string id = req.params.at("id");
res.set_json("{\"message\":\"User " + id + " updated\"}");
});
router.del("/users/:id", [](const HttpRequest& req, HttpResponse& res) {
std::string id = req.params.at("id");
res.set_json("{\"message\":\"User " + id + " deleted\"}");
});
server = new HttpServer(8080, router);
if (!server->start()) {
std::cerr << "Failed to start server\n";
return 1;
}
std::cout << "Server running on http://localhost:8080\n";
std::cout << "Test routes:\n";
std::cout << " GET /\n";
std::cout << " GET /api/status\n";
std::cout << " GET /users\n";
std::cout << " GET /users/123\n";
std::cout << " POST /users\n";
std::cout << " GET /request-info\n";
std::cout << "Press Ctrl+C to stop\n";
server->run();
delete server;
return 0;
}

200
router.hpp Normal file
View File

@ -0,0 +1,200 @@
#pragma once
#include "http_common.hpp"
#include "http_parser.hpp"
#include <string>
#include <unordered_map>
#include <vector>
#include <functional>
#include <string_view>
#include <memory>
using std::string_view;
struct HttpResponse {
int status = 200;
std::string body;
std::string content_type = "text/plain";
std::unordered_map<std::string, std::string> headers;
void set_json(const std::string& json) {
body = json;
content_type = "application/json";
}
void set_text(const std::string& text) {
body = text;
content_type = "text/plain";
}
void set_html(const std::string& html) {
body = html;
content_type = "text/html";
}
};
using Handler = std::function<void(const HttpRequest&, HttpResponse&)>;
struct TrieNode {
std::string prefix;
std::string param_name; // For parameter nodes
std::unordered_map<HttpMethod, Handler> handlers;
std::unordered_map<char, std::unique_ptr<TrieNode>> children;
bool is_param = false;
TrieNode(string_view p = "") : prefix(p) {}
};
class Router {
private:
std::unique_ptr<TrieNode> root = std::make_unique<TrieNode>();
TrieNode* insert_path(string_view path) {
TrieNode* current = root.get();
size_t i = 0;
while (i < path.length()) {
if (path[i] == '/') {
i++;
continue;
}
bool is_param = path[i] == ':';
if (is_param) i++; // skip ':'
size_t start = i;
while (i < path.length() && path[i] != '/') i++;
std::string segment(path.substr(start, i - start));
if (is_param) {
// All parameters share the same node
auto it = current->children.find(':');
if (it == current->children.end()) {
auto new_node = std::make_unique<TrieNode>();
new_node->is_param = true;
new_node->param_name = segment;
current->children[':'] = std::move(new_node);
}
current = current->children[':'].get();
} else {
char first_char = segment[0];
auto it = current->children.find(first_char);
if (it == current->children.end()) {
auto new_node = std::make_unique<TrieNode>(segment);
current->children[first_char] = std::move(new_node);
current = current->children[first_char].get();
} else {
current = it->second.get();
if (current->prefix != segment) {
// Handle prefix mismatch - split node if needed
auto new_node = std::make_unique<TrieNode>(segment);
current->children[first_char] = std::move(new_node);
current = current->children[first_char].get();
}
}
}
}
return current;
}
TrieNode* find_path(string_view path, std::unordered_map<std::string, std::string>& params) const {
TrieNode* current = root.get();
size_t i = 0;
while (i < path.length() && current) {
if (path[i] == '/') {
i++;
continue;
}
size_t start = i;
while (i < path.length() && path[i] != '/') i++;
string_view segment = path.substr(start, i - start);
// Try exact match first
auto it = current->children.find(segment[0]);
if (it != current->children.end() && it->second->prefix == segment) {
current = it->second.get();
continue;
}
// Try parameter match
auto param_it = current->children.find(':');
if (param_it != current->children.end()) {
current = param_it->second.get();
params[current->param_name] = std::string(segment);
continue;
}
return nullptr;
}
return current;
}
public:
void get(string_view path, Handler handler) {
insert_path(path)->handlers[HttpMethod::GET] = std::move(handler);
}
void post(string_view path, Handler handler) {
insert_path(path)->handlers[HttpMethod::POST] = std::move(handler);
}
void put(string_view path, Handler handler) {
insert_path(path)->handlers[HttpMethod::PUT] = std::move(handler);
}
void del(string_view path, Handler handler) {
insert_path(path)->handlers[HttpMethod::DELETE] = std::move(handler);
}
void patch(string_view path, Handler handler) {
insert_path(path)->handlers[HttpMethod::PATCH] = std::move(handler);
}
bool route(HttpMethod method, string_view path, Handler& out_handler, std::unordered_map<std::string, std::string>& params) const {
TrieNode* node = find_path(path, params);
if (!node) return false;
auto it = node->handlers.find(method);
if (it == node->handlers.end()) return false;
out_handler = it->second;
return true;
}
bool handle(HttpRequest& request, HttpResponse& response) const {
Handler handler;
if (route(request.method, request.path, handler, request.params)) {
handler(request, response);
return true;
}
return false;
}
};
// Usage example:
/*
Router router;
router.get("/users", [](const HttpRequest& req, HttpResponse& res) {
res.set_json("{\"users\": []}");
});
router.get("/users/:id", [](const HttpRequest& req, HttpResponse& res) {
std::string id = req.params.at("id");
res.set_json("{\"id\": \"" + id + "\"}");
});
router.post("/users", [](const HttpRequest& req, HttpResponse& res) {
// Use req.body to get POST data
res.status = 201;
res.set_json("{\"created\": true}");
});
router.get("/users/:id/posts/:postId", [](const HttpRequest& req, HttpResponse& res) {
std::string userId = req.params.at("id");
std::string postId = req.params.at("postId");
res.set_json("{\"userId\": \"" + userId + "\", \"postId\": \"" + postId + "\"}");
});
*/

BIN
server Executable file

Binary file not shown.

149
server.hpp Normal file
View File

@ -0,0 +1,149 @@
#pragma once
#include "sockets/epoll.hpp"
#include "router.hpp"
#include "http_parser.hpp"
#include <iostream>
#include <string.h>
#include <string_view>
#include <array>
#include <string>
class HttpServer {
public:
explicit HttpServer(uint16_t port, Router& router) : socket_(port), router_(router) {
socket_.on_connection([this](int fd) { handle_connection(fd); });
socket_.on_data([this](int fd) { handle_data(fd); });
socket_.on_disconnect([this](int fd) { handle_disconnect(fd); });
}
bool start() { return socket_.start(); }
void run() { socket_.run(); }
void stop() { socket_.stop(); }
// Utility function to extract path parameters
static std::string get_path_param(string_view path, size_t segment_index = 0) {
size_t start = 0;
size_t current_segment = 0;
while (start < path.length()) {
if (path[start] == '/') {
start++;
continue;
}
size_t end = path.find('/', start);
if (end == std::string_view::npos) end = path.length();
if (current_segment == segment_index) {
return std::string(path.substr(start, end - start));
}
current_segment++;
start = end;
}
return "";
}
private:
static constexpr int BUFFER_SIZE = 8192;
EpollSocket socket_;
Router& router_;
std::array<char, BUFFER_SIZE> buffer_;
void handle_connection(int client_fd) {
// Client connected
}
void handle_data(int client_fd) {
while (true) {
ssize_t bytes_read = read(client_fd, buffer_.data(), buffer_.size());
if (bytes_read == -1) {
if (errno == EAGAIN || errno == EWOULDBLOCK) break;
return;
}
if (bytes_read == 0) return;
std::string request_data(buffer_.data(), bytes_read);
if (request_data.find("\r\n\r\n") != std::string::npos) {
process_request(client_fd, request_data);
}
}
}
void handle_disconnect(int client_fd) {
// Client disconnected
}
void process_request(int client_fd, std::string_view request_data) {
HttpRequest req = HttpParser::parse(request_data);
if (!req.valid) {
send_error_response(client_fd, "Bad Request", 400);
return;
}
HttpResponse response;
// Try to route the request (router will populate req.params)
if (router_.handle(req, response)) {
send_http_response(client_fd, response);
} else {
response.status = 404;
response.set_text("Not Found");
send_http_response(client_fd, response);
}
}
void send_http_response(int client_fd, const HttpResponse& response) {
std::string http_response = "HTTP/1.1 " + std::to_string(response.status);
switch (response.status) {
case 200: http_response += " OK"; break;
case 201: http_response += " Created"; break;
case 400: http_response += " Bad Request"; break;
case 404: http_response += " Not Found"; break;
case 500: http_response += " Internal Server Error"; break;
default: http_response += " Unknown"; break;
}
http_response += "\r\n";
http_response += "Content-Type: " + response.content_type + "\r\n";
http_response += "Content-Length: " + std::to_string(response.body.size()) + "\r\n";
// Add custom headers
for (const auto& [key, value] : response.headers) {
http_response += key + ": " + value + "\r\n";
}
http_response += "Connection: keep-alive\r\n\r\n";
http_response += response.body;
send_raw_response(client_fd, http_response);
}
void send_raw_response(int client_fd, const std::string& response) {
ssize_t total_sent = 0;
ssize_t response_len = response.size();
while (total_sent < response_len) {
ssize_t sent = write(client_fd, response.data() + total_sent, response_len - total_sent);
if (sent == -1) {
if (errno == EAGAIN || errno == EWOULDBLOCK) continue;
break;
}
total_sent += sent;
}
}
void send_error_response(int client_fd, const std::string& message, int status) {
HttpResponse response;
response.status = status;
response.set_text(message);
response.headers["Connection"] = "close";
send_http_response(client_fd, response);
}
};

156
sockets/epoll.hpp Normal file
View File

@ -0,0 +1,156 @@
#pragma once
#include <sys/epoll.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <netinet/tcp.h>
#include <fcntl.h>
#include <unistd.h>
#include <errno.h>
#include <array>
#include <functional>
class EpollSocket {
public:
using ConnectionHandler = std::function<void(int client_fd)>;
using DataHandler = std::function<void(int client_fd)>;
using DisconnectHandler = std::function<void(int client_fd)>;
explicit EpollSocket(uint16_t port = 8080) : port_(port) {}
~EpollSocket() {
if (server_fd_ != -1) close(server_fd_);
if (epoll_fd_ != -1) close(epoll_fd_);
}
bool start() {
if (!create_server_socket()) return false;
if (!create_epoll()) return false;
if (!add_server_to_epoll()) return false;
return true;
}
void run() {
std::array<epoll_event, MAX_EVENTS> events;
while (running_) {
int num_events = epoll_wait(epoll_fd_, events.data(), MAX_EVENTS, -1);
if (num_events == -1) break;
for (int i = 0; i < num_events; ++i) {
if (events[i].data.fd == server_fd_) {
accept_connections();
} else {
handle_client_event(events[i].data.fd);
}
}
}
}
void stop() { running_ = false; }
bool add_client(int client_fd) {
epoll_event event{};
event.events = EPOLLIN | EPOLLET;
event.data.fd = client_fd;
return epoll_ctl(epoll_fd_, EPOLL_CTL_ADD, client_fd, &event) != -1;
}
void remove_client(int client_fd) {
epoll_ctl(epoll_fd_, EPOLL_CTL_DEL, client_fd, nullptr);
close(client_fd);
}
// Event handlers
void on_connection(ConnectionHandler handler) { on_connection_ = std::move(handler); }
void on_data(DataHandler handler) { on_data_ = std::move(handler); }
void on_disconnect(DisconnectHandler handler) { on_disconnect_ = std::move(handler); }
private:
static constexpr int MAX_EVENTS = 1024;
uint16_t port_;
int server_fd_ = -1;
int epoll_fd_ = -1;
bool running_ = true;
ConnectionHandler on_connection_;
DataHandler on_data_;
DisconnectHandler on_disconnect_;
bool create_server_socket() {
server_fd_ = socket(AF_INET, SOCK_STREAM | SOCK_NONBLOCK | SOCK_CLOEXEC, 0);
if (server_fd_ == -1) return false;
int opt = 1;
if (setsockopt(server_fd_, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt)) == -1) {
return false;
}
// Enable port reuse for load balancing
if (setsockopt(server_fd_, SOL_SOCKET, SO_REUSEPORT, &opt, sizeof(opt)) == -1) {
return false;
}
// Disable Nagle's algorithm for lower latency
if (setsockopt(server_fd_, IPPROTO_TCP, TCP_NODELAY, &opt, sizeof(opt)) == -1) {
return false;
}
sockaddr_in addr{};
addr.sin_family = AF_INET;
addr.sin_addr.s_addr = INADDR_ANY;
addr.sin_port = htons(port_);
if (bind(server_fd_, reinterpret_cast<sockaddr*>(&addr), sizeof(addr)) == -1) {
return false;
}
return listen(server_fd_, SOMAXCONN) != -1;
}
bool create_epoll() {
epoll_fd_ = epoll_create1(EPOLL_CLOEXEC);
return epoll_fd_ != -1;
}
bool add_server_to_epoll() {
epoll_event event{};
event.events = EPOLLIN | EPOLLET;
event.data.fd = server_fd_;
return epoll_ctl(epoll_fd_, EPOLL_CTL_ADD, server_fd_, &event) != -1;
}
inline void accept_connections() {
while (true) {
sockaddr_in client_addr{};
socklen_t client_len = sizeof(client_addr);
// Use accept4 to set non-blocking atomically
int client_fd = accept4(server_fd_,
reinterpret_cast<sockaddr*>(&client_addr),
&client_len,
SOCK_NONBLOCK | SOCK_CLOEXEC);
if (client_fd == -1) {
if (errno == EAGAIN || errno == EWOULDBLOCK) break;
continue;
}
// Set TCP_NODELAY for client connections
int opt = 1;
setsockopt(client_fd, IPPROTO_TCP, TCP_NODELAY, &opt, sizeof(opt));
if (add_client(client_fd)) {
if (on_connection_) on_connection_(client_fd);
} else {
close(client_fd);
}
}
}
inline void handle_client_event(int client_fd) {
if (on_data_) on_data_(client_fd);
}
};

182
sockets/uring.hpp Normal file
View File

@ -0,0 +1,182 @@
#pragma once
#include <liburing.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <unistd.h>
#include <errno.h>
#include <functional>
#include <unordered_map>
#include <memory>
enum class OpType : uint8_t {
ACCEPT,
READ,
WRITE
};
struct IoData {
OpType type;
int fd;
char* buffer;
size_t size;
};
class IoUringSocket {
public:
using ConnectionHandler = std::function<void(int client_fd)>;
using DataHandler = std::function<void(int client_fd, const char* data, size_t len)>;
using DisconnectHandler = std::function<void(int client_fd)>;
explicit IoUringSocket(uint16_t port = 8080, int queue_depth = 256)
: port_(port), queue_depth_(queue_depth) {}
~IoUringSocket() {
if (server_fd_ != -1) close(server_fd_);
io_uring_queue_exit(&ring_);
for (auto& [fd, buffer] : buffers_) {
delete[] buffer;
}
}
bool start() {
if (!create_server_socket()) return false;
if (io_uring_queue_init(queue_depth_, &ring_, 0) < 0) return false;
submit_accept();
return true;
}
void run() {
while (running_) {
io_uring_cqe* cqe;
int ret = io_uring_wait_cqe(&ring_, &cqe);
if (ret < 0) break;
handle_completion(cqe);
io_uring_cqe_seen(&ring_, cqe);
}
}
void stop() { running_ = false; }
// Event handlers
void on_connection(ConnectionHandler handler) { on_connection_ = std::move(handler); }
void on_data(DataHandler handler) { on_data_ = std::move(handler); }
void on_disconnect(DisconnectHandler handler) { on_disconnect_ = std::move(handler); }
private:
static constexpr int BUFFER_SIZE = 8192;
uint16_t port_;
int queue_depth_;
int server_fd_ = -1;
bool running_ = true;
io_uring ring_;
std::unordered_map<int, char*> buffers_;
ConnectionHandler on_connection_;
DataHandler on_data_;
DisconnectHandler on_disconnect_;
bool create_server_socket() {
server_fd_ = socket(AF_INET, SOCK_STREAM, 0);
if (server_fd_ == -1) return false;
int opt = 1;
if (setsockopt(server_fd_, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt)) == -1) {
return false;
}
sockaddr_in addr{};
addr.sin_family = AF_INET;
addr.sin_addr.s_addr = INADDR_ANY;
addr.sin_port = htons(port_);
if (bind(server_fd_, reinterpret_cast<sockaddr*>(&addr), sizeof(addr)) == -1) {
return false;
}
return listen(server_fd_, SOMAXCONN) != -1;
}
void submit_accept() {
auto* data = new IoData{OpType::ACCEPT, server_fd_, nullptr, 0};
io_uring_sqe* sqe = io_uring_get_sqe(&ring_);
io_uring_prep_accept(sqe, server_fd_, nullptr, nullptr, 0);
io_uring_sqe_set_data(sqe, data);
io_uring_submit(&ring_);
}
void submit_read(int client_fd) {
if (buffers_.find(client_fd) == buffers_.end()) {
buffers_[client_fd] = new char[BUFFER_SIZE];
}
auto* data = new IoData{OpType::READ, client_fd, buffers_[client_fd], BUFFER_SIZE};
io_uring_sqe* sqe = io_uring_get_sqe(&ring_);
io_uring_prep_read(sqe, client_fd, data->buffer, data->size, 0);
io_uring_sqe_set_data(sqe, data);
io_uring_submit(&ring_);
}
void handle_completion(io_uring_cqe* cqe) {
auto* data = static_cast<IoData*>(io_uring_cqe_get_data(cqe));
if (!data) return;
switch (data->type) {
case OpType::ACCEPT:
handle_accept(cqe->res);
submit_accept(); // Continue accepting
break;
case OpType::READ:
handle_read(data->fd, cqe->res, data->buffer);
break;
case OpType::WRITE:
// Write completed, nothing special needed
break;
}
delete data;
}
void handle_accept(int result) {
if (result < 0) return;
int client_fd = result;
if (on_connection_) on_connection_(client_fd);
submit_read(client_fd);
}
void handle_read(int client_fd, int result, char* buffer) {
if (result <= 0) {
if (on_disconnect_) on_disconnect_(client_fd);
cleanup_client(client_fd);
return;
}
if (on_data_) on_data_(client_fd, buffer, result);
submit_read(client_fd); // Continue reading
}
void cleanup_client(int client_fd) {
close(client_fd);
if (buffers_.find(client_fd) != buffers_.end()) {
delete[] buffers_[client_fd];
buffers_.erase(client_fd);
}
}
public:
void write_async(int client_fd, const char* data, size_t len) {
auto* io_data = new IoData{OpType::WRITE, client_fd, nullptr, len};
io_uring_sqe* sqe = io_uring_get_sqe(&ring_);
io_uring_prep_write(sqe, client_fd, data, len, 0);
io_uring_sqe_set_data(sqe, io_data);
io_uring_submit(&ring_);
}
};