ClickHouse/src/Functions/initializeAggregation.cpp

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

159 lines
5.8 KiB
C++
Raw Normal View History

2021-05-17 07:30:42 +00:00
#include <Functions/IFunction.h>
#include <Functions/FunctionFactory.h>
#include <Functions/FunctionHelpers.h>
#include <Columns/ColumnString.h>
#include <Columns/ColumnAggregateFunction.h>
#include <AggregateFunctions/AggregateFunctionFactory.h>
#include <AggregateFunctions/AggregateFunctionState.h>
#include <AggregateFunctions/IAggregateFunction.h>
#include <AggregateFunctions/parseAggregateFunctionParameters.h>
#include <Common/Arena.h>
2022-04-27 15:05:45 +00:00
#include <Common/scope_guard_safe.h>
namespace DB
{
namespace ErrorCodes
{
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
extern const int BAD_ARGUMENTS;
}
2020-09-07 18:00:37 +00:00
namespace
{
class FunctionInitializeAggregation : public IFunction, private WithContext
{
public:
static constexpr auto name = "initializeAggregation";
static FunctionPtr create(ContextPtr context_) { return std::make_shared<FunctionInitializeAggregation>(context_); }
explicit FunctionInitializeAggregation(ContextPtr context_) : WithContext(context_) {}
String getName() const override { return name; }
bool isVariadic() const override { return true; }
size_t getNumberOfArguments() const override { return 0; }
2021-06-22 16:21:23 +00:00
bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo & /*arguments*/) const override { return true; }
bool useDefaultImplementationForConstants() const override { return true; }
bool useDefaultImplementationForNulls() const override { return false; }
ColumnNumbers getArgumentsThatAreAlwaysConstant() const override { return {0}; }
DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override;
ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, size_t input_rows_count) const override;
private:
2021-06-06 21:49:55 +00:00
/// TODO Rewrite with FunctionBuilder.
mutable AggregateFunctionPtr aggregate_function;
};
DataTypePtr FunctionInitializeAggregation::getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const
{
if (arguments.size() < 2)
throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH,
"Number of arguments for function {} doesn't match: passed {}, should be at least 2.",
getName(), arguments.size());
const ColumnConst * aggregate_function_name_column = checkAndGetColumnConst<ColumnString>(arguments[0].column.get());
if (!aggregate_function_name_column)
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "First argument for function {} must be constant string: "
"name of aggregate function.", getName());
DataTypes argument_types(arguments.size() - 1);
for (size_t i = 1, size = arguments.size(); i < size; ++i)
{
argument_types[i - 1] = arguments[i].type;
}
if (!aggregate_function)
{
String aggregate_function_name_with_params = aggregate_function_name_column->getValue<String>();
if (aggregate_function_name_with_params.empty())
throw Exception(ErrorCodes::BAD_ARGUMENTS, "First argument for function {} (name of aggregate function) cannot be empty.", getName());
String aggregate_function_name;
Array params_row;
getAggregateFunctionNameAndParametersArray(aggregate_function_name_with_params,
aggregate_function_name, params_row, "function " + getName(), getContext());
AggregateFunctionProperties properties;
aggregate_function = AggregateFunctionFactory::instance().get(aggregate_function_name, argument_types, params_row, properties);
}
2022-11-28 15:02:59 +00:00
return aggregate_function->getResultType();
}
ColumnPtr FunctionInitializeAggregation::executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, size_t input_rows_count) const
{
2021-06-06 21:49:55 +00:00
const IAggregateFunction & agg_func = *aggregate_function;
std::unique_ptr<Arena> arena = std::make_unique<Arena>();
const size_t num_arguments_columns = arguments.size() - 1;
std::vector<ColumnPtr> materialized_columns(num_arguments_columns);
std::vector<const IColumn *> aggregate_arguments_vec(num_arguments_columns);
for (size_t i = 0; i < num_arguments_columns; ++i)
{
2020-10-19 13:42:14 +00:00
const IColumn * col = arguments[i + 1].column.get();
materialized_columns.emplace_back(col->convertToFullColumnIfConst());
aggregate_arguments_vec[i] = &(*materialized_columns.back());
}
const IColumn ** aggregate_arguments = aggregate_arguments_vec.data();
2020-10-19 13:42:14 +00:00
MutableColumnPtr result_holder = result_type->createColumn();
IColumn & res_col = *result_holder;
PODArray<AggregateDataPtr> places(input_rows_count);
for (size_t i = 0; i < input_rows_count; ++i)
{
places[i] = arena->alignedAlloc(agg_func.sizeOfData(), agg_func.alignOfData());
try
{
agg_func.create(places[i]);
}
catch (...)
{
for (size_t j = 0; j < i; ++j)
agg_func.destroy(places[j]);
throw;
}
}
SCOPE_EXIT_MEMORY_SAFE({
for (size_t i = 0; i < input_rows_count; ++i)
agg_func.destroy(places[i]);
});
{
2021-06-09 11:07:21 +00:00
const auto * that = &agg_func;
/// Unnest consecutive trailing -State combinators
2021-06-09 11:07:21 +00:00
while (const auto * func = typeid_cast<const AggregateFunctionState *>(that))
that = func->getNestedFunction().get();
that->addBatch(0, input_rows_count, places.data(), 0, aggregate_arguments, arena.get());
}
for (size_t i = 0; i < input_rows_count; ++i)
/// We should use insertMergeResultInto to insert result into ColumnAggregateFunction
/// correctly if result contains AggregateFunction's states
agg_func.insertMergeResultInto(places[i], res_col, arena.get());
2020-10-19 13:42:14 +00:00
return result_holder;
}
2020-09-07 18:00:37 +00:00
}
REGISTER_FUNCTION(InitializeAggregation)
{
factory.registerFunction<FunctionInitializeAggregation>();
}
}