diff --git a/dbms/src/Access/AllowedClientHosts.cpp b/dbms/src/Access/AllowedClientHosts.cpp index 735411c5657..de720df1fe4 100644 --- a/dbms/src/Access/AllowedClientHosts.cpp +++ b/dbms/src/Access/AllowedClientHosts.cpp @@ -9,6 +9,7 @@ #include #include #include +#include #include @@ -23,29 +24,64 @@ namespace ErrorCodes namespace { using IPAddress = Poco::Net::IPAddress; + using IPSubnet = AllowedClientHosts::IPSubnet; + const IPSubnet ALL_ADDRESSES{IPAddress{IPAddress::IPv6}, IPAddress{IPAddress::IPv6}}; - const AllowedClientHosts::IPSubnet ALL_ADDRESSES = AllowedClientHosts::IPSubnet{IPAddress{IPAddress::IPv6}, IPAddress{IPAddress::IPv6}}; - - IPAddress toIPv6(const IPAddress & addr) + const IPAddress & getIPV6Loopback() { - if (addr.family() == IPAddress::IPv6) - return addr; - - if (addr.isLoopback()) - return IPAddress("::1"); - - return IPAddress("::FFFF:" + addr.toString()); + static const IPAddress ip("::1"); + return ip; } - IPAddress maskToIPv6(const IPAddress & mask) + bool isIPV4LoopbackMappedToIPV6(const IPAddress & ip) { - if (mask.family() == IPAddress::IPv6) - return mask; - - return IPAddress(96, IPAddress::IPv6) | toIPv6(mask); + static const IPAddress prefix("::ffff:127.0.0.0"); + /// 104 == 128 - 24, we have to reset the lowest 24 bits of 128 before comparing with `prefix` + /// (IPv4 loopback means any IP from 127.0.0.0 to 127.255.255.255). + return (ip & IPAddress(104, IPAddress::IPv6)) == prefix; } + /// Converts an address to IPv6. + /// The loopback address "127.0.0.1" (or any "127.x.y.z") is converted to "::1". + IPAddress toIPv6(const IPAddress & ip) + { + IPAddress v6; + if (ip.family() == IPAddress::IPv6) + v6 = ip; + else + v6 = IPAddress("::ffff:" + ip.toString()); + // ::ffff:127.XX.XX.XX -> ::1 + if (isIPV4LoopbackMappedToIPV6(v6)) + v6 = getIPV6Loopback(); + + return v6; + } + + /// Converts a subnet to IPv6. + IPSubnet toIPv6(const IPSubnet & subnet) + { + IPSubnet v6; + if (subnet.prefix.family() == IPAddress::IPv6) + v6.prefix = subnet.prefix; + else + v6.prefix = IPAddress("::ffff:" + subnet.prefix.toString()); + + if (subnet.mask.family() == IPAddress::IPv6) + v6.mask = subnet.mask; + else + v6.mask = IPAddress(96, IPAddress::IPv6) | IPAddress("::ffff:" + subnet.mask.toString()); + + v6.prefix = v6.prefix & v6.mask; + + // ::ffff:127.XX.XX.XX -> ::1 + if (isIPV4LoopbackMappedToIPV6(v6.prefix)) + v6 = {getIPV6Loopback(), IPAddress(128, IPAddress::IPv6)}; + + return v6; + } + + /// Helper function for isAddressOfHost(). bool isAddressOfHostImpl(const IPAddress & address, const String & host) { IPAddress addr_v6 = toIPv6(address); @@ -93,15 +129,15 @@ namespace return false; } - - /// Cached version of isAddressOfHostImpl(). We need to cache DNS requests. + /// Whether a specified address is one of the addresses of a specified host. bool isAddressOfHost(const IPAddress & address, const String & host) { + /// We need to cache DNS requests. static SimpleCache cache; return cache(address, host); } - + /// Helper function for isAddressOfLocalhost(). std::vector getAddressesOfLocalhostImpl() { std::vector addresses; @@ -114,7 +150,7 @@ namespace int err = getifaddrs(&ifa_begin); if (err) - return {IPAddress{"127.0.0.1"}, IPAddress{"::1"}}; + return {getIPV6Loopback()}; for (const ifaddrs * ifa = ifa_begin; ifa; ifa = ifa->ifa_next) { @@ -134,15 +170,15 @@ namespace return addresses; } - - /// Checks if a specified address pointers to the localhost. - bool isLocalAddress(const IPAddress & address) + /// Whether a specified address is one of the addresses of the localhost. + bool isAddressOfLocalhost(const IPAddress & address) { + /// We need to cache DNS requests. static const std::vector local_addresses = getAddressesOfLocalhostImpl(); - return boost::range::find(local_addresses, address) != local_addresses.end(); + return boost::range::find(local_addresses, toIPv6(address)) != local_addresses.end(); } - + /// Helper function for getHostByAddress(). String getHostByAddressImpl(const IPAddress & address) { Poco::Net::SocketAddress sock_addr(address, 0); @@ -160,10 +196,10 @@ namespace return host; } - - /// Cached version of getHostByAddressImpl(). We need to cache DNS requests. + /// Returns the host name by its address. String getHostByAddress(const IPAddress & address) { + /// We need to cache DNS requests. static SimpleCache cache; return cache(address); } @@ -203,7 +239,7 @@ AllowedClientHosts::AllowedClientHosts(const AllowedClientHosts & src) AllowedClientHosts & AllowedClientHosts::operator =(const AllowedClientHosts & src) { addresses = src.addresses; - loopback = src.loopback; + localhost = src.localhost; subnets = src.subnets; host_names = src.host_names; host_regexps = src.host_regexps; @@ -212,28 +248,14 @@ AllowedClientHosts & AllowedClientHosts::operator =(const AllowedClientHosts & s } -AllowedClientHosts::AllowedClientHosts(AllowedClientHosts && src) -{ - *this = src; -} - - -AllowedClientHosts & AllowedClientHosts::operator =(AllowedClientHosts && src) -{ - addresses = std::move(src.addresses); - loopback = src.loopback; - subnets = std::move(src.subnets); - host_names = std::move(src.host_names); - host_regexps = std::move(src.host_regexps); - compiled_host_regexps = std::move(src.compiled_host_regexps); - return *this; -} +AllowedClientHosts::AllowedClientHosts(AllowedClientHosts && src) = default; +AllowedClientHosts & AllowedClientHosts::operator =(AllowedClientHosts && src) = default; void AllowedClientHosts::clear() { addresses.clear(); - loopback = false; + localhost = false; subnets.clear(); host_names.clear(); host_regexps.clear(); @@ -250,10 +272,11 @@ bool AllowedClientHosts::empty() const void AllowedClientHosts::addAddress(const IPAddress & address) { IPAddress addr_v6 = toIPv6(address); - if (boost::range::find(addresses, addr_v6) == addresses.end()) - addresses.push_back(addr_v6); + if (boost::range::find(addresses, addr_v6) != addresses.end()) + return; + addresses.push_back(addr_v6); if (addr_v6.isLoopback()) - loopback = true; + localhost = true; } @@ -265,9 +288,7 @@ void AllowedClientHosts::addAddress(const String & address) void AllowedClientHosts::addSubnet(const IPSubnet & subnet) { - IPSubnet subnet_v6; - subnet_v6.prefix = toIPv6(subnet.prefix); - subnet_v6.mask = maskToIPv6(subnet.mask); + IPSubnet subnet_v6 = toIPv6(subnet); if (subnet_v6.mask == IPAddress(128, IPAddress::IPv6)) { @@ -275,8 +296,6 @@ void AllowedClientHosts::addSubnet(const IPSubnet & subnet) return; } - subnet_v6.prefix = subnet_v6.prefix & subnet_v6.mask; - if (boost::range::find(subnets, subnet_v6) == subnets.end()) subnets.push_back(subnet_v6); } @@ -314,8 +333,11 @@ void AllowedClientHosts::addSubnet(const String & subnet) void AllowedClientHosts::addHostName(const String & host_name) { - if (boost::range::find(host_names, host_name) == host_names.end()) - host_names.push_back(host_name); + if (boost::range::find(host_names, host_name) != host_names.end()) + return; + host_names.push_back(host_name); + if (boost::iequals(host_name, "localhost")) + localhost = true; } @@ -360,7 +382,7 @@ bool AllowedClientHosts::contains(const IPAddress & address) const if (boost::range::find(addresses, addr_v6) != addresses.end()) return true; - if (loopback && isLocalAddress(addr_v6)) + if (localhost && isAddressOfLocalhost(addr_v6)) return true; /// Check `ip_subnets`. diff --git a/dbms/src/Access/AllowedClientHosts.h b/dbms/src/Access/AllowedClientHosts.h index 17f8be878a1..34abd22c3bf 100644 --- a/dbms/src/Access/AllowedClientHosts.h +++ b/dbms/src/Access/AllowedClientHosts.h @@ -94,7 +94,7 @@ private: void compileRegexps() const; std::vector addresses; - bool loopback = false; + bool localhost = false; std::vector subnets; std::vector host_names; std::vector host_regexps;