mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-11-21 23:21:59 +00:00
dbms: added function 'arrayReduce' [#METR-19264].
This commit is contained in:
parent
9060931654
commit
1e4def963d
@ -18,6 +18,12 @@
|
||||
#include <DB/Interpreters/AggregationCommon.h>
|
||||
#include <DB/Functions/NumberTraits.h>
|
||||
#include <DB/Functions/FunctionsConditional.h>
|
||||
#include <DB/AggregateFunctions/IAggregateFunction.h>
|
||||
#include <DB/AggregateFunctions/AggregateFunctionFactory.h>
|
||||
#include <DB/Parsers/ExpressionListParsers.h>
|
||||
#include <DB/Parsers/parseQuery.h>
|
||||
#include <DB/Parsers/ASTExpressionList.h>
|
||||
#include <DB/Parsers/ASTLiteral.h>
|
||||
|
||||
#include <ext/range.hpp>
|
||||
|
||||
@ -47,6 +53,8 @@ namespace DB
|
||||
* - для кортежей из элементов на соответствующих позициях в нескольких массивах.
|
||||
*
|
||||
* emptyArrayToSingle(arr) - заменить пустые массивы на массивы из одного элемента со значением "по-умолчанию".
|
||||
*
|
||||
* arrayReduce('agg', arr1, ...) - применить агрегатную функцию agg к массивам arr1...
|
||||
*/
|
||||
|
||||
|
||||
@ -2320,6 +2328,178 @@ private:
|
||||
};
|
||||
|
||||
|
||||
/** Применяет к массиву агрегатную функцию и возвращает её результат.
|
||||
* Также может быть применена к нескольким массивам одинаковых размеров, если агрегатная функция принимает несколько аргументов.
|
||||
*/
|
||||
class FunctionArrayReduce : public IFunction
|
||||
{
|
||||
public:
|
||||
static constexpr auto name = "arrayReduce";
|
||||
static IFunction * create(const Context & context) { return new FunctionArrayReduce; }
|
||||
|
||||
/// Получить имя функции.
|
||||
String getName() const override
|
||||
{
|
||||
return name;
|
||||
}
|
||||
|
||||
void getReturnTypeAndPrerequisites(
|
||||
const ColumnsWithTypeAndName & arguments,
|
||||
DataTypePtr & out_return_type,
|
||||
std::vector<ExpressionAction> & out_prerequisites) override
|
||||
{
|
||||
/// Первый аргумент - константная строка с именем агрегатной функции (возможно, с параметрами в скобках, например: "quantile(0.99)").
|
||||
|
||||
if (arguments.size() < 2)
|
||||
throw Exception("Number of arguments for function " + getName() + " doesn't match: passed "
|
||||
+ toString(arguments.size()) + ", should be at least 2.",
|
||||
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
|
||||
|
||||
const ColumnConstString * aggregate_function_name_column = typeid_cast<const ColumnConstString *>(arguments[0].column.get());
|
||||
if (!aggregate_function_name_column)
|
||||
throw Exception("First argument for function " + getName() + " must be constant string: name of aggregate function.",
|
||||
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
|
||||
|
||||
DataTypes argument_types(arguments.size() - 1);
|
||||
for (size_t i = 1, size = arguments.size(); i < size; ++i)
|
||||
{
|
||||
const DataTypeArray * arg = typeid_cast<const DataTypeArray *>(arguments[i].type.get());
|
||||
if (!arg)
|
||||
throw Exception("Argument " + toString(i) + " for function " + getName() + " must be array.",
|
||||
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
|
||||
|
||||
argument_types[i - 1] = arg->getNestedType()->clone();
|
||||
}
|
||||
|
||||
if (!aggregate_function)
|
||||
{
|
||||
const String & aggregate_function_name_with_params = aggregate_function_name_column->getData();
|
||||
|
||||
if (aggregate_function_name_with_params.empty())
|
||||
throw Exception("First argument for function " + getName() + " (name of aggregate function) cannot be empty.",
|
||||
ErrorCodes::BAD_ARGUMENTS);
|
||||
|
||||
bool has_parameters = ')' == aggregate_function_name_with_params.back();
|
||||
|
||||
String aggregate_function_name = aggregate_function_name_with_params;
|
||||
String parameters;
|
||||
Array params_row;
|
||||
|
||||
if (has_parameters)
|
||||
{
|
||||
size_t pos = aggregate_function_name_with_params.find('(');
|
||||
if (pos == std::string::npos || pos + 2 >= aggregate_function_name_with_params.size())
|
||||
throw Exception("First argument for function " + getName() + " doesn't look like aggregate function name.",
|
||||
ErrorCodes::BAD_ARGUMENTS);
|
||||
|
||||
aggregate_function_name = aggregate_function_name_with_params.substr(0, pos);
|
||||
parameters = aggregate_function_name_with_params.substr(pos + 1, aggregate_function_name_with_params.size() - pos - 2);
|
||||
|
||||
if (aggregate_function_name.empty())
|
||||
throw Exception("First argument for function " + getName() + " doesn't look like aggregate function name.",
|
||||
ErrorCodes::BAD_ARGUMENTS);
|
||||
|
||||
ParserExpressionList params_parser(false);
|
||||
ASTPtr args_ast = parseQuery(params_parser,
|
||||
parameters.data(), parameters.data() + parameters.size(),
|
||||
"parameters of aggregate function");
|
||||
|
||||
ASTExpressionList & args_list = typeid_cast<ASTExpressionList &>(*args_ast);
|
||||
|
||||
if (args_list.children.empty())
|
||||
throw Exception("Incorrect list of parameters to aggregate function "
|
||||
+ aggregate_function_name, ErrorCodes::BAD_ARGUMENTS);
|
||||
|
||||
params_row.reserve(args_list.children.size());
|
||||
for (const auto & child : args_list.children)
|
||||
{
|
||||
const ASTLiteral * lit = typeid_cast<const ASTLiteral *>(child.get());
|
||||
if (!lit)
|
||||
throw Exception("Parameters to aggregate functions must be literals",
|
||||
ErrorCodes::PARAMETERS_TO_AGGREGATE_FUNCTIONS_MUST_BE_LITERALS);
|
||||
|
||||
params_row.push_back(lit->value);
|
||||
}
|
||||
}
|
||||
|
||||
aggregate_function = AggregateFunctionFactory().get(aggregate_function_name, argument_types);
|
||||
if (has_parameters)
|
||||
aggregate_function->setParameters(params_row);
|
||||
aggregate_function->setArguments(argument_types);
|
||||
}
|
||||
|
||||
out_return_type = aggregate_function->getReturnType();
|
||||
}
|
||||
|
||||
|
||||
void execute(Block & block, const ColumnNumbers & arguments, size_t result) override
|
||||
{
|
||||
IAggregateFunction & agg_func = *aggregate_function.get();
|
||||
std::unique_ptr<char[]> place_holder { new char[agg_func.sizeOfData()] };
|
||||
char * place = place_holder.get();
|
||||
|
||||
size_t rows = block.rowsInFirstColumn();
|
||||
|
||||
/// Агрегатные функции не поддерживают константные столбцы. Поэтому, материализуем их.
|
||||
std::vector<ColumnPtr> materialized_columns;
|
||||
|
||||
std::vector<const IColumn *> aggregate_arguments_vec(arguments.size() - 1);
|
||||
|
||||
for (size_t i = 0, size = arguments.size() - 1; i < size; ++i)
|
||||
{
|
||||
const IColumn * col = block.unsafeGetByPosition(arguments[i + 1]).column.get();
|
||||
if (const ColumnArray * arr = typeid_cast<const ColumnArray *>(col))
|
||||
{
|
||||
aggregate_arguments_vec[i] = arr->getDataPtr().get();
|
||||
}
|
||||
else if (const ColumnConstArray * arr = typeid_cast<const ColumnConstArray *>(col))
|
||||
{
|
||||
materialized_columns.emplace_back(arr->convertToFullColumn());
|
||||
aggregate_arguments_vec[i] = typeid_cast<const ColumnArray &>(*materialized_columns.back().get()).getDataPtr().get();
|
||||
}
|
||||
else
|
||||
throw Exception("Illegal column " + col->getName() + " as argument of function " + getName(), ErrorCodes::ILLEGAL_COLUMN);
|
||||
|
||||
}
|
||||
const IColumn ** aggregate_arguments = aggregate_arguments_vec.data();
|
||||
|
||||
const ColumnArray::Offsets_t & offsets = typeid_cast<const ColumnArray &>(!materialized_columns.empty()
|
||||
? *materialized_columns.front().get()
|
||||
: *block.unsafeGetByPosition(arguments[1]).column.get()).getOffsets();
|
||||
|
||||
ColumnPtr result_holder = block.getByPosition(result).type->createColumn();
|
||||
block.getByPosition(result).column = result_holder;
|
||||
IColumn & res_col = *result_holder.get();
|
||||
|
||||
ColumnArray::Offset_t current_offset = 0;
|
||||
for (size_t i = 0; i < rows; ++i)
|
||||
{
|
||||
agg_func.create(place);
|
||||
ColumnArray::Offset_t next_offset = offsets[i];
|
||||
|
||||
try
|
||||
{
|
||||
for (size_t j = current_offset; j < next_offset; ++j)
|
||||
agg_func.add(place, aggregate_arguments, j);
|
||||
|
||||
agg_func.insertResultInto(place, res_col);
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
agg_func.destroy(place);
|
||||
throw;
|
||||
}
|
||||
|
||||
agg_func.destroy(place);
|
||||
current_offset = next_offset;
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
AggregateFunctionPtr aggregate_function;
|
||||
};
|
||||
|
||||
|
||||
struct NameHas { static constexpr auto name = "has"; };
|
||||
struct NameIndexOf { static constexpr auto name = "indexOf"; };
|
||||
struct NameCountEqual { static constexpr auto name = "countEqual"; };
|
||||
|
@ -29,6 +29,7 @@ void registerFunctionsArray(FunctionFactory & factory)
|
||||
factory.registerFunction<FunctionEmptyArrayString>();
|
||||
factory.registerFunction<FunctionEmptyArrayToSingle>();
|
||||
factory.registerFunction<FunctionRange>();
|
||||
factory.registerFunction<FunctionArrayReduce>();
|
||||
}
|
||||
|
||||
}
|
||||
|
13
dbms/tests/queries/0_stateless/00291_array_reduce.reference
Normal file
13
dbms/tests/queries/0_stateless/00291_array_reduce.reference
Normal file
@ -0,0 +1,13 @@
|
||||
2 4 4 3
|
||||
[nan,nan] []
|
||||
[0,0] [0]
|
||||
[0.5,0.9] [0,1]
|
||||
[1,1.8] [0,1,2]
|
||||
[1.5,2.7] [0,1,2,3]
|
||||
[2,3.6] [0,1,2,3,4]
|
||||
[2.5,4.5] [0,1,2,3,4,5]
|
||||
[3,5.4] [0,1,2,3,4,5,6]
|
||||
[3.5,6.3] [0,1,2,3,4,5,6,7]
|
||||
[4,7.2] [0,1,2,3,4,5,6,7,8]
|
||||
[4.5,8.1] [0,1,2,3,4,5,6,7,8,9]
|
||||
[5,9] [0,1,2,3,4,5,6,7,8,9,10]
|
7
dbms/tests/queries/0_stateless/00291_array_reduce.sql
Normal file
7
dbms/tests/queries/0_stateless/00291_array_reduce.sql
Normal file
@ -0,0 +1,7 @@
|
||||
SELECT
|
||||
arrayReduce('uniq', [1, 2, 1]) AS a,
|
||||
arrayReduce('uniq', [1, 2, 2, 1], ['hello', 'world', '', '']) AS b,
|
||||
arrayReduce('uniqUpTo(5)', [1, 2, 2, 1], materialize(['hello', 'world', '', ''])) AS c,
|
||||
arrayReduce('uniqExactIf', [1, 2, 3, 4], [1, 0, 1, 1]) AS d;
|
||||
|
||||
SELECT arrayReduce('quantiles(0.5, 0.9)', range(number) AS r), r FROM system.numbers LIMIT 12;
|
Loading…
Reference in New Issue
Block a user