update threading

This commit is contained in:
Sky Johnson 2025-06-28 17:32:57 -05:00
parent 9ca52ef39a
commit 2ac41374d5
2 changed files with 134 additions and 55 deletions

View File

@ -3,61 +3,77 @@
#include "Core.hpp" #include "Core.hpp"
#include "Utilities.hpp" #include "Utilities.hpp"
#include <thread> #include <thread>
#include <mutex> #include <string>
#include <condition_variable>
#include <vector> #include <vector>
#include <memory> #include <memory>
#include <atomic> #include <atomic>
#include <latch>
namespace reactor { namespace reactor
{
class EventLoopThread : public NonCopyable class EventLoopThread : public NonCopyable
{ {
private: private:
std::thread thread_; std::jthread thread_;
EventLoop* loop_; EventLoop* loop_;
std::mutex mutex_; std::latch readyLatch_;
std::condition_variable cond_;
std::string name_; std::string name_;
/*
* The function executed by the internal thread.
* It creates the EventLoop, signals readiness, and runs the loop.
*/
void threadFunc() void threadFunc()
{ {
LOG_DEBUG << "EventLoopThread '" << name_ << "' starting";
EventLoop loop; EventLoop loop;
{ loop_ = &loop;
std::lock_guard<std::mutex> lock(mutex_); readyLatch_.count_down();
loop_ = &loop;
cond_.notify_one();
}
loop.loop(); loop.loop();
std::lock_guard<std::mutex> lock(mutex_);
loop_ = nullptr;
LOG_DEBUG << "EventLoopThread '" << name_ << "' finished";
} }
public: public:
/*
* Constructs an EventLoopThread.
* Starts a new thread and creates an EventLoop within it,
* waiting until the loop is fully initialized.
*/
explicit EventLoopThread(const std::string& name = "EventLoopThread") explicit EventLoopThread(const std::string& name = "EventLoopThread")
: loop_(nullptr), name_(name) : loop_(nullptr), readyLatch_(1), name_(name)
{ {
thread_ = std::thread([this]() { threadFunc(); }); thread_ = std::jthread(&EventLoopThread::threadFunc, this);
std::unique_lock<std::mutex> lock(mutex_); readyLatch_.wait();
cond_.wait(lock, [this]() { return loop_ != nullptr; });
LOG_INFO << "EventLoopThread '" << name_ << "' initialized"; LOG_INFO << "EventLoopThread '" << name_ << "' initialized";
} }
/*
* Destructs the EventLoopThread.
* Quits the event loop. The underlying jthread automatically joins.
*/
~EventLoopThread() ~EventLoopThread()
{ {
if (loop_) { if (loop_) {
loop_->quit(); loop_->quit();
thread_.join();
} }
LOG_INFO << "EventLoopThread '" << name_ << "' destroyed"; LOG_INFO << "EventLoopThread '" << name_ << "' destroyed";
} }
EventLoop* getLoop() { return loop_; } /*
const std::string& name() const { return name_; } * Returns the EventLoop associated with this thread.
* The pointer is valid for the lifetime of the thread.
*/
EventLoop* getLoop()
{
return loop_;
}
/*
* Returns the name of the thread.
*/
const std::string& name() const
{
return name_;
}
}; };
class EventLoopThreadPool : public NonCopyable class EventLoopThreadPool : public NonCopyable
@ -69,6 +85,10 @@ private:
std::string baseThreadName_; std::string baseThreadName_;
public: public:
/*
* Constructs an EventLoopThreadPool.
* Creates a pool of threads, each running an EventLoop.
*/
explicit EventLoopThreadPool(size_t numThreads, const std::string& baseName = "EventLoopThread") explicit EventLoopThreadPool(size_t numThreads, const std::string& baseName = "EventLoopThread")
: next_(0), baseThreadName_(baseName) : next_(0), baseThreadName_(baseName)
{ {
@ -84,11 +104,19 @@ public:
LOG_INFO << "EventLoopThreadPool created with " << numThreads << " threads"; LOG_INFO << "EventLoopThreadPool created with " << numThreads << " threads";
} }
/*
* Destructs the thread pool.
* The threads will be quit and joined automatically via their destructors.
*/
~EventLoopThreadPool() ~EventLoopThreadPool()
{ {
LOG_INFO << "EventLoopThreadPool destroying " << threads_.size() << " threads"; LOG_INFO << "EventLoopThreadPool destroying " << threads_.size() << " threads";
} }
/*
* Gets the next EventLoop from the pool in a round-robin fashion.
* This method is thread-safe.
*/
EventLoop* getNextLoop() EventLoop* getNextLoop()
{ {
if (loops_.empty()) { if (loops_.empty()) {
@ -96,15 +124,34 @@ public:
return nullptr; return nullptr;
} }
size_t index = next_++ % loops_.size(); size_t index = next_.fetch_add(1, std::memory_order_relaxed) % loops_.size();
LOG_TRACE << "EventLoopThreadPool returning loop " << index; LOG_TRACE << "EventLoopThreadPool returning loop " << index;
return loops_[index]; return loops_[index];
} }
std::vector<EventLoop*> getAllLoops() const { return loops_; } /*
size_t size() const { return loops_.size(); } * Returns pointers to all EventLoops in the pool.
*/
std::vector<EventLoop*> getAllLoops() const
{
return loops_;
}
const std::string& getBaseName() const { return baseThreadName_; } /*
* Returns the number of threads in the pool.
*/
size_t size() const
{
return loops_.size();
}
/*
* Returns the base name for threads in the pool.
*/
const std::string& getBaseName() const
{
return baseThreadName_;
}
}; };
} // namespace reactor }

View File

@ -7,7 +7,11 @@
#include <atomic> #include <atomic>
#include <vector> #include <vector>
#include <future> #include <future>
#include <memory>
/*
* Tests basic creation and task execution in an EventLoopThread.
*/
void test_event_loop_thread_basic() void test_event_loop_thread_basic()
{ {
std::cout << "Testing basic EventLoopThread...\n"; std::cout << "Testing basic EventLoopThread...\n";
@ -29,6 +33,9 @@ void test_event_loop_thread_basic()
std::cout << "✓ Basic EventLoopThread passed\n"; std::cout << "✓ Basic EventLoopThread passed\n";
} }
/*
* Tests timer functionality within an EventLoopThread.
*/
void test_event_loop_thread_timer() void test_event_loop_thread_timer()
{ {
std::cout << "Testing EventLoopThread timer...\n"; std::cout << "Testing EventLoopThread timer...\n";
@ -50,23 +57,26 @@ void test_event_loop_thread_timer()
std::cout << "✓ EventLoopThread timer passed (count: " << final_count << ")\n"; std::cout << "✓ EventLoopThread timer passed (count: " << final_count << ")\n";
} }
/*
* Tests the creation and interaction of multiple EventLoopThreads.
*/
void test_multiple_event_loop_threads() void test_multiple_event_loop_threads()
{ {
std::cout << "Testing multiple EventLoopThreads...\n"; std::cout << "Testing multiple EventLoopThreads...\n";
constexpr int num_threads = 3; constexpr int num_threads = 3;
std::vector<std::unique_ptr<reactor::EventLoopThread>> threads; std::vector<std::unique_ptr<reactor::EventLoopThread>> threads;
std::vector<std::atomic<int>*> counters; std::vector<std::unique_ptr<std::atomic<int>>> counters;
for (int i = 0; i < num_threads; ++i) { for (int i = 0; i < num_threads; ++i) {
std::string name = "Thread-" + std::to_string(i); std::string name = "Thread-" + std::to_string(i);
threads.push_back(std::make_unique<reactor::EventLoopThread>(name)); threads.push_back(std::make_unique<reactor::EventLoopThread>(name));
counters.push_back(new std::atomic<int>{0}); counters.push_back(std::make_unique<std::atomic<int>>(0));
} }
for (int i = 0; i < num_threads; ++i) { for (int i = 0; i < num_threads; ++i) {
auto loop = threads[i]->getLoop(); auto loop = threads[i]->getLoop();
auto counter = counters[i]; auto* counter = counters[i].get();
for (int j = 0; j < 10; ++j) { for (int j = 0; j < 10; ++j) {
loop->queueInLoop([counter]() { loop->queueInLoop([counter]() {
@ -79,12 +89,14 @@ void test_multiple_event_loop_threads()
for (int i = 0; i < num_threads; ++i) { for (int i = 0; i < num_threads; ++i) {
assert(*counters[i] == 10); assert(*counters[i] == 10);
delete counters[i];
} }
std::cout << "✓ Multiple EventLoopThreads passed\n"; std::cout << "✓ Multiple EventLoopThreads passed\n";
} }
/*
* Tests the basic creation of an EventLoopThreadPool.
*/
void test_event_loop_thread_pool_basic() void test_event_loop_thread_pool_basic()
{ {
std::cout << "Testing basic EventLoopThreadPool...\n"; std::cout << "Testing basic EventLoopThreadPool...\n";
@ -104,6 +116,9 @@ void test_event_loop_thread_pool_basic()
std::cout << "✓ Basic EventLoopThreadPool passed\n"; std::cout << "✓ Basic EventLoopThreadPool passed\n";
} }
/*
* Tests the round-robin distribution of loops from the pool.
*/
void test_thread_pool_round_robin() void test_thread_pool_round_robin()
{ {
std::cout << "Testing thread pool round robin...\n"; std::cout << "Testing thread pool round robin...\n";
@ -115,6 +130,7 @@ void test_thread_pool_round_robin()
selected_loops.push_back(pool.getNextLoop()); selected_loops.push_back(pool.getNextLoop());
} }
// With 3 threads, loop i, i+3, and i+6 should be the same.
for (int i = 0; i < 3; ++i) { for (int i = 0; i < 3; ++i) {
assert(selected_loops[i] == selected_loops[i + 3]); assert(selected_loops[i] == selected_loops[i + 3]);
assert(selected_loops[i] == selected_loops[i + 6]); assert(selected_loops[i] == selected_loops[i + 6]);
@ -123,15 +139,18 @@ void test_thread_pool_round_robin()
std::cout << "✓ Thread pool round robin passed\n"; std::cout << "✓ Thread pool round robin passed\n";
} }
/*
* Tests that tasks are distributed among threads in the pool.
*/
void test_thread_pool_task_distribution() void test_thread_pool_task_distribution()
{ {
std::cout << "Testing thread pool task distribution...\n"; std::cout << "Testing thread pool task distribution...\n";
reactor::EventLoopThreadPool pool(3, "TaskDist"); reactor::EventLoopThreadPool pool(3, "TaskDist");
std::vector<std::atomic<int>*> counters(3); std::vector<std::unique_ptr<std::atomic<int>>> counters;
for (int i = 0; i < 3; ++i) { for (int i = 0; i < 3; ++i) {
counters[i] = new std::atomic<int>{0}; counters.push_back(std::make_unique<std::atomic<int>>(0));
} }
std::map<reactor::EventLoop*, int> loop_to_index; std::map<reactor::EventLoop*, int> loop_to_index;
@ -143,9 +162,9 @@ void test_thread_pool_task_distribution()
constexpr int tasks_per_loop = 10; constexpr int tasks_per_loop = 10;
for (int i = 0; i < 3 * tasks_per_loop; ++i) { for (int i = 0; i < 3 * tasks_per_loop; ++i) {
auto loop = pool.getNextLoop(); auto loop = pool.getNextLoop();
int index = loop_to_index[loop]; int index = loop_to_index.at(loop);
loop->queueInLoop([counter = counters[index]]() { loop->queueInLoop([counter = counters[index].get()]() {
(*counter)++; (*counter)++;
}); });
} }
@ -154,12 +173,14 @@ void test_thread_pool_task_distribution()
for (int i = 0; i < 3; ++i) { for (int i = 0; i < 3; ++i) {
assert(*counters[i] == tasks_per_loop); assert(*counters[i] == tasks_per_loop);
delete counters[i];
} }
std::cout << "✓ Thread pool task distribution passed\n"; std::cout << "✓ Thread pool task distribution passed\n";
} }
/*
* Tests the behavior of an empty thread pool.
*/
void test_empty_thread_pool() void test_empty_thread_pool()
{ {
std::cout << "Testing empty thread pool...\n"; std::cout << "Testing empty thread pool...\n";
@ -173,6 +194,9 @@ void test_empty_thread_pool()
std::cout << "✓ Empty thread pool passed\n"; std::cout << "✓ Empty thread pool passed\n";
} }
/*
* Tests concurrent access to the thread pool's getNextLoop method.
*/
void test_thread_pool_concurrent_access() void test_thread_pool_concurrent_access()
{ {
std::cout << "Testing thread pool concurrent access...\n"; std::cout << "Testing thread pool concurrent access...\n";
@ -205,6 +229,9 @@ void test_thread_pool_concurrent_access()
std::cout << "✓ Thread pool concurrent access passed\n"; std::cout << "✓ Thread pool concurrent access passed\n";
} }
/*
* Tests timer functionality across all threads in a pool.
*/
void test_thread_pool_with_timers() void test_thread_pool_with_timers()
{ {
std::cout << "Testing thread pool with timers...\n"; std::cout << "Testing thread pool with timers...\n";
@ -224,7 +251,7 @@ void test_thread_pool_with_timers()
std::this_thread::sleep_for(std::chrono::milliseconds(100)); std::this_thread::sleep_for(std::chrono::milliseconds(100));
for (int i = 0; i < loops.size(); ++i) { for (size_t i = 0; i < loops.size(); ++i) {
loops[i]->cancel(timer_ids[i]); loops[i]->cancel(timer_ids[i]);
} }
@ -234,6 +261,9 @@ void test_thread_pool_with_timers()
std::cout << "✓ Thread pool with timers passed (count: " << final_count << ")\n"; std::cout << "✓ Thread pool with timers passed (count: " << final_count << ")\n";
} }
/*
* Tests task synchronization between the main thread and a loop thread.
*/
void test_thread_synchronization() void test_thread_synchronization()
{ {
std::cout << "Testing thread synchronization...\n"; std::cout << "Testing thread synchronization...\n";
@ -243,22 +273,20 @@ void test_thread_synchronization()
std::vector<int> results; std::vector<int> results;
std::mutex results_mutex; std::mutex results_mutex;
std::vector<std::future<void>> futures; std::vector<std::future<void>> futures;
for (int i = 0; i < 10; ++i) { for (int i = 0; i < 10; ++i) {
std::promise<void> promise; // Create a shared_ptr to the promise to make the lambda copyable
auto future = promise.get_future(); auto promise_ptr = std::make_shared<std::promise<void>>();
futures.push_back(promise_ptr->get_future());
loop->queueInLoop([i, &results, &results_mutex, promise = std::move(promise)]() mutable { loop->queueInLoop([i, &results, &results_mutex, p = promise_ptr]() {
{ {
std::lock_guard<std::mutex> lock(results_mutex); std::lock_guard<std::mutex> lock(results_mutex);
results.push_back(i); results.push_back(i);
} }
promise.set_value(); p->set_value();
}); });
futures.push_back(std::move(future));
} }
for (auto& future : futures) { for (auto& future : futures) {
@ -270,29 +298,33 @@ void test_thread_synchronization()
std::cout << "✓ Thread synchronization passed\n"; std::cout << "✓ Thread synchronization passed\n";
} }
/*
* Tests that tasks are completed before a thread pool is destructed.
*/
void test_thread_pool_destruction() void test_thread_pool_destruction()
{ {
std::cout << "Testing thread pool destruction...\n"; std::cout << "Testing thread pool destruction...\n";
std::atomic<int> destructor_count{0}; std::atomic<int> task_count{0};
{ {
reactor::EventLoopThreadPool pool(2, "DestroyTest"); reactor::EventLoopThreadPool pool(2, "DestroyTest");
auto loops = pool.getAllLoops(); auto loops = pool.getAllLoops();
for (auto loop : loops) { for (auto loop : loops) {
loop->queueInLoop([&destructor_count]() { loop->queueInLoop([&task_count]() {
destructor_count++; std::this_thread::sleep_for(std::chrono::milliseconds(20));
task_count++;
}); });
} }
} // pool is destroyed here, blocking until threads finish
std::this_thread::sleep_for(std::chrono::milliseconds(50)); assert(task_count == 2);
}
assert(destructor_count == 2);
std::cout << "✓ Thread pool destruction passed\n"; std::cout << "✓ Thread pool destruction passed\n";
} }
/*
* Main entry point for running all threading tests.
*/
int main() int main()
{ {
std::cout << "=== Threading Tests ===\n"; std::cout << "=== Threading Tests ===\n";
@ -309,6 +341,6 @@ int main()
test_thread_synchronization(); test_thread_synchronization();
test_thread_pool_destruction(); test_thread_pool_destruction();
std::cout << "All threading tests passed! ✓\n"; std::cout << "\nAll threading tests passed! ✓\n";
return 0; return 0;
} }