diff --git a/lib/EventLoopThread.hpp b/lib/EventLoopThread.hpp index 3a208d1..0a84d2d 100644 --- a/lib/EventLoopThread.hpp +++ b/lib/EventLoopThread.hpp @@ -3,61 +3,77 @@ #include "Core.hpp" #include "Utilities.hpp" #include -#include -#include +#include #include #include #include +#include -namespace reactor { +namespace reactor +{ class EventLoopThread : public NonCopyable { private: - std::thread thread_; + std::jthread thread_; EventLoop* loop_; - std::mutex mutex_; - std::condition_variable cond_; + std::latch readyLatch_; std::string name_; + /* + * The function executed by the internal thread. + * It creates the EventLoop, signals readiness, and runs the loop. + */ void threadFunc() { - LOG_DEBUG << "EventLoopThread '" << name_ << "' starting"; EventLoop loop; - { - std::lock_guard lock(mutex_); - loop_ = &loop; - cond_.notify_one(); - } - + loop_ = &loop; + readyLatch_.count_down(); loop.loop(); - - std::lock_guard lock(mutex_); - loop_ = nullptr; - LOG_DEBUG << "EventLoopThread '" << name_ << "' finished"; } public: + /* + * Constructs an EventLoopThread. + * Starts a new thread and creates an EventLoop within it, + * waiting until the loop is fully initialized. + */ explicit EventLoopThread(const std::string& name = "EventLoopThread") - : loop_(nullptr), name_(name) + : loop_(nullptr), readyLatch_(1), name_(name) { - thread_ = std::thread([this]() { threadFunc(); }); - std::unique_lock lock(mutex_); - cond_.wait(lock, [this]() { return loop_ != nullptr; }); + thread_ = std::jthread(&EventLoopThread::threadFunc, this); + readyLatch_.wait(); LOG_INFO << "EventLoopThread '" << name_ << "' initialized"; } + /* + * Destructs the EventLoopThread. + * Quits the event loop. The underlying jthread automatically joins. + */ ~EventLoopThread() { if (loop_) { loop_->quit(); - thread_.join(); } LOG_INFO << "EventLoopThread '" << name_ << "' destroyed"; } - EventLoop* getLoop() { return loop_; } - const std::string& name() const { return name_; } + /* + * Returns the EventLoop associated with this thread. + * The pointer is valid for the lifetime of the thread. + */ + EventLoop* getLoop() + { + return loop_; + } + + /* + * Returns the name of the thread. + */ + const std::string& name() const + { + return name_; + } }; class EventLoopThreadPool : public NonCopyable @@ -69,6 +85,10 @@ private: std::string baseThreadName_; public: + /* + * Constructs an EventLoopThreadPool. + * Creates a pool of threads, each running an EventLoop. + */ explicit EventLoopThreadPool(size_t numThreads, const std::string& baseName = "EventLoopThread") : next_(0), baseThreadName_(baseName) { @@ -84,11 +104,19 @@ public: LOG_INFO << "EventLoopThreadPool created with " << numThreads << " threads"; } + /* + * Destructs the thread pool. + * The threads will be quit and joined automatically via their destructors. + */ ~EventLoopThreadPool() { LOG_INFO << "EventLoopThreadPool destroying " << threads_.size() << " threads"; } + /* + * Gets the next EventLoop from the pool in a round-robin fashion. + * This method is thread-safe. + */ EventLoop* getNextLoop() { if (loops_.empty()) { @@ -96,15 +124,34 @@ public: return nullptr; } - size_t index = next_++ % loops_.size(); + size_t index = next_.fetch_add(1, std::memory_order_relaxed) % loops_.size(); LOG_TRACE << "EventLoopThreadPool returning loop " << index; return loops_[index]; } - std::vector getAllLoops() const { return loops_; } - size_t size() const { return loops_.size(); } + /* + * Returns pointers to all EventLoops in the pool. + */ + std::vector getAllLoops() const + { + return loops_; + } - const std::string& getBaseName() const { return baseThreadName_; } + /* + * Returns the number of threads in the pool. + */ + size_t size() const + { + return loops_.size(); + } + + /* + * Returns the base name for threads in the pool. + */ + const std::string& getBaseName() const + { + return baseThreadName_; + } }; -} // namespace reactor +} diff --git a/tests/test_threading.cpp b/tests/test_threading.cpp index a66b2dc..6093099 100644 --- a/tests/test_threading.cpp +++ b/tests/test_threading.cpp @@ -7,7 +7,11 @@ #include #include #include +#include +/* +* Tests basic creation and task execution in an EventLoopThread. +*/ void test_event_loop_thread_basic() { std::cout << "Testing basic EventLoopThread...\n"; @@ -29,6 +33,9 @@ void test_event_loop_thread_basic() std::cout << "✓ Basic EventLoopThread passed\n"; } +/* +* Tests timer functionality within an EventLoopThread. +*/ void test_event_loop_thread_timer() { std::cout << "Testing EventLoopThread timer...\n"; @@ -50,23 +57,26 @@ void test_event_loop_thread_timer() std::cout << "✓ EventLoopThread timer passed (count: " << final_count << ")\n"; } +/* +* Tests the creation and interaction of multiple EventLoopThreads. +*/ void test_multiple_event_loop_threads() { std::cout << "Testing multiple EventLoopThreads...\n"; constexpr int num_threads = 3; std::vector> threads; - std::vector*> counters; + std::vector>> counters; for (int i = 0; i < num_threads; ++i) { std::string name = "Thread-" + std::to_string(i); threads.push_back(std::make_unique(name)); - counters.push_back(new std::atomic{0}); + counters.push_back(std::make_unique>(0)); } for (int i = 0; i < num_threads; ++i) { auto loop = threads[i]->getLoop(); - auto counter = counters[i]; + auto* counter = counters[i].get(); for (int j = 0; j < 10; ++j) { loop->queueInLoop([counter]() { @@ -79,12 +89,14 @@ void test_multiple_event_loop_threads() for (int i = 0; i < num_threads; ++i) { assert(*counters[i] == 10); - delete counters[i]; } std::cout << "✓ Multiple EventLoopThreads passed\n"; } +/* +* Tests the basic creation of an EventLoopThreadPool. +*/ void test_event_loop_thread_pool_basic() { std::cout << "Testing basic EventLoopThreadPool...\n"; @@ -104,6 +116,9 @@ void test_event_loop_thread_pool_basic() std::cout << "✓ Basic EventLoopThreadPool passed\n"; } +/* +* Tests the round-robin distribution of loops from the pool. +*/ void test_thread_pool_round_robin() { std::cout << "Testing thread pool round robin...\n"; @@ -115,6 +130,7 @@ void test_thread_pool_round_robin() selected_loops.push_back(pool.getNextLoop()); } + // With 3 threads, loop i, i+3, and i+6 should be the same. for (int i = 0; i < 3; ++i) { assert(selected_loops[i] == selected_loops[i + 3]); assert(selected_loops[i] == selected_loops[i + 6]); @@ -123,15 +139,18 @@ void test_thread_pool_round_robin() std::cout << "✓ Thread pool round robin passed\n"; } +/* +* Tests that tasks are distributed among threads in the pool. +*/ void test_thread_pool_task_distribution() { std::cout << "Testing thread pool task distribution...\n"; reactor::EventLoopThreadPool pool(3, "TaskDist"); - std::vector*> counters(3); + std::vector>> counters; for (int i = 0; i < 3; ++i) { - counters[i] = new std::atomic{0}; + counters.push_back(std::make_unique>(0)); } std::map loop_to_index; @@ -143,9 +162,9 @@ void test_thread_pool_task_distribution() constexpr int tasks_per_loop = 10; for (int i = 0; i < 3 * tasks_per_loop; ++i) { auto loop = pool.getNextLoop(); - int index = loop_to_index[loop]; + int index = loop_to_index.at(loop); - loop->queueInLoop([counter = counters[index]]() { + loop->queueInLoop([counter = counters[index].get()]() { (*counter)++; }); } @@ -154,12 +173,14 @@ void test_thread_pool_task_distribution() for (int i = 0; i < 3; ++i) { assert(*counters[i] == tasks_per_loop); - delete counters[i]; } std::cout << "✓ Thread pool task distribution passed\n"; } +/* +* Tests the behavior of an empty thread pool. +*/ void test_empty_thread_pool() { std::cout << "Testing empty thread pool...\n"; @@ -173,6 +194,9 @@ void test_empty_thread_pool() std::cout << "✓ Empty thread pool passed\n"; } +/* +* Tests concurrent access to the thread pool's getNextLoop method. +*/ void test_thread_pool_concurrent_access() { std::cout << "Testing thread pool concurrent access...\n"; @@ -205,6 +229,9 @@ void test_thread_pool_concurrent_access() std::cout << "✓ Thread pool concurrent access passed\n"; } +/* +* Tests timer functionality across all threads in a pool. +*/ void test_thread_pool_with_timers() { std::cout << "Testing thread pool with timers...\n"; @@ -224,7 +251,7 @@ void test_thread_pool_with_timers() std::this_thread::sleep_for(std::chrono::milliseconds(100)); - for (int i = 0; i < loops.size(); ++i) { + for (size_t i = 0; i < loops.size(); ++i) { loops[i]->cancel(timer_ids[i]); } @@ -234,6 +261,9 @@ void test_thread_pool_with_timers() std::cout << "✓ Thread pool with timers passed (count: " << final_count << ")\n"; } +/* +* Tests task synchronization between the main thread and a loop thread. +*/ void test_thread_synchronization() { std::cout << "Testing thread synchronization...\n"; @@ -243,22 +273,20 @@ void test_thread_synchronization() std::vector results; std::mutex results_mutex; - std::vector> futures; for (int i = 0; i < 10; ++i) { - std::promise promise; - auto future = promise.get_future(); + // Create a shared_ptr to the promise to make the lambda copyable + auto promise_ptr = std::make_shared>(); + futures.push_back(promise_ptr->get_future()); - loop->queueInLoop([i, &results, &results_mutex, promise = std::move(promise)]() mutable { + loop->queueInLoop([i, &results, &results_mutex, p = promise_ptr]() { { std::lock_guard lock(results_mutex); results.push_back(i); } - promise.set_value(); + p->set_value(); }); - - futures.push_back(std::move(future)); } for (auto& future : futures) { @@ -270,29 +298,33 @@ void test_thread_synchronization() std::cout << "✓ Thread synchronization passed\n"; } +/* +* Tests that tasks are completed before a thread pool is destructed. +*/ void test_thread_pool_destruction() { std::cout << "Testing thread pool destruction...\n"; - std::atomic destructor_count{0}; + std::atomic task_count{0}; { reactor::EventLoopThreadPool pool(2, "DestroyTest"); - auto loops = pool.getAllLoops(); for (auto loop : loops) { - loop->queueInLoop([&destructor_count]() { - destructor_count++; + loop->queueInLoop([&task_count]() { + std::this_thread::sleep_for(std::chrono::milliseconds(20)); + task_count++; }); } + } // pool is destroyed here, blocking until threads finish - std::this_thread::sleep_for(std::chrono::milliseconds(50)); - } - - assert(destructor_count == 2); + assert(task_count == 2); std::cout << "✓ Thread pool destruction passed\n"; } +/* +* Main entry point for running all threading tests. +*/ int main() { std::cout << "=== Threading Tests ===\n"; @@ -309,6 +341,6 @@ int main() test_thread_synchronization(); test_thread_pool_destruction(); - std::cout << "All threading tests passed! ✓\n"; + std::cout << "\nAll threading tests passed! ✓\n"; return 0; }