diff --git a/src/Common/IPv6ToBinary.cpp b/src/Common/IPv6ToBinary.cpp index 54630eda6f3..a5ceaa72022 100644 --- a/src/Common/IPv6ToBinary.cpp +++ b/src/Common/IPv6ToBinary.cpp @@ -15,10 +15,6 @@ namespace DB constexpr size_t IPV6_MASKS_COUNT = 256; using RawMaskArrayV6 = std::array; -/// Same for IPv4 -constexpr size_t IPV4_MASKS_COUNT = 256; -using RawMaskArrayV4 = std::array; - void IPv6ToRawBinary(const Poco::Net::IPAddress & address, char * res) { if (Poco::Net::IPAddress::IPv6 == address.family()) @@ -75,10 +71,55 @@ const std::array & getCIDRMaskIPv6(UInt8 prefix_len) return IPV6_RAW_MASK_ARRAY[prefix_len]; } -const std::array & getCIDRMaskIPv4(UInt8 prefix_len) +bool matchIPv4Subnet(UInt32 addr, UInt32 cidr_addr, UInt8 prefix) { - static constexpr auto IPV4_RAW_MASK_ARRAY = generateBitMasks(); - return IPV4_RAW_MASK_ARRAY[prefix_len]; + UInt32 mask = (prefix >= 32) ? 0xffffffffu : ~(0xffffffffu >> prefix); + return (addr & mask) == (cidr_addr & mask); } +#if defined(__SSE2__) +#include + +bool matchIPv6Subnet(const uint8_t * addr, const uint8_t * cidr_addr, UInt8 prefix) +{ + uint16_t mask = _mm_movemask_epi8(_mm_cmpeq_epi8( + _mm_loadu_si128(reinterpret_cast(addr)), + _mm_loadu_si128(reinterpret_cast(cidr_addr)))); + mask = ~mask; + + if (mask) + { + auto offset = __builtin_ctz(mask); + + if (prefix / 8 != offset) + return prefix / 8 < offset; + + auto cmpmask = ~(0xff >> (prefix % 8)); + return (addr[offset] & cmpmask) == (cidr_addr[offset] & cmpmask); + } + return true; +} + +# else + +bool matchIPv6Subnet(const uint8_t * addr, const uint8_t * cidr_addr, UInt8 prefix) +{ + if (prefix > IPV6_BINARY_LENGTH * 8U) + prefix = IPV6_BINARY_LENGTH * 8U; + + size_t i = 0; + for (; prefix >= 8; ++i, prefix -= 8) + { + if (target[i] != cidr_addr[i]) + return false; + } + if (prefix == 0) + return true; + + auto mask = ~(0xff >> prefix); + return (addr[i] & mask) == (cidr_addr[i] & mask); +} + +#endif // __SSE2__ + } diff --git a/src/Common/IPv6ToBinary.h b/src/Common/IPv6ToBinary.h index 6d2d3d33e97..d766d408359 100644 --- a/src/Common/IPv6ToBinary.h +++ b/src/Common/IPv6ToBinary.h @@ -19,7 +19,8 @@ std::array IPv6ToBinary(const Poco::Net::IPAddress & address); /// Values of prefix_len greater than 128 interpreted as 128 exactly. const std::array & getCIDRMaskIPv6(UInt8 prefix_len); -/// This is identical to getCIDRMaskIPv6 except it's for IPv4 addresses. -const std::array & getCIDRMaskIPv4(UInt8 prefix_len); +/// Check that address contained in CIDR range +bool matchIPv4Subnet(UInt32 addr, UInt32 cidr_addr, UInt8 prefix); +bool matchIPv6Subnet(const uint8_t * addr, const uint8_t * cidr_addr, UInt8 prefix); } diff --git a/src/Dictionaries/IPAddressDictionary.cpp b/src/Dictionaries/IPAddressDictionary.cpp index 165fa3a000d..4b51d94f0d8 100644 --- a/src/Dictionaries/IPAddressDictionary.cpp +++ b/src/Dictionaries/IPAddressDictionary.cpp @@ -4,19 +4,17 @@ #include #include #include -#include #include #include #include #include -#include #include #include #include #include #include -#include "DictionaryBlockInputStream.h" -#include "DictionaryFactory.h" +#include +#include #include namespace DB @@ -191,57 +189,6 @@ inline static void mapIPv4ToIPv6(UInt32 addr, uint8_t * buf) memcpy(&buf[12], &addr, 4); } -static bool matchIPv4Subnet(UInt32 target, UInt32 addr, UInt8 prefix) -{ - UInt32 mask = (prefix >= 32) ? 0xffffffffu : ~(0xffffffffu >> prefix); - return (target & mask) == addr; -} - -#if defined(__SSE2__) -#include - -static bool matchIPv6Subnet(const uint8_t * target, const uint8_t * addr, UInt8 prefix) -{ - uint16_t mask = _mm_movemask_epi8(_mm_cmpeq_epi8( - _mm_loadu_si128(reinterpret_cast(target)), - _mm_loadu_si128(reinterpret_cast(addr)))); - mask = ~mask; - - if (mask) - { - auto offset = __builtin_ctz(mask); - - if (prefix / 8 != offset) - return prefix / 8 < offset; - - auto cmpmask = ~(0xff >> (prefix % 8)); - return (target[offset] & cmpmask) == addr[offset]; - } - return true; -} - -# else - -static bool matchIPv6Subnet(const uint8_t * target, const uint8_t * addr, UInt8 prefix) -{ - if (prefix > IPV6_BINARY_LENGTH * 8U) - prefix = IPV6_BINARY_LENGTH * 8U; - - size_t i = 0; - for (; prefix >= 8; ++i, prefix -= 8) - { - if (target[i] != addr[i]) - return false; - } - if (prefix == 0) - return true; - - auto mask = ~(0xff >> prefix); - return (target[i] & mask) == addr[i]; -} - -#endif // __SSE2__ - IPAddressDictionary::IPAddressDictionary( const StorageID & dict_id_, const DictionaryStructure & dict_struct_, diff --git a/src/Functions/isIPAddressContainedIn.cpp b/src/Functions/isIPAddressContainedIn.cpp index 49a0c5b74de..45e2f15ed66 100644 --- a/src/Functions/isIPAddressContainedIn.cpp +++ b/src/Functions/isIPAddressContainedIn.cpp @@ -4,291 +4,120 @@ #include #include #include +#include #include #include #include #include #include #include +#include -namespace DB + +#include +namespace DB::ErrorCodes { - namespace ErrorCodes - { - extern const int CANNOT_PARSE_TEXT; - extern const int ILLEGAL_TYPE_OF_ARGUMENT; - } + extern const int CANNOT_PARSE_TEXT; + extern const int ILLEGAL_TYPE_OF_ARGUMENT; + extern const int CANNOT_PARSE_NUMBER; } -namespace ipaddr +namespace { - class Address + +class IPAddressVariant +{ +public: + + explicit IPAddressVariant(const StringRef & address_str) { - public: - Address() = delete; + // IP address parser functions require that the input is + // NULL-terminated so we need to copy it. + const auto address_str_copy = std::string(address_str); - explicit Address(const StringRef & in) + UInt32 v4; + if (DB::parseIPv4(address_str_copy.c_str(), reinterpret_cast(&v4))) { - // IP address parser functions require that the input is - // NULL-terminated so we need to copy it. - const auto in_copy = std::string(in); - - UInt32 v4; - if (DB::parseIPv4(in_copy.c_str(), reinterpret_cast(&v4))) - { - addr = V4(v4); - } - else - { - V6 v6; - if (DB::parseIPv6(in_copy.c_str(), v6.addr.data())) - addr = std::move(v6); - else - throw DB::Exception("Neither IPv4 nor IPv6 address: " + in_copy, - DB::ErrorCodes::CANNOT_PARSE_TEXT); - } + addr = v4; } - - template - struct IPVersionBase + else { - IPVersionBase() - : addr {} {} - - explicit IPVersionBase(const std::array & octets) - : addr(octets) {} - - constexpr size_t numBits() const - { - return numOctets * 8; - } - - uint8_t operator[] (size_t i) const - { - assert(i >= 0 && i < numOctets); - return addr[i]; - } - - uint8_t & operator[] (size_t i) - { - assert(i >= 0 && i < numOctets); - return addr[i]; - } - - bool operator<= (const ConcreteType & rhs) const - { - for (size_t i = 0; i < numOctets; i++) - { - if ((*this)[i] < rhs[i]) return true; - if ((*this)[i] > rhs[i]) return false; - } - return true; - } - - bool operator>= (const ConcreteType & rhs) const - { - for (size_t i = 0; i < numOctets; i++) - { - if ((*this)[i] > rhs[i]) return true; - if ((*this)[i] < rhs[i]) return false; - } - return true; - } - - ConcreteType operator& (const ConcreteType & rhs) const - { - ConcreteType lhs(addr); - - for (size_t i = 0; i < numOctets; i++) - lhs[i] &= rhs[i]; - - return lhs; - } - - ConcreteType operator| (const ConcreteType & rhs) const - { - ConcreteType lhs(addr); - - for (size_t i = 0; i < numOctets; i++) - lhs[i] |= rhs[i]; - - return lhs; - } - - ConcreteType operator~ () const - { - ConcreteType tmp(addr); - - for (size_t i = 0; i < numOctets; i++) - tmp[i] = ~tmp[i]; - - return tmp; - } - - private: - // Big-endian - std::array addr; - friend class Address; - }; - - struct V4 : public IPVersionBase - { - V4() = default; - - explicit V4(UInt32 addr_) - { - addr[0] = (addr_ >> 24) & 0xFF; - addr[1] = (addr_ >> 16) & 0xFF; - addr[2] = (addr_ >> 8) & 0xFF; - addr[3] = addr_ & 0xFF; - } - - explicit V4(const std::array & components) - : IPVersionBase(components) {} - }; - - struct V6 : public IPVersionBase - { - V6() = default; - - explicit V6(const std::array & components) - : IPVersionBase(components) {} - }; - - constexpr const std::variant & variant() const - { - return addr; + addr = IPv6AddrType(); + bool success = DB::parseIPv6(address_str_copy.c_str(), std::get(addr).data()); + if (!success) + throw DB::Exception("Neither IPv4 nor IPv6 address: '" + address_str_copy + "'", + DB::ErrorCodes::CANNOT_PARSE_TEXT); } + } - private: - std::variant addr; - }; - - class CIDR + UInt32 asV4() const { - public: - CIDR() = delete; + if (const auto * val = std::get_if(&addr)) + return *val; + return 0; + } - explicit CIDR(const StringRef & in) - { - const auto in_view = std::string_view(in); - const auto pos_slash = in_view.find('/'); + const uint8_t * asV6() const + { + if (const auto * val = std::get_if(&addr)) + return val->data(); + return nullptr; + } - if (pos_slash == std::string_view::npos) - throw DB::Exception("The text does not contain '/': " + std::string(in_view), - DB::ErrorCodes::CANNOT_PARSE_TEXT); +private: + using IPv4AddrType = UInt32; + using IPv6AddrType = std::array; - prefix = Address(StringRef(in_view.substr(0, pos_slash))); + std::variant addr; +}; - // DB::parse() in ignores - // non-digit characters. std::stoi() skips whitespaces. We - // need to parse the prefix bits in a strict way. +struct IPAddressCIDR +{ + IPAddressVariant address; + UInt8 prefix; +}; - if (pos_slash + 1 == in_view.size()) - throw DB::Exception("The CIDR has no prefix bits: " + std::string(in_view), - DB::ErrorCodes::CANNOT_PARSE_TEXT); +IPAddressCIDR parseIPWithCIDR(const StringRef cidr_str) +{ + std::string_view cidr_str_view(cidr_str); + size_t pos_slash = cidr_str_view.find('/'); - bits = 0; - for (size_t i = pos_slash + 1; i < in_view.size(); i++) - { - const auto c = in_view[i]; - if (c >= '0' && c <= '9') - { - bits *= 10; - bits += c - '0'; - } - else - { - throw DB::Exception("The CIDR has a malformed prefix bits: " + std::string(in_view), - DB::ErrorCodes::CANNOT_PARSE_TEXT); - } - } + if (pos_slash == 0) + throw DB::Exception("Error parsing IP address with prefix: " + std::string(cidr_str), DB::ErrorCodes::CANNOT_PARSE_TEXT); + if (pos_slash == std::string_view::npos) + throw DB::Exception("The text does not contain '/': " + std::string(cidr_str), DB::ErrorCodes::CANNOT_PARSE_TEXT); - const size_t max_bits - = std::visit([&](const auto & addr_v) -> size_t - { - return addr_v.numBits(); - }, prefix->variant()); - if (bits > max_bits) - throw DB::Exception("The CIDR has an invalid prefix bits: " + std::string(in_view), - DB::ErrorCodes::CANNOT_PARSE_TEXT); - } + std::string_view addr_str = cidr_str_view.substr(0, pos_slash); + IPAddressVariant addr(StringRef{addr_str.data(), addr_str.size()}); - private: - template - static PrefixT toMask(uint8_t bits) - { - if constexpr (std::is_same_v) - { - return PrefixT(DB::getCIDRMaskIPv4(bits)); - } - else - { - return PrefixT(DB::getCIDRMaskIPv6(bits)); - } - } + UInt8 prefix = 0; + auto prefix_str = cidr_str_view.substr(pos_slash+1); - template - static PrefixT startOf(const PrefixT & prefix, uint8_t bits) - { - return prefix & toMask(bits); - } + const auto * prefix_str_end = prefix_str.data() + prefix_str.size(); + auto [parse_end, parse_error] = std::from_chars(prefix_str.data(), prefix_str_end, prefix); + UInt8 max_prefix = (addr.asV6() ? IPV6_BINARY_LENGTH : IPV4_BINARY_LENGTH) * 8; + bool has_error = parse_error != std::errc() || parse_end != prefix_str_end || prefix > max_prefix; + if (has_error) + throw DB::Exception("The CIDR has a malformed prefix bits: " + std::string(cidr_str), DB::ErrorCodes::CANNOT_PARSE_TEXT); - template - static PrefixT endOf(const PrefixT & prefix, uint8_t bits) - { - return prefix | ~toMask(bits); - } + return {addr, prefix}; +} - /* Convert a CIDR notation into an IP address range [start, end]. */ - template - static std::pair toRange(const PrefixT & prefix, uint8_t bits) - { - return std::make_pair(startOf(prefix, bits), endOf(prefix, bits)); - } +inline bool isAddressInRange(const IPAddressVariant & address, const IPAddressCIDR & cidr) +{ + if (const auto * cidr_v6 = cidr.address.asV6()) + { + if (const auto * addr_v6 = address.asV6()) + return DB::matchIPv6Subnet(addr_v6, cidr_v6, cidr.prefix); + } + else + { + if (!address.asV6()) + return DB::matchIPv4Subnet(address.asV4(), cidr.address.asV4(), cidr.prefix); + } + return false; +} - public: - bool contains(const Address & addr) const - { - return std::visit([&](const auto & addr_v) -> bool - { - return std::visit([&](const auto & prefix_v) -> bool - { - using AddrT = std::decay_t; - using PrefixT = std::decay_t; - - if constexpr (std::is_same_v) - { - if constexpr (std::is_same_v) - { - const auto range = toRange(prefix_v, bits); - return addr_v >= range.first && addr_v <= range.second; - } - else - { - return false; // IP version mismatch is not an error. - } - } - else - { - if constexpr (std::is_same_v) - { - const auto range = toRange(prefix_v, bits); - return addr_v >= range.first && addr_v <= range.second; - } - else - { - return false; // IP version mismatch is not an error. - } - } - }, prefix->variant()); - }, addr.variant()); - } - - private: - std::optional
prefix; // Guaranteed to have a value after construction. - uint8_t bits; - }; } namespace DB @@ -309,9 +138,8 @@ namespace DB if (const auto * col_addr_const = checkAndGetAnyColumnConst(col_addr)) { - // col_addr_const is constant and is either String or - // Nullable(String). We don't care which one it exactly is. - + // col_addr_const is constant and is either String or Nullable(String). + // We don't care which one it exactly is. if (const auto * col_cidr_const = checkAndGetAnyColumnConst(col_cidr)) return executeImpl(*col_addr_const, *col_cidr_const, return_type, input_rows_count); else @@ -359,13 +187,13 @@ namespace DB } else { - const auto addr = ipaddr::Address(col_addr.getDataAt(0)); - const auto cidr = ipaddr::CIDR(col_cidr.getDataAt(0)); + const auto addr = IPAddressVariant(col_addr.getDataAt(0)); + const auto cidr = parseIPWithCIDR(col_cidr.getDataAt(0)); ColumnUInt8::MutablePtr col_res = ColumnUInt8::create(1); ColumnUInt8::Container & vec_res = col_res->getData(); - vec_res[0] = cidr.contains(addr) ? 1 : 0; + vec_res[0] = isAddressInRange(addr, cidr) ? 1 : 0; if (return_type->isNullable()) { @@ -396,7 +224,7 @@ namespace DB } else { - const auto addr = ipaddr::Address(col_addr.getDataAt(0)); + const auto addr = IPAddressVariant(col_addr.getDataAt (0)); ColumnUInt8::MutablePtr col_res = ColumnUInt8::create(input_rows_count); ColumnUInt8::Container & vec_res = col_res->getData(); @@ -414,8 +242,8 @@ namespace DB } else { - const auto cidr = ipaddr::CIDR(col_cidr.getDataAt(i)); - vec_res[i] = cidr.contains(addr) ? 1 : 0; + const auto cidr = parseIPWithCIDR(col_cidr.getDataAt(i)); + vec_res[i] = isAddressInRange(addr, cidr) ? 1 : 0; vec_null_map_res[i] = false; } } @@ -426,8 +254,8 @@ namespace DB { for (size_t i = 0; i < input_rows_count; i++) { - const auto cidr = ipaddr::CIDR(col_cidr.getDataAt(i)); - vec_res[i] = cidr.contains(addr) ? 1 : 0; + const auto cidr = parseIPWithCIDR(col_cidr.getDataAt(i)); + vec_res[i] = isAddressInRange(addr, cidr) ? 1 : 0; } return col_res; @@ -448,7 +276,7 @@ namespace DB } else { - const auto cidr = ipaddr::CIDR(col_cidr.getDataAt(0)); + const auto cidr = parseIPWithCIDR(col_cidr.getDataAt(0)); ColumnUInt8::MutablePtr col_res = ColumnUInt8::create(input_rows_count); ColumnUInt8::Container & vec_res = col_res->getData(); @@ -466,8 +294,8 @@ namespace DB } else { - const auto addr = ipaddr::Address(col_addr.getDataAt(i)); - vec_res[i] = cidr.contains(addr) ? 1 : 0; + const auto addr = IPAddressVariant(col_addr.getDataAt(i)); + vec_res[i] = isAddressInRange(addr, cidr) ? 1 : 0; vec_null_map_res[i] = false; } } @@ -478,8 +306,8 @@ namespace DB { for (size_t i = 0; i < input_rows_count; i++) { - const auto addr = ipaddr::Address(col_addr.getDataAt(i)); - vec_res[i] = cidr.contains(addr) ? 1 : 0; + const auto addr = IPAddressVariant(col_addr.getDataAt(i)); + vec_res[i] = isAddressInRange(addr, cidr) ? 1 : 0; } return col_res; @@ -506,10 +334,10 @@ namespace DB } else { - const auto addr = ipaddr::Address(col_addr.getDataAt(i)); - const auto cidr = ipaddr::CIDR(col_cidr.getDataAt(i)); + const auto addr = IPAddressVariant(col_addr.getDataAt(i)); + const auto cidr = parseIPWithCIDR(col_cidr.getDataAt(i)); - vec_res[i] = cidr.contains(addr) ? 1 : 0; + vec_res[i] = isAddressInRange(addr, cidr) ? 1 : 0; vec_null_map_res[i] = false; } } @@ -520,10 +348,10 @@ namespace DB { for (size_t i = 0; i < input_rows_count; i++) { - const auto addr = ipaddr::Address(col_addr.getDataAt(i)); - const auto cidr = ipaddr::CIDR(col_cidr.getDataAt(i)); + const auto addr = IPAddressVariant(col_addr.getDataAt(i)); + const auto cidr = parseIPWithCIDR(col_cidr.getDataAt(i)); - vec_res[i] = cidr.contains(addr) ? 1 : 0; + vec_res[i] = isAddressInRange(addr, cidr) ? 1 : 0; } return col_res;