diff --git a/lib/InetAddress.hpp b/lib/InetAddress.hpp index 23216be..10767a9 100644 --- a/lib/InetAddress.hpp +++ b/lib/InetAddress.hpp @@ -5,144 +5,202 @@ #include #include #include +#include // Use std::variant -namespace reactor { +namespace reactor +{ class InetAddress { private: - union - { - sockaddr_in addr4_; - sockaddr_in6 addr6_; - }; - bool isIpV6_; + std::variant addr_; public: + /* + * Constructs an address. + */ explicit InetAddress(uint16_t port = 0, bool ipv6 = false, bool loopback = false) - : isIpV6_(ipv6) { - if (ipv6) { - memset(&addr6_, 0, sizeof(addr6_)); - addr6_.sin6_family = AF_INET6; - addr6_.sin6_addr = loopback ? in6addr_loopback : in6addr_any; - addr6_.sin6_port = htons(port); - } else { - memset(&addr4_, 0, sizeof(addr4_)); - addr4_.sin_family = AF_INET; - addr4_.sin_addr.s_addr = htonl(loopback ? INADDR_LOOPBACK : INADDR_ANY); - addr4_.sin_port = htons(port); + if (ipv6) + { + sockaddr_in6 addr6; + memset(&addr6, 0, sizeof(addr6)); + addr6.sin6_family = AF_INET6; + addr6.sin6_addr = loopback ? in6addr_loopback : in6addr_any; + addr6.sin6_port = htons(port); + addr_ = addr6; + } + else + { + sockaddr_in addr4; + memset(&addr4, 0, sizeof(addr4)); + addr4.sin_family = AF_INET; + addr4.sin_addr.s_addr = htonl(loopback ? INADDR_LOOPBACK : INADDR_ANY); + addr4.sin_port = htons(port); + addr_ = addr4; } LOG_TRACE << "InetAddress created: " << toIpPort(); } - InetAddress(const std::string& ip, uint16_t port) + /* + * Constructs an address from an IP and port. + */ + InetAddress(const std::string &ip, uint16_t port) { - if (ip.find(':') != std::string::npos) { - isIpV6_ = true; - memset(&addr6_, 0, sizeof(addr6_)); - addr6_.sin6_family = AF_INET6; - addr6_.sin6_port = htons(port); - if (inet_pton(AF_INET6, ip.c_str(), &addr6_.sin6_addr) <= 0) { + if (ip.find(':') != std::string::npos) + { + sockaddr_in6 addr6; + memset(&addr6, 0, sizeof(addr6)); + addr6.sin6_family = AF_INET6; + addr6.sin6_port = htons(port); + if (inet_pton(AF_INET6, ip.c_str(), &addr6.sin6_addr) <= 0) + { LOG_ERROR << "Invalid IPv6 address: " << ip; } - } else { - isIpV6_ = false; - memset(&addr4_, 0, sizeof(addr4_)); - addr4_.sin_family = AF_INET; - addr4_.sin_port = htons(port); - if (inet_pton(AF_INET, ip.c_str(), &addr4_.sin_addr) <= 0) { + addr_ = addr6; + } + else + { + sockaddr_in addr4; + memset(&addr4, 0, sizeof(addr4)); + addr4.sin_family = AF_INET; + addr4.sin_port = htons(port); + if (inet_pton(AF_INET, ip.c_str(), &addr4.sin_addr) <= 0) + { LOG_ERROR << "Invalid IPv4 address: " << ip; } + addr_ = addr4; } LOG_TRACE << "InetAddress created from ip:port: " << toIpPort(); } - explicit InetAddress(const sockaddr_in& addr) : addr4_(addr), isIpV6_(false) + /* + * Constructs an address from a sockaddr_in struct. + */ + explicit InetAddress(const sockaddr_in &addr) : addr_(addr) { LOG_TRACE << "InetAddress created from sockaddr_in: " << toIpPort(); } - explicit InetAddress(const sockaddr_in6& addr) : addr6_(addr), isIpV6_(true) + /* + * Constructs an address from a sockaddr_in6 struct. + */ + explicit InetAddress(const sockaddr_in6 &addr) : addr_(addr) { LOG_TRACE << "InetAddress created from sockaddr_in6: " << toIpPort(); } - const sockaddr* getSockAddr() const + const sockaddr *getSockAddr() const { - if (isIpV6_) { - return reinterpret_cast(&addr6_); - } else { - return reinterpret_cast(&addr4_); - } + // std::visit gets the pointer from the active variant member + return std::visit( + [](const auto &addr) + { + return reinterpret_cast(&addr); + }, + addr_); } - socklen_t getSockLen() const { return isIpV6_ ? sizeof(addr6_) : sizeof(addr4_); } - bool isIpV6() const { return isIpV6_; } - uint16_t port() const { return ntohs(isIpV6_ ? addr6_.sin6_port : addr4_.sin_port); } + socklen_t getSockLen() const + { + return std::visit([](const auto &addr) + { return sizeof(addr); }, addr_); + } + bool isIpV6() const { return std::holds_alternative(addr_); } + uint16_t port() const + { + // Use if constexpr to access members with different names + return std::visit( + [](const auto &addr) + { + if constexpr (std::is_same_v, sockaddr_in>) + { + return ntohs(addr.sin_port); + } + else + { + return ntohs(addr.sin6_port); + } + }, + addr_); + } std::string toIp() const { char buf[INET6_ADDRSTRLEN]; - if (isIpV6_) { - inet_ntop(AF_INET6, &addr6_.sin6_addr, buf, sizeof(buf)); - } else { - inet_ntop(AF_INET, &addr4_.sin_addr, buf, sizeof(buf)); - } + std::visit( + [&buf](const auto &addr) + { + if constexpr (std::is_same_v, sockaddr_in>) + { + inet_ntop(AF_INET, &addr.sin_addr, buf, sizeof(buf)); + } + else + { + inet_ntop(AF_INET6, &addr.sin6_addr, buf, sizeof(buf)); + } + }, + addr_); return std::string(buf); } std::string toIpPort() const { - return isIpV6_ ? "[" + toIp() + "]:" + std::to_string(port()) - : toIp() + ":" + std::to_string(port()); + return isIpV6() ? "[" + toIp() + "]:" + std::to_string(port()) + : toIp() + ":" + std::to_string(port()); } - bool operator==(const InetAddress& other) const + bool operator==(const InetAddress &other) const { - if (isIpV6_ != other.isIpV6_) return false; - - if (isIpV6_) { - return memcmp(&addr6_, &other.addr6_, sizeof(addr6_)) == 0; - } else { - return memcmp(&addr4_, &other.addr4_, sizeof(addr4_)) == 0; - } + if (addr_.index() != other.addr_.index()) return false; + return std::visit( + [&other](const auto &self_addr) + { + const auto &other_addr = std::get>(other.addr_); + return memcmp(&self_addr, &other_addr, sizeof(self_addr)) == 0; + }, + addr_); } - bool operator!=(const InetAddress& other) const + bool operator!=(const InetAddress &other) const { return !(*this == other); } - bool operator<(const InetAddress& other) const + bool operator<(const InetAddress &other) const { - if (isIpV6_ != other.isIpV6_) { - return !isIpV6_; + if (isIpV6() != other.isIpV6()) + { + return !isIpV6(); } - if (isIpV6_) { - return memcmp(&addr6_, &other.addr6_, sizeof(addr6_)) < 0; - } else { - return memcmp(&addr4_, &other.addr4_, sizeof(addr4_)) < 0; - } + return std::visit( + [&other](const auto &self_addr) + { + const auto &other_addr = std::get>(other.addr_); + return memcmp(&self_addr, &other_addr, sizeof(self_addr)) < 0; + }, + addr_); } std::string familyToString() const { - return isIpV6_ ? "IPv6" : "IPv4"; + return isIpV6() ? "IPv6" : "IPv4"; } - static bool resolve(const std::string& hostname, InetAddress& result) + static bool resolve(const std::string &hostname, InetAddress &result) { // Simple resolution - in a real implementation you'd use getaddrinfo - if (hostname == "localhost") { + if (hostname == "localhost") + { result = InetAddress(0, false, true); return true; } // Try to parse as IP address directly InetAddress addr(hostname, 0); - if (addr.toIp() != "0.0.0.0" && addr.toIp() != "::") { + if (addr.toIp() != "0.0.0.0" && addr.toIp() != "::") + { result = addr; return true; } @@ -150,21 +208,34 @@ public: LOG_WARN << "Could not resolve hostname: " << hostname; return false; } + + // Make variant accessible for hashing + const std::variant &getVariant() const + { + return addr_; + } }; } // namespace reactor -namespace std { -template<> +namespace std +{ +template <> struct hash { - size_t operator()(const reactor::InetAddress& addr) const + size_t operator()(const reactor::InetAddress &addr) const { - size_t seed = 0; - reactor::hashCombine(seed, addr.toIp()); - reactor::hashCombine(seed, addr.port()); - reactor::hashCombine(seed, addr.isIpV6()); - return seed; + // Hash the raw bytes of the underlying struct for better performance + // than converting to a string first. + return std::visit( + [](const auto &a) + { + size_t seed = 0; + std::string_view bytes(reinterpret_cast(&a), sizeof(a)); + reactor::hashCombine(seed, std::hash{}(bytes)); + return seed; + }, + addr.getVariant()); } }; } // namespace std diff --git a/tests/test_inet_address.cpp b/tests/test_inet_address.cpp index 9a9826a..8a0ebdf 100644 --- a/tests/test_inet_address.cpp +++ b/tests/test_inet_address.cpp @@ -74,9 +74,10 @@ void test_address_ordering() reactor::InetAddress addr3("192.168.1.1", 8080); reactor::InetAddress addr4("::1", 8080); - std::cout << "addr1 < addr2: " << (addr1 < addr2) << "\n"; - std::cout << "addr1 < addr3: " << (addr1 < addr3) << "\n"; - std::cout << "addr1 < addr4: " << (addr1 < addr4) << "\n"; + // Use asserts for proper testing + assert(addr1 < addr2); + assert(addr1 < addr3); + assert(addr1 < addr4); // IPv4 is less than IPv6 in this implementation std::cout << "✓ Address ordering passed\n"; }