#pragma once #include "Utilities.hpp" #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include namespace reactor { using TimePoint = std::chrono::steady_clock::time_point; using Duration = std::chrono::milliseconds; using TimerCallback = std::function; using EventCallback = std::function; class EventLoop; class Channel : public NonCopyable { private: EventLoop* loop_; int fd_; int events_; int revents_; int index_; bool tied_; std::weak_ptr tie_; EventCallback readCallback_; EventCallback writeCallback_; EventCallback closeCallback_; EventCallback errorCallback_; /* * Safely handle the event, checking if the tied object is still alive. */ void handleEventSafely() { LOG_TRACE << "Channel fd=" << fd_ << " handling events: " << revents_; if ((revents_ & POLLHUP) && !(revents_ & POLLIN)) { LOG_DEBUG << "Channel fd=" << fd_ << " hangup"; if (closeCallback_) { closeCallback_(); } } if (revents_ & POLLERR) { LOG_WARN << "Channel fd=" << fd_ << " error event"; if (errorCallback_) { errorCallback_(); } } if (revents_ & (POLLIN | POLLPRI | POLLRDHUP)) { LOG_TRACE << "Channel fd=" << fd_ << " readable"; if (readCallback_) { readCallback_(); } } if (revents_ & POLLOUT) { LOG_TRACE << "Channel fd=" << fd_ << " writable"; if (writeCallback_) { writeCallback_(); } } } public: static constexpr int kNoneEvent = 0; static constexpr int kReadEvent = POLLIN | POLLPRI; static constexpr int kWriteEvent = POLLOUT; Channel(EventLoop* loop, int fd) : loop_(loop), fd_(fd), events_(0), revents_(0), index_(-1), tied_(false) { LOG_TRACE << "Channel created for fd=" << fd; } ~Channel() { LOG_TRACE << "Channel destroyed for fd=" << fd_; } int fd() const { return fd_; } int events() const { return events_; } int revents() const { return revents_; } int index() const { return index_; } EventLoop* ownerLoop() const { return loop_; } void setRevents(int revents) { revents_ = revents; } void setIndex(int index) { index_ = index; } void enableReading() { events_ |= kReadEvent; update(); LOG_TRACE << "Channel fd=" << fd_ << " enabled reading"; } void disableReading() { events_ &= ~kReadEvent; update(); LOG_TRACE << "Channel fd=" << fd_ << " disabled reading"; } void enableWriting() { events_ |= kWriteEvent; update(); LOG_TRACE << "Channel fd=" << fd_ << " enabled writing"; } void disableWriting() { events_ &= ~kWriteEvent; update(); LOG_TRACE << "Channel fd=" << fd_ << " disabled writing"; } void disableAll() { events_ = kNoneEvent; update(); LOG_TRACE << "Channel fd=" << fd_ << " disabled all events"; } bool isNoneEvent() const { return events_ == kNoneEvent; } bool isWriting() const { return events_ & kWriteEvent; } bool isReading() const { return events_ & kReadEvent; } void tie(const std::shared_ptr& obj) { tie_ = obj; tied_ = true; } void setReadCallback(EventCallback cb) { readCallback_ = std::move(cb); } void setWriteCallback(EventCallback cb) { writeCallback_ = std::move(cb); } void setCloseCallback(EventCallback cb) { closeCallback_ = std::move(cb); } void setErrorCallback(EventCallback cb) { errorCallback_ = std::move(cb); } /* * Handle an event. If the channel is tied to an object, * ensure the object is still alive before proceeding. */ void handleEvent() { if (tied_) { std::shared_ptr guard = tie_.lock(); if (guard) { handleEventSafely(); } else { LOG_WARN << "Channel fd=" << fd_ << " tied object expired"; } } else { handleEventSafely(); } } void remove(); void update(); }; class Timer : public NonCopyable { private: TimerCallback callback_; TimePoint when_; Duration interval_; bool repeat_; uint64_t id_; static inline std::atomic s_numCreated_{0}; public: Timer(TimerCallback cb, TimePoint when, Duration interval = Duration(0)) : callback_(std::move(cb)), when_(when), interval_(interval), repeat_(interval.count() > 0), id_(++s_numCreated_) { LOG_TRACE << "Timer created id=" << id_ << " repeat=" << repeat_; } void run() const { LOG_TRACE << "Timer id=" << id_ << " executing"; callback_(); } TimePoint when() const { return when_; } bool repeat() const { return repeat_; } uint64_t id() const { return id_; } void restart(TimePoint now) { if (repeat_) { when_ = now + interval_; LOG_TRACE << "Timer id=" << id_ << " restarted"; } } bool operator<(const Timer& other) const { return when_ > other.when_; } }; class TimerQueue : public NonCopyable { private: EventLoop* loop_; int timerfd_; std::unique_ptr timerChannel_; std::priority_queue> timers_; std::unordered_set activeTimers_; static int createTimerfd() { int fd = timerfd_create(CLOCK_MONOTONIC, TFD_NONBLOCK | TFD_CLOEXEC); if (fd < 0) { LOG_FATAL << "Failed to create timerfd: " << strerror(errno); abort(); } LOG_DEBUG << "Created timerfd=" << fd; return fd; } void resetTimerfd(TimePoint expiration) { auto duration = expiration - std::chrono::steady_clock::now(); auto ns = std::chrono::duration_cast(duration).count(); if (ns < 100000) { ns = 100000; } itimerspec newValue{}; newValue.it_value.tv_sec = ns / 1000000000; newValue.it_value.tv_nsec = ns % 1000000000; if (timerfd_settime(timerfd_, 0, &newValue, nullptr) < 0) { LOG_ERROR << "timerfd_settime failed: " << strerror(errno); } } void handleRead() { uint64_t data; ssize_t n = read(timerfd_, &data, sizeof(data)); if (n != sizeof(data)) { LOG_ERROR << "TimerQueue read " << n << " bytes instead of 8"; } auto now = std::chrono::steady_clock::now(); std::vector> expired; while (!timers_.empty() && timers_.top()->when() <= now) { expired.push_back(timers_.top()); timers_.pop(); } LOG_TRACE << "TimerQueue processed " << expired.size() << " expired timers"; for (auto& timer : expired) { if (activeTimers_.count(timer->id())) { timer->run(); if (timer->repeat()) { timer->restart(now); timers_.push(timer); } else { activeTimers_.erase(timer->id()); } } } if (!timers_.empty()) { resetTimerfd(timers_.top()->when()); } } public: TimerQueue(EventLoop* loop) : loop_(loop), timerfd_(createTimerfd()), timerChannel_(std::make_unique(loop, timerfd_)) { timerChannel_->setReadCallback([this]() { handleRead(); }); timerChannel_->enableReading(); LOG_DEBUG << "TimerQueue initialized"; } ~TimerQueue() { timerChannel_->disableAll(); timerChannel_->remove(); close(timerfd_); LOG_DEBUG << "TimerQueue destroyed"; } uint64_t addTimer(TimerCallback cb, TimePoint when, Duration interval = Duration(0)) { auto timer = std::make_shared(std::move(cb), when, interval); bool earliestChanged = timers_.empty() || when < timers_.top()->when(); timers_.push(timer); activeTimers_.insert(timer->id()); if (earliestChanged) { resetTimerfd(when); } LOG_DEBUG << "Added timer id=" << timer->id() << " earliest_changed=" << earliestChanged; return timer->id(); } void cancel(uint64_t timerId) { auto erased = activeTimers_.erase(timerId); LOG_DEBUG << "Cancelled timer id=" << timerId << " found=" << (erased > 0); } }; class EpollPoller : public NonCopyable { private: static constexpr int kNew = -1; static constexpr int kAdded = 1; static constexpr int kDeleted = 2; static constexpr int kInitEventListSize = 16; int epollfd_; std::vector events_; std::unordered_map channels_; void update(int operation, Channel* channel) { epoll_event event{}; event.events = channel->events(); event.data.ptr = channel; if (epoll_ctl(epollfd_, operation, channel->fd(), &event) < 0) { LOG_ERROR << "epoll_ctl op=" << operation << " fd=" << channel->fd() << " failed: " << strerror(errno); } else { LOG_TRACE << "epoll_ctl op=" << operation << " fd=" << channel->fd() << " success"; } } public: EpollPoller() : epollfd_(epoll_create1(EPOLL_CLOEXEC)), events_(kInitEventListSize) { if (epollfd_ < 0) { LOG_FATAL << "epoll_create1 failed: " << strerror(errno); abort(); } LOG_DEBUG << "EpollPoller created with epollfd=" << epollfd_; } ~EpollPoller() { close(epollfd_); LOG_DEBUG << "EpollPoller destroyed"; } std::vector poll(int timeoutMs = -1) { int numEvents = epoll_wait(epollfd_, events_.data(), static_cast(events_.size()), timeoutMs); std::vector activeChannels; if (numEvents > 0) { LOG_TRACE << "EpollPoller got " << numEvents << " events"; for (int i = 0; i < numEvents; ++i) { auto channel = static_cast(events_[i].data.ptr); channel->setRevents(events_[i].events); activeChannels.push_back(channel); } if (static_cast(numEvents) == events_.size()) { events_.resize(events_.size() * 2); LOG_DEBUG << "EpollPoller events buffer grown to " << events_.size(); } } else if (numEvents < 0 && errno != EINTR) { LOG_ERROR << "EpollPoller::poll failed: " << strerror(errno); } return activeChannels; } void updateChannel(Channel* channel) { int index = channel->index(); int fd = channel->fd(); if (index == kNew || index == kDeleted) { if (index == kNew) { channels_[fd] = channel; } channel->setIndex(kAdded); update(EPOLL_CTL_ADD, channel); } else { if (channel->isNoneEvent()) { update(EPOLL_CTL_DEL, channel); channel->setIndex(kDeleted); } else { update(EPOLL_CTL_MOD, channel); } } } void removeChannel(Channel* channel) { int fd = channel->fd(); int index = channel->index(); auto n = channels_.erase(fd); assert(n == 1); if (index == kAdded) { update(EPOLL_CTL_DEL, channel); } channel->setIndex(kNew); LOG_TRACE << "Channel fd=" << fd << " removed from poller"; } }; class EventLoop : public NonCopyable { private: std::unique_ptr poller_; std::unique_ptr timerQueue_; int wakeupFd_; std::unique_ptr wakeupChannel_; std::atomic looping_; std::atomic quit_; std::thread::id threadId_; std::mutex mutex_; std::vector> pendingFunctors_; bool callingPendingFunctors_; static int createEventfd() { int fd = eventfd(0, EFD_NONBLOCK | EFD_CLOEXEC); if (fd < 0) { LOG_FATAL << "Failed to create eventfd: " << strerror(errno); abort(); } LOG_DEBUG << "Created eventfd=" << fd; return fd; } void wakeup() { uint64_t one = 1; ssize_t n = write(wakeupFd_, &one, sizeof(one)); if (n != sizeof(one)) { LOG_ERROR << "EventLoop wakeup writes " << n << " bytes instead of 8"; } } void handleRead() { uint64_t one; ssize_t n = read(wakeupFd_, &one, sizeof(one)); if (n != sizeof(one)) { LOG_ERROR << "EventLoop handleRead reads " << n << " bytes instead of 8"; } } void doPendingFunctors() { std::vector> functors; callingPendingFunctors_ = true; { std::lock_guard lock(mutex_); functors.swap(pendingFunctors_); } if (!functors.empty()) { LOG_TRACE << "EventLoop executed " << functors.size() << " pending functors"; } for (const auto& functor : functors) { functor(); } callingPendingFunctors_ = false; } public: EventLoop() : poller_(std::make_unique()), timerQueue_(std::make_unique(this)), wakeupFd_(createEventfd()), wakeupChannel_(std::make_unique(this, wakeupFd_)), looping_(false), quit_(false), threadId_(), pendingFunctors_(), callingPendingFunctors_(false) { wakeupChannel_->setReadCallback([this]() { handleRead(); }); wakeupChannel_->enableReading(); LOG_INFO << "EventLoop created"; } ~EventLoop() { wakeupChannel_->disableAll(); wakeupChannel_->remove(); close(wakeupFd_); LOG_INFO << "EventLoop destroyed"; } void loop() { assert(!looping_); threadId_ = std::this_thread::get_id(); looping_ = true; quit_ = false; LOG_INFO << "EventLoop started looping in thread " << threadId_; while (!quit_) { auto activeChannels = poller_->poll(10000); for (auto channel : activeChannels) { channel->handleEvent(); } doPendingFunctors(); } looping_ = false; LOG_INFO << "EventLoop stopped looping"; } void quit() { quit_ = true; if (!isInLoopThread()) { wakeup(); } LOG_DEBUG << "EventLoop quit requested"; } template void runInLoop(F&& cb) { if (isInLoopThread()) { cb(); } else { queueInLoop(std::forward(cb)); } } template void queueInLoop(F&& cb) { { std::lock_guard lock(mutex_); pendingFunctors_.emplace_back(std::forward(cb)); } if (!isInLoopThread() || callingPendingFunctors_) { wakeup(); } } uint64_t runAt(TimePoint when, TimerCallback cb) { return timerQueue_->addTimer(std::move(cb), when); } uint64_t runAfter(Duration delay, TimerCallback cb) { return runAt(std::chrono::steady_clock::now() + delay, std::move(cb)); } uint64_t runEvery(Duration interval, TimerCallback cb) { auto when = std::chrono::steady_clock::now() + interval; return timerQueue_->addTimer(std::move(cb), when, interval); } void cancel(uint64_t timerId) { timerQueue_->cancel(timerId); } void updateChannel(Channel* channel) { poller_->updateChannel(channel); } void removeChannel(Channel* channel) { poller_->removeChannel(channel); } bool isInLoopThread() const { return threadId_ == std::this_thread::get_id(); } void assertInLoopThread() const { if (!isInLoopThread()) { LOG_FATAL << "EventLoop was created in thread " << threadId_ << " but accessed from thread " << std::this_thread::get_id(); abort(); } } }; inline void Channel::update() { loop_->updateChannel(this); } inline void Channel::remove() { assert(isNoneEvent()); loop_->removeChannel(this); } }