From 9c38b1a0310db616caf342b9595441b2f0081525 Mon Sep 17 00:00:00 2001 From: Pablo Alegre Date: Mon, 7 Feb 2022 14:25:23 +0100 Subject: [PATCH] fixup! Add groupSortedArray() function --- .../AggregateFunctionGroupSortedArray.cpp | 98 +++++++------------ .../AggregateFunctionGroupSortedArray.h | 8 +- 2 files changed, 37 insertions(+), 69 deletions(-) diff --git a/src/AggregateFunctions/AggregateFunctionGroupSortedArray.cpp b/src/AggregateFunctions/AggregateFunctionGroupSortedArray.cpp index d47aab2b5be..ef52428aecd 100644 --- a/src/AggregateFunctions/AggregateFunctionGroupSortedArray.cpp +++ b/src/AggregateFunctions/AggregateFunctionGroupSortedArray.cpp @@ -42,73 +42,35 @@ namespace DataTypePtr getReturnType() const override { return std::make_shared(std::make_shared()); } }; - template - static IAggregateFunction * createWithExtraTypes(const DataTypes & argument_types, UInt64 threshold, const Array & params) + template + AggregateFunctionPtr + createAggregateFunctionGroupSortedArrayTyped(const DataTypes & argument_types, const Array & params, UInt64 threshold) { - if (argument_types.empty()) - throw DB::Exception(ErrorCodes::LOGICAL_ERROR, "Got empty arguments list"); - +#define DISPATCH(A, CLS, B) \ + if (which.idx == TypeIndex::A) \ + return AggregateFunctionPtr(new CLS(threshold, argument_types, params)); +#define DISPATCH_NUMERIC(A) DISPATCH(A, AggregateFunctionGroupSortedArrayNumeric, A) WhichDataType which(argument_types[0]); - if (which.idx == TypeIndex::Date) - return new AggregateFunctionGroupSortedArrayFieldType( - threshold, argument_types, params); - if (which.idx == TypeIndex::DateTime) - return new AggregateFunctionGroupSortedArrayFieldType( - threshold, argument_types, params); + FOR_NUMERIC_TYPES(DISPATCH_NUMERIC) + DISPATCH(Enum8, AggregateFunctionGroupSortedArrayNumeric, Int8) + DISPATCH(Enum16, AggregateFunctionGroupSortedArrayNumeric, Int16) + DISPATCH(Date, AggregateFunctionGroupSortedArrayFieldType, DataTypeDate) + DISPATCH(DateTime, AggregateFunctionGroupSortedArrayFieldType, DataTypeDateTime) +#undef DISPATCH +#undef DISPATCH_NUMERIC if (argument_types[0]->isValueUnambiguouslyRepresentedInContiguousMemoryRegion()) { - return new AggregateFunctionGroupSortedArray( - threshold, argument_types, params); + return AggregateFunctionPtr(new AggregateFunctionGroupSortedArray( + threshold, argument_types, params)); } else - { - return new AggregateFunctionGroupSortedArray( - threshold, argument_types, params); + return AggregateFunctionPtr(new AggregateFunctionGroupSortedArray( + threshold, argument_types, params)); } } - template < - template - class AggregateFunctionTemplate, - bool bool_param, - typename TColumnB, - bool is_plain_b, - typename... TArgs> - static IAggregateFunction * createWithNumericType2(const IDataType & argument_type, TArgs &&... args) - { - WhichDataType which(argument_type); -#define DISPATCH(TYPE) \ - if (which.idx == TypeIndex::TYPE) \ - return new AggregateFunctionTemplate(std::forward(args)...); - FOR_NUMERIC_TYPES(DISPATCH) -#undef DISPATCH - if (which.idx == TypeIndex::Enum8) - return new AggregateFunctionTemplate(std::forward(args)...); - if (which.idx == TypeIndex::Enum16) - return new AggregateFunctionTemplate(std::forward(args)...); - return nullptr; - } - - template - AggregateFunctionPtr createAggregateFunctionGroupSortedArrayTyped( - const std::string & name, const DataTypes & argument_types, const Array & params, UInt64 threshold) - { - AggregateFunctionPtr res(createWithNumericType2( - *argument_types[0], threshold, argument_types, params)); - - if (!res) - res = AggregateFunctionPtr(createWithExtraTypes(argument_types, threshold, params)); - - if (!res) - throw Exception( - "Illegal type " + argument_types[0]->getName() + " of argument for aggregate function " + name, - ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); - - return res; - } - AggregateFunctionPtr createAggregateFunctionGroupSortedArray( const std::string & name, const DataTypes & argument_types, const Array & params, const Settings *) @@ -136,26 +98,32 @@ namespace if (argument_types.size() == 2) { - if (isUnsignedInteger(argument_types[1])) + if (isNumber(argument_types[1])) { - return createAggregateFunctionGroupSortedArrayTyped(name, argument_types, params, threshold); - } - else if (isInteger(argument_types[1])) - { - return createAggregateFunctionGroupSortedArrayTyped(name, argument_types, params, threshold); +#define DISPATCH2(A, B) \ + if (which.idx == TypeIndex::A) \ + return createAggregateFunctionGroupSortedArrayTyped(argument_types, params, threshold); +#define DISPATCH(A) DISPATCH2(A, A) + WhichDataType which(argument_types[1]); + FOR_NUMERIC_TYPES(DISPATCH) + DISPATCH2(Enum8, Int8) + DISPATCH2(Enum16, Int16) +#undef DISPATCH +#undef DISPATCH2 + throw Exception("Invalid parameter type.", ErrorCodes::BAD_ARGUMENTS); } else if (argument_types[1]->isValueUnambiguouslyRepresentedInContiguousMemoryRegion()) { - return createAggregateFunctionGroupSortedArrayTyped(name, argument_types, params, threshold); + return createAggregateFunctionGroupSortedArrayTyped(argument_types, params, threshold); } else { - return createAggregateFunctionGroupSortedArrayTyped(name, argument_types, params, threshold); + return createAggregateFunctionGroupSortedArrayTyped(argument_types, params, threshold); } } else if (argument_types.size() == 1) { - return createAggregateFunctionGroupSortedArrayTyped<>(name, argument_types, params, threshold); + return createAggregateFunctionGroupSortedArrayTyped<>(argument_types, params, threshold); } else { diff --git a/src/AggregateFunctions/AggregateFunctionGroupSortedArray.h b/src/AggregateFunctions/AggregateFunctionGroupSortedArray.h index 5f9a67a21ae..79eeb5202f9 100644 --- a/src/AggregateFunctions/AggregateFunctionGroupSortedArray.h +++ b/src/AggregateFunctions/AggregateFunctionGroupSortedArray.h @@ -180,7 +180,7 @@ public: if constexpr (use_column_b) { forFirstRows( - batch_size, columns, 1, arena, if_argument_pos, [columns, &arena, &data](size_t row, const TColumnB * values) + batch_size, columns, 1, arena, if_argument_pos, [columns, &arena, &data](size_t row, const TColumnB * values) { data.add(readItem(columns[0], arena, row), values[row]); }); @@ -188,9 +188,9 @@ public: else { forFirstRows( - batch_size, columns, 0, arena, if_argument_pos, [&data](size_t row, const TColumnA * values) - { - data.add(values[row]); + batch_size, columns, 0, arena, if_argument_pos, [&data](size_t row, const TColumnA * values) + { + data.add(values[row]); }); } }