inet address improvements

This commit is contained in:
Sky Johnson 2025-06-27 20:15:08 -05:00
parent 86eff0b230
commit 6df70256ec
2 changed files with 153 additions and 81 deletions

View File

@ -5,144 +5,202 @@
#include <arpa/inet.h> #include <arpa/inet.h>
#include <string> #include <string>
#include <cstring> #include <cstring>
#include <variant> // Use std::variant
namespace reactor { namespace reactor
{
class InetAddress class InetAddress
{ {
private: private:
union std::variant<sockaddr_in, sockaddr_in6> addr_;
{
sockaddr_in addr4_;
sockaddr_in6 addr6_;
};
bool isIpV6_;
public: public:
/*
* Constructs an address.
*/
explicit InetAddress(uint16_t port = 0, bool ipv6 = false, bool loopback = false) explicit InetAddress(uint16_t port = 0, bool ipv6 = false, bool loopback = false)
: isIpV6_(ipv6)
{ {
if (ipv6) { if (ipv6)
memset(&addr6_, 0, sizeof(addr6_)); {
addr6_.sin6_family = AF_INET6; sockaddr_in6 addr6;
addr6_.sin6_addr = loopback ? in6addr_loopback : in6addr_any; memset(&addr6, 0, sizeof(addr6));
addr6_.sin6_port = htons(port); addr6.sin6_family = AF_INET6;
} else { addr6.sin6_addr = loopback ? in6addr_loopback : in6addr_any;
memset(&addr4_, 0, sizeof(addr4_)); addr6.sin6_port = htons(port);
addr4_.sin_family = AF_INET; addr_ = addr6;
addr4_.sin_addr.s_addr = htonl(loopback ? INADDR_LOOPBACK : INADDR_ANY); }
addr4_.sin_port = htons(port); else
{
sockaddr_in addr4;
memset(&addr4, 0, sizeof(addr4));
addr4.sin_family = AF_INET;
addr4.sin_addr.s_addr = htonl(loopback ? INADDR_LOOPBACK : INADDR_ANY);
addr4.sin_port = htons(port);
addr_ = addr4;
} }
LOG_TRACE << "InetAddress created: " << toIpPort(); LOG_TRACE << "InetAddress created: " << toIpPort();
} }
InetAddress(const std::string& ip, uint16_t port) /*
* Constructs an address from an IP and port.
*/
InetAddress(const std::string &ip, uint16_t port)
{ {
if (ip.find(':') != std::string::npos) { if (ip.find(':') != std::string::npos)
isIpV6_ = true; {
memset(&addr6_, 0, sizeof(addr6_)); sockaddr_in6 addr6;
addr6_.sin6_family = AF_INET6; memset(&addr6, 0, sizeof(addr6));
addr6_.sin6_port = htons(port); addr6.sin6_family = AF_INET6;
if (inet_pton(AF_INET6, ip.c_str(), &addr6_.sin6_addr) <= 0) { addr6.sin6_port = htons(port);
if (inet_pton(AF_INET6, ip.c_str(), &addr6.sin6_addr) <= 0)
{
LOG_ERROR << "Invalid IPv6 address: " << ip; LOG_ERROR << "Invalid IPv6 address: " << ip;
} }
} else { addr_ = addr6;
isIpV6_ = false; }
memset(&addr4_, 0, sizeof(addr4_)); else
addr4_.sin_family = AF_INET; {
addr4_.sin_port = htons(port); sockaddr_in addr4;
if (inet_pton(AF_INET, ip.c_str(), &addr4_.sin_addr) <= 0) { memset(&addr4, 0, sizeof(addr4));
addr4.sin_family = AF_INET;
addr4.sin_port = htons(port);
if (inet_pton(AF_INET, ip.c_str(), &addr4.sin_addr) <= 0)
{
LOG_ERROR << "Invalid IPv4 address: " << ip; LOG_ERROR << "Invalid IPv4 address: " << ip;
} }
addr_ = addr4;
} }
LOG_TRACE << "InetAddress created from ip:port: " << toIpPort(); LOG_TRACE << "InetAddress created from ip:port: " << toIpPort();
} }
explicit InetAddress(const sockaddr_in& addr) : addr4_(addr), isIpV6_(false) /*
* Constructs an address from a sockaddr_in struct.
*/
explicit InetAddress(const sockaddr_in &addr) : addr_(addr)
{ {
LOG_TRACE << "InetAddress created from sockaddr_in: " << toIpPort(); LOG_TRACE << "InetAddress created from sockaddr_in: " << toIpPort();
} }
explicit InetAddress(const sockaddr_in6& addr) : addr6_(addr), isIpV6_(true) /*
* Constructs an address from a sockaddr_in6 struct.
*/
explicit InetAddress(const sockaddr_in6 &addr) : addr_(addr)
{ {
LOG_TRACE << "InetAddress created from sockaddr_in6: " << toIpPort(); LOG_TRACE << "InetAddress created from sockaddr_in6: " << toIpPort();
} }
const sockaddr* getSockAddr() const const sockaddr *getSockAddr() const
{ {
if (isIpV6_) { // std::visit gets the pointer from the active variant member
return reinterpret_cast<const sockaddr*>(&addr6_); return std::visit(
} else { [](const auto &addr)
return reinterpret_cast<const sockaddr*>(&addr4_); {
} return reinterpret_cast<const sockaddr *>(&addr);
},
addr_);
} }
socklen_t getSockLen() const { return isIpV6_ ? sizeof(addr6_) : sizeof(addr4_); } socklen_t getSockLen() const
bool isIpV6() const { return isIpV6_; } {
uint16_t port() const { return ntohs(isIpV6_ ? addr6_.sin6_port : addr4_.sin_port); } return std::visit([](const auto &addr)
{ return sizeof(addr); }, addr_);
}
bool isIpV6() const { return std::holds_alternative<sockaddr_in6>(addr_); }
uint16_t port() const
{
// Use if constexpr to access members with different names
return std::visit(
[](const auto &addr)
{
if constexpr (std::is_same_v<std::decay_t<decltype(addr)>, sockaddr_in>)
{
return ntohs(addr.sin_port);
}
else
{
return ntohs(addr.sin6_port);
}
},
addr_);
}
std::string toIp() const std::string toIp() const
{ {
char buf[INET6_ADDRSTRLEN]; char buf[INET6_ADDRSTRLEN];
if (isIpV6_) { std::visit(
inet_ntop(AF_INET6, &addr6_.sin6_addr, buf, sizeof(buf)); [&buf](const auto &addr)
} else { {
inet_ntop(AF_INET, &addr4_.sin_addr, buf, sizeof(buf)); if constexpr (std::is_same_v<std::decay_t<decltype(addr)>, sockaddr_in>)
} {
inet_ntop(AF_INET, &addr.sin_addr, buf, sizeof(buf));
}
else
{
inet_ntop(AF_INET6, &addr.sin6_addr, buf, sizeof(buf));
}
},
addr_);
return std::string(buf); return std::string(buf);
} }
std::string toIpPort() const std::string toIpPort() const
{ {
return isIpV6_ ? "[" + toIp() + "]:" + std::to_string(port()) return isIpV6() ? "[" + toIp() + "]:" + std::to_string(port())
: toIp() + ":" + std::to_string(port()); : toIp() + ":" + std::to_string(port());
} }
bool operator==(const InetAddress& other) const bool operator==(const InetAddress &other) const
{ {
if (isIpV6_ != other.isIpV6_) return false; if (addr_.index() != other.addr_.index()) return false;
return std::visit(
if (isIpV6_) { [&other](const auto &self_addr)
return memcmp(&addr6_, &other.addr6_, sizeof(addr6_)) == 0; {
} else { const auto &other_addr = std::get<std::decay_t<decltype(self_addr)>>(other.addr_);
return memcmp(&addr4_, &other.addr4_, sizeof(addr4_)) == 0; return memcmp(&self_addr, &other_addr, sizeof(self_addr)) == 0;
} },
addr_);
} }
bool operator!=(const InetAddress& other) const bool operator!=(const InetAddress &other) const
{ {
return !(*this == other); return !(*this == other);
} }
bool operator<(const InetAddress& other) const bool operator<(const InetAddress &other) const
{ {
if (isIpV6_ != other.isIpV6_) { if (isIpV6() != other.isIpV6())
return !isIpV6_; {
return !isIpV6();
} }
if (isIpV6_) { return std::visit(
return memcmp(&addr6_, &other.addr6_, sizeof(addr6_)) < 0; [&other](const auto &self_addr)
} else { {
return memcmp(&addr4_, &other.addr4_, sizeof(addr4_)) < 0; const auto &other_addr = std::get<std::decay_t<decltype(self_addr)>>(other.addr_);
} return memcmp(&self_addr, &other_addr, sizeof(self_addr)) < 0;
},
addr_);
} }
std::string familyToString() const std::string familyToString() const
{ {
return isIpV6_ ? "IPv6" : "IPv4"; return isIpV6() ? "IPv6" : "IPv4";
} }
static bool resolve(const std::string& hostname, InetAddress& result) static bool resolve(const std::string &hostname, InetAddress &result)
{ {
// Simple resolution - in a real implementation you'd use getaddrinfo // Simple resolution - in a real implementation you'd use getaddrinfo
if (hostname == "localhost") { if (hostname == "localhost")
{
result = InetAddress(0, false, true); result = InetAddress(0, false, true);
return true; return true;
} }
// Try to parse as IP address directly // Try to parse as IP address directly
InetAddress addr(hostname, 0); InetAddress addr(hostname, 0);
if (addr.toIp() != "0.0.0.0" && addr.toIp() != "::") { if (addr.toIp() != "0.0.0.0" && addr.toIp() != "::")
{
result = addr; result = addr;
return true; return true;
} }
@ -150,21 +208,34 @@ public:
LOG_WARN << "Could not resolve hostname: " << hostname; LOG_WARN << "Could not resolve hostname: " << hostname;
return false; return false;
} }
// Make variant accessible for hashing
const std::variant<sockaddr_in, sockaddr_in6> &getVariant() const
{
return addr_;
}
}; };
} // namespace reactor } // namespace reactor
namespace std { namespace std
template<> {
template <>
struct hash<reactor::InetAddress> struct hash<reactor::InetAddress>
{ {
size_t operator()(const reactor::InetAddress& addr) const size_t operator()(const reactor::InetAddress &addr) const
{ {
size_t seed = 0; // Hash the raw bytes of the underlying struct for better performance
reactor::hashCombine(seed, addr.toIp()); // than converting to a string first.
reactor::hashCombine(seed, addr.port()); return std::visit(
reactor::hashCombine(seed, addr.isIpV6()); [](const auto &a)
return seed; {
size_t seed = 0;
std::string_view bytes(reinterpret_cast<const char *>(&a), sizeof(a));
reactor::hashCombine(seed, std::hash<std::string_view>{}(bytes));
return seed;
},
addr.getVariant());
} }
}; };
} // namespace std } // namespace std

View File

@ -74,9 +74,10 @@ void test_address_ordering()
reactor::InetAddress addr3("192.168.1.1", 8080); reactor::InetAddress addr3("192.168.1.1", 8080);
reactor::InetAddress addr4("::1", 8080); reactor::InetAddress addr4("::1", 8080);
std::cout << "addr1 < addr2: " << (addr1 < addr2) << "\n"; // Use asserts for proper testing
std::cout << "addr1 < addr3: " << (addr1 < addr3) << "\n"; assert(addr1 < addr2);
std::cout << "addr1 < addr4: " << (addr1 < addr4) << "\n"; assert(addr1 < addr3);
assert(addr1 < addr4); // IPv4 is less than IPv6 in this implementation
std::cout << "✓ Address ordering passed\n"; std::cout << "✓ Address ordering passed\n";
} }