Reuse some functions for IPAddressContainedIn

This commit is contained in:
vdimir 2021-03-29 12:04:05 +03:00
parent 26dc629366
commit 24aa25d7dc
No known key found for this signature in database
GPG Key ID: F57B3E10A21DBB31
4 changed files with 157 additions and 340 deletions

View File

@ -15,10 +15,6 @@ namespace DB
constexpr size_t IPV6_MASKS_COUNT = 256; constexpr size_t IPV6_MASKS_COUNT = 256;
using RawMaskArrayV6 = std::array<uint8_t, IPV6_BINARY_LENGTH>; using RawMaskArrayV6 = std::array<uint8_t, IPV6_BINARY_LENGTH>;
/// Same for IPv4
constexpr size_t IPV4_MASKS_COUNT = 256;
using RawMaskArrayV4 = std::array<uint8_t, IPV4_BINARY_LENGTH>;
void IPv6ToRawBinary(const Poco::Net::IPAddress & address, char * res) void IPv6ToRawBinary(const Poco::Net::IPAddress & address, char * res)
{ {
if (Poco::Net::IPAddress::IPv6 == address.family()) if (Poco::Net::IPAddress::IPv6 == address.family())
@ -75,10 +71,55 @@ const std::array<uint8_t, 16> & getCIDRMaskIPv6(UInt8 prefix_len)
return IPV6_RAW_MASK_ARRAY[prefix_len]; return IPV6_RAW_MASK_ARRAY[prefix_len];
} }
const std::array<uint8_t, 4> & getCIDRMaskIPv4(UInt8 prefix_len) bool matchIPv4Subnet(UInt32 addr, UInt32 cidr_addr, UInt8 prefix)
{ {
static constexpr auto IPV4_RAW_MASK_ARRAY = generateBitMasks<RawMaskArrayV4, IPV4_MASKS_COUNT>(); UInt32 mask = (prefix >= 32) ? 0xffffffffu : ~(0xffffffffu >> prefix);
return IPV4_RAW_MASK_ARRAY[prefix_len]; return (addr & mask) == (cidr_addr & mask);
} }
#if defined(__SSE2__)
#include <emmintrin.h>
bool matchIPv6Subnet(const uint8_t * addr, const uint8_t * cidr_addr, UInt8 prefix)
{
uint16_t mask = _mm_movemask_epi8(_mm_cmpeq_epi8(
_mm_loadu_si128(reinterpret_cast<const __m128i *>(addr)),
_mm_loadu_si128(reinterpret_cast<const __m128i *>(cidr_addr))));
mask = ~mask;
if (mask)
{
auto offset = __builtin_ctz(mask);
if (prefix / 8 != offset)
return prefix / 8 < offset;
auto cmpmask = ~(0xff >> (prefix % 8));
return (addr[offset] & cmpmask) == (cidr_addr[offset] & cmpmask);
}
return true;
}
# else
bool matchIPv6Subnet(const uint8_t * addr, const uint8_t * cidr_addr, UInt8 prefix)
{
if (prefix > IPV6_BINARY_LENGTH * 8U)
prefix = IPV6_BINARY_LENGTH * 8U;
size_t i = 0;
for (; prefix >= 8; ++i, prefix -= 8)
{
if (target[i] != cidr_addr[i])
return false;
}
if (prefix == 0)
return true;
auto mask = ~(0xff >> prefix);
return (addr[i] & mask) == (cidr_addr[i] & mask);
}
#endif // __SSE2__
} }

View File

@ -19,7 +19,8 @@ std::array<char, 16> IPv6ToBinary(const Poco::Net::IPAddress & address);
/// Values of prefix_len greater than 128 interpreted as 128 exactly. /// Values of prefix_len greater than 128 interpreted as 128 exactly.
const std::array<uint8_t, 16> & getCIDRMaskIPv6(UInt8 prefix_len); const std::array<uint8_t, 16> & getCIDRMaskIPv6(UInt8 prefix_len);
/// This is identical to getCIDRMaskIPv6 except it's for IPv4 addresses. /// Check that address contained in CIDR range
const std::array<uint8_t, 4> & getCIDRMaskIPv4(UInt8 prefix_len); bool matchIPv4Subnet(UInt32 addr, UInt32 cidr_addr, UInt8 prefix);
bool matchIPv6Subnet(const uint8_t * addr, const uint8_t * cidr_addr, UInt8 prefix);
} }

View File

@ -4,19 +4,17 @@
#include <Common/assert_cast.h> #include <Common/assert_cast.h>
#include <Common/IPv6ToBinary.h> #include <Common/IPv6ToBinary.h>
#include <Common/memcmpSmall.h> #include <Common/memcmpSmall.h>
#include <Common/memcpySmall.h>
#include <Common/typeid_cast.h> #include <Common/typeid_cast.h>
#include <DataTypes/DataTypeFixedString.h> #include <DataTypes/DataTypeFixedString.h>
#include <DataTypes/DataTypeString.h> #include <DataTypes/DataTypeString.h>
#include <DataTypes/DataTypesDecimal.h> #include <DataTypes/DataTypesDecimal.h>
#include <IO/WriteIntText.h>
#include <Poco/ByteOrder.h> #include <Poco/ByteOrder.h>
#include <Common/formatIPv6.h> #include <Common/formatIPv6.h>
#include <common/itoa.h> #include <common/itoa.h>
#include <ext/map.h> #include <ext/map.h>
#include <ext/range.h> #include <ext/range.h>
#include "DictionaryBlockInputStream.h" #include <Dictionaries/DictionaryBlockInputStream.h>
#include "DictionaryFactory.h" #include <Dictionaries/DictionaryFactory.h>
#include <Functions/FunctionHelpers.h> #include <Functions/FunctionHelpers.h>
namespace DB namespace DB
@ -191,57 +189,6 @@ inline static void mapIPv4ToIPv6(UInt32 addr, uint8_t * buf)
memcpy(&buf[12], &addr, 4); memcpy(&buf[12], &addr, 4);
} }
static bool matchIPv4Subnet(UInt32 target, UInt32 addr, UInt8 prefix)
{
UInt32 mask = (prefix >= 32) ? 0xffffffffu : ~(0xffffffffu >> prefix);
return (target & mask) == addr;
}
#if defined(__SSE2__)
#include <emmintrin.h>
static bool matchIPv6Subnet(const uint8_t * target, const uint8_t * addr, UInt8 prefix)
{
uint16_t mask = _mm_movemask_epi8(_mm_cmpeq_epi8(
_mm_loadu_si128(reinterpret_cast<const __m128i *>(target)),
_mm_loadu_si128(reinterpret_cast<const __m128i *>(addr))));
mask = ~mask;
if (mask)
{
auto offset = __builtin_ctz(mask);
if (prefix / 8 != offset)
return prefix / 8 < offset;
auto cmpmask = ~(0xff >> (prefix % 8));
return (target[offset] & cmpmask) == addr[offset];
}
return true;
}
# else
static bool matchIPv6Subnet(const uint8_t * target, const uint8_t * addr, UInt8 prefix)
{
if (prefix > IPV6_BINARY_LENGTH * 8U)
prefix = IPV6_BINARY_LENGTH * 8U;
size_t i = 0;
for (; prefix >= 8; ++i, prefix -= 8)
{
if (target[i] != addr[i])
return false;
}
if (prefix == 0)
return true;
auto mask = ~(0xff >> prefix);
return (target[i] & mask) == addr[i];
}
#endif // __SSE2__
IPAddressDictionary::IPAddressDictionary( IPAddressDictionary::IPAddressDictionary(
const StorageID & dict_id_, const StorageID & dict_id_,
const DictionaryStructure & dict_struct_, const DictionaryStructure & dict_struct_,

View File

@ -4,291 +4,120 @@
#include <Columns/ColumnsNumber.h> #include <Columns/ColumnsNumber.h>
#include <Common/IPv6ToBinary.h> #include <Common/IPv6ToBinary.h>
#include <Common/formatIPv6.h> #include <Common/formatIPv6.h>
#include <Common/IPv6ToBinary.h>
#include <DataTypes/DataTypeNullable.h> #include <DataTypes/DataTypeNullable.h>
#include <DataTypes/DataTypesNumber.h> #include <DataTypes/DataTypesNumber.h>
#include <Functions/IFunctionImpl.h> #include <Functions/IFunctionImpl.h>
#include <Functions/FunctionFactory.h> #include <Functions/FunctionFactory.h>
#include <Functions/FunctionHelpers.h> #include <Functions/FunctionHelpers.h>
#include <variant> #include <variant>
#include <charconv>
namespace DB
#include <common/logger_useful.h>
namespace DB::ErrorCodes
{ {
namespace ErrorCodes
{
extern const int CANNOT_PARSE_TEXT; extern const int CANNOT_PARSE_TEXT;
extern const int ILLEGAL_TYPE_OF_ARGUMENT; extern const int ILLEGAL_TYPE_OF_ARGUMENT;
} extern const int CANNOT_PARSE_NUMBER;
} }
namespace ipaddr namespace
{ {
class Address
{
public:
Address() = delete;
explicit Address(const StringRef & in) class IPAddressVariant
{
public:
explicit IPAddressVariant(const StringRef & address_str)
{ {
// IP address parser functions require that the input is // IP address parser functions require that the input is
// NULL-terminated so we need to copy it. // NULL-terminated so we need to copy it.
const auto in_copy = std::string(in); const auto address_str_copy = std::string(address_str);
UInt32 v4; UInt32 v4;
if (DB::parseIPv4(in_copy.c_str(), reinterpret_cast<unsigned char *>(&v4))) if (DB::parseIPv4(address_str_copy.c_str(), reinterpret_cast<unsigned char *>(&v4)))
{ {
addr = V4(v4); addr = v4;
} }
else else
{ {
V6 v6; addr = IPv6AddrType();
if (DB::parseIPv6(in_copy.c_str(), v6.addr.data())) bool success = DB::parseIPv6(address_str_copy.c_str(), std::get<IPv6AddrType>(addr).data());
addr = std::move(v6); if (!success)
else throw DB::Exception("Neither IPv4 nor IPv6 address: '" + address_str_copy + "'",
throw DB::Exception("Neither IPv4 nor IPv6 address: " + in_copy,
DB::ErrorCodes::CANNOT_PARSE_TEXT); DB::ErrorCodes::CANNOT_PARSE_TEXT);
} }
} }
template <typename ConcreteType, size_t numOctets> UInt32 asV4() const
struct IPVersionBase
{ {
IPVersionBase() if (const auto * val = std::get_if<IPv4AddrType>(&addr))
: addr {} {} return *val;
return 0;
explicit IPVersionBase(const std::array<uint8_t, numOctets> & octets)
: addr(octets) {}
constexpr size_t numBits() const
{
return numOctets * 8;
} }
uint8_t operator[] (size_t i) const const uint8_t * asV6() const
{ {
assert(i >= 0 && i < numOctets); if (const auto * val = std::get_if<IPv6AddrType>(&addr))
return addr[i]; return val->data();
return nullptr;
} }
uint8_t & operator[] (size_t i) private:
{ using IPv4AddrType = UInt32;
assert(i >= 0 && i < numOctets); using IPv6AddrType = std::array<uint8_t, IPV6_BINARY_LENGTH>;
return addr[i];
}
bool operator<= (const ConcreteType & rhs) const std::variant<IPv4AddrType, IPv6AddrType> addr;
{ };
for (size_t i = 0; i < numOctets; i++)
{
if ((*this)[i] < rhs[i]) return true;
if ((*this)[i] > rhs[i]) return false;
}
return true;
}
bool operator>= (const ConcreteType & rhs) const struct IPAddressCIDR
{ {
for (size_t i = 0; i < numOctets; i++) IPAddressVariant address;
{ UInt8 prefix;
if ((*this)[i] > rhs[i]) return true; };
if ((*this)[i] < rhs[i]) return false;
}
return true;
}
ConcreteType operator& (const ConcreteType & rhs) const IPAddressCIDR parseIPWithCIDR(const StringRef cidr_str)
{ {
ConcreteType lhs(addr); std::string_view cidr_str_view(cidr_str);
size_t pos_slash = cidr_str_view.find('/');
for (size_t i = 0; i < numOctets; i++)
lhs[i] &= rhs[i];
return lhs;
}
ConcreteType operator| (const ConcreteType & rhs) const
{
ConcreteType lhs(addr);
for (size_t i = 0; i < numOctets; i++)
lhs[i] |= rhs[i];
return lhs;
}
ConcreteType operator~ () const
{
ConcreteType tmp(addr);
for (size_t i = 0; i < numOctets; i++)
tmp[i] = ~tmp[i];
return tmp;
}
private:
// Big-endian
std::array<uint8_t, numOctets> addr;
friend class Address;
};
struct V4 : public IPVersionBase<V4, 4>
{
V4() = default;
explicit V4(UInt32 addr_)
{
addr[0] = (addr_ >> 24) & 0xFF;
addr[1] = (addr_ >> 16) & 0xFF;
addr[2] = (addr_ >> 8) & 0xFF;
addr[3] = addr_ & 0xFF;
}
explicit V4(const std::array<uint8_t, 4> & components)
: IPVersionBase(components) {}
};
struct V6 : public IPVersionBase<V6, 16>
{
V6() = default;
explicit V6(const std::array<uint8_t, 16> & components)
: IPVersionBase(components) {}
};
constexpr const std::variant<V4, V6> & variant() const
{
return addr;
}
private:
std::variant<V4, V6> addr;
};
class CIDR
{
public:
CIDR() = delete;
explicit CIDR(const StringRef & in)
{
const auto in_view = std::string_view(in);
const auto pos_slash = in_view.find('/');
if (pos_slash == 0)
throw DB::Exception("Error parsing IP address with prefix: " + std::string(cidr_str), DB::ErrorCodes::CANNOT_PARSE_TEXT);
if (pos_slash == std::string_view::npos) if (pos_slash == std::string_view::npos)
throw DB::Exception("The text does not contain '/': " + std::string(in_view), throw DB::Exception("The text does not contain '/': " + std::string(cidr_str), DB::ErrorCodes::CANNOT_PARSE_TEXT);
DB::ErrorCodes::CANNOT_PARSE_TEXT);
prefix = Address(StringRef(in_view.substr(0, pos_slash))); std::string_view addr_str = cidr_str_view.substr(0, pos_slash);
IPAddressVariant addr(StringRef{addr_str.data(), addr_str.size()});
// DB::parse<Uint8>() in <IO/ReadHelpers.h> ignores UInt8 prefix = 0;
// non-digit characters. std::stoi() skips whitespaces. We auto prefix_str = cidr_str_view.substr(pos_slash+1);
// need to parse the prefix bits in a strict way.
if (pos_slash + 1 == in_view.size()) const auto * prefix_str_end = prefix_str.data() + prefix_str.size();
throw DB::Exception("The CIDR has no prefix bits: " + std::string(in_view), auto [parse_end, parse_error] = std::from_chars(prefix_str.data(), prefix_str_end, prefix);
DB::ErrorCodes::CANNOT_PARSE_TEXT); UInt8 max_prefix = (addr.asV6() ? IPV6_BINARY_LENGTH : IPV4_BINARY_LENGTH) * 8;
bool has_error = parse_error != std::errc() || parse_end != prefix_str_end || prefix > max_prefix;
if (has_error)
throw DB::Exception("The CIDR has a malformed prefix bits: " + std::string(cidr_str), DB::ErrorCodes::CANNOT_PARSE_TEXT);
bits = 0; return {addr, prefix};
for (size_t i = pos_slash + 1; i < in_view.size(); i++) }
inline bool isAddressInRange(const IPAddressVariant & address, const IPAddressCIDR & cidr)
{
if (const auto * cidr_v6 = cidr.address.asV6())
{ {
const auto c = in_view[i]; if (const auto * addr_v6 = address.asV6())
if (c >= '0' && c <= '9') return DB::matchIPv6Subnet(addr_v6, cidr_v6, cidr.prefix);
{
bits *= 10;
bits += c - '0';
} }
else else
{ {
throw DB::Exception("The CIDR has a malformed prefix bits: " + std::string(in_view), if (!address.asV6())
DB::ErrorCodes::CANNOT_PARSE_TEXT); return DB::matchIPv4Subnet(address.asV4(), cidr.address.asV4(), cidr.prefix);
}
} }
return false;
}
const size_t max_bits
= std::visit([&](const auto & addr_v) -> size_t
{
return addr_v.numBits();
}, prefix->variant());
if (bits > max_bits)
throw DB::Exception("The CIDR has an invalid prefix bits: " + std::string(in_view),
DB::ErrorCodes::CANNOT_PARSE_TEXT);
}
private:
template <typename PrefixT>
static PrefixT toMask(uint8_t bits)
{
if constexpr (std::is_same_v<PrefixT, Address::V4>)
{
return PrefixT(DB::getCIDRMaskIPv4(bits));
}
else
{
return PrefixT(DB::getCIDRMaskIPv6(bits));
}
}
template <typename PrefixT>
static PrefixT startOf(const PrefixT & prefix, uint8_t bits)
{
return prefix & toMask<PrefixT>(bits);
}
template <typename PrefixT>
static PrefixT endOf(const PrefixT & prefix, uint8_t bits)
{
return prefix | ~toMask<PrefixT>(bits);
}
/* Convert a CIDR notation into an IP address range [start, end]. */
template <typename PrefixT>
static std::pair<PrefixT, PrefixT> toRange(const PrefixT & prefix, uint8_t bits)
{
return std::make_pair(startOf(prefix, bits), endOf(prefix, bits));
}
public:
bool contains(const Address & addr) const
{
return std::visit([&](const auto & addr_v) -> bool
{
return std::visit([&](const auto & prefix_v) -> bool
{
using AddrT = std::decay_t<decltype(addr_v)>;
using PrefixT = std::decay_t<decltype(prefix_v)>;
if constexpr (std::is_same_v<AddrT, Address::V4>)
{
if constexpr (std::is_same_v<PrefixT, Address::V4>)
{
const auto range = toRange(prefix_v, bits);
return addr_v >= range.first && addr_v <= range.second;
}
else
{
return false; // IP version mismatch is not an error.
}
}
else
{
if constexpr (std::is_same_v<PrefixT, Address::V6>)
{
const auto range = toRange(prefix_v, bits);
return addr_v >= range.first && addr_v <= range.second;
}
else
{
return false; // IP version mismatch is not an error.
}
}
}, prefix->variant());
}, addr.variant());
}
private:
std::optional<Address> prefix; // Guaranteed to have a value after construction.
uint8_t bits;
};
} }
namespace DB namespace DB
@ -309,9 +138,8 @@ namespace DB
if (const auto * col_addr_const = checkAndGetAnyColumnConst(col_addr)) if (const auto * col_addr_const = checkAndGetAnyColumnConst(col_addr))
{ {
// col_addr_const is constant and is either String or // col_addr_const is constant and is either String or Nullable(String).
// Nullable(String). We don't care which one it exactly is. // We don't care which one it exactly is.
if (const auto * col_cidr_const = checkAndGetAnyColumnConst(col_cidr)) if (const auto * col_cidr_const = checkAndGetAnyColumnConst(col_cidr))
return executeImpl(*col_addr_const, *col_cidr_const, return_type, input_rows_count); return executeImpl(*col_addr_const, *col_cidr_const, return_type, input_rows_count);
else else
@ -359,13 +187,13 @@ namespace DB
} }
else else
{ {
const auto addr = ipaddr::Address(col_addr.getDataAt(0)); const auto addr = IPAddressVariant(col_addr.getDataAt(0));
const auto cidr = ipaddr::CIDR(col_cidr.getDataAt(0)); const auto cidr = parseIPWithCIDR(col_cidr.getDataAt(0));
ColumnUInt8::MutablePtr col_res = ColumnUInt8::create(1); ColumnUInt8::MutablePtr col_res = ColumnUInt8::create(1);
ColumnUInt8::Container & vec_res = col_res->getData(); ColumnUInt8::Container & vec_res = col_res->getData();
vec_res[0] = cidr.contains(addr) ? 1 : 0; vec_res[0] = isAddressInRange(addr, cidr) ? 1 : 0;
if (return_type->isNullable()) if (return_type->isNullable())
{ {
@ -396,7 +224,7 @@ namespace DB
} }
else else
{ {
const auto addr = ipaddr::Address(col_addr.getDataAt(0)); const auto addr = IPAddressVariant(col_addr.getDataAt (0));
ColumnUInt8::MutablePtr col_res = ColumnUInt8::create(input_rows_count); ColumnUInt8::MutablePtr col_res = ColumnUInt8::create(input_rows_count);
ColumnUInt8::Container & vec_res = col_res->getData(); ColumnUInt8::Container & vec_res = col_res->getData();
@ -414,8 +242,8 @@ namespace DB
} }
else else
{ {
const auto cidr = ipaddr::CIDR(col_cidr.getDataAt(i)); const auto cidr = parseIPWithCIDR(col_cidr.getDataAt(i));
vec_res[i] = cidr.contains(addr) ? 1 : 0; vec_res[i] = isAddressInRange(addr, cidr) ? 1 : 0;
vec_null_map_res[i] = false; vec_null_map_res[i] = false;
} }
} }
@ -426,8 +254,8 @@ namespace DB
{ {
for (size_t i = 0; i < input_rows_count; i++) for (size_t i = 0; i < input_rows_count; i++)
{ {
const auto cidr = ipaddr::CIDR(col_cidr.getDataAt(i)); const auto cidr = parseIPWithCIDR(col_cidr.getDataAt(i));
vec_res[i] = cidr.contains(addr) ? 1 : 0; vec_res[i] = isAddressInRange(addr, cidr) ? 1 : 0;
} }
return col_res; return col_res;
@ -448,7 +276,7 @@ namespace DB
} }
else else
{ {
const auto cidr = ipaddr::CIDR(col_cidr.getDataAt(0)); const auto cidr = parseIPWithCIDR(col_cidr.getDataAt(0));
ColumnUInt8::MutablePtr col_res = ColumnUInt8::create(input_rows_count); ColumnUInt8::MutablePtr col_res = ColumnUInt8::create(input_rows_count);
ColumnUInt8::Container & vec_res = col_res->getData(); ColumnUInt8::Container & vec_res = col_res->getData();
@ -466,8 +294,8 @@ namespace DB
} }
else else
{ {
const auto addr = ipaddr::Address(col_addr.getDataAt(i)); const auto addr = IPAddressVariant(col_addr.getDataAt(i));
vec_res[i] = cidr.contains(addr) ? 1 : 0; vec_res[i] = isAddressInRange(addr, cidr) ? 1 : 0;
vec_null_map_res[i] = false; vec_null_map_res[i] = false;
} }
} }
@ -478,8 +306,8 @@ namespace DB
{ {
for (size_t i = 0; i < input_rows_count; i++) for (size_t i = 0; i < input_rows_count; i++)
{ {
const auto addr = ipaddr::Address(col_addr.getDataAt(i)); const auto addr = IPAddressVariant(col_addr.getDataAt(i));
vec_res[i] = cidr.contains(addr) ? 1 : 0; vec_res[i] = isAddressInRange(addr, cidr) ? 1 : 0;
} }
return col_res; return col_res;
@ -506,10 +334,10 @@ namespace DB
} }
else else
{ {
const auto addr = ipaddr::Address(col_addr.getDataAt(i)); const auto addr = IPAddressVariant(col_addr.getDataAt(i));
const auto cidr = ipaddr::CIDR(col_cidr.getDataAt(i)); const auto cidr = parseIPWithCIDR(col_cidr.getDataAt(i));
vec_res[i] = cidr.contains(addr) ? 1 : 0; vec_res[i] = isAddressInRange(addr, cidr) ? 1 : 0;
vec_null_map_res[i] = false; vec_null_map_res[i] = false;
} }
} }
@ -520,10 +348,10 @@ namespace DB
{ {
for (size_t i = 0; i < input_rows_count; i++) for (size_t i = 0; i < input_rows_count; i++)
{ {
const auto addr = ipaddr::Address(col_addr.getDataAt(i)); const auto addr = IPAddressVariant(col_addr.getDataAt(i));
const auto cidr = ipaddr::CIDR(col_cidr.getDataAt(i)); const auto cidr = parseIPWithCIDR(col_cidr.getDataAt(i));
vec_res[i] = cidr.contains(addr) ? 1 : 0; vec_res[i] = isAddressInRange(addr, cidr) ? 1 : 0;
} }
return col_res; return col_res;