v1.0 lib
This commit is contained in:
parent
eec6bfde5e
commit
14a395aa3c
262
lib/Buffer.hpp
Normal file
262
lib/Buffer.hpp
Normal file
@ -0,0 +1,262 @@
|
||||
#pragma once
|
||||
|
||||
#include "Utilities.hpp"
|
||||
#include <sys/uio.h>
|
||||
#include <unistd.h>
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
|
||||
namespace reactor {
|
||||
|
||||
class Buffer : public NonCopyable
|
||||
{
|
||||
private:
|
||||
std::vector<char> buffer_;
|
||||
size_t readIndex_;
|
||||
size_t writeIndex_;
|
||||
size_t initialCap_;
|
||||
|
||||
static constexpr size_t kBufferOffset = 8;
|
||||
|
||||
public:
|
||||
static constexpr size_t kInitialSize = 1024;
|
||||
|
||||
explicit Buffer(size_t initialSize = kInitialSize)
|
||||
: buffer_(initialSize + kBufferOffset), readIndex_(kBufferOffset),
|
||||
writeIndex_(kBufferOffset), initialCap_(initialSize)
|
||||
{
|
||||
LOG_TRACE << "Buffer created with size " << initialSize;
|
||||
}
|
||||
|
||||
~Buffer()
|
||||
{
|
||||
LOG_TRACE << "Buffer destroyed, had " << readableBytes() << " readable bytes";
|
||||
}
|
||||
|
||||
size_t readableBytes() const { return writeIndex_ - readIndex_; }
|
||||
size_t writableBytes() const { return buffer_.size() - writeIndex_; }
|
||||
size_t prependableBytes() const { return readIndex_; }
|
||||
|
||||
const char* peek() const { return &buffer_[readIndex_]; }
|
||||
char* beginWrite() { return &buffer_[writeIndex_]; }
|
||||
const char* beginWrite() const { return &buffer_[writeIndex_]; }
|
||||
|
||||
void retrieve(size_t len)
|
||||
{
|
||||
if (len < readableBytes()) {
|
||||
readIndex_ += len;
|
||||
} else {
|
||||
retrieveAll();
|
||||
}
|
||||
}
|
||||
|
||||
void retrieveAll()
|
||||
{
|
||||
if (buffer_.size() > (initialCap_ * 2)) {
|
||||
buffer_.resize(initialCap_ + kBufferOffset);
|
||||
LOG_DEBUG << "Buffer shrunk to " << buffer_.size();
|
||||
}
|
||||
readIndex_ = kBufferOffset;
|
||||
writeIndex_ = kBufferOffset;
|
||||
}
|
||||
|
||||
void hasWritten(size_t len)
|
||||
{
|
||||
writeIndex_ += len;
|
||||
LOG_TRACE << "Buffer written " << len << " bytes, total readable: " << readableBytes();
|
||||
}
|
||||
|
||||
void append(const char* data, size_t len)
|
||||
{
|
||||
ensureWritableBytes(len);
|
||||
std::copy(data, data + len, beginWrite());
|
||||
hasWritten(len);
|
||||
}
|
||||
|
||||
void append(const std::string& str)
|
||||
{
|
||||
append(str.data(), str.size());
|
||||
}
|
||||
|
||||
void ensureWritableBytes(size_t len)
|
||||
{
|
||||
if (writableBytes() >= len) return;
|
||||
|
||||
if (readIndex_ + writableBytes() >= (len + kBufferOffset)) {
|
||||
// Move readable data to front
|
||||
std::copy(&buffer_[readIndex_], &buffer_[writeIndex_], &buffer_[kBufferOffset]);
|
||||
writeIndex_ = kBufferOffset + (writeIndex_ - readIndex_);
|
||||
readIndex_ = kBufferOffset;
|
||||
LOG_TRACE << "Buffer compacted, readable bytes: " << readableBytes();
|
||||
} else {
|
||||
// Grow buffer
|
||||
size_t newLen = std::max(buffer_.size() * 2, kBufferOffset + readableBytes() + len);
|
||||
buffer_.resize(newLen);
|
||||
LOG_DEBUG << "Buffer grown to " << newLen;
|
||||
}
|
||||
}
|
||||
|
||||
// Network byte order append functions
|
||||
void appendInt8(uint8_t x)
|
||||
{
|
||||
append(reinterpret_cast<const char*>(&x), sizeof(x));
|
||||
}
|
||||
|
||||
void appendInt16(uint16_t x)
|
||||
{
|
||||
uint16_t be = htons(x);
|
||||
append(reinterpret_cast<const char*>(&be), sizeof(be));
|
||||
}
|
||||
|
||||
void appendInt32(uint32_t x)
|
||||
{
|
||||
uint32_t be = htonl(x);
|
||||
append(reinterpret_cast<const char*>(&be), sizeof(be));
|
||||
}
|
||||
|
||||
void appendInt64(uint64_t x)
|
||||
{
|
||||
uint64_t be = hton64(x);
|
||||
append(reinterpret_cast<const char*>(&be), sizeof(be));
|
||||
}
|
||||
|
||||
// Network byte order read functions
|
||||
uint8_t readInt8()
|
||||
{
|
||||
uint8_t result = peekInt8();
|
||||
retrieve(sizeof(result));
|
||||
return result;
|
||||
}
|
||||
|
||||
uint16_t readInt16()
|
||||
{
|
||||
uint16_t result = peekInt16();
|
||||
retrieve(sizeof(result));
|
||||
return result;
|
||||
}
|
||||
|
||||
uint32_t readInt32()
|
||||
{
|
||||
uint32_t result = peekInt32();
|
||||
retrieve(sizeof(result));
|
||||
return result;
|
||||
}
|
||||
|
||||
uint64_t readInt64()
|
||||
{
|
||||
uint64_t result = peekInt64();
|
||||
retrieve(sizeof(result));
|
||||
return result;
|
||||
}
|
||||
|
||||
// Network byte order peek functions
|
||||
uint8_t peekInt8() const
|
||||
{
|
||||
assert(readableBytes() >= sizeof(uint8_t));
|
||||
return *reinterpret_cast<const uint8_t*>(peek());
|
||||
}
|
||||
|
||||
uint16_t peekInt16() const
|
||||
{
|
||||
assert(readableBytes() >= sizeof(uint16_t));
|
||||
uint16_t be = *reinterpret_cast<const uint16_t*>(peek());
|
||||
return ntohs(be);
|
||||
}
|
||||
|
||||
uint32_t peekInt32() const
|
||||
{
|
||||
assert(readableBytes() >= sizeof(uint32_t));
|
||||
uint32_t be = *reinterpret_cast<const uint32_t*>(peek());
|
||||
return ntohl(be);
|
||||
}
|
||||
|
||||
uint64_t peekInt64() const
|
||||
{
|
||||
assert(readableBytes() >= sizeof(uint64_t));
|
||||
uint64_t be = *reinterpret_cast<const uint64_t*>(peek());
|
||||
return ntoh64(be);
|
||||
}
|
||||
|
||||
// Prepend functions for efficient header insertion
|
||||
void prepend(const char* data, size_t len)
|
||||
{
|
||||
assert(len <= prependableBytes());
|
||||
readIndex_ -= len;
|
||||
std::copy(data, data + len, &buffer_[readIndex_]);
|
||||
}
|
||||
|
||||
void prependInt8(uint8_t x)
|
||||
{
|
||||
prepend(reinterpret_cast<const char*>(&x), sizeof(x));
|
||||
}
|
||||
|
||||
void prependInt16(uint16_t x)
|
||||
{
|
||||
uint16_t be = htons(x);
|
||||
prepend(reinterpret_cast<const char*>(&be), sizeof(be));
|
||||
}
|
||||
|
||||
void prependInt32(uint32_t x)
|
||||
{
|
||||
uint32_t be = htonl(x);
|
||||
prepend(reinterpret_cast<const char*>(&be), sizeof(be));
|
||||
}
|
||||
|
||||
void prependInt64(uint64_t x)
|
||||
{
|
||||
uint64_t be = hton64(x);
|
||||
prepend(reinterpret_cast<const char*>(&be), sizeof(be));
|
||||
}
|
||||
|
||||
std::string read(size_t len)
|
||||
{
|
||||
if (len > readableBytes()) len = readableBytes();
|
||||
std::string result(peek(), len);
|
||||
retrieve(len);
|
||||
return result;
|
||||
}
|
||||
|
||||
std::string readAll()
|
||||
{
|
||||
return read(readableBytes());
|
||||
}
|
||||
|
||||
ssize_t readFd(int fd, int* savedErrno = nullptr)
|
||||
{
|
||||
char extrabuf[65536];
|
||||
iovec vec[2];
|
||||
size_t writable = writableBytes();
|
||||
|
||||
vec[0].iov_base = beginWrite();
|
||||
vec[0].iov_len = writable;
|
||||
vec[1].iov_base = extrabuf;
|
||||
vec[1].iov_len = sizeof(extrabuf);
|
||||
|
||||
const int iovcnt = (writable < sizeof(extrabuf)) ? 2 : 1;
|
||||
ssize_t n = readv(fd, vec, iovcnt);
|
||||
|
||||
if (n < 0) {
|
||||
if (savedErrno) *savedErrno = errno;
|
||||
LOG_ERROR << "readFd error: " << strerror(errno);
|
||||
} else if (static_cast<size_t>(n) <= writable) {
|
||||
writeIndex_ += n;
|
||||
} else {
|
||||
writeIndex_ = buffer_.size();
|
||||
append(extrabuf, n - writable);
|
||||
}
|
||||
|
||||
LOG_TRACE << "readFd returned " << n << " bytes, buffer now has " << readableBytes();
|
||||
return n;
|
||||
}
|
||||
|
||||
void swap(Buffer& other) noexcept
|
||||
{
|
||||
buffer_.swap(other.buffer_);
|
||||
std::swap(readIndex_, other.readIndex_);
|
||||
std::swap(writeIndex_, other.writeIndex_);
|
||||
std::swap(initialCap_, other.initialCap_);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace reactor
|
591
lib/Core.hpp
Normal file
591
lib/Core.hpp
Normal file
@ -0,0 +1,591 @@
|
||||
#pragma once
|
||||
|
||||
#include "Utilities.hpp"
|
||||
#include <sys/epoll.h>
|
||||
#include <sys/timerfd.h>
|
||||
#include <sys/eventfd.h>
|
||||
#include <poll.h>
|
||||
#include <unistd.h>
|
||||
#include <vector>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <queue>
|
||||
#include <chrono>
|
||||
#include <thread>
|
||||
#include <atomic>
|
||||
#include <cassert>
|
||||
|
||||
namespace reactor {
|
||||
|
||||
using TimePoint = std::chrono::steady_clock::time_point;
|
||||
using Duration = std::chrono::milliseconds;
|
||||
using TimerCallback = std::function<void()>;
|
||||
using EventCallback = std::function<void()>;
|
||||
|
||||
class EventLoop;
|
||||
|
||||
class Channel : public NonCopyable
|
||||
{
|
||||
private:
|
||||
EventLoop* loop_;
|
||||
int fd_;
|
||||
int events_;
|
||||
int revents_;
|
||||
int index_;
|
||||
bool tied_;
|
||||
std::weak_ptr<void> tie_;
|
||||
EventCallback readCallback_;
|
||||
EventCallback writeCallback_;
|
||||
EventCallback closeCallback_;
|
||||
EventCallback errorCallback_;
|
||||
|
||||
void handleEventSafely()
|
||||
{
|
||||
LOG_TRACE << "Channel fd=" << fd_ << " handling events: " << revents_;
|
||||
|
||||
if ((revents_ & POLLHUP) && !(revents_ & POLLIN)) {
|
||||
LOG_DEBUG << "Channel fd=" << fd_ << " hangup";
|
||||
if (closeCallback_) closeCallback_();
|
||||
}
|
||||
if (revents_ & POLLERR) {
|
||||
LOG_WARN << "Channel fd=" << fd_ << " error event";
|
||||
if (errorCallback_) errorCallback_();
|
||||
}
|
||||
if (revents_ & (POLLIN | POLLPRI | POLLRDHUP)) {
|
||||
LOG_TRACE << "Channel fd=" << fd_ << " readable";
|
||||
if (readCallback_) readCallback_();
|
||||
}
|
||||
if (revents_ & POLLOUT) {
|
||||
LOG_TRACE << "Channel fd=" << fd_ << " writable";
|
||||
if (writeCallback_) writeCallback_();
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
static constexpr int kNoneEvent = 0;
|
||||
static constexpr int kReadEvent = POLLIN | POLLPRI;
|
||||
static constexpr int kWriteEvent = POLLOUT;
|
||||
|
||||
Channel(EventLoop* loop, int fd)
|
||||
: loop_(loop), fd_(fd), events_(0), revents_(0), index_(-1), tied_(false)
|
||||
{
|
||||
LOG_TRACE << "Channel created for fd=" << fd;
|
||||
}
|
||||
|
||||
~Channel()
|
||||
{
|
||||
LOG_TRACE << "Channel destroyed for fd=" << fd_;
|
||||
}
|
||||
|
||||
int fd() const { return fd_; }
|
||||
int events() const { return events_; }
|
||||
int revents() const { return revents_; }
|
||||
int index() const { return index_; }
|
||||
EventLoop* ownerLoop() const { return loop_; }
|
||||
|
||||
void setRevents(int revents) { revents_ = revents; }
|
||||
void setIndex(int index) { index_ = index; }
|
||||
|
||||
void enableReading()
|
||||
{
|
||||
events_ |= kReadEvent;
|
||||
update();
|
||||
LOG_TRACE << "Channel fd=" << fd_ << " enabled reading";
|
||||
}
|
||||
|
||||
void disableReading()
|
||||
{
|
||||
events_ &= ~kReadEvent;
|
||||
update();
|
||||
LOG_TRACE << "Channel fd=" << fd_ << " disabled reading";
|
||||
}
|
||||
|
||||
void enableWriting()
|
||||
{
|
||||
events_ |= kWriteEvent;
|
||||
update();
|
||||
LOG_TRACE << "Channel fd=" << fd_ << " enabled writing";
|
||||
}
|
||||
|
||||
void disableWriting()
|
||||
{
|
||||
events_ &= ~kWriteEvent;
|
||||
update();
|
||||
LOG_TRACE << "Channel fd=" << fd_ << " disabled writing";
|
||||
}
|
||||
|
||||
void disableAll()
|
||||
{
|
||||
events_ = kNoneEvent;
|
||||
update();
|
||||
LOG_TRACE << "Channel fd=" << fd_ << " disabled all events";
|
||||
}
|
||||
|
||||
bool isNoneEvent() const { return events_ == kNoneEvent; }
|
||||
bool isWriting() const { return events_ & kWriteEvent; }
|
||||
bool isReading() const { return events_ & kReadEvent; }
|
||||
|
||||
void tie(const std::shared_ptr<void>& obj)
|
||||
{
|
||||
tie_ = obj;
|
||||
tied_ = true;
|
||||
}
|
||||
|
||||
void setReadCallback(EventCallback cb) { readCallback_ = std::move(cb); }
|
||||
void setWriteCallback(EventCallback cb) { writeCallback_ = std::move(cb); }
|
||||
void setCloseCallback(EventCallback cb) { closeCallback_ = std::move(cb); }
|
||||
void setErrorCallback(EventCallback cb) { errorCallback_ = std::move(cb); }
|
||||
|
||||
void handleEvent()
|
||||
{
|
||||
if (tied_) {
|
||||
std::shared_ptr<void> guard = tie_.lock();
|
||||
if (guard) {
|
||||
handleEventSafely();
|
||||
} else {
|
||||
LOG_WARN << "Channel fd=" << fd_ << " tied object expired";
|
||||
}
|
||||
} else {
|
||||
handleEventSafely();
|
||||
}
|
||||
}
|
||||
|
||||
void remove();
|
||||
void update();
|
||||
};
|
||||
|
||||
class Timer : public NonCopyable
|
||||
{
|
||||
private:
|
||||
TimerCallback callback_;
|
||||
TimePoint when_;
|
||||
Duration interval_;
|
||||
bool repeat_;
|
||||
uint64_t id_;
|
||||
static inline std::atomic<uint64_t> s_numCreated_{0};
|
||||
|
||||
public:
|
||||
Timer(TimerCallback cb, TimePoint when, Duration interval = Duration(0))
|
||||
: callback_(std::move(cb)), when_(when), interval_(interval),
|
||||
repeat_(interval.count() > 0), id_(++s_numCreated_)
|
||||
{
|
||||
LOG_TRACE << "Timer created id=" << id_ << " repeat=" << repeat_;
|
||||
}
|
||||
|
||||
void run() const
|
||||
{
|
||||
LOG_TRACE << "Timer id=" << id_ << " executing";
|
||||
callback_();
|
||||
}
|
||||
|
||||
TimePoint when() const { return when_; }
|
||||
bool repeat() const { return repeat_; }
|
||||
uint64_t id() const { return id_; }
|
||||
|
||||
void restart(TimePoint now)
|
||||
{
|
||||
if (repeat_) {
|
||||
when_ = now + interval_;
|
||||
LOG_TRACE << "Timer id=" << id_ << " restarted";
|
||||
}
|
||||
}
|
||||
|
||||
bool operator<(const Timer& other) const { return when_ > other.when_; }
|
||||
};
|
||||
|
||||
class TimerQueue : public NonCopyable
|
||||
{
|
||||
private:
|
||||
EventLoop* loop_;
|
||||
int timerfd_;
|
||||
std::unique_ptr<Channel> timerChannel_;
|
||||
std::priority_queue<std::shared_ptr<Timer>> timers_;
|
||||
std::unordered_set<uint64_t> activeTimers_;
|
||||
|
||||
static int createTimerfd()
|
||||
{
|
||||
int fd = timerfd_create(CLOCK_MONOTONIC, TFD_NONBLOCK | TFD_CLOEXEC);
|
||||
if (fd < 0) {
|
||||
LOG_FATAL << "Failed to create timerfd: " << strerror(errno);
|
||||
abort();
|
||||
}
|
||||
LOG_DEBUG << "Created timerfd=" << fd;
|
||||
return fd;
|
||||
}
|
||||
|
||||
void resetTimerfd(TimePoint expiration)
|
||||
{
|
||||
auto duration = expiration - std::chrono::steady_clock::now();
|
||||
auto ns = std::chrono::duration_cast<std::chrono::nanoseconds>(duration).count();
|
||||
if (ns < 100000) ns = 100000;
|
||||
|
||||
itimerspec newValue{};
|
||||
newValue.it_value.tv_sec = ns / 1000000000;
|
||||
newValue.it_value.tv_nsec = ns % 1000000000;
|
||||
|
||||
if (timerfd_settime(timerfd_, 0, &newValue, nullptr) < 0) {
|
||||
LOG_ERROR << "timerfd_settime failed: " << strerror(errno);
|
||||
}
|
||||
}
|
||||
|
||||
void handleRead()
|
||||
{
|
||||
uint64_t data;
|
||||
ssize_t n = read(timerfd_, &data, sizeof(data));
|
||||
if (n != sizeof(data)) {
|
||||
LOG_ERROR << "TimerQueue read " << n << " bytes instead of 8";
|
||||
}
|
||||
|
||||
auto now = std::chrono::steady_clock::now();
|
||||
std::vector<std::shared_ptr<Timer>> expired;
|
||||
|
||||
while (!timers_.empty() && timers_.top()->when() <= now) {
|
||||
expired.push_back(timers_.top());
|
||||
timers_.pop();
|
||||
}
|
||||
|
||||
LOG_TRACE << "TimerQueue processed " << expired.size() << " expired timers";
|
||||
|
||||
for (auto& timer : expired) {
|
||||
if (activeTimers_.count(timer->id())) {
|
||||
timer->run();
|
||||
if (timer->repeat()) {
|
||||
timer->restart(now);
|
||||
timers_.push(timer);
|
||||
} else {
|
||||
activeTimers_.erase(timer->id());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!timers_.empty()) {
|
||||
resetTimerfd(timers_.top()->when());
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
TimerQueue(EventLoop* loop)
|
||||
: loop_(loop), timerfd_(createTimerfd()),
|
||||
timerChannel_(std::make_unique<Channel>(loop, timerfd_))
|
||||
{
|
||||
timerChannel_->setReadCallback([this]() { handleRead(); });
|
||||
timerChannel_->enableReading();
|
||||
LOG_DEBUG << "TimerQueue initialized";
|
||||
}
|
||||
|
||||
~TimerQueue()
|
||||
{
|
||||
timerChannel_->disableAll();
|
||||
timerChannel_->remove();
|
||||
close(timerfd_);
|
||||
LOG_DEBUG << "TimerQueue destroyed";
|
||||
}
|
||||
|
||||
uint64_t addTimer(TimerCallback cb, TimePoint when, Duration interval = Duration(0))
|
||||
{
|
||||
auto timer = std::make_shared<Timer>(std::move(cb), when, interval);
|
||||
bool earliestChanged = timers_.empty() || when < timers_.top()->when();
|
||||
|
||||
timers_.push(timer);
|
||||
activeTimers_.insert(timer->id());
|
||||
|
||||
if (earliestChanged) {
|
||||
resetTimerfd(when);
|
||||
}
|
||||
|
||||
LOG_DEBUG << "Added timer id=" << timer->id() << " earliest_changed=" << earliestChanged;
|
||||
return timer->id();
|
||||
}
|
||||
|
||||
void cancel(uint64_t timerId)
|
||||
{
|
||||
auto erased = activeTimers_.erase(timerId);
|
||||
LOG_DEBUG << "Cancelled timer id=" << timerId << " found=" << (erased > 0);
|
||||
}
|
||||
};
|
||||
|
||||
class EpollPoller : public NonCopyable
|
||||
{
|
||||
private:
|
||||
static constexpr int kNew = -1;
|
||||
static constexpr int kAdded = 1;
|
||||
static constexpr int kDeleted = 2;
|
||||
static constexpr int kInitEventListSize = 16;
|
||||
|
||||
int epollfd_;
|
||||
std::vector<epoll_event> events_;
|
||||
std::unordered_map<int, Channel*> channels_;
|
||||
|
||||
void update(int operation, Channel* channel)
|
||||
{
|
||||
epoll_event event{};
|
||||
event.events = channel->events();
|
||||
event.data.ptr = channel;
|
||||
|
||||
if (epoll_ctl(epollfd_, operation, channel->fd(), &event) < 0) {
|
||||
LOG_ERROR << "epoll_ctl op=" << operation << " fd=" << channel->fd()
|
||||
<< " failed: " << strerror(errno);
|
||||
} else {
|
||||
LOG_TRACE << "epoll_ctl op=" << operation << " fd=" << channel->fd() << " success";
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
EpollPoller() : epollfd_(epoll_create1(EPOLL_CLOEXEC)), events_(kInitEventListSize)
|
||||
{
|
||||
if (epollfd_ < 0) {
|
||||
LOG_FATAL << "epoll_create1 failed: " << strerror(errno);
|
||||
abort();
|
||||
}
|
||||
LOG_DEBUG << "EpollPoller created with epollfd=" << epollfd_;
|
||||
}
|
||||
|
||||
~EpollPoller()
|
||||
{
|
||||
close(epollfd_);
|
||||
LOG_DEBUG << "EpollPoller destroyed";
|
||||
}
|
||||
|
||||
std::vector<Channel*> poll(int timeoutMs = -1)
|
||||
{
|
||||
int numEvents = epoll_wait(epollfd_, events_.data(),
|
||||
static_cast<int>(events_.size()), timeoutMs);
|
||||
|
||||
std::vector<Channel*> activeChannels;
|
||||
|
||||
if (numEvents > 0) {
|
||||
LOG_TRACE << "EpollPoller got " << numEvents << " events";
|
||||
|
||||
for (int i = 0; i < numEvents; ++i) {
|
||||
auto channel = static_cast<Channel*>(events_[i].data.ptr);
|
||||
channel->setRevents(events_[i].events);
|
||||
activeChannels.push_back(channel);
|
||||
}
|
||||
|
||||
if (static_cast<size_t>(numEvents) == events_.size()) {
|
||||
events_.resize(events_.size() * 2);
|
||||
LOG_DEBUG << "EpollPoller events buffer grown to " << events_.size();
|
||||
}
|
||||
} else if (numEvents < 0 && errno != EINTR) {
|
||||
LOG_ERROR << "EpollPoller::poll failed: " << strerror(errno);
|
||||
}
|
||||
|
||||
return activeChannels;
|
||||
}
|
||||
|
||||
void updateChannel(Channel* channel)
|
||||
{
|
||||
int index = channel->index();
|
||||
int fd = channel->fd();
|
||||
|
||||
if (index == kNew || index == kDeleted) {
|
||||
if (index == kNew) {
|
||||
channels_[fd] = channel;
|
||||
}
|
||||
channel->setIndex(kAdded);
|
||||
update(EPOLL_CTL_ADD, channel);
|
||||
} else {
|
||||
if (channel->isNoneEvent()) {
|
||||
update(EPOLL_CTL_DEL, channel);
|
||||
channel->setIndex(kDeleted);
|
||||
} else {
|
||||
update(EPOLL_CTL_MOD, channel);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void removeChannel(Channel* channel)
|
||||
{
|
||||
int fd = channel->fd();
|
||||
int index = channel->index();
|
||||
|
||||
auto n = channels_.erase(fd);
|
||||
assert(n == 1);
|
||||
|
||||
if (index == kAdded) {
|
||||
update(EPOLL_CTL_DEL, channel);
|
||||
}
|
||||
|
||||
channel->setIndex(kNew);
|
||||
LOG_TRACE << "Channel fd=" << fd << " removed from poller";
|
||||
}
|
||||
};
|
||||
|
||||
class EventLoop : public NonCopyable
|
||||
{
|
||||
private:
|
||||
std::unique_ptr<EpollPoller> poller_;
|
||||
std::unique_ptr<TimerQueue> timerQueue_;
|
||||
int wakeupFd_;
|
||||
std::unique_ptr<Channel> wakeupChannel_;
|
||||
std::atomic<bool> looping_;
|
||||
std::atomic<bool> quit_;
|
||||
std::thread::id threadId_;
|
||||
LockFreeQueue<std::function<void()>> pendingFunctors_;
|
||||
bool callingPendingFunctors_;
|
||||
|
||||
static int createEventfd()
|
||||
{
|
||||
int fd = eventfd(0, EFD_NONBLOCK | EFD_CLOEXEC);
|
||||
if (fd < 0) {
|
||||
LOG_FATAL << "Failed to create eventfd: " << strerror(errno);
|
||||
abort();
|
||||
}
|
||||
LOG_DEBUG << "Created eventfd=" << fd;
|
||||
return fd;
|
||||
}
|
||||
|
||||
void wakeup()
|
||||
{
|
||||
uint64_t one = 1;
|
||||
ssize_t n = write(wakeupFd_, &one, sizeof(one));
|
||||
if (n != sizeof(one)) {
|
||||
LOG_ERROR << "EventLoop wakeup writes " << n << " bytes instead of 8";
|
||||
}
|
||||
}
|
||||
|
||||
void handleRead()
|
||||
{
|
||||
uint64_t one;
|
||||
ssize_t n = read(wakeupFd_, &one, sizeof(one));
|
||||
if (n != sizeof(one)) {
|
||||
LOG_ERROR << "EventLoop handleRead reads " << n << " bytes instead of 8";
|
||||
}
|
||||
}
|
||||
|
||||
void doPendingFunctors()
|
||||
{
|
||||
callingPendingFunctors_ = true;
|
||||
|
||||
std::function<void()> functor;
|
||||
int count = 0;
|
||||
while (pendingFunctors_.dequeue(functor)) {
|
||||
functor();
|
||||
++count;
|
||||
}
|
||||
|
||||
if (count > 0) {
|
||||
LOG_TRACE << "EventLoop executed " << count << " pending functors";
|
||||
}
|
||||
|
||||
callingPendingFunctors_ = false;
|
||||
}
|
||||
|
||||
public:
|
||||
EventLoop()
|
||||
: poller_(std::make_unique<EpollPoller>()),
|
||||
timerQueue_(std::make_unique<TimerQueue>(this)),
|
||||
wakeupFd_(createEventfd()),
|
||||
wakeupChannel_(std::make_unique<Channel>(this, wakeupFd_)),
|
||||
looping_(false), quit_(false),
|
||||
threadId_(std::this_thread::get_id()),
|
||||
callingPendingFunctors_(false)
|
||||
{
|
||||
wakeupChannel_->setReadCallback([this]() { handleRead(); });
|
||||
wakeupChannel_->enableReading();
|
||||
LOG_INFO << "EventLoop created in thread " << threadId_;
|
||||
}
|
||||
|
||||
~EventLoop()
|
||||
{
|
||||
wakeupChannel_->disableAll();
|
||||
wakeupChannel_->remove();
|
||||
close(wakeupFd_);
|
||||
LOG_INFO << "EventLoop destroyed";
|
||||
}
|
||||
|
||||
void loop()
|
||||
{
|
||||
assert(!looping_);
|
||||
assertInLoopThread();
|
||||
looping_ = true;
|
||||
quit_ = false;
|
||||
|
||||
LOG_INFO << "EventLoop started looping";
|
||||
|
||||
while (!quit_) {
|
||||
auto activeChannels = poller_->poll(10000);
|
||||
for (auto channel : activeChannels) {
|
||||
channel->handleEvent();
|
||||
}
|
||||
doPendingFunctors();
|
||||
}
|
||||
|
||||
looping_ = false;
|
||||
LOG_INFO << "EventLoop stopped looping";
|
||||
}
|
||||
|
||||
void quit()
|
||||
{
|
||||
quit_ = true;
|
||||
if (!isInLoopThread()) {
|
||||
wakeup();
|
||||
}
|
||||
LOG_DEBUG << "EventLoop quit requested";
|
||||
}
|
||||
|
||||
void runInLoop(std::function<void()> cb)
|
||||
{
|
||||
if (isInLoopThread()) {
|
||||
cb();
|
||||
} else {
|
||||
queueInLoop(std::move(cb));
|
||||
}
|
||||
}
|
||||
|
||||
void queueInLoop(std::function<void()> cb)
|
||||
{
|
||||
pendingFunctors_.enqueue(std::move(cb));
|
||||
|
||||
if (!isInLoopThread() || callingPendingFunctors_) {
|
||||
wakeup();
|
||||
}
|
||||
}
|
||||
|
||||
uint64_t runAt(TimePoint when, TimerCallback cb)
|
||||
{
|
||||
return timerQueue_->addTimer(std::move(cb), when);
|
||||
}
|
||||
|
||||
uint64_t runAfter(Duration delay, TimerCallback cb)
|
||||
{
|
||||
return runAt(std::chrono::steady_clock::now() + delay, std::move(cb));
|
||||
}
|
||||
|
||||
uint64_t runEvery(Duration interval, TimerCallback cb)
|
||||
{
|
||||
auto when = std::chrono::steady_clock::now() + interval;
|
||||
return timerQueue_->addTimer(std::move(cb), when, interval);
|
||||
}
|
||||
|
||||
void cancel(uint64_t timerId) { timerQueue_->cancel(timerId); }
|
||||
|
||||
void updateChannel(Channel* channel) { poller_->updateChannel(channel); }
|
||||
void removeChannel(Channel* channel) { poller_->removeChannel(channel); }
|
||||
|
||||
bool isInLoopThread() const { return threadId_ == std::this_thread::get_id(); }
|
||||
|
||||
void assertInLoopThread() const
|
||||
{
|
||||
if (!isInLoopThread()) {
|
||||
LOG_FATAL << "EventLoop was created in thread " << threadId_
|
||||
<< " but accessed from thread " << std::this_thread::get_id();
|
||||
abort();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
inline void Channel::update()
|
||||
{
|
||||
loop_->updateChannel(this);
|
||||
}
|
||||
|
||||
inline void Channel::remove()
|
||||
{
|
||||
assert(isNoneEvent());
|
||||
loop_->removeChannel(this);
|
||||
}
|
||||
|
||||
} // namespace reactor
|
110
lib/EventLoopThread.hpp
Normal file
110
lib/EventLoopThread.hpp
Normal file
@ -0,0 +1,110 @@
|
||||
#pragma once
|
||||
|
||||
#include "Core.hpp"
|
||||
#include "Utilities.hpp"
|
||||
#include <thread>
|
||||
#include <mutex>
|
||||
#include <condition_variable>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <atomic>
|
||||
|
||||
namespace reactor {
|
||||
|
||||
class EventLoopThread : public NonCopyable
|
||||
{
|
||||
private:
|
||||
std::thread thread_;
|
||||
EventLoop* loop_;
|
||||
std::mutex mutex_;
|
||||
std::condition_variable cond_;
|
||||
std::string name_;
|
||||
|
||||
void threadFunc()
|
||||
{
|
||||
LOG_DEBUG << "EventLoopThread '" << name_ << "' starting";
|
||||
EventLoop loop;
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
loop_ = &loop;
|
||||
cond_.notify_one();
|
||||
}
|
||||
|
||||
loop.loop();
|
||||
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
loop_ = nullptr;
|
||||
LOG_DEBUG << "EventLoopThread '" << name_ << "' finished";
|
||||
}
|
||||
|
||||
public:
|
||||
explicit EventLoopThread(const std::string& name = "EventLoopThread")
|
||||
: loop_(nullptr), name_(name)
|
||||
{
|
||||
thread_ = std::thread([this]() { threadFunc(); });
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
cond_.wait(lock, [this]() { return loop_ != nullptr; });
|
||||
LOG_INFO << "EventLoopThread '" << name_ << "' initialized";
|
||||
}
|
||||
|
||||
~EventLoopThread()
|
||||
{
|
||||
if (loop_) {
|
||||
loop_->quit();
|
||||
thread_.join();
|
||||
}
|
||||
LOG_INFO << "EventLoopThread '" << name_ << "' destroyed";
|
||||
}
|
||||
|
||||
EventLoop* getLoop() { return loop_; }
|
||||
const std::string& name() const { return name_; }
|
||||
};
|
||||
|
||||
class EventLoopThreadPool : public NonCopyable
|
||||
{
|
||||
private:
|
||||
std::vector<std::unique_ptr<EventLoopThread>> threads_;
|
||||
std::vector<EventLoop*> loops_;
|
||||
std::atomic<size_t> next_;
|
||||
std::string baseThreadName_;
|
||||
|
||||
public:
|
||||
explicit EventLoopThreadPool(size_t numThreads, const std::string& baseName = "EventLoopThread")
|
||||
: next_(0), baseThreadName_(baseName)
|
||||
{
|
||||
LOG_INFO << "Creating EventLoopThreadPool with " << numThreads << " threads";
|
||||
|
||||
for (size_t i = 0; i < numThreads; ++i) {
|
||||
std::string threadName = baseThreadName_ + "-" + std::to_string(i);
|
||||
auto thread = std::make_unique<EventLoopThread>(threadName);
|
||||
loops_.push_back(thread->getLoop());
|
||||
threads_.push_back(std::move(thread));
|
||||
}
|
||||
|
||||
LOG_INFO << "EventLoopThreadPool created with " << numThreads << " threads";
|
||||
}
|
||||
|
||||
~EventLoopThreadPool()
|
||||
{
|
||||
LOG_INFO << "EventLoopThreadPool destroying " << threads_.size() << " threads";
|
||||
}
|
||||
|
||||
EventLoop* getNextLoop()
|
||||
{
|
||||
if (loops_.empty()) {
|
||||
LOG_WARN << "EventLoopThreadPool has no loops available";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
size_t index = next_++ % loops_.size();
|
||||
LOG_TRACE << "EventLoopThreadPool returning loop " << index;
|
||||
return loops_[index];
|
||||
}
|
||||
|
||||
std::vector<EventLoop*> getAllLoops() const { return loops_; }
|
||||
size_t size() const { return loops_.size(); }
|
||||
|
||||
const std::string& getBaseName() const { return baseThreadName_; }
|
||||
};
|
||||
|
||||
} // namespace reactor
|
170
lib/InetAddress.hpp
Normal file
170
lib/InetAddress.hpp
Normal file
@ -0,0 +1,170 @@
|
||||
#pragma once
|
||||
|
||||
#include "Utilities.hpp"
|
||||
#include <netinet/in.h>
|
||||
#include <arpa/inet.h>
|
||||
#include <string>
|
||||
#include <cstring>
|
||||
|
||||
namespace reactor {
|
||||
|
||||
class InetAddress
|
||||
{
|
||||
private:
|
||||
union
|
||||
{
|
||||
sockaddr_in addr4_;
|
||||
sockaddr_in6 addr6_;
|
||||
};
|
||||
bool isIpV6_;
|
||||
|
||||
public:
|
||||
explicit InetAddress(uint16_t port = 0, bool ipv6 = false, bool loopback = false)
|
||||
: isIpV6_(ipv6)
|
||||
{
|
||||
if (ipv6) {
|
||||
memset(&addr6_, 0, sizeof(addr6_));
|
||||
addr6_.sin6_family = AF_INET6;
|
||||
addr6_.sin6_addr = loopback ? in6addr_loopback : in6addr_any;
|
||||
addr6_.sin6_port = htons(port);
|
||||
} else {
|
||||
memset(&addr4_, 0, sizeof(addr4_));
|
||||
addr4_.sin_family = AF_INET;
|
||||
addr4_.sin_addr.s_addr = htonl(loopback ? INADDR_LOOPBACK : INADDR_ANY);
|
||||
addr4_.sin_port = htons(port);
|
||||
}
|
||||
LOG_TRACE << "InetAddress created: " << toIpPort();
|
||||
}
|
||||
|
||||
InetAddress(const std::string& ip, uint16_t port)
|
||||
{
|
||||
if (ip.find(':') != std::string::npos) {
|
||||
isIpV6_ = true;
|
||||
memset(&addr6_, 0, sizeof(addr6_));
|
||||
addr6_.sin6_family = AF_INET6;
|
||||
addr6_.sin6_port = htons(port);
|
||||
if (inet_pton(AF_INET6, ip.c_str(), &addr6_.sin6_addr) <= 0) {
|
||||
LOG_ERROR << "Invalid IPv6 address: " << ip;
|
||||
}
|
||||
} else {
|
||||
isIpV6_ = false;
|
||||
memset(&addr4_, 0, sizeof(addr4_));
|
||||
addr4_.sin_family = AF_INET;
|
||||
addr4_.sin_port = htons(port);
|
||||
if (inet_pton(AF_INET, ip.c_str(), &addr4_.sin_addr) <= 0) {
|
||||
LOG_ERROR << "Invalid IPv4 address: " << ip;
|
||||
}
|
||||
}
|
||||
LOG_TRACE << "InetAddress created from ip:port: " << toIpPort();
|
||||
}
|
||||
|
||||
explicit InetAddress(const sockaddr_in& addr) : addr4_(addr), isIpV6_(false)
|
||||
{
|
||||
LOG_TRACE << "InetAddress created from sockaddr_in: " << toIpPort();
|
||||
}
|
||||
|
||||
explicit InetAddress(const sockaddr_in6& addr) : addr6_(addr), isIpV6_(true)
|
||||
{
|
||||
LOG_TRACE << "InetAddress created from sockaddr_in6: " << toIpPort();
|
||||
}
|
||||
|
||||
const sockaddr* getSockAddr() const
|
||||
{
|
||||
if (isIpV6_) {
|
||||
return reinterpret_cast<const sockaddr*>(&addr6_);
|
||||
} else {
|
||||
return reinterpret_cast<const sockaddr*>(&addr4_);
|
||||
}
|
||||
}
|
||||
|
||||
socklen_t getSockLen() const { return isIpV6_ ? sizeof(addr6_) : sizeof(addr4_); }
|
||||
bool isIpV6() const { return isIpV6_; }
|
||||
uint16_t port() const { return ntohs(isIpV6_ ? addr6_.sin6_port : addr4_.sin_port); }
|
||||
|
||||
std::string toIp() const
|
||||
{
|
||||
char buf[INET6_ADDRSTRLEN];
|
||||
if (isIpV6_) {
|
||||
inet_ntop(AF_INET6, &addr6_.sin6_addr, buf, sizeof(buf));
|
||||
} else {
|
||||
inet_ntop(AF_INET, &addr4_.sin_addr, buf, sizeof(buf));
|
||||
}
|
||||
return std::string(buf);
|
||||
}
|
||||
|
||||
std::string toIpPort() const
|
||||
{
|
||||
return isIpV6_ ? "[" + toIp() + "]:" + std::to_string(port())
|
||||
: toIp() + ":" + std::to_string(port());
|
||||
}
|
||||
|
||||
bool operator==(const InetAddress& other) const
|
||||
{
|
||||
if (isIpV6_ != other.isIpV6_) return false;
|
||||
|
||||
if (isIpV6_) {
|
||||
return memcmp(&addr6_, &other.addr6_, sizeof(addr6_)) == 0;
|
||||
} else {
|
||||
return memcmp(&addr4_, &other.addr4_, sizeof(addr4_)) == 0;
|
||||
}
|
||||
}
|
||||
|
||||
bool operator!=(const InetAddress& other) const
|
||||
{
|
||||
return !(*this == other);
|
||||
}
|
||||
|
||||
bool operator<(const InetAddress& other) const
|
||||
{
|
||||
if (isIpV6_ != other.isIpV6_) {
|
||||
return !isIpV6_;
|
||||
}
|
||||
|
||||
if (isIpV6_) {
|
||||
return memcmp(&addr6_, &other.addr6_, sizeof(addr6_)) < 0;
|
||||
} else {
|
||||
return memcmp(&addr4_, &other.addr4_, sizeof(addr4_)) < 0;
|
||||
}
|
||||
}
|
||||
|
||||
std::string familyToString() const
|
||||
{
|
||||
return isIpV6_ ? "IPv6" : "IPv4";
|
||||
}
|
||||
|
||||
static bool resolve(const std::string& hostname, InetAddress& result)
|
||||
{
|
||||
// Simple resolution - in a real implementation you'd use getaddrinfo
|
||||
if (hostname == "localhost") {
|
||||
result = InetAddress(0, false, true);
|
||||
return true;
|
||||
}
|
||||
|
||||
// Try to parse as IP address directly
|
||||
InetAddress addr(hostname, 0);
|
||||
if (addr.toIp() != "0.0.0.0" && addr.toIp() != "::") {
|
||||
result = addr;
|
||||
return true;
|
||||
}
|
||||
|
||||
LOG_WARN << "Could not resolve hostname: " << hostname;
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace reactor
|
||||
|
||||
namespace std {
|
||||
template<>
|
||||
struct hash<reactor::InetAddress>
|
||||
{
|
||||
size_t operator()(const reactor::InetAddress& addr) const
|
||||
{
|
||||
size_t seed = 0;
|
||||
reactor::hashCombine(seed, addr.toIp());
|
||||
reactor::hashCombine(seed, addr.port());
|
||||
reactor::hashCombine(seed, addr.isIpV6());
|
||||
return seed;
|
||||
}
|
||||
};
|
||||
} // namespace std
|
318
lib/Socket.hpp
Normal file
318
lib/Socket.hpp
Normal file
@ -0,0 +1,318 @@
|
||||
#pragma once
|
||||
|
||||
#include "InetAddress.hpp"
|
||||
#include "Utilities.hpp"
|
||||
#include <sys/socket.h>
|
||||
#include <netinet/tcp.h>
|
||||
#include <unistd.h>
|
||||
#include <fcntl.h>
|
||||
#include <cassert>
|
||||
|
||||
namespace reactor {
|
||||
|
||||
class Socket : public NonCopyable
|
||||
{
|
||||
private:
|
||||
int fd_;
|
||||
|
||||
void setNonBlockAndCloseOnExec()
|
||||
{
|
||||
int flags = fcntl(fd_, F_GETFL, 0);
|
||||
flags |= O_NONBLOCK;
|
||||
fcntl(fd_, F_SETFL, flags);
|
||||
|
||||
flags = fcntl(fd_, F_GETFD, 0);
|
||||
flags |= FD_CLOEXEC;
|
||||
fcntl(fd_, F_SETFD, flags);
|
||||
}
|
||||
|
||||
public:
|
||||
explicit Socket(int fd) : fd_(fd)
|
||||
{
|
||||
LOG_TRACE << "Socket created with fd=" << fd_;
|
||||
}
|
||||
|
||||
~Socket()
|
||||
{
|
||||
if (fd_ >= 0) {
|
||||
close(fd_);
|
||||
LOG_TRACE << "Socket fd=" << fd_ << " closed";
|
||||
}
|
||||
}
|
||||
|
||||
Socket(Socket&& other) noexcept : fd_(other.fd_)
|
||||
{
|
||||
other.fd_ = -1;
|
||||
LOG_TRACE << "Socket moved fd=" << fd_;
|
||||
}
|
||||
|
||||
Socket& operator=(Socket&& other) noexcept
|
||||
{
|
||||
if (this != &other) {
|
||||
if (fd_ >= 0) {
|
||||
close(fd_);
|
||||
LOG_TRACE << "Socket fd=" << fd_ << " closed in move assignment";
|
||||
}
|
||||
fd_ = other.fd_;
|
||||
other.fd_ = -1;
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
static Socket createTcp(bool ipv6 = false)
|
||||
{
|
||||
int fd = socket(ipv6 ? AF_INET6 : AF_INET, SOCK_STREAM | SOCK_NONBLOCK | SOCK_CLOEXEC, 0);
|
||||
if (fd < 0) {
|
||||
LOG_FATAL << "Failed to create TCP socket: " << strerror(errno);
|
||||
abort();
|
||||
}
|
||||
LOG_DEBUG << "Created TCP socket fd=" << fd << " ipv6=" << ipv6;
|
||||
return Socket(fd);
|
||||
}
|
||||
|
||||
static Socket createUdp(bool ipv6 = false)
|
||||
{
|
||||
int fd = socket(ipv6 ? AF_INET6 : AF_INET, SOCK_DGRAM | SOCK_NONBLOCK | SOCK_CLOEXEC, 0);
|
||||
if (fd < 0) {
|
||||
LOG_FATAL << "Failed to create UDP socket: " << strerror(errno);
|
||||
abort();
|
||||
}
|
||||
LOG_DEBUG << "Created UDP socket fd=" << fd << " ipv6=" << ipv6;
|
||||
return Socket(fd);
|
||||
}
|
||||
|
||||
void bind(const InetAddress& addr)
|
||||
{
|
||||
int ret = ::bind(fd_, addr.getSockAddr(), addr.getSockLen());
|
||||
if (ret < 0) {
|
||||
LOG_FATAL << "Socket bind to " << addr.toIpPort() << " failed: " << strerror(errno);
|
||||
abort();
|
||||
}
|
||||
LOG_INFO << "Socket fd=" << fd_ << " bound to " << addr.toIpPort();
|
||||
}
|
||||
|
||||
void listen(int backlog = SOMAXCONN)
|
||||
{
|
||||
int ret = ::listen(fd_, backlog);
|
||||
if (ret < 0) {
|
||||
LOG_FATAL << "Socket listen failed: " << strerror(errno);
|
||||
abort();
|
||||
}
|
||||
LOG_INFO << "Socket fd=" << fd_ << " listening with backlog=" << backlog;
|
||||
}
|
||||
|
||||
int accept(InetAddress& peerAddr)
|
||||
{
|
||||
sockaddr_in6 addr;
|
||||
socklen_t len = sizeof(addr);
|
||||
int connfd = accept4(fd_, reinterpret_cast<sockaddr*>(&addr), &len, SOCK_NONBLOCK | SOCK_CLOEXEC);
|
||||
|
||||
if (connfd >= 0) {
|
||||
if (addr.sin6_family == AF_INET) {
|
||||
peerAddr = InetAddress(*reinterpret_cast<sockaddr_in*>(&addr));
|
||||
} else {
|
||||
peerAddr = InetAddress(addr);
|
||||
}
|
||||
LOG_DEBUG << "Socket fd=" << fd_ << " accepted connection fd=" << connfd
|
||||
<< " from " << peerAddr.toIpPort();
|
||||
} else if (errno != EAGAIN && errno != EWOULDBLOCK) {
|
||||
LOG_ERROR << "Socket accept failed: " << strerror(errno);
|
||||
}
|
||||
|
||||
return connfd;
|
||||
}
|
||||
|
||||
int connect(const InetAddress& addr)
|
||||
{
|
||||
int ret = ::connect(fd_, addr.getSockAddr(), addr.getSockLen());
|
||||
if (ret < 0 && errno != EINPROGRESS) {
|
||||
LOG_ERROR << "Socket connect to " << addr.toIpPort() << " failed: " << strerror(errno);
|
||||
} else {
|
||||
LOG_DEBUG << "Socket fd=" << fd_ << " connecting to " << addr.toIpPort();
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
void setReuseAddr(bool on = true)
|
||||
{
|
||||
int optval = on ? 1 : 0;
|
||||
if (setsockopt(fd_, SOL_SOCKET, SO_REUSEADDR, &optval, sizeof(optval)) < 0) {
|
||||
LOG_ERROR << "setsockopt SO_REUSEADDR failed: " << strerror(errno);
|
||||
} else {
|
||||
LOG_TRACE << "Socket fd=" << fd_ << " SO_REUSEADDR=" << on;
|
||||
}
|
||||
}
|
||||
|
||||
void setReusePort(bool on = true)
|
||||
{
|
||||
int optval = on ? 1 : 0;
|
||||
if (setsockopt(fd_, SOL_SOCKET, SO_REUSEPORT, &optval, sizeof(optval)) < 0) {
|
||||
LOG_ERROR << "setsockopt SO_REUSEPORT failed: " << strerror(errno);
|
||||
} else {
|
||||
LOG_TRACE << "Socket fd=" << fd_ << " SO_REUSEPORT=" << on;
|
||||
}
|
||||
}
|
||||
|
||||
void setTcpNoDelay(bool on = true)
|
||||
{
|
||||
int optval = on ? 1 : 0;
|
||||
if (setsockopt(fd_, IPPROTO_TCP, TCP_NODELAY, &optval, sizeof(optval)) < 0) {
|
||||
LOG_ERROR << "setsockopt TCP_NODELAY failed: " << strerror(errno);
|
||||
} else {
|
||||
LOG_TRACE << "Socket fd=" << fd_ << " TCP_NODELAY=" << on;
|
||||
}
|
||||
}
|
||||
|
||||
void setKeepAlive(bool on = true)
|
||||
{
|
||||
int optval = on ? 1 : 0;
|
||||
if (setsockopt(fd_, SOL_SOCKET, SO_KEEPALIVE, &optval, sizeof(optval)) < 0) {
|
||||
LOG_ERROR << "setsockopt SO_KEEPALIVE failed: " << strerror(errno);
|
||||
} else {
|
||||
LOG_TRACE << "Socket fd=" << fd_ << " SO_KEEPALIVE=" << on;
|
||||
}
|
||||
}
|
||||
|
||||
void setTcpKeepAlive(int idle, int interval, int count)
|
||||
{
|
||||
if (setsockopt(fd_, IPPROTO_TCP, TCP_KEEPIDLE, &idle, sizeof(idle)) < 0 ||
|
||||
setsockopt(fd_, IPPROTO_TCP, TCP_KEEPINTVL, &interval, sizeof(interval)) < 0 ||
|
||||
setsockopt(fd_, IPPROTO_TCP, TCP_KEEPCNT, &count, sizeof(count)) < 0) {
|
||||
LOG_ERROR << "setsockopt TCP_KEEP* failed: " << strerror(errno);
|
||||
} else {
|
||||
LOG_TRACE << "Socket fd=" << fd_ << " TCP keepalive: idle=" << idle
|
||||
<< " interval=" << interval << " count=" << count;
|
||||
}
|
||||
}
|
||||
|
||||
void setRecvBuffer(int size)
|
||||
{
|
||||
if (setsockopt(fd_, SOL_SOCKET, SO_RCVBUF, &size, sizeof(size)) < 0) {
|
||||
LOG_ERROR << "setsockopt SO_RCVBUF failed: " << strerror(errno);
|
||||
} else {
|
||||
LOG_TRACE << "Socket fd=" << fd_ << " SO_RCVBUF=" << size;
|
||||
}
|
||||
}
|
||||
|
||||
void setSendBuffer(int size)
|
||||
{
|
||||
if (setsockopt(fd_, SOL_SOCKET, SO_SNDBUF, &size, sizeof(size)) < 0) {
|
||||
LOG_ERROR << "setsockopt SO_SNDBUF failed: " << strerror(errno);
|
||||
} else {
|
||||
LOG_TRACE << "Socket fd=" << fd_ << " SO_SNDBUF=" << size;
|
||||
}
|
||||
}
|
||||
|
||||
ssize_t read(void* buf, size_t len)
|
||||
{
|
||||
ssize_t n = ::read(fd_, buf, len);
|
||||
if (n < 0 && errno != EAGAIN && errno != EWOULDBLOCK) {
|
||||
LOG_ERROR << "Socket read failed: " << strerror(errno);
|
||||
}
|
||||
return n;
|
||||
}
|
||||
|
||||
ssize_t write(const void* buf, size_t len)
|
||||
{
|
||||
ssize_t n = ::write(fd_, buf, len);
|
||||
if (n < 0 && errno != EAGAIN && errno != EWOULDBLOCK) {
|
||||
LOG_ERROR << "Socket write failed: " << strerror(errno);
|
||||
}
|
||||
return n;
|
||||
}
|
||||
|
||||
ssize_t sendTo(const void* buf, size_t len, const InetAddress& addr)
|
||||
{
|
||||
ssize_t n = sendto(fd_, buf, len, 0, addr.getSockAddr(), addr.getSockLen());
|
||||
if (n < 0 && errno != EAGAIN && errno != EWOULDBLOCK) {
|
||||
LOG_ERROR << "Socket sendto failed: " << strerror(errno);
|
||||
}
|
||||
return n;
|
||||
}
|
||||
|
||||
ssize_t recvFrom(void* buf, size_t len, InetAddress& addr)
|
||||
{
|
||||
sockaddr_in6 sockaddr;
|
||||
socklen_t addrlen = sizeof(sockaddr);
|
||||
ssize_t n = recvfrom(fd_, buf, len, 0, reinterpret_cast<struct sockaddr*>(&sockaddr), &addrlen);
|
||||
|
||||
if (n >= 0) {
|
||||
if (sockaddr.sin6_family == AF_INET) {
|
||||
addr = InetAddress(*reinterpret_cast<sockaddr_in*>(&sockaddr));
|
||||
} else {
|
||||
addr = InetAddress(sockaddr);
|
||||
}
|
||||
} else if (errno != EAGAIN && errno != EWOULDBLOCK) {
|
||||
LOG_ERROR << "Socket recvfrom failed: " << strerror(errno);
|
||||
}
|
||||
|
||||
return n;
|
||||
}
|
||||
|
||||
void shutdownWrite()
|
||||
{
|
||||
if (shutdown(fd_, SHUT_WR) < 0) {
|
||||
LOG_ERROR << "Socket shutdown write failed: " << strerror(errno);
|
||||
} else {
|
||||
LOG_DEBUG << "Socket fd=" << fd_ << " shutdown write";
|
||||
}
|
||||
}
|
||||
|
||||
void shutdownRead()
|
||||
{
|
||||
if (shutdown(fd_, SHUT_RD) < 0) {
|
||||
LOG_ERROR << "Socket shutdown read failed: " << strerror(errno);
|
||||
} else {
|
||||
LOG_DEBUG << "Socket fd=" << fd_ << " shutdown read";
|
||||
}
|
||||
}
|
||||
|
||||
int getSocketError()
|
||||
{
|
||||
int optval;
|
||||
socklen_t optlen = sizeof(optval);
|
||||
if (getsockopt(fd_, SOL_SOCKET, SO_ERROR, &optval, &optlen) < 0) {
|
||||
return errno;
|
||||
}
|
||||
return optval;
|
||||
}
|
||||
|
||||
int fd() const { return fd_; }
|
||||
|
||||
static InetAddress getLocalAddr(int sockfd)
|
||||
{
|
||||
sockaddr_in6 addr;
|
||||
socklen_t addrlen = sizeof(addr);
|
||||
if (getsockname(sockfd, reinterpret_cast<sockaddr*>(&addr), &addrlen) < 0) {
|
||||
LOG_ERROR << "getsockname failed: " << strerror(errno);
|
||||
return InetAddress();
|
||||
}
|
||||
|
||||
if (addr.sin6_family == AF_INET) {
|
||||
return InetAddress(*reinterpret_cast<sockaddr_in*>(&addr));
|
||||
}
|
||||
return InetAddress(addr);
|
||||
}
|
||||
|
||||
static InetAddress getPeerAddr(int sockfd)
|
||||
{
|
||||
sockaddr_in6 addr;
|
||||
socklen_t addrlen = sizeof(addr);
|
||||
if (getpeername(sockfd, reinterpret_cast<sockaddr*>(&addr), &addrlen) < 0) {
|
||||
LOG_ERROR << "getpeername failed: " << strerror(errno);
|
||||
return InetAddress();
|
||||
}
|
||||
|
||||
if (addr.sin6_family == AF_INET) {
|
||||
return InetAddress(*reinterpret_cast<sockaddr_in*>(&addr));
|
||||
}
|
||||
return InetAddress(addr);
|
||||
}
|
||||
|
||||
bool isSelfConnected()
|
||||
{
|
||||
return getLocalAddr(fd_) == getPeerAddr(fd_);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace reactor
|
364
lib/TcpConnection.hpp
Normal file
364
lib/TcpConnection.hpp
Normal file
@ -0,0 +1,364 @@
|
||||
#pragma once
|
||||
|
||||
#include "Core.hpp"
|
||||
#include "Socket.hpp"
|
||||
#include "Buffer.hpp"
|
||||
#include "Utilities.hpp"
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <functional>
|
||||
#include <errno.h>
|
||||
|
||||
namespace reactor {
|
||||
|
||||
class TcpConnection;
|
||||
using TcpConnectionPtr = std::shared_ptr<TcpConnection>;
|
||||
using MessageCallback = std::function<void(const TcpConnectionPtr&, Buffer&)>;
|
||||
using ConnectionCallback = std::function<void(const TcpConnectionPtr&)>;
|
||||
using WriteCompleteCallback = std::function<void(const TcpConnectionPtr&)>;
|
||||
using HighWaterMarkCallback = std::function<void(const TcpConnectionPtr&, size_t)>;
|
||||
|
||||
class TcpConnection : public NonCopyable, public std::enable_shared_from_this<TcpConnection>
|
||||
{
|
||||
public:
|
||||
enum StateE { kDisconnected, kConnecting, kConnected, kDisconnecting };
|
||||
|
||||
private:
|
||||
EventLoop* loop_;
|
||||
Socket socket_;
|
||||
std::unique_ptr<Channel> channel_;
|
||||
InetAddress localAddr_;
|
||||
InetAddress peerAddr_;
|
||||
std::string name_;
|
||||
StateE state_;
|
||||
Buffer inputBuffer_;
|
||||
Buffer outputBuffer_;
|
||||
MessageCallback messageCallback_;
|
||||
ConnectionCallback connectionCallback_;
|
||||
ConnectionCallback closeCallback_;
|
||||
WriteCompleteCallback writeCompleteCallback_;
|
||||
HighWaterMarkCallback highWaterMarkCallback_;
|
||||
size_t highWaterMark_;
|
||||
|
||||
void setState(StateE s) { state_ = s; }
|
||||
|
||||
void handleRead()
|
||||
{
|
||||
loop_->assertInLoopThread();
|
||||
int savedErrno = 0;
|
||||
ssize_t n = inputBuffer_.readFd(socket_.fd(), &savedErrno);
|
||||
|
||||
if (n > 0) {
|
||||
LOG_TRACE << "TcpConnection " << name_ << " read " << n << " bytes";
|
||||
if (messageCallback_) {
|
||||
messageCallback_(shared_from_this(), inputBuffer_);
|
||||
}
|
||||
} else if (n == 0) {
|
||||
LOG_DEBUG << "TcpConnection " << name_ << " peer closed";
|
||||
handleClose();
|
||||
} else {
|
||||
errno = savedErrno;
|
||||
LOG_ERROR << "TcpConnection " << name_ << " read error: " << strerror(savedErrno);
|
||||
handleError();
|
||||
}
|
||||
}
|
||||
|
||||
void handleWrite()
|
||||
{
|
||||
loop_->assertInLoopThread();
|
||||
if (channel_->isWriting()) {
|
||||
ssize_t n = socket_.write(outputBuffer_.peek(), outputBuffer_.readableBytes());
|
||||
if (n > 0) {
|
||||
outputBuffer_.retrieve(n);
|
||||
LOG_TRACE << "TcpConnection " << name_ << " wrote " << n << " bytes, "
|
||||
<< outputBuffer_.readableBytes() << " bytes left";
|
||||
|
||||
if (outputBuffer_.readableBytes() == 0) {
|
||||
channel_->disableWriting();
|
||||
if (writeCompleteCallback_) {
|
||||
loop_->queueInLoop([self = shared_from_this()]() {
|
||||
self->writeCompleteCallback_(self);
|
||||
});
|
||||
}
|
||||
if (state_ == kDisconnecting) {
|
||||
shutdownInLoop();
|
||||
}
|
||||
}
|
||||
} else {
|
||||
LOG_ERROR << "TcpConnection " << name_ << " write error: " << strerror(errno);
|
||||
}
|
||||
} else {
|
||||
LOG_TRACE << "TcpConnection " << name_ << " not writing, ignore";
|
||||
}
|
||||
}
|
||||
|
||||
void handleClose()
|
||||
{
|
||||
loop_->assertInLoopThread();
|
||||
LOG_DEBUG << "TcpConnection " << name_ << " state=" << stateToString();
|
||||
assert(state_ == kConnected || state_ == kDisconnecting);
|
||||
setState(kDisconnected);
|
||||
channel_->disableAll();
|
||||
|
||||
auto guardThis = shared_from_this();
|
||||
if (connectionCallback_) {
|
||||
connectionCallback_(guardThis);
|
||||
}
|
||||
if (closeCallback_) {
|
||||
closeCallback_(guardThis);
|
||||
}
|
||||
}
|
||||
|
||||
void handleError()
|
||||
{
|
||||
int err = socket_.getSocketError();
|
||||
LOG_ERROR << "TcpConnection " << name_ << " SO_ERROR=" << err << " " << strerror(err);
|
||||
handleClose();
|
||||
}
|
||||
|
||||
void sendInLoop(const char* data, size_t len)
|
||||
{
|
||||
loop_->assertInLoopThread();
|
||||
ssize_t nwrote = 0;
|
||||
size_t remaining = len;
|
||||
bool faultError = false;
|
||||
|
||||
if (state_ == kDisconnected) {
|
||||
LOG_WARN << "TcpConnection " << name_ << " disconnected, give up writing";
|
||||
return;
|
||||
}
|
||||
|
||||
if (!channel_->isWriting() && outputBuffer_.readableBytes() == 0) {
|
||||
nwrote = socket_.write(data, len);
|
||||
if (nwrote >= 0) {
|
||||
remaining = len - nwrote;
|
||||
if (remaining == 0 && writeCompleteCallback_) {
|
||||
loop_->queueInLoop([self = shared_from_this()]() {
|
||||
self->writeCompleteCallback_(self);
|
||||
});
|
||||
}
|
||||
} else {
|
||||
nwrote = 0;
|
||||
if (errno != EWOULDBLOCK) {
|
||||
LOG_ERROR << "TcpConnection " << name_ << " send error: " << strerror(errno);
|
||||
if (errno == EPIPE || errno == ECONNRESET) {
|
||||
faultError = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
assert(remaining <= len);
|
||||
if (!faultError && remaining > 0) {
|
||||
size_t oldLen = outputBuffer_.readableBytes();
|
||||
if (oldLen + remaining >= highWaterMark_ &&
|
||||
oldLen < highWaterMark_ &&
|
||||
highWaterMarkCallback_) {
|
||||
loop_->queueInLoop([self = shared_from_this(), mark = oldLen + remaining]() {
|
||||
self->highWaterMarkCallback_(self, mark);
|
||||
});
|
||||
}
|
||||
outputBuffer_.append(data + nwrote, remaining);
|
||||
if (!channel_->isWriting()) {
|
||||
channel_->enableWriting();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void shutdownInLoop()
|
||||
{
|
||||
loop_->assertInLoopThread();
|
||||
if (!channel_->isWriting()) {
|
||||
socket_.shutdownWrite();
|
||||
}
|
||||
}
|
||||
|
||||
void forceCloseInLoop()
|
||||
{
|
||||
loop_->assertInLoopThread();
|
||||
if (state_ == kConnected || state_ == kDisconnecting) {
|
||||
handleClose();
|
||||
}
|
||||
}
|
||||
|
||||
std::string stateToString() const
|
||||
{
|
||||
switch (state_) {
|
||||
case kDisconnected: return "kDisconnected";
|
||||
case kConnecting: return "kConnecting";
|
||||
case kConnected: return "kConnected";
|
||||
case kDisconnecting: return "kDisconnecting";
|
||||
default: return "unknown state";
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
TcpConnection(EventLoop* loop, const std::string& name, int sockfd,
|
||||
const InetAddress& localAddr, const InetAddress& peerAddr)
|
||||
: loop_(loop), socket_(sockfd), channel_(std::make_unique<Channel>(loop, sockfd)),
|
||||
localAddr_(localAddr), peerAddr_(peerAddr), name_(name), state_(kConnecting),
|
||||
highWaterMark_(64*1024*1024)
|
||||
{
|
||||
channel_->setReadCallback([this]() { handleRead(); });
|
||||
channel_->setWriteCallback([this]() { handleWrite(); });
|
||||
channel_->setCloseCallback([this]() { handleClose(); });
|
||||
channel_->setErrorCallback([this]() { handleError(); });
|
||||
|
||||
socket_.setKeepAlive(true);
|
||||
socket_.setTcpNoDelay(true);
|
||||
|
||||
LOG_INFO << "TcpConnection " << name_ << " created from "
|
||||
<< localAddr_.toIpPort() << " to " << peerAddr_.toIpPort() << " fd=" << sockfd;
|
||||
}
|
||||
|
||||
~TcpConnection()
|
||||
{
|
||||
LOG_INFO << "TcpConnection " << name_ << " destroyed state=" << stateToString();
|
||||
assert(state_ == kDisconnected);
|
||||
}
|
||||
|
||||
void connectEstablished()
|
||||
{
|
||||
loop_->assertInLoopThread();
|
||||
assert(state_ == kConnecting);
|
||||
setState(kConnected);
|
||||
channel_->tie(shared_from_this());
|
||||
channel_->enableReading();
|
||||
|
||||
if (connectionCallback_) {
|
||||
connectionCallback_(shared_from_this());
|
||||
}
|
||||
|
||||
LOG_INFO << "TcpConnection " << name_ << " established";
|
||||
}
|
||||
|
||||
void connectDestroyed()
|
||||
{
|
||||
loop_->assertInLoopThread();
|
||||
if (state_ == kConnected) {
|
||||
setState(kDisconnected);
|
||||
channel_->disableAll();
|
||||
if (connectionCallback_) {
|
||||
connectionCallback_(shared_from_this());
|
||||
}
|
||||
}
|
||||
channel_->remove();
|
||||
LOG_INFO << "TcpConnection " << name_ << " destroyed";
|
||||
}
|
||||
|
||||
const std::string& name() const { return name_; }
|
||||
const InetAddress& localAddr() const { return localAddr_; }
|
||||
const InetAddress& peerAddr() const { return peerAddr_; }
|
||||
bool connected() const { return state_ == kConnected; }
|
||||
bool disconnected() const { return state_ == kDisconnected; }
|
||||
EventLoop* getLoop() const { return loop_; }
|
||||
|
||||
void send(const std::string& message)
|
||||
{
|
||||
if (state_ == kConnected) {
|
||||
if (loop_->isInLoopThread()) {
|
||||
sendInLoop(message.data(), message.size());
|
||||
} else {
|
||||
loop_->runInLoop([self = shared_from_this(), message]() {
|
||||
self->sendInLoop(message.data(), message.size());
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void send(const char* data, size_t len)
|
||||
{
|
||||
if (state_ == kConnected) {
|
||||
if (loop_->isInLoopThread()) {
|
||||
sendInLoop(data, len);
|
||||
} else {
|
||||
std::string message(data, len);
|
||||
loop_->runInLoop([self = shared_from_this(), message]() {
|
||||
self->sendInLoop(message.data(), message.size());
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void send(Buffer& buffer)
|
||||
{
|
||||
if (state_ == kConnected) {
|
||||
if (loop_->isInLoopThread()) {
|
||||
sendInLoop(buffer.peek(), buffer.readableBytes());
|
||||
buffer.retrieveAll();
|
||||
} else {
|
||||
std::string message = buffer.readAll();
|
||||
loop_->runInLoop([self = shared_from_this(), message]() {
|
||||
self->sendInLoop(message.data(), message.size());
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void shutdown()
|
||||
{
|
||||
if (state_ == kConnected) {
|
||||
setState(kDisconnecting);
|
||||
loop_->runInLoop([self = shared_from_this()]() {
|
||||
self->shutdownInLoop();
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
void forceClose()
|
||||
{
|
||||
if (state_ == kConnected || state_ == kDisconnecting) {
|
||||
setState(kDisconnecting);
|
||||
loop_->queueInLoop([self = shared_from_this()]() {
|
||||
self->forceCloseInLoop();
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
void forceCloseWithDelay(double seconds)
|
||||
{
|
||||
if (state_ == kConnected || state_ == kDisconnecting) {
|
||||
setState(kDisconnecting);
|
||||
loop_->runAfter(Duration(static_cast<int>(seconds * 1000)),
|
||||
[self = shared_from_this()]() {
|
||||
self->forceClose();
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
void setTcpNoDelay(bool on) { socket_.setTcpNoDelay(on); }
|
||||
void setTcpKeepAlive(bool on) { socket_.setKeepAlive(on); }
|
||||
|
||||
void setMessageCallback(MessageCallback cb) { messageCallback_ = std::move(cb); }
|
||||
void setConnectionCallback(ConnectionCallback cb) { connectionCallback_ = std::move(cb); }
|
||||
void setCloseCallback(ConnectionCallback cb) { closeCallback_ = std::move(cb); }
|
||||
void setWriteCompleteCallback(WriteCompleteCallback cb) { writeCompleteCallback_ = std::move(cb); }
|
||||
void setHighWaterMarkCallback(HighWaterMarkCallback cb, size_t highWaterMark)
|
||||
{
|
||||
highWaterMarkCallback_ = std::move(cb);
|
||||
highWaterMark_ = highWaterMark;
|
||||
}
|
||||
|
||||
Buffer* inputBuffer() { return &inputBuffer_; }
|
||||
Buffer* outputBuffer() { return &outputBuffer_; }
|
||||
|
||||
void startRead()
|
||||
{
|
||||
loop_->runInLoop([self = shared_from_this()]() {
|
||||
if (!self->channel_->isReading()) {
|
||||
self->channel_->enableReading();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
void stopRead()
|
||||
{
|
||||
loop_->runInLoop([self = shared_from_this()]() {
|
||||
if (self->channel_->isReading()) {
|
||||
self->channel_->disableReading();
|
||||
}
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace reactor
|
236
lib/TcpServer.hpp
Normal file
236
lib/TcpServer.hpp
Normal file
@ -0,0 +1,236 @@
|
||||
#pragma once
|
||||
|
||||
#include "Core.hpp"
|
||||
#include "Socket.hpp"
|
||||
#include "TcpConnection.hpp"
|
||||
#include "EventLoopThread.hpp"
|
||||
#include "Utilities.hpp"
|
||||
#include <unordered_map>
|
||||
#include <string>
|
||||
#include <atomic>
|
||||
#include <functional>
|
||||
|
||||
namespace reactor {
|
||||
|
||||
using NewConnectionCallback = std::function<void(int, const InetAddress&)>;
|
||||
|
||||
class Acceptor : public NonCopyable
|
||||
{
|
||||
private:
|
||||
EventLoop* loop_;
|
||||
Socket acceptSocket_;
|
||||
std::unique_ptr<Channel> acceptChannel_;
|
||||
NewConnectionCallback newConnectionCallback_;
|
||||
bool listening_;
|
||||
int idleFd_;
|
||||
|
||||
void handleRead()
|
||||
{
|
||||
loop_->assertInLoopThread();
|
||||
InetAddress peerAddr;
|
||||
int connfd = acceptSocket_.accept(peerAddr);
|
||||
|
||||
if (connfd >= 0) {
|
||||
if (newConnectionCallback_) {
|
||||
newConnectionCallback_(connfd, peerAddr);
|
||||
} else {
|
||||
close(connfd);
|
||||
LOG_WARN << "Acceptor no callback for new connection, closing fd=" << connfd;
|
||||
}
|
||||
} else {
|
||||
LOG_ERROR << "Acceptor accept failed: " << strerror(errno);
|
||||
if (errno == EMFILE) {
|
||||
close(idleFd_);
|
||||
idleFd_ = ::accept(acceptSocket_.fd(), nullptr, nullptr);
|
||||
close(idleFd_);
|
||||
idleFd_ = ::open("/dev/null", O_RDONLY | O_CLOEXEC);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
Acceptor(EventLoop* loop, const InetAddress& listenAddr, bool reusePort = true)
|
||||
: loop_(loop), acceptSocket_(Socket::createTcp(listenAddr.isIpV6())),
|
||||
acceptChannel_(std::make_unique<Channel>(loop, acceptSocket_.fd())),
|
||||
listening_(false), idleFd_(::open("/dev/null", O_RDONLY | O_CLOEXEC))
|
||||
{
|
||||
acceptSocket_.setReuseAddr(true);
|
||||
if (reusePort) {
|
||||
acceptSocket_.setReusePort(true);
|
||||
}
|
||||
acceptSocket_.bind(listenAddr);
|
||||
acceptChannel_->setReadCallback([this]() { handleRead(); });
|
||||
LOG_INFO << "Acceptor created for " << listenAddr.toIpPort();
|
||||
}
|
||||
|
||||
~Acceptor()
|
||||
{
|
||||
acceptChannel_->disableAll();
|
||||
acceptChannel_->remove();
|
||||
close(idleFd_);
|
||||
LOG_INFO << "Acceptor destroyed";
|
||||
}
|
||||
|
||||
void listen()
|
||||
{
|
||||
loop_->assertInLoopThread();
|
||||
listening_ = true;
|
||||
acceptSocket_.listen();
|
||||
acceptChannel_->enableReading();
|
||||
LOG_INFO << "Acceptor listening";
|
||||
}
|
||||
|
||||
bool listening() const { return listening_; }
|
||||
|
||||
void setNewConnectionCallback(NewConnectionCallback cb)
|
||||
{
|
||||
newConnectionCallback_ = std::move(cb);
|
||||
}
|
||||
};
|
||||
|
||||
class TcpServer : public NonCopyable
|
||||
{
|
||||
private:
|
||||
EventLoop* loop_;
|
||||
std::string name_;
|
||||
std::unique_ptr<Acceptor> acceptor_;
|
||||
std::unique_ptr<EventLoopThreadPool> threadPool_;
|
||||
MessageCallback messageCallback_;
|
||||
ConnectionCallback connectionCallback_;
|
||||
WriteCompleteCallback writeCompleteCallback_;
|
||||
|
||||
std::unordered_map<std::string, TcpConnectionPtr> connections_;
|
||||
std::atomic<int> nextConnId_;
|
||||
bool started_;
|
||||
|
||||
void newConnection(int sockfd, const InetAddress& peerAddr)
|
||||
{
|
||||
loop_->assertInLoopThread();
|
||||
EventLoop* ioLoop = threadPool_->getNextLoop();
|
||||
if (!ioLoop) ioLoop = loop_;
|
||||
|
||||
std::string connName = name_ + "-" + peerAddr.toIpPort() + "#" + std::to_string(nextConnId_++);
|
||||
InetAddress localAddr = Socket::getLocalAddr(sockfd);
|
||||
|
||||
LOG_INFO << "TcpServer new connection " << connName << " from " << peerAddr.toIpPort();
|
||||
|
||||
auto conn = std::make_shared<TcpConnection>(ioLoop, connName, sockfd, localAddr, peerAddr);
|
||||
connections_[connName] = conn;
|
||||
|
||||
conn->setMessageCallback(messageCallback_);
|
||||
conn->setConnectionCallback(connectionCallback_);
|
||||
conn->setWriteCompleteCallback(writeCompleteCallback_);
|
||||
conn->setCloseCallback([this](const TcpConnectionPtr& conn) {
|
||||
removeConnection(conn);
|
||||
});
|
||||
|
||||
ioLoop->runInLoop([conn]() { conn->connectEstablished(); });
|
||||
}
|
||||
|
||||
void removeConnection(const TcpConnectionPtr& conn)
|
||||
{
|
||||
loop_->runInLoop([this, conn]() {
|
||||
LOG_INFO << "TcpServer removing connection " << conn->name();
|
||||
size_t n = connections_.erase(conn->name());
|
||||
assert(n == 1);
|
||||
|
||||
EventLoop* ioLoop = conn->getLoop();
|
||||
ioLoop->queueInLoop([conn]() { conn->connectDestroyed(); });
|
||||
});
|
||||
}
|
||||
|
||||
void removeConnectionInLoop(const TcpConnectionPtr& conn)
|
||||
{
|
||||
loop_->assertInLoopThread();
|
||||
LOG_INFO << "TcpServer removing connection " << conn->name();
|
||||
size_t n = connections_.erase(conn->name());
|
||||
assert(n == 1);
|
||||
|
||||
EventLoop* ioLoop = conn->getLoop();
|
||||
ioLoop->queueInLoop([conn]() { conn->connectDestroyed(); });
|
||||
}
|
||||
|
||||
public:
|
||||
TcpServer(EventLoop* loop, const InetAddress& listenAddr, const std::string& name,
|
||||
bool reusePort = true)
|
||||
: loop_(loop), name_(name),
|
||||
acceptor_(std::make_unique<Acceptor>(loop, listenAddr, reusePort)),
|
||||
threadPool_(std::make_unique<EventLoopThreadPool>(0, name + "-EventLoop")),
|
||||
nextConnId_(1), started_(false)
|
||||
{
|
||||
acceptor_->setNewConnectionCallback([this](int sockfd, const InetAddress& addr) {
|
||||
newConnection(sockfd, addr);
|
||||
});
|
||||
LOG_INFO << "TcpServer " << name_ << " created for " << listenAddr.toIpPort();
|
||||
}
|
||||
|
||||
~TcpServer()
|
||||
{
|
||||
loop_->assertInLoopThread();
|
||||
LOG_INFO << "TcpServer " << name_ << " destructing with " << connections_.size() << " connections";
|
||||
|
||||
for (auto& item : connections_) {
|
||||
auto conn = item.second;
|
||||
auto ioLoop = conn->getLoop();
|
||||
ioLoop->runInLoop([conn]() { conn->forceClose(); });
|
||||
}
|
||||
}
|
||||
|
||||
void setThreadNum(int numThreads)
|
||||
{
|
||||
assert(0 <= numThreads);
|
||||
threadPool_ = std::make_unique<EventLoopThreadPool>(numThreads, name_ + "-EventLoop");
|
||||
LOG_INFO << "TcpServer " << name_ << " set thread pool size to " << numThreads;
|
||||
}
|
||||
|
||||
void start()
|
||||
{
|
||||
if (!started_) {
|
||||
started_ = true;
|
||||
if (!acceptor_->listening()) {
|
||||
loop_->runInLoop([this]() { acceptor_->listen(); });
|
||||
}
|
||||
LOG_INFO << "TcpServer " << name_ << " started with " << threadPool_->size() << " threads";
|
||||
}
|
||||
}
|
||||
|
||||
void setMessageCallback(MessageCallback cb) { messageCallback_ = std::move(cb); }
|
||||
void setConnectionCallback(ConnectionCallback cb) { connectionCallback_ = std::move(cb); }
|
||||
void setWriteCompleteCallback(WriteCompleteCallback cb) { writeCompleteCallback_ = std::move(cb); }
|
||||
|
||||
const std::string& name() const { return name_; }
|
||||
const char* ipPort() const { return acceptor_ ? "listening" : "not-listening"; }
|
||||
EventLoop* getLoop() const { return loop_; }
|
||||
|
||||
size_t numConnections() const
|
||||
{
|
||||
return connections_.size();
|
||||
}
|
||||
|
||||
std::vector<TcpConnectionPtr> getConnections() const
|
||||
{
|
||||
std::vector<TcpConnectionPtr> result;
|
||||
result.reserve(connections_.size());
|
||||
for (const auto& item : connections_) {
|
||||
result.push_back(item.second);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
TcpConnectionPtr getConnection(const std::string& name) const
|
||||
{
|
||||
auto it = connections_.find(name);
|
||||
return it != connections_.end() ? it->second : TcpConnectionPtr();
|
||||
}
|
||||
|
||||
void forceCloseAllConnections()
|
||||
{
|
||||
for (auto& item : connections_) {
|
||||
auto conn = item.second;
|
||||
auto ioLoop = conn->getLoop();
|
||||
ioLoop->runInLoop([conn]() { conn->forceClose(); });
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace reactor
|
346
lib/Utilities.hpp
Normal file
346
lib/Utilities.hpp
Normal file
@ -0,0 +1,346 @@
|
||||
#pragma once
|
||||
|
||||
#include <future>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <functional>
|
||||
#include <atomic>
|
||||
#include <vector>
|
||||
#include <queue>
|
||||
#include <mutex>
|
||||
#include <thread>
|
||||
#include <condition_variable>
|
||||
#include <iostream>
|
||||
#include <fstream>
|
||||
#include <chrono>
|
||||
#include <type_traits>
|
||||
#include <algorithm>
|
||||
#include <cstring>
|
||||
#include <netinet/in.h>
|
||||
#include <sstream>
|
||||
|
||||
namespace reactor {
|
||||
|
||||
// NonCopyable base class
|
||||
class NonCopyable
|
||||
{
|
||||
protected:
|
||||
NonCopyable() = default;
|
||||
~NonCopyable() = default;
|
||||
NonCopyable(const NonCopyable&) = delete;
|
||||
NonCopyable& operator=(const NonCopyable&) = delete;
|
||||
NonCopyable(NonCopyable&&) noexcept = default;
|
||||
NonCopyable& operator=(NonCopyable&&) noexcept = default;
|
||||
};
|
||||
|
||||
// Network byte order utilities
|
||||
inline uint64_t hton64(uint64_t n)
|
||||
{
|
||||
static const int one = 1;
|
||||
static const char sig = *(char*)&one;
|
||||
if (sig == 0) return n;
|
||||
char* ptr = reinterpret_cast<char*>(&n);
|
||||
std::reverse(ptr, ptr + sizeof(uint64_t));
|
||||
return n;
|
||||
}
|
||||
|
||||
inline uint64_t ntoh64(uint64_t n) { return hton64(n); }
|
||||
|
||||
// Lock-free MPSC queue
|
||||
template<typename T>
|
||||
class LockFreeQueue : public NonCopyable
|
||||
{
|
||||
private:
|
||||
struct Node
|
||||
{
|
||||
Node() = default;
|
||||
Node(const T& data) : data_(std::make_unique<T>(data)) {}
|
||||
Node(T&& data) : data_(std::make_unique<T>(std::move(data))) {}
|
||||
std::unique_ptr<T> data_;
|
||||
std::atomic<Node*> next_{nullptr};
|
||||
};
|
||||
|
||||
std::atomic<Node*> head_;
|
||||
std::atomic<Node*> tail_;
|
||||
|
||||
public:
|
||||
LockFreeQueue() : head_(new Node), tail_(head_.load()) {}
|
||||
|
||||
~LockFreeQueue()
|
||||
{
|
||||
T output;
|
||||
while (dequeue(output)) {}
|
||||
delete head_.load();
|
||||
}
|
||||
|
||||
void enqueue(T&& input)
|
||||
{
|
||||
Node* node = new Node(std::move(input));
|
||||
Node* prevhead = head_.exchange(node, std::memory_order_acq_rel);
|
||||
prevhead->next_.store(node, std::memory_order_release);
|
||||
}
|
||||
|
||||
void enqueue(const T& input)
|
||||
{
|
||||
Node* node = new Node(input);
|
||||
Node* prevhead = head_.exchange(node, std::memory_order_acq_rel);
|
||||
prevhead->next_.store(node, std::memory_order_release);
|
||||
}
|
||||
|
||||
bool dequeue(T& output)
|
||||
{
|
||||
Node* tail = tail_.load(std::memory_order_relaxed);
|
||||
Node* next = tail->next_.load(std::memory_order_acquire);
|
||||
|
||||
if (next == nullptr) return false;
|
||||
|
||||
output = std::move(*next->data_);
|
||||
tail_.store(next, std::memory_order_release);
|
||||
delete tail;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool empty()
|
||||
{
|
||||
Node* tail = tail_.load(std::memory_order_relaxed);
|
||||
Node* next = tail->next_.load(std::memory_order_acquire);
|
||||
return next == nullptr;
|
||||
}
|
||||
};
|
||||
|
||||
// Object Pool
|
||||
template<typename T>
|
||||
class ObjectPool : public NonCopyable, public std::enable_shared_from_this<ObjectPool<T>>
|
||||
{
|
||||
private:
|
||||
std::vector<T*> objects_;
|
||||
std::mutex mutex_;
|
||||
|
||||
public:
|
||||
std::shared_ptr<T> getObject()
|
||||
{
|
||||
static_assert(!std::is_pointer_v<T>, "ObjectPool type cannot be pointer");
|
||||
|
||||
T* p = nullptr;
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
if (!objects_.empty()) {
|
||||
p = objects_.back();
|
||||
objects_.pop_back();
|
||||
}
|
||||
}
|
||||
|
||||
if (!p) p = new T;
|
||||
|
||||
std::weak_ptr<ObjectPool<T>> weakPtr = this->shared_from_this();
|
||||
return std::shared_ptr<T>(p, [weakPtr](T* ptr) {
|
||||
auto self = weakPtr.lock();
|
||||
if (self) {
|
||||
std::lock_guard<std::mutex> lock(self->mutex_);
|
||||
self->objects_.push_back(ptr);
|
||||
} else {
|
||||
delete ptr;
|
||||
}
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
// Simple Logger
|
||||
enum class LogLevel { TRACE, DEBUG, INFO, WARN, ERROR, FATAL };
|
||||
|
||||
class Logger : public NonCopyable
|
||||
{
|
||||
private:
|
||||
static inline LogLevel level_ = LogLevel::INFO;
|
||||
static inline std::unique_ptr<std::ofstream> file_;
|
||||
static inline std::mutex mutex_;
|
||||
|
||||
std::ostringstream stream_;
|
||||
LogLevel msgLevel_;
|
||||
|
||||
static const char* levelString(LogLevel level)
|
||||
{
|
||||
switch (level) {
|
||||
case LogLevel::TRACE: return "TRACE";
|
||||
case LogLevel::DEBUG: return "DEBUG";
|
||||
case LogLevel::INFO: return "INFO ";
|
||||
case LogLevel::WARN: return "WARN ";
|
||||
case LogLevel::ERROR: return "ERROR";
|
||||
case LogLevel::FATAL: return "FATAL";
|
||||
}
|
||||
return "UNKNOWN";
|
||||
}
|
||||
|
||||
static std::string timestamp()
|
||||
{
|
||||
auto now = std::chrono::system_clock::now();
|
||||
auto time_t = std::chrono::system_clock::to_time_t(now);
|
||||
auto ms = std::chrono::duration_cast<std::chrono::milliseconds>(
|
||||
now.time_since_epoch()) % 1000;
|
||||
|
||||
char buf[64];
|
||||
std::strftime(buf, sizeof(buf), "%Y%m%d %H:%M:%S", std::localtime(&time_t));
|
||||
return std::string(buf) + "." + std::to_string(ms.count());
|
||||
}
|
||||
|
||||
public:
|
||||
Logger(LogLevel level) : msgLevel_(level)
|
||||
{
|
||||
if (level >= level_) {
|
||||
stream_ << timestamp() << " " << levelString(level) << " ";
|
||||
}
|
||||
}
|
||||
|
||||
~Logger()
|
||||
{
|
||||
if (msgLevel_ >= level_) {
|
||||
stream_ << "\n";
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
if (file_ && file_->is_open()) {
|
||||
*file_ << stream_.str();
|
||||
file_->flush();
|
||||
} else {
|
||||
std::cout << stream_.str();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
Logger& operator<<(const T& value)
|
||||
{
|
||||
if (msgLevel_ >= level_) {
|
||||
stream_ << value;
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
static void setLevel(LogLevel level) { level_ = level; }
|
||||
static void setLogFile(const std::string& filename)
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
file_ = std::make_unique<std::ofstream>(filename, std::ios::app);
|
||||
}
|
||||
};
|
||||
|
||||
// Task Queue interface
|
||||
class TaskQueue : public NonCopyable
|
||||
{
|
||||
public:
|
||||
virtual ~TaskQueue() = default;
|
||||
virtual void runTaskInQueue(const std::function<void()>& task) = 0;
|
||||
virtual void runTaskInQueue(std::function<void()>&& task) = 0;
|
||||
virtual std::string getName() const { return ""; }
|
||||
|
||||
void syncTaskInQueue(const std::function<void()>& task)
|
||||
{
|
||||
std::promise<void> promise;
|
||||
auto future = promise.get_future();
|
||||
runTaskInQueue([&]() {
|
||||
task();
|
||||
promise.set_value();
|
||||
});
|
||||
future.wait();
|
||||
}
|
||||
};
|
||||
|
||||
// Concurrent Task Queue
|
||||
class ConcurrentTaskQueue : public TaskQueue
|
||||
{
|
||||
private:
|
||||
std::vector<std::thread> threads_;
|
||||
std::queue<std::function<void()>> taskQueue_;
|
||||
std::mutex taskMutex_;
|
||||
std::condition_variable taskCond_;
|
||||
std::atomic<bool> stop_{false};
|
||||
std::string name_;
|
||||
|
||||
void workerThread(int threadId)
|
||||
{
|
||||
while (!stop_) {
|
||||
std::function<void()> task;
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(taskMutex_);
|
||||
taskCond_.wait(lock, [this]() { return stop_ || !taskQueue_.empty(); });
|
||||
|
||||
if (taskQueue_.empty()) continue;
|
||||
|
||||
task = std::move(taskQueue_.front());
|
||||
taskQueue_.pop();
|
||||
}
|
||||
task();
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
ConcurrentTaskQueue(size_t threadNum, const std::string& name = "ConcurrentTaskQueue")
|
||||
: name_(name)
|
||||
{
|
||||
for (size_t i = 0; i < threadNum; ++i) {
|
||||
threads_.emplace_back(&ConcurrentTaskQueue::workerThread, this, i);
|
||||
}
|
||||
}
|
||||
|
||||
~ConcurrentTaskQueue()
|
||||
{
|
||||
stop_ = true;
|
||||
taskCond_.notify_all();
|
||||
for (auto& t : threads_) {
|
||||
if (t.joinable()) t.join();
|
||||
}
|
||||
}
|
||||
|
||||
void runTaskInQueue(const std::function<void()>& task) override
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(taskMutex_);
|
||||
taskQueue_.push(task);
|
||||
taskCond_.notify_one();
|
||||
}
|
||||
|
||||
void runTaskInQueue(std::function<void()>&& task) override
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(taskMutex_);
|
||||
taskQueue_.push(std::move(task));
|
||||
taskCond_.notify_one();
|
||||
}
|
||||
|
||||
std::string getName() const override { return name_; }
|
||||
|
||||
size_t getTaskCount()
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(taskMutex_);
|
||||
return taskQueue_.size();
|
||||
}
|
||||
};
|
||||
|
||||
// Logging macros
|
||||
#define LOG_TRACE Logger(LogLevel::TRACE)
|
||||
#define LOG_DEBUG Logger(LogLevel::DEBUG)
|
||||
#define LOG_INFO Logger(LogLevel::INFO)
|
||||
#define LOG_WARN Logger(LogLevel::WARN)
|
||||
#define LOG_ERROR Logger(LogLevel::ERROR)
|
||||
#define LOG_FATAL Logger(LogLevel::FATAL)
|
||||
|
||||
// Utility functions
|
||||
template<typename T>
|
||||
void hashCombine(std::size_t& seed, const T& value)
|
||||
{
|
||||
std::hash<T> hasher;
|
||||
seed ^= hasher(value) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
|
||||
}
|
||||
|
||||
inline std::vector<std::string> splitString(const std::string& s, const std::string& delimiter)
|
||||
{
|
||||
std::vector<std::string> result;
|
||||
size_t start = 0;
|
||||
size_t end = s.find(delimiter);
|
||||
|
||||
while (end != std::string::npos) {
|
||||
result.push_back(s.substr(start, end - start));
|
||||
start = end + delimiter.length();
|
||||
end = s.find(delimiter, start);
|
||||
}
|
||||
result.push_back(s.substr(start));
|
||||
return result;
|
||||
}
|
||||
|
||||
} // namespace reactor
|
Loading…
x
Reference in New Issue
Block a user