fixup! Add groupSortedArray() function

This commit is contained in:
Pablo Alegre 2022-02-07 14:25:23 +01:00
parent 7f553d55ae
commit 9c38b1a031
2 changed files with 37 additions and 69 deletions

View File

@ -42,73 +42,35 @@ namespace
DataTypePtr getReturnType() const override { return std::make_shared<DataTypeArray>(std::make_shared<T>()); }
};
template <bool expr_sorted, typename TColumnB, bool is_plain_b>
static IAggregateFunction * createWithExtraTypes(const DataTypes & argument_types, UInt64 threshold, const Array & params)
template <bool expr_sorted = false, typename TColumnB = UInt64, bool is_plain_b = false>
AggregateFunctionPtr
createAggregateFunctionGroupSortedArrayTyped(const DataTypes & argument_types, const Array & params, UInt64 threshold)
{
if (argument_types.empty())
throw DB::Exception(ErrorCodes::LOGICAL_ERROR, "Got empty arguments list");
#define DISPATCH(A, CLS, B) \
if (which.idx == TypeIndex::A) \
return AggregateFunctionPtr(new CLS<B, expr_sorted, TColumnB, is_plain_b>(threshold, argument_types, params));
#define DISPATCH_NUMERIC(A) DISPATCH(A, AggregateFunctionGroupSortedArrayNumeric, A)
WhichDataType which(argument_types[0]);
if (which.idx == TypeIndex::Date)
return new AggregateFunctionGroupSortedArrayFieldType<DataTypeDate, expr_sorted, TColumnB, is_plain_b>(
threshold, argument_types, params);
if (which.idx == TypeIndex::DateTime)
return new AggregateFunctionGroupSortedArrayFieldType<DataTypeDateTime, expr_sorted, TColumnB, is_plain_b>(
threshold, argument_types, params);
FOR_NUMERIC_TYPES(DISPATCH_NUMERIC)
DISPATCH(Enum8, AggregateFunctionGroupSortedArrayNumeric, Int8)
DISPATCH(Enum16, AggregateFunctionGroupSortedArrayNumeric, Int16)
DISPATCH(Date, AggregateFunctionGroupSortedArrayFieldType, DataTypeDate)
DISPATCH(DateTime, AggregateFunctionGroupSortedArrayFieldType, DataTypeDateTime)
#undef DISPATCH
#undef DISPATCH_NUMERIC
if (argument_types[0]->isValueUnambiguouslyRepresentedInContiguousMemoryRegion())
{
return new AggregateFunctionGroupSortedArray<StringRef, true, expr_sorted, TColumnB, is_plain_b>(
threshold, argument_types, params);
return AggregateFunctionPtr(new AggregateFunctionGroupSortedArray<StringRef, true, expr_sorted, TColumnB, is_plain_b>(
threshold, argument_types, params));
}
else
{
return new AggregateFunctionGroupSortedArray<StringRef, false, expr_sorted, TColumnB, is_plain_b>(
threshold, argument_types, params);
return AggregateFunctionPtr(new AggregateFunctionGroupSortedArray<StringRef, false, expr_sorted, TColumnB, is_plain_b>(
threshold, argument_types, params));
}
}
template <
template <typename, bool, typename, bool>
class AggregateFunctionTemplate,
bool bool_param,
typename TColumnB,
bool is_plain_b,
typename... TArgs>
static IAggregateFunction * createWithNumericType2(const IDataType & argument_type, TArgs &&... args)
{
WhichDataType which(argument_type);
#define DISPATCH(TYPE) \
if (which.idx == TypeIndex::TYPE) \
return new AggregateFunctionTemplate<TYPE, bool_param, TColumnB, is_plain_b>(std::forward<TArgs>(args)...);
FOR_NUMERIC_TYPES(DISPATCH)
#undef DISPATCH
if (which.idx == TypeIndex::Enum8)
return new AggregateFunctionTemplate<Int8, bool_param, TColumnB, is_plain_b>(std::forward<TArgs>(args)...);
if (which.idx == TypeIndex::Enum16)
return new AggregateFunctionTemplate<Int16, bool_param, TColumnB, is_plain_b>(std::forward<TArgs>(args)...);
return nullptr;
}
template <bool expr_sorted = false, typename TColumnB = UInt64, bool is_plain_b = false>
AggregateFunctionPtr createAggregateFunctionGroupSortedArrayTyped(
const std::string & name, const DataTypes & argument_types, const Array & params, UInt64 threshold)
{
AggregateFunctionPtr res(createWithNumericType2<AggregateFunctionGroupSortedArrayNumeric, expr_sorted, TColumnB, is_plain_b>(
*argument_types[0], threshold, argument_types, params));
if (!res)
res = AggregateFunctionPtr(createWithExtraTypes<expr_sorted, TColumnB, is_plain_b>(argument_types, threshold, params));
if (!res)
throw Exception(
"Illegal type " + argument_types[0]->getName() + " of argument for aggregate function " + name,
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
return res;
}
AggregateFunctionPtr createAggregateFunctionGroupSortedArray(
const std::string & name, const DataTypes & argument_types, const Array & params, const Settings *)
@ -136,26 +98,32 @@ namespace
if (argument_types.size() == 2)
{
if (isUnsignedInteger(argument_types[1]))
if (isNumber(argument_types[1]))
{
return createAggregateFunctionGroupSortedArrayTyped<true, UInt64>(name, argument_types, params, threshold);
}
else if (isInteger(argument_types[1]))
{
return createAggregateFunctionGroupSortedArrayTyped<true, Int64>(name, argument_types, params, threshold);
#define DISPATCH2(A, B) \
if (which.idx == TypeIndex::A) \
return createAggregateFunctionGroupSortedArrayTyped<true, B>(argument_types, params, threshold);
#define DISPATCH(A) DISPATCH2(A, A)
WhichDataType which(argument_types[1]);
FOR_NUMERIC_TYPES(DISPATCH)
DISPATCH2(Enum8, Int8)
DISPATCH2(Enum16, Int16)
#undef DISPATCH
#undef DISPATCH2
throw Exception("Invalid parameter type.", ErrorCodes::BAD_ARGUMENTS);
}
else if (argument_types[1]->isValueUnambiguouslyRepresentedInContiguousMemoryRegion())
{
return createAggregateFunctionGroupSortedArrayTyped<true, StringRef, true>(name, argument_types, params, threshold);
return createAggregateFunctionGroupSortedArrayTyped<true, StringRef, true>(argument_types, params, threshold);
}
else
{
return createAggregateFunctionGroupSortedArrayTyped<true, StringRef, false>(name, argument_types, params, threshold);
return createAggregateFunctionGroupSortedArrayTyped<true, StringRef, false>(argument_types, params, threshold);
}
}
else if (argument_types.size() == 1)
{
return createAggregateFunctionGroupSortedArrayTyped<>(name, argument_types, params, threshold);
return createAggregateFunctionGroupSortedArrayTyped<>(argument_types, params, threshold);
}
else
{

View File

@ -180,7 +180,7 @@ public:
if constexpr (use_column_b)
{
forFirstRows<TColumnB, is_plain_b>(
batch_size, columns, 1, arena, if_argument_pos, [columns, &arena, &data](size_t row, const TColumnB * values)
batch_size, columns, 1, arena, if_argument_pos, [columns, &arena, &data](size_t row, const TColumnB * values)
{
data.add(readItem<TColumnA, is_plain_a>(columns[0], arena, row), values[row]);
});
@ -188,9 +188,9 @@ public:
else
{
forFirstRows<TColumnA, is_plain_a>(
batch_size, columns, 0, arena, if_argument_pos, [&data](size_t row, const TColumnA * values)
{
data.add(values[row]);
batch_size, columns, 0, arena, if_argument_pos, [&data](size_t row, const TColumnA * values)
{
data.add(values[row]);
});
}
}