diff --git a/lib/Socket.hpp b/lib/Socket.hpp index 5b3eca3..d5fe338 100644 --- a/lib/Socket.hpp +++ b/lib/Socket.hpp @@ -6,50 +6,58 @@ #include #include #include -#include +#include +#include -namespace reactor { +namespace reactor +{ class Socket : public NonCopyable { private: int fd_; - void setNonBlockAndCloseOnExec() - { - int flags = fcntl(fd_, F_GETFL, 0); - flags |= O_NONBLOCK; - fcntl(fd_, F_SETFL, flags); - - flags = fcntl(fd_, F_GETFD, 0); - flags |= FD_CLOEXEC; - fcntl(fd_, F_SETFD, flags); - } - public: - explicit Socket(int fd) : fd_(fd) + /* + * Constructs a Socket by taking ownership of a file descriptor. + */ + explicit Socket(int fd) + : fd_(fd) { LOG_TRACE << "Socket created with fd=" << fd_; } + /* + * Destructor, closes the socket file descriptor. + */ ~Socket() { - if (fd_ >= 0) { + if (fd_ >= 0) + { close(fd_); LOG_TRACE << "Socket fd=" << fd_ << " closed"; } } - Socket(Socket&& other) noexcept : fd_(other.fd_) + /* + * Move constructor. + */ + Socket(Socket &&other) noexcept + : fd_(other.fd_) { other.fd_ = -1; LOG_TRACE << "Socket moved fd=" << fd_; } - Socket& operator=(Socket&& other) noexcept + /* + * Move assignment operator. + */ + Socket &operator=(Socket &&other) noexcept { - if (this != &other) { - if (fd_ >= 0) { + if (this != &other) + { + if (fd_ >= 0) + { close(fd_); LOG_TRACE << "Socket fd=" << fd_ << " closed in move assignment"; } @@ -59,86 +67,175 @@ public: return *this; } + /* + * Creates a non-blocking TCP socket. + * Throws std::runtime_error on failure. + */ static Socket createTcp(bool ipv6 = false) { int fd = socket(ipv6 ? AF_INET6 : AF_INET, SOCK_STREAM | SOCK_NONBLOCK | SOCK_CLOEXEC, 0); - if (fd < 0) { - LOG_FATAL << "Failed to create TCP socket: " << strerror(errno); - abort(); + if (fd < 0) + { + throw std::runtime_error("Failed to create TCP socket: " + std::string(strerror(errno))); } LOG_DEBUG << "Created TCP socket fd=" << fd << " ipv6=" << ipv6; return Socket(fd); } + /* + * Creates a non-blocking UDP socket. + * Throws std::runtime_error on failure. + */ static Socket createUdp(bool ipv6 = false) { int fd = socket(ipv6 ? AF_INET6 : AF_INET, SOCK_DGRAM | SOCK_NONBLOCK | SOCK_CLOEXEC, 0); - if (fd < 0) { - LOG_FATAL << "Failed to create UDP socket: " << strerror(errno); - abort(); + if (fd < 0) + { + throw std::runtime_error("Failed to create UDP socket: " + std::string(strerror(errno))); } LOG_DEBUG << "Created UDP socket fd=" << fd << " ipv6=" << ipv6; return Socket(fd); } - void bind(const InetAddress& addr) + /* + * Binds the socket to a specific address. + * Throws std::runtime_error on failure. + */ + void bind(const InetAddress &addr) { - int ret = ::bind(fd_, addr.getSockAddr(), addr.getSockLen()); - if (ret < 0) { - LOG_FATAL << "Socket bind to " << addr.toIpPort() << " failed: " << strerror(errno); - abort(); + if (::bind(fd_, addr.getSockAddr(), addr.getSockLen()) < 0) + { + throw std::runtime_error("Socket bind to " + addr.toIpPort() + " failed: " + std::string(strerror(errno))); } LOG_INFO << "Socket fd=" << fd_ << " bound to " << addr.toIpPort(); } + /* + * Puts the socket in listening mode for incoming connections. + * Throws std::runtime_error on failure. + */ void listen(int backlog = SOMAXCONN) { - int ret = ::listen(fd_, backlog); - if (ret < 0) { - LOG_FATAL << "Socket listen failed: " << strerror(errno); - abort(); + if (::listen(fd_, backlog) < 0) + { + throw std::runtime_error("Socket listen failed: " + std::string(strerror(errno))); } LOG_INFO << "Socket fd=" << fd_ << " listening with backlog=" << backlog; } - int accept(InetAddress& peerAddr) + /* + * Accepts a new connection. + * Returns an optional Socket. Returns nullopt if no pending connection. + */ + std::optional accept(InetAddress &peerAddr) { sockaddr_in6 addr; socklen_t len = sizeof(addr); - int connfd = accept4(fd_, reinterpret_cast(&addr), &len, SOCK_NONBLOCK | SOCK_CLOEXEC); + int connfd = accept4(fd_, reinterpret_cast(&addr), &len, SOCK_NONBLOCK | SOCK_CLOEXEC); - if (connfd >= 0) { - if (addr.sin6_family == AF_INET) { - peerAddr = InetAddress(*reinterpret_cast(&addr)); - } else { + if (connfd >= 0) + { + if (addr.sin6_family == AF_INET) + { + peerAddr = InetAddress(*reinterpret_cast(&addr)); + } + else + { peerAddr = InetAddress(addr); } LOG_DEBUG << "Socket fd=" << fd_ << " accepted connection fd=" << connfd << " from " << peerAddr.toIpPort(); - } else if (errno != EAGAIN && errno != EWOULDBLOCK) { + return std::optional(Socket(connfd)); + } + else if (errno != EAGAIN && errno != EWOULDBLOCK) + { LOG_ERROR << "Socket accept failed: " << strerror(errno); } - return connfd; + return std::nullopt; } - int connect(const InetAddress& addr) + /* + * Establishes a connection to a specified address. + */ + int connect(const InetAddress &addr) { int ret = ::connect(fd_, addr.getSockAddr(), addr.getSockLen()); - if (ret < 0 && errno != EINPROGRESS) { + if (ret < 0 && errno != EINPROGRESS) + { LOG_ERROR << "Socket connect to " << addr.toIpPort() << " failed: " << strerror(errno); - } else { + } + else + { LOG_DEBUG << "Socket fd=" << fd_ << " connecting to " << addr.toIpPort(); } return ret; } + ssize_t read(void *buf, size_t len) + { + ssize_t n = ::read(fd_, buf, len); + if (n < 0 && errno != EAGAIN && errno != EWOULDBLOCK && errno != EINTR) + { + LOG_ERROR << "Socket read failed: " << strerror(errno); + } + return n; + } + + ssize_t write(const void *buf, size_t len) + { + ssize_t n = ::write(fd_, buf, len); + if (n < 0 && errno != EAGAIN && errno != EWOULDBLOCK && errno != EINTR) + { + LOG_ERROR << "Socket write failed: " << strerror(errno); + } + return n; + } + + ssize_t sendTo(const void *buf, size_t len, const InetAddress &addr) + { + ssize_t n = ::sendto(fd_, buf, len, 0, addr.getSockAddr(), addr.getSockLen()); + if (n < 0 && errno != EAGAIN && errno != EWOULDBLOCK) + { + LOG_ERROR << "Socket sendto failed: " << strerror(errno); + } + return n; + } + + ssize_t recvFrom(void *buf, size_t len, InetAddress &addr) + { + sockaddr_in6 sockaddr; + socklen_t addrlen = sizeof(sockaddr); + ssize_t n = ::recvfrom(fd_, buf, len, 0, reinterpret_cast(&sockaddr), &addrlen); + + if (n >= 0) + { + if (sockaddr.sin6_family == AF_INET) + { + addr = InetAddress(*reinterpret_cast(&sockaddr)); + } + else + { + addr = InetAddress(sockaddr); + } + } + else if (errno != EAGAIN && errno != EWOULDBLOCK) + { + LOG_ERROR << "Socket recvfrom failed: " << strerror(errno); + } + + return n; + } + void setReuseAddr(bool on = true) { int optval = on ? 1 : 0; - if (setsockopt(fd_, SOL_SOCKET, SO_REUSEADDR, &optval, sizeof(optval)) < 0) { + if (setsockopt(fd_, SOL_SOCKET, SO_REUSEADDR, &optval, sizeof(optval)) < 0) + { LOG_ERROR << "setsockopt SO_REUSEADDR failed: " << strerror(errno); - } else { + } + else + { LOG_TRACE << "Socket fd=" << fd_ << " SO_REUSEADDR=" << on; } } @@ -146,9 +243,12 @@ public: void setReusePort(bool on = true) { int optval = on ? 1 : 0; - if (setsockopt(fd_, SOL_SOCKET, SO_REUSEPORT, &optval, sizeof(optval)) < 0) { + if (setsockopt(fd_, SOL_SOCKET, SO_REUSEPORT, &optval, sizeof(optval)) < 0) + { LOG_ERROR << "setsockopt SO_REUSEPORT failed: " << strerror(errno); - } else { + } + else + { LOG_TRACE << "Socket fd=" << fd_ << " SO_REUSEPORT=" << on; } } @@ -156,9 +256,12 @@ public: void setTcpNoDelay(bool on = true) { int optval = on ? 1 : 0; - if (setsockopt(fd_, IPPROTO_TCP, TCP_NODELAY, &optval, sizeof(optval)) < 0) { + if (setsockopt(fd_, IPPROTO_TCP, TCP_NODELAY, &optval, sizeof(optval)) < 0) + { LOG_ERROR << "setsockopt TCP_NODELAY failed: " << strerror(errno); - } else { + } + else + { LOG_TRACE << "Socket fd=" << fd_ << " TCP_NODELAY=" << on; } } @@ -166,9 +269,12 @@ public: void setKeepAlive(bool on = true) { int optval = on ? 1 : 0; - if (setsockopt(fd_, SOL_SOCKET, SO_KEEPALIVE, &optval, sizeof(optval)) < 0) { + if (setsockopt(fd_, SOL_SOCKET, SO_KEEPALIVE, &optval, sizeof(optval)) < 0) + { LOG_ERROR << "setsockopt SO_KEEPALIVE failed: " << strerror(errno); - } else { + } + else + { LOG_TRACE << "Socket fd=" << fd_ << " SO_KEEPALIVE=" << on; } } @@ -177,9 +283,12 @@ public: { if (setsockopt(fd_, IPPROTO_TCP, TCP_KEEPIDLE, &idle, sizeof(idle)) < 0 || setsockopt(fd_, IPPROTO_TCP, TCP_KEEPINTVL, &interval, sizeof(interval)) < 0 || - setsockopt(fd_, IPPROTO_TCP, TCP_KEEPCNT, &count, sizeof(count)) < 0) { + setsockopt(fd_, IPPROTO_TCP, TCP_KEEPCNT, &count, sizeof(count)) < 0) + { LOG_ERROR << "setsockopt TCP_KEEP* failed: " << strerror(errno); - } else { + } + else + { LOG_TRACE << "Socket fd=" << fd_ << " TCP keepalive: idle=" << idle << " interval=" << interval << " count=" << count; } @@ -187,143 +296,105 @@ public: void setRecvBuffer(int size) { - if (setsockopt(fd_, SOL_SOCKET, SO_RCVBUF, &size, sizeof(size)) < 0) { + if (setsockopt(fd_, SOL_SOCKET, SO_RCVBUF, &size, sizeof(size)) < 0) + { LOG_ERROR << "setsockopt SO_RCVBUF failed: " << strerror(errno); - } else { + } + else + { LOG_TRACE << "Socket fd=" << fd_ << " SO_RCVBUF=" << size; } } void setSendBuffer(int size) { - if (setsockopt(fd_, SOL_SOCKET, SO_SNDBUF, &size, sizeof(size)) < 0) { + if (setsockopt(fd_, SOL_SOCKET, SO_SNDBUF, &size, sizeof(size)) < 0) + { LOG_ERROR << "setsockopt SO_SNDBUF failed: " << strerror(errno); - } else { + } + else + { LOG_TRACE << "Socket fd=" << fd_ << " SO_SNDBUF=" << size; } } - ssize_t read(void* buf, size_t len) - { - ssize_t n = ::read(fd_, buf, len); - if (n < 0 && errno != EAGAIN && errno != EWOULDBLOCK && errno != EINTR) { - LOG_ERROR << "Socket read failed: " << strerror(errno); - } - return n; - } - - ssize_t write(const void* buf, size_t len) - { - ssize_t n = ::write(fd_, buf, len); - if (n < 0 && errno != EAGAIN && errno != EWOULDBLOCK && errno != EINTR) { - LOG_ERROR << "Socket write failed: " << strerror(errno); - } - return n; - } - - ssize_t sendTo(const void* buf, size_t len, const InetAddress& addr) - { - ssize_t n = sendto(fd_, buf, len, 0, addr.getSockAddr(), addr.getSockLen()); - if (n < 0 && errno != EAGAIN && errno != EWOULDBLOCK) { - LOG_ERROR << "Socket sendto failed: " << strerror(errno); - } - return n; - } - - ssize_t recvFrom(void* buf, size_t len, InetAddress& addr) - { - sockaddr_in6 sockaddr; - socklen_t addrlen = sizeof(sockaddr); - ssize_t n = recvfrom(fd_, buf, len, 0, reinterpret_cast(&sockaddr), &addrlen); - - if (n >= 0) { - if (sockaddr.sin6_family == AF_INET) { - addr = InetAddress(*reinterpret_cast(&sockaddr)); - } else { - addr = InetAddress(sockaddr); - } - } else if (errno != EAGAIN && errno != EWOULDBLOCK) { - LOG_ERROR << "Socket recvfrom failed: " << strerror(errno); - } - - return n; - } - void shutdownWrite() { - if (shutdown(fd_, SHUT_WR) < 0) { + if (shutdown(fd_, SHUT_WR) < 0) + { LOG_ERROR << "Socket shutdown write failed: " << strerror(errno); - } else { + } + else + { LOG_DEBUG << "Socket fd=" << fd_ << " shutdown write"; } } void shutdownRead() { - if (shutdown(fd_, SHUT_RD) < 0) { + if (shutdown(fd_, SHUT_RD) < 0) + { LOG_ERROR << "Socket shutdown read failed: " << strerror(errno); - } else { + } + else + { LOG_DEBUG << "Socket fd=" << fd_ << " shutdown read"; } } - int getSocketError() + int getSocketError() const { int optval; socklen_t optlen = sizeof(optval); - if (getsockopt(fd_, SOL_SOCKET, SO_ERROR, &optval, &optlen) < 0) { + if (getsockopt(fd_, SOL_SOCKET, SO_ERROR, &optval, &optlen) < 0) + { return errno; } return optval; } - bool isConnected() - { - int error = getSocketError(); - if (error != 0) return false; - - char c; - ssize_t result = ::recv(fd_, &c, 1, MSG_PEEK | MSG_DONTWAIT); - if (result == 0) return false; // Connection closed - if (result < 0 && errno != EAGAIN && errno != EWOULDBLOCK) return false; - return true; - } - int fd() const { return fd_; } - static InetAddress getLocalAddr(int sockfd) + InetAddress getLocalAddr() const { sockaddr_in6 addr; socklen_t addrlen = sizeof(addr); - if (getsockname(sockfd, reinterpret_cast(&addr), &addrlen) < 0) { + if (getsockname(fd_, reinterpret_cast(&addr), &addrlen) < 0) + { LOG_ERROR << "getsockname failed: " << strerror(errno); + // Return a default-constructed address on failure return InetAddress(); } - if (addr.sin6_family == AF_INET) { - return InetAddress(*reinterpret_cast(&addr)); + if (addr.sin6_family == AF_INET) + { + return InetAddress(*reinterpret_cast(&addr)); } return InetAddress(addr); } - static InetAddress getPeerAddr(int sockfd) + InetAddress getPeerAddr() const { sockaddr_in6 addr; socklen_t addrlen = sizeof(addr); - if (getpeername(sockfd, reinterpret_cast(&addr), &addrlen) < 0) { + if (getpeername(fd_, reinterpret_cast(&addr), &addrlen) < 0) + { LOG_ERROR << "getpeername failed: " << strerror(errno); + // Return a default-constructed address on failure return InetAddress(); } - if (addr.sin6_family == AF_INET) { - return InetAddress(*reinterpret_cast(&addr)); + if (addr.sin6_family == AF_INET) + { + return InetAddress(*reinterpret_cast(&addr)); } return InetAddress(addr); } - bool isSelfConnected() + bool isSelfConnected() const { - return getLocalAddr(fd_) == getPeerAddr(fd_); + // This is only meaningful for a connected TCP socket + return getLocalAddr() == getPeerAddr(); } }; diff --git a/tests/test_socket.cpp b/tests/test_socket.cpp index e8aec6f..46af96a 100644 --- a/tests/test_socket.cpp +++ b/tests/test_socket.cpp @@ -5,6 +5,31 @@ #include #include #include +#include // Required for std::atomic + +// Forward declarations of test functions +void test_socket_creation(); +void test_socket_move(); +void test_socket_bind_listen(); +void test_socket_options(); +void test_socket_connection(); +void test_udp_socket(); +void test_socket_shutdown(); +void test_socket_error_handling(); +void test_address_retrieval(); +void test_self_connection_detection(); + + +bool waitForSocketReady(int fd, short events, int timeout_ms) +{ + pollfd pfd; + pfd.fd = fd; + pfd.events = events; + pfd.revents = 0; + + int result = poll(&pfd, 1, timeout_ms); + return result > 0 && (pfd.revents & events); +} void test_socket_creation() { @@ -55,7 +80,8 @@ void test_socket_bind_listen() socket.bind(addr); socket.listen(); - auto local_addr = reactor::Socket::getLocalAddr(socket.fd()); + // FIX: Call getLocalAddr as a member function on the object. + auto local_addr = socket.getLocalAddr(); assert(local_addr.port() > 0); std::cout << "✓ Socket bind and listen passed\n"; @@ -78,16 +104,6 @@ void test_socket_options() std::cout << "✓ Socket options passed\n"; } -bool waitForSocketReady(int fd, short events, int timeout_ms) -{ - pollfd pfd; - pfd.fd = fd; - pfd.events = events; - pfd.revents = 0; - - int result = poll(&pfd, 1, timeout_ms); - return result > 0 && (pfd.revents & events); -} void test_socket_connection() { @@ -100,30 +116,31 @@ void test_socket_connection() server_socket.bind(server_addr); server_socket.listen(); - auto actual_addr = reactor::Socket::getLocalAddr(server_socket.fd()); + // FIX: Call getLocalAddr as a member function on the object. + auto actual_addr = server_socket.getLocalAddr(); std::cout << "Server listening on: " << actual_addr.toIpPort() << "\n"; std::atomic server_done{false}; std::thread server_thread([&server_socket, &server_done]() { reactor::InetAddress peer_addr; - // Wait for connection with timeout if (waitForSocketReady(server_socket.fd(), POLLIN, 1000)) { - int client_fd = server_socket.accept(peer_addr); - if (client_fd >= 0) { + // FIX: Handle the std::optional returned by accept(). + auto client_sock_opt = server_socket.accept(peer_addr); + if (client_sock_opt) { + auto &client_sock = *client_sock_opt; std::cout << "Server accepted connection from: " << peer_addr.toIpPort() << "\n"; - // Wait for data to be ready - if (waitForSocketReady(client_fd, POLLIN, 1000)) { + if (waitForSocketReady(client_sock.fd(), POLLIN, 1000)) { char buffer[1024]; - ssize_t n = read(client_fd, buffer, sizeof(buffer)); + // FIX: Use the Socket object's read/write methods. + ssize_t n = client_sock.read(buffer, sizeof(buffer)); if (n > 0) { - // Echo back the data - ssize_t written = write(client_fd, buffer, n); - (void)written; // Suppress unused warning + ssize_t written = client_sock.write(buffer, n); + (void)written; } } - close(client_fd); + // FIX: No need to call close(), Socket's destructor handles it. } } server_done = true; @@ -136,7 +153,6 @@ void test_socket_connection() int result = client_socket.connect(connect_addr); - // Wait for connection to complete if (result < 0 && errno == EINPROGRESS) { if (waitForSocketReady(client_socket.fd(), POLLOUT, 1000)) { int error = client_socket.getSocketError(); @@ -148,7 +164,6 @@ void test_socket_connection() ssize_t sent = client_socket.write(message, strlen(message)); assert(sent > 0); - // Wait for response if (waitForSocketReady(client_socket.fd(), POLLIN, 1000)) { char response[1024]; ssize_t received = client_socket.read(response, sizeof(response)); @@ -171,7 +186,8 @@ void test_udp_socket() server_socket.setReuseAddr(true); server_socket.bind(server_addr); - auto actual_addr = reactor::Socket::getLocalAddr(server_socket.fd()); + // FIX: Call getLocalAddr as a member function on the object. + auto actual_addr = server_socket.getLocalAddr(); std::cout << "UDP server bound to: " << actual_addr.toIpPort() << "\n"; std::thread server_thread([&server_socket]() { @@ -217,16 +233,17 @@ void test_socket_shutdown() server_socket.bind(server_addr); server_socket.listen(); - auto actual_addr = reactor::Socket::getLocalAddr(server_socket.fd()); + // FIX: Call getLocalAddr as a member function on the object. + auto actual_addr = server_socket.getLocalAddr(); std::thread server_thread([&server_socket]() { if (waitForSocketReady(server_socket.fd(), POLLIN, 1000)) { reactor::InetAddress peer_addr; - int client_fd = server_socket.accept(peer_addr); - if (client_fd >= 0) { - reactor::Socket client_sock(client_fd); + // FIX: Handle the std::optional returned by accept(). + auto client_sock_opt = server_socket.accept(peer_addr); + if (client_sock_opt) { std::this_thread::sleep_for(std::chrono::milliseconds(20)); - client_sock.shutdownWrite(); + client_sock_opt->shutdownWrite(); } } }); @@ -245,7 +262,7 @@ void test_socket_shutdown() char buffer[1024]; ssize_t n = client_socket.read(buffer, sizeof(buffer)); - assert(n == 0); // Connection should be closed + assert(n == 0); server_thread.join(); std::cout << "✓ Socket shutdown passed\n"; @@ -279,7 +296,8 @@ void test_address_retrieval() server_socket.bind(server_addr); server_socket.listen(); - auto local_addr = reactor::Socket::getLocalAddr(server_socket.fd()); + // FIX: Call getLocalAddr as a member function on the object. + auto local_addr = server_socket.getLocalAddr(); assert(local_addr.toIp() == "127.0.0.1"); assert(local_addr.port() > 0); @@ -296,27 +314,34 @@ void test_self_connection_detection() socket.setReuseAddr(true); socket.bind(addr); + // This is not a true self-connection test, as the socket isn't connected. + // We call it just to ensure it doesn't crash. bool is_self = socket.isSelfConnected(); - std::cout << "Self connected: " << is_self << "\n"; + std::cout << "Self connected (on non-connected socket): " << std::boolalpha << is_self << "\n"; std::cout << "✓ Self connection detection passed\n"; } int main() { - std::cout << "=== Socket Tests ===\n"; + try { + std::cout << "=== Socket Tests ===\n"; - test_socket_creation(); - test_socket_move(); - test_socket_bind_listen(); - test_socket_options(); - test_socket_connection(); - test_udp_socket(); - test_socket_shutdown(); - test_socket_error_handling(); - test_address_retrieval(); - test_self_connection_detection(); + test_socket_creation(); + test_socket_move(); + test_socket_bind_listen(); + test_socket_options(); + test_socket_connection(); + test_udp_socket(); + test_socket_shutdown(); + test_socket_error_handling(); + test_address_retrieval(); + test_self_connection_detection(); - std::cout << "All socket tests passed! ✓\n"; - return 0; + std::cout << "\nAll socket tests passed! ✓\n"; + return 0; + } catch (const std::exception& e) { + std::cerr << "A test failed with an exception: " << e.what() << std::endl; + return 1; + } }