cpp_server/sockets/uring.hpp
2025-06-12 19:01:53 -05:00

183 lines
4.3 KiB
C++

#pragma once
#include <liburing.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <unistd.h>
#include <errno.h>
#include <functional>
#include <unordered_map>
#include <memory>
enum class OpType : uint8_t {
ACCEPT,
READ,
WRITE
};
struct IoData {
OpType type;
int fd;
char* buffer;
size_t size;
};
class IoUringSocket {
public:
using ConnectionHandler = std::function<void(int client_fd)>;
using DataHandler = std::function<void(int client_fd, const char* data, size_t len)>;
using DisconnectHandler = std::function<void(int client_fd)>;
explicit IoUringSocket(uint16_t port = 8080, int queue_depth = 256)
: port_(port), queue_depth_(queue_depth) {}
~IoUringSocket() {
if (server_fd_ != -1) close(server_fd_);
io_uring_queue_exit(&ring_);
for (auto& [fd, buffer] : buffers_) {
delete[] buffer;
}
}
bool start() {
if (!create_server_socket()) return false;
if (io_uring_queue_init(queue_depth_, &ring_, 0) < 0) return false;
submit_accept();
return true;
}
void run() {
while (running_) {
io_uring_cqe* cqe;
int ret = io_uring_wait_cqe(&ring_, &cqe);
if (ret < 0) break;
handle_completion(cqe);
io_uring_cqe_seen(&ring_, cqe);
}
}
void stop() { running_ = false; }
// Event handlers
void on_connection(ConnectionHandler handler) { on_connection_ = std::move(handler); }
void on_data(DataHandler handler) { on_data_ = std::move(handler); }
void on_disconnect(DisconnectHandler handler) { on_disconnect_ = std::move(handler); }
private:
static constexpr int BUFFER_SIZE = 8192;
uint16_t port_;
int queue_depth_;
int server_fd_ = -1;
bool running_ = true;
io_uring ring_;
std::unordered_map<int, char*> buffers_;
ConnectionHandler on_connection_;
DataHandler on_data_;
DisconnectHandler on_disconnect_;
bool create_server_socket() {
server_fd_ = socket(AF_INET, SOCK_STREAM, 0);
if (server_fd_ == -1) return false;
int opt = 1;
if (setsockopt(server_fd_, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt)) == -1) {
return false;
}
sockaddr_in addr{};
addr.sin_family = AF_INET;
addr.sin_addr.s_addr = INADDR_ANY;
addr.sin_port = htons(port_);
if (bind(server_fd_, reinterpret_cast<sockaddr*>(&addr), sizeof(addr)) == -1) {
return false;
}
return listen(server_fd_, SOMAXCONN) != -1;
}
void submit_accept() {
auto* data = new IoData{OpType::ACCEPT, server_fd_, nullptr, 0};
io_uring_sqe* sqe = io_uring_get_sqe(&ring_);
io_uring_prep_accept(sqe, server_fd_, nullptr, nullptr, 0);
io_uring_sqe_set_data(sqe, data);
io_uring_submit(&ring_);
}
void submit_read(int client_fd) {
if (buffers_.find(client_fd) == buffers_.end()) {
buffers_[client_fd] = new char[BUFFER_SIZE];
}
auto* data = new IoData{OpType::READ, client_fd, buffers_[client_fd], BUFFER_SIZE};
io_uring_sqe* sqe = io_uring_get_sqe(&ring_);
io_uring_prep_read(sqe, client_fd, data->buffer, data->size, 0);
io_uring_sqe_set_data(sqe, data);
io_uring_submit(&ring_);
}
void handle_completion(io_uring_cqe* cqe) {
auto* data = static_cast<IoData*>(io_uring_cqe_get_data(cqe));
if (!data) return;
switch (data->type) {
case OpType::ACCEPT:
handle_accept(cqe->res);
submit_accept(); // Continue accepting
break;
case OpType::READ:
handle_read(data->fd, cqe->res, data->buffer);
break;
case OpType::WRITE:
// Write completed, nothing special needed
break;
}
delete data;
}
void handle_accept(int result) {
if (result < 0) return;
int client_fd = result;
if (on_connection_) on_connection_(client_fd);
submit_read(client_fd);
}
void handle_read(int client_fd, int result, char* buffer) {
if (result <= 0) {
if (on_disconnect_) on_disconnect_(client_fd);
cleanup_client(client_fd);
return;
}
if (on_data_) on_data_(client_fd, buffer, result);
submit_read(client_fd); // Continue reading
}
void cleanup_client(int client_fd) {
close(client_fd);
if (buffers_.find(client_fd) != buffers_.end()) {
delete[] buffers_[client_fd];
buffers_.erase(client_fd);
}
}
public:
void write_async(int client_fd, const char* data, size_t len) {
auto* io_data = new IoData{OpType::WRITE, client_fd, nullptr, len};
io_uring_sqe* sqe = io_uring_get_sqe(&ring_);
io_uring_prep_write(sqe, client_fd, data, len, 0);
io_uring_sqe_set_data(sqe, io_data);
io_uring_submit(&ring_);
}
};