diff --git a/README.md b/README.md index bc46e50..d6d251e 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # Sockeye -An easy-to-use, fast socket C++ library! Uses epoll on Linux, kqueue on BSD/macOS. +An easy-to-use, fast, and robust C++ epoll socket library. ## API Reference @@ -9,29 +9,35 @@ An easy-to-use, fast socket C++ library! Uses epoll on Linux, kqueue on BSD/macO Main server class for handling TCP connections. #### Constructor + ```cpp -explicit Socket(uint16_t port = 8080) +explicit Socket(uint16_t port = 8080, int timeout_ms = 5000) ``` +Constructs a server instance. `timeout_ms` is the duration of inactivity in milliseconds before a client connection is automatically closed. + #### Methods **`bool start()`** Initialize the server socket and event system. Returns `true` on success. **`void run()`** -Start the event loop. Blocks until `stop()` is called. +Start the event loop. This is a blocking call that runs until `stop()` is called. **`void stop()`** -Stop the server and exit the event loop. +Stops the server and causes the `run()` loop to exit. + +**`bool send(int client_fd, const std::string& data)`** +Sends data to a connected client. This method handles partial sends and ensures all data is written. Returns `true` on success. **`void on_connection(ConnectionHandler handler)`** -Set callback for new client connections. +Set a callback to be executed when a new client connects. **`void on_data(DataHandler handler)`** -Set callback for incoming data from clients. +Set a callback to be executed when data is received from a client. **`void on_disconnect(DisconnectHandler handler)`** -Set callback for client disconnections. +Set a callback to be executed when a client disconnects for any reason (including timeout). #### Handler Types @@ -45,38 +51,7 @@ using DisconnectHandler = std::function; ### Basic Echo Server -```cpp -#include "sockeye.hpp" -#include -#include - -int main() { - sockeye::Socket server(8080); - - server.on_connection([](int client_fd) { - std::cout << "Client connected: " << client_fd << std::endl; - }); - - server.on_data([](int client_fd, const char* data, size_t len) { - // Echo data back to client - send(client_fd, data, len, 0); - }); - - server.on_disconnect([](int client_fd) { - std::cout << "Client disconnected: " << client_fd << std::endl; - }); - - if (!server.start()) { - std::cerr << "Failed to start server" << std::endl; - return 1; - } - - server.run(); - return 0; -} -``` - -### HTTP-like Server +This example demonstrates how to echo all received data back to the client. It captures the `server` object to use the integrated `send` method. ```cpp #include "sockeye.hpp" @@ -84,61 +59,124 @@ int main() { #include int main() { - sockeye::Socket server(8080); + sockeye::Socket server(8080); - server.on_data([](int client_fd, const char* data, size_t len) { - std::string request(data, len); + server.on_connection([](int client_fd) { + std::cout << "Client connected: " << client_fd << std::endl; + }); - // Simple HTTP response - std::string response = - "HTTP/1.1 200 OK\r\n" - "Content-Length: 13\r\n" - "Connection: close\r\n\r\n" - "Hello, World!"; + // Capture server to use its send method + server.on_data([&server](int client_fd, const char* data, size_t len) { + // Echo data back to client using the server's send method + server.send(client_fd, std::string(data, len)); + }); - send(client_fd, response.c_str(), response.length(), 0); - close(client_fd); - }); + server.on_disconnect([](int client_fd) { + std::cout << "Client disconnected: " << client_fd << std::endl; + }); - if (!server.start()) { - std::cerr << "Failed to start server" << std::endl; - return 1; - } + if (!server.start()) { + std::cerr << "Failed to start server" << std::endl; + return 1; + } - std::cout << "Server listening on port 8080" << std::endl; - server.run(); - return 0; + server.run(); + return 0; +} +``` + +### HTTP Server with Keep-Alive + +This example shows a simple HTTP server that properly handles keep-alive connections. The library automatically manages client timeouts. + +```cpp +#include "sockeye.hpp" +#include +#include +#include +#include + +int main() { + sockeye::Socket server(8080); + + // Buffers for accumulating request data per client + std::unordered_map request_buffers; + std::mutex buffer_mutex; + + server.on_data([&](int client_fd, const char* data, size_t len) { + std::string request_chunk(data, len); + + std::lock_guard lock(buffer_mutex); + request_buffers[client_fd] += request_chunk; + + // Check if we have a full HTTP request + if (request_buffers[client_fd].find("\r\n\r\n") != std::string::npos) { + std::string response = + "HTTP/1.1 200 OK\r\n" + "Content-Length: 13\r\n" + "Connection: keep-alive\r\n\r\n" + "Hello, World!"; + + server.send(client_fd, response); + + // Clear buffer for this client for the next request + request_buffers[client_fd].clear(); + } + }); + + server.on_disconnect([&](int client_fd) { + std::cout << "Client disconnected: " << client_fd << std::endl; + std::lock_guard lock(buffer_mutex); + request_buffers.erase(client_fd); + }); + + if (!server.start()) { + std::cerr << "Failed to start server" << std::endl; + return 1; + } + + std::cout << "Server listening on port 8080" << std::endl; + server.run(); + return 0; } ``` ### Graceful Shutdown +This example shows how to catch `SIGINT` (Ctrl+C) and `SIGTERM` signals to shut down the server gracefully. + ```cpp #include "sockeye.hpp" #include +#include sockeye::Socket* server_ptr = nullptr; void signal_handler(int signal) { - if (server_ptr) { - server_ptr->stop(); - } + if (server_ptr) { + std::cout << "\nCaught signal " << signal << ", stopping server..." << std::endl; + server_ptr->stop(); + } } int main() { - sockeye::Socket server(8080); - server_ptr = &server; + sockeye::Socket server(8080); + server_ptr = &server; - signal(SIGINT, signal_handler); - signal(SIGTERM, signal_handler); + signal(SIGINT, signal_handler); + signal(SIGTERM, signal_handler); - // Set up handlers... + // Set up handlers... + server.on_connection([](int fd){ /* ... */ }); - if (!server.start()) { - return 1; - } + if (!server.start()) { + return 1; + } - server.run(); - return 0; + std::cout << "Server started. Press Ctrl+C to exit." << std::endl; + server.run(); + std::cout << "Server stopped." << std::endl; + + return 0; } ``` diff --git a/sockeye.hpp b/sockeye.hpp index fdf0dfb..f669e7f 100644 --- a/sockeye.hpp +++ b/sockeye.hpp @@ -9,6 +9,10 @@ #include #include #include +#include +#include +#include +#include // Added for std::string namespace sockeye { @@ -19,7 +23,8 @@ public: using DataHandler = std::function; using DisconnectHandler = std::function; - explicit Socket(uint16_t port = 8080) : port_(port) {} + explicit Socket(uint16_t port = 8080, int timeout_ms = 5000) + : port_(port), timeout_ms_(timeout_ms) {} ~Socket() { @@ -52,24 +57,32 @@ public: handle_client_data(events[i].data.fd); } } + + check_timeouts(); } } void stop() { running_ = false; } - bool add_client(int client_fd) - { - epoll_event event{}; - event.events = EPOLLIN | EPOLLET; - event.data.fd = client_fd; - return epoll_ctl(epoll_fd_, EPOLL_CTL_ADD, client_fd, &event) != -1; - } + // New: Send data to a client + bool send(int client_fd, const std::string& data) { + ssize_t total_sent = 0; + const char* p_data = data.c_str(); + size_t len = data.length(); - void remove_client(int client_fd) - { - epoll_ctl(epoll_fd_, EPOLL_CTL_DEL, client_fd, nullptr); - close(client_fd); - if (on_disconnect_) on_disconnect_(client_fd); + while (total_sent < static_cast(len)) { + ssize_t sent = ::send(client_fd, p_data + total_sent, len - total_sent, MSG_NOSIGNAL); + if (sent == -1) { + if (errno == EAGAIN || errno == EWOULDBLOCK) { + // Can't send right now, handle as an error for simplicity in this context + return false; + } + // Other error + return false; + } + total_sent += sent; + } + return true; } void on_connection(ConnectionHandler handler) { on_connection_ = std::move(handler); } @@ -77,14 +90,22 @@ public: void on_disconnect(DisconnectHandler handler) { on_disconnect_ = std::move(handler); } private: + struct Client { + int fd; + std::chrono::steady_clock::time_point last_activity; + }; + static constexpr int MAX_EVENTS = 1024; static constexpr int BUFFER_SIZE = 8192; uint16_t port_; + int timeout_ms_; int server_fd_ = -1; int epoll_fd_ = -1; bool running_ = true; + std::unordered_map clients_; + ConnectionHandler on_connection_; DataHandler on_data_; DisconnectHandler on_disconnect_; @@ -124,6 +145,26 @@ private: return epoll_ctl(epoll_fd_, EPOLL_CTL_ADD, server_fd_, &event) != -1; } + bool add_client(int client_fd) + { + epoll_event event{}; + event.events = EPOLLIN | EPOLLET; + event.data.fd = client_fd; + if (epoll_ctl(epoll_fd_, EPOLL_CTL_ADD, client_fd, &event) == -1) { + return false; + } + clients_[client_fd] = {client_fd, std::chrono::steady_clock::now()}; + return true; + } + + void remove_client(int client_fd) + { + epoll_ctl(epoll_fd_, EPOLL_CTL_DEL, client_fd, nullptr); + close(client_fd); + clients_.erase(client_fd); + if (on_disconnect_) on_disconnect_(client_fd); + } + inline void accept_connections() { while (true) { @@ -153,6 +194,11 @@ private: inline void handle_client_data(int client_fd) { + auto it = clients_.find(client_fd); + if (it != clients_.end()) { + it->second.last_activity = std::chrono::steady_clock::now(); + } + if (!on_data_) return; char buffer[BUFFER_SIZE]; @@ -170,6 +216,24 @@ private: } } } + + void check_timeouts() { + if (timeout_ms_ <= 0) return; + + auto now = std::chrono::steady_clock::now(); + std::vector timed_out_clients; + + for (const auto& pair : clients_) { + auto duration = std::chrono::duration_cast(now - pair.second.last_activity); + if (duration.count() > timeout_ms_) { + timed_out_clients.push_back(pair.first); + } + } + + for (int client_fd : timed_out_clients) { + remove_client(client_fd); + } + } }; } // namespace sockeye diff --git a/test.cpp b/test.cpp index 5ee3430..dea4454 100644 --- a/test.cpp +++ b/test.cpp @@ -1,4 +1,5 @@ #include "sockeye.hpp" +#include #include #include #include @@ -9,19 +10,21 @@ #include #include #include +#include using namespace std; -class SocketTest { +class SocketTest +{ public: - void run() { - cout << "Starting socket tests...\n"; + void run() + { + cout << "Starting socket tests with HTTP workload...\n"; start_server(); this_thread::sleep_for(chrono::milliseconds(100)); - test_basic_functionality(); - test_throughput(); + test_http_workload(); server_.stop(); if (server_thread_.joinable()) { @@ -35,24 +38,22 @@ private: struct MemoryInfo { size_t rss_kb = 0; size_t peak_rss_kb = 0; - size_t vsize_kb = 0; }; - MemoryInfo get_memory_info() { + MemoryInfo get_memory_info() + { MemoryInfo info; - struct rusage usage; if (getrusage(RUSAGE_SELF, &usage) == 0) { - info.rss_kb = usage.ru_maxrss; - info.peak_rss_kb = usage.ru_maxrss; - #ifdef __APPLE__ // macOS reports in bytes, convert to KB - info.rss_kb /= 1024; - info.peak_rss_kb /= 1024; + info.rss_kb = usage.ru_maxrss / 1024; + #else + // Linux reports in KB + info.rss_kb = usage.ru_maxrss; #endif + info.peak_rss_kb = info.rss_kb; } - return info; } @@ -60,38 +61,55 @@ private: thread server_thread_; atomic peak_connections_{0}; atomic current_connections_{0}; - atomic messages_received_{0}; + atomic requests_processed_{0}; atomic bytes_received_{0}; chrono::steady_clock::time_point start_time_; + // Per-client request buffering + mutex request_buffer_mutex_; + unordered_map request_buffers_; + // Memory tracking MemoryInfo baseline_memory_; MemoryInfo peak_memory_; - // Synchronization for reliable testing - atomic clients_connected_{0}; - atomic clients_finished_sending_{0}; - mutex stats_mutex_; - - void start_server() { + void start_server() + { server_.on_connection([this](int fd) { int current = ++current_connections_; peak_connections_.store(max(peak_connections_.load(), current)); - clients_connected_++; }); server_.on_data([this](int fd, const char* data, size_t len) { - // Count messages by looking for our delimiter - for (size_t i = 0; i < len; ++i) { - if (data[i] == '\n') { - messages_received_++; - } - } bytes_received_ += len; + string request_chunk(data, len); + + lock_guard lock(request_buffer_mutex_); + request_buffers_[fd] += request_chunk; + + // Process all complete requests in the buffer + while (true) { + auto& buffer = request_buffers_[fd]; + size_t pos = buffer.find("\r\n\r\n"); + if (pos == string::npos) { + break; // No complete request found + } + + requests_processed_++; + + // Remove the processed request from the buffer + buffer.erase(0, pos + 4); + + // Send a standard HTTP response + string response = "HTTP/1.1 200 OK\r\nContent-Length: 13\r\nConnection: keep-alive\r\n\r\nHello, World!"; + server_.send(fd, response); + } }); server_.on_disconnect([this](int fd) { current_connections_--; + lock_guard lock(request_buffer_mutex_); + request_buffers_.erase(fd); }); if (!server_.start()) { @@ -104,48 +122,28 @@ private: }); } - void test_basic_functionality() { - cout << "Testing basic functionality...\n"; - - int sock = create_client_socket(); - if (sock == -1) { - cerr << "Failed to create client socket\n"; - return; - } - - string msg = "Hello, server!\n"; - send(sock, msg.c_str(), msg.length(), 0); - - this_thread::sleep_for(chrono::milliseconds(50)); - close(sock); - - cout << "Basic test completed\n"; - } - - void test_throughput() { + void test_http_workload() + { constexpr int num_clients = 100; - constexpr int messages_per_client = 500; - constexpr int message_size = 1024; - constexpr int expected_messages = num_clients * messages_per_client; + constexpr int requests_per_client = 1000; + constexpr long long expected_requests = num_clients * requests_per_client; - cout << "Testing throughput: " << expected_messages << " messages...\n"; + cout << "Testing HTTP workload: " << num_clients << " clients, " + << requests_per_client << " req/client (" << expected_requests << " total)...\n"; - // Capture baseline memory baseline_memory_ = get_memory_info(); peak_memory_ = baseline_memory_; - // Reset counters - messages_received_ = 0; + requests_processed_ = 0; bytes_received_ = 0; - clients_connected_ = 0; - clients_finished_sending_ = 0; + current_connections_ = 0; + peak_connections_ = 0; start_time_ = chrono::steady_clock::now(); vector clients; clients.reserve(num_clients); - // Start memory monitoring thread atomic monitoring{true}; thread memory_monitor([this, &monitoring]() { while (monitoring) { @@ -158,149 +156,74 @@ private: }); for (int i = 0; i < num_clients; ++i) { - clients.emplace_back([this, i, messages_per_client, message_size]() { - reliable_client_worker(i, messages_per_client, message_size); + clients.emplace_back([this, requests_per_client]() { + http_client_worker(requests_per_client); }); } - // Wait for all clients to connect - while (clients_connected_.load() < num_clients) { - this_thread::sleep_for(chrono::milliseconds(10)); - } - cout << "All clients connected\n"; - - // Wait for all clients to finish for (auto& t : clients) { t.join(); } - cout << "All clients finished sending\n"; - // Wait for server to process all data with timeout - auto deadline = chrono::steady_clock::now() + chrono::seconds(5); - auto last_count = messages_received_.load(); - int stable_iterations = 0; - - while (chrono::steady_clock::now() < deadline) { + // Wait for server to process all requests + auto deadline = chrono::steady_clock::now() + chrono::seconds(10); + while (requests_processed_.load() < expected_requests && chrono::steady_clock::now() < deadline) { this_thread::sleep_for(chrono::milliseconds(50)); - auto current_count = messages_received_.load(); - - if (current_count == last_count) { - stable_iterations++; - if (stable_iterations >= 5) break; // 250ms of stability - } else { - stable_iterations = 0; - last_count = current_count; - } } monitoring = false; memory_monitor.join(); - cout << "Expected: " << expected_messages << ", Received: " << messages_received_.load() << "\n"; + cout << "Expected: " << expected_requests << ", Processed: " << requests_processed_.load() << "\n"; } - void reliable_client_worker(int client_id, int message_count, int message_size) { - int sock = -1; - int retry_count = 0; - constexpr int max_retries = 10; + void http_client_worker(int request_count) + { + int sock = create_client_socket(); + if (sock == -1) return; - // Retry connection with exponential backoff - while (sock == -1 && retry_count < max_retries) { - sock = create_client_socket(); - if (sock == -1) { - retry_count++; - this_thread::sleep_for(chrono::milliseconds(10 << retry_count)); + string request = "GET / HTTP/1.1\r\nHost: localhost\r\n\r\n"; + char response_buf[1024]; + + for (int i = 0; i < request_count; ++i) { + if (::send(sock, request.c_str(), request.length(), 0) < 0) { + break; // Send failed } - } - if (sock == -1) { - cout << "Client " << client_id << " failed to connect\n"; - return; - } - - // Create message with delimiter - string base_msg(message_size - 1, 'A' + (client_id % 26)); - string msg = base_msg + "\n"; // Add delimiter - - int sent_count = 0; - - // Send messages with proper flow control - for (int i = 0; i < message_count; ++i) { - int attempts = 0; - bool sent = false; - - while (!sent && attempts < 5) { - ssize_t result = send(sock, msg.c_str(), msg.length(), MSG_NOSIGNAL); - - if (result == static_cast(msg.length())) { - sent = true; - sent_count++; - } else if (result == -1) { - if (errno == EAGAIN || errno == EWOULDBLOCK) { - // Socket buffer full, wait longer - this_thread::sleep_for(chrono::milliseconds(1 << attempts)); - attempts++; - } else { - // Connection error - break; - } - } else if (result > 0) { - // Partial send - handle remaining data - size_t remaining = msg.length() - result; - const char* remaining_data = msg.c_str() + result; - - while (remaining > 0) { - ssize_t more = send(sock, remaining_data, remaining, MSG_NOSIGNAL); - if (more <= 0) break; - remaining -= more; - remaining_data += more; - } - - if (remaining == 0) { - sent = true; - sent_count++; + // Read response (basic implementation for testing) + ssize_t bytes_read = 0; + bool response_complete = false; + while (!response_complete) { + ssize_t result = ::recv(sock, response_buf, sizeof(response_buf) - 1, 0); + if (result > 0) { + bytes_read += result; + response_buf[result] = '\0'; + // Simple check for end of our known response + if (strstr(response_buf, "Hello, World!")) { + response_complete = true; } + } else { + break; // Connection closed or error } } - - if (!sent) break; - - // Flow control: small delay every batch - if ((i + 1) % 100 == 0) { - this_thread::sleep_for(chrono::microseconds(500)); - } + if (!response_complete) break; } - // Ensure all data is sent before closing - shutdown(sock, SHUT_WR); - - // Give server time to process - this_thread::sleep_for(chrono::milliseconds(50)); - close(sock); - clients_finished_sending_++; - - if (sent_count != message_count) { - cout << "Client " << client_id << " sent " << sent_count << "/" << message_count << " messages\n"; - } } - int create_client_socket() { + int create_client_socket() + { int sock = socket(AF_INET, SOCK_STREAM, 0); if (sock == -1) return -1; - // Optimize socket int opt = 1; setsockopt(sock, IPPROTO_TCP, TCP_NODELAY, &opt, sizeof(opt)); - // Larger send buffer - int buf_size = 1048576; // 1MB - setsockopt(sock, SOL_SOCKET, SO_SNDBUF, &buf_size, sizeof(buf_size)); - - // Set send timeout struct timeval timeout; timeout.tv_sec = 5; timeout.tv_usec = 0; + setsockopt(sock, SOL_SOCKET, SO_RCVTIMEO, &timeout, sizeof(timeout)); setsockopt(sock, SOL_SOCKET, SO_SNDTIMEO, &timeout, sizeof(timeout)); sockaddr_in addr{}; @@ -316,33 +239,36 @@ private: return sock; } - void print_results() { + void print_results() + { auto end_time = chrono::steady_clock::now(); - auto duration = chrono::duration_cast(end_time - start_time_); + auto duration_ms = chrono::duration_cast(end_time - start_time_).count(); + if (duration_ms == 0) duration_ms = 1; - double throughput_mps = static_cast(messages_received_) / duration.count() * 1000; - double throughput_mbps = static_cast(bytes_received_) / (1024 * 1024) / duration.count() * 1000; + long long final_requests = requests_processed_.load(); + double duration_s = static_cast(duration_ms) / 1000.0; + double rps = static_cast(final_requests) / duration_s; + double throughput_mbps = static_cast(bytes_received_) / (1024 * 1024) / duration_s; cout << "\n=== Test Results ===\n"; - cout << "Duration: " << duration.count() << "ms\n"; - cout << "Messages received: " << messages_received_.load() << "\n"; - cout << "Bytes received: " << bytes_received_.load() << "\n"; - cout << "Throughput: " << static_cast(throughput_mps) << " msg/sec\n"; - cout << "Throughput: " << throughput_mbps << " MB/sec\n"; + cout << "Duration: " << duration_ms << " ms\n"; + cout << "Requests processed: " << final_requests << "\n"; + cout << "Throughput: " << static_cast(rps) << " req/sec\n"; + cout << "Data Rate (RX): " << throughput_mbps << " MB/sec\n"; cout << "Peak connections: " << peak_connections_.load() << "\n"; cout << "\n=== Memory Usage ===\n"; cout << "Baseline RSS: " << baseline_memory_.rss_kb << " KB\n"; cout << "Peak RSS: " << peak_memory_.rss_kb << " KB\n"; cout << "Memory increase: " << (peak_memory_.rss_kb - baseline_memory_.rss_kb) << " KB\n"; - if (peak_memory_.vsize_kb > 0) { - cout << "Virtual memory: " << peak_memory_.vsize_kb << " KB\n"; + if (peak_connections_.load() > 0) { + cout << "Memory per connection: " << static_cast(peak_memory_.rss_kb - baseline_memory_.rss_kb) / peak_connections_.load() << " KB\n"; } - cout << "Memory per connection: " << (peak_memory_.rss_kb - baseline_memory_.rss_kb) / peak_connections_.load() << " KB\n"; } }; -int main() { +int main() +{ SocketTest test; test.run(); return 0;