Reactor/tests/test_tcp_server.cpp

285 lines
6.9 KiB
C++

#include "../lib/TcpServer.hpp"
#include "../lib/Socket.hpp"
#include "../lib/InetAddress.hpp"
#include "../lib/Core.hpp"
#include "../lib/Buffer.hpp"
#include <cassert>
#include <iostream>
#include <set>
#include <thread>
#include <chrono>
#include <atomic>
#include <memory>
#include <vector>
/*
* A simple client for testing the TcpServer.
*/
class TestClient
{
private:
reactor::Socket socket_;
public:
TestClient()
: socket_(reactor::Socket::createTcp())
{
}
bool connect(const reactor::InetAddress& addr)
{
int result = socket_.connect(addr);
if (result == 0 || errno == EINPROGRESS) {
std::this_thread::sleep_for(std::chrono::milliseconds(20));
return true;
}
return false;
}
bool send(const std::string& data)
{
ssize_t sent = socket_.write(data.data(), data.size());
return sent == static_cast<ssize_t>(data.size());
}
std::string receive(size_t max_size = 1024)
{
char buffer[1024];
ssize_t received = socket_.read(buffer, std::min(max_size, sizeof(buffer)));
if (received > 0) {
return std::string(buffer, received);
}
return "";
}
void close()
{
socket_.shutdownWrite();
}
};
/*
* Tests basic server functionality: start, connect, send/receive, disconnect.
*/
void test_tcp_server_basic()
{
std::cout << "Testing basic TCP server...\n";
reactor::EventLoop loop;
reactor::InetAddress listen_addr(0); // Port 0 asks OS for any free port
reactor::TcpServer server(&loop, listen_addr, "TestServer");
std::atomic<bool> connection_received{false};
std::atomic<bool> message_received{false};
server.setConnectionCallback([&](const reactor::TcpConnectionPtr& conn) {
if (conn->connected()) {
connection_received = true;
std::cout << "New connection: " << conn->name() << "\n";
} else {
std::cout << "Connection closed: " << conn->name() << "\n";
}
});
server.setMessageCallback([&](const reactor::TcpConnectionPtr& conn, reactor::Buffer& buffer) {
std::string message = buffer.readAll();
std::cout << "Received: " << message << "\n";
message_received = true;
conn->send("Echo: " + message);
});
server.start();
// Get the actual port assigned by the OS after the server starts listening
reactor::InetAddress actual_listen_addr("127.0.0.1", server.listenAddress().port());
std::thread server_thread([&loop]() {
loop.loop();
});
std::this_thread::sleep_for(std::chrono::milliseconds(10));
TestClient client;
bool connected = client.connect(actual_listen_addr);
assert(connected);
std::this_thread::sleep_for(std::chrono::milliseconds(10));
assert(connection_received);
assert(client.send("Hello Server"));
std::this_thread::sleep_for(std::chrono::milliseconds(10));
assert(message_received);
std::string response = client.receive();
assert(response == "Echo: Hello Server");
client.close();
loop.quit();
server_thread.join();
std::cout << "✓ Basic TCP server passed\n";
}
/*
* Tests the server's ability to handle multiple concurrent connections.
*/
void test_multiple_connections()
{
std::cout << "Testing multiple connections...\n";
reactor::EventLoop loop;
reactor::InetAddress listen_addr(0);
reactor::TcpServer server(&loop, listen_addr, "MultiServer");
std::atomic<int> connection_count{0};
std::atomic<int> message_count{0};
server.setConnectionCallback([&](const reactor::TcpConnectionPtr& conn) {
if (conn->connected()) {
connection_count++;
} else {
connection_count.fetch_sub(1);
}
});
server.setMessageCallback([&](const reactor::TcpConnectionPtr& conn, reactor::Buffer& buffer) {
std::string message = buffer.readAll();
message_count++;
conn->send("Response: " + message);
});
server.start();
reactor::InetAddress actual_listen_addr("127.0.0.1", server.listenAddress().port());
std::thread server_thread([&loop]() {
loop.loop();
});
std::this_thread::sleep_for(std::chrono::milliseconds(10));
constexpr int num_clients = 5;
std::vector<std::unique_ptr<TestClient>> clients;
for (int i = 0; i < num_clients; ++i) {
auto client = std::make_unique<TestClient>();
bool connected = client->connect(actual_listen_addr);
assert(connected);
clients.push_back(std::move(client));
}
std::this_thread::sleep_for(std::chrono::milliseconds(20));
assert(connection_count == num_clients);
for (int i = 0; i < num_clients; ++i) {
std::string message = "Message " + std::to_string(i);
assert(clients[i]->send(message));
}
std::this_thread::sleep_for(std::chrono::milliseconds(50));
assert(message_count == num_clients);
for (auto& client : clients) {
client->close();
}
std::this_thread::sleep_for(std::chrono::milliseconds(50));
assert(connection_count == 0);
loop.quit();
server_thread.join();
std::cout << "✓ Multiple connections passed\n";
}
/*
* Tests the server's thread pool for distributing work.
*/
void test_server_with_thread_pool()
{
std::cout << "Testing server with thread pool...\n";
reactor::EventLoop loop;
reactor::InetAddress listen_addr(0);
reactor::TcpServer server(&loop, listen_addr, "ThreadPoolServer");
server.setThreadNum(2);
std::atomic<int> message_count{0};
std::vector<std::thread::id> thread_ids;
std::mutex thread_ids_mutex;
server.setMessageCallback([&](const reactor::TcpConnectionPtr& conn, reactor::Buffer& buffer) {
{
std::lock_guard<std::mutex> lock(thread_ids_mutex);
thread_ids.push_back(std::this_thread::get_id());
}
std::string message = buffer.readAll();
message_count++;
std::this_thread::sleep_for(std::chrono::milliseconds(10));
conn->send("Processed: " + message);
});
server.start();
reactor::InetAddress actual_listen_addr("127.0.0.1", server.listenAddress().port());
std::thread server_thread([&loop]() {
loop.loop();
});
std::this_thread::sleep_for(std::chrono::milliseconds(20));
constexpr int num_clients = 4;
std::vector<std::thread> client_threads;
for (int i = 0; i < num_clients; ++i) {
client_threads.emplace_back([&actual_listen_addr, i]() {
TestClient client;
bool connected = client.connect(actual_listen_addr);
assert(connected);
std::string message = "Client" + std::to_string(i);
assert(client.send(message));
// Wait for the server to process and reply.
std::this_thread::sleep_for(std::chrono::milliseconds(50));
std::string response = client.receive(1024);
assert(response == "Processed: " + message);
client.close();
});
}
for (auto& thread : client_threads) {
thread.join();
}
std::this_thread::sleep_for(std::chrono::milliseconds(50));
assert(message_count == num_clients);
{
std::lock_guard<std::mutex> lock(thread_ids_mutex);
std::set<std::thread::id> unique_threads(thread_ids.begin(), thread_ids.end());
assert(unique_threads.size() >= 2);
}
loop.quit();
server_thread.join();
std::cout << "✓ Server with thread pool passed\n";
}
int main()
{
std::cout << "=== TCP Server Tests ===\n";
test_tcp_server_basic();
test_multiple_connections();
test_server_with_thread_pool();
std::cout << "All TCP server tests passed! ✓\n";
return 0;
}