365 lines
9.6 KiB
C++
365 lines
9.6 KiB
C++
#pragma once
|
|
|
|
#include "Core.hpp"
|
|
#include "Socket.hpp"
|
|
#include "Buffer.hpp"
|
|
#include "Utilities.hpp"
|
|
#include <memory>
|
|
#include <string>
|
|
#include <functional>
|
|
#include <errno.h>
|
|
|
|
namespace reactor {
|
|
|
|
class TcpConnection;
|
|
using TcpConnectionPtr = std::shared_ptr<TcpConnection>;
|
|
using MessageCallback = std::function<void(const TcpConnectionPtr&, Buffer&)>;
|
|
using ConnectionCallback = std::function<void(const TcpConnectionPtr&)>;
|
|
using WriteCompleteCallback = std::function<void(const TcpConnectionPtr&)>;
|
|
using HighWaterMarkCallback = std::function<void(const TcpConnectionPtr&, size_t)>;
|
|
|
|
class TcpConnection : public NonCopyable, public std::enable_shared_from_this<TcpConnection>
|
|
{
|
|
public:
|
|
enum StateE { kDisconnected, kConnecting, kConnected, kDisconnecting };
|
|
|
|
private:
|
|
EventLoop* loop_;
|
|
Socket socket_;
|
|
std::unique_ptr<Channel> channel_;
|
|
InetAddress localAddr_;
|
|
InetAddress peerAddr_;
|
|
std::string name_;
|
|
StateE state_;
|
|
Buffer inputBuffer_;
|
|
Buffer outputBuffer_;
|
|
MessageCallback messageCallback_;
|
|
ConnectionCallback connectionCallback_;
|
|
ConnectionCallback closeCallback_;
|
|
WriteCompleteCallback writeCompleteCallback_;
|
|
HighWaterMarkCallback highWaterMarkCallback_;
|
|
size_t highWaterMark_;
|
|
|
|
void setState(StateE s) { state_ = s; }
|
|
|
|
void handleRead()
|
|
{
|
|
loop_->assertInLoopThread();
|
|
int savedErrno = 0;
|
|
ssize_t n = inputBuffer_.readFd(socket_.fd(), &savedErrno);
|
|
|
|
if (n > 0) {
|
|
LOG_TRACE << "TcpConnection " << name_ << " read " << n << " bytes";
|
|
if (messageCallback_) {
|
|
messageCallback_(shared_from_this(), inputBuffer_);
|
|
}
|
|
} else if (n == 0) {
|
|
LOG_DEBUG << "TcpConnection " << name_ << " peer closed";
|
|
handleClose();
|
|
} else {
|
|
errno = savedErrno;
|
|
LOG_ERROR << "TcpConnection " << name_ << " read error: " << strerror(savedErrno);
|
|
handleError();
|
|
}
|
|
}
|
|
|
|
void handleWrite()
|
|
{
|
|
loop_->assertInLoopThread();
|
|
if (channel_->isWriting()) {
|
|
ssize_t n = socket_.write(outputBuffer_.peek(), outputBuffer_.readableBytes());
|
|
if (n > 0) {
|
|
outputBuffer_.retrieve(n);
|
|
LOG_TRACE << "TcpConnection " << name_ << " wrote " << n << " bytes, "
|
|
<< outputBuffer_.readableBytes() << " bytes left";
|
|
|
|
if (outputBuffer_.readableBytes() == 0) {
|
|
channel_->disableWriting();
|
|
if (writeCompleteCallback_) {
|
|
loop_->queueInLoop([self = shared_from_this()]() {
|
|
self->writeCompleteCallback_(self);
|
|
});
|
|
}
|
|
if (state_ == kDisconnecting) {
|
|
shutdownInLoop();
|
|
}
|
|
}
|
|
} else {
|
|
LOG_ERROR << "TcpConnection " << name_ << " write error: " << strerror(errno);
|
|
}
|
|
} else {
|
|
LOG_TRACE << "TcpConnection " << name_ << " not writing, ignore";
|
|
}
|
|
}
|
|
|
|
void handleClose()
|
|
{
|
|
loop_->assertInLoopThread();
|
|
LOG_DEBUG << "TcpConnection " << name_ << " state=" << stateToString();
|
|
assert(state_ == kConnected || state_ == kDisconnecting);
|
|
setState(kDisconnected);
|
|
channel_->disableAll();
|
|
|
|
auto guardThis = shared_from_this();
|
|
if (connectionCallback_) {
|
|
connectionCallback_(guardThis);
|
|
}
|
|
if (closeCallback_) {
|
|
closeCallback_(guardThis);
|
|
}
|
|
}
|
|
|
|
void handleError()
|
|
{
|
|
int err = socket_.getSocketError();
|
|
LOG_ERROR << "TcpConnection " << name_ << " SO_ERROR=" << err << " " << strerror(err);
|
|
handleClose();
|
|
}
|
|
|
|
void sendInLoop(const char* data, size_t len)
|
|
{
|
|
loop_->assertInLoopThread();
|
|
ssize_t nwrote = 0;
|
|
size_t remaining = len;
|
|
bool faultError = false;
|
|
|
|
if (state_ == kDisconnected) {
|
|
LOG_WARN << "TcpConnection " << name_ << " disconnected, give up writing";
|
|
return;
|
|
}
|
|
|
|
if (!channel_->isWriting() && outputBuffer_.readableBytes() == 0) {
|
|
nwrote = socket_.write(data, len);
|
|
if (nwrote >= 0) {
|
|
remaining = len - nwrote;
|
|
if (remaining == 0 && writeCompleteCallback_) {
|
|
loop_->queueInLoop([self = shared_from_this()]() {
|
|
self->writeCompleteCallback_(self);
|
|
});
|
|
}
|
|
} else {
|
|
nwrote = 0;
|
|
if (errno != EWOULDBLOCK) {
|
|
LOG_ERROR << "TcpConnection " << name_ << " send error: " << strerror(errno);
|
|
if (errno == EPIPE || errno == ECONNRESET) {
|
|
faultError = true;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
assert(remaining <= len);
|
|
if (!faultError && remaining > 0) {
|
|
size_t oldLen = outputBuffer_.readableBytes();
|
|
if (oldLen + remaining >= highWaterMark_ &&
|
|
oldLen < highWaterMark_ &&
|
|
highWaterMarkCallback_) {
|
|
loop_->queueInLoop([self = shared_from_this(), mark = oldLen + remaining]() {
|
|
self->highWaterMarkCallback_(self, mark);
|
|
});
|
|
}
|
|
outputBuffer_.append(data + nwrote, remaining);
|
|
if (!channel_->isWriting()) {
|
|
channel_->enableWriting();
|
|
}
|
|
}
|
|
}
|
|
|
|
void shutdownInLoop()
|
|
{
|
|
loop_->assertInLoopThread();
|
|
if (!channel_->isWriting()) {
|
|
socket_.shutdownWrite();
|
|
}
|
|
}
|
|
|
|
void forceCloseInLoop()
|
|
{
|
|
loop_->assertInLoopThread();
|
|
if (state_ == kConnected || state_ == kDisconnecting) {
|
|
handleClose();
|
|
}
|
|
}
|
|
|
|
std::string stateToString() const
|
|
{
|
|
switch (state_) {
|
|
case kDisconnected: return "kDisconnected";
|
|
case kConnecting: return "kConnecting";
|
|
case kConnected: return "kConnected";
|
|
case kDisconnecting: return "kDisconnecting";
|
|
default: return "unknown state";
|
|
}
|
|
}
|
|
|
|
public:
|
|
TcpConnection(EventLoop* loop, const std::string& name, int sockfd,
|
|
const InetAddress& localAddr, const InetAddress& peerAddr)
|
|
: loop_(loop), socket_(sockfd), channel_(std::make_unique<Channel>(loop, sockfd)),
|
|
localAddr_(localAddr), peerAddr_(peerAddr), name_(name), state_(kConnecting),
|
|
highWaterMark_(64*1024*1024)
|
|
{
|
|
channel_->setReadCallback([this]() { handleRead(); });
|
|
channel_->setWriteCallback([this]() { handleWrite(); });
|
|
channel_->setCloseCallback([this]() { handleClose(); });
|
|
channel_->setErrorCallback([this]() { handleError(); });
|
|
|
|
socket_.setKeepAlive(true);
|
|
socket_.setTcpNoDelay(true);
|
|
|
|
LOG_INFO << "TcpConnection " << name_ << " created from "
|
|
<< localAddr_.toIpPort() << " to " << peerAddr_.toIpPort() << " fd=" << sockfd;
|
|
}
|
|
|
|
~TcpConnection()
|
|
{
|
|
LOG_INFO << "TcpConnection " << name_ << " destroyed state=" << stateToString();
|
|
assert(state_ == kDisconnected);
|
|
}
|
|
|
|
void connectEstablished()
|
|
{
|
|
loop_->assertInLoopThread();
|
|
assert(state_ == kConnecting);
|
|
setState(kConnected);
|
|
channel_->tie(shared_from_this());
|
|
channel_->enableReading();
|
|
|
|
if (connectionCallback_) {
|
|
connectionCallback_(shared_from_this());
|
|
}
|
|
|
|
LOG_INFO << "TcpConnection " << name_ << " established";
|
|
}
|
|
|
|
void connectDestroyed()
|
|
{
|
|
loop_->assertInLoopThread();
|
|
if (state_ == kConnected) {
|
|
setState(kDisconnected);
|
|
channel_->disableAll();
|
|
if (connectionCallback_) {
|
|
connectionCallback_(shared_from_this());
|
|
}
|
|
}
|
|
channel_->remove();
|
|
LOG_INFO << "TcpConnection " << name_ << " destroyed";
|
|
}
|
|
|
|
const std::string& name() const { return name_; }
|
|
const InetAddress& localAddr() const { return localAddr_; }
|
|
const InetAddress& peerAddr() const { return peerAddr_; }
|
|
bool connected() const { return state_ == kConnected; }
|
|
bool disconnected() const { return state_ == kDisconnected; }
|
|
EventLoop* getLoop() const { return loop_; }
|
|
|
|
void send(const std::string& message)
|
|
{
|
|
if (state_ == kConnected) {
|
|
if (loop_->isInLoopThread()) {
|
|
sendInLoop(message.data(), message.size());
|
|
} else {
|
|
loop_->runInLoop([self = shared_from_this(), message]() {
|
|
self->sendInLoop(message.data(), message.size());
|
|
});
|
|
}
|
|
}
|
|
}
|
|
|
|
void send(const char* data, size_t len)
|
|
{
|
|
if (state_ == kConnected) {
|
|
if (loop_->isInLoopThread()) {
|
|
sendInLoop(data, len);
|
|
} else {
|
|
std::string message(data, len);
|
|
loop_->runInLoop([self = shared_from_this(), message]() {
|
|
self->sendInLoop(message.data(), message.size());
|
|
});
|
|
}
|
|
}
|
|
}
|
|
|
|
void send(Buffer& buffer)
|
|
{
|
|
if (state_ == kConnected) {
|
|
if (loop_->isInLoopThread()) {
|
|
sendInLoop(buffer.peek(), buffer.readableBytes());
|
|
buffer.retrieveAll();
|
|
} else {
|
|
std::string message = buffer.readAll();
|
|
loop_->runInLoop([self = shared_from_this(), message]() {
|
|
self->sendInLoop(message.data(), message.size());
|
|
});
|
|
}
|
|
}
|
|
}
|
|
|
|
void shutdown()
|
|
{
|
|
if (state_ == kConnected) {
|
|
setState(kDisconnecting);
|
|
loop_->runInLoop([self = shared_from_this()]() {
|
|
self->shutdownInLoop();
|
|
});
|
|
}
|
|
}
|
|
|
|
void forceClose()
|
|
{
|
|
if (state_ == kConnected || state_ == kDisconnecting) {
|
|
setState(kDisconnecting);
|
|
loop_->queueInLoop([self = shared_from_this()]() {
|
|
self->forceCloseInLoop();
|
|
});
|
|
}
|
|
}
|
|
|
|
void forceCloseWithDelay(double seconds)
|
|
{
|
|
if (state_ == kConnected || state_ == kDisconnecting) {
|
|
setState(kDisconnecting);
|
|
loop_->runAfter(Duration(static_cast<int>(seconds * 1000)),
|
|
[self = shared_from_this()]() {
|
|
self->forceClose();
|
|
});
|
|
}
|
|
}
|
|
|
|
void setTcpNoDelay(bool on) { socket_.setTcpNoDelay(on); }
|
|
void setTcpKeepAlive(bool on) { socket_.setKeepAlive(on); }
|
|
|
|
void setMessageCallback(MessageCallback cb) { messageCallback_ = std::move(cb); }
|
|
void setConnectionCallback(ConnectionCallback cb) { connectionCallback_ = std::move(cb); }
|
|
void setCloseCallback(ConnectionCallback cb) { closeCallback_ = std::move(cb); }
|
|
void setWriteCompleteCallback(WriteCompleteCallback cb) { writeCompleteCallback_ = std::move(cb); }
|
|
void setHighWaterMarkCallback(HighWaterMarkCallback cb, size_t highWaterMark)
|
|
{
|
|
highWaterMarkCallback_ = std::move(cb);
|
|
highWaterMark_ = highWaterMark;
|
|
}
|
|
|
|
Buffer* inputBuffer() { return &inputBuffer_; }
|
|
Buffer* outputBuffer() { return &outputBuffer_; }
|
|
|
|
void startRead()
|
|
{
|
|
loop_->runInLoop([self = shared_from_this()]() {
|
|
if (!self->channel_->isReading()) {
|
|
self->channel_->enableReading();
|
|
}
|
|
});
|
|
}
|
|
|
|
void stopRead()
|
|
{
|
|
loop_->runInLoop([self = shared_from_this()]() {
|
|
if (self->channel_->isReading()) {
|
|
self->channel_->disableReading();
|
|
}
|
|
});
|
|
}
|
|
};
|
|
|
|
} // namespace reactor
|