cpp_server/epoll_socket.hpp
2025-06-30 22:28:19 -05:00

381 lines
9.4 KiB
C++

#pragma once
#include <sys/epoll.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <netinet/tcp.h>
#include <fcntl.h>
#include <unistd.h>
#include <errno.h>
#include <array>
#include <functional>
#include <unordered_map>
#include <vector>
#include <cstring>
class EpollSocket
{
public:
using ConnectionHandler = std::function<void(int client_fd)>;
using DataHandler = std::function<void(int client_fd)>;
using DisconnectHandler = std::function<void(int client_fd)>;
explicit EpollSocket(uint16_t port = 8080) : port_(port) {}
~EpollSocket()
{
shutdown();
}
bool start()
{
if (!create_server_socket()) return false;
if (!create_epoll()) return false;
if (!add_server_to_epoll()) return false;
return true;
}
void run()
{
std::array<epoll_event, MAX_EVENTS> events;
while (running_) {
int num_events = epoll_wait(epoll_fd_, events.data(), MAX_EVENTS, 1000);
if (num_events == -1) {
if (errno == EINTR) continue;
break;
}
for (int i = 0; i < num_events; ++i) {
const auto& event = events[i];
int fd = event.data.fd;
if (fd == server_fd_) {
accept_connections();
} else if (event.events & (EPOLLHUP | EPOLLERR)) {
// Client disconnected or error occurred
handle_disconnect(fd);
} else if (event.events & EPOLLIN) {
// Data available to read
handle_client_read(fd);
} else if (event.events & EPOLLOUT) {
// Socket ready for writing
handle_client_write(fd);
}
}
}
}
void stop() { running_ = false; }
// Send data to a client, returns true if queued successfully
bool send_data(int client_fd, const void* data, size_t len)
{
if (clients_.find(client_fd) == clients_.end()) {
return false; // Client not found
}
auto& client = clients_[client_fd];
// Try immediate send if no pending data
if (client.write_buffer.empty()) {
ssize_t sent = send(client_fd, data, len, MSG_NOSIGNAL);
if (sent == static_cast<ssize_t>(len)) {
return true; // All data sent immediately
}
if (sent > 0) {
// Partial send, queue the rest
const char* remaining = static_cast<const char*>(data) + sent;
client.write_buffer.insert(client.write_buffer.end(),
remaining, remaining + (len - sent));
} else if (sent == -1 && errno != EAGAIN && errno != EWOULDBLOCK) {
return false; // Real error
} else {
// EAGAIN/EWOULDBLOCK, queue all data
const char* bytes = static_cast<const char*>(data);
client.write_buffer.insert(client.write_buffer.end(), bytes, bytes + len);
}
} else {
// Already have pending data, just queue
const char* bytes = static_cast<const char*>(data);
client.write_buffer.insert(client.write_buffer.end(), bytes, bytes + len);
}
// Enable EPOLLOUT to get notified when we can write
enable_write_events(client_fd);
return true;
}
// Event handlers
void on_connection(ConnectionHandler handler) { on_connection_ = std::move(handler); }
void on_data(DataHandler handler) { on_data_ = std::move(handler); }
void on_disconnect(DisconnectHandler handler) { on_disconnect_ = std::move(handler); }
// Get number of active connections
size_t connection_count() const { return clients_.size(); }
private:
static constexpr int MAX_EVENTS = 1024;
static constexpr size_t MAX_WRITE_BUFFER = 64 * 1024; // 64KB max buffer per client
struct ClientInfo
{
std::vector<char> write_buffer;
bool write_enabled = false;
};
uint16_t port_;
int server_fd_ = -1;
int epoll_fd_ = -1;
bool running_ = true;
std::unordered_map<int, ClientInfo> clients_;
ConnectionHandler on_connection_;
DataHandler on_data_;
DisconnectHandler on_disconnect_;
bool create_server_socket()
{
// Create socket with non-blocking and close-on-exec flags
server_fd_ = socket(AF_INET, SOCK_STREAM | SOCK_NONBLOCK | SOCK_CLOEXEC, 0);
if (server_fd_ == -1) return false;
int opt = 1;
// Allow address reuse to avoid "Address already in use" errors
if (setsockopt(server_fd_, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt)) == -1) {
return false;
}
// Enable port reuse for load balancing multiple processes
if (setsockopt(server_fd_, SOL_SOCKET, SO_REUSEPORT, &opt, sizeof(opt)) == -1) {
return false;
}
// Disable Nagle's algorithm for lower latency
if (setsockopt(server_fd_, IPPROTO_TCP, TCP_NODELAY, &opt, sizeof(opt)) == -1) {
return false;
}
sockaddr_in addr{};
addr.sin_family = AF_INET;
addr.sin_addr.s_addr = INADDR_ANY;
addr.sin_port = htons(port_);
if (bind(server_fd_, reinterpret_cast<sockaddr*>(&addr), sizeof(addr)) == -1) {
return false;
}
// Use SOMAXCONN for maximum backlog
return listen(server_fd_, SOMAXCONN) != -1;
}
bool create_epoll()
{
// Create epoll instance with close-on-exec flag
epoll_fd_ = epoll_create1(EPOLL_CLOEXEC);
return epoll_fd_ != -1;
}
bool add_server_to_epoll()
{
epoll_event event{};
event.events = EPOLLIN | EPOLLET; // Edge-triggered for better performance
event.data.fd = server_fd_;
return epoll_ctl(epoll_fd_, EPOLL_CTL_ADD, server_fd_, &event) != -1;
}
bool add_client(int client_fd)
{
epoll_event event{};
event.events = EPOLLIN | EPOLLET; // Start with read events only
event.data.fd = client_fd;
if (epoll_ctl(epoll_fd_, EPOLL_CTL_ADD, client_fd, &event) == -1) {
return false;
}
// Track client connection
clients_[client_fd] = ClientInfo{};
return true;
}
void remove_client(int client_fd)
{
// Remove from epoll (ignore errors as fd might be closed)
epoll_ctl(epoll_fd_, EPOLL_CTL_DEL, client_fd, nullptr);
// Remove from client tracking
clients_.erase(client_fd);
// Close the socket
close(client_fd);
}
void enable_write_events(int client_fd)
{
auto it = clients_.find(client_fd);
if (it == clients_.end() || it->second.write_enabled) {
return; // Client not found or write already enabled
}
epoll_event event{};
event.events = EPOLLIN | EPOLLOUT | EPOLLET;
event.data.fd = client_fd;
if (epoll_ctl(epoll_fd_, EPOLL_CTL_MOD, client_fd, &event) != -1) {
it->second.write_enabled = true;
}
}
void disable_write_events(int client_fd)
{
auto it = clients_.find(client_fd);
if (it == clients_.end() || !it->second.write_enabled) {
return; // Client not found or write already disabled
}
epoll_event event{};
event.events = EPOLLIN | EPOLLET;
event.data.fd = client_fd;
if (epoll_ctl(epoll_fd_, EPOLL_CTL_MOD, client_fd, &event) != -1) {
it->second.write_enabled = false;
}
}
inline void accept_connections()
{
// Accept all pending connections in edge-triggered mode
while (running_) {
sockaddr_in client_addr{};
socklen_t client_len = sizeof(client_addr);
// Use accept4 to set non-blocking and close-on-exec atomically
int client_fd = accept4(server_fd_,
reinterpret_cast<sockaddr*>(&client_addr),
&client_len,
SOCK_NONBLOCK | SOCK_CLOEXEC);
if (client_fd == -1) {
if (errno == EAGAIN || errno == EWOULDBLOCK) {
break; // No more connections to accept
}
continue; // Skip this connection on other errors
}
// Set TCP_NODELAY for client connections (low latency)
int opt = 1;
setsockopt(client_fd, IPPROTO_TCP, TCP_NODELAY, &opt, sizeof(opt));
if (add_client(client_fd)) {
if (on_connection_) on_connection_(client_fd);
} else {
close(client_fd);
}
}
}
inline void handle_client_read(int client_fd)
{
// Check if client is still tracked
if (clients_.find(client_fd) == clients_.end()) {
return;
}
// Use MSG_PEEK to check connection status without consuming data
char peek_buffer[1];
ssize_t peek_result = recv(client_fd, peek_buffer, 1, MSG_PEEK | MSG_DONTWAIT);
if (peek_result == 0) {
// Clean disconnect (FIN received)
handle_disconnect(client_fd);
return;
}
if (peek_result == -1) {
if (errno == EAGAIN || errno == EWOULDBLOCK) {
// False alarm, no data available
return;
}
// Connection error
handle_disconnect(client_fd);
return;
}
// Data is available, notify handler
if (on_data_) on_data_(client_fd);
}
inline void handle_client_write(int client_fd)
{
auto it = clients_.find(client_fd);
if (it == clients_.end()) {
return; // Client no longer exists
}
auto& client = it->second;
// Send as much buffered data as possible
while (!client.write_buffer.empty()) {
ssize_t sent = send(client_fd, client.write_buffer.data(),
client.write_buffer.size(), MSG_NOSIGNAL);
if (sent == -1) {
if (errno == EAGAIN || errno == EWOULDBLOCK) {
break; // Socket buffer full, try again later
}
// Connection error
handle_disconnect(client_fd);
return;
}
if (sent == 0) {
// Shouldn't happen with MSG_NOSIGNAL, but handle it
break;
}
// Remove sent data from buffer
client.write_buffer.erase(client.write_buffer.begin(),
client.write_buffer.begin() + sent);
}
// If buffer is empty, disable write events to avoid busy-waiting
if (client.write_buffer.empty()) {
disable_write_events(client_fd);
}
}
inline void handle_disconnect(int client_fd)
{
// Notify application before removing client
if (on_disconnect_) on_disconnect_(client_fd);
// Clean up client resources
remove_client(client_fd);
}
void shutdown()
{
running_ = false;
// Close all client connections gracefully
for (const auto& [fd, client] : clients_) {
close(fd);
}
clients_.clear();
// Close server socket
if (server_fd_ != -1) {
close(server_fd_);
server_fd_ = -1;
}
// Close epoll instance
if (epoll_fd_ != -1) {
close(epoll_fd_);
epoll_fd_ = -1;
}
}
};