#pragma once #include "Core.hpp" #include "Socket.hpp" #include "Buffer.hpp" #include "Utilities.hpp" #include #include #include #include namespace reactor { class TcpConnection; using TcpConnectionPtr = std::shared_ptr; using MessageCallback = std::function; using ConnectionCallback = std::function; using WriteCompleteCallback = std::function; using HighWaterMarkCallback = std::function; /* * Represents a single TCP connection. * Manages the lifetime of a socket. */ class TcpConnection : public NonCopyable, public std::enable_shared_from_this { public: enum StateE { kDisconnected, kConnecting, kConnected, kDisconnecting }; /* * Constructs a TcpConnection, taking ownership of an existing socket. */ TcpConnection(EventLoop* loop, const std::string& name, Socket&& socket, const InetAddress& localAddr, const InetAddress& peerAddr) : loop_(loop), socket_(std::move(socket)), channel_(std::make_unique(loop, socket_.fd())), localAddr_(localAddr), peerAddr_(peerAddr), name_(name), state_(kConnecting), highWaterMark_(64 * 1024 * 1024) { channel_->setReadCallback([this] { handleRead(); }); channel_->setWriteCallback([this] { handleWrite(); }); channel_->setCloseCallback([this] { handleClose(); }); channel_->setErrorCallback([this] { handleError(); }); socket_.setKeepAlive(true); socket_.setTcpNoDelay(true); LOG_INFO << "TcpConnection " << name_ << " created from " << localAddr_.toIpPort() << " to " << peerAddr_.toIpPort() << " fd=" << socket_.fd(); } /* * Destroys the TcpConnection. */ ~TcpConnection() { LOG_INFO << "TcpConnection " << name_ << " destroyed state=" << stateToString(); assert(state_ == kDisconnected); } /* * Called when the connection is established. */ void connectEstablished() { loop_->assertInLoopThread(); assert(state_ == kConnecting); setState(kConnected); channel_->tie(shared_from_this()); channel_->enableReading(); if (connectionCallback_) { connectionCallback_(shared_from_this()); } LOG_INFO << "TcpConnection " << name_ << " established"; } /* * Called when the connection is to be destroyed. */ void connectDestroyed() { loop_->assertInLoopThread(); if (state_ == kConnected) { setState(kDisconnected); channel_->disableAll(); if (connectionCallback_) { connectionCallback_(shared_from_this()); } } channel_->remove(); LOG_INFO << "TcpConnection " << name_ << " destroyed"; } /* * Sends data. This is thread-safe. */ void send(const std::string& message) { if (state_ == kConnected) { if (loop_->isInLoopThread()) { sendInLoop(message.data(), message.size()); } else { loop_->runInLoop([self = shared_from_this(), message]() { self->sendInLoop(message.data(), message.size()); }); } } } /* * Sends data. This is thread-safe. */ void send(const char* data, size_t len) { if (state_ == kConnected) { if (loop_->isInLoopThread()) { sendInLoop(data, len); } else { std::string message(data, len); loop_->runInLoop([self = shared_from_this(), message]() { self->sendInLoop(message.data(), message.size()); }); } } } /* * Sends data from a buffer. This is thread-safe. */ void send(Buffer& buffer) { if (state_ == kConnected) { if (loop_->isInLoopThread()) { sendInLoop(buffer.peek(), buffer.readableBytes()); buffer.retrieveAll(); } else { std::string message = buffer.readAll(); loop_->runInLoop([self = shared_from_this(), message]() { self->sendInLoop(message.data(), message.size()); }); } } } /* * Shuts down the write-half of the connection. */ void shutdown() { if (state_ == kConnected) { setState(kDisconnecting); loop_->runInLoop([self = shared_from_this()]() { self->shutdownInLoop(); }); } } /* * Forcibly closes the connection. */ void forceClose() { if (state_ == kConnected || state_ == kDisconnecting) { setState(kDisconnecting); loop_->queueInLoop([self = shared_from_this()]() { self->forceCloseInLoop(); }); } } /* * Forcibly closes the connection after a delay. */ void forceCloseWithDelay(double seconds) { if (state_ == kConnected || state_ == kDisconnecting) { setState(kDisconnecting); loop_->runAfter(Duration(static_cast(seconds * 1000)), [self = shared_from_this()]() { self->forceClose(); }); } } void startRead() { loop_->runInLoop([self = shared_from_this()]() { if (!self->channel_->isReading()) { self->channel_->enableReading(); } }); } void stopRead() { loop_->runInLoop([self = shared_from_this()]() { if (self->channel_->isReading()) { self->channel_->disableReading(); } }); } const std::string& name() const { return name_; } const InetAddress& localAddr() const { return localAddr_; } const InetAddress& peerAddr() const { return peerAddr_; } bool connected() const { return state_ == kConnected; } bool disconnected() const { return state_ == kDisconnected; } EventLoop* getLoop() const { return loop_; } void setTcpNoDelay(bool on) { socket_.setTcpNoDelay(on); } void setTcpKeepAlive(bool on) { socket_.setKeepAlive(on); } void setMessageCallback(MessageCallback cb) { messageCallback_ = std::move(cb); } void setConnectionCallback(ConnectionCallback cb) { connectionCallback_ = std::move(cb); } void setCloseCallback(ConnectionCallback cb) { closeCallback_ = std::move(cb); } void setWriteCompleteCallback(WriteCompleteCallback cb) { writeCompleteCallback_ = std::move(cb); } void setHighWaterMarkCallback(HighWaterMarkCallback cb, size_t highWaterMark) { highWaterMarkCallback_ = std::move(cb); highWaterMark_ = highWaterMark; } Buffer* inputBuffer() { return &inputBuffer_; } Buffer* outputBuffer() { return &outputBuffer_; } private: /* * Sets the internal state of the connection. */ void setState(StateE s) { state_ = s; } /* * Handles readable events on the socket. */ void handleRead() { loop_->assertInLoopThread(); int savedErrno = 0; ssize_t n = inputBuffer_.readFd(socket_.fd(), &savedErrno); if (n > 0) { LOG_TRACE << "TcpConnection " << name_ << " read " << n << " bytes"; if (messageCallback_) { messageCallback_(shared_from_this(), inputBuffer_); } } else if (n == 0) { LOG_DEBUG << "TcpConnection " << name_ << " peer closed"; handleClose(); } else { errno = savedErrno; LOG_ERROR << "TcpConnection " << name_ << " read error: " << strerror(savedErrno); handleError(); } } /* * Handles writable events on the socket. */ void handleWrite() { loop_->assertInLoopThread(); if (channel_->isWriting()) { ssize_t n = socket_.write(outputBuffer_.peek(), outputBuffer_.readableBytes()); if (n > 0) { outputBuffer_.retrieve(n); LOG_TRACE << "TcpConnection " << name_ << " wrote " << n << " bytes, " << outputBuffer_.readableBytes() << " bytes left"; if (outputBuffer_.readableBytes() == 0) { channel_->disableWriting(); if (writeCompleteCallback_) { loop_->queueInLoop([self = shared_from_this()]() { self->writeCompleteCallback_(self); }); } if (state_ == kDisconnecting) { shutdownInLoop(); } } } else { LOG_ERROR << "TcpConnection " << name_ << " write error: " << strerror(errno); } } else { LOG_TRACE << "TcpConnection " << name_ << " not writing, ignore"; } } /* * Handles connection close events. */ void handleClose() { loop_->assertInLoopThread(); LOG_DEBUG << "TcpConnection " << name_ << " state=" << stateToString(); assert(state_ == kConnected || state_ == kDisconnecting); setState(kDisconnected); channel_->disableAll(); auto guardThis = shared_from_this(); if (connectionCallback_) { connectionCallback_(guardThis); } if (closeCallback_) { closeCallback_(guardThis); } } /* * Handles socket error events. */ void handleError() { int err = socket_.getSocketError(); LOG_ERROR << "TcpConnection " << name_ << " SO_ERROR=" << err << " " << strerror(err); handleClose(); } /* * Sends data within the event loop. */ void sendInLoop(const char* data, size_t len) { loop_->assertInLoopThread(); ssize_t nwrote = 0; size_t remaining = len; bool faultError = false; if (state_ == kDisconnected) { LOG_WARN << "TcpConnection " << name_ << " disconnected, give up writing"; return; } if (!channel_->isWriting() && outputBuffer_.readableBytes() == 0) { nwrote = socket_.write(data, len); if (nwrote >= 0) { remaining = len - nwrote; if (remaining == 0 && writeCompleteCallback_) { loop_->queueInLoop([self = shared_from_this()]() { self->writeCompleteCallback_(self); }); } } else { nwrote = 0; if (errno != EWOULDBLOCK) { LOG_ERROR << "TcpConnection " << name_ << " send error: " << strerror(errno); if (errno == EPIPE || errno == ECONNRESET) { faultError = true; } } } } assert(remaining <= len); if (!faultError && remaining > 0) { size_t oldLen = outputBuffer_.readableBytes(); if (oldLen + remaining >= highWaterMark_ && oldLen < highWaterMark_ && highWaterMarkCallback_) { loop_->queueInLoop([self = shared_from_this(), mark = oldLen + remaining]() { self->highWaterMarkCallback_(self, mark); }); } outputBuffer_.append(data + nwrote, remaining); if (!channel_->isWriting()) { channel_->enableWriting(); } } } /* * Shuts down the connection within the event loop. */ void shutdownInLoop() { loop_->assertInLoopThread(); if (!channel_->isWriting()) { socket_.shutdownWrite(); } } /* * Forcibly closes the connection within the event loop. */ void forceCloseInLoop() { loop_->assertInLoopThread(); if (state_ == kConnected || state_ == kDisconnecting) { handleClose(); } } /* * Converts the current state to a string for logging. */ std::string stateToString() const { switch (state_) { case kDisconnected: return "kDisconnected"; case kConnecting: return "kConnecting"; case kConnected: return "kConnected"; case kDisconnecting: return "kDisconnecting"; default: return "unknown state"; } } EventLoop* loop_; Socket socket_; std::unique_ptr channel_; InetAddress localAddr_; InetAddress peerAddr_; std::string name_; StateE state_; Buffer inputBuffer_; Buffer outputBuffer_; MessageCallback messageCallback_; ConnectionCallback connectionCallback_; ConnectionCallback closeCallback_; WriteCompleteCallback writeCompleteCallback_; HighWaterMarkCallback highWaterMarkCallback_; size_t highWaterMark_; }; } // namespace reactor