Reactor/lib/TcpConnection.hpp
2025-06-27 18:01:00 -05:00

365 lines
9.6 KiB
C++

#pragma once
#include "Core.hpp"
#include "Socket.hpp"
#include "Buffer.hpp"
#include "Utilities.hpp"
#include <memory>
#include <string>
#include <functional>
#include <errno.h>
namespace reactor {
class TcpConnection;
using TcpConnectionPtr = std::shared_ptr<TcpConnection>;
using MessageCallback = std::function<void(const TcpConnectionPtr&, Buffer&)>;
using ConnectionCallback = std::function<void(const TcpConnectionPtr&)>;
using WriteCompleteCallback = std::function<void(const TcpConnectionPtr&)>;
using HighWaterMarkCallback = std::function<void(const TcpConnectionPtr&, size_t)>;
class TcpConnection : public NonCopyable, public std::enable_shared_from_this<TcpConnection>
{
public:
enum StateE { kDisconnected, kConnecting, kConnected, kDisconnecting };
private:
EventLoop* loop_;
Socket socket_;
std::unique_ptr<Channel> channel_;
InetAddress localAddr_;
InetAddress peerAddr_;
std::string name_;
StateE state_;
Buffer inputBuffer_;
Buffer outputBuffer_;
MessageCallback messageCallback_;
ConnectionCallback connectionCallback_;
ConnectionCallback closeCallback_;
WriteCompleteCallback writeCompleteCallback_;
HighWaterMarkCallback highWaterMarkCallback_;
size_t highWaterMark_;
void setState(StateE s) { state_ = s; }
void handleRead()
{
loop_->assertInLoopThread();
int savedErrno = 0;
ssize_t n = inputBuffer_.readFd(socket_.fd(), &savedErrno);
if (n > 0) {
LOG_TRACE << "TcpConnection " << name_ << " read " << n << " bytes";
if (messageCallback_) {
messageCallback_(shared_from_this(), inputBuffer_);
}
} else if (n == 0) {
LOG_DEBUG << "TcpConnection " << name_ << " peer closed";
handleClose();
} else {
errno = savedErrno;
LOG_ERROR << "TcpConnection " << name_ << " read error: " << strerror(savedErrno);
handleError();
}
}
void handleWrite()
{
loop_->assertInLoopThread();
if (channel_->isWriting()) {
ssize_t n = socket_.write(outputBuffer_.peek(), outputBuffer_.readableBytes());
if (n > 0) {
outputBuffer_.retrieve(n);
LOG_TRACE << "TcpConnection " << name_ << " wrote " << n << " bytes, "
<< outputBuffer_.readableBytes() << " bytes left";
if (outputBuffer_.readableBytes() == 0) {
channel_->disableWriting();
if (writeCompleteCallback_) {
loop_->queueInLoop([self = shared_from_this()]() {
self->writeCompleteCallback_(self);
});
}
if (state_ == kDisconnecting) {
shutdownInLoop();
}
}
} else {
LOG_ERROR << "TcpConnection " << name_ << " write error: " << strerror(errno);
}
} else {
LOG_TRACE << "TcpConnection " << name_ << " not writing, ignore";
}
}
void handleClose()
{
loop_->assertInLoopThread();
LOG_DEBUG << "TcpConnection " << name_ << " state=" << stateToString();
assert(state_ == kConnected || state_ == kDisconnecting);
setState(kDisconnected);
channel_->disableAll();
auto guardThis = shared_from_this();
if (connectionCallback_) {
connectionCallback_(guardThis);
}
if (closeCallback_) {
closeCallback_(guardThis);
}
}
void handleError()
{
int err = socket_.getSocketError();
LOG_ERROR << "TcpConnection " << name_ << " SO_ERROR=" << err << " " << strerror(err);
handleClose();
}
void sendInLoop(const char* data, size_t len)
{
loop_->assertInLoopThread();
ssize_t nwrote = 0;
size_t remaining = len;
bool faultError = false;
if (state_ == kDisconnected) {
LOG_WARN << "TcpConnection " << name_ << " disconnected, give up writing";
return;
}
if (!channel_->isWriting() && outputBuffer_.readableBytes() == 0) {
nwrote = socket_.write(data, len);
if (nwrote >= 0) {
remaining = len - nwrote;
if (remaining == 0 && writeCompleteCallback_) {
loop_->queueInLoop([self = shared_from_this()]() {
self->writeCompleteCallback_(self);
});
}
} else {
nwrote = 0;
if (errno != EWOULDBLOCK) {
LOG_ERROR << "TcpConnection " << name_ << " send error: " << strerror(errno);
if (errno == EPIPE || errno == ECONNRESET) {
faultError = true;
}
}
}
}
assert(remaining <= len);
if (!faultError && remaining > 0) {
size_t oldLen = outputBuffer_.readableBytes();
if (oldLen + remaining >= highWaterMark_ &&
oldLen < highWaterMark_ &&
highWaterMarkCallback_) {
loop_->queueInLoop([self = shared_from_this(), mark = oldLen + remaining]() {
self->highWaterMarkCallback_(self, mark);
});
}
outputBuffer_.append(data + nwrote, remaining);
if (!channel_->isWriting()) {
channel_->enableWriting();
}
}
}
void shutdownInLoop()
{
loop_->assertInLoopThread();
if (!channel_->isWriting()) {
socket_.shutdownWrite();
}
}
void forceCloseInLoop()
{
loop_->assertInLoopThread();
if (state_ == kConnected || state_ == kDisconnecting) {
handleClose();
}
}
std::string stateToString() const
{
switch (state_) {
case kDisconnected: return "kDisconnected";
case kConnecting: return "kConnecting";
case kConnected: return "kConnected";
case kDisconnecting: return "kDisconnecting";
default: return "unknown state";
}
}
public:
TcpConnection(EventLoop* loop, const std::string& name, int sockfd,
const InetAddress& localAddr, const InetAddress& peerAddr)
: loop_(loop), socket_(sockfd), channel_(std::make_unique<Channel>(loop, sockfd)),
localAddr_(localAddr), peerAddr_(peerAddr), name_(name), state_(kConnecting),
highWaterMark_(64*1024*1024)
{
channel_->setReadCallback([this]() { handleRead(); });
channel_->setWriteCallback([this]() { handleWrite(); });
channel_->setCloseCallback([this]() { handleClose(); });
channel_->setErrorCallback([this]() { handleError(); });
socket_.setKeepAlive(true);
socket_.setTcpNoDelay(true);
LOG_INFO << "TcpConnection " << name_ << " created from "
<< localAddr_.toIpPort() << " to " << peerAddr_.toIpPort() << " fd=" << sockfd;
}
~TcpConnection()
{
LOG_INFO << "TcpConnection " << name_ << " destroyed state=" << stateToString();
assert(state_ == kDisconnected);
}
void connectEstablished()
{
loop_->assertInLoopThread();
assert(state_ == kConnecting);
setState(kConnected);
channel_->tie(shared_from_this());
channel_->enableReading();
if (connectionCallback_) {
connectionCallback_(shared_from_this());
}
LOG_INFO << "TcpConnection " << name_ << " established";
}
void connectDestroyed()
{
loop_->assertInLoopThread();
if (state_ == kConnected) {
setState(kDisconnected);
channel_->disableAll();
if (connectionCallback_) {
connectionCallback_(shared_from_this());
}
}
channel_->remove();
LOG_INFO << "TcpConnection " << name_ << " destroyed";
}
const std::string& name() const { return name_; }
const InetAddress& localAddr() const { return localAddr_; }
const InetAddress& peerAddr() const { return peerAddr_; }
bool connected() const { return state_ == kConnected; }
bool disconnected() const { return state_ == kDisconnected; }
EventLoop* getLoop() const { return loop_; }
void send(const std::string& message)
{
if (state_ == kConnected) {
if (loop_->isInLoopThread()) {
sendInLoop(message.data(), message.size());
} else {
loop_->runInLoop([self = shared_from_this(), message]() {
self->sendInLoop(message.data(), message.size());
});
}
}
}
void send(const char* data, size_t len)
{
if (state_ == kConnected) {
if (loop_->isInLoopThread()) {
sendInLoop(data, len);
} else {
std::string message(data, len);
loop_->runInLoop([self = shared_from_this(), message]() {
self->sendInLoop(message.data(), message.size());
});
}
}
}
void send(Buffer& buffer)
{
if (state_ == kConnected) {
if (loop_->isInLoopThread()) {
sendInLoop(buffer.peek(), buffer.readableBytes());
buffer.retrieveAll();
} else {
std::string message = buffer.readAll();
loop_->runInLoop([self = shared_from_this(), message]() {
self->sendInLoop(message.data(), message.size());
});
}
}
}
void shutdown()
{
if (state_ == kConnected) {
setState(kDisconnecting);
loop_->runInLoop([self = shared_from_this()]() {
self->shutdownInLoop();
});
}
}
void forceClose()
{
if (state_ == kConnected || state_ == kDisconnecting) {
setState(kDisconnecting);
loop_->queueInLoop([self = shared_from_this()]() {
self->forceCloseInLoop();
});
}
}
void forceCloseWithDelay(double seconds)
{
if (state_ == kConnected || state_ == kDisconnecting) {
setState(kDisconnecting);
loop_->runAfter(Duration(static_cast<int>(seconds * 1000)),
[self = shared_from_this()]() {
self->forceClose();
});
}
}
void setTcpNoDelay(bool on) { socket_.setTcpNoDelay(on); }
void setTcpKeepAlive(bool on) { socket_.setKeepAlive(on); }
void setMessageCallback(MessageCallback cb) { messageCallback_ = std::move(cb); }
void setConnectionCallback(ConnectionCallback cb) { connectionCallback_ = std::move(cb); }
void setCloseCallback(ConnectionCallback cb) { closeCallback_ = std::move(cb); }
void setWriteCompleteCallback(WriteCompleteCallback cb) { writeCompleteCallback_ = std::move(cb); }
void setHighWaterMarkCallback(HighWaterMarkCallback cb, size_t highWaterMark)
{
highWaterMarkCallback_ = std::move(cb);
highWaterMark_ = highWaterMark;
}
Buffer* inputBuffer() { return &inputBuffer_; }
Buffer* outputBuffer() { return &outputBuffer_; }
void startRead()
{
loop_->runInLoop([self = shared_from_this()]() {
if (!self->channel_->isReading()) {
self->channel_->enableReading();
}
});
}
void stopRead()
{
loop_->runInLoop([self = shared_from_this()]() {
if (self->channel_->isReading()) {
self->channel_->disableReading();
}
});
}
};
} // namespace reactor