dbms: added function 'arrayReduce' [#METR-19264].

This commit is contained in:
Alexey Milovidov 2015-12-13 13:43:49 +03:00
parent 9060931654
commit 1e4def963d
4 changed files with 201 additions and 0 deletions

View File

@ -18,6 +18,12 @@
#include <DB/Interpreters/AggregationCommon.h>
#include <DB/Functions/NumberTraits.h>
#include <DB/Functions/FunctionsConditional.h>
#include <DB/AggregateFunctions/IAggregateFunction.h>
#include <DB/AggregateFunctions/AggregateFunctionFactory.h>
#include <DB/Parsers/ExpressionListParsers.h>
#include <DB/Parsers/parseQuery.h>
#include <DB/Parsers/ASTExpressionList.h>
#include <DB/Parsers/ASTLiteral.h>
#include <ext/range.hpp>
@ -47,6 +53,8 @@ namespace DB
* - для кортежей из элементов на соответствующих позициях в нескольких массивах.
*
* emptyArrayToSingle(arr) - заменить пустые массивы на массивы из одного элемента со значением "по-умолчанию".
*
* arrayReduce('agg', arr1, ...) - применить агрегатную функцию agg к массивам arr1...
*/
@ -2320,6 +2328,178 @@ private:
};
/** Применяет к массиву агрегатную функцию и возвращает её результат.
* Также может быть применена к нескольким массивам одинаковых размеров, если агрегатная функция принимает несколько аргументов.
*/
class FunctionArrayReduce : public IFunction
{
public:
static constexpr auto name = "arrayReduce";
static IFunction * create(const Context & context) { return new FunctionArrayReduce; }
/// Получить имя функции.
String getName() const override
{
return name;
}
void getReturnTypeAndPrerequisites(
const ColumnsWithTypeAndName & arguments,
DataTypePtr & out_return_type,
std::vector<ExpressionAction> & out_prerequisites) override
{
/// Первый аргумент - константная строка с именем агрегатной функции (возможно, с параметрами в скобках, например: "quantile(0.99)").
if (arguments.size() < 2)
throw Exception("Number of arguments for function " + getName() + " doesn't match: passed "
+ toString(arguments.size()) + ", should be at least 2.",
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
const ColumnConstString * aggregate_function_name_column = typeid_cast<const ColumnConstString *>(arguments[0].column.get());
if (!aggregate_function_name_column)
throw Exception("First argument for function " + getName() + " must be constant string: name of aggregate function.",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
DataTypes argument_types(arguments.size() - 1);
for (size_t i = 1, size = arguments.size(); i < size; ++i)
{
const DataTypeArray * arg = typeid_cast<const DataTypeArray *>(arguments[i].type.get());
if (!arg)
throw Exception("Argument " + toString(i) + " for function " + getName() + " must be array.",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
argument_types[i - 1] = arg->getNestedType()->clone();
}
if (!aggregate_function)
{
const String & aggregate_function_name_with_params = aggregate_function_name_column->getData();
if (aggregate_function_name_with_params.empty())
throw Exception("First argument for function " + getName() + " (name of aggregate function) cannot be empty.",
ErrorCodes::BAD_ARGUMENTS);
bool has_parameters = ')' == aggregate_function_name_with_params.back();
String aggregate_function_name = aggregate_function_name_with_params;
String parameters;
Array params_row;
if (has_parameters)
{
size_t pos = aggregate_function_name_with_params.find('(');
if (pos == std::string::npos || pos + 2 >= aggregate_function_name_with_params.size())
throw Exception("First argument for function " + getName() + " doesn't look like aggregate function name.",
ErrorCodes::BAD_ARGUMENTS);
aggregate_function_name = aggregate_function_name_with_params.substr(0, pos);
parameters = aggregate_function_name_with_params.substr(pos + 1, aggregate_function_name_with_params.size() - pos - 2);
if (aggregate_function_name.empty())
throw Exception("First argument for function " + getName() + " doesn't look like aggregate function name.",
ErrorCodes::BAD_ARGUMENTS);
ParserExpressionList params_parser(false);
ASTPtr args_ast = parseQuery(params_parser,
parameters.data(), parameters.data() + parameters.size(),
"parameters of aggregate function");
ASTExpressionList & args_list = typeid_cast<ASTExpressionList &>(*args_ast);
if (args_list.children.empty())
throw Exception("Incorrect list of parameters to aggregate function "
+ aggregate_function_name, ErrorCodes::BAD_ARGUMENTS);
params_row.reserve(args_list.children.size());
for (const auto & child : args_list.children)
{
const ASTLiteral * lit = typeid_cast<const ASTLiteral *>(child.get());
if (!lit)
throw Exception("Parameters to aggregate functions must be literals",
ErrorCodes::PARAMETERS_TO_AGGREGATE_FUNCTIONS_MUST_BE_LITERALS);
params_row.push_back(lit->value);
}
}
aggregate_function = AggregateFunctionFactory().get(aggregate_function_name, argument_types);
if (has_parameters)
aggregate_function->setParameters(params_row);
aggregate_function->setArguments(argument_types);
}
out_return_type = aggregate_function->getReturnType();
}
void execute(Block & block, const ColumnNumbers & arguments, size_t result) override
{
IAggregateFunction & agg_func = *aggregate_function.get();
std::unique_ptr<char[]> place_holder { new char[agg_func.sizeOfData()] };
char * place = place_holder.get();
size_t rows = block.rowsInFirstColumn();
/// Агрегатные функции не поддерживают константные столбцы. Поэтому, материализуем их.
std::vector<ColumnPtr> materialized_columns;
std::vector<const IColumn *> aggregate_arguments_vec(arguments.size() - 1);
for (size_t i = 0, size = arguments.size() - 1; i < size; ++i)
{
const IColumn * col = block.unsafeGetByPosition(arguments[i + 1]).column.get();
if (const ColumnArray * arr = typeid_cast<const ColumnArray *>(col))
{
aggregate_arguments_vec[i] = arr->getDataPtr().get();
}
else if (const ColumnConstArray * arr = typeid_cast<const ColumnConstArray *>(col))
{
materialized_columns.emplace_back(arr->convertToFullColumn());
aggregate_arguments_vec[i] = typeid_cast<const ColumnArray &>(*materialized_columns.back().get()).getDataPtr().get();
}
else
throw Exception("Illegal column " + col->getName() + " as argument of function " + getName(), ErrorCodes::ILLEGAL_COLUMN);
}
const IColumn ** aggregate_arguments = aggregate_arguments_vec.data();
const ColumnArray::Offsets_t & offsets = typeid_cast<const ColumnArray &>(!materialized_columns.empty()
? *materialized_columns.front().get()
: *block.unsafeGetByPosition(arguments[1]).column.get()).getOffsets();
ColumnPtr result_holder = block.getByPosition(result).type->createColumn();
block.getByPosition(result).column = result_holder;
IColumn & res_col = *result_holder.get();
ColumnArray::Offset_t current_offset = 0;
for (size_t i = 0; i < rows; ++i)
{
agg_func.create(place);
ColumnArray::Offset_t next_offset = offsets[i];
try
{
for (size_t j = current_offset; j < next_offset; ++j)
agg_func.add(place, aggregate_arguments, j);
agg_func.insertResultInto(place, res_col);
}
catch (...)
{
agg_func.destroy(place);
throw;
}
agg_func.destroy(place);
current_offset = next_offset;
}
}
private:
AggregateFunctionPtr aggregate_function;
};
struct NameHas { static constexpr auto name = "has"; };
struct NameIndexOf { static constexpr auto name = "indexOf"; };
struct NameCountEqual { static constexpr auto name = "countEqual"; };

View File

@ -29,6 +29,7 @@ void registerFunctionsArray(FunctionFactory & factory)
factory.registerFunction<FunctionEmptyArrayString>();
factory.registerFunction<FunctionEmptyArrayToSingle>();
factory.registerFunction<FunctionRange>();
factory.registerFunction<FunctionArrayReduce>();
}
}

View File

@ -0,0 +1,13 @@
2 4 4 3
[nan,nan] []
[0,0] [0]
[0.5,0.9] [0,1]
[1,1.8] [0,1,2]
[1.5,2.7] [0,1,2,3]
[2,3.6] [0,1,2,3,4]
[2.5,4.5] [0,1,2,3,4,5]
[3,5.4] [0,1,2,3,4,5,6]
[3.5,6.3] [0,1,2,3,4,5,6,7]
[4,7.2] [0,1,2,3,4,5,6,7,8]
[4.5,8.1] [0,1,2,3,4,5,6,7,8,9]
[5,9] [0,1,2,3,4,5,6,7,8,9,10]

View File

@ -0,0 +1,7 @@
SELECT
arrayReduce('uniq', [1, 2, 1]) AS a,
arrayReduce('uniq', [1, 2, 2, 1], ['hello', 'world', '', '']) AS b,
arrayReduce('uniqUpTo(5)', [1, 2, 2, 1], materialize(['hello', 'world', '', ''])) AS c,
arrayReduce('uniqExactIf', [1, 2, 3, 4], [1, 0, 1, 1]) AS d;
SELECT arrayReduce('quantiles(0.5, 0.9)', range(number) AS r), r FROM system.numbers LIMIT 12;