624 lines
14 KiB
C++
624 lines
14 KiB
C++
#pragma once
|
|
|
|
#include "Utilities.hpp"
|
|
#include <sys/epoll.h>
|
|
#include <sys/timerfd.h>
|
|
#include <sys/eventfd.h>
|
|
#include <poll.h>
|
|
#include <unistd.h>
|
|
#include <vector>
|
|
#include <unordered_map>
|
|
#include <unordered_set>
|
|
#include <functional>
|
|
#include <memory>
|
|
#include <queue>
|
|
#include <chrono>
|
|
#include <thread>
|
|
#include <atomic>
|
|
#include <cassert>
|
|
#include <cstring>
|
|
#include <mutex>
|
|
|
|
namespace reactor
|
|
{
|
|
|
|
using TimePoint = std::chrono::steady_clock::time_point;
|
|
using Duration = std::chrono::milliseconds;
|
|
using TimerCallback = std::function<void()>;
|
|
using EventCallback = std::function<void()>;
|
|
|
|
class EventLoop;
|
|
|
|
class Channel : public NonCopyable
|
|
{
|
|
private:
|
|
EventLoop* loop_;
|
|
int fd_;
|
|
int events_;
|
|
int revents_;
|
|
int index_;
|
|
bool tied_;
|
|
std::weak_ptr<void> tie_;
|
|
EventCallback readCallback_;
|
|
EventCallback writeCallback_;
|
|
EventCallback closeCallback_;
|
|
EventCallback errorCallback_;
|
|
|
|
/*
|
|
* Safely handle the event, checking if the tied object is still alive.
|
|
*/
|
|
void handleEventSafely()
|
|
{
|
|
LOG_TRACE << "Channel fd=" << fd_ << " handling events: " << revents_;
|
|
|
|
if ((revents_ & POLLHUP) && !(revents_ & POLLIN)) {
|
|
LOG_DEBUG << "Channel fd=" << fd_ << " hangup";
|
|
if (closeCallback_) {
|
|
closeCallback_();
|
|
}
|
|
}
|
|
if (revents_ & POLLERR) {
|
|
LOG_WARN << "Channel fd=" << fd_ << " error event";
|
|
if (errorCallback_) {
|
|
errorCallback_();
|
|
}
|
|
}
|
|
if (revents_ & (POLLIN | POLLPRI | POLLRDHUP)) {
|
|
LOG_TRACE << "Channel fd=" << fd_ << " readable";
|
|
if (readCallback_) {
|
|
readCallback_();
|
|
}
|
|
}
|
|
if (revents_ & POLLOUT) {
|
|
LOG_TRACE << "Channel fd=" << fd_ << " writable";
|
|
if (writeCallback_) {
|
|
writeCallback_();
|
|
}
|
|
}
|
|
}
|
|
|
|
public:
|
|
static constexpr int kNoneEvent = 0;
|
|
static constexpr int kReadEvent = POLLIN | POLLPRI;
|
|
static constexpr int kWriteEvent = POLLOUT;
|
|
|
|
Channel(EventLoop* loop, int fd)
|
|
: loop_(loop), fd_(fd), events_(0), revents_(0), index_(-1), tied_(false)
|
|
{
|
|
LOG_TRACE << "Channel created for fd=" << fd;
|
|
}
|
|
|
|
~Channel()
|
|
{
|
|
LOG_TRACE << "Channel destroyed for fd=" << fd_;
|
|
}
|
|
|
|
int fd() const { return fd_; }
|
|
int events() const { return events_; }
|
|
int revents() const { return revents_; }
|
|
int index() const { return index_; }
|
|
EventLoop* ownerLoop() const { return loop_; }
|
|
|
|
void setRevents(int revents) { revents_ = revents; }
|
|
void setIndex(int index) { index_ = index; }
|
|
|
|
void enableReading()
|
|
{
|
|
events_ |= kReadEvent;
|
|
update();
|
|
LOG_TRACE << "Channel fd=" << fd_ << " enabled reading";
|
|
}
|
|
|
|
void disableReading()
|
|
{
|
|
events_ &= ~kReadEvent;
|
|
update();
|
|
LOG_TRACE << "Channel fd=" << fd_ << " disabled reading";
|
|
}
|
|
|
|
void enableWriting()
|
|
{
|
|
events_ |= kWriteEvent;
|
|
update();
|
|
LOG_TRACE << "Channel fd=" << fd_ << " enabled writing";
|
|
}
|
|
|
|
void disableWriting()
|
|
{
|
|
events_ &= ~kWriteEvent;
|
|
update();
|
|
LOG_TRACE << "Channel fd=" << fd_ << " disabled writing";
|
|
}
|
|
|
|
void disableAll()
|
|
{
|
|
events_ = kNoneEvent;
|
|
update();
|
|
LOG_TRACE << "Channel fd=" << fd_ << " disabled all events";
|
|
}
|
|
|
|
bool isNoneEvent() const { return events_ == kNoneEvent; }
|
|
bool isWriting() const { return events_ & kWriteEvent; }
|
|
bool isReading() const { return events_ & kReadEvent; }
|
|
|
|
void tie(const std::shared_ptr<void>& obj)
|
|
{
|
|
tie_ = obj;
|
|
tied_ = true;
|
|
}
|
|
|
|
void setReadCallback(EventCallback cb) { readCallback_ = std::move(cb); }
|
|
void setWriteCallback(EventCallback cb) { writeCallback_ = std::move(cb); }
|
|
void setCloseCallback(EventCallback cb) { closeCallback_ = std::move(cb); }
|
|
void setErrorCallback(EventCallback cb) { errorCallback_ = std::move(cb); }
|
|
|
|
/*
|
|
* Handle an event. If the channel is tied to an object,
|
|
* ensure the object is still alive before proceeding.
|
|
*/
|
|
void handleEvent()
|
|
{
|
|
if (tied_) {
|
|
std::shared_ptr<void> guard = tie_.lock();
|
|
if (guard) {
|
|
handleEventSafely();
|
|
} else {
|
|
LOG_WARN << "Channel fd=" << fd_ << " tied object expired";
|
|
}
|
|
} else {
|
|
handleEventSafely();
|
|
}
|
|
}
|
|
|
|
void remove();
|
|
void update();
|
|
};
|
|
|
|
class Timer : public NonCopyable
|
|
{
|
|
private:
|
|
TimerCallback callback_;
|
|
TimePoint when_;
|
|
Duration interval_;
|
|
bool repeat_;
|
|
uint64_t id_;
|
|
static inline std::atomic<uint64_t> s_numCreated_{0};
|
|
|
|
public:
|
|
Timer(TimerCallback cb, TimePoint when, Duration interval = Duration(0))
|
|
: callback_(std::move(cb)), when_(when), interval_(interval),
|
|
repeat_(interval.count() > 0), id_(++s_numCreated_)
|
|
{
|
|
LOG_TRACE << "Timer created id=" << id_ << " repeat=" << repeat_;
|
|
}
|
|
|
|
void run() const
|
|
{
|
|
LOG_TRACE << "Timer id=" << id_ << " executing";
|
|
callback_();
|
|
}
|
|
|
|
TimePoint when() const { return when_; }
|
|
bool repeat() const { return repeat_; }
|
|
uint64_t id() const { return id_; }
|
|
|
|
void restart(TimePoint now)
|
|
{
|
|
if (repeat_) {
|
|
when_ = now + interval_;
|
|
LOG_TRACE << "Timer id=" << id_ << " restarted";
|
|
}
|
|
}
|
|
|
|
bool operator<(const Timer& other) const { return when_ > other.when_; }
|
|
};
|
|
|
|
class TimerQueue : public NonCopyable
|
|
{
|
|
private:
|
|
EventLoop* loop_;
|
|
int timerfd_;
|
|
std::unique_ptr<Channel> timerChannel_;
|
|
std::priority_queue<std::shared_ptr<Timer>> timers_;
|
|
std::unordered_set<uint64_t> activeTimers_;
|
|
|
|
static int createTimerfd()
|
|
{
|
|
int fd = timerfd_create(CLOCK_MONOTONIC, TFD_NONBLOCK | TFD_CLOEXEC);
|
|
if (fd < 0) {
|
|
LOG_FATAL << "Failed to create timerfd: " << strerror(errno);
|
|
abort();
|
|
}
|
|
LOG_DEBUG << "Created timerfd=" << fd;
|
|
return fd;
|
|
}
|
|
|
|
void resetTimerfd(TimePoint expiration)
|
|
{
|
|
auto duration = expiration - std::chrono::steady_clock::now();
|
|
auto ns = std::chrono::duration_cast<std::chrono::nanoseconds>(duration).count();
|
|
if (ns < 100000) {
|
|
ns = 100000;
|
|
}
|
|
|
|
itimerspec newValue{};
|
|
newValue.it_value.tv_sec = ns / 1000000000;
|
|
newValue.it_value.tv_nsec = ns % 1000000000;
|
|
|
|
if (timerfd_settime(timerfd_, 0, &newValue, nullptr) < 0) {
|
|
LOG_ERROR << "timerfd_settime failed: " << strerror(errno);
|
|
}
|
|
}
|
|
|
|
void handleRead()
|
|
{
|
|
uint64_t data;
|
|
ssize_t n = read(timerfd_, &data, sizeof(data));
|
|
if (n != sizeof(data)) {
|
|
LOG_ERROR << "TimerQueue read " << n << " bytes instead of 8";
|
|
}
|
|
|
|
auto now = std::chrono::steady_clock::now();
|
|
std::vector<std::shared_ptr<Timer>> expired;
|
|
|
|
while (!timers_.empty() && timers_.top()->when() <= now) {
|
|
expired.push_back(timers_.top());
|
|
timers_.pop();
|
|
}
|
|
|
|
LOG_TRACE << "TimerQueue processed " << expired.size() << " expired timers";
|
|
|
|
for (auto& timer : expired) {
|
|
if (activeTimers_.count(timer->id())) {
|
|
timer->run();
|
|
if (timer->repeat()) {
|
|
timer->restart(now);
|
|
timers_.push(timer);
|
|
} else {
|
|
activeTimers_.erase(timer->id());
|
|
}
|
|
}
|
|
}
|
|
|
|
if (!timers_.empty()) {
|
|
resetTimerfd(timers_.top()->when());
|
|
}
|
|
}
|
|
|
|
public:
|
|
TimerQueue(EventLoop* loop)
|
|
: loop_(loop), timerfd_(createTimerfd()),
|
|
timerChannel_(std::make_unique<Channel>(loop, timerfd_))
|
|
{
|
|
timerChannel_->setReadCallback([this]() { handleRead(); });
|
|
timerChannel_->enableReading();
|
|
LOG_DEBUG << "TimerQueue initialized";
|
|
}
|
|
|
|
~TimerQueue()
|
|
{
|
|
timerChannel_->disableAll();
|
|
timerChannel_->remove();
|
|
close(timerfd_);
|
|
LOG_DEBUG << "TimerQueue destroyed";
|
|
}
|
|
|
|
uint64_t addTimer(TimerCallback cb, TimePoint when, Duration interval = Duration(0))
|
|
{
|
|
auto timer = std::make_shared<Timer>(std::move(cb), when, interval);
|
|
bool earliestChanged = timers_.empty() || when < timers_.top()->when();
|
|
|
|
timers_.push(timer);
|
|
activeTimers_.insert(timer->id());
|
|
|
|
if (earliestChanged) {
|
|
resetTimerfd(when);
|
|
}
|
|
|
|
LOG_DEBUG << "Added timer id=" << timer->id() << " earliest_changed=" << earliestChanged;
|
|
return timer->id();
|
|
}
|
|
|
|
void cancel(uint64_t timerId)
|
|
{
|
|
auto erased = activeTimers_.erase(timerId);
|
|
LOG_DEBUG << "Cancelled timer id=" << timerId << " found=" << (erased > 0);
|
|
}
|
|
};
|
|
|
|
class EpollPoller : public NonCopyable
|
|
{
|
|
private:
|
|
static constexpr int kNew = -1;
|
|
static constexpr int kAdded = 1;
|
|
static constexpr int kDeleted = 2;
|
|
static constexpr int kInitEventListSize = 16;
|
|
|
|
int epollfd_;
|
|
std::vector<epoll_event> events_;
|
|
std::unordered_map<int, Channel*> channels_;
|
|
|
|
void update(int operation, Channel* channel)
|
|
{
|
|
epoll_event event{};
|
|
event.events = channel->events();
|
|
event.data.ptr = channel;
|
|
|
|
if (epoll_ctl(epollfd_, operation, channel->fd(), &event) < 0) {
|
|
LOG_ERROR << "epoll_ctl op=" << operation << " fd=" << channel->fd()
|
|
<< " failed: " << strerror(errno);
|
|
} else {
|
|
LOG_TRACE << "epoll_ctl op=" << operation << " fd=" << channel->fd() << " success";
|
|
}
|
|
}
|
|
|
|
public:
|
|
EpollPoller() : epollfd_(epoll_create1(EPOLL_CLOEXEC)), events_(kInitEventListSize)
|
|
{
|
|
if (epollfd_ < 0) {
|
|
LOG_FATAL << "epoll_create1 failed: " << strerror(errno);
|
|
abort();
|
|
}
|
|
LOG_DEBUG << "EpollPoller created with epollfd=" << epollfd_;
|
|
}
|
|
|
|
~EpollPoller()
|
|
{
|
|
close(epollfd_);
|
|
LOG_DEBUG << "EpollPoller destroyed";
|
|
}
|
|
|
|
std::vector<Channel*> poll(int timeoutMs = -1)
|
|
{
|
|
int numEvents = epoll_wait(epollfd_, events_.data(),
|
|
static_cast<int>(events_.size()), timeoutMs);
|
|
|
|
std::vector<Channel*> activeChannels;
|
|
|
|
if (numEvents > 0) {
|
|
LOG_TRACE << "EpollPoller got " << numEvents << " events";
|
|
|
|
for (int i = 0; i < numEvents; ++i) {
|
|
auto channel = static_cast<Channel*>(events_[i].data.ptr);
|
|
channel->setRevents(events_[i].events);
|
|
activeChannels.push_back(channel);
|
|
}
|
|
|
|
if (static_cast<size_t>(numEvents) == events_.size()) {
|
|
events_.resize(events_.size() * 2);
|
|
LOG_DEBUG << "EpollPoller events buffer grown to " << events_.size();
|
|
}
|
|
} else if (numEvents < 0 && errno != EINTR) {
|
|
LOG_ERROR << "EpollPoller::poll failed: " << strerror(errno);
|
|
}
|
|
|
|
return activeChannels;
|
|
}
|
|
|
|
void updateChannel(Channel* channel)
|
|
{
|
|
int index = channel->index();
|
|
int fd = channel->fd();
|
|
|
|
if (index == kNew || index == kDeleted) {
|
|
if (index == kNew) {
|
|
channels_[fd] = channel;
|
|
}
|
|
channel->setIndex(kAdded);
|
|
update(EPOLL_CTL_ADD, channel);
|
|
} else {
|
|
if (channel->isNoneEvent()) {
|
|
update(EPOLL_CTL_DEL, channel);
|
|
channel->setIndex(kDeleted);
|
|
} else {
|
|
update(EPOLL_CTL_MOD, channel);
|
|
}
|
|
}
|
|
}
|
|
|
|
void removeChannel(Channel* channel)
|
|
{
|
|
int fd = channel->fd();
|
|
int index = channel->index();
|
|
|
|
auto n = channels_.erase(fd);
|
|
assert(n == 1);
|
|
|
|
if (index == kAdded) {
|
|
update(EPOLL_CTL_DEL, channel);
|
|
}
|
|
|
|
channel->setIndex(kNew);
|
|
LOG_TRACE << "Channel fd=" << fd << " removed from poller";
|
|
}
|
|
};
|
|
|
|
class EventLoop : public NonCopyable
|
|
{
|
|
private:
|
|
std::unique_ptr<EpollPoller> poller_;
|
|
std::unique_ptr<TimerQueue> timerQueue_;
|
|
int wakeupFd_;
|
|
std::unique_ptr<Channel> wakeupChannel_;
|
|
std::atomic<bool> looping_;
|
|
std::atomic<bool> quit_;
|
|
std::thread::id threadId_;
|
|
std::mutex mutex_;
|
|
std::vector<std::function<void()>> pendingFunctors_;
|
|
bool callingPendingFunctors_;
|
|
|
|
static int createEventfd()
|
|
{
|
|
int fd = eventfd(0, EFD_NONBLOCK | EFD_CLOEXEC);
|
|
if (fd < 0) {
|
|
LOG_FATAL << "Failed to create eventfd: " << strerror(errno);
|
|
abort();
|
|
}
|
|
LOG_DEBUG << "Created eventfd=" << fd;
|
|
return fd;
|
|
}
|
|
|
|
void wakeup()
|
|
{
|
|
uint64_t one = 1;
|
|
ssize_t n = write(wakeupFd_, &one, sizeof(one));
|
|
if (n != sizeof(one)) {
|
|
LOG_ERROR << "EventLoop wakeup writes " << n << " bytes instead of 8";
|
|
}
|
|
}
|
|
|
|
void handleRead()
|
|
{
|
|
uint64_t one;
|
|
ssize_t n = read(wakeupFd_, &one, sizeof(one));
|
|
if (n != sizeof(one)) {
|
|
LOG_ERROR << "EventLoop handleRead reads " << n << " bytes instead of 8";
|
|
}
|
|
}
|
|
|
|
void doPendingFunctors()
|
|
{
|
|
std::vector<std::function<void()>> functors;
|
|
callingPendingFunctors_ = true;
|
|
{
|
|
std::lock_guard<std::mutex> lock(mutex_);
|
|
functors.swap(pendingFunctors_);
|
|
}
|
|
|
|
if (!functors.empty()) {
|
|
LOG_TRACE << "EventLoop executed " << functors.size() << " pending functors";
|
|
}
|
|
|
|
for (const auto& functor : functors) {
|
|
functor();
|
|
}
|
|
|
|
callingPendingFunctors_ = false;
|
|
}
|
|
|
|
public:
|
|
EventLoop()
|
|
: poller_(std::make_unique<EpollPoller>()),
|
|
timerQueue_(std::make_unique<TimerQueue>(this)),
|
|
wakeupFd_(createEventfd()),
|
|
wakeupChannel_(std::make_unique<Channel>(this, wakeupFd_)),
|
|
looping_(false), quit_(false),
|
|
threadId_(),
|
|
pendingFunctors_(),
|
|
callingPendingFunctors_(false)
|
|
{
|
|
wakeupChannel_->setReadCallback([this]() { handleRead(); });
|
|
wakeupChannel_->enableReading();
|
|
LOG_INFO << "EventLoop created";
|
|
}
|
|
|
|
~EventLoop()
|
|
{
|
|
wakeupChannel_->disableAll();
|
|
wakeupChannel_->remove();
|
|
close(wakeupFd_);
|
|
LOG_INFO << "EventLoop destroyed";
|
|
}
|
|
|
|
void loop()
|
|
{
|
|
assert(!looping_);
|
|
threadId_ = std::this_thread::get_id();
|
|
looping_ = true;
|
|
quit_ = false;
|
|
|
|
LOG_INFO << "EventLoop started looping in thread " << threadId_;
|
|
|
|
while (!quit_) {
|
|
auto activeChannels = poller_->poll(10000);
|
|
for (auto channel : activeChannels) {
|
|
channel->handleEvent();
|
|
}
|
|
doPendingFunctors();
|
|
}
|
|
|
|
looping_ = false;
|
|
LOG_INFO << "EventLoop stopped looping";
|
|
}
|
|
|
|
void quit()
|
|
{
|
|
quit_ = true;
|
|
if (!isInLoopThread()) {
|
|
wakeup();
|
|
}
|
|
LOG_DEBUG << "EventLoop quit requested";
|
|
}
|
|
|
|
template<typename F>
|
|
void runInLoop(F&& cb)
|
|
{
|
|
if (isInLoopThread()) {
|
|
cb();
|
|
} else {
|
|
queueInLoop(std::forward<F>(cb));
|
|
}
|
|
}
|
|
|
|
template<typename F>
|
|
void queueInLoop(F&& cb)
|
|
{
|
|
{
|
|
std::lock_guard<std::mutex> lock(mutex_);
|
|
pendingFunctors_.emplace_back(std::forward<F>(cb));
|
|
}
|
|
|
|
if (!isInLoopThread() || callingPendingFunctors_) {
|
|
wakeup();
|
|
}
|
|
}
|
|
|
|
uint64_t runAt(TimePoint when, TimerCallback cb)
|
|
{
|
|
return timerQueue_->addTimer(std::move(cb), when);
|
|
}
|
|
|
|
uint64_t runAfter(Duration delay, TimerCallback cb)
|
|
{
|
|
return runAt(std::chrono::steady_clock::now() + delay, std::move(cb));
|
|
}
|
|
|
|
uint64_t runEvery(Duration interval, TimerCallback cb)
|
|
{
|
|
auto when = std::chrono::steady_clock::now() + interval;
|
|
return timerQueue_->addTimer(std::move(cb), when, interval);
|
|
}
|
|
|
|
void cancel(uint64_t timerId) { timerQueue_->cancel(timerId); }
|
|
|
|
void updateChannel(Channel* channel) { poller_->updateChannel(channel); }
|
|
void removeChannel(Channel* channel) { poller_->removeChannel(channel); }
|
|
|
|
bool isInLoopThread() const
|
|
{
|
|
return threadId_ == std::this_thread::get_id();
|
|
}
|
|
|
|
void assertInLoopThread() const
|
|
{
|
|
if (!isInLoopThread()) {
|
|
LOG_FATAL << "EventLoop was created in thread " << threadId_
|
|
<< " but accessed from thread " << std::this_thread::get_id();
|
|
abort();
|
|
}
|
|
}
|
|
};
|
|
|
|
inline void Channel::update()
|
|
{
|
|
loop_->updateChannel(this);
|
|
}
|
|
|
|
inline void Channel::remove()
|
|
{
|
|
assert(isNoneEvent());
|
|
loop_->removeChannel(this);
|
|
}
|
|
|
|
}
|