Use sorted ip array instead of trie in TrieDictionary

This commit is contained in:
vdimir 2020-11-08 19:01:12 +03:00
parent 7fb53b205c
commit a67f5b780f
No known key found for this signature in database
GPG Key ID: 4F25F52AFAF0C2C0
7 changed files with 200 additions and 158 deletions

View File

@ -675,7 +675,7 @@
<!-- Configuration of external dictionaries. See:
https://clickhouse.yandex/docs/en/dicts/external_dicts/
https://clickhouse.tech/docs/en/sql-reference/dictionaries/external-dictionaries/external-dicts
-->
<dictionaries_config>*_dictionary.xml</dictionaries_config>

View File

@ -1,5 +1,7 @@
#include "IPv6ToBinary.h"
#include <Poco/Net/IPAddress.h>
#include <Poco/ByteOrder.h>
#include <cstring>
@ -28,4 +30,29 @@ std::array<char, 16> IPv6ToBinary(const Poco::Net::IPAddress & address)
return res;
}
UInt32 IPv4ToBinary(const Poco::Net::IPAddress & address, bool & success)
{
if (!address.isIPv4Mapped())
{
success = false;
return 0;
}
success = true;
if (Poco::Net::IPAddress::IPv6 == address.family())
{
auto raw = reinterpret_cast<const uint8_t *>(address.addr());
return *reinterpret_cast<const UInt32 *>(&raw[12]);
}
else if (Poco::Net::IPAddress::IPv4 == address.family())
{
auto raw = reinterpret_cast<const uint8_t *>(address.addr());
return *reinterpret_cast<const UInt32 *>(raw);
}
success = false;
return 0;
}
}

View File

@ -1,5 +1,6 @@
#pragma once
#include <array>
#include <common/types.h>
namespace Poco { namespace Net { class IPAddress; }}
@ -9,4 +10,8 @@ namespace DB
/// Convert IP address to 16-byte array with IPv6 data (big endian). If it's an IPv4, map it to IPv6.
std::array<char, 16> IPv6ToBinary(const Poco::Net::IPAddress & address);
/// Convert IP address to UInt32 (big endian) if it's IPv4 or IPv4 mapped to IPv6.
/// Sets success variable to true if succeed.
UInt32 IPv4ToBinary(const Poco::Net::IPAddress & address, bool & success);
}

View File

@ -3,11 +3,11 @@
#include <Columns/ColumnFixedString.h>
#include <Columns/ColumnVector.h>
#include <Common/assert_cast.h>
#include <Common/IPv6ToBinary.h>
#include <DataTypes/DataTypeFixedString.h>
#include <DataTypes/DataTypeString.h>
#include <IO/WriteIntText.h>
#include <Poco/ByteOrder.h>
#include <Poco/Net/IPAddress.h>
#include <Common/formatIPv6.h>
#include <common/itoa.h>
#include <ext/map.h>
@ -15,14 +15,6 @@
#include "DictionaryBlockInputStream.h"
#include "DictionaryFactory.h"
#ifdef __clang__
#pragma clang diagnostic ignored "-Wold-style-cast"
#pragma clang diagnostic ignored "-Wnewline-eof"
#endif
#include <btrie.h>
namespace DB
{
namespace ErrorCodes
@ -31,7 +23,6 @@ namespace ErrorCodes
extern const int TYPE_MISMATCH;
extern const int BAD_ARGUMENTS;
extern const int DICTIONARY_IS_EMPTY;
extern const int NOT_IMPLEMENTED;
}
static void validateKeyTypes(const DataTypes & key_types)
@ -45,6 +36,18 @@ static void validateKeyTypes(const DataTypes & key_types)
throw Exception{"Key does not match, expected either UInt32 or FixedString(16)", ErrorCodes::TYPE_MISMATCH};
}
/// Create IPAddress from 16 byte array converting to ipv4 if possible
static Poco::Net::IPAddress ip4or6fromBytes(const uint8_t * data)
{
Poco::Net::IPAddress ipaddr(reinterpret_cast<const uint8_t *>(data), IPV6_BINARY_LENGTH);
// try to consider as ipv4
bool is_v4 = false;
if (auto addr_v4 = IPv4ToBinary(ipaddr, is_v4); is_v4)
return Poco::Net::IPAddress(reinterpret_cast<const uint8_t *>(&addr_v4), IPV4_BINARY_LENGTH);
return ipaddr;
}
TrieDictionary::TrieDictionary(
const StorageID & dict_id_,
@ -57,17 +60,16 @@ TrieDictionary::TrieDictionary(
, source_ptr{std::move(source_ptr_)}
, dict_lifetime(dict_lifetime_)
, require_nonempty(require_nonempty_)
, total_ip_length(0)
, logger(&Poco::Logger::get("TrieDictionary"))
{
createAttributes();
trie = btrie_create();
loadData();
calculateBytesAllocated();
}
TrieDictionary::~TrieDictionary()
{
btrie_destroy(trie);
}
#define DECLARE(TYPE) \
@ -305,7 +307,8 @@ void TrieDictionary::loadData()
/// created upfront to avoid excess allocations
const auto keys_size = dict_struct.key->size();
StringRefs keys(keys_size);
ip_records.reserve(keys_size);
const auto attributes_size = attributes.size();
@ -331,11 +334,42 @@ void TrieDictionary::loadData()
{
const auto & attribute_column = *attribute_column_ptrs[attribute_idx];
auto & attribute = attributes[attribute_idx];
setAttributeValue(attribute, key_column->getDataAt(row_idx), attribute_column[row_idx]);
setAttributeValue(attribute, attribute_column[row_idx]);
}
size_t row_number = ip_records.size();
std::string addr_str(key_column->getDataAt(row_idx).toString());
size_t pos = addr_str.find('/');
if (pos != std::string::npos)
{
IPAddress addr(addr_str.substr(0, pos));
UInt8 prefix = std::stoi(addr_str.substr(pos + 1), nullptr, 10);
addr = addr & IPAddress(prefix, addr.family());
ip_records.emplace_back(IPRecord{addr, prefix, row_number});
}
else
{
IPAddress addr(addr_str);
UInt8 prefix = addr.length() * 8;
ip_records.emplace_back(IPRecord{addr, prefix, row_number});
}
total_ip_length += ip_records.back().addr.length();
}
}
LOG_TRACE(logger, "{} ip records are read", ip_records.size());
std::sort(ip_records.begin(), ip_records.end(), [](const auto & a, const auto & b)
{
if (a.addr.family() != b.addr.family())
return a.addr.family() < b.addr.family();
if (a.addr == b.addr)
return a.prefix > b.prefix;
return a.addr < b.addr;
});
stream->readSuffix();
if (require_nonempty && 0 == element_count)
@ -352,6 +386,8 @@ void TrieDictionary::addAttributeSize(const Attribute & attribute)
void TrieDictionary::calculateBytesAllocated()
{
bytes_allocated += ip_records.size() * sizeof(ip_records.front());
bytes_allocated += total_ip_length;
bytes_allocated += attributes.size() * sizeof(attributes.front());
for (const auto & attribute : attributes)
@ -411,8 +447,6 @@ void TrieDictionary::calculateBytesAllocated()
}
}
}
bytes_allocated += btrie_allocated(trie);
}
@ -494,16 +528,15 @@ void TrieDictionary::getItemsImpl(
const auto first_column = key_columns.front();
const auto rows = first_column->size();
if (first_column->isNumeric())
{
for (const auto i : ext::range(0, rows))
{
auto addr = Int32(first_column->get64(i));
uintptr_t slot = btrie_find(trie, addr);
#pragma GCC diagnostic push
#pragma GCC diagnostic warning "-Wold-style-cast"
set_value(i, slot != BTRIE_NULL ? static_cast<OutputType>(vec[slot]) : get_default(i));
#pragma GCC diagnostic pop
auto addr = Poco::ByteOrder::toNetwork(UInt32(first_column->get64(i)));
auto ipaddr = IPAddress(reinterpret_cast<const uint8_t *>(&addr), IPV4_BINARY_LENGTH);
auto found = lookupIPRecord(ipaddr);
set_value(i, (found != ipRecordNotFound()) ? static_cast<OutputType>(vec[found->row]) : get_default(i));
}
}
else
@ -511,107 +544,66 @@ void TrieDictionary::getItemsImpl(
for (const auto i : ext::range(0, rows))
{
auto addr = first_column->getDataAt(i);
if (addr.size != 16)
if (addr.size != IPV6_BINARY_LENGTH)
throw Exception("Expected key to be FixedString(16)", ErrorCodes::LOGICAL_ERROR);
uintptr_t slot = btrie_find_a6(trie, reinterpret_cast<const uint8_t *>(addr.data));
#pragma GCC diagnostic push
#pragma GCC diagnostic warning "-Wold-style-cast"
set_value(i, slot != BTRIE_NULL ? static_cast<OutputType>(vec[slot]) : get_default(i));
#pragma GCC diagnostic pop
auto ipaddr = ip4or6fromBytes(reinterpret_cast<const uint8_t *>(addr.data));
auto found = lookupIPRecord(ipaddr);
set_value(i, (found != ipRecordNotFound()) ? static_cast<OutputType>(vec[found->row]) : get_default(i));
}
}
query_count.fetch_add(rows, std::memory_order_relaxed);
}
template <typename T>
bool TrieDictionary::setAttributeValueImpl(Attribute & attribute, const StringRef key, const T value)
void TrieDictionary::setAttributeValueImpl(Attribute & attribute, const T value)
{
// Insert value into appropriate vector type
auto & vec = std::get<ContainerType<T>>(attribute.maps);
size_t row = vec.size();
vec.push_back(value);
// Parse IP address and subnet length from string (e.g. 2a02:6b8::3/64)
Poco::Net::IPAddress addr, mask;
std::string addr_str(key.toString());
size_t pos = addr_str.find('/');
if (pos != std::string::npos)
{
addr = Poco::Net::IPAddress(addr_str.substr(0, pos));
mask = Poco::Net::IPAddress(std::stoi(addr_str.substr(pos + 1), nullptr, 10), addr.family());
}
else
{
addr = Poco::Net::IPAddress(addr_str);
mask = Poco::Net::IPAddress(addr.length() * 8, addr.family());
}
/*
* Here we might overwrite the same key with the same slot as each key can map to multiple attributes.
* However, all columns have equal number of rows so it is okay to store only row number for each key
* instead of building a trie for each column. This comes at the cost of additional lookup in attribute
* vector on lookup time to return cell from row + column. The reason for this is to save space,
* and build only single trie instead of trie for each column.
*/
if (addr.family() == Poco::Net::IPAddress::IPv4)
{
UInt32 addr_v4 = Poco::ByteOrder::toNetwork(*reinterpret_cast<const UInt32 *>(addr.addr()));
UInt32 mask_v4 = Poco::ByteOrder::toNetwork(*reinterpret_cast<const UInt32 *>(mask.addr()));
return btrie_insert(trie, addr_v4, mask_v4, row) == 0;
}
const uint8_t * addr_v6 = reinterpret_cast<const uint8_t *>(addr.addr());
const uint8_t * mask_v6 = reinterpret_cast<const uint8_t *>(mask.addr());
return btrie_insert_a6(trie, addr_v6, mask_v6, row) == 0;
}
bool TrieDictionary::setAttributeValue(Attribute & attribute, const StringRef key, const Field & value)
void TrieDictionary::setAttributeValue(Attribute & attribute, const Field & value)
{
switch (attribute.type)
{
case AttributeUnderlyingType::utUInt8:
return setAttributeValueImpl<UInt8>(attribute, key, value.get<UInt64>());
return setAttributeValueImpl<UInt8>(attribute, value.get<UInt64>());
case AttributeUnderlyingType::utUInt16:
return setAttributeValueImpl<UInt16>(attribute, key, value.get<UInt64>());
return setAttributeValueImpl<UInt16>(attribute, value.get<UInt64>());
case AttributeUnderlyingType::utUInt32:
return setAttributeValueImpl<UInt32>(attribute, key, value.get<UInt64>());
return setAttributeValueImpl<UInt32>(attribute, value.get<UInt64>());
case AttributeUnderlyingType::utUInt64:
return setAttributeValueImpl<UInt64>(attribute, key, value.get<UInt64>());
return setAttributeValueImpl<UInt64>(attribute, value.get<UInt64>());
case AttributeUnderlyingType::utUInt128:
return setAttributeValueImpl<UInt128>(attribute, key, value.get<UInt128>());
return setAttributeValueImpl<UInt128>(attribute, value.get<UInt128>());
case AttributeUnderlyingType::utInt8:
return setAttributeValueImpl<Int8>(attribute, key, value.get<Int64>());
return setAttributeValueImpl<Int8>(attribute, value.get<Int64>());
case AttributeUnderlyingType::utInt16:
return setAttributeValueImpl<Int16>(attribute, key, value.get<Int64>());
return setAttributeValueImpl<Int16>(attribute, value.get<Int64>());
case AttributeUnderlyingType::utInt32:
return setAttributeValueImpl<Int32>(attribute, key, value.get<Int64>());
return setAttributeValueImpl<Int32>(attribute, value.get<Int64>());
case AttributeUnderlyingType::utInt64:
return setAttributeValueImpl<Int64>(attribute, key, value.get<Int64>());
return setAttributeValueImpl<Int64>(attribute, value.get<Int64>());
case AttributeUnderlyingType::utFloat32:
return setAttributeValueImpl<Float32>(attribute, key, value.get<Float64>());
return setAttributeValueImpl<Float32>(attribute, value.get<Float64>());
case AttributeUnderlyingType::utFloat64:
return setAttributeValueImpl<Float64>(attribute, key, value.get<Float64>());
return setAttributeValueImpl<Float64>(attribute, value.get<Float64>());
case AttributeUnderlyingType::utDecimal32:
return setAttributeValueImpl<Decimal32>(attribute, key, value.get<Decimal32>());
return setAttributeValueImpl<Decimal32>(attribute, value.get<Decimal32>());
case AttributeUnderlyingType::utDecimal64:
return setAttributeValueImpl<Decimal64>(attribute, key, value.get<Decimal64>());
return setAttributeValueImpl<Decimal64>(attribute, value.get<Decimal64>());
case AttributeUnderlyingType::utDecimal128:
return setAttributeValueImpl<Decimal128>(attribute, key, value.get<Decimal128>());
return setAttributeValueImpl<Decimal128>(attribute, value.get<Decimal128>());
case AttributeUnderlyingType::utString:
{
const auto & string = value.get<String>();
const auto * string_in_arena = attribute.string_arena->insert(string.data(), string.size());
setAttributeValueImpl<StringRef>(attribute, key, StringRef{string_in_arena, string.size()});
return true;
return setAttributeValueImpl<StringRef>(attribute, StringRef{string_in_arena, string.size()});
}
}
return {};
}
const TrieDictionary::Attribute & TrieDictionary::getAttribute(const std::string & attribute_name) const
@ -633,11 +625,9 @@ void TrieDictionary::has(const Attribute &, const Columns & key_columns, PaddedP
for (const auto i : ext::range(0, rows))
{
auto addr = Int32(first_column->get64(i));
uintptr_t slot = btrie_find(trie, addr);
#pragma GCC diagnostic push
#pragma GCC diagnostic warning "-Wold-style-cast"
out[i] = (slot != BTRIE_NULL);
#pragma GCC diagnostic pop
auto ipaddr = IPAddress(reinterpret_cast<const uint8_t *>(&addr), IPV4_BINARY_LENGTH);
auto found = lookupIPRecord(ipaddr);
out[i] = (found != ipRecordNotFound());
}
}
else
@ -648,78 +638,27 @@ void TrieDictionary::has(const Attribute &, const Columns & key_columns, PaddedP
if (unlikely(addr.size != 16))
throw Exception("Expected key to be FixedString(16)", ErrorCodes::LOGICAL_ERROR);
uintptr_t slot = btrie_find_a6(trie, reinterpret_cast<const uint8_t *>(addr.data));
#pragma GCC diagnostic push
#pragma GCC diagnostic warning "-Wold-style-cast"
out[i] = (slot != BTRIE_NULL);
#pragma GCC diagnostic pop
auto ipaddr = ip4or6fromBytes(reinterpret_cast<const uint8_t *>(addr.data));
auto found = lookupIPRecord(ipaddr);
out[i] = (found != ipRecordNotFound());
}
}
query_count.fetch_add(rows, std::memory_order_relaxed);
}
template <typename Getter, typename KeyType>
static void trieTraverse(const btrie_t * trie, Getter && getter)
{
KeyType key = 0;
const KeyType high_bit = ~((~key) >> 1);
btrie_node_t * node;
node = trie->root;
std::stack<btrie_node_t *> stack;
while (node)
{
stack.push(node);
node = node->left;
}
auto get_bit = [&high_bit](size_t size) { return size ? (high_bit >> (size - 1)) : 0; };
while (!stack.empty())
{
node = stack.top();
stack.pop();
#pragma GCC diagnostic push
#pragma GCC diagnostic warning "-Wold-style-cast"
if (node && node->value != BTRIE_NULL)
#pragma GCC diagnostic pop
getter(key, stack.size());
if (node && node->right)
{
stack.push(nullptr);
key |= get_bit(stack.size());
stack.push(node->right);
while (stack.top()->left)
stack.push(stack.top()->left);
}
else
key &= ~get_bit(stack.size());
}
}
Columns TrieDictionary::getKeyColumns() const
{
auto ip_column = ColumnFixedString::create(IPV6_BINARY_LENGTH);
auto mask_column = ColumnVector<UInt8>::create();
#if defined(__SIZEOF_INT128__)
auto getter = [&ip_column, &mask_column](__uint128_t ip, size_t mask)
for (const auto & record : ip_records)
{
Poco::UInt64 * ip_array = reinterpret_cast<Poco::UInt64 *>(&ip); // Poco:: for old poco + macos
ip_array[0] = Poco::ByteOrder::fromNetwork(ip_array[0]);
ip_array[1] = Poco::ByteOrder::fromNetwork(ip_array[1]);
std::swap(ip_array[0], ip_array[1]);
ip_column->insertData(reinterpret_cast<const char *>(ip_array), IPV6_BINARY_LENGTH);
mask_column->insertValue(static_cast<UInt8>(mask));
};
auto ip_array = IPv6ToBinary(record.addr);
ip_column->insertData(ip_array.data(), IPV6_BINARY_LENGTH);
mask_column->insertValue(record.prefix);
}
trieTraverse<decltype(getter), __uint128_t>(trie, std::move(getter));
#else
throw Exception("TrieDictionary::getKeyColumns is not implemented for 32bit arch", ErrorCodes::NOT_IMPLEMENTED);
#endif
return {std::move(ip_column), std::move(mask_column)};
}
@ -755,6 +694,45 @@ BlockInputStreamPtr TrieDictionary::getBlockInputStream(const Names & column_nam
shared_from_this(), max_block_size, getKeyColumns(), column_names, std::move(get_keys), std::move(get_view));
}
int TrieDictionary::matchIPAddrWithRecord(const IPAddress & ipaddr, const IPRecord & record) const
{
if (ipaddr.family() != record.addr.family())
return ipaddr.family() < record.addr.family() ? -1 : 1;
auto masked_ipaddr = ipaddr & IPAddress(record.prefix, record.addr.family());
if (masked_ipaddr < record.addr)
return -1;
if (masked_ipaddr == record.addr)
return 0;
return 1;
}
TrieDictionary::IPRecordConstIt TrieDictionary::ipRecordNotFound() const
{
return ip_records.end();
}
TrieDictionary::IPRecordConstIt TrieDictionary::lookupIPRecord(const IPAddress & target) const
{
if (ip_records.empty())
return ipRecordNotFound();
auto comp = [&](const IPAddress & needle, const IPRecord & record) -> bool
{
return matchIPAddrWithRecord(needle, record) < 0;
};
auto next_it = std::upper_bound(ip_records.begin(), ip_records.end(), target, comp);
if (next_it == ip_records.begin())
return ipRecordNotFound();
auto found = next_it - 1;
if (matchIPAddrWithRecord(target, *found) == 0)
return found;
return ipRecordNotFound();
}
void registerDictionaryTrie(DictionaryFactory & factory)
{

View File

@ -7,6 +7,7 @@
#include <Columns/ColumnString.h>
#include <Common/Arena.h>
#include <Common/HashTable/HashMap.h>
#include <Poco/Net/IPAddress.h>
#include <common/StringRef.h>
#include <common/logger_useful.h>
#include <ext/range.h>
@ -14,9 +15,6 @@
#include "IDictionary.h"
#include "IDictionarySource.h"
struct btrie_s;
typedef struct btrie_s btrie_t;
namespace DB
{
class TrieDictionary final : public IDictionaryBase
@ -150,9 +148,22 @@ public:
BlockInputStreamPtr getBlockInputStream(const Names & column_names, size_t max_block_size) const override;
private:
template <typename Value>
using ContainerType = std::vector<Value>;
using IPAddress = Poco::Net::IPAddress;
struct IPRecord;
using IPRecordConstIt = ContainerType<IPRecord>::const_iterator;
struct IPRecord final
{
IPAddress addr;
UInt8 prefix;
size_t row;
};
struct Attribute final
{
AttributeUnderlyingType type;
@ -212,11 +223,10 @@ private:
void
getItemsImpl(const Attribute & attribute, const Columns & key_columns, ValueSetter && set_value, DefaultGetter && get_default) const;
template <typename T>
bool setAttributeValueImpl(Attribute & attribute, const StringRef key, const T value);
void setAttributeValueImpl(Attribute & attribute, const T value);
bool setAttributeValue(Attribute & attribute, const StringRef key, const Field & value);
void setAttributeValue(Attribute & attribute, const Field & value);
const Attribute & getAttribute(const std::string & attribute_name) const;
@ -225,14 +235,27 @@ private:
Columns getKeyColumns() const;
/**
* Compare ip addresses.
*
* @return negative value if ipaddr less than address in record
* @return zero if ipaddr in record subnet
* @return positive value if ipaddr greater than address in record
*/
int matchIPAddrWithRecord(const IPAddress & ipaddr, const IPRecord & record) const;
IPRecordConstIt ipRecordNotFound() const;
IPRecordConstIt lookupIPRecord(const IPAddress & target) const;
const DictionaryStructure dict_struct;
const DictionarySourcePtr source_ptr;
const DictionaryLifetime dict_lifetime;
const bool require_nonempty;
const std::string key_description{dict_struct.getKeyDescription()};
ContainerType<IPRecord> ip_records;
size_t total_ip_length;
btrie_t * trie = nullptr;
std::map<std::string, size_t> attribute_index_by_name;
std::vector<Attribute> attributes;

View File

@ -10,6 +10,11 @@
0
***ip trie dict***
17501
17501
17502
0
11211
11211
NP
***hierarchy dict***
Moscow

View File

@ -82,8 +82,7 @@ CREATE TABLE database_for_dict.table_ip_trie
)
engine = TinyLog;
INSERT INTO database_for_dict.table_ip_trie VALUES ('202.79.32.0/20', 17501, 'NP'), ('2620:0:870::/48', 3856, 'US'), ('2a02:6b8:1::/48', 13238, 'RU'), ('2001:db8::/32', 65536, 'ZZ');
INSERT INTO database_for_dict.table_ip_trie VALUES ('202.79.32.0/20', 17501, 'NP'), ('202.79.32.2', 17502, 'NP'), ('101.79.55.22', 11211, 'UK'), ('2620:0:870::/48', 3856, 'US'), ('2a02:6b8:1::/48', 13238, 'RU'), ('2001:db8::/32', 65536, 'ZZ');
CREATE DICTIONARY database_for_dict.dict_ip_trie
(
@ -97,6 +96,11 @@ LAYOUT(IP_TRIE())
LIFETIME(MIN 10 MAX 100);
SELECT dictGetUInt32('database_for_dict.dict_ip_trie', 'asn', tuple(IPv4StringToNum('202.79.32.0')));
SELECT dictGetUInt32('database_for_dict.dict_ip_trie', 'asn', tuple(IPv4StringToNum('202.79.32.1')));
SELECT dictGetUInt32('database_for_dict.dict_ip_trie', 'asn', tuple(IPv4StringToNum('202.79.32.2')));
SELECT dictHas('database_for_dict.dict_ip_trie', tuple(IPv6StringToNum('654f:3716::')));
SELECT dictGetUInt32('database_for_dict.dict_ip_trie', 'asn', tuple(IPv6StringToNum('::ffff:654f:3716')));
SELECT dictGetUInt32('database_for_dict.dict_ip_trie', 'asn', tuple(IPv6StringToNum('::ffff:101.79.55.22')));
SELECT dictGetString('database_for_dict.dict_ip_trie', 'cca2', tuple(IPv4StringToNum('202.79.32.0')));
SELECT '***hierarchy dict***';