Reactor/lib/Core.hpp
2025-06-28 15:30:14 -05:00

624 lines
14 KiB
C++

#pragma once
#include "Utilities.hpp"
#include <sys/epoll.h>
#include <sys/timerfd.h>
#include <sys/eventfd.h>
#include <poll.h>
#include <unistd.h>
#include <vector>
#include <unordered_map>
#include <unordered_set>
#include <functional>
#include <memory>
#include <queue>
#include <chrono>
#include <thread>
#include <atomic>
#include <cassert>
#include <cstring>
#include <mutex>
namespace reactor
{
using TimePoint = std::chrono::steady_clock::time_point;
using Duration = std::chrono::milliseconds;
using TimerCallback = std::function<void()>;
using EventCallback = std::function<void()>;
class EventLoop;
class Channel : public NonCopyable
{
private:
EventLoop* loop_;
int fd_;
int events_;
int revents_;
int index_;
bool tied_;
std::weak_ptr<void> tie_;
EventCallback readCallback_;
EventCallback writeCallback_;
EventCallback closeCallback_;
EventCallback errorCallback_;
/*
* Safely handle the event, checking if the tied object is still alive.
*/
void handleEventSafely()
{
LOG_TRACE << "Channel fd=" << fd_ << " handling events: " << revents_;
if ((revents_ & POLLHUP) && !(revents_ & POLLIN)) {
LOG_DEBUG << "Channel fd=" << fd_ << " hangup";
if (closeCallback_) {
closeCallback_();
}
}
if (revents_ & POLLERR) {
LOG_WARN << "Channel fd=" << fd_ << " error event";
if (errorCallback_) {
errorCallback_();
}
}
if (revents_ & (POLLIN | POLLPRI | POLLRDHUP)) {
LOG_TRACE << "Channel fd=" << fd_ << " readable";
if (readCallback_) {
readCallback_();
}
}
if (revents_ & POLLOUT) {
LOG_TRACE << "Channel fd=" << fd_ << " writable";
if (writeCallback_) {
writeCallback_();
}
}
}
public:
static constexpr int kNoneEvent = 0;
static constexpr int kReadEvent = POLLIN | POLLPRI;
static constexpr int kWriteEvent = POLLOUT;
Channel(EventLoop* loop, int fd)
: loop_(loop), fd_(fd), events_(0), revents_(0), index_(-1), tied_(false)
{
LOG_TRACE << "Channel created for fd=" << fd;
}
~Channel()
{
LOG_TRACE << "Channel destroyed for fd=" << fd_;
}
int fd() const { return fd_; }
int events() const { return events_; }
int revents() const { return revents_; }
int index() const { return index_; }
EventLoop* ownerLoop() const { return loop_; }
void setRevents(int revents) { revents_ = revents; }
void setIndex(int index) { index_ = index; }
void enableReading()
{
events_ |= kReadEvent;
update();
LOG_TRACE << "Channel fd=" << fd_ << " enabled reading";
}
void disableReading()
{
events_ &= ~kReadEvent;
update();
LOG_TRACE << "Channel fd=" << fd_ << " disabled reading";
}
void enableWriting()
{
events_ |= kWriteEvent;
update();
LOG_TRACE << "Channel fd=" << fd_ << " enabled writing";
}
void disableWriting()
{
events_ &= ~kWriteEvent;
update();
LOG_TRACE << "Channel fd=" << fd_ << " disabled writing";
}
void disableAll()
{
events_ = kNoneEvent;
update();
LOG_TRACE << "Channel fd=" << fd_ << " disabled all events";
}
bool isNoneEvent() const { return events_ == kNoneEvent; }
bool isWriting() const { return events_ & kWriteEvent; }
bool isReading() const { return events_ & kReadEvent; }
void tie(const std::shared_ptr<void>& obj)
{
tie_ = obj;
tied_ = true;
}
void setReadCallback(EventCallback cb) { readCallback_ = std::move(cb); }
void setWriteCallback(EventCallback cb) { writeCallback_ = std::move(cb); }
void setCloseCallback(EventCallback cb) { closeCallback_ = std::move(cb); }
void setErrorCallback(EventCallback cb) { errorCallback_ = std::move(cb); }
/*
* Handle an event. If the channel is tied to an object,
* ensure the object is still alive before proceeding.
*/
void handleEvent()
{
if (tied_) {
std::shared_ptr<void> guard = tie_.lock();
if (guard) {
handleEventSafely();
} else {
LOG_WARN << "Channel fd=" << fd_ << " tied object expired";
}
} else {
handleEventSafely();
}
}
void remove();
void update();
};
class Timer : public NonCopyable
{
private:
TimerCallback callback_;
TimePoint when_;
Duration interval_;
bool repeat_;
uint64_t id_;
static inline std::atomic<uint64_t> s_numCreated_{0};
public:
Timer(TimerCallback cb, TimePoint when, Duration interval = Duration(0))
: callback_(std::move(cb)), when_(when), interval_(interval),
repeat_(interval.count() > 0), id_(++s_numCreated_)
{
LOG_TRACE << "Timer created id=" << id_ << " repeat=" << repeat_;
}
void run() const
{
LOG_TRACE << "Timer id=" << id_ << " executing";
callback_();
}
TimePoint when() const { return when_; }
bool repeat() const { return repeat_; }
uint64_t id() const { return id_; }
void restart(TimePoint now)
{
if (repeat_) {
when_ = now + interval_;
LOG_TRACE << "Timer id=" << id_ << " restarted";
}
}
bool operator<(const Timer& other) const { return when_ > other.when_; }
};
class TimerQueue : public NonCopyable
{
private:
EventLoop* loop_;
int timerfd_;
std::unique_ptr<Channel> timerChannel_;
std::priority_queue<std::shared_ptr<Timer>> timers_;
std::unordered_set<uint64_t> activeTimers_;
static int createTimerfd()
{
int fd = timerfd_create(CLOCK_MONOTONIC, TFD_NONBLOCK | TFD_CLOEXEC);
if (fd < 0) {
LOG_FATAL << "Failed to create timerfd: " << strerror(errno);
abort();
}
LOG_DEBUG << "Created timerfd=" << fd;
return fd;
}
void resetTimerfd(TimePoint expiration)
{
auto duration = expiration - std::chrono::steady_clock::now();
auto ns = std::chrono::duration_cast<std::chrono::nanoseconds>(duration).count();
if (ns < 100000) {
ns = 100000;
}
itimerspec newValue{};
newValue.it_value.tv_sec = ns / 1000000000;
newValue.it_value.tv_nsec = ns % 1000000000;
if (timerfd_settime(timerfd_, 0, &newValue, nullptr) < 0) {
LOG_ERROR << "timerfd_settime failed: " << strerror(errno);
}
}
void handleRead()
{
uint64_t data;
ssize_t n = read(timerfd_, &data, sizeof(data));
if (n != sizeof(data)) {
LOG_ERROR << "TimerQueue read " << n << " bytes instead of 8";
}
auto now = std::chrono::steady_clock::now();
std::vector<std::shared_ptr<Timer>> expired;
while (!timers_.empty() && timers_.top()->when() <= now) {
expired.push_back(timers_.top());
timers_.pop();
}
LOG_TRACE << "TimerQueue processed " << expired.size() << " expired timers";
for (auto& timer : expired) {
if (activeTimers_.count(timer->id())) {
timer->run();
if (timer->repeat()) {
timer->restart(now);
timers_.push(timer);
} else {
activeTimers_.erase(timer->id());
}
}
}
if (!timers_.empty()) {
resetTimerfd(timers_.top()->when());
}
}
public:
TimerQueue(EventLoop* loop)
: loop_(loop), timerfd_(createTimerfd()),
timerChannel_(std::make_unique<Channel>(loop, timerfd_))
{
timerChannel_->setReadCallback([this]() { handleRead(); });
timerChannel_->enableReading();
LOG_DEBUG << "TimerQueue initialized";
}
~TimerQueue()
{
timerChannel_->disableAll();
timerChannel_->remove();
close(timerfd_);
LOG_DEBUG << "TimerQueue destroyed";
}
uint64_t addTimer(TimerCallback cb, TimePoint when, Duration interval = Duration(0))
{
auto timer = std::make_shared<Timer>(std::move(cb), when, interval);
bool earliestChanged = timers_.empty() || when < timers_.top()->when();
timers_.push(timer);
activeTimers_.insert(timer->id());
if (earliestChanged) {
resetTimerfd(when);
}
LOG_DEBUG << "Added timer id=" << timer->id() << " earliest_changed=" << earliestChanged;
return timer->id();
}
void cancel(uint64_t timerId)
{
auto erased = activeTimers_.erase(timerId);
LOG_DEBUG << "Cancelled timer id=" << timerId << " found=" << (erased > 0);
}
};
class EpollPoller : public NonCopyable
{
private:
static constexpr int kNew = -1;
static constexpr int kAdded = 1;
static constexpr int kDeleted = 2;
static constexpr int kInitEventListSize = 16;
int epollfd_;
std::vector<epoll_event> events_;
std::unordered_map<int, Channel*> channels_;
void update(int operation, Channel* channel)
{
epoll_event event{};
event.events = channel->events();
event.data.ptr = channel;
if (epoll_ctl(epollfd_, operation, channel->fd(), &event) < 0) {
LOG_ERROR << "epoll_ctl op=" << operation << " fd=" << channel->fd()
<< " failed: " << strerror(errno);
} else {
LOG_TRACE << "epoll_ctl op=" << operation << " fd=" << channel->fd() << " success";
}
}
public:
EpollPoller() : epollfd_(epoll_create1(EPOLL_CLOEXEC)), events_(kInitEventListSize)
{
if (epollfd_ < 0) {
LOG_FATAL << "epoll_create1 failed: " << strerror(errno);
abort();
}
LOG_DEBUG << "EpollPoller created with epollfd=" << epollfd_;
}
~EpollPoller()
{
close(epollfd_);
LOG_DEBUG << "EpollPoller destroyed";
}
std::vector<Channel*> poll(int timeoutMs = -1)
{
int numEvents = epoll_wait(epollfd_, events_.data(),
static_cast<int>(events_.size()), timeoutMs);
std::vector<Channel*> activeChannels;
if (numEvents > 0) {
LOG_TRACE << "EpollPoller got " << numEvents << " events";
for (int i = 0; i < numEvents; ++i) {
auto channel = static_cast<Channel*>(events_[i].data.ptr);
channel->setRevents(events_[i].events);
activeChannels.push_back(channel);
}
if (static_cast<size_t>(numEvents) == events_.size()) {
events_.resize(events_.size() * 2);
LOG_DEBUG << "EpollPoller events buffer grown to " << events_.size();
}
} else if (numEvents < 0 && errno != EINTR) {
LOG_ERROR << "EpollPoller::poll failed: " << strerror(errno);
}
return activeChannels;
}
void updateChannel(Channel* channel)
{
int index = channel->index();
int fd = channel->fd();
if (index == kNew || index == kDeleted) {
if (index == kNew) {
channels_[fd] = channel;
}
channel->setIndex(kAdded);
update(EPOLL_CTL_ADD, channel);
} else {
if (channel->isNoneEvent()) {
update(EPOLL_CTL_DEL, channel);
channel->setIndex(kDeleted);
} else {
update(EPOLL_CTL_MOD, channel);
}
}
}
void removeChannel(Channel* channel)
{
int fd = channel->fd();
int index = channel->index();
auto n = channels_.erase(fd);
assert(n == 1);
if (index == kAdded) {
update(EPOLL_CTL_DEL, channel);
}
channel->setIndex(kNew);
LOG_TRACE << "Channel fd=" << fd << " removed from poller";
}
};
class EventLoop : public NonCopyable
{
private:
std::unique_ptr<EpollPoller> poller_;
std::unique_ptr<TimerQueue> timerQueue_;
int wakeupFd_;
std::unique_ptr<Channel> wakeupChannel_;
std::atomic<bool> looping_;
std::atomic<bool> quit_;
std::thread::id threadId_;
std::mutex mutex_;
std::vector<std::function<void()>> pendingFunctors_;
bool callingPendingFunctors_;
static int createEventfd()
{
int fd = eventfd(0, EFD_NONBLOCK | EFD_CLOEXEC);
if (fd < 0) {
LOG_FATAL << "Failed to create eventfd: " << strerror(errno);
abort();
}
LOG_DEBUG << "Created eventfd=" << fd;
return fd;
}
void wakeup()
{
uint64_t one = 1;
ssize_t n = write(wakeupFd_, &one, sizeof(one));
if (n != sizeof(one)) {
LOG_ERROR << "EventLoop wakeup writes " << n << " bytes instead of 8";
}
}
void handleRead()
{
uint64_t one;
ssize_t n = read(wakeupFd_, &one, sizeof(one));
if (n != sizeof(one)) {
LOG_ERROR << "EventLoop handleRead reads " << n << " bytes instead of 8";
}
}
void doPendingFunctors()
{
std::vector<std::function<void()>> functors;
callingPendingFunctors_ = true;
{
std::lock_guard<std::mutex> lock(mutex_);
functors.swap(pendingFunctors_);
}
if (!functors.empty()) {
LOG_TRACE << "EventLoop executed " << functors.size() << " pending functors";
}
for (const auto& functor : functors) {
functor();
}
callingPendingFunctors_ = false;
}
public:
EventLoop()
: poller_(std::make_unique<EpollPoller>()),
timerQueue_(std::make_unique<TimerQueue>(this)),
wakeupFd_(createEventfd()),
wakeupChannel_(std::make_unique<Channel>(this, wakeupFd_)),
looping_(false), quit_(false),
threadId_(),
pendingFunctors_(),
callingPendingFunctors_(false)
{
wakeupChannel_->setReadCallback([this]() { handleRead(); });
wakeupChannel_->enableReading();
LOG_INFO << "EventLoop created";
}
~EventLoop()
{
wakeupChannel_->disableAll();
wakeupChannel_->remove();
close(wakeupFd_);
LOG_INFO << "EventLoop destroyed";
}
void loop()
{
assert(!looping_);
threadId_ = std::this_thread::get_id();
looping_ = true;
quit_ = false;
LOG_INFO << "EventLoop started looping in thread " << threadId_;
while (!quit_) {
auto activeChannels = poller_->poll(10000);
for (auto channel : activeChannels) {
channel->handleEvent();
}
doPendingFunctors();
}
looping_ = false;
LOG_INFO << "EventLoop stopped looping";
}
void quit()
{
quit_ = true;
if (!isInLoopThread()) {
wakeup();
}
LOG_DEBUG << "EventLoop quit requested";
}
template<typename F>
void runInLoop(F&& cb)
{
if (isInLoopThread()) {
cb();
} else {
queueInLoop(std::forward<F>(cb));
}
}
template<typename F>
void queueInLoop(F&& cb)
{
{
std::lock_guard<std::mutex> lock(mutex_);
pendingFunctors_.emplace_back(std::forward<F>(cb));
}
if (!isInLoopThread() || callingPendingFunctors_) {
wakeup();
}
}
uint64_t runAt(TimePoint when, TimerCallback cb)
{
return timerQueue_->addTimer(std::move(cb), when);
}
uint64_t runAfter(Duration delay, TimerCallback cb)
{
return runAt(std::chrono::steady_clock::now() + delay, std::move(cb));
}
uint64_t runEvery(Duration interval, TimerCallback cb)
{
auto when = std::chrono::steady_clock::now() + interval;
return timerQueue_->addTimer(std::move(cb), when, interval);
}
void cancel(uint64_t timerId) { timerQueue_->cancel(timerId); }
void updateChannel(Channel* channel) { poller_->updateChannel(channel); }
void removeChannel(Channel* channel) { poller_->removeChannel(channel); }
bool isInLoopThread() const
{
return threadId_ == std::this_thread::get_id();
}
void assertInLoopThread() const
{
if (!isInLoopThread()) {
LOG_FATAL << "EventLoop was created in thread " << threadId_
<< " but accessed from thread " << std::this_thread::get_id();
abort();
}
}
};
inline void Channel::update()
{
loop_->updateChannel(this);
}
inline void Channel::remove()
{
assert(isNoneEvent());
loop_->removeChannel(this);
}
}