Merge pull request #17569 from vdimir/speedup-apply-cidr-mask-v6

Speedup applyCIDRMask for IPv6
This commit is contained in:
alexey-milovidov 2020-12-31 03:31:43 +03:00 committed by GitHub
commit 81f8ee5fd8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 65 additions and 12 deletions

View File

@ -2,12 +2,20 @@
#include <Poco/Net/IPAddress.h>
#include <Poco/ByteOrder.h>
#include <Common/formatIPv6.h>
#include <cstring>
namespace DB
{
/// Result array could be indexed with all possible uint8 values without extra check.
/// For values greater than 128 we will store same value as for 128 (all bits set).
constexpr size_t IPV6_MASKS_COUNT = 256;
using RawMaskArray = std::array<uint8_t, IPV6_BINARY_LENGTH>;
void IPv6ToRawBinary(const Poco::Net::IPAddress & address, char * res)
{
if (Poco::Net::IPAddress::IPv6 == address.family())
@ -33,4 +41,33 @@ std::array<char, 16> IPv6ToBinary(const Poco::Net::IPAddress & address)
return res;
}
static constexpr RawMaskArray generateBitMask(size_t prefix)
{
if (prefix >= 128)
prefix = 128;
RawMaskArray arr{0};
size_t i = 0;
for (; prefix >= 8; ++i, prefix -= 8)
arr[i] = 0xff;
if (prefix > 0)
arr[i++] = ~(0xff >> prefix);
while (i < 16)
arr[i++] = 0x00;
return arr;
}
static constexpr std::array<RawMaskArray, IPV6_MASKS_COUNT> generateBitMasks()
{
std::array<RawMaskArray, IPV6_MASKS_COUNT> arr{};
for (size_t i = 0; i < IPV6_MASKS_COUNT; ++i)
arr[i] = generateBitMask(i);
return arr;
}
const uint8_t * getCIDRMaskIPv6(UInt8 prefix_len)
{
static constexpr std::array<RawMaskArray, IPV6_MASKS_COUNT> IPV6_RAW_MASK_ARRAY = generateBitMasks();
return IPV6_RAW_MASK_ARRAY[prefix_len].data();
}
}

View File

@ -14,4 +14,9 @@ void IPv6ToRawBinary(const Poco::Net::IPAddress & address, char * res);
/// Convert IP address to 16-byte array with IPv6 data (big endian). If it's an IPv4, map it to IPv6.
std::array<char, 16> IPv6ToBinary(const Poco::Net::IPAddress & address);
/// Returns pointer to 16-byte array containing mask with first `prefix_len` bits set to `1` and `128 - prefix_len` to `0`.
/// Pointer is valid during all program execution time and doesn't require freeing.
/// Values of prefix_len greater than 128 interpreted as 128 exactly.
const uint8_t * getCIDRMaskIPv6(UInt8 prefix_len);
}

View File

@ -1,7 +1,8 @@
#pragma once
#include <Common/hex.h>
#include <Common/formatIPv6.h>
#include <Common/hex.h>
#include <Common/IPv6ToBinary.h>
#include <Common/typeid_cast.h>
#include <IO/WriteHelpers.h>
#include <DataTypes/DataTypeFactory.h>
@ -1617,20 +1618,28 @@ public:
class FunctionIPv6CIDRToRange : public IFunction
{
private:
/// TODO Inefficient.
#if defined(__SSE2__)
#include <emmintrin.h>
static inline void applyCIDRMask(const UInt8 * __restrict src, UInt8 * __restrict dst_lower, UInt8 * __restrict dst_upper, UInt8 bits_to_keep)
{
__m128i mask = _mm_loadu_si128(reinterpret_cast<const __m128i *>(getCIDRMaskIPv6(bits_to_keep)));
__m128i lower = _mm_and_si128(_mm_loadu_si128(reinterpret_cast<const __m128i *>(src)), mask);
_mm_storeu_si128(reinterpret_cast<__m128i *>(dst_lower), lower);
__m128i inv_mask = _mm_xor_si128(mask, _mm_cmpeq_epi32(_mm_setzero_si128(), _mm_setzero_si128()));
__m128i upper = _mm_or_si128(lower, inv_mask);
_mm_storeu_si128(reinterpret_cast<__m128i *>(dst_upper), upper);
}
#else
/// NOTE IPv6 is stored in memory in big endian format that makes some difficulties.
static void applyCIDRMask(const UInt8 * __restrict src, UInt8 * __restrict dst_lower, UInt8 * __restrict dst_upper, UInt8 bits_to_keep)
{
UInt8 mask[16]{};
UInt8 bytes_to_keep = bits_to_keep / 8;
UInt8 bits_to_keep_in_last_byte = bits_to_keep % 8;
for (size_t i = 0; i < bits_to_keep / 8; ++i)
mask[i] = 0xFFU;
if (bits_to_keep_in_last_byte)
mask[bytes_to_keep] = 0xFFU << (8 - bits_to_keep_in_last_byte);
const auto * mask = getCIDRMaskIPv6(bits_to_keep);
for (size_t i = 0; i < 16; ++i)
{
@ -1639,6 +1648,8 @@ private:
}
}
#endif
public:
static constexpr auto name = "IPv6CIDRToRange";
static FunctionPtr create(const Context &) { return std::make_shared<FunctionIPv6CIDRToRange>(); }