426 lines
11 KiB
C++
426 lines
11 KiB
C++
#pragma once
|
|
|
|
#include "Core.hpp"
|
|
#include "Socket.hpp"
|
|
#include "Buffer.hpp"
|
|
#include "Utilities.hpp"
|
|
#include <memory>
|
|
#include <string>
|
|
#include <functional>
|
|
#include <errno.h>
|
|
|
|
namespace reactor
|
|
{
|
|
|
|
class TcpConnection;
|
|
using TcpConnectionPtr = std::shared_ptr<TcpConnection>;
|
|
using MessageCallback = std::function<void(const TcpConnectionPtr&, Buffer&)>;
|
|
using ConnectionCallback = std::function<void(const TcpConnectionPtr&)>;
|
|
using WriteCompleteCallback = std::function<void(const TcpConnectionPtr&)>;
|
|
using HighWaterMarkCallback = std::function<void(const TcpConnectionPtr&, size_t)>;
|
|
|
|
/*
|
|
* Represents a single TCP connection.
|
|
* Manages the lifetime of a socket.
|
|
*/
|
|
class TcpConnection : public NonCopyable, public std::enable_shared_from_this<TcpConnection>
|
|
{
|
|
public:
|
|
enum StateE { kDisconnected, kConnecting, kConnected, kDisconnecting };
|
|
|
|
/*
|
|
* Constructs a TcpConnection, taking ownership of an existing socket.
|
|
*/
|
|
TcpConnection(EventLoop* loop, const std::string& name, Socket&& socket,
|
|
const InetAddress& localAddr, const InetAddress& peerAddr)
|
|
: loop_(loop),
|
|
socket_(std::move(socket)),
|
|
channel_(std::make_unique<Channel>(loop, socket_.fd())),
|
|
localAddr_(localAddr),
|
|
peerAddr_(peerAddr),
|
|
name_(name),
|
|
state_(kConnecting),
|
|
highWaterMark_(64 * 1024 * 1024)
|
|
{
|
|
channel_->setReadCallback([this] { handleRead(); });
|
|
channel_->setWriteCallback([this] { handleWrite(); });
|
|
channel_->setCloseCallback([this] { handleClose(); });
|
|
channel_->setErrorCallback([this] { handleError(); });
|
|
|
|
socket_.setKeepAlive(true);
|
|
socket_.setTcpNoDelay(true);
|
|
|
|
LOG_INFO << "TcpConnection " << name_ << " created from "
|
|
<< localAddr_.toIpPort() << " to " << peerAddr_.toIpPort() << " fd=" << socket_.fd();
|
|
}
|
|
|
|
/*
|
|
* Destroys the TcpConnection.
|
|
*/
|
|
~TcpConnection()
|
|
{
|
|
LOG_INFO << "TcpConnection " << name_ << " destroyed state=" << stateToString();
|
|
assert(state_ == kDisconnected);
|
|
}
|
|
|
|
/*
|
|
* Called when the connection is established.
|
|
*/
|
|
void connectEstablished()
|
|
{
|
|
loop_->assertInLoopThread();
|
|
assert(state_ == kConnecting);
|
|
setState(kConnected);
|
|
channel_->tie(shared_from_this());
|
|
channel_->enableReading();
|
|
if (connectionCallback_) {
|
|
connectionCallback_(shared_from_this());
|
|
}
|
|
LOG_INFO << "TcpConnection " << name_ << " established";
|
|
}
|
|
|
|
/*
|
|
* Called when the connection is to be destroyed.
|
|
*/
|
|
void connectDestroyed()
|
|
{
|
|
loop_->assertInLoopThread();
|
|
if (state_ == kConnected) {
|
|
setState(kDisconnected);
|
|
channel_->disableAll();
|
|
if (connectionCallback_) {
|
|
connectionCallback_(shared_from_this());
|
|
}
|
|
}
|
|
channel_->remove();
|
|
LOG_INFO << "TcpConnection " << name_ << " destroyed";
|
|
}
|
|
|
|
/*
|
|
* Sends data. This is thread-safe.
|
|
*/
|
|
void send(const std::string& message)
|
|
{
|
|
if (state_ == kConnected) {
|
|
if (loop_->isInLoopThread()) {
|
|
sendInLoop(message.data(), message.size());
|
|
} else {
|
|
loop_->runInLoop([self = shared_from_this(), message]() {
|
|
self->sendInLoop(message.data(), message.size());
|
|
});
|
|
}
|
|
}
|
|
}
|
|
|
|
/*
|
|
* Sends data. This is thread-safe.
|
|
*/
|
|
void send(const char* data, size_t len)
|
|
{
|
|
if (state_ == kConnected) {
|
|
if (loop_->isInLoopThread()) {
|
|
sendInLoop(data, len);
|
|
} else {
|
|
std::string message(data, len);
|
|
loop_->runInLoop([self = shared_from_this(), message]() {
|
|
self->sendInLoop(message.data(), message.size());
|
|
});
|
|
}
|
|
}
|
|
}
|
|
|
|
/*
|
|
* Sends data from a buffer. This is thread-safe.
|
|
*/
|
|
void send(Buffer& buffer)
|
|
{
|
|
if (state_ == kConnected) {
|
|
if (loop_->isInLoopThread()) {
|
|
sendInLoop(buffer.peek(), buffer.readableBytes());
|
|
buffer.retrieveAll();
|
|
} else {
|
|
std::string message = buffer.readAll();
|
|
loop_->runInLoop([self = shared_from_this(), message]() {
|
|
self->sendInLoop(message.data(), message.size());
|
|
});
|
|
}
|
|
}
|
|
}
|
|
|
|
/*
|
|
* Shuts down the write-half of the connection.
|
|
*/
|
|
void shutdown()
|
|
{
|
|
if (state_ == kConnected) {
|
|
setState(kDisconnecting);
|
|
loop_->runInLoop([self = shared_from_this()]() {
|
|
self->shutdownInLoop();
|
|
});
|
|
}
|
|
}
|
|
|
|
/*
|
|
* Forcibly closes the connection.
|
|
*/
|
|
void forceClose()
|
|
{
|
|
if (state_ == kConnected || state_ == kDisconnecting) {
|
|
setState(kDisconnecting);
|
|
loop_->queueInLoop([self = shared_from_this()]() {
|
|
self->forceCloseInLoop();
|
|
});
|
|
}
|
|
}
|
|
|
|
/*
|
|
* Forcibly closes the connection after a delay.
|
|
*/
|
|
void forceCloseWithDelay(double seconds)
|
|
{
|
|
if (state_ == kConnected || state_ == kDisconnecting) {
|
|
setState(kDisconnecting);
|
|
loop_->runAfter(Duration(static_cast<int>(seconds * 1000)),
|
|
[self = shared_from_this()]() {
|
|
self->forceClose();
|
|
});
|
|
}
|
|
}
|
|
|
|
void startRead()
|
|
{
|
|
loop_->runInLoop([self = shared_from_this()]() {
|
|
if (!self->channel_->isReading()) {
|
|
self->channel_->enableReading();
|
|
}
|
|
});
|
|
}
|
|
|
|
void stopRead()
|
|
{
|
|
loop_->runInLoop([self = shared_from_this()]() {
|
|
if (self->channel_->isReading()) {
|
|
self->channel_->disableReading();
|
|
}
|
|
});
|
|
}
|
|
|
|
const std::string& name() const { return name_; }
|
|
const InetAddress& localAddr() const { return localAddr_; }
|
|
const InetAddress& peerAddr() const { return peerAddr_; }
|
|
bool connected() const { return state_ == kConnected; }
|
|
bool disconnected() const { return state_ == kDisconnected; }
|
|
EventLoop* getLoop() const { return loop_; }
|
|
|
|
void setTcpNoDelay(bool on) { socket_.setTcpNoDelay(on); }
|
|
void setTcpKeepAlive(bool on) { socket_.setKeepAlive(on); }
|
|
|
|
void setMessageCallback(MessageCallback cb) { messageCallback_ = std::move(cb); }
|
|
void setConnectionCallback(ConnectionCallback cb) { connectionCallback_ = std::move(cb); }
|
|
void setCloseCallback(ConnectionCallback cb) { closeCallback_ = std::move(cb); }
|
|
void setWriteCompleteCallback(WriteCompleteCallback cb) { writeCompleteCallback_ = std::move(cb); }
|
|
void setHighWaterMarkCallback(HighWaterMarkCallback cb, size_t highWaterMark)
|
|
{
|
|
highWaterMarkCallback_ = std::move(cb);
|
|
highWaterMark_ = highWaterMark;
|
|
}
|
|
|
|
Buffer* inputBuffer() { return &inputBuffer_; }
|
|
Buffer* outputBuffer() { return &outputBuffer_; }
|
|
|
|
private:
|
|
/*
|
|
* Sets the internal state of the connection.
|
|
*/
|
|
void setState(StateE s)
|
|
{
|
|
state_ = s;
|
|
}
|
|
|
|
/*
|
|
* Handles readable events on the socket.
|
|
*/
|
|
void handleRead()
|
|
{
|
|
loop_->assertInLoopThread();
|
|
int savedErrno = 0;
|
|
ssize_t n = inputBuffer_.readFd(socket_.fd(), &savedErrno);
|
|
if (n > 0) {
|
|
LOG_TRACE << "TcpConnection " << name_ << " read " << n << " bytes";
|
|
if (messageCallback_) {
|
|
messageCallback_(shared_from_this(), inputBuffer_);
|
|
}
|
|
} else if (n == 0) {
|
|
LOG_DEBUG << "TcpConnection " << name_ << " peer closed";
|
|
handleClose();
|
|
} else {
|
|
errno = savedErrno;
|
|
LOG_ERROR << "TcpConnection " << name_ << " read error: " << strerror(savedErrno);
|
|
handleError();
|
|
}
|
|
}
|
|
|
|
/*
|
|
* Handles writable events on the socket.
|
|
*/
|
|
void handleWrite()
|
|
{
|
|
loop_->assertInLoopThread();
|
|
if (channel_->isWriting()) {
|
|
ssize_t n = socket_.write(outputBuffer_.peek(), outputBuffer_.readableBytes());
|
|
if (n > 0) {
|
|
outputBuffer_.retrieve(n);
|
|
LOG_TRACE << "TcpConnection " << name_ << " wrote " << n << " bytes, "
|
|
<< outputBuffer_.readableBytes() << " bytes left";
|
|
if (outputBuffer_.readableBytes() == 0) {
|
|
channel_->disableWriting();
|
|
if (writeCompleteCallback_) {
|
|
loop_->queueInLoop([self = shared_from_this()]() {
|
|
self->writeCompleteCallback_(self);
|
|
});
|
|
}
|
|
if (state_ == kDisconnecting) {
|
|
shutdownInLoop();
|
|
}
|
|
}
|
|
} else {
|
|
LOG_ERROR << "TcpConnection " << name_ << " write error: " << strerror(errno);
|
|
}
|
|
} else {
|
|
LOG_TRACE << "TcpConnection " << name_ << " not writing, ignore";
|
|
}
|
|
}
|
|
|
|
/*
|
|
* Handles connection close events.
|
|
*/
|
|
void handleClose()
|
|
{
|
|
loop_->assertInLoopThread();
|
|
LOG_DEBUG << "TcpConnection " << name_ << " state=" << stateToString();
|
|
assert(state_ == kConnected || state_ == kDisconnecting);
|
|
setState(kDisconnected);
|
|
channel_->disableAll();
|
|
auto guardThis = shared_from_this();
|
|
if (connectionCallback_) {
|
|
connectionCallback_(guardThis);
|
|
}
|
|
if (closeCallback_) {
|
|
closeCallback_(guardThis);
|
|
}
|
|
}
|
|
|
|
/*
|
|
* Handles socket error events.
|
|
*/
|
|
void handleError()
|
|
{
|
|
int err = socket_.getSocketError();
|
|
LOG_ERROR << "TcpConnection " << name_ << " SO_ERROR=" << err << " " << strerror(err);
|
|
handleClose();
|
|
}
|
|
|
|
/*
|
|
* Sends data within the event loop.
|
|
*/
|
|
void sendInLoop(const char* data, size_t len)
|
|
{
|
|
loop_->assertInLoopThread();
|
|
ssize_t nwrote = 0;
|
|
size_t remaining = len;
|
|
bool faultError = false;
|
|
if (state_ == kDisconnected) {
|
|
LOG_WARN << "TcpConnection " << name_ << " disconnected, give up writing";
|
|
return;
|
|
}
|
|
if (!channel_->isWriting() && outputBuffer_.readableBytes() == 0) {
|
|
nwrote = socket_.write(data, len);
|
|
if (nwrote >= 0) {
|
|
remaining = len - nwrote;
|
|
if (remaining == 0 && writeCompleteCallback_) {
|
|
loop_->queueInLoop([self = shared_from_this()]() {
|
|
self->writeCompleteCallback_(self);
|
|
});
|
|
}
|
|
} else {
|
|
nwrote = 0;
|
|
if (errno != EWOULDBLOCK) {
|
|
LOG_ERROR << "TcpConnection " << name_ << " send error: " << strerror(errno);
|
|
if (errno == EPIPE || errno == ECONNRESET) {
|
|
faultError = true;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
assert(remaining <= len);
|
|
if (!faultError && remaining > 0) {
|
|
size_t oldLen = outputBuffer_.readableBytes();
|
|
if (oldLen + remaining >= highWaterMark_ &&
|
|
oldLen < highWaterMark_ &&
|
|
highWaterMarkCallback_) {
|
|
loop_->queueInLoop([self = shared_from_this(), mark = oldLen + remaining]() {
|
|
self->highWaterMarkCallback_(self, mark);
|
|
});
|
|
}
|
|
outputBuffer_.append(data + nwrote, remaining);
|
|
if (!channel_->isWriting()) {
|
|
channel_->enableWriting();
|
|
}
|
|
}
|
|
}
|
|
|
|
/*
|
|
* Shuts down the connection within the event loop.
|
|
*/
|
|
void shutdownInLoop()
|
|
{
|
|
loop_->assertInLoopThread();
|
|
if (!channel_->isWriting()) {
|
|
socket_.shutdownWrite();
|
|
}
|
|
}
|
|
|
|
/*
|
|
* Forcibly closes the connection within the event loop.
|
|
*/
|
|
void forceCloseInLoop()
|
|
{
|
|
loop_->assertInLoopThread();
|
|
if (state_ == kConnected || state_ == kDisconnecting) {
|
|
handleClose();
|
|
}
|
|
}
|
|
|
|
/*
|
|
* Converts the current state to a string for logging.
|
|
*/
|
|
std::string stateToString() const
|
|
{
|
|
switch (state_) {
|
|
case kDisconnected: return "kDisconnected";
|
|
case kConnecting: return "kConnecting";
|
|
case kConnected: return "kConnected";
|
|
case kDisconnecting: return "kDisconnecting";
|
|
default: return "unknown state";
|
|
}
|
|
}
|
|
|
|
EventLoop* loop_;
|
|
Socket socket_;
|
|
std::unique_ptr<Channel> channel_;
|
|
InetAddress localAddr_;
|
|
InetAddress peerAddr_;
|
|
std::string name_;
|
|
StateE state_;
|
|
Buffer inputBuffer_;
|
|
Buffer outputBuffer_;
|
|
MessageCallback messageCallback_;
|
|
ConnectionCallback connectionCallback_;
|
|
ConnectionCallback closeCallback_;
|
|
WriteCompleteCallback writeCompleteCallback_;
|
|
HighWaterMarkCallback highWaterMarkCallback_;
|
|
size_t highWaterMark_;
|
|
};
|
|
|
|
} // namespace reactor
|