Add generic implementation of function transform

This commit is contained in:
Alexey Milovidov 2023-06-24 06:52:28 +02:00
parent 396eb70426
commit cb2d395410
3 changed files with 148 additions and 39 deletions

View File

@ -16,6 +16,7 @@
#include <Interpreters/convertFieldToType.h> #include <Interpreters/convertFieldToType.h>
#include <Common/HashTable/HashMap.h> #include <Common/HashTable/HashMap.h>
#include <Common/typeid_cast.h> #include <Common/typeid_cast.h>
#include <Common/FieldVisitorsAccurateComparison.h>
namespace DB namespace DB
@ -79,15 +80,6 @@ namespace
args_size); args_size);
const DataTypePtr & type_x = arguments[0]; const DataTypePtr & type_x = arguments[0];
const auto & type_x_nn = removeNullable(type_x);
if (!type_x_nn->isValueRepresentedByNumber() && !isString(type_x_nn) && !isNothing(type_x_nn))
throw Exception(
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
"Unsupported type {} of first argument "
"of function {}, must be numeric type or Date/DateTime or String",
type_x->getName(),
getName());
const DataTypeArray * type_arr_from = checkAndGetDataType<DataTypeArray>(arguments[1].get()); const DataTypeArray * type_arr_from = checkAndGetDataType<DataTypeArray>(arguments[1].get());
@ -99,14 +91,13 @@ namespace
const auto type_arr_from_nested = type_arr_from->getNestedType(); const auto type_arr_from_nested = type_arr_from->getNestedType();
if ((type_x->isValueRepresentedByNumber() != type_arr_from_nested->isValueRepresentedByNumber()) auto src = tryGetLeastSupertype(DataTypes{type_x, type_arr_from_nested});
|| (isString(type_x) != isString(type_arr_from_nested))) if (!src)
{ {
throw Exception( throw Exception(
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
"First argument and elements of array " "First argument and elements of array "
"of second argument of function {} must have compatible types: " "of the second argument of function {} must have compatible types",
"both numeric or both strings.",
getName()); getName());
} }
@ -157,8 +148,8 @@ namespace
} }
} }
ColumnPtr ColumnPtr executeImpl(
executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, size_t input_rows_count) const override const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, size_t input_rows_count) const override
{ {
initialize(arguments, result_type); initialize(arguments, result_type);
@ -172,6 +163,8 @@ namespace
default_non_const = castColumn(arguments[3], result_type); default_non_const = castColumn(arguments[3], result_type);
auto column_result = result_type->createColumn(); auto column_result = result_type->createColumn();
if (cache.table_num_to_idx)
{
if (!executeNum<ColumnVector<UInt8>>(in, *column_result, default_non_const) if (!executeNum<ColumnVector<UInt8>>(in, *column_result, default_non_const)
&& !executeNum<ColumnVector<UInt16>>(in, *column_result, default_non_const) && !executeNum<ColumnVector<UInt16>>(in, *column_result, default_non_const)
&& !executeNum<ColumnVector<UInt32>>(in, *column_result, default_non_const) && !executeNum<ColumnVector<UInt32>>(in, *column_result, default_non_const)
@ -183,11 +176,23 @@ namespace
&& !executeNum<ColumnVector<Float32>>(in, *column_result, default_non_const) && !executeNum<ColumnVector<Float32>>(in, *column_result, default_non_const)
&& !executeNum<ColumnVector<Float64>>(in, *column_result, default_non_const) && !executeNum<ColumnVector<Float64>>(in, *column_result, default_non_const)
&& !executeNum<ColumnDecimal<Decimal32>>(in, *column_result, default_non_const) && !executeNum<ColumnDecimal<Decimal32>>(in, *column_result, default_non_const)
&& !executeNum<ColumnDecimal<Decimal64>>(in, *column_result, default_non_const) && !executeNum<ColumnDecimal<Decimal64>>(in, *column_result, default_non_const))
&& !executeString(in, *column_result, default_non_const))
{ {
throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Illegal column {} of first argument of function {}", in->getName(), getName()); throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Illegal column {} of first argument of function {}", in->getName(), getName());
} }
}
else if (cache.table_string_to_idx)
{
if (!executeString(in, *column_result, default_non_const))
executeContiguous(in, *column_result, default_non_const);
}
else if (cache.table_anything_to_idx)
{
executeAnything(in, *column_result, default_non_const);
}
else
throw Exception(ErrorCodes::LOGICAL_ERROR, "State of the function `transform` is not initialized");
return column_result; return column_result;
} }
@ -204,6 +209,47 @@ namespace
return impl->execute(args, result_type, input_rows_count); return impl->execute(args, result_type, input_rows_count);
} }
void executeAnything(const IColumn * in, IColumn & column_result, const ColumnPtr default_non_const) const
{
const size_t size = in->size();
const auto & table = *cache.table_anything_to_idx;
column_result.reserve(size);
for (size_t i = 0; i < size; ++i)
{
SipHash hash;
in->updateHashWithValue(i, hash);
const auto * it = table.find(hash.get128());
if (it)
column_result.insertFrom(*cache.to_column, it->getMapped());
else if (cache.default_column)
column_result.insertFrom(*cache.default_column, 0);
else if (default_non_const)
column_result.insertFrom(*default_non_const, i);
else
column_result.insertFrom(*in, i);
}
}
void executeContiguous(const IColumn * in, IColumn & column_result, const ColumnPtr default_non_const) const
{
const size_t size = in->size();
const auto & table = *cache.table_string_to_idx;
column_result.reserve(size);
for (size_t i = 0; i < size; ++i)
{
const auto * it = table.find(in->getDataAt(i));
if (it)
column_result.insertFrom(*cache.to_column, it->getMapped());
else if (cache.default_column)
column_result.insertFrom(*cache.default_column, 0);
else if (default_non_const)
column_result.insertFrom(*default_non_const, i);
else
column_result.insertFrom(*in, i);
}
}
template <typename T> template <typename T>
bool executeNum(const IColumn * in_untyped, IColumn & column_result, const ColumnPtr default_non_const) const bool executeNum(const IColumn * in_untyped, IColumn & column_result, const ColumnPtr default_non_const) const
{ {
@ -593,9 +639,11 @@ namespace
{ {
using NumToIdx = HashMap<UInt64, size_t, HashCRC32<UInt64>>; using NumToIdx = HashMap<UInt64, size_t, HashCRC32<UInt64>>;
using StringToIdx = HashMap<StringRef, size_t, StringRefHash>; using StringToIdx = HashMap<StringRef, size_t, StringRefHash>;
using AnythingToIdx = HashMap<UInt128, size_t>;
std::unique_ptr<NumToIdx> table_num_to_idx; std::unique_ptr<NumToIdx> table_num_to_idx;
std::unique_ptr<StringToIdx> table_string_to_idx; std::unique_ptr<StringToIdx> table_string_to_idx;
std::unique_ptr<AnythingToIdx> table_anything_to_idx;
ColumnPtr from_column; ColumnPtr from_column;
ColumnPtr to_column; ColumnPtr to_column;
@ -648,18 +696,16 @@ namespace
std::lock_guard lock(cache.mutex); std::lock_guard lock(cache.mutex);
ColumnPtr from_column_or_null_ptr = castColumnAccurateOrNull( const ColumnPtr & from_column_uncasted = array_from->getDataPtr();
cache.from_column = castColumn(
{ {
array_from->getDataPtr(), from_column_uncasted,
typeid_cast<const DataTypeArray &>(*arguments[1].type).getNestedType(), typeid_cast<const DataTypeArray &>(*arguments[1].type).getNestedType(),
arguments[1].name arguments[1].name
}, },
from_type); from_type);
const ColumnNullable & from_column_or_null = assert_cast<const ColumnNullable &>(*from_column_or_null_ptr);
cache.from_column = from_column_or_null.getNestedColumnPtr();
cache.to_column = castColumn( cache.to_column = castColumn(
{ {
array_to->getDataPtr(), array_to->getDataPtr(),
@ -696,13 +742,14 @@ namespace
/// Note: Doesn't check the duplicates in the `from` array. /// Note: Doesn't check the duplicates in the `from` array.
if (from_type->isValueRepresentedByNumber()) WhichDataType which(from_type);
if (isNativeNumber(which) || which.isDecimal32() || which.isDecimal64())
{ {
cache.table_num_to_idx = std::make_unique<Cache::NumToIdx>(); cache.table_num_to_idx = std::make_unique<Cache::NumToIdx>();
auto & table = *cache.table_num_to_idx; auto & table = *cache.table_num_to_idx;
for (size_t i = 0; i < size; ++i) for (size_t i = 0; i < size; ++i)
{ {
if (!from_column_or_null.isNullAt(i)) if (applyVisitor(FieldVisitorAccurateEquals(), (*cache.from_column)[i], (*from_column_uncasted)[i]))
{ {
/// Field may be of Float type, but for the purpose of bitwise equality we can treat them as UInt64 /// Field may be of Float type, but for the purpose of bitwise equality we can treat them as UInt64
StringRef ref = cache.from_column->getDataAt(i); StringRef ref = cache.from_column->getDataAt(i);
@ -718,7 +765,7 @@ namespace
auto & table = *cache.table_string_to_idx; auto & table = *cache.table_string_to_idx;
for (size_t i = 0; i < size; ++i) for (size_t i = 0; i < size; ++i)
{ {
if (!from_column_or_null.isNullAt(i)) if (applyVisitor(FieldVisitorAccurateEquals(), (*cache.from_column)[i], (*from_column_uncasted)[i]))
{ {
StringRef ref = cache.from_column->getDataAt(i); StringRef ref = cache.from_column->getDataAt(i);
table[ref] = i; table[ref] = i;
@ -726,7 +773,19 @@ namespace
} }
} }
else else
throw Exception(ErrorCodes::LOGICAL_ERROR, "Unexpected data type {} as the first argument in function `transform`", from_type->getName()); {
cache.table_anything_to_idx = std::make_unique<Cache::AnythingToIdx>();
auto & table = *cache.table_anything_to_idx;
for (size_t i = 0; i < size; ++i)
{
if (applyVisitor(FieldVisitorAccurateEquals(), (*cache.from_column)[i], (*from_column_uncasted)[i]))
{
SipHash hash;
cache.from_column->updateHashWithValue(i, hash);
table[hash.get128()] = i;
}
}
}
cache.initialized = true; cache.initialized = true;
} }

View File

@ -0,0 +1,38 @@
def
def
def
hello
def
world
def
abc!
def
def
hello
world
abc
hello
world
abc
123
2023-03-03 00:00:00.000
2023-02-02 00:00:00.000
2023-01-01 00:00:00.000
1 1
42 42
42
42

View File

@ -0,0 +1,12 @@
SELECT transform((number, toString(number)), [(3, '3'), (5, '5'), (7, '7')], ['hello', 'world', 'abc!'], 'def') FROM system.numbers LIMIT 10;
SELECT transform(toNullable(toInt256(number)), [3, 5, 7], ['hello', 'world', 'abc'], '') FROM system.numbers LIMIT 10;
SELECT transform(toUInt256(number), [3, 5, 7], ['hello', 'world', 'abc'], '') FROM system.numbers LIMIT 10;
select case 1::Nullable(Int32) when 1 then 123 else 0 end;
SELECT transform(arrayJoin(['c', 'b', 'a']), ['a', 'b'], [toDateTime64('2023-01-01', 3), toDateTime64('2023-02-02', 3)], toDateTime64('2023-03-03', 3));
SELECT transform(1, [1], [toDecimal32(1, 2)]), toDecimal32(1, 2);
select transform(1, [1], [toDecimal32(42, 2)]), toDecimal32(42, 2);
SELECT transform(1, [1], [toDecimal32(42, 2)], 0);
SELECT transform(1, [1], [toDecimal32(42, 2)], toDecimal32(0, 2));