Allow constexpr parameters for aggregate functions

This commit is contained in:
Alexey Milovidov 2021-07-02 03:53:08 +03:00
parent 7a993404b4
commit 0e621788c7
10 changed files with 72 additions and 43 deletions

View File

@ -4,6 +4,8 @@
#include <Parsers/ExpressionListParsers.h>
#include <Parsers/parseQuery.h>
#include <Interpreters/evaluateConstantExpression.h>
namespace DB
{
@ -15,7 +17,7 @@ namespace ErrorCodes
extern const int PARAMETERS_TO_AGGREGATE_FUNCTIONS_MUST_BE_LITERALS;
}
Array getAggregateFunctionParametersArray(const ASTPtr & expression_list, const std::string & error_context)
Array getAggregateFunctionParametersArray(const ASTPtr & expression_list, const std::string & error_context, ContextPtr context)
{
const ASTs & parameters = expression_list->children;
if (parameters.empty())
@ -25,25 +27,25 @@ Array getAggregateFunctionParametersArray(const ASTPtr & expression_list, const
for (size_t i = 0; i < parameters.size(); ++i)
{
const auto * literal = parameters[i]->as<ASTLiteral>();
ASTPtr func_literal;
if (!literal)
if (const auto * func = parameters[i]->as<ASTFunction>())
if ((func_literal = func->toLiteral()))
literal = func_literal->as<ASTLiteral>();
if (!literal)
ASTPtr literal;
try
{
throw Exception(
ErrorCodes::PARAMETERS_TO_AGGREGATE_FUNCTIONS_MUST_BE_LITERALS,
"Parameters to aggregate functions must be literals. "
"Got parameter '{}'{}",
parameters[i]->formatForErrorMessage(),
(error_context.empty() ? "" : " (in " + error_context +")"));
literal = evaluateConstantExpressionAsLiteral(parameters[i], context);
}
catch (Exception & e)
{
if (e.code() == ErrorCodes::BAD_ARGUMENTS)
throw Exception(
ErrorCodes::PARAMETERS_TO_AGGREGATE_FUNCTIONS_MUST_BE_LITERALS,
"Parameters to aggregate functions must be literals. "
"Got parameter '{}'{}",
parameters[i]->formatForErrorMessage(),
(error_context.empty() ? "" : " (in " + error_context +")"));
throw;
}
params_row[i] = literal->value;
params_row[i] = literal->as<ASTLiteral>()->value;
}
return params_row;
@ -54,7 +56,8 @@ void getAggregateFunctionNameAndParametersArray(
const std::string & aggregate_function_name_with_params,
std::string & aggregate_function_name,
Array & aggregate_function_parameters,
const std::string & error_context)
const std::string & error_context,
ContextPtr context)
{
if (aggregate_function_name_with_params.back() != ')')
{
@ -84,7 +87,7 @@ void getAggregateFunctionNameAndParametersArray(
throw Exception("Incorrect list of parameters to aggregate function "
+ aggregate_function_name, ErrorCodes::BAD_ARGUMENTS);
aggregate_function_parameters = getAggregateFunctionParametersArray(args_ast);
aggregate_function_parameters = getAggregateFunctionParametersArray(args_ast, error_context, context);
}
}

View File

@ -1,19 +1,23 @@
#pragma once
#include <Parsers/ASTLiteral.h>
#include <Parsers/ASTExpressionList.h>
#include <Interpreters/Context_fwd.h>
namespace DB
{
struct Settings;
Array getAggregateFunctionParametersArray(const ASTPtr & expression_list, const std::string & error_context = "");
Array getAggregateFunctionParametersArray(
const ASTPtr & expression_list,
const std::string & error_context,
ContextPtr context);
void getAggregateFunctionNameAndParametersArray(
const std::string & aggregate_function_name_with_params,
std::string & aggregate_function_name,
Array & aggregate_function_parameters,
const std::string & error_context);
const std::string & error_context,
ContextPtr context);
}

View File

@ -33,11 +33,12 @@ namespace ErrorCodes
* arrayReduce('agg', arr1, ...) - apply the aggregate function `agg` to arrays `arr1...`
* If multiple arrays passed, then elements on corresponding positions are passed as multiple arguments to the aggregate function.
*/
class FunctionArrayReduce : public IFunction
class FunctionArrayReduce : public IFunction, private WithContext
{
public:
static constexpr auto name = "arrayReduce";
static FunctionPtr create(ContextPtr) { return std::make_shared<FunctionArrayReduce>(); }
static FunctionPtr create(ContextPtr context_) { return std::make_shared<FunctionArrayReduce>(context_); }
FunctionArrayReduce(ContextPtr context_) : WithContext(context_) {}
String getName() const override { return name; }
@ -95,7 +96,7 @@ DataTypePtr FunctionArrayReduce::getReturnTypeImpl(const ColumnsWithTypeAndName
String aggregate_function_name;
Array params_row;
getAggregateFunctionNameAndParametersArray(aggregate_function_name_with_params,
aggregate_function_name, params_row, "function " + getName());
aggregate_function_name, params_row, "function " + getName(), getContext());
AggregateFunctionProperties properties;
aggregate_function = AggregateFunctionFactory::instance().get(aggregate_function_name, argument_types, params_row, properties);

View File

@ -35,12 +35,13 @@ namespace ErrorCodes
*
* arrayReduceInRanges('agg', indices, lengths, arr1, ...)
*/
class FunctionArrayReduceInRanges : public IFunction
class FunctionArrayReduceInRanges : public IFunction, private WithContext
{
public:
static const size_t minimum_step = 64;
static constexpr auto name = "arrayReduceInRanges";
static FunctionPtr create(ContextPtr) { return std::make_shared<FunctionArrayReduceInRanges>(); }
static FunctionPtr create(ContextPtr context_) { return std::make_shared<FunctionArrayReduceInRanges>(context_); }
FunctionArrayReduceInRanges(ContextPtr context_) : WithContext(context_) {}
String getName() const override { return name; }
@ -113,7 +114,7 @@ DataTypePtr FunctionArrayReduceInRanges::getReturnTypeImpl(const ColumnsWithType
String aggregate_function_name;
Array params_row;
getAggregateFunctionNameAndParametersArray(aggregate_function_name_with_params,
aggregate_function_name, params_row, "function " + getName());
aggregate_function_name, params_row, "function " + getName(), getContext());
AggregateFunctionProperties properties;
aggregate_function = AggregateFunctionFactory::instance().get(aggregate_function_name, argument_types, params_row, properties);

View File

@ -25,11 +25,12 @@ namespace ErrorCodes
namespace
{
class FunctionInitializeAggregation : public IFunction
class FunctionInitializeAggregation : public IFunction, private WithContext
{
public:
static constexpr auto name = "initializeAggregation";
static FunctionPtr create(ContextPtr) { return std::make_shared<FunctionInitializeAggregation>(); }
static FunctionPtr create(ContextPtr context_) { return std::make_shared<FunctionInitializeAggregation>(context_); }
FunctionInitializeAggregation(ContextPtr context_) : WithContext(context_) {}
String getName() const override { return name; }
@ -78,7 +79,7 @@ DataTypePtr FunctionInitializeAggregation::getReturnTypeImpl(const ColumnsWithTy
String aggregate_function_name;
Array params_row;
getAggregateFunctionNameAndParametersArray(aggregate_function_name_with_params,
aggregate_function_name, params_row, "function " + getName());
aggregate_function_name, params_row, "function " + getName(), getContext());
AggregateFunctionProperties properties;
aggregate_function = AggregateFunctionFactory::instance().get(aggregate_function_name, argument_types, params_row, properties);

View File

@ -468,7 +468,7 @@ bool ExpressionAnalyzer::makeAggregateDescriptions(ActionsDAGPtr & actions)
}
AggregateFunctionProperties properties;
aggregate.parameters = (node->parameters) ? getAggregateFunctionParametersArray(node->parameters) : Array();
aggregate.parameters = (node->parameters) ? getAggregateFunctionParametersArray(node->parameters, "", getContext()) : Array();
aggregate.function = AggregateFunctionFactory::instance().get(node->name, types, aggregate.parameters, properties);
aggregate_descriptions.push_back(aggregate);
@ -651,7 +651,7 @@ void ExpressionAnalyzer::makeWindowDescriptions(ActionsDAGPtr actions)
window_function.function_parameters
= window_function.function_node->parameters
? getAggregateFunctionParametersArray(
window_function.function_node->parameters)
window_function.function_node->parameters, "", getContext())
: Array();
// Requiring a constant reference to a shared pointer to non-const AST

View File

@ -49,17 +49,20 @@ std::pair<Field, std::shared_ptr<const IDataType>> evaluateConstantExpression(co
expr_for_constant_folding->execute(block_with_constants);
if (!block_with_constants || block_with_constants.rows() == 0)
throw Exception("Logical error: empty block after evaluation of constant expression for IN, VALUES or LIMIT", ErrorCodes::LOGICAL_ERROR);
throw Exception("Logical error: empty block after evaluation of constant expression for IN, VALUES or LIMIT or aggregate function parameter",
ErrorCodes::LOGICAL_ERROR);
if (!block_with_constants.has(name))
throw Exception("Element of set in IN, VALUES or LIMIT is not a constant expression (result column not found): " + name, ErrorCodes::BAD_ARGUMENTS);
throw Exception(ErrorCodes::BAD_ARGUMENTS,
"Element of set in IN, VALUES or LIMIT or aggregate function parameter is not a constant expression (result column not found): {}", name);
const ColumnWithTypeAndName & result = block_with_constants.getByName(name);
const IColumn & result_column = *result.column;
/// Expressions like rand() or now() are not constant
if (!isColumnConst(result_column))
throw Exception("Element of set in IN, VALUES or LIMIT is not a constant expression (result column is not const): " + name, ErrorCodes::BAD_ARGUMENTS);
throw Exception(ErrorCodes::BAD_ARGUMENTS,
"Element of set in IN, VALUES or LIMIT or aggregate function parameter is not a constant expression (result column is not const): {}", name);
return std::make_pair(result_column[0], result.type);
}

View File

@ -116,8 +116,11 @@ static bool compareRetentions(const Graphite::Retention & a, const Graphite::Ret
* </default>
* </graphite_rollup>
*/
static void
appendGraphitePattern(const Poco::Util::AbstractConfiguration & config, const String & config_element, Graphite::Patterns & patterns)
static void appendGraphitePattern(
const Poco::Util::AbstractConfiguration & config,
const String & config_element,
Graphite::Patterns & out_patterns,
ContextPtr context)
{
Graphite::Pattern pattern;
@ -137,7 +140,7 @@ appendGraphitePattern(const Poco::Util::AbstractConfiguration & config, const St
String aggregate_function_name;
Array params_row;
getAggregateFunctionNameAndParametersArray(
aggregate_function_name_with_params, aggregate_function_name, params_row, "GraphiteMergeTree storage initialization");
aggregate_function_name_with_params, aggregate_function_name, params_row, "GraphiteMergeTree storage initialization", context);
/// TODO Not only Float64
AggregateFunctionProperties properties;
@ -181,7 +184,7 @@ appendGraphitePattern(const Poco::Util::AbstractConfiguration & config, const St
if (pattern.type & pattern.TypeRetention) /// TypeRetention or TypeAll
std::sort(pattern.retentions.begin(), pattern.retentions.end(), compareRetentions);
patterns.emplace_back(pattern);
out_patterns.emplace_back(pattern);
}
static void setGraphitePatternsFromConfig(ContextPtr context, const String & config_element, Graphite::Params & params)
@ -204,7 +207,7 @@ static void setGraphitePatternsFromConfig(ContextPtr context, const String & con
{
if (startsWith(key, "pattern"))
{
appendGraphitePattern(config, config_element + "." + key, params.patterns);
appendGraphitePattern(config, config_element + "." + key, params.patterns, context);
}
else if (key == "default")
{
@ -219,7 +222,7 @@ static void setGraphitePatternsFromConfig(ContextPtr context, const String & con
}
if (config.has(config_element + ".default"))
appendGraphitePattern(config, config_element + "." + ".default", params.patterns);
appendGraphitePattern(config, config_element + "." + ".default", params.patterns, context);
}

View File

@ -0,0 +1,2 @@
[0,1,2,3,4]
[0,1,2,3,4]

View File

@ -0,0 +1,11 @@
SELECT groupArray(2 + 3)(number) FROM numbers(10);
SELECT groupArray('5'::UInt8)(number) FROM numbers(10);
SELECT groupArray()(number) FROM numbers(10); -- { serverError 36 }
SELECT groupArray(NULL)(number) FROM numbers(10); -- { serverError 36 }
SELECT groupArray(NULL + NULL)(number) FROM numbers(10); -- { serverError 36 }
SELECT groupArray([])(number) FROM numbers(10); -- { serverError 36 }
SELECT groupArray(throwIf(1))(number) FROM numbers(10); -- { serverError 395 }
-- Not the best error message, can be improved.
SELECT groupArray(number)(number) FROM numbers(10); -- { serverError 47 }