From 14a395aa3cd46a67d69004dbcdc8714d83be620e Mon Sep 17 00:00:00 2001 From: Sky Johnson Date: Fri, 27 Jun 2025 18:01:00 -0500 Subject: [PATCH] v1.0 lib --- lib/Buffer.hpp | 262 ++++++++++++++++++ lib/Core.hpp | 591 ++++++++++++++++++++++++++++++++++++++++ lib/EventLoopThread.hpp | 110 ++++++++ lib/InetAddress.hpp | 170 ++++++++++++ lib/Socket.hpp | 318 +++++++++++++++++++++ lib/TcpConnection.hpp | 364 +++++++++++++++++++++++++ lib/TcpServer.hpp | 236 ++++++++++++++++ lib/Utilities.hpp | 346 +++++++++++++++++++++++ 8 files changed, 2397 insertions(+) create mode 100644 lib/Buffer.hpp create mode 100644 lib/Core.hpp create mode 100644 lib/EventLoopThread.hpp create mode 100644 lib/InetAddress.hpp create mode 100644 lib/Socket.hpp create mode 100644 lib/TcpConnection.hpp create mode 100644 lib/TcpServer.hpp create mode 100644 lib/Utilities.hpp diff --git a/lib/Buffer.hpp b/lib/Buffer.hpp new file mode 100644 index 0000000..c44b42a --- /dev/null +++ b/lib/Buffer.hpp @@ -0,0 +1,262 @@ +#pragma once + +#include "Utilities.hpp" +#include +#include +#include +#include +#include + +namespace reactor { + +class Buffer : public NonCopyable +{ +private: + std::vector buffer_; + size_t readIndex_; + size_t writeIndex_; + size_t initialCap_; + + static constexpr size_t kBufferOffset = 8; + +public: + static constexpr size_t kInitialSize = 1024; + + explicit Buffer(size_t initialSize = kInitialSize) + : buffer_(initialSize + kBufferOffset), readIndex_(kBufferOffset), + writeIndex_(kBufferOffset), initialCap_(initialSize) + { + LOG_TRACE << "Buffer created with size " << initialSize; + } + + ~Buffer() + { + LOG_TRACE << "Buffer destroyed, had " << readableBytes() << " readable bytes"; + } + + size_t readableBytes() const { return writeIndex_ - readIndex_; } + size_t writableBytes() const { return buffer_.size() - writeIndex_; } + size_t prependableBytes() const { return readIndex_; } + + const char* peek() const { return &buffer_[readIndex_]; } + char* beginWrite() { return &buffer_[writeIndex_]; } + const char* beginWrite() const { return &buffer_[writeIndex_]; } + + void retrieve(size_t len) + { + if (len < readableBytes()) { + readIndex_ += len; + } else { + retrieveAll(); + } + } + + void retrieveAll() + { + if (buffer_.size() > (initialCap_ * 2)) { + buffer_.resize(initialCap_ + kBufferOffset); + LOG_DEBUG << "Buffer shrunk to " << buffer_.size(); + } + readIndex_ = kBufferOffset; + writeIndex_ = kBufferOffset; + } + + void hasWritten(size_t len) + { + writeIndex_ += len; + LOG_TRACE << "Buffer written " << len << " bytes, total readable: " << readableBytes(); + } + + void append(const char* data, size_t len) + { + ensureWritableBytes(len); + std::copy(data, data + len, beginWrite()); + hasWritten(len); + } + + void append(const std::string& str) + { + append(str.data(), str.size()); + } + + void ensureWritableBytes(size_t len) + { + if (writableBytes() >= len) return; + + if (readIndex_ + writableBytes() >= (len + kBufferOffset)) { + // Move readable data to front + std::copy(&buffer_[readIndex_], &buffer_[writeIndex_], &buffer_[kBufferOffset]); + writeIndex_ = kBufferOffset + (writeIndex_ - readIndex_); + readIndex_ = kBufferOffset; + LOG_TRACE << "Buffer compacted, readable bytes: " << readableBytes(); + } else { + // Grow buffer + size_t newLen = std::max(buffer_.size() * 2, kBufferOffset + readableBytes() + len); + buffer_.resize(newLen); + LOG_DEBUG << "Buffer grown to " << newLen; + } + } + + // Network byte order append functions + void appendInt8(uint8_t x) + { + append(reinterpret_cast(&x), sizeof(x)); + } + + void appendInt16(uint16_t x) + { + uint16_t be = htons(x); + append(reinterpret_cast(&be), sizeof(be)); + } + + void appendInt32(uint32_t x) + { + uint32_t be = htonl(x); + append(reinterpret_cast(&be), sizeof(be)); + } + + void appendInt64(uint64_t x) + { + uint64_t be = hton64(x); + append(reinterpret_cast(&be), sizeof(be)); + } + + // Network byte order read functions + uint8_t readInt8() + { + uint8_t result = peekInt8(); + retrieve(sizeof(result)); + return result; + } + + uint16_t readInt16() + { + uint16_t result = peekInt16(); + retrieve(sizeof(result)); + return result; + } + + uint32_t readInt32() + { + uint32_t result = peekInt32(); + retrieve(sizeof(result)); + return result; + } + + uint64_t readInt64() + { + uint64_t result = peekInt64(); + retrieve(sizeof(result)); + return result; + } + + // Network byte order peek functions + uint8_t peekInt8() const + { + assert(readableBytes() >= sizeof(uint8_t)); + return *reinterpret_cast(peek()); + } + + uint16_t peekInt16() const + { + assert(readableBytes() >= sizeof(uint16_t)); + uint16_t be = *reinterpret_cast(peek()); + return ntohs(be); + } + + uint32_t peekInt32() const + { + assert(readableBytes() >= sizeof(uint32_t)); + uint32_t be = *reinterpret_cast(peek()); + return ntohl(be); + } + + uint64_t peekInt64() const + { + assert(readableBytes() >= sizeof(uint64_t)); + uint64_t be = *reinterpret_cast(peek()); + return ntoh64(be); + } + + // Prepend functions for efficient header insertion + void prepend(const char* data, size_t len) + { + assert(len <= prependableBytes()); + readIndex_ -= len; + std::copy(data, data + len, &buffer_[readIndex_]); + } + + void prependInt8(uint8_t x) + { + prepend(reinterpret_cast(&x), sizeof(x)); + } + + void prependInt16(uint16_t x) + { + uint16_t be = htons(x); + prepend(reinterpret_cast(&be), sizeof(be)); + } + + void prependInt32(uint32_t x) + { + uint32_t be = htonl(x); + prepend(reinterpret_cast(&be), sizeof(be)); + } + + void prependInt64(uint64_t x) + { + uint64_t be = hton64(x); + prepend(reinterpret_cast(&be), sizeof(be)); + } + + std::string read(size_t len) + { + if (len > readableBytes()) len = readableBytes(); + std::string result(peek(), len); + retrieve(len); + return result; + } + + std::string readAll() + { + return read(readableBytes()); + } + + ssize_t readFd(int fd, int* savedErrno = nullptr) + { + char extrabuf[65536]; + iovec vec[2]; + size_t writable = writableBytes(); + + vec[0].iov_base = beginWrite(); + vec[0].iov_len = writable; + vec[1].iov_base = extrabuf; + vec[1].iov_len = sizeof(extrabuf); + + const int iovcnt = (writable < sizeof(extrabuf)) ? 2 : 1; + ssize_t n = readv(fd, vec, iovcnt); + + if (n < 0) { + if (savedErrno) *savedErrno = errno; + LOG_ERROR << "readFd error: " << strerror(errno); + } else if (static_cast(n) <= writable) { + writeIndex_ += n; + } else { + writeIndex_ = buffer_.size(); + append(extrabuf, n - writable); + } + + LOG_TRACE << "readFd returned " << n << " bytes, buffer now has " << readableBytes(); + return n; + } + + void swap(Buffer& other) noexcept + { + buffer_.swap(other.buffer_); + std::swap(readIndex_, other.readIndex_); + std::swap(writeIndex_, other.writeIndex_); + std::swap(initialCap_, other.initialCap_); + } +}; + +} // namespace reactor diff --git a/lib/Core.hpp b/lib/Core.hpp new file mode 100644 index 0000000..6b84e67 --- /dev/null +++ b/lib/Core.hpp @@ -0,0 +1,591 @@ +#pragma once + +#include "Utilities.hpp" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace reactor { + +using TimePoint = std::chrono::steady_clock::time_point; +using Duration = std::chrono::milliseconds; +using TimerCallback = std::function; +using EventCallback = std::function; + +class EventLoop; + +class Channel : public NonCopyable +{ +private: + EventLoop* loop_; + int fd_; + int events_; + int revents_; + int index_; + bool tied_; + std::weak_ptr tie_; + EventCallback readCallback_; + EventCallback writeCallback_; + EventCallback closeCallback_; + EventCallback errorCallback_; + + void handleEventSafely() + { + LOG_TRACE << "Channel fd=" << fd_ << " handling events: " << revents_; + + if ((revents_ & POLLHUP) && !(revents_ & POLLIN)) { + LOG_DEBUG << "Channel fd=" << fd_ << " hangup"; + if (closeCallback_) closeCallback_(); + } + if (revents_ & POLLERR) { + LOG_WARN << "Channel fd=" << fd_ << " error event"; + if (errorCallback_) errorCallback_(); + } + if (revents_ & (POLLIN | POLLPRI | POLLRDHUP)) { + LOG_TRACE << "Channel fd=" << fd_ << " readable"; + if (readCallback_) readCallback_(); + } + if (revents_ & POLLOUT) { + LOG_TRACE << "Channel fd=" << fd_ << " writable"; + if (writeCallback_) writeCallback_(); + } + } + +public: + static constexpr int kNoneEvent = 0; + static constexpr int kReadEvent = POLLIN | POLLPRI; + static constexpr int kWriteEvent = POLLOUT; + + Channel(EventLoop* loop, int fd) + : loop_(loop), fd_(fd), events_(0), revents_(0), index_(-1), tied_(false) + { + LOG_TRACE << "Channel created for fd=" << fd; + } + + ~Channel() + { + LOG_TRACE << "Channel destroyed for fd=" << fd_; + } + + int fd() const { return fd_; } + int events() const { return events_; } + int revents() const { return revents_; } + int index() const { return index_; } + EventLoop* ownerLoop() const { return loop_; } + + void setRevents(int revents) { revents_ = revents; } + void setIndex(int index) { index_ = index; } + + void enableReading() + { + events_ |= kReadEvent; + update(); + LOG_TRACE << "Channel fd=" << fd_ << " enabled reading"; + } + + void disableReading() + { + events_ &= ~kReadEvent; + update(); + LOG_TRACE << "Channel fd=" << fd_ << " disabled reading"; + } + + void enableWriting() + { + events_ |= kWriteEvent; + update(); + LOG_TRACE << "Channel fd=" << fd_ << " enabled writing"; + } + + void disableWriting() + { + events_ &= ~kWriteEvent; + update(); + LOG_TRACE << "Channel fd=" << fd_ << " disabled writing"; + } + + void disableAll() + { + events_ = kNoneEvent; + update(); + LOG_TRACE << "Channel fd=" << fd_ << " disabled all events"; + } + + bool isNoneEvent() const { return events_ == kNoneEvent; } + bool isWriting() const { return events_ & kWriteEvent; } + bool isReading() const { return events_ & kReadEvent; } + + void tie(const std::shared_ptr& obj) + { + tie_ = obj; + tied_ = true; + } + + void setReadCallback(EventCallback cb) { readCallback_ = std::move(cb); } + void setWriteCallback(EventCallback cb) { writeCallback_ = std::move(cb); } + void setCloseCallback(EventCallback cb) { closeCallback_ = std::move(cb); } + void setErrorCallback(EventCallback cb) { errorCallback_ = std::move(cb); } + + void handleEvent() + { + if (tied_) { + std::shared_ptr guard = tie_.lock(); + if (guard) { + handleEventSafely(); + } else { + LOG_WARN << "Channel fd=" << fd_ << " tied object expired"; + } + } else { + handleEventSafely(); + } + } + + void remove(); + void update(); +}; + +class Timer : public NonCopyable +{ +private: + TimerCallback callback_; + TimePoint when_; + Duration interval_; + bool repeat_; + uint64_t id_; + static inline std::atomic s_numCreated_{0}; + +public: + Timer(TimerCallback cb, TimePoint when, Duration interval = Duration(0)) + : callback_(std::move(cb)), when_(when), interval_(interval), + repeat_(interval.count() > 0), id_(++s_numCreated_) + { + LOG_TRACE << "Timer created id=" << id_ << " repeat=" << repeat_; + } + + void run() const + { + LOG_TRACE << "Timer id=" << id_ << " executing"; + callback_(); + } + + TimePoint when() const { return when_; } + bool repeat() const { return repeat_; } + uint64_t id() const { return id_; } + + void restart(TimePoint now) + { + if (repeat_) { + when_ = now + interval_; + LOG_TRACE << "Timer id=" << id_ << " restarted"; + } + } + + bool operator<(const Timer& other) const { return when_ > other.when_; } +}; + +class TimerQueue : public NonCopyable +{ +private: + EventLoop* loop_; + int timerfd_; + std::unique_ptr timerChannel_; + std::priority_queue> timers_; + std::unordered_set activeTimers_; + + static int createTimerfd() + { + int fd = timerfd_create(CLOCK_MONOTONIC, TFD_NONBLOCK | TFD_CLOEXEC); + if (fd < 0) { + LOG_FATAL << "Failed to create timerfd: " << strerror(errno); + abort(); + } + LOG_DEBUG << "Created timerfd=" << fd; + return fd; + } + + void resetTimerfd(TimePoint expiration) + { + auto duration = expiration - std::chrono::steady_clock::now(); + auto ns = std::chrono::duration_cast(duration).count(); + if (ns < 100000) ns = 100000; + + itimerspec newValue{}; + newValue.it_value.tv_sec = ns / 1000000000; + newValue.it_value.tv_nsec = ns % 1000000000; + + if (timerfd_settime(timerfd_, 0, &newValue, nullptr) < 0) { + LOG_ERROR << "timerfd_settime failed: " << strerror(errno); + } + } + + void handleRead() + { + uint64_t data; + ssize_t n = read(timerfd_, &data, sizeof(data)); + if (n != sizeof(data)) { + LOG_ERROR << "TimerQueue read " << n << " bytes instead of 8"; + } + + auto now = std::chrono::steady_clock::now(); + std::vector> expired; + + while (!timers_.empty() && timers_.top()->when() <= now) { + expired.push_back(timers_.top()); + timers_.pop(); + } + + LOG_TRACE << "TimerQueue processed " << expired.size() << " expired timers"; + + for (auto& timer : expired) { + if (activeTimers_.count(timer->id())) { + timer->run(); + if (timer->repeat()) { + timer->restart(now); + timers_.push(timer); + } else { + activeTimers_.erase(timer->id()); + } + } + } + + if (!timers_.empty()) { + resetTimerfd(timers_.top()->when()); + } + } + +public: + TimerQueue(EventLoop* loop) + : loop_(loop), timerfd_(createTimerfd()), + timerChannel_(std::make_unique(loop, timerfd_)) + { + timerChannel_->setReadCallback([this]() { handleRead(); }); + timerChannel_->enableReading(); + LOG_DEBUG << "TimerQueue initialized"; + } + + ~TimerQueue() + { + timerChannel_->disableAll(); + timerChannel_->remove(); + close(timerfd_); + LOG_DEBUG << "TimerQueue destroyed"; + } + + uint64_t addTimer(TimerCallback cb, TimePoint when, Duration interval = Duration(0)) + { + auto timer = std::make_shared(std::move(cb), when, interval); + bool earliestChanged = timers_.empty() || when < timers_.top()->when(); + + timers_.push(timer); + activeTimers_.insert(timer->id()); + + if (earliestChanged) { + resetTimerfd(when); + } + + LOG_DEBUG << "Added timer id=" << timer->id() << " earliest_changed=" << earliestChanged; + return timer->id(); + } + + void cancel(uint64_t timerId) + { + auto erased = activeTimers_.erase(timerId); + LOG_DEBUG << "Cancelled timer id=" << timerId << " found=" << (erased > 0); + } +}; + +class EpollPoller : public NonCopyable +{ +private: + static constexpr int kNew = -1; + static constexpr int kAdded = 1; + static constexpr int kDeleted = 2; + static constexpr int kInitEventListSize = 16; + + int epollfd_; + std::vector events_; + std::unordered_map channels_; + + void update(int operation, Channel* channel) + { + epoll_event event{}; + event.events = channel->events(); + event.data.ptr = channel; + + if (epoll_ctl(epollfd_, operation, channel->fd(), &event) < 0) { + LOG_ERROR << "epoll_ctl op=" << operation << " fd=" << channel->fd() + << " failed: " << strerror(errno); + } else { + LOG_TRACE << "epoll_ctl op=" << operation << " fd=" << channel->fd() << " success"; + } + } + +public: + EpollPoller() : epollfd_(epoll_create1(EPOLL_CLOEXEC)), events_(kInitEventListSize) + { + if (epollfd_ < 0) { + LOG_FATAL << "epoll_create1 failed: " << strerror(errno); + abort(); + } + LOG_DEBUG << "EpollPoller created with epollfd=" << epollfd_; + } + + ~EpollPoller() + { + close(epollfd_); + LOG_DEBUG << "EpollPoller destroyed"; + } + + std::vector poll(int timeoutMs = -1) + { + int numEvents = epoll_wait(epollfd_, events_.data(), + static_cast(events_.size()), timeoutMs); + + std::vector activeChannels; + + if (numEvents > 0) { + LOG_TRACE << "EpollPoller got " << numEvents << " events"; + + for (int i = 0; i < numEvents; ++i) { + auto channel = static_cast(events_[i].data.ptr); + channel->setRevents(events_[i].events); + activeChannels.push_back(channel); + } + + if (static_cast(numEvents) == events_.size()) { + events_.resize(events_.size() * 2); + LOG_DEBUG << "EpollPoller events buffer grown to " << events_.size(); + } + } else if (numEvents < 0 && errno != EINTR) { + LOG_ERROR << "EpollPoller::poll failed: " << strerror(errno); + } + + return activeChannels; + } + + void updateChannel(Channel* channel) + { + int index = channel->index(); + int fd = channel->fd(); + + if (index == kNew || index == kDeleted) { + if (index == kNew) { + channels_[fd] = channel; + } + channel->setIndex(kAdded); + update(EPOLL_CTL_ADD, channel); + } else { + if (channel->isNoneEvent()) { + update(EPOLL_CTL_DEL, channel); + channel->setIndex(kDeleted); + } else { + update(EPOLL_CTL_MOD, channel); + } + } + } + + void removeChannel(Channel* channel) + { + int fd = channel->fd(); + int index = channel->index(); + + auto n = channels_.erase(fd); + assert(n == 1); + + if (index == kAdded) { + update(EPOLL_CTL_DEL, channel); + } + + channel->setIndex(kNew); + LOG_TRACE << "Channel fd=" << fd << " removed from poller"; + } +}; + +class EventLoop : public NonCopyable +{ +private: + std::unique_ptr poller_; + std::unique_ptr timerQueue_; + int wakeupFd_; + std::unique_ptr wakeupChannel_; + std::atomic looping_; + std::atomic quit_; + std::thread::id threadId_; + LockFreeQueue> pendingFunctors_; + bool callingPendingFunctors_; + + static int createEventfd() + { + int fd = eventfd(0, EFD_NONBLOCK | EFD_CLOEXEC); + if (fd < 0) { + LOG_FATAL << "Failed to create eventfd: " << strerror(errno); + abort(); + } + LOG_DEBUG << "Created eventfd=" << fd; + return fd; + } + + void wakeup() + { + uint64_t one = 1; + ssize_t n = write(wakeupFd_, &one, sizeof(one)); + if (n != sizeof(one)) { + LOG_ERROR << "EventLoop wakeup writes " << n << " bytes instead of 8"; + } + } + + void handleRead() + { + uint64_t one; + ssize_t n = read(wakeupFd_, &one, sizeof(one)); + if (n != sizeof(one)) { + LOG_ERROR << "EventLoop handleRead reads " << n << " bytes instead of 8"; + } + } + + void doPendingFunctors() + { + callingPendingFunctors_ = true; + + std::function functor; + int count = 0; + while (pendingFunctors_.dequeue(functor)) { + functor(); + ++count; + } + + if (count > 0) { + LOG_TRACE << "EventLoop executed " << count << " pending functors"; + } + + callingPendingFunctors_ = false; + } + +public: + EventLoop() + : poller_(std::make_unique()), + timerQueue_(std::make_unique(this)), + wakeupFd_(createEventfd()), + wakeupChannel_(std::make_unique(this, wakeupFd_)), + looping_(false), quit_(false), + threadId_(std::this_thread::get_id()), + callingPendingFunctors_(false) + { + wakeupChannel_->setReadCallback([this]() { handleRead(); }); + wakeupChannel_->enableReading(); + LOG_INFO << "EventLoop created in thread " << threadId_; + } + + ~EventLoop() + { + wakeupChannel_->disableAll(); + wakeupChannel_->remove(); + close(wakeupFd_); + LOG_INFO << "EventLoop destroyed"; + } + + void loop() + { + assert(!looping_); + assertInLoopThread(); + looping_ = true; + quit_ = false; + + LOG_INFO << "EventLoop started looping"; + + while (!quit_) { + auto activeChannels = poller_->poll(10000); + for (auto channel : activeChannels) { + channel->handleEvent(); + } + doPendingFunctors(); + } + + looping_ = false; + LOG_INFO << "EventLoop stopped looping"; + } + + void quit() + { + quit_ = true; + if (!isInLoopThread()) { + wakeup(); + } + LOG_DEBUG << "EventLoop quit requested"; + } + + void runInLoop(std::function cb) + { + if (isInLoopThread()) { + cb(); + } else { + queueInLoop(std::move(cb)); + } + } + + void queueInLoop(std::function cb) + { + pendingFunctors_.enqueue(std::move(cb)); + + if (!isInLoopThread() || callingPendingFunctors_) { + wakeup(); + } + } + + uint64_t runAt(TimePoint when, TimerCallback cb) + { + return timerQueue_->addTimer(std::move(cb), when); + } + + uint64_t runAfter(Duration delay, TimerCallback cb) + { + return runAt(std::chrono::steady_clock::now() + delay, std::move(cb)); + } + + uint64_t runEvery(Duration interval, TimerCallback cb) + { + auto when = std::chrono::steady_clock::now() + interval; + return timerQueue_->addTimer(std::move(cb), when, interval); + } + + void cancel(uint64_t timerId) { timerQueue_->cancel(timerId); } + + void updateChannel(Channel* channel) { poller_->updateChannel(channel); } + void removeChannel(Channel* channel) { poller_->removeChannel(channel); } + + bool isInLoopThread() const { return threadId_ == std::this_thread::get_id(); } + + void assertInLoopThread() const + { + if (!isInLoopThread()) { + LOG_FATAL << "EventLoop was created in thread " << threadId_ + << " but accessed from thread " << std::this_thread::get_id(); + abort(); + } + } +}; + +inline void Channel::update() +{ + loop_->updateChannel(this); +} + +inline void Channel::remove() +{ + assert(isNoneEvent()); + loop_->removeChannel(this); +} + +} // namespace reactor diff --git a/lib/EventLoopThread.hpp b/lib/EventLoopThread.hpp new file mode 100644 index 0000000..3a208d1 --- /dev/null +++ b/lib/EventLoopThread.hpp @@ -0,0 +1,110 @@ +#pragma once + +#include "Core.hpp" +#include "Utilities.hpp" +#include +#include +#include +#include +#include +#include + +namespace reactor { + +class EventLoopThread : public NonCopyable +{ +private: + std::thread thread_; + EventLoop* loop_; + std::mutex mutex_; + std::condition_variable cond_; + std::string name_; + + void threadFunc() + { + LOG_DEBUG << "EventLoopThread '" << name_ << "' starting"; + EventLoop loop; + { + std::lock_guard lock(mutex_); + loop_ = &loop; + cond_.notify_one(); + } + + loop.loop(); + + std::lock_guard lock(mutex_); + loop_ = nullptr; + LOG_DEBUG << "EventLoopThread '" << name_ << "' finished"; + } + +public: + explicit EventLoopThread(const std::string& name = "EventLoopThread") + : loop_(nullptr), name_(name) + { + thread_ = std::thread([this]() { threadFunc(); }); + std::unique_lock lock(mutex_); + cond_.wait(lock, [this]() { return loop_ != nullptr; }); + LOG_INFO << "EventLoopThread '" << name_ << "' initialized"; + } + + ~EventLoopThread() + { + if (loop_) { + loop_->quit(); + thread_.join(); + } + LOG_INFO << "EventLoopThread '" << name_ << "' destroyed"; + } + + EventLoop* getLoop() { return loop_; } + const std::string& name() const { return name_; } +}; + +class EventLoopThreadPool : public NonCopyable +{ +private: + std::vector> threads_; + std::vector loops_; + std::atomic next_; + std::string baseThreadName_; + +public: + explicit EventLoopThreadPool(size_t numThreads, const std::string& baseName = "EventLoopThread") + : next_(0), baseThreadName_(baseName) + { + LOG_INFO << "Creating EventLoopThreadPool with " << numThreads << " threads"; + + for (size_t i = 0; i < numThreads; ++i) { + std::string threadName = baseThreadName_ + "-" + std::to_string(i); + auto thread = std::make_unique(threadName); + loops_.push_back(thread->getLoop()); + threads_.push_back(std::move(thread)); + } + + LOG_INFO << "EventLoopThreadPool created with " << numThreads << " threads"; + } + + ~EventLoopThreadPool() + { + LOG_INFO << "EventLoopThreadPool destroying " << threads_.size() << " threads"; + } + + EventLoop* getNextLoop() + { + if (loops_.empty()) { + LOG_WARN << "EventLoopThreadPool has no loops available"; + return nullptr; + } + + size_t index = next_++ % loops_.size(); + LOG_TRACE << "EventLoopThreadPool returning loop " << index; + return loops_[index]; + } + + std::vector getAllLoops() const { return loops_; } + size_t size() const { return loops_.size(); } + + const std::string& getBaseName() const { return baseThreadName_; } +}; + +} // namespace reactor diff --git a/lib/InetAddress.hpp b/lib/InetAddress.hpp new file mode 100644 index 0000000..23216be --- /dev/null +++ b/lib/InetAddress.hpp @@ -0,0 +1,170 @@ +#pragma once + +#include "Utilities.hpp" +#include +#include +#include +#include + +namespace reactor { + +class InetAddress +{ +private: + union + { + sockaddr_in addr4_; + sockaddr_in6 addr6_; + }; + bool isIpV6_; + +public: + explicit InetAddress(uint16_t port = 0, bool ipv6 = false, bool loopback = false) + : isIpV6_(ipv6) + { + if (ipv6) { + memset(&addr6_, 0, sizeof(addr6_)); + addr6_.sin6_family = AF_INET6; + addr6_.sin6_addr = loopback ? in6addr_loopback : in6addr_any; + addr6_.sin6_port = htons(port); + } else { + memset(&addr4_, 0, sizeof(addr4_)); + addr4_.sin_family = AF_INET; + addr4_.sin_addr.s_addr = htonl(loopback ? INADDR_LOOPBACK : INADDR_ANY); + addr4_.sin_port = htons(port); + } + LOG_TRACE << "InetAddress created: " << toIpPort(); + } + + InetAddress(const std::string& ip, uint16_t port) + { + if (ip.find(':') != std::string::npos) { + isIpV6_ = true; + memset(&addr6_, 0, sizeof(addr6_)); + addr6_.sin6_family = AF_INET6; + addr6_.sin6_port = htons(port); + if (inet_pton(AF_INET6, ip.c_str(), &addr6_.sin6_addr) <= 0) { + LOG_ERROR << "Invalid IPv6 address: " << ip; + } + } else { + isIpV6_ = false; + memset(&addr4_, 0, sizeof(addr4_)); + addr4_.sin_family = AF_INET; + addr4_.sin_port = htons(port); + if (inet_pton(AF_INET, ip.c_str(), &addr4_.sin_addr) <= 0) { + LOG_ERROR << "Invalid IPv4 address: " << ip; + } + } + LOG_TRACE << "InetAddress created from ip:port: " << toIpPort(); + } + + explicit InetAddress(const sockaddr_in& addr) : addr4_(addr), isIpV6_(false) + { + LOG_TRACE << "InetAddress created from sockaddr_in: " << toIpPort(); + } + + explicit InetAddress(const sockaddr_in6& addr) : addr6_(addr), isIpV6_(true) + { + LOG_TRACE << "InetAddress created from sockaddr_in6: " << toIpPort(); + } + + const sockaddr* getSockAddr() const + { + if (isIpV6_) { + return reinterpret_cast(&addr6_); + } else { + return reinterpret_cast(&addr4_); + } + } + + socklen_t getSockLen() const { return isIpV6_ ? sizeof(addr6_) : sizeof(addr4_); } + bool isIpV6() const { return isIpV6_; } + uint16_t port() const { return ntohs(isIpV6_ ? addr6_.sin6_port : addr4_.sin_port); } + + std::string toIp() const + { + char buf[INET6_ADDRSTRLEN]; + if (isIpV6_) { + inet_ntop(AF_INET6, &addr6_.sin6_addr, buf, sizeof(buf)); + } else { + inet_ntop(AF_INET, &addr4_.sin_addr, buf, sizeof(buf)); + } + return std::string(buf); + } + + std::string toIpPort() const + { + return isIpV6_ ? "[" + toIp() + "]:" + std::to_string(port()) + : toIp() + ":" + std::to_string(port()); + } + + bool operator==(const InetAddress& other) const + { + if (isIpV6_ != other.isIpV6_) return false; + + if (isIpV6_) { + return memcmp(&addr6_, &other.addr6_, sizeof(addr6_)) == 0; + } else { + return memcmp(&addr4_, &other.addr4_, sizeof(addr4_)) == 0; + } + } + + bool operator!=(const InetAddress& other) const + { + return !(*this == other); + } + + bool operator<(const InetAddress& other) const + { + if (isIpV6_ != other.isIpV6_) { + return !isIpV6_; + } + + if (isIpV6_) { + return memcmp(&addr6_, &other.addr6_, sizeof(addr6_)) < 0; + } else { + return memcmp(&addr4_, &other.addr4_, sizeof(addr4_)) < 0; + } + } + + std::string familyToString() const + { + return isIpV6_ ? "IPv6" : "IPv4"; + } + + static bool resolve(const std::string& hostname, InetAddress& result) + { + // Simple resolution - in a real implementation you'd use getaddrinfo + if (hostname == "localhost") { + result = InetAddress(0, false, true); + return true; + } + + // Try to parse as IP address directly + InetAddress addr(hostname, 0); + if (addr.toIp() != "0.0.0.0" && addr.toIp() != "::") { + result = addr; + return true; + } + + LOG_WARN << "Could not resolve hostname: " << hostname; + return false; + } +}; + +} // namespace reactor + +namespace std { +template<> +struct hash +{ + size_t operator()(const reactor::InetAddress& addr) const + { + size_t seed = 0; + reactor::hashCombine(seed, addr.toIp()); + reactor::hashCombine(seed, addr.port()); + reactor::hashCombine(seed, addr.isIpV6()); + return seed; + } +}; +} // namespace std diff --git a/lib/Socket.hpp b/lib/Socket.hpp new file mode 100644 index 0000000..6105715 --- /dev/null +++ b/lib/Socket.hpp @@ -0,0 +1,318 @@ +#pragma once + +#include "InetAddress.hpp" +#include "Utilities.hpp" +#include +#include +#include +#include +#include + +namespace reactor { + +class Socket : public NonCopyable +{ +private: + int fd_; + + void setNonBlockAndCloseOnExec() + { + int flags = fcntl(fd_, F_GETFL, 0); + flags |= O_NONBLOCK; + fcntl(fd_, F_SETFL, flags); + + flags = fcntl(fd_, F_GETFD, 0); + flags |= FD_CLOEXEC; + fcntl(fd_, F_SETFD, flags); + } + +public: + explicit Socket(int fd) : fd_(fd) + { + LOG_TRACE << "Socket created with fd=" << fd_; + } + + ~Socket() + { + if (fd_ >= 0) { + close(fd_); + LOG_TRACE << "Socket fd=" << fd_ << " closed"; + } + } + + Socket(Socket&& other) noexcept : fd_(other.fd_) + { + other.fd_ = -1; + LOG_TRACE << "Socket moved fd=" << fd_; + } + + Socket& operator=(Socket&& other) noexcept + { + if (this != &other) { + if (fd_ >= 0) { + close(fd_); + LOG_TRACE << "Socket fd=" << fd_ << " closed in move assignment"; + } + fd_ = other.fd_; + other.fd_ = -1; + } + return *this; + } + + static Socket createTcp(bool ipv6 = false) + { + int fd = socket(ipv6 ? AF_INET6 : AF_INET, SOCK_STREAM | SOCK_NONBLOCK | SOCK_CLOEXEC, 0); + if (fd < 0) { + LOG_FATAL << "Failed to create TCP socket: " << strerror(errno); + abort(); + } + LOG_DEBUG << "Created TCP socket fd=" << fd << " ipv6=" << ipv6; + return Socket(fd); + } + + static Socket createUdp(bool ipv6 = false) + { + int fd = socket(ipv6 ? AF_INET6 : AF_INET, SOCK_DGRAM | SOCK_NONBLOCK | SOCK_CLOEXEC, 0); + if (fd < 0) { + LOG_FATAL << "Failed to create UDP socket: " << strerror(errno); + abort(); + } + LOG_DEBUG << "Created UDP socket fd=" << fd << " ipv6=" << ipv6; + return Socket(fd); + } + + void bind(const InetAddress& addr) + { + int ret = ::bind(fd_, addr.getSockAddr(), addr.getSockLen()); + if (ret < 0) { + LOG_FATAL << "Socket bind to " << addr.toIpPort() << " failed: " << strerror(errno); + abort(); + } + LOG_INFO << "Socket fd=" << fd_ << " bound to " << addr.toIpPort(); + } + + void listen(int backlog = SOMAXCONN) + { + int ret = ::listen(fd_, backlog); + if (ret < 0) { + LOG_FATAL << "Socket listen failed: " << strerror(errno); + abort(); + } + LOG_INFO << "Socket fd=" << fd_ << " listening with backlog=" << backlog; + } + + int accept(InetAddress& peerAddr) + { + sockaddr_in6 addr; + socklen_t len = sizeof(addr); + int connfd = accept4(fd_, reinterpret_cast(&addr), &len, SOCK_NONBLOCK | SOCK_CLOEXEC); + + if (connfd >= 0) { + if (addr.sin6_family == AF_INET) { + peerAddr = InetAddress(*reinterpret_cast(&addr)); + } else { + peerAddr = InetAddress(addr); + } + LOG_DEBUG << "Socket fd=" << fd_ << " accepted connection fd=" << connfd + << " from " << peerAddr.toIpPort(); + } else if (errno != EAGAIN && errno != EWOULDBLOCK) { + LOG_ERROR << "Socket accept failed: " << strerror(errno); + } + + return connfd; + } + + int connect(const InetAddress& addr) + { + int ret = ::connect(fd_, addr.getSockAddr(), addr.getSockLen()); + if (ret < 0 && errno != EINPROGRESS) { + LOG_ERROR << "Socket connect to " << addr.toIpPort() << " failed: " << strerror(errno); + } else { + LOG_DEBUG << "Socket fd=" << fd_ << " connecting to " << addr.toIpPort(); + } + return ret; + } + + void setReuseAddr(bool on = true) + { + int optval = on ? 1 : 0; + if (setsockopt(fd_, SOL_SOCKET, SO_REUSEADDR, &optval, sizeof(optval)) < 0) { + LOG_ERROR << "setsockopt SO_REUSEADDR failed: " << strerror(errno); + } else { + LOG_TRACE << "Socket fd=" << fd_ << " SO_REUSEADDR=" << on; + } + } + + void setReusePort(bool on = true) + { + int optval = on ? 1 : 0; + if (setsockopt(fd_, SOL_SOCKET, SO_REUSEPORT, &optval, sizeof(optval)) < 0) { + LOG_ERROR << "setsockopt SO_REUSEPORT failed: " << strerror(errno); + } else { + LOG_TRACE << "Socket fd=" << fd_ << " SO_REUSEPORT=" << on; + } + } + + void setTcpNoDelay(bool on = true) + { + int optval = on ? 1 : 0; + if (setsockopt(fd_, IPPROTO_TCP, TCP_NODELAY, &optval, sizeof(optval)) < 0) { + LOG_ERROR << "setsockopt TCP_NODELAY failed: " << strerror(errno); + } else { + LOG_TRACE << "Socket fd=" << fd_ << " TCP_NODELAY=" << on; + } + } + + void setKeepAlive(bool on = true) + { + int optval = on ? 1 : 0; + if (setsockopt(fd_, SOL_SOCKET, SO_KEEPALIVE, &optval, sizeof(optval)) < 0) { + LOG_ERROR << "setsockopt SO_KEEPALIVE failed: " << strerror(errno); + } else { + LOG_TRACE << "Socket fd=" << fd_ << " SO_KEEPALIVE=" << on; + } + } + + void setTcpKeepAlive(int idle, int interval, int count) + { + if (setsockopt(fd_, IPPROTO_TCP, TCP_KEEPIDLE, &idle, sizeof(idle)) < 0 || + setsockopt(fd_, IPPROTO_TCP, TCP_KEEPINTVL, &interval, sizeof(interval)) < 0 || + setsockopt(fd_, IPPROTO_TCP, TCP_KEEPCNT, &count, sizeof(count)) < 0) { + LOG_ERROR << "setsockopt TCP_KEEP* failed: " << strerror(errno); + } else { + LOG_TRACE << "Socket fd=" << fd_ << " TCP keepalive: idle=" << idle + << " interval=" << interval << " count=" << count; + } + } + + void setRecvBuffer(int size) + { + if (setsockopt(fd_, SOL_SOCKET, SO_RCVBUF, &size, sizeof(size)) < 0) { + LOG_ERROR << "setsockopt SO_RCVBUF failed: " << strerror(errno); + } else { + LOG_TRACE << "Socket fd=" << fd_ << " SO_RCVBUF=" << size; + } + } + + void setSendBuffer(int size) + { + if (setsockopt(fd_, SOL_SOCKET, SO_SNDBUF, &size, sizeof(size)) < 0) { + LOG_ERROR << "setsockopt SO_SNDBUF failed: " << strerror(errno); + } else { + LOG_TRACE << "Socket fd=" << fd_ << " SO_SNDBUF=" << size; + } + } + + ssize_t read(void* buf, size_t len) + { + ssize_t n = ::read(fd_, buf, len); + if (n < 0 && errno != EAGAIN && errno != EWOULDBLOCK) { + LOG_ERROR << "Socket read failed: " << strerror(errno); + } + return n; + } + + ssize_t write(const void* buf, size_t len) + { + ssize_t n = ::write(fd_, buf, len); + if (n < 0 && errno != EAGAIN && errno != EWOULDBLOCK) { + LOG_ERROR << "Socket write failed: " << strerror(errno); + } + return n; + } + + ssize_t sendTo(const void* buf, size_t len, const InetAddress& addr) + { + ssize_t n = sendto(fd_, buf, len, 0, addr.getSockAddr(), addr.getSockLen()); + if (n < 0 && errno != EAGAIN && errno != EWOULDBLOCK) { + LOG_ERROR << "Socket sendto failed: " << strerror(errno); + } + return n; + } + + ssize_t recvFrom(void* buf, size_t len, InetAddress& addr) + { + sockaddr_in6 sockaddr; + socklen_t addrlen = sizeof(sockaddr); + ssize_t n = recvfrom(fd_, buf, len, 0, reinterpret_cast(&sockaddr), &addrlen); + + if (n >= 0) { + if (sockaddr.sin6_family == AF_INET) { + addr = InetAddress(*reinterpret_cast(&sockaddr)); + } else { + addr = InetAddress(sockaddr); + } + } else if (errno != EAGAIN && errno != EWOULDBLOCK) { + LOG_ERROR << "Socket recvfrom failed: " << strerror(errno); + } + + return n; + } + + void shutdownWrite() + { + if (shutdown(fd_, SHUT_WR) < 0) { + LOG_ERROR << "Socket shutdown write failed: " << strerror(errno); + } else { + LOG_DEBUG << "Socket fd=" << fd_ << " shutdown write"; + } + } + + void shutdownRead() + { + if (shutdown(fd_, SHUT_RD) < 0) { + LOG_ERROR << "Socket shutdown read failed: " << strerror(errno); + } else { + LOG_DEBUG << "Socket fd=" << fd_ << " shutdown read"; + } + } + + int getSocketError() + { + int optval; + socklen_t optlen = sizeof(optval); + if (getsockopt(fd_, SOL_SOCKET, SO_ERROR, &optval, &optlen) < 0) { + return errno; + } + return optval; + } + + int fd() const { return fd_; } + + static InetAddress getLocalAddr(int sockfd) + { + sockaddr_in6 addr; + socklen_t addrlen = sizeof(addr); + if (getsockname(sockfd, reinterpret_cast(&addr), &addrlen) < 0) { + LOG_ERROR << "getsockname failed: " << strerror(errno); + return InetAddress(); + } + + if (addr.sin6_family == AF_INET) { + return InetAddress(*reinterpret_cast(&addr)); + } + return InetAddress(addr); + } + + static InetAddress getPeerAddr(int sockfd) + { + sockaddr_in6 addr; + socklen_t addrlen = sizeof(addr); + if (getpeername(sockfd, reinterpret_cast(&addr), &addrlen) < 0) { + LOG_ERROR << "getpeername failed: " << strerror(errno); + return InetAddress(); + } + + if (addr.sin6_family == AF_INET) { + return InetAddress(*reinterpret_cast(&addr)); + } + return InetAddress(addr); + } + + bool isSelfConnected() + { + return getLocalAddr(fd_) == getPeerAddr(fd_); + } +}; + +} // namespace reactor diff --git a/lib/TcpConnection.hpp b/lib/TcpConnection.hpp new file mode 100644 index 0000000..6ed535a --- /dev/null +++ b/lib/TcpConnection.hpp @@ -0,0 +1,364 @@ +#pragma once + +#include "Core.hpp" +#include "Socket.hpp" +#include "Buffer.hpp" +#include "Utilities.hpp" +#include +#include +#include +#include + +namespace reactor { + +class TcpConnection; +using TcpConnectionPtr = std::shared_ptr; +using MessageCallback = std::function; +using ConnectionCallback = std::function; +using WriteCompleteCallback = std::function; +using HighWaterMarkCallback = std::function; + +class TcpConnection : public NonCopyable, public std::enable_shared_from_this +{ +public: + enum StateE { kDisconnected, kConnecting, kConnected, kDisconnecting }; + +private: + EventLoop* loop_; + Socket socket_; + std::unique_ptr channel_; + InetAddress localAddr_; + InetAddress peerAddr_; + std::string name_; + StateE state_; + Buffer inputBuffer_; + Buffer outputBuffer_; + MessageCallback messageCallback_; + ConnectionCallback connectionCallback_; + ConnectionCallback closeCallback_; + WriteCompleteCallback writeCompleteCallback_; + HighWaterMarkCallback highWaterMarkCallback_; + size_t highWaterMark_; + + void setState(StateE s) { state_ = s; } + + void handleRead() + { + loop_->assertInLoopThread(); + int savedErrno = 0; + ssize_t n = inputBuffer_.readFd(socket_.fd(), &savedErrno); + + if (n > 0) { + LOG_TRACE << "TcpConnection " << name_ << " read " << n << " bytes"; + if (messageCallback_) { + messageCallback_(shared_from_this(), inputBuffer_); + } + } else if (n == 0) { + LOG_DEBUG << "TcpConnection " << name_ << " peer closed"; + handleClose(); + } else { + errno = savedErrno; + LOG_ERROR << "TcpConnection " << name_ << " read error: " << strerror(savedErrno); + handleError(); + } + } + + void handleWrite() + { + loop_->assertInLoopThread(); + if (channel_->isWriting()) { + ssize_t n = socket_.write(outputBuffer_.peek(), outputBuffer_.readableBytes()); + if (n > 0) { + outputBuffer_.retrieve(n); + LOG_TRACE << "TcpConnection " << name_ << " wrote " << n << " bytes, " + << outputBuffer_.readableBytes() << " bytes left"; + + if (outputBuffer_.readableBytes() == 0) { + channel_->disableWriting(); + if (writeCompleteCallback_) { + loop_->queueInLoop([self = shared_from_this()]() { + self->writeCompleteCallback_(self); + }); + } + if (state_ == kDisconnecting) { + shutdownInLoop(); + } + } + } else { + LOG_ERROR << "TcpConnection " << name_ << " write error: " << strerror(errno); + } + } else { + LOG_TRACE << "TcpConnection " << name_ << " not writing, ignore"; + } + } + + void handleClose() + { + loop_->assertInLoopThread(); + LOG_DEBUG << "TcpConnection " << name_ << " state=" << stateToString(); + assert(state_ == kConnected || state_ == kDisconnecting); + setState(kDisconnected); + channel_->disableAll(); + + auto guardThis = shared_from_this(); + if (connectionCallback_) { + connectionCallback_(guardThis); + } + if (closeCallback_) { + closeCallback_(guardThis); + } + } + + void handleError() + { + int err = socket_.getSocketError(); + LOG_ERROR << "TcpConnection " << name_ << " SO_ERROR=" << err << " " << strerror(err); + handleClose(); + } + + void sendInLoop(const char* data, size_t len) + { + loop_->assertInLoopThread(); + ssize_t nwrote = 0; + size_t remaining = len; + bool faultError = false; + + if (state_ == kDisconnected) { + LOG_WARN << "TcpConnection " << name_ << " disconnected, give up writing"; + return; + } + + if (!channel_->isWriting() && outputBuffer_.readableBytes() == 0) { + nwrote = socket_.write(data, len); + if (nwrote >= 0) { + remaining = len - nwrote; + if (remaining == 0 && writeCompleteCallback_) { + loop_->queueInLoop([self = shared_from_this()]() { + self->writeCompleteCallback_(self); + }); + } + } else { + nwrote = 0; + if (errno != EWOULDBLOCK) { + LOG_ERROR << "TcpConnection " << name_ << " send error: " << strerror(errno); + if (errno == EPIPE || errno == ECONNRESET) { + faultError = true; + } + } + } + } + + assert(remaining <= len); + if (!faultError && remaining > 0) { + size_t oldLen = outputBuffer_.readableBytes(); + if (oldLen + remaining >= highWaterMark_ && + oldLen < highWaterMark_ && + highWaterMarkCallback_) { + loop_->queueInLoop([self = shared_from_this(), mark = oldLen + remaining]() { + self->highWaterMarkCallback_(self, mark); + }); + } + outputBuffer_.append(data + nwrote, remaining); + if (!channel_->isWriting()) { + channel_->enableWriting(); + } + } + } + + void shutdownInLoop() + { + loop_->assertInLoopThread(); + if (!channel_->isWriting()) { + socket_.shutdownWrite(); + } + } + + void forceCloseInLoop() + { + loop_->assertInLoopThread(); + if (state_ == kConnected || state_ == kDisconnecting) { + handleClose(); + } + } + + std::string stateToString() const + { + switch (state_) { + case kDisconnected: return "kDisconnected"; + case kConnecting: return "kConnecting"; + case kConnected: return "kConnected"; + case kDisconnecting: return "kDisconnecting"; + default: return "unknown state"; + } + } + +public: + TcpConnection(EventLoop* loop, const std::string& name, int sockfd, + const InetAddress& localAddr, const InetAddress& peerAddr) + : loop_(loop), socket_(sockfd), channel_(std::make_unique(loop, sockfd)), + localAddr_(localAddr), peerAddr_(peerAddr), name_(name), state_(kConnecting), + highWaterMark_(64*1024*1024) + { + channel_->setReadCallback([this]() { handleRead(); }); + channel_->setWriteCallback([this]() { handleWrite(); }); + channel_->setCloseCallback([this]() { handleClose(); }); + channel_->setErrorCallback([this]() { handleError(); }); + + socket_.setKeepAlive(true); + socket_.setTcpNoDelay(true); + + LOG_INFO << "TcpConnection " << name_ << " created from " + << localAddr_.toIpPort() << " to " << peerAddr_.toIpPort() << " fd=" << sockfd; + } + + ~TcpConnection() + { + LOG_INFO << "TcpConnection " << name_ << " destroyed state=" << stateToString(); + assert(state_ == kDisconnected); + } + + void connectEstablished() + { + loop_->assertInLoopThread(); + assert(state_ == kConnecting); + setState(kConnected); + channel_->tie(shared_from_this()); + channel_->enableReading(); + + if (connectionCallback_) { + connectionCallback_(shared_from_this()); + } + + LOG_INFO << "TcpConnection " << name_ << " established"; + } + + void connectDestroyed() + { + loop_->assertInLoopThread(); + if (state_ == kConnected) { + setState(kDisconnected); + channel_->disableAll(); + if (connectionCallback_) { + connectionCallback_(shared_from_this()); + } + } + channel_->remove(); + LOG_INFO << "TcpConnection " << name_ << " destroyed"; + } + + const std::string& name() const { return name_; } + const InetAddress& localAddr() const { return localAddr_; } + const InetAddress& peerAddr() const { return peerAddr_; } + bool connected() const { return state_ == kConnected; } + bool disconnected() const { return state_ == kDisconnected; } + EventLoop* getLoop() const { return loop_; } + + void send(const std::string& message) + { + if (state_ == kConnected) { + if (loop_->isInLoopThread()) { + sendInLoop(message.data(), message.size()); + } else { + loop_->runInLoop([self = shared_from_this(), message]() { + self->sendInLoop(message.data(), message.size()); + }); + } + } + } + + void send(const char* data, size_t len) + { + if (state_ == kConnected) { + if (loop_->isInLoopThread()) { + sendInLoop(data, len); + } else { + std::string message(data, len); + loop_->runInLoop([self = shared_from_this(), message]() { + self->sendInLoop(message.data(), message.size()); + }); + } + } + } + + void send(Buffer& buffer) + { + if (state_ == kConnected) { + if (loop_->isInLoopThread()) { + sendInLoop(buffer.peek(), buffer.readableBytes()); + buffer.retrieveAll(); + } else { + std::string message = buffer.readAll(); + loop_->runInLoop([self = shared_from_this(), message]() { + self->sendInLoop(message.data(), message.size()); + }); + } + } + } + + void shutdown() + { + if (state_ == kConnected) { + setState(kDisconnecting); + loop_->runInLoop([self = shared_from_this()]() { + self->shutdownInLoop(); + }); + } + } + + void forceClose() + { + if (state_ == kConnected || state_ == kDisconnecting) { + setState(kDisconnecting); + loop_->queueInLoop([self = shared_from_this()]() { + self->forceCloseInLoop(); + }); + } + } + + void forceCloseWithDelay(double seconds) + { + if (state_ == kConnected || state_ == kDisconnecting) { + setState(kDisconnecting); + loop_->runAfter(Duration(static_cast(seconds * 1000)), + [self = shared_from_this()]() { + self->forceClose(); + }); + } + } + + void setTcpNoDelay(bool on) { socket_.setTcpNoDelay(on); } + void setTcpKeepAlive(bool on) { socket_.setKeepAlive(on); } + + void setMessageCallback(MessageCallback cb) { messageCallback_ = std::move(cb); } + void setConnectionCallback(ConnectionCallback cb) { connectionCallback_ = std::move(cb); } + void setCloseCallback(ConnectionCallback cb) { closeCallback_ = std::move(cb); } + void setWriteCompleteCallback(WriteCompleteCallback cb) { writeCompleteCallback_ = std::move(cb); } + void setHighWaterMarkCallback(HighWaterMarkCallback cb, size_t highWaterMark) + { + highWaterMarkCallback_ = std::move(cb); + highWaterMark_ = highWaterMark; + } + + Buffer* inputBuffer() { return &inputBuffer_; } + Buffer* outputBuffer() { return &outputBuffer_; } + + void startRead() + { + loop_->runInLoop([self = shared_from_this()]() { + if (!self->channel_->isReading()) { + self->channel_->enableReading(); + } + }); + } + + void stopRead() + { + loop_->runInLoop([self = shared_from_this()]() { + if (self->channel_->isReading()) { + self->channel_->disableReading(); + } + }); + } +}; + +} // namespace reactor diff --git a/lib/TcpServer.hpp b/lib/TcpServer.hpp new file mode 100644 index 0000000..75a2e69 --- /dev/null +++ b/lib/TcpServer.hpp @@ -0,0 +1,236 @@ +#pragma once + +#include "Core.hpp" +#include "Socket.hpp" +#include "TcpConnection.hpp" +#include "EventLoopThread.hpp" +#include "Utilities.hpp" +#include +#include +#include +#include + +namespace reactor { + +using NewConnectionCallback = std::function; + +class Acceptor : public NonCopyable +{ +private: + EventLoop* loop_; + Socket acceptSocket_; + std::unique_ptr acceptChannel_; + NewConnectionCallback newConnectionCallback_; + bool listening_; + int idleFd_; + + void handleRead() + { + loop_->assertInLoopThread(); + InetAddress peerAddr; + int connfd = acceptSocket_.accept(peerAddr); + + if (connfd >= 0) { + if (newConnectionCallback_) { + newConnectionCallback_(connfd, peerAddr); + } else { + close(connfd); + LOG_WARN << "Acceptor no callback for new connection, closing fd=" << connfd; + } + } else { + LOG_ERROR << "Acceptor accept failed: " << strerror(errno); + if (errno == EMFILE) { + close(idleFd_); + idleFd_ = ::accept(acceptSocket_.fd(), nullptr, nullptr); + close(idleFd_); + idleFd_ = ::open("/dev/null", O_RDONLY | O_CLOEXEC); + } + } + } + +public: + Acceptor(EventLoop* loop, const InetAddress& listenAddr, bool reusePort = true) + : loop_(loop), acceptSocket_(Socket::createTcp(listenAddr.isIpV6())), + acceptChannel_(std::make_unique(loop, acceptSocket_.fd())), + listening_(false), idleFd_(::open("/dev/null", O_RDONLY | O_CLOEXEC)) + { + acceptSocket_.setReuseAddr(true); + if (reusePort) { + acceptSocket_.setReusePort(true); + } + acceptSocket_.bind(listenAddr); + acceptChannel_->setReadCallback([this]() { handleRead(); }); + LOG_INFO << "Acceptor created for " << listenAddr.toIpPort(); + } + + ~Acceptor() + { + acceptChannel_->disableAll(); + acceptChannel_->remove(); + close(idleFd_); + LOG_INFO << "Acceptor destroyed"; + } + + void listen() + { + loop_->assertInLoopThread(); + listening_ = true; + acceptSocket_.listen(); + acceptChannel_->enableReading(); + LOG_INFO << "Acceptor listening"; + } + + bool listening() const { return listening_; } + + void setNewConnectionCallback(NewConnectionCallback cb) + { + newConnectionCallback_ = std::move(cb); + } +}; + +class TcpServer : public NonCopyable +{ +private: + EventLoop* loop_; + std::string name_; + std::unique_ptr acceptor_; + std::unique_ptr threadPool_; + MessageCallback messageCallback_; + ConnectionCallback connectionCallback_; + WriteCompleteCallback writeCompleteCallback_; + + std::unordered_map connections_; + std::atomic nextConnId_; + bool started_; + + void newConnection(int sockfd, const InetAddress& peerAddr) + { + loop_->assertInLoopThread(); + EventLoop* ioLoop = threadPool_->getNextLoop(); + if (!ioLoop) ioLoop = loop_; + + std::string connName = name_ + "-" + peerAddr.toIpPort() + "#" + std::to_string(nextConnId_++); + InetAddress localAddr = Socket::getLocalAddr(sockfd); + + LOG_INFO << "TcpServer new connection " << connName << " from " << peerAddr.toIpPort(); + + auto conn = std::make_shared(ioLoop, connName, sockfd, localAddr, peerAddr); + connections_[connName] = conn; + + conn->setMessageCallback(messageCallback_); + conn->setConnectionCallback(connectionCallback_); + conn->setWriteCompleteCallback(writeCompleteCallback_); + conn->setCloseCallback([this](const TcpConnectionPtr& conn) { + removeConnection(conn); + }); + + ioLoop->runInLoop([conn]() { conn->connectEstablished(); }); + } + + void removeConnection(const TcpConnectionPtr& conn) + { + loop_->runInLoop([this, conn]() { + LOG_INFO << "TcpServer removing connection " << conn->name(); + size_t n = connections_.erase(conn->name()); + assert(n == 1); + + EventLoop* ioLoop = conn->getLoop(); + ioLoop->queueInLoop([conn]() { conn->connectDestroyed(); }); + }); + } + + void removeConnectionInLoop(const TcpConnectionPtr& conn) + { + loop_->assertInLoopThread(); + LOG_INFO << "TcpServer removing connection " << conn->name(); + size_t n = connections_.erase(conn->name()); + assert(n == 1); + + EventLoop* ioLoop = conn->getLoop(); + ioLoop->queueInLoop([conn]() { conn->connectDestroyed(); }); + } + +public: + TcpServer(EventLoop* loop, const InetAddress& listenAddr, const std::string& name, + bool reusePort = true) + : loop_(loop), name_(name), + acceptor_(std::make_unique(loop, listenAddr, reusePort)), + threadPool_(std::make_unique(0, name + "-EventLoop")), + nextConnId_(1), started_(false) + { + acceptor_->setNewConnectionCallback([this](int sockfd, const InetAddress& addr) { + newConnection(sockfd, addr); + }); + LOG_INFO << "TcpServer " << name_ << " created for " << listenAddr.toIpPort(); + } + + ~TcpServer() + { + loop_->assertInLoopThread(); + LOG_INFO << "TcpServer " << name_ << " destructing with " << connections_.size() << " connections"; + + for (auto& item : connections_) { + auto conn = item.second; + auto ioLoop = conn->getLoop(); + ioLoop->runInLoop([conn]() { conn->forceClose(); }); + } + } + + void setThreadNum(int numThreads) + { + assert(0 <= numThreads); + threadPool_ = std::make_unique(numThreads, name_ + "-EventLoop"); + LOG_INFO << "TcpServer " << name_ << " set thread pool size to " << numThreads; + } + + void start() + { + if (!started_) { + started_ = true; + if (!acceptor_->listening()) { + loop_->runInLoop([this]() { acceptor_->listen(); }); + } + LOG_INFO << "TcpServer " << name_ << " started with " << threadPool_->size() << " threads"; + } + } + + void setMessageCallback(MessageCallback cb) { messageCallback_ = std::move(cb); } + void setConnectionCallback(ConnectionCallback cb) { connectionCallback_ = std::move(cb); } + void setWriteCompleteCallback(WriteCompleteCallback cb) { writeCompleteCallback_ = std::move(cb); } + + const std::string& name() const { return name_; } + const char* ipPort() const { return acceptor_ ? "listening" : "not-listening"; } + EventLoop* getLoop() const { return loop_; } + + size_t numConnections() const + { + return connections_.size(); + } + + std::vector getConnections() const + { + std::vector result; + result.reserve(connections_.size()); + for (const auto& item : connections_) { + result.push_back(item.second); + } + return result; + } + + TcpConnectionPtr getConnection(const std::string& name) const + { + auto it = connections_.find(name); + return it != connections_.end() ? it->second : TcpConnectionPtr(); + } + + void forceCloseAllConnections() + { + for (auto& item : connections_) { + auto conn = item.second; + auto ioLoop = conn->getLoop(); + ioLoop->runInLoop([conn]() { conn->forceClose(); }); + } + } +}; + +} // namespace reactor diff --git a/lib/Utilities.hpp b/lib/Utilities.hpp new file mode 100644 index 0000000..2689c37 --- /dev/null +++ b/lib/Utilities.hpp @@ -0,0 +1,346 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace reactor { + +// NonCopyable base class +class NonCopyable +{ +protected: + NonCopyable() = default; + ~NonCopyable() = default; + NonCopyable(const NonCopyable&) = delete; + NonCopyable& operator=(const NonCopyable&) = delete; + NonCopyable(NonCopyable&&) noexcept = default; + NonCopyable& operator=(NonCopyable&&) noexcept = default; +}; + +// Network byte order utilities +inline uint64_t hton64(uint64_t n) +{ + static const int one = 1; + static const char sig = *(char*)&one; + if (sig == 0) return n; + char* ptr = reinterpret_cast(&n); + std::reverse(ptr, ptr + sizeof(uint64_t)); + return n; +} + +inline uint64_t ntoh64(uint64_t n) { return hton64(n); } + +// Lock-free MPSC queue +template +class LockFreeQueue : public NonCopyable +{ +private: + struct Node + { + Node() = default; + Node(const T& data) : data_(std::make_unique(data)) {} + Node(T&& data) : data_(std::make_unique(std::move(data))) {} + std::unique_ptr data_; + std::atomic next_{nullptr}; + }; + + std::atomic head_; + std::atomic tail_; + +public: + LockFreeQueue() : head_(new Node), tail_(head_.load()) {} + + ~LockFreeQueue() + { + T output; + while (dequeue(output)) {} + delete head_.load(); + } + + void enqueue(T&& input) + { + Node* node = new Node(std::move(input)); + Node* prevhead = head_.exchange(node, std::memory_order_acq_rel); + prevhead->next_.store(node, std::memory_order_release); + } + + void enqueue(const T& input) + { + Node* node = new Node(input); + Node* prevhead = head_.exchange(node, std::memory_order_acq_rel); + prevhead->next_.store(node, std::memory_order_release); + } + + bool dequeue(T& output) + { + Node* tail = tail_.load(std::memory_order_relaxed); + Node* next = tail->next_.load(std::memory_order_acquire); + + if (next == nullptr) return false; + + output = std::move(*next->data_); + tail_.store(next, std::memory_order_release); + delete tail; + return true; + } + + bool empty() + { + Node* tail = tail_.load(std::memory_order_relaxed); + Node* next = tail->next_.load(std::memory_order_acquire); + return next == nullptr; + } +}; + +// Object Pool +template +class ObjectPool : public NonCopyable, public std::enable_shared_from_this> +{ +private: + std::vector objects_; + std::mutex mutex_; + +public: + std::shared_ptr getObject() + { + static_assert(!std::is_pointer_v, "ObjectPool type cannot be pointer"); + + T* p = nullptr; + { + std::lock_guard lock(mutex_); + if (!objects_.empty()) { + p = objects_.back(); + objects_.pop_back(); + } + } + + if (!p) p = new T; + + std::weak_ptr> weakPtr = this->shared_from_this(); + return std::shared_ptr(p, [weakPtr](T* ptr) { + auto self = weakPtr.lock(); + if (self) { + std::lock_guard lock(self->mutex_); + self->objects_.push_back(ptr); + } else { + delete ptr; + } + }); + } +}; + +// Simple Logger +enum class LogLevel { TRACE, DEBUG, INFO, WARN, ERROR, FATAL }; + +class Logger : public NonCopyable +{ +private: + static inline LogLevel level_ = LogLevel::INFO; + static inline std::unique_ptr file_; + static inline std::mutex mutex_; + + std::ostringstream stream_; + LogLevel msgLevel_; + + static const char* levelString(LogLevel level) + { + switch (level) { + case LogLevel::TRACE: return "TRACE"; + case LogLevel::DEBUG: return "DEBUG"; + case LogLevel::INFO: return "INFO "; + case LogLevel::WARN: return "WARN "; + case LogLevel::ERROR: return "ERROR"; + case LogLevel::FATAL: return "FATAL"; + } + return "UNKNOWN"; + } + + static std::string timestamp() + { + auto now = std::chrono::system_clock::now(); + auto time_t = std::chrono::system_clock::to_time_t(now); + auto ms = std::chrono::duration_cast( + now.time_since_epoch()) % 1000; + + char buf[64]; + std::strftime(buf, sizeof(buf), "%Y%m%d %H:%M:%S", std::localtime(&time_t)); + return std::string(buf) + "." + std::to_string(ms.count()); + } + +public: + Logger(LogLevel level) : msgLevel_(level) + { + if (level >= level_) { + stream_ << timestamp() << " " << levelString(level) << " "; + } + } + + ~Logger() + { + if (msgLevel_ >= level_) { + stream_ << "\n"; + std::lock_guard lock(mutex_); + if (file_ && file_->is_open()) { + *file_ << stream_.str(); + file_->flush(); + } else { + std::cout << stream_.str(); + } + } + } + + template + Logger& operator<<(const T& value) + { + if (msgLevel_ >= level_) { + stream_ << value; + } + return *this; + } + + static void setLevel(LogLevel level) { level_ = level; } + static void setLogFile(const std::string& filename) + { + std::lock_guard lock(mutex_); + file_ = std::make_unique(filename, std::ios::app); + } +}; + +// Task Queue interface +class TaskQueue : public NonCopyable +{ +public: + virtual ~TaskQueue() = default; + virtual void runTaskInQueue(const std::function& task) = 0; + virtual void runTaskInQueue(std::function&& task) = 0; + virtual std::string getName() const { return ""; } + + void syncTaskInQueue(const std::function& task) + { + std::promise promise; + auto future = promise.get_future(); + runTaskInQueue([&]() { + task(); + promise.set_value(); + }); + future.wait(); + } +}; + +// Concurrent Task Queue +class ConcurrentTaskQueue : public TaskQueue +{ +private: + std::vector threads_; + std::queue> taskQueue_; + std::mutex taskMutex_; + std::condition_variable taskCond_; + std::atomic stop_{false}; + std::string name_; + + void workerThread(int threadId) + { + while (!stop_) { + std::function task; + { + std::unique_lock lock(taskMutex_); + taskCond_.wait(lock, [this]() { return stop_ || !taskQueue_.empty(); }); + + if (taskQueue_.empty()) continue; + + task = std::move(taskQueue_.front()); + taskQueue_.pop(); + } + task(); + } + } + +public: + ConcurrentTaskQueue(size_t threadNum, const std::string& name = "ConcurrentTaskQueue") + : name_(name) + { + for (size_t i = 0; i < threadNum; ++i) { + threads_.emplace_back(&ConcurrentTaskQueue::workerThread, this, i); + } + } + + ~ConcurrentTaskQueue() + { + stop_ = true; + taskCond_.notify_all(); + for (auto& t : threads_) { + if (t.joinable()) t.join(); + } + } + + void runTaskInQueue(const std::function& task) override + { + std::lock_guard lock(taskMutex_); + taskQueue_.push(task); + taskCond_.notify_one(); + } + + void runTaskInQueue(std::function&& task) override + { + std::lock_guard lock(taskMutex_); + taskQueue_.push(std::move(task)); + taskCond_.notify_one(); + } + + std::string getName() const override { return name_; } + + size_t getTaskCount() + { + std::lock_guard lock(taskMutex_); + return taskQueue_.size(); + } +}; + +// Logging macros +#define LOG_TRACE Logger(LogLevel::TRACE) +#define LOG_DEBUG Logger(LogLevel::DEBUG) +#define LOG_INFO Logger(LogLevel::INFO) +#define LOG_WARN Logger(LogLevel::WARN) +#define LOG_ERROR Logger(LogLevel::ERROR) +#define LOG_FATAL Logger(LogLevel::FATAL) + +// Utility functions +template +void hashCombine(std::size_t& seed, const T& value) +{ + std::hash hasher; + seed ^= hasher(value) + 0x9e3779b9 + (seed << 6) + (seed >> 2); +} + +inline std::vector splitString(const std::string& s, const std::string& delimiter) +{ + std::vector result; + size_t start = 0; + size_t end = s.find(delimiter); + + while (end != std::string::npos) { + result.push_back(s.substr(start, end - start)); + start = end + delimiter.length(); + end = s.find(delimiter, start); + } + result.push_back(s.substr(start)); + return result; +} + +} // namespace reactor