mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-09-20 16:50:48 +00:00
Allow constexpr parameters for aggregate functions
This commit is contained in:
parent
7a993404b4
commit
0e621788c7
@ -4,6 +4,8 @@
|
||||
#include <Parsers/ExpressionListParsers.h>
|
||||
#include <Parsers/parseQuery.h>
|
||||
|
||||
#include <Interpreters/evaluateConstantExpression.h>
|
||||
|
||||
|
||||
namespace DB
|
||||
{
|
||||
@ -15,7 +17,7 @@ namespace ErrorCodes
|
||||
extern const int PARAMETERS_TO_AGGREGATE_FUNCTIONS_MUST_BE_LITERALS;
|
||||
}
|
||||
|
||||
Array getAggregateFunctionParametersArray(const ASTPtr & expression_list, const std::string & error_context)
|
||||
Array getAggregateFunctionParametersArray(const ASTPtr & expression_list, const std::string & error_context, ContextPtr context)
|
||||
{
|
||||
const ASTs & parameters = expression_list->children;
|
||||
if (parameters.empty())
|
||||
@ -25,25 +27,25 @@ Array getAggregateFunctionParametersArray(const ASTPtr & expression_list, const
|
||||
|
||||
for (size_t i = 0; i < parameters.size(); ++i)
|
||||
{
|
||||
const auto * literal = parameters[i]->as<ASTLiteral>();
|
||||
|
||||
ASTPtr func_literal;
|
||||
if (!literal)
|
||||
if (const auto * func = parameters[i]->as<ASTFunction>())
|
||||
if ((func_literal = func->toLiteral()))
|
||||
literal = func_literal->as<ASTLiteral>();
|
||||
|
||||
if (!literal)
|
||||
ASTPtr literal;
|
||||
try
|
||||
{
|
||||
throw Exception(
|
||||
ErrorCodes::PARAMETERS_TO_AGGREGATE_FUNCTIONS_MUST_BE_LITERALS,
|
||||
"Parameters to aggregate functions must be literals. "
|
||||
"Got parameter '{}'{}",
|
||||
parameters[i]->formatForErrorMessage(),
|
||||
(error_context.empty() ? "" : " (in " + error_context +")"));
|
||||
literal = evaluateConstantExpressionAsLiteral(parameters[i], context);
|
||||
}
|
||||
catch (Exception & e)
|
||||
{
|
||||
if (e.code() == ErrorCodes::BAD_ARGUMENTS)
|
||||
throw Exception(
|
||||
ErrorCodes::PARAMETERS_TO_AGGREGATE_FUNCTIONS_MUST_BE_LITERALS,
|
||||
"Parameters to aggregate functions must be literals. "
|
||||
"Got parameter '{}'{}",
|
||||
parameters[i]->formatForErrorMessage(),
|
||||
(error_context.empty() ? "" : " (in " + error_context +")"));
|
||||
|
||||
throw;
|
||||
}
|
||||
|
||||
params_row[i] = literal->value;
|
||||
params_row[i] = literal->as<ASTLiteral>()->value;
|
||||
}
|
||||
|
||||
return params_row;
|
||||
@ -54,7 +56,8 @@ void getAggregateFunctionNameAndParametersArray(
|
||||
const std::string & aggregate_function_name_with_params,
|
||||
std::string & aggregate_function_name,
|
||||
Array & aggregate_function_parameters,
|
||||
const std::string & error_context)
|
||||
const std::string & error_context,
|
||||
ContextPtr context)
|
||||
{
|
||||
if (aggregate_function_name_with_params.back() != ')')
|
||||
{
|
||||
@ -84,7 +87,7 @@ void getAggregateFunctionNameAndParametersArray(
|
||||
throw Exception("Incorrect list of parameters to aggregate function "
|
||||
+ aggregate_function_name, ErrorCodes::BAD_ARGUMENTS);
|
||||
|
||||
aggregate_function_parameters = getAggregateFunctionParametersArray(args_ast);
|
||||
aggregate_function_parameters = getAggregateFunctionParametersArray(args_ast, error_context, context);
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -1,19 +1,23 @@
|
||||
#pragma once
|
||||
|
||||
#include <Parsers/ASTLiteral.h>
|
||||
#include <Parsers/ASTExpressionList.h>
|
||||
#include <Interpreters/Context_fwd.h>
|
||||
|
||||
|
||||
namespace DB
|
||||
{
|
||||
struct Settings;
|
||||
|
||||
Array getAggregateFunctionParametersArray(const ASTPtr & expression_list, const std::string & error_context = "");
|
||||
|
||||
Array getAggregateFunctionParametersArray(
|
||||
const ASTPtr & expression_list,
|
||||
const std::string & error_context,
|
||||
ContextPtr context);
|
||||
|
||||
void getAggregateFunctionNameAndParametersArray(
|
||||
const std::string & aggregate_function_name_with_params,
|
||||
std::string & aggregate_function_name,
|
||||
Array & aggregate_function_parameters,
|
||||
const std::string & error_context);
|
||||
const std::string & error_context,
|
||||
ContextPtr context);
|
||||
|
||||
}
|
||||
|
@ -33,11 +33,12 @@ namespace ErrorCodes
|
||||
* arrayReduce('agg', arr1, ...) - apply the aggregate function `agg` to arrays `arr1...`
|
||||
* If multiple arrays passed, then elements on corresponding positions are passed as multiple arguments to the aggregate function.
|
||||
*/
|
||||
class FunctionArrayReduce : public IFunction
|
||||
class FunctionArrayReduce : public IFunction, private WithContext
|
||||
{
|
||||
public:
|
||||
static constexpr auto name = "arrayReduce";
|
||||
static FunctionPtr create(ContextPtr) { return std::make_shared<FunctionArrayReduce>(); }
|
||||
static FunctionPtr create(ContextPtr context_) { return std::make_shared<FunctionArrayReduce>(context_); }
|
||||
FunctionArrayReduce(ContextPtr context_) : WithContext(context_) {}
|
||||
|
||||
String getName() const override { return name; }
|
||||
|
||||
@ -95,7 +96,7 @@ DataTypePtr FunctionArrayReduce::getReturnTypeImpl(const ColumnsWithTypeAndName
|
||||
String aggregate_function_name;
|
||||
Array params_row;
|
||||
getAggregateFunctionNameAndParametersArray(aggregate_function_name_with_params,
|
||||
aggregate_function_name, params_row, "function " + getName());
|
||||
aggregate_function_name, params_row, "function " + getName(), getContext());
|
||||
|
||||
AggregateFunctionProperties properties;
|
||||
aggregate_function = AggregateFunctionFactory::instance().get(aggregate_function_name, argument_types, params_row, properties);
|
||||
|
@ -35,12 +35,13 @@ namespace ErrorCodes
|
||||
*
|
||||
* arrayReduceInRanges('agg', indices, lengths, arr1, ...)
|
||||
*/
|
||||
class FunctionArrayReduceInRanges : public IFunction
|
||||
class FunctionArrayReduceInRanges : public IFunction, private WithContext
|
||||
{
|
||||
public:
|
||||
static const size_t minimum_step = 64;
|
||||
static constexpr auto name = "arrayReduceInRanges";
|
||||
static FunctionPtr create(ContextPtr) { return std::make_shared<FunctionArrayReduceInRanges>(); }
|
||||
static FunctionPtr create(ContextPtr context_) { return std::make_shared<FunctionArrayReduceInRanges>(context_); }
|
||||
FunctionArrayReduceInRanges(ContextPtr context_) : WithContext(context_) {}
|
||||
|
||||
String getName() const override { return name; }
|
||||
|
||||
@ -113,7 +114,7 @@ DataTypePtr FunctionArrayReduceInRanges::getReturnTypeImpl(const ColumnsWithType
|
||||
String aggregate_function_name;
|
||||
Array params_row;
|
||||
getAggregateFunctionNameAndParametersArray(aggregate_function_name_with_params,
|
||||
aggregate_function_name, params_row, "function " + getName());
|
||||
aggregate_function_name, params_row, "function " + getName(), getContext());
|
||||
|
||||
AggregateFunctionProperties properties;
|
||||
aggregate_function = AggregateFunctionFactory::instance().get(aggregate_function_name, argument_types, params_row, properties);
|
||||
|
@ -25,11 +25,12 @@ namespace ErrorCodes
|
||||
namespace
|
||||
{
|
||||
|
||||
class FunctionInitializeAggregation : public IFunction
|
||||
class FunctionInitializeAggregation : public IFunction, private WithContext
|
||||
{
|
||||
public:
|
||||
static constexpr auto name = "initializeAggregation";
|
||||
static FunctionPtr create(ContextPtr) { return std::make_shared<FunctionInitializeAggregation>(); }
|
||||
static FunctionPtr create(ContextPtr context_) { return std::make_shared<FunctionInitializeAggregation>(context_); }
|
||||
FunctionInitializeAggregation(ContextPtr context_) : WithContext(context_) {}
|
||||
|
||||
String getName() const override { return name; }
|
||||
|
||||
@ -78,7 +79,7 @@ DataTypePtr FunctionInitializeAggregation::getReturnTypeImpl(const ColumnsWithTy
|
||||
String aggregate_function_name;
|
||||
Array params_row;
|
||||
getAggregateFunctionNameAndParametersArray(aggregate_function_name_with_params,
|
||||
aggregate_function_name, params_row, "function " + getName());
|
||||
aggregate_function_name, params_row, "function " + getName(), getContext());
|
||||
|
||||
AggregateFunctionProperties properties;
|
||||
aggregate_function = AggregateFunctionFactory::instance().get(aggregate_function_name, argument_types, params_row, properties);
|
||||
|
@ -468,7 +468,7 @@ bool ExpressionAnalyzer::makeAggregateDescriptions(ActionsDAGPtr & actions)
|
||||
}
|
||||
|
||||
AggregateFunctionProperties properties;
|
||||
aggregate.parameters = (node->parameters) ? getAggregateFunctionParametersArray(node->parameters) : Array();
|
||||
aggregate.parameters = (node->parameters) ? getAggregateFunctionParametersArray(node->parameters, "", getContext()) : Array();
|
||||
aggregate.function = AggregateFunctionFactory::instance().get(node->name, types, aggregate.parameters, properties);
|
||||
|
||||
aggregate_descriptions.push_back(aggregate);
|
||||
@ -651,7 +651,7 @@ void ExpressionAnalyzer::makeWindowDescriptions(ActionsDAGPtr actions)
|
||||
window_function.function_parameters
|
||||
= window_function.function_node->parameters
|
||||
? getAggregateFunctionParametersArray(
|
||||
window_function.function_node->parameters)
|
||||
window_function.function_node->parameters, "", getContext())
|
||||
: Array();
|
||||
|
||||
// Requiring a constant reference to a shared pointer to non-const AST
|
||||
|
@ -49,17 +49,20 @@ std::pair<Field, std::shared_ptr<const IDataType>> evaluateConstantExpression(co
|
||||
expr_for_constant_folding->execute(block_with_constants);
|
||||
|
||||
if (!block_with_constants || block_with_constants.rows() == 0)
|
||||
throw Exception("Logical error: empty block after evaluation of constant expression for IN, VALUES or LIMIT", ErrorCodes::LOGICAL_ERROR);
|
||||
throw Exception("Logical error: empty block after evaluation of constant expression for IN, VALUES or LIMIT or aggregate function parameter",
|
||||
ErrorCodes::LOGICAL_ERROR);
|
||||
|
||||
if (!block_with_constants.has(name))
|
||||
throw Exception("Element of set in IN, VALUES or LIMIT is not a constant expression (result column not found): " + name, ErrorCodes::BAD_ARGUMENTS);
|
||||
throw Exception(ErrorCodes::BAD_ARGUMENTS,
|
||||
"Element of set in IN, VALUES or LIMIT or aggregate function parameter is not a constant expression (result column not found): {}", name);
|
||||
|
||||
const ColumnWithTypeAndName & result = block_with_constants.getByName(name);
|
||||
const IColumn & result_column = *result.column;
|
||||
|
||||
/// Expressions like rand() or now() are not constant
|
||||
if (!isColumnConst(result_column))
|
||||
throw Exception("Element of set in IN, VALUES or LIMIT is not a constant expression (result column is not const): " + name, ErrorCodes::BAD_ARGUMENTS);
|
||||
throw Exception(ErrorCodes::BAD_ARGUMENTS,
|
||||
"Element of set in IN, VALUES or LIMIT or aggregate function parameter is not a constant expression (result column is not const): {}", name);
|
||||
|
||||
return std::make_pair(result_column[0], result.type);
|
||||
}
|
||||
|
@ -116,8 +116,11 @@ static bool compareRetentions(const Graphite::Retention & a, const Graphite::Ret
|
||||
* </default>
|
||||
* </graphite_rollup>
|
||||
*/
|
||||
static void
|
||||
appendGraphitePattern(const Poco::Util::AbstractConfiguration & config, const String & config_element, Graphite::Patterns & patterns)
|
||||
static void appendGraphitePattern(
|
||||
const Poco::Util::AbstractConfiguration & config,
|
||||
const String & config_element,
|
||||
Graphite::Patterns & out_patterns,
|
||||
ContextPtr context)
|
||||
{
|
||||
Graphite::Pattern pattern;
|
||||
|
||||
@ -137,7 +140,7 @@ appendGraphitePattern(const Poco::Util::AbstractConfiguration & config, const St
|
||||
String aggregate_function_name;
|
||||
Array params_row;
|
||||
getAggregateFunctionNameAndParametersArray(
|
||||
aggregate_function_name_with_params, aggregate_function_name, params_row, "GraphiteMergeTree storage initialization");
|
||||
aggregate_function_name_with_params, aggregate_function_name, params_row, "GraphiteMergeTree storage initialization", context);
|
||||
|
||||
/// TODO Not only Float64
|
||||
AggregateFunctionProperties properties;
|
||||
@ -181,7 +184,7 @@ appendGraphitePattern(const Poco::Util::AbstractConfiguration & config, const St
|
||||
if (pattern.type & pattern.TypeRetention) /// TypeRetention or TypeAll
|
||||
std::sort(pattern.retentions.begin(), pattern.retentions.end(), compareRetentions);
|
||||
|
||||
patterns.emplace_back(pattern);
|
||||
out_patterns.emplace_back(pattern);
|
||||
}
|
||||
|
||||
static void setGraphitePatternsFromConfig(ContextPtr context, const String & config_element, Graphite::Params & params)
|
||||
@ -204,7 +207,7 @@ static void setGraphitePatternsFromConfig(ContextPtr context, const String & con
|
||||
{
|
||||
if (startsWith(key, "pattern"))
|
||||
{
|
||||
appendGraphitePattern(config, config_element + "." + key, params.patterns);
|
||||
appendGraphitePattern(config, config_element + "." + key, params.patterns, context);
|
||||
}
|
||||
else if (key == "default")
|
||||
{
|
||||
@ -219,7 +222,7 @@ static void setGraphitePatternsFromConfig(ContextPtr context, const String & con
|
||||
}
|
||||
|
||||
if (config.has(config_element + ".default"))
|
||||
appendGraphitePattern(config, config_element + "." + ".default", params.patterns);
|
||||
appendGraphitePattern(config, config_element + "." + ".default", params.patterns, context);
|
||||
}
|
||||
|
||||
|
||||
|
@ -0,0 +1,2 @@
|
||||
[0,1,2,3,4]
|
||||
[0,1,2,3,4]
|
@ -0,0 +1,11 @@
|
||||
SELECT groupArray(2 + 3)(number) FROM numbers(10);
|
||||
SELECT groupArray('5'::UInt8)(number) FROM numbers(10);
|
||||
|
||||
SELECT groupArray()(number) FROM numbers(10); -- { serverError 36 }
|
||||
SELECT groupArray(NULL)(number) FROM numbers(10); -- { serverError 36 }
|
||||
SELECT groupArray(NULL + NULL)(number) FROM numbers(10); -- { serverError 36 }
|
||||
SELECT groupArray([])(number) FROM numbers(10); -- { serverError 36 }
|
||||
SELECT groupArray(throwIf(1))(number) FROM numbers(10); -- { serverError 395 }
|
||||
|
||||
-- Not the best error message, can be improved.
|
||||
SELECT groupArray(number)(number) FROM numbers(10); -- { serverError 47 }
|
Loading…
Reference in New Issue
Block a user