diff --git a/src/Functions/initializeAggregation.cpp b/src/Functions/initializeAggregation.cpp new file mode 100644 index 00000000000..81bfa19a55a --- /dev/null +++ b/src/Functions/initializeAggregation.cpp @@ -0,0 +1,161 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; + extern const int ILLEGAL_COLUMN; + extern const int ILLEGAL_TYPE_OF_ARGUMENT; + extern const int BAD_ARGUMENTS; +} + + +class FunctionInitializeAggregation : public IFunction +{ +public: + static constexpr auto name = "initializeAggregation"; + static FunctionPtr create(const Context &) { return std::make_shared(); } + + String getName() const override { return name; } + + bool isVariadic() const override { return true; } + size_t getNumberOfArguments() const override { return 0; } + + bool useDefaultImplementationForConstants() const override { return true; } + ColumnNumbers getArgumentsThatAreAlwaysConstant() const override { return {0}; } + + DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override; + + void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result, size_t input_rows_count) override; + +private: + mutable AggregateFunctionPtr aggregate_function; +}; + + +DataTypePtr FunctionInitializeAggregation::getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const +{ + if (arguments.size() < 2) + throw Exception("Number of arguments for function " + getName() + " doesn't match: passed " + + toString(arguments.size()) + ", should be at least 2.", + ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); + + const ColumnConst * aggregate_function_name_column = checkAndGetColumnConst(arguments[0].column.get()); + if (!aggregate_function_name_column) + throw Exception("First argument for function " + getName() + " must be constant string: name of aggregate function.", + ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); + + DataTypes argument_types(arguments.size() - 1); + for (size_t i = 1, size = arguments.size(); i < size; ++i) + { + argument_types[i - 1] = arguments[i].type; + } + + if (!aggregate_function) + { + String aggregate_function_name_with_params = aggregate_function_name_column->getValue(); + + if (aggregate_function_name_with_params.empty()) + throw Exception("First argument for function " + getName() + " (name of aggregate function) cannot be empty.", + ErrorCodes::BAD_ARGUMENTS); + + String aggregate_function_name; + Array params_row; + getAggregateFunctionNameAndParametersArray(aggregate_function_name_with_params, + aggregate_function_name, params_row, "function " + getName()); + + AggregateFunctionProperties properties; + aggregate_function = AggregateFunctionFactory::instance().get(aggregate_function_name, argument_types, params_row, properties); + } + + return aggregate_function->getReturnType(); +} + + +void FunctionInitializeAggregation::executeImpl(Block & block, const ColumnNumbers & arguments, size_t result, size_t input_rows_count) +{ + IAggregateFunction & agg_func = *aggregate_function; + std::unique_ptr arena = std::make_unique(); + + const size_t num_arguments_columns = arguments.size() - 1; + + std::vector materialized_columns(num_arguments_columns); + std::vector aggregate_arguments_vec(num_arguments_columns); + + for (size_t i = 0; i < num_arguments_columns; ++i) + { + const IColumn * col = block.getByPosition(arguments[i + 1]).column.get(); + materialized_columns.emplace_back(col->convertToFullColumnIfConst()); + aggregate_arguments_vec[i] = &(*materialized_columns.back()); + } + + const IColumn ** aggregate_arguments = aggregate_arguments_vec.data(); + + MutableColumnPtr result_holder = block.getByPosition(result).type->createColumn(); + IColumn & res_col = *result_holder; + + /// AggregateFunction's states should be inserted into column using specific way + auto * res_col_aggregate_function = typeid_cast(&res_col); + + if (!res_col_aggregate_function && agg_func.isState()) + throw Exception("State function " + agg_func.getName() + " inserts results into non-state column " + + block.getByPosition(result).type->getName(), ErrorCodes::ILLEGAL_COLUMN); + + PODArray places(input_rows_count); + for (size_t i = 0; i < input_rows_count; ++i) + { + places[i] = arena->alignedAlloc(agg_func.sizeOfData(), agg_func.alignOfData()); + try + { + agg_func.create(places[i]); + } + catch (...) + { + for (size_t j = 0; j < i; ++j) + agg_func.destroy(places[j]); + throw; + } + } + + SCOPE_EXIT({ + for (size_t i = 0; i < input_rows_count; ++i) + agg_func.destroy(places[i]); + }); + + { + auto * that = &agg_func; + /// Unnest consecutive trailing -State combinators + while (auto * func = typeid_cast(that)) + that = func->getNestedFunction().get(); + that->addBatch(input_rows_count, places.data(), 0, aggregate_arguments, arena.get()); + } + + for (size_t i = 0; i < input_rows_count; ++i) + if (!res_col_aggregate_function) + agg_func.insertResultInto(places[i], res_col, arena.get()); + else + res_col_aggregate_function->insertFrom(places[i]); + block.getByPosition(result).column = std::move(result_holder); +} + + +void registerFunctionInitializeAggregation(FunctionFactory & factory) +{ + factory.registerFunction(); +} + +} diff --git a/src/Functions/registerFunctionsMiscellaneous.cpp b/src/Functions/registerFunctionsMiscellaneous.cpp index 5eb1e3e47c0..697eb5ecb64 100644 --- a/src/Functions/registerFunctionsMiscellaneous.cpp +++ b/src/Functions/registerFunctionsMiscellaneous.cpp @@ -58,6 +58,7 @@ void registerFunctionGetMacro(FunctionFactory &); void registerFunctionGetScalar(FunctionFactory &); void registerFunctionIsConstant(FunctionFactory &); void registerFunctionGlobalVariable(FunctionFactory &); +void registerFunctionInitializeAggregation(FunctionFactory &); #if USE_ICU void registerFunctionConvertCharset(FunctionFactory &); @@ -116,6 +117,7 @@ void registerFunctionsMiscellaneous(FunctionFactory & factory) registerFunctionGetScalar(factory); registerFunctionIsConstant(factory); registerFunctionGlobalVariable(factory); + registerFunctionInitializeAggregation(factory); #if USE_ICU registerFunctionConvertCharset(factory); diff --git a/tests/queries/0_stateless/01356_initialize_aggregation.reference b/tests/queries/0_stateless/01356_initialize_aggregation.reference new file mode 100644 index 00000000000..63ebb1717d6 --- /dev/null +++ b/tests/queries/0_stateless/01356_initialize_aggregation.reference @@ -0,0 +1,4 @@ +3 +[999,998,997,996,995,994,993,992,991,990] +[1] +[990,991,992,993,994,995,996,997,998,999] diff --git a/tests/queries/0_stateless/01356_initialize_aggregation.sql b/tests/queries/0_stateless/01356_initialize_aggregation.sql new file mode 100644 index 00000000000..07a5ca1892b --- /dev/null +++ b/tests/queries/0_stateless/01356_initialize_aggregation.sql @@ -0,0 +1,4 @@ +SELECT uniqMerge(state) FROM (SELECT initializeAggregation('uniqState', number % 3) AS state FROM system.numbers LIMIT 10000); +SELECT topKWeightedMerge(10)(state) FROM (SELECT initializeAggregation('topKWeightedState(10)', number, number) AS state FROM system.numbers LIMIT 1000); +SELECT topKWeightedMerge(10)(state) FROM (SELECT initializeAggregation('topKWeightedState(10)', 1, number) AS state FROM system.numbers LIMIT 1000); +SELECT topKWeightedMerge(10)(state) FROM (SELECT initializeAggregation('topKWeightedState(10)', number, 1) AS state FROM system.numbers LIMIT 1000);