fix cutIPv6 function and IPv4 hashing

This commit is contained in:
Yakov Olkhovskiy 2023-01-07 21:48:02 +00:00
parent 1f32ffedf8
commit d77178194f
2 changed files with 77 additions and 60 deletions

View File

@ -139,12 +139,15 @@ public:
DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
{ {
const auto * ptr = checkAndGetDataType<DataTypeFixedString>(arguments[0].get()); if (!checkAndGetDataType<DataTypeIPv6>(arguments[0].get()))
if (!ptr || ptr->getN() != IPV6_BINARY_LENGTH) {
throw Exception("Illegal type " + arguments[0]->getName() + const auto * ptr = checkAndGetDataType<DataTypeFixedString>(arguments[0].get());
" of argument 1 of function " + getName() + if (!ptr || ptr->getN() != IPV6_BINARY_LENGTH)
", expected FixedString(" + toString(IPV6_BINARY_LENGTH) + ")", throw Exception("Illegal type " + arguments[0]->getName() +
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); " of argument 1 of function " + getName() +
", expected FixedString(" + toString(IPV6_BINARY_LENGTH) + ")",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
}
if (!WhichDataType(arguments[1]).isUInt8()) if (!WhichDataType(arguments[1]).isUInt8())
throw Exception("Illegal type " + arguments[1]->getName() + throw Exception("Illegal type " + arguments[1]->getName() +
@ -162,7 +165,7 @@ public:
bool useDefaultImplementationForConstants() const override { return true; } bool useDefaultImplementationForConstants() const override { return true; }
ColumnNumbers getArgumentsThatAreAlwaysConstant() const override { return {1, 2}; } ColumnNumbers getArgumentsThatAreAlwaysConstant() const override { return {1, 2}; }
ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t /*input_rows_count*/) const override ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t input_rows_count) const override
{ {
const auto & col_type_name = arguments[0]; const auto & col_type_name = arguments[0];
const ColumnPtr & column = col_type_name.column; const ColumnPtr & column = col_type_name.column;
@ -172,68 +175,82 @@ public:
const auto & col_ipv4_zeroed_tail_bytes_type = arguments[2]; const auto & col_ipv4_zeroed_tail_bytes_type = arguments[2];
const auto & col_ipv4_zeroed_tail_bytes = col_ipv4_zeroed_tail_bytes_type.column; const auto & col_ipv4_zeroed_tail_bytes = col_ipv4_zeroed_tail_bytes_type.column;
if (const auto * col_in = checkAndGetColumn<ColumnFixedString>(column.get())) const auto * col_in_str = checkAndGetColumn<ColumnFixedString>(column.get());
const auto * col_in_ip = checkAndGetColumn<ColumnIPv6>(column.get());
if (!col_in_str && !col_in_ip)
throw Exception("Illegal column " + arguments[0].column->getName()
+ " of argument of function " + getName(),
ErrorCodes::ILLEGAL_COLUMN);
if (col_in_str && col_in_str->getN() != IPV6_BINARY_LENGTH)
throw Exception("Illegal type " + col_type_name.type->getName() +
" of column " + col_in_str->getName() +
" argument of function " + getName() +
", expected FixedString(" + toString(IPV6_BINARY_LENGTH) + ")",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
const auto * ipv6_zeroed_tail_bytes = checkAndGetColumnConst<ColumnVector<UInt8>>(col_ipv6_zeroed_tail_bytes.get());
if (!ipv6_zeroed_tail_bytes)
throw Exception("Illegal type " + col_ipv6_zeroed_tail_bytes_type.type->getName() +
" of argument 2 of function " + getName(),
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
UInt8 ipv6_zeroed_tail_bytes_count = ipv6_zeroed_tail_bytes->getValue<UInt8>();
if (ipv6_zeroed_tail_bytes_count > IPV6_BINARY_LENGTH)
throw Exception("Illegal value for argument 2 " + col_ipv6_zeroed_tail_bytes_type.type->getName() +
" of function " + getName(),
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
const auto * ipv4_zeroed_tail_bytes = checkAndGetColumnConst<ColumnVector<UInt8>>(col_ipv4_zeroed_tail_bytes.get());
if (!ipv4_zeroed_tail_bytes)
throw Exception("Illegal type " + col_ipv4_zeroed_tail_bytes_type.type->getName() +
" of argument 3 of function " + getName(),
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
UInt8 ipv4_zeroed_tail_bytes_count = ipv4_zeroed_tail_bytes->getValue<UInt8>();
if (ipv4_zeroed_tail_bytes_count > IPV6_BINARY_LENGTH)
throw Exception("Illegal value for argument 3 " + col_ipv4_zeroed_tail_bytes_type.type->getName() +
" of function " + getName(),
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
auto col_res = ColumnString::create();
ColumnString::Chars & vec_res = col_res->getChars();
ColumnString::Offsets & offsets_res = col_res->getOffsets();
vec_res.resize(input_rows_count * (IPV6_MAX_TEXT_LENGTH + 1));
offsets_res.resize(input_rows_count);
auto * begin = reinterpret_cast<char *>(vec_res.data());
auto * pos = begin;
if (col_in_str)
{ {
if (col_in->getN() != IPV6_BINARY_LENGTH) const auto & vec_in = col_in_str->getChars();
throw Exception("Illegal type " + col_type_name.type->getName() +
" of column " + col_in->getName() +
" argument of function " + getName() +
", expected FixedString(" + toString(IPV6_BINARY_LENGTH) + ")",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
const auto * ipv6_zeroed_tail_bytes = checkAndGetColumnConst<ColumnVector<UInt8>>(col_ipv6_zeroed_tail_bytes.get()); for (size_t offset = 0, i = 0; i < input_rows_count; offset += IPV6_BINARY_LENGTH, ++i)
if (!ipv6_zeroed_tail_bytes)
throw Exception("Illegal type " + col_ipv6_zeroed_tail_bytes_type.type->getName() +
" of argument 2 of function " + getName(),
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
UInt8 ipv6_zeroed_tail_bytes_count = ipv6_zeroed_tail_bytes->getValue<UInt8>();
if (ipv6_zeroed_tail_bytes_count > IPV6_BINARY_LENGTH)
throw Exception("Illegal value for argument 2 " + col_ipv6_zeroed_tail_bytes_type.type->getName() +
" of function " + getName(),
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
const auto * ipv4_zeroed_tail_bytes = checkAndGetColumnConst<ColumnVector<UInt8>>(col_ipv4_zeroed_tail_bytes.get());
if (!ipv4_zeroed_tail_bytes)
throw Exception("Illegal type " + col_ipv4_zeroed_tail_bytes_type.type->getName() +
" of argument 3 of function " + getName(),
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
UInt8 ipv4_zeroed_tail_bytes_count = ipv4_zeroed_tail_bytes->getValue<UInt8>();
if (ipv4_zeroed_tail_bytes_count > IPV6_BINARY_LENGTH)
throw Exception("Illegal value for argument 3 " + col_ipv4_zeroed_tail_bytes_type.type->getName() +
" of function " + getName(),
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
const auto size = col_in->size();
const auto & vec_in = col_in->getChars();
auto col_res = ColumnString::create();
ColumnString::Chars & vec_res = col_res->getChars();
ColumnString::Offsets & offsets_res = col_res->getOffsets();
vec_res.resize(size * (IPV6_MAX_TEXT_LENGTH + 1));
offsets_res.resize(size);
auto * begin = reinterpret_cast<char *>(vec_res.data());
auto * pos = begin;
for (size_t offset = 0, i = 0; offset < vec_in.size(); offset += IPV6_BINARY_LENGTH, ++i)
{ {
const auto * address = &vec_in[offset]; const auto * address = &vec_in[offset];
UInt8 zeroed_tail_bytes_count = isIPv4Mapped(address) ? ipv4_zeroed_tail_bytes_count : ipv6_zeroed_tail_bytes_count; UInt8 zeroed_tail_bytes_count = isIPv4Mapped(address) ? ipv4_zeroed_tail_bytes_count : ipv6_zeroed_tail_bytes_count;
cutAddress(reinterpret_cast<const unsigned char *>(address), pos, zeroed_tail_bytes_count); cutAddress(reinterpret_cast<const unsigned char *>(address), pos, zeroed_tail_bytes_count);
offsets_res[i] = pos - begin; offsets_res[i] = pos - begin;
} }
vec_res.resize(pos - begin);
return col_res;
} }
else else
throw Exception("Illegal column " + arguments[0].column->getName() {
+ " of argument of function " + getName(), const auto & vec_in = col_in_ip->getData();
ErrorCodes::ILLEGAL_COLUMN);
for (size_t i = 0; i < input_rows_count; ++i)
{
const auto * address = reinterpret_cast<const UInt8 *>(&vec_in[i]);
UInt8 zeroed_tail_bytes_count = isIPv4Mapped(address) ? ipv4_zeroed_tail_bytes_count : ipv6_zeroed_tail_bytes_count;
cutAddress(reinterpret_cast<const unsigned char *>(address), pos, zeroed_tail_bytes_count);
offsets_res[i] = pos - begin;
}
}
vec_res.resize(pos - begin);
return col_res;
} }
private: private:

View File

@ -1136,7 +1136,7 @@ private:
else if (which.isInt128()) executeBigIntType<Int128, first>(icolumn, vec_to); else if (which.isInt128()) executeBigIntType<Int128, first>(icolumn, vec_to);
else if (which.isInt256()) executeBigIntType<Int256, first>(icolumn, vec_to); else if (which.isInt256()) executeBigIntType<Int256, first>(icolumn, vec_to);
else if (which.isUUID()) executeBigIntType<UUID, first>(icolumn, vec_to); else if (which.isUUID()) executeBigIntType<UUID, first>(icolumn, vec_to);
else if (which.isIPv4()) executeBigIntType<IPv4, first>(icolumn, vec_to); else if (which.isIPv4()) executeIntType<IPv4, first>(icolumn, vec_to);
else if (which.isIPv6()) executeBigIntType<IPv6, first>(icolumn, vec_to); else if (which.isIPv6()) executeBigIntType<IPv6, first>(icolumn, vec_to);
else if (which.isEnum8()) executeIntType<Int8, first>(icolumn, vec_to); else if (which.isEnum8()) executeIntType<Int8, first>(icolumn, vec_to);
else if (which.isEnum16()) executeIntType<Int16, first>(icolumn, vec_to); else if (which.isEnum16()) executeIntType<Int16, first>(icolumn, vec_to);