Merge pull request #18435 from vdimir/ip-dict-minor-fix

Raise an error if more than one key is provided to ip_dictionary
This commit is contained in:
alexey-milovidov 2020-12-24 20:16:07 +03:00 committed by GitHub
commit 30e3900235
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 33 additions and 100 deletions

View File

@ -404,57 +404,32 @@ void IPAddressDictionary::has(const Columns & key_columns, const DataTypes & key
{
validateKeyTypes(key_types);
const auto & attribute = attributes.front();
switch (attribute.type)
const auto first_column = key_columns.front();
const auto rows = first_column->size();
if (first_column->isNumeric())
{
case AttributeUnderlyingType::utUInt8:
has<UInt8>(attribute, key_columns, out);
break;
case AttributeUnderlyingType::utUInt16:
has<UInt16>(attribute, key_columns, out);
break;
case AttributeUnderlyingType::utUInt32:
has<UInt32>(attribute, key_columns, out);
break;
case AttributeUnderlyingType::utUInt64:
has<UInt64>(attribute, key_columns, out);
break;
case AttributeUnderlyingType::utUInt128:
has<UInt128>(attribute, key_columns, out);
break;
case AttributeUnderlyingType::utInt8:
has<Int8>(attribute, key_columns, out);
break;
case AttributeUnderlyingType::utInt16:
has<Int16>(attribute, key_columns, out);
break;
case AttributeUnderlyingType::utInt32:
has<Int32>(attribute, key_columns, out);
break;
case AttributeUnderlyingType::utInt64:
has<Int64>(attribute, key_columns, out);
break;
case AttributeUnderlyingType::utFloat32:
has<Float32>(attribute, key_columns, out);
break;
case AttributeUnderlyingType::utFloat64:
has<Float64>(attribute, key_columns, out);
break;
case AttributeUnderlyingType::utString:
has<StringRef>(attribute, key_columns, out);
break;
case AttributeUnderlyingType::utDecimal32:
has<Decimal32>(attribute, key_columns, out);
break;
case AttributeUnderlyingType::utDecimal64:
has<Decimal64>(attribute, key_columns, out);
break;
case AttributeUnderlyingType::utDecimal128:
has<Decimal128>(attribute, key_columns, out);
break;
uint8_t addrv6_buf[IPV6_BINARY_LENGTH];
for (const auto i : ext::range(0, rows))
{
auto addrv4 = UInt32(first_column->get64(i));
auto found = tryLookupIPv4(addrv4, addrv6_buf);
out[i] = (found != ipNotFound());
}
}
else
{
for (const auto i : ext::range(0, rows))
{
auto addr = first_column->getDataAt(i);
if (unlikely(addr.size != IPV6_BINARY_LENGTH))
throw Exception("Expected key to be FixedString(16)", ErrorCodes::LOGICAL_ERROR);
auto found = tryLookupIPv6(reinterpret_cast<const uint8_t *>(addr.data));
out[i] = (found != ipNotFound());
}
}
query_count.fetch_add(rows, std::memory_order_relaxed);
}
void IPAddressDictionary::createAttributes()
@ -478,16 +453,10 @@ void IPAddressDictionary::loadData()
auto stream = source_ptr->loadAll();
stream->readPrefix();
/// created upfront to avoid excess allocations
const auto keys_size = dict_struct.key->size();
const auto attributes_size = attributes.size();
std::vector<IPRecord> ip_records;
row_idx.reserve(keys_size);
mask_column.reserve(keys_size);
bool has_ipv6 = false;
while (const auto block = stream->read())
@ -495,19 +464,14 @@ void IPAddressDictionary::loadData()
const auto rows = block.rows();
element_count += rows;
const auto key_column_ptrs = ext::map<Columns>(
ext::range(0, keys_size), [&](const size_t attribute_idx) { return block.safeGetByPosition(attribute_idx).column; });
const ColumnPtr key_column_ptr = block.safeGetByPosition(0).column;
const auto attribute_column_ptrs = ext::map<Columns>(ext::range(0, attributes_size), [&](const size_t attribute_idx)
{
return block.safeGetByPosition(keys_size + attribute_idx).column;
return block.safeGetByPosition(attribute_idx + 1).column;
});
for (const auto row : ext::range(0, rows))
{
/// calculate key once per row
const auto key_column = key_column_ptrs.front();
for (const auto attribute_idx : ext::range(0, attributes_size))
{
const auto & attribute_column = *attribute_column_ptrs[attribute_idx];
@ -516,7 +480,7 @@ void IPAddressDictionary::loadData()
setAttributeValue(attribute, attribute_column[row]);
}
const auto [addr, prefix] = parseIPFromString(std::string_view(key_column->getDataAt(row)));
const auto [addr, prefix] = parseIPFromString(std::string_view(key_column_ptr->getDataAt(row)));
has_ipv6 = has_ipv6 || (addr.family() == Poco::Net::IPAddress::IPv6);
size_t row_number = ip_records.size();
@ -526,6 +490,9 @@ void IPAddressDictionary::loadData()
stream->readSuffix();
row_idx.reserve(ip_records.size());
mask_column.reserve(ip_records.size());
if (has_ipv6)
{
auto deleted_count = sortAndUnique(ip_records,
@ -971,37 +938,6 @@ const IPAddressDictionary::Attribute & IPAddressDictionary::getAttribute(const s
return attributes[it->second];
}
template <typename T>
void IPAddressDictionary::has(const Attribute &, const Columns & key_columns, PaddedPODArray<UInt8> & out) const
{
const auto first_column = key_columns.front();
const auto rows = first_column->size();
if (first_column->isNumeric())
{
uint8_t addrv6_buf[IPV6_BINARY_LENGTH];
for (const auto i : ext::range(0, rows))
{
auto addrv4 = UInt32(first_column->get64(i));
auto found = tryLookupIPv4(addrv4, addrv6_buf);
out[i] = (found != ipNotFound());
}
}
else
{
for (const auto i : ext::range(0, rows))
{
auto addr = first_column->getDataAt(i);
if (unlikely(addr.size != IPV6_BINARY_LENGTH))
throw Exception("Expected key to be FixedString(16)", ErrorCodes::LOGICAL_ERROR);
auto found = tryLookupIPv6(reinterpret_cast<const uint8_t *>(addr.data));
out[i] = (found != ipNotFound());
}
}
query_count.fetch_add(rows, std::memory_order_relaxed);
}
Columns IPAddressDictionary::getKeyColumns() const
{
const auto * ipv4_col = std::get_if<IPv4Container>(&ip_column);
@ -1178,13 +1114,13 @@ void registerDictionaryTrie(DictionaryFactory & factory)
const std::string & config_prefix,
DictionarySourcePtr source_ptr) -> DictionaryPtr
{
if (!dict_struct.key)
throw Exception{"'key' is required for dictionary of layout 'ip_trie'", ErrorCodes::BAD_ARGUMENTS};
if (!dict_struct.key || dict_struct.key->size() != 1)
throw Exception{"Dictionary of layout 'ip_trie' has to have one 'key'", ErrorCodes::BAD_ARGUMENTS};
const auto dict_id = StorageID::fromDictionaryConfig(config, config_prefix);
const DictionaryLifetime dict_lifetime{config, config_prefix + ".lifetime"};
const bool require_nonempty = config.getBool(config_prefix + ".require_nonempty", false);
// This is specialised trie for storing IPv4 and IPv6 prefixes.
// This is specialised dictionary for storing IPv4 and IPv6 prefixes.
return std::make_unique<IPAddressDictionary>(dict_id, dict_struct, std::move(source_ptr), dict_lifetime, require_nonempty);
};
factory.registerLayout("ip_trie", create_layout, true);

View File

@ -228,9 +228,6 @@ private:
const Attribute & getAttribute(const std::string & attribute_name) const;
template <typename T>
void has(const Attribute & attribute, const Columns & key_columns, PaddedPODArray<UInt8> & out) const;
Columns getKeyColumns() const;
RowIdxConstIter ipNotFound() const;
RowIdxConstIter tryLookupIPv4(UInt32 addr, uint8_t * buf) const;