From a67f5b780f0582f88951bbd38e9f884ab874bbd7 Mon Sep 17 00:00:00 2001 From: vdimir Date: Sun, 8 Nov 2020 19:01:12 +0300 Subject: [PATCH] Use sorted ip array instead of trie in TrieDictionary --- programs/server/config.xml | 2 +- src/Common/IPv6ToBinary.cpp | 27 ++ src/Common/IPv6ToBinary.h | 5 + src/Dictionaries/TrieDictionary.cpp | 274 ++++++++---------- src/Dictionaries/TrieDictionary.h | 37 ++- .../01018_ddl_dictionaries_special.reference | 5 + .../01018_ddl_dictionaries_special.sql | 8 +- 7 files changed, 200 insertions(+), 158 deletions(-) diff --git a/programs/server/config.xml b/programs/server/config.xml index e17b59671af..a03270aa7b9 100644 --- a/programs/server/config.xml +++ b/programs/server/config.xml @@ -675,7 +675,7 @@ *_dictionary.xml diff --git a/src/Common/IPv6ToBinary.cpp b/src/Common/IPv6ToBinary.cpp index bfa6992de9e..00c1b520a7a 100644 --- a/src/Common/IPv6ToBinary.cpp +++ b/src/Common/IPv6ToBinary.cpp @@ -1,5 +1,7 @@ #include "IPv6ToBinary.h" #include +#include + #include @@ -28,4 +30,29 @@ std::array IPv6ToBinary(const Poco::Net::IPAddress & address) return res; } + +UInt32 IPv4ToBinary(const Poco::Net::IPAddress & address, bool & success) +{ + if (!address.isIPv4Mapped()) + { + success = false; + return 0; + } + + success = true; + if (Poco::Net::IPAddress::IPv6 == address.family()) + { + auto raw = reinterpret_cast(address.addr()); + return *reinterpret_cast(&raw[12]); + } + else if (Poco::Net::IPAddress::IPv4 == address.family()) + { + auto raw = reinterpret_cast(address.addr()); + return *reinterpret_cast(raw); + } + + success = false; + return 0; +} + } diff --git a/src/Common/IPv6ToBinary.h b/src/Common/IPv6ToBinary.h index e95dfa10223..4f2cdd0ea21 100644 --- a/src/Common/IPv6ToBinary.h +++ b/src/Common/IPv6ToBinary.h @@ -1,5 +1,6 @@ #pragma once #include +#include namespace Poco { namespace Net { class IPAddress; }} @@ -9,4 +10,8 @@ namespace DB /// Convert IP address to 16-byte array with IPv6 data (big endian). If it's an IPv4, map it to IPv6. std::array IPv6ToBinary(const Poco::Net::IPAddress & address); +/// Convert IP address to UInt32 (big endian) if it's IPv4 or IPv4 mapped to IPv6. +/// Sets success variable to true if succeed. +UInt32 IPv4ToBinary(const Poco::Net::IPAddress & address, bool & success); + } diff --git a/src/Dictionaries/TrieDictionary.cpp b/src/Dictionaries/TrieDictionary.cpp index d8267047b92..a2e2bcf0bd8 100644 --- a/src/Dictionaries/TrieDictionary.cpp +++ b/src/Dictionaries/TrieDictionary.cpp @@ -3,11 +3,11 @@ #include #include #include +#include #include #include #include #include -#include #include #include #include @@ -15,14 +15,6 @@ #include "DictionaryBlockInputStream.h" #include "DictionaryFactory.h" -#ifdef __clang__ - #pragma clang diagnostic ignored "-Wold-style-cast" - #pragma clang diagnostic ignored "-Wnewline-eof" -#endif - -#include - - namespace DB { namespace ErrorCodes @@ -31,7 +23,6 @@ namespace ErrorCodes extern const int TYPE_MISMATCH; extern const int BAD_ARGUMENTS; extern const int DICTIONARY_IS_EMPTY; - extern const int NOT_IMPLEMENTED; } static void validateKeyTypes(const DataTypes & key_types) @@ -45,6 +36,18 @@ static void validateKeyTypes(const DataTypes & key_types) throw Exception{"Key does not match, expected either UInt32 or FixedString(16)", ErrorCodes::TYPE_MISMATCH}; } +/// Create IPAddress from 16 byte array converting to ipv4 if possible +static Poco::Net::IPAddress ip4or6fromBytes(const uint8_t * data) +{ + Poco::Net::IPAddress ipaddr(reinterpret_cast(data), IPV6_BINARY_LENGTH); + + // try to consider as ipv4 + bool is_v4 = false; + if (auto addr_v4 = IPv4ToBinary(ipaddr, is_v4); is_v4) + return Poco::Net::IPAddress(reinterpret_cast(&addr_v4), IPV4_BINARY_LENGTH); + + return ipaddr; +} TrieDictionary::TrieDictionary( const StorageID & dict_id_, @@ -57,17 +60,16 @@ TrieDictionary::TrieDictionary( , source_ptr{std::move(source_ptr_)} , dict_lifetime(dict_lifetime_) , require_nonempty(require_nonempty_) + , total_ip_length(0) , logger(&Poco::Logger::get("TrieDictionary")) { createAttributes(); - trie = btrie_create(); loadData(); calculateBytesAllocated(); } TrieDictionary::~TrieDictionary() { - btrie_destroy(trie); } #define DECLARE(TYPE) \ @@ -305,7 +307,8 @@ void TrieDictionary::loadData() /// created upfront to avoid excess allocations const auto keys_size = dict_struct.key->size(); - StringRefs keys(keys_size); + + ip_records.reserve(keys_size); const auto attributes_size = attributes.size(); @@ -331,11 +334,42 @@ void TrieDictionary::loadData() { const auto & attribute_column = *attribute_column_ptrs[attribute_idx]; auto & attribute = attributes[attribute_idx]; - setAttributeValue(attribute, key_column->getDataAt(row_idx), attribute_column[row_idx]); + + setAttributeValue(attribute, attribute_column[row_idx]); } + + size_t row_number = ip_records.size(); + + std::string addr_str(key_column->getDataAt(row_idx).toString()); + size_t pos = addr_str.find('/'); + if (pos != std::string::npos) + { + IPAddress addr(addr_str.substr(0, pos)); + UInt8 prefix = std::stoi(addr_str.substr(pos + 1), nullptr, 10); + addr = addr & IPAddress(prefix, addr.family()); + ip_records.emplace_back(IPRecord{addr, prefix, row_number}); + } + else + { + IPAddress addr(addr_str); + UInt8 prefix = addr.length() * 8; + ip_records.emplace_back(IPRecord{addr, prefix, row_number}); + } + total_ip_length += ip_records.back().addr.length(); } } + LOG_TRACE(logger, "{} ip records are read", ip_records.size()); + + std::sort(ip_records.begin(), ip_records.end(), [](const auto & a, const auto & b) + { + if (a.addr.family() != b.addr.family()) + return a.addr.family() < b.addr.family(); + if (a.addr == b.addr) + return a.prefix > b.prefix; + return a.addr < b.addr; + }); + stream->readSuffix(); if (require_nonempty && 0 == element_count) @@ -352,6 +386,8 @@ void TrieDictionary::addAttributeSize(const Attribute & attribute) void TrieDictionary::calculateBytesAllocated() { + bytes_allocated += ip_records.size() * sizeof(ip_records.front()); + bytes_allocated += total_ip_length; bytes_allocated += attributes.size() * sizeof(attributes.front()); for (const auto & attribute : attributes) @@ -411,8 +447,6 @@ void TrieDictionary::calculateBytesAllocated() } } } - - bytes_allocated += btrie_allocated(trie); } @@ -494,16 +528,15 @@ void TrieDictionary::getItemsImpl( const auto first_column = key_columns.front(); const auto rows = first_column->size(); + if (first_column->isNumeric()) { for (const auto i : ext::range(0, rows)) { - auto addr = Int32(first_column->get64(i)); - uintptr_t slot = btrie_find(trie, addr); -#pragma GCC diagnostic push -#pragma GCC diagnostic warning "-Wold-style-cast" - set_value(i, slot != BTRIE_NULL ? static_cast(vec[slot]) : get_default(i)); -#pragma GCC diagnostic pop + auto addr = Poco::ByteOrder::toNetwork(UInt32(first_column->get64(i))); + auto ipaddr = IPAddress(reinterpret_cast(&addr), IPV4_BINARY_LENGTH); + auto found = lookupIPRecord(ipaddr); + set_value(i, (found != ipRecordNotFound()) ? static_cast(vec[found->row]) : get_default(i)); } } else @@ -511,107 +544,66 @@ void TrieDictionary::getItemsImpl( for (const auto i : ext::range(0, rows)) { auto addr = first_column->getDataAt(i); - if (addr.size != 16) + if (addr.size != IPV6_BINARY_LENGTH) throw Exception("Expected key to be FixedString(16)", ErrorCodes::LOGICAL_ERROR); - uintptr_t slot = btrie_find_a6(trie, reinterpret_cast(addr.data)); -#pragma GCC diagnostic push -#pragma GCC diagnostic warning "-Wold-style-cast" - set_value(i, slot != BTRIE_NULL ? static_cast(vec[slot]) : get_default(i)); -#pragma GCC diagnostic pop + auto ipaddr = ip4or6fromBytes(reinterpret_cast(addr.data)); + auto found = lookupIPRecord(ipaddr); + set_value(i, (found != ipRecordNotFound()) ? static_cast(vec[found->row]) : get_default(i)); } } query_count.fetch_add(rows, std::memory_order_relaxed); } - template -bool TrieDictionary::setAttributeValueImpl(Attribute & attribute, const StringRef key, const T value) +void TrieDictionary::setAttributeValueImpl(Attribute & attribute, const T value) { - // Insert value into appropriate vector type auto & vec = std::get>(attribute.maps); - size_t row = vec.size(); vec.push_back(value); - - // Parse IP address and subnet length from string (e.g. 2a02:6b8::3/64) - Poco::Net::IPAddress addr, mask; - std::string addr_str(key.toString()); - size_t pos = addr_str.find('/'); - if (pos != std::string::npos) - { - addr = Poco::Net::IPAddress(addr_str.substr(0, pos)); - mask = Poco::Net::IPAddress(std::stoi(addr_str.substr(pos + 1), nullptr, 10), addr.family()); - } - else - { - addr = Poco::Net::IPAddress(addr_str); - mask = Poco::Net::IPAddress(addr.length() * 8, addr.family()); - } - - /* - * Here we might overwrite the same key with the same slot as each key can map to multiple attributes. - * However, all columns have equal number of rows so it is okay to store only row number for each key - * instead of building a trie for each column. This comes at the cost of additional lookup in attribute - * vector on lookup time to return cell from row + column. The reason for this is to save space, - * and build only single trie instead of trie for each column. - */ - if (addr.family() == Poco::Net::IPAddress::IPv4) - { - UInt32 addr_v4 = Poco::ByteOrder::toNetwork(*reinterpret_cast(addr.addr())); - UInt32 mask_v4 = Poco::ByteOrder::toNetwork(*reinterpret_cast(mask.addr())); - return btrie_insert(trie, addr_v4, mask_v4, row) == 0; - } - - const uint8_t * addr_v6 = reinterpret_cast(addr.addr()); - const uint8_t * mask_v6 = reinterpret_cast(mask.addr()); - return btrie_insert_a6(trie, addr_v6, mask_v6, row) == 0; } -bool TrieDictionary::setAttributeValue(Attribute & attribute, const StringRef key, const Field & value) +void TrieDictionary::setAttributeValue(Attribute & attribute, const Field & value) { switch (attribute.type) { case AttributeUnderlyingType::utUInt8: - return setAttributeValueImpl(attribute, key, value.get()); + return setAttributeValueImpl(attribute, value.get()); case AttributeUnderlyingType::utUInt16: - return setAttributeValueImpl(attribute, key, value.get()); + return setAttributeValueImpl(attribute, value.get()); case AttributeUnderlyingType::utUInt32: - return setAttributeValueImpl(attribute, key, value.get()); + return setAttributeValueImpl(attribute, value.get()); case AttributeUnderlyingType::utUInt64: - return setAttributeValueImpl(attribute, key, value.get()); + return setAttributeValueImpl(attribute, value.get()); case AttributeUnderlyingType::utUInt128: - return setAttributeValueImpl(attribute, key, value.get()); + return setAttributeValueImpl(attribute, value.get()); case AttributeUnderlyingType::utInt8: - return setAttributeValueImpl(attribute, key, value.get()); + return setAttributeValueImpl(attribute, value.get()); case AttributeUnderlyingType::utInt16: - return setAttributeValueImpl(attribute, key, value.get()); + return setAttributeValueImpl(attribute, value.get()); case AttributeUnderlyingType::utInt32: - return setAttributeValueImpl(attribute, key, value.get()); + return setAttributeValueImpl(attribute, value.get()); case AttributeUnderlyingType::utInt64: - return setAttributeValueImpl(attribute, key, value.get()); + return setAttributeValueImpl(attribute, value.get()); case AttributeUnderlyingType::utFloat32: - return setAttributeValueImpl(attribute, key, value.get()); + return setAttributeValueImpl(attribute, value.get()); case AttributeUnderlyingType::utFloat64: - return setAttributeValueImpl(attribute, key, value.get()); + return setAttributeValueImpl(attribute, value.get()); case AttributeUnderlyingType::utDecimal32: - return setAttributeValueImpl(attribute, key, value.get()); + return setAttributeValueImpl(attribute, value.get()); case AttributeUnderlyingType::utDecimal64: - return setAttributeValueImpl(attribute, key, value.get()); + return setAttributeValueImpl(attribute, value.get()); case AttributeUnderlyingType::utDecimal128: - return setAttributeValueImpl(attribute, key, value.get()); + return setAttributeValueImpl(attribute, value.get()); case AttributeUnderlyingType::utString: { const auto & string = value.get(); const auto * string_in_arena = attribute.string_arena->insert(string.data(), string.size()); - setAttributeValueImpl(attribute, key, StringRef{string_in_arena, string.size()}); - return true; + return setAttributeValueImpl(attribute, StringRef{string_in_arena, string.size()}); } } - - return {}; } const TrieDictionary::Attribute & TrieDictionary::getAttribute(const std::string & attribute_name) const @@ -633,11 +625,9 @@ void TrieDictionary::has(const Attribute &, const Columns & key_columns, PaddedP for (const auto i : ext::range(0, rows)) { auto addr = Int32(first_column->get64(i)); - uintptr_t slot = btrie_find(trie, addr); -#pragma GCC diagnostic push -#pragma GCC diagnostic warning "-Wold-style-cast" - out[i] = (slot != BTRIE_NULL); -#pragma GCC diagnostic pop + auto ipaddr = IPAddress(reinterpret_cast(&addr), IPV4_BINARY_LENGTH); + auto found = lookupIPRecord(ipaddr); + out[i] = (found != ipRecordNotFound()); } } else @@ -648,78 +638,27 @@ void TrieDictionary::has(const Attribute &, const Columns & key_columns, PaddedP if (unlikely(addr.size != 16)) throw Exception("Expected key to be FixedString(16)", ErrorCodes::LOGICAL_ERROR); - uintptr_t slot = btrie_find_a6(trie, reinterpret_cast(addr.data)); -#pragma GCC diagnostic push -#pragma GCC diagnostic warning "-Wold-style-cast" - out[i] = (slot != BTRIE_NULL); -#pragma GCC diagnostic pop + auto ipaddr = ip4or6fromBytes(reinterpret_cast(addr.data)); + auto found = lookupIPRecord(ipaddr); + out[i] = (found != ipRecordNotFound()); } } query_count.fetch_add(rows, std::memory_order_relaxed); } -template -static void trieTraverse(const btrie_t * trie, Getter && getter) -{ - KeyType key = 0; - const KeyType high_bit = ~((~key) >> 1); - - btrie_node_t * node; - node = trie->root; - - std::stack stack; - while (node) - { - stack.push(node); - node = node->left; - } - - auto get_bit = [&high_bit](size_t size) { return size ? (high_bit >> (size - 1)) : 0; }; - - while (!stack.empty()) - { - node = stack.top(); - stack.pop(); -#pragma GCC diagnostic push -#pragma GCC diagnostic warning "-Wold-style-cast" - if (node && node->value != BTRIE_NULL) -#pragma GCC diagnostic pop - getter(key, stack.size()); - - if (node && node->right) - { - stack.push(nullptr); - key |= get_bit(stack.size()); - stack.push(node->right); - while (stack.top()->left) - stack.push(stack.top()->left); - } - else - key &= ~get_bit(stack.size()); - } -} - Columns TrieDictionary::getKeyColumns() const { auto ip_column = ColumnFixedString::create(IPV6_BINARY_LENGTH); auto mask_column = ColumnVector::create(); -#if defined(__SIZEOF_INT128__) - auto getter = [&ip_column, &mask_column](__uint128_t ip, size_t mask) + for (const auto & record : ip_records) { - Poco::UInt64 * ip_array = reinterpret_cast(&ip); // Poco:: for old poco + macos - ip_array[0] = Poco::ByteOrder::fromNetwork(ip_array[0]); - ip_array[1] = Poco::ByteOrder::fromNetwork(ip_array[1]); - std::swap(ip_array[0], ip_array[1]); - ip_column->insertData(reinterpret_cast(ip_array), IPV6_BINARY_LENGTH); - mask_column->insertValue(static_cast(mask)); - }; + auto ip_array = IPv6ToBinary(record.addr); + ip_column->insertData(ip_array.data(), IPV6_BINARY_LENGTH); + mask_column->insertValue(record.prefix); + } - trieTraverse(trie, std::move(getter)); -#else - throw Exception("TrieDictionary::getKeyColumns is not implemented for 32bit arch", ErrorCodes::NOT_IMPLEMENTED); -#endif return {std::move(ip_column), std::move(mask_column)}; } @@ -755,6 +694,45 @@ BlockInputStreamPtr TrieDictionary::getBlockInputStream(const Names & column_nam shared_from_this(), max_block_size, getKeyColumns(), column_names, std::move(get_keys), std::move(get_view)); } +int TrieDictionary::matchIPAddrWithRecord(const IPAddress & ipaddr, const IPRecord & record) const +{ + if (ipaddr.family() != record.addr.family()) + return ipaddr.family() < record.addr.family() ? -1 : 1; + + auto masked_ipaddr = ipaddr & IPAddress(record.prefix, record.addr.family()); + if (masked_ipaddr < record.addr) + return -1; + if (masked_ipaddr == record.addr) + return 0; + return 1; +} + +TrieDictionary::IPRecordConstIt TrieDictionary::ipRecordNotFound() const +{ + return ip_records.end(); +} + +TrieDictionary::IPRecordConstIt TrieDictionary::lookupIPRecord(const IPAddress & target) const +{ + if (ip_records.empty()) + return ipRecordNotFound(); + + auto comp = [&](const IPAddress & needle, const IPRecord & record) -> bool + { + return matchIPAddrWithRecord(needle, record) < 0; + }; + + auto next_it = std::upper_bound(ip_records.begin(), ip_records.end(), target, comp); + + if (next_it == ip_records.begin()) + return ipRecordNotFound(); + + auto found = next_it - 1; + if (matchIPAddrWithRecord(target, *found) == 0) + return found; + + return ipRecordNotFound(); +} void registerDictionaryTrie(DictionaryFactory & factory) { diff --git a/src/Dictionaries/TrieDictionary.h b/src/Dictionaries/TrieDictionary.h index 1849f161935..891dadd3be8 100644 --- a/src/Dictionaries/TrieDictionary.h +++ b/src/Dictionaries/TrieDictionary.h @@ -7,6 +7,7 @@ #include #include #include +#include #include #include #include @@ -14,9 +15,6 @@ #include "IDictionary.h" #include "IDictionarySource.h" -struct btrie_s; -typedef struct btrie_s btrie_t; - namespace DB { class TrieDictionary final : public IDictionaryBase @@ -150,9 +148,22 @@ public: BlockInputStreamPtr getBlockInputStream(const Names & column_names, size_t max_block_size) const override; private: + template using ContainerType = std::vector; + using IPAddress = Poco::Net::IPAddress; + + struct IPRecord; + using IPRecordConstIt = ContainerType::const_iterator; + + struct IPRecord final + { + IPAddress addr; + UInt8 prefix; + size_t row; + }; + struct Attribute final { AttributeUnderlyingType type; @@ -212,11 +223,10 @@ private: void getItemsImpl(const Attribute & attribute, const Columns & key_columns, ValueSetter && set_value, DefaultGetter && get_default) const; - template - bool setAttributeValueImpl(Attribute & attribute, const StringRef key, const T value); + void setAttributeValueImpl(Attribute & attribute, const T value); - bool setAttributeValue(Attribute & attribute, const StringRef key, const Field & value); + void setAttributeValue(Attribute & attribute, const Field & value); const Attribute & getAttribute(const std::string & attribute_name) const; @@ -225,14 +235,27 @@ private: Columns getKeyColumns() const; + /** + * Compare ip addresses. + * + * @return negative value if ipaddr less than address in record + * @return zero if ipaddr in record subnet + * @return positive value if ipaddr greater than address in record + */ + int matchIPAddrWithRecord(const IPAddress & ipaddr, const IPRecord & record) const; + + IPRecordConstIt ipRecordNotFound() const; + IPRecordConstIt lookupIPRecord(const IPAddress & target) const; + const DictionaryStructure dict_struct; const DictionarySourcePtr source_ptr; const DictionaryLifetime dict_lifetime; const bool require_nonempty; const std::string key_description{dict_struct.getKeyDescription()}; + ContainerType ip_records; + size_t total_ip_length; - btrie_t * trie = nullptr; std::map attribute_index_by_name; std::vector attributes; diff --git a/tests/queries/0_stateless/01018_ddl_dictionaries_special.reference b/tests/queries/0_stateless/01018_ddl_dictionaries_special.reference index c6c6993faa8..a6332b85f4e 100644 --- a/tests/queries/0_stateless/01018_ddl_dictionaries_special.reference +++ b/tests/queries/0_stateless/01018_ddl_dictionaries_special.reference @@ -10,6 +10,11 @@ 0 ***ip trie dict*** 17501 +17501 +17502 +0 +11211 +11211 NP ***hierarchy dict*** Moscow diff --git a/tests/queries/0_stateless/01018_ddl_dictionaries_special.sql b/tests/queries/0_stateless/01018_ddl_dictionaries_special.sql index ede5897bdf7..6c4a325a3b5 100644 --- a/tests/queries/0_stateless/01018_ddl_dictionaries_special.sql +++ b/tests/queries/0_stateless/01018_ddl_dictionaries_special.sql @@ -82,8 +82,7 @@ CREATE TABLE database_for_dict.table_ip_trie ) engine = TinyLog; -INSERT INTO database_for_dict.table_ip_trie VALUES ('202.79.32.0/20', 17501, 'NP'), ('2620:0:870::/48', 3856, 'US'), ('2a02:6b8:1::/48', 13238, 'RU'), ('2001:db8::/32', 65536, 'ZZ'); - +INSERT INTO database_for_dict.table_ip_trie VALUES ('202.79.32.0/20', 17501, 'NP'), ('202.79.32.2', 17502, 'NP'), ('101.79.55.22', 11211, 'UK'), ('2620:0:870::/48', 3856, 'US'), ('2a02:6b8:1::/48', 13238, 'RU'), ('2001:db8::/32', 65536, 'ZZ'); CREATE DICTIONARY database_for_dict.dict_ip_trie ( @@ -97,6 +96,11 @@ LAYOUT(IP_TRIE()) LIFETIME(MIN 10 MAX 100); SELECT dictGetUInt32('database_for_dict.dict_ip_trie', 'asn', tuple(IPv4StringToNum('202.79.32.0'))); +SELECT dictGetUInt32('database_for_dict.dict_ip_trie', 'asn', tuple(IPv4StringToNum('202.79.32.1'))); +SELECT dictGetUInt32('database_for_dict.dict_ip_trie', 'asn', tuple(IPv4StringToNum('202.79.32.2'))); +SELECT dictHas('database_for_dict.dict_ip_trie', tuple(IPv6StringToNum('654f:3716::'))); +SELECT dictGetUInt32('database_for_dict.dict_ip_trie', 'asn', tuple(IPv6StringToNum('::ffff:654f:3716'))); +SELECT dictGetUInt32('database_for_dict.dict_ip_trie', 'asn', tuple(IPv6StringToNum('::ffff:101.79.55.22'))); SELECT dictGetString('database_for_dict.dict_ip_trie', 'cca2', tuple(IPv4StringToNum('202.79.32.0'))); SELECT '***hierarchy dict***';