ClickHouse/src/Functions/array/arrayReduce.cpp
2021-05-17 10:30:42 +03:00

204 lines
8.0 KiB
C++

#include <Functions/IFunction.h>
#include <Functions/FunctionFactory.h>
#include <Functions/FunctionHelpers.h>
#include <DataTypes/DataTypeArray.h>
#include <Columns/ColumnArray.h>
#include <Columns/ColumnString.h>
#include <Columns/ColumnAggregateFunction.h>
#include <AggregateFunctions/AggregateFunctionFactory.h>
#include <AggregateFunctions/AggregateFunctionState.h>
#include <AggregateFunctions/IAggregateFunction.h>
#include <AggregateFunctions/parseAggregateFunctionParameters.h>
#include <Common/Arena.h>
#include <ext/scope_guard_safe.h>
namespace DB
{
namespace ErrorCodes
{
extern const int SIZES_OF_ARRAYS_DOESNT_MATCH;
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
extern const int ILLEGAL_COLUMN;
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
extern const int BAD_ARGUMENTS;
}
/** Applies an aggregate function to array and returns its result.
* If aggregate function has multiple arguments, then this function can be applied to multiple arrays of the same size.
*
* arrayReduce('agg', arr1, ...) - apply the aggregate function `agg` to arrays `arr1...`
* If multiple arrays passed, then elements on corresponding positions are passed as multiple arguments to the aggregate function.
*/
class FunctionArrayReduce : public IFunction
{
public:
static constexpr auto name = "arrayReduce";
static FunctionPtr create(ContextPtr) { return std::make_shared<FunctionArrayReduce>(); }
String getName() const override { return name; }
bool isVariadic() const override { return true; }
size_t getNumberOfArguments() const override { return 0; }
bool useDefaultImplementationForConstants() const override { return true; }
ColumnNumbers getArgumentsThatAreAlwaysConstant() const override { return {0}; }
DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override;
ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, size_t input_rows_count) const override;
private:
/// lazy initialization in getReturnTypeImpl
/// TODO: init in OverloadResolver
mutable AggregateFunctionPtr aggregate_function;
};
DataTypePtr FunctionArrayReduce::getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const
{
/// The first argument is a constant string with the name of the aggregate function
/// (possibly with parameters in parentheses, for example: "quantile(0.99)").
if (arguments.size() < 2)
throw Exception("Number of arguments for function " + getName() + " doesn't match: passed "
+ toString(arguments.size()) + ", should be at least 2.",
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
const ColumnConst * aggregate_function_name_column = checkAndGetColumnConst<ColumnString>(arguments[0].column.get());
if (!aggregate_function_name_column)
throw Exception("First argument for function " + getName() + " must be constant string: name of aggregate function.",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
DataTypes argument_types(arguments.size() - 1);
for (size_t i = 1, size = arguments.size(); i < size; ++i)
{
const DataTypeArray * arg = checkAndGetDataType<DataTypeArray>(arguments[i].type.get());
if (!arg)
throw Exception("Argument " + toString(i) + " for function " + getName() + " must be an array but it has type "
+ arguments[i].type->getName() + ".", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
argument_types[i - 1] = arg->getNestedType();
}
if (!aggregate_function)
{
String aggregate_function_name_with_params = aggregate_function_name_column->getValue<String>();
if (aggregate_function_name_with_params.empty())
throw Exception("First argument for function " + getName() + " (name of aggregate function) cannot be empty.",
ErrorCodes::BAD_ARGUMENTS);
String aggregate_function_name;
Array params_row;
getAggregateFunctionNameAndParametersArray(aggregate_function_name_with_params,
aggregate_function_name, params_row, "function " + getName());
AggregateFunctionProperties properties;
aggregate_function = AggregateFunctionFactory::instance().get(aggregate_function_name, argument_types, params_row, properties);
}
return aggregate_function->getReturnType();
}
ColumnPtr FunctionArrayReduce::executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, size_t input_rows_count) const
{
IAggregateFunction & agg_func = *aggregate_function;
std::unique_ptr<Arena> arena = std::make_unique<Arena>();
/// Aggregate functions do not support constant columns. Therefore, we materialize them.
std::vector<ColumnPtr> materialized_columns;
const size_t num_arguments_columns = arguments.size() - 1;
std::vector<const IColumn *> aggregate_arguments_vec(num_arguments_columns);
const ColumnArray::Offsets * offsets = nullptr;
for (size_t i = 0; i < num_arguments_columns; ++i)
{
const IColumn * col = arguments[i + 1].column.get();
const ColumnArray::Offsets * offsets_i = nullptr;
if (const ColumnArray * arr = checkAndGetColumn<ColumnArray>(col))
{
aggregate_arguments_vec[i] = &arr->getData();
offsets_i = &arr->getOffsets();
}
else if (const ColumnConst * const_arr = checkAndGetColumnConst<ColumnArray>(col))
{
materialized_columns.emplace_back(const_arr->convertToFullColumn());
const auto & materialized_arr = typeid_cast<const ColumnArray &>(*materialized_columns.back());
aggregate_arguments_vec[i] = &materialized_arr.getData();
offsets_i = &materialized_arr.getOffsets();
}
else
throw Exception("Illegal column " + col->getName() + " as argument of function " + getName(), ErrorCodes::ILLEGAL_COLUMN);
if (i == 0)
offsets = offsets_i;
else if (*offsets_i != *offsets)
throw Exception("Lengths of all arrays passed to " + getName() + " must be equal.",
ErrorCodes::SIZES_OF_ARRAYS_DOESNT_MATCH);
}
const IColumn ** aggregate_arguments = aggregate_arguments_vec.data();
MutableColumnPtr result_holder = result_type->createColumn();
IColumn & res_col = *result_holder;
/// AggregateFunction's states should be inserted into column using specific way
auto * res_col_aggregate_function = typeid_cast<ColumnAggregateFunction *>(&res_col);
if (!res_col_aggregate_function && agg_func.isState())
throw Exception("State function " + agg_func.getName() + " inserts results into non-state column "
+ result_type->getName(), ErrorCodes::ILLEGAL_COLUMN);
PODArray<AggregateDataPtr> places(input_rows_count);
for (size_t i = 0; i < input_rows_count; ++i)
{
places[i] = arena->alignedAlloc(agg_func.sizeOfData(), agg_func.alignOfData());
try
{
agg_func.create(places[i]);
}
catch (...)
{
for (size_t j = 0; j < i; ++j)
agg_func.destroy(places[j]);
throw;
}
}
SCOPE_EXIT_MEMORY_SAFE({
for (size_t i = 0; i < input_rows_count; ++i)
agg_func.destroy(places[i]);
});
{
auto * that = &agg_func;
/// Unnest consecutive trailing -State combinators
while (auto * func = typeid_cast<AggregateFunctionState *>(that))
that = func->getNestedFunction().get();
that->addBatchArray(input_rows_count, places.data(), 0, aggregate_arguments, offsets->data(), arena.get());
}
for (size_t i = 0; i < input_rows_count; ++i)
if (!res_col_aggregate_function)
agg_func.insertResultInto(places[i], res_col, arena.get());
else
res_col_aggregate_function->insertFrom(places[i]);
return result_holder;
}
void registerFunctionArrayReduce(FunctionFactory & factory)
{
factory.registerFunction<FunctionArrayReduce>();
}
}