mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-11-23 08:02:02 +00:00
Remove arithmetic operations from aggregation functions (#10047)
This commit is contained in:
parent
c7d9094a7a
commit
7757c9ec57
@ -411,6 +411,7 @@ struct Settings : public SettingsCollection<Settings>
|
||||
M(SettingBool, enable_scalar_subquery_optimization, true, "If it is set to true, prevent scalar subqueries from (de)serializing large scalar values and possibly avoid running the same subquery more than once.", 0) \
|
||||
M(SettingBool, optimize_trivial_count_query, true, "Process trivial 'SELECT count() FROM table' query from metadata.", 0) \
|
||||
M(SettingUInt64, mutations_sync, 0, "Wait for synchronous execution of ALTER TABLE UPDATE/DELETE queries (mutations). 0 - execute asynchronously. 1 - wait current server. 2 - wait all replicas if they exist.", 0) \
|
||||
M(SettingBool, optimize_arithmetic_operations_in_agr_func, true, "Removing arithmetic operations from aggregation functions", 0) \
|
||||
M(SettingBool, optimize_if_chain_to_miltiif, false, "Replace if(cond1, then1, if(cond2, ...)) chains to multiIf. Currently it's not beneficial for numeric types.", 0) \
|
||||
M(SettingBool, allow_experimental_alter_materialized_view_structure, false, "Allow atomic alter on Materialized views. Work in progress.", 0) \
|
||||
M(SettingBool, enable_early_constant_folding, true, "Enable query optimization where we analyze function and subqueries results and rewrite query if there're constants there", 0) \
|
||||
|
313
src/Interpreters/ArithmeticOperationsInAgrFuncOptimize.cpp
Normal file
313
src/Interpreters/ArithmeticOperationsInAgrFuncOptimize.cpp
Normal file
@ -0,0 +1,313 @@
|
||||
#include <Common/typeid_cast.h>
|
||||
#include <Parsers/ASTLiteral.h>
|
||||
#include <Parsers/ASTFunction.h>
|
||||
#include <Interpreters/ArithmeticOperationsInAgrFuncOptimize.h>
|
||||
#include <IO/WriteHelpers.h>
|
||||
#include <Parsers/ASTTablesInSelectQuery.h>
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
namespace ErrorCodes
|
||||
{
|
||||
extern const int LOGICAL_ERROR;
|
||||
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
|
||||
extern const int UNEXPECTED_AST_STRUCTURE;
|
||||
}
|
||||
namespace
|
||||
{
|
||||
constexpr const char * min = "min";
|
||||
constexpr const char * max = "max";
|
||||
constexpr const char * mul = "multiply";
|
||||
constexpr const char * plus = "plus";
|
||||
constexpr const char * sum = "sum";
|
||||
}
|
||||
|
||||
bool isConstantField(const Field & field)
|
||||
{
|
||||
return field.getType() == Field::Types::Int64 ||
|
||||
field.getType() == Field::Types::UInt64 ||
|
||||
field.getType() == Field::Types::Int128 ||
|
||||
field.getType() == Field::Types::UInt128;
|
||||
}
|
||||
|
||||
bool onlyConstsInside(const ASTFunction * func_node)
|
||||
{
|
||||
return !(func_node->arguments->children[0]->as<ASTFunction>()) &&
|
||||
(func_node->arguments->children.size() == 2 &&
|
||||
!(func_node->arguments->children[1]->as<ASTFunction>()));
|
||||
}
|
||||
|
||||
bool inappropriateNameInside(const ASTFunction * func_node, const char * inter_func_name)
|
||||
{
|
||||
return (func_node->arguments->children[0]->as<ASTFunction>() &&
|
||||
inter_func_name != func_node->arguments->children[0]->as<ASTFunction>()->name) ||
|
||||
(func_node->arguments->children.size() == 2 &&
|
||||
func_node->arguments->children[1]->as<ASTFunction>() &&
|
||||
inter_func_name != func_node->arguments->children[1]->as<ASTFunction>()->name);
|
||||
}
|
||||
|
||||
bool isInappropriate(const ASTPtr & node, const char * inter_func_name)
|
||||
{
|
||||
return !node->as<ASTFunction>() || inter_func_name != node->as<ASTFunction>()->name;
|
||||
}
|
||||
|
||||
ASTFunction * getInternalFunction(const ASTFunction * f_n)
|
||||
{
|
||||
const auto * function_args = f_n->arguments->as<ASTExpressionList>();
|
||||
if (!function_args || function_args->children.size() != 1)
|
||||
throw Exception("Wrong number of arguments for function" + f_n->name + "(" + toString(function_args->children.size()) + " instead of 1)",
|
||||
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
|
||||
|
||||
return f_n->arguments->children[0]->as<ASTFunction>();
|
||||
}
|
||||
|
||||
ASTFunction * treeFiller(ASTFunction * old_tree, const ASTs & nodes_array, size_t size, const char * name)
|
||||
{
|
||||
for (size_t i = 0; i < size; ++i)
|
||||
{
|
||||
old_tree->arguments->children = {};
|
||||
old_tree->arguments->children.push_back(nodes_array[i]);
|
||||
|
||||
old_tree->arguments->children.push_back(makeASTFunction(name));
|
||||
old_tree = old_tree->arguments->children[1]->as<ASTFunction>();
|
||||
}
|
||||
return old_tree;
|
||||
}
|
||||
|
||||
/// scalar values from the first level
|
||||
std::pair<ASTs, ASTs> tryGetConst(const char * name, const ASTs & arguments)
|
||||
{
|
||||
ASTs const_num;
|
||||
ASTs not_const;
|
||||
|
||||
for (const auto & arg : arguments)
|
||||
{
|
||||
if (const auto * literal = arg->as<ASTLiteral>())
|
||||
{
|
||||
if (isConstantField(literal->value))
|
||||
const_num.push_back(arg);
|
||||
else
|
||||
not_const.push_back(arg);
|
||||
}
|
||||
else
|
||||
not_const.push_back(arg);
|
||||
}
|
||||
|
||||
if ((name == plus || name == mul) && const_num.size() + not_const.size() != 2)
|
||||
{
|
||||
throw Exception("Wrong number of arguments for function 'plus' or 'multiply' (" + toString(const_num.size() + not_const.size()) + " instead of 2)",
|
||||
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
|
||||
}
|
||||
|
||||
return {const_num, not_const};
|
||||
}
|
||||
|
||||
std::pair<ASTs, ASTs> findAllConsts(const ASTFunction * func_node, const char * inter_func_name)
|
||||
{
|
||||
if (!func_node->arguments)
|
||||
return {};
|
||||
|
||||
if (onlyConstsInside(func_node))
|
||||
return tryGetConst(func_node->name.c_str(), func_node->arguments->children);
|
||||
else if (inappropriateNameInside(func_node, inter_func_name))
|
||||
{
|
||||
bool first_child_is_const = func_node->arguments->children[0]->as<ASTLiteral>() &&
|
||||
isConstantField(func_node->arguments->children[0]->as<ASTLiteral>()->value);
|
||||
bool second_child_is_const = func_node->arguments->children.size() == 2 &&
|
||||
func_node->arguments->children[1]->as<ASTLiteral>() &&
|
||||
isConstantField(func_node->arguments->children[1]->as<ASTLiteral>()->value);
|
||||
if (first_child_is_const)
|
||||
return {{func_node->arguments->children[0]}, {func_node->arguments->children[1]}};
|
||||
else if (second_child_is_const)
|
||||
return {{func_node->arguments->children[1]}, {func_node->arguments->children[0]}};
|
||||
|
||||
if (isInappropriate(func_node->arguments->children[0], inter_func_name) && isInappropriate(func_node->arguments->children[1], inter_func_name))
|
||||
return {{}, {func_node->arguments->children[0], func_node->arguments->children[1]}};
|
||||
else if (isInappropriate(func_node->arguments->children[0], inter_func_name))
|
||||
{
|
||||
std::pair<ASTs, ASTs> ans = findAllConsts(func_node->arguments->children[1]->as<ASTFunction>(), inter_func_name);
|
||||
ans.second.push_back(func_node->arguments->children[0]);
|
||||
return ans;
|
||||
}
|
||||
|
||||
std::pair<ASTs, ASTs> ans = findAllConsts(func_node->arguments->children[0]->as<ASTFunction>(), inter_func_name);
|
||||
ans.second.push_back(func_node->arguments->children[1]);
|
||||
return ans;
|
||||
}
|
||||
|
||||
std::pair<ASTs, ASTs> fl = tryGetConst(func_node->name.c_str(), func_node->arguments->children);
|
||||
ASTs first_lvl_consts = fl.first;
|
||||
ASTs first_lvl_not_consts = fl.second;
|
||||
if (!first_lvl_not_consts[0]->as<ASTFunction>())
|
||||
return {first_lvl_consts, first_lvl_not_consts};
|
||||
|
||||
std::pair<ASTs, ASTs> ans = findAllConsts(first_lvl_not_consts[0]->as<ASTFunction>(), inter_func_name);
|
||||
ASTs all_consts = ans.first;
|
||||
ASTs all_not_consts = ans.second;
|
||||
|
||||
if (first_lvl_consts.size() == 1)
|
||||
{
|
||||
if (!first_lvl_not_consts[0]->as<ASTFunction>())
|
||||
all_not_consts.push_back(first_lvl_not_consts[0]);
|
||||
|
||||
all_consts.push_back(first_lvl_consts[0]);
|
||||
}
|
||||
else if (first_lvl_consts.empty())
|
||||
{
|
||||
/// if node is inappropriate to go into it, we just add this node to all_not_consts vector
|
||||
bool first_node_inappropriate_to_go_into = isInappropriate(first_lvl_not_consts[0], inter_func_name);
|
||||
bool second_node_inappropriate_to_go_into = first_lvl_not_consts.size() == 2 &&
|
||||
isInappropriate(first_lvl_not_consts[1], inter_func_name);
|
||||
if (first_node_inappropriate_to_go_into)
|
||||
all_not_consts.push_back(first_lvl_not_consts[0]);
|
||||
|
||||
if (second_node_inappropriate_to_go_into)
|
||||
all_not_consts.push_back(first_lvl_not_consts[1]);
|
||||
}
|
||||
else
|
||||
throw Exception("did not expect that", ErrorCodes::UNEXPECTED_AST_STRUCTURE);
|
||||
return {all_consts, all_not_consts};
|
||||
}
|
||||
|
||||
/// rebuilds tree, all scalar values now outside the main func
|
||||
void buildTree(ASTFunction * cur_node, const char * func_name, const char * intro_func, const std::pair<ASTs, ASTs> & tree_comp)
|
||||
{
|
||||
ASTs cons_val = tree_comp.first;
|
||||
ASTs non_cons = tree_comp.second;
|
||||
|
||||
cur_node->name = intro_func;
|
||||
cur_node = treeFiller(cur_node, cons_val, cons_val.size(), intro_func);
|
||||
cur_node->name = func_name;
|
||||
|
||||
if (non_cons.size() == 1)
|
||||
cur_node->arguments->children.push_back(non_cons[0]);
|
||||
else
|
||||
{
|
||||
cur_node->arguments->children.push_back(makeASTFunction(intro_func));
|
||||
cur_node = cur_node->arguments->children[0]->as<ASTFunction>();
|
||||
cur_node = treeFiller(cur_node, non_cons, non_cons.size() - 2, intro_func);
|
||||
cur_node->arguments->children = {non_cons[non_cons.size() - 2], non_cons[non_cons.size() - 1]};
|
||||
}
|
||||
}
|
||||
|
||||
void sumOptimize(ASTFunction * f_n)
|
||||
{
|
||||
const auto * function_args = f_n->arguments->as<ASTExpressionList>();
|
||||
|
||||
if (!function_args || function_args->children.size() != 1)
|
||||
throw Exception("Wrong number of arguments for function 'sum' (" + toString(function_args->children.size()) + " instead of 1)",
|
||||
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
|
||||
|
||||
ASTFunction * inter_node = getInternalFunction(f_n);
|
||||
if (inter_node && inter_node->name == mul)
|
||||
{
|
||||
std::pair<ASTs, ASTs> nodes = findAllConsts(f_n, mul);
|
||||
|
||||
if (nodes.first.empty())
|
||||
return;
|
||||
|
||||
buildTree(f_n, sum, mul, nodes);
|
||||
}
|
||||
}
|
||||
|
||||
void minOptimize(ASTFunction * f_n)
|
||||
{
|
||||
ASTFunction * inter_node = getInternalFunction(f_n);
|
||||
if (inter_node && inter_node->name == mul)
|
||||
{
|
||||
int sign = 1;
|
||||
std::pair<ASTs, ASTs> nodes = findAllConsts(f_n, mul);
|
||||
|
||||
if (nodes.first.empty())
|
||||
return;
|
||||
|
||||
for (const auto & arg : nodes.first)
|
||||
{
|
||||
Int128 num = applyVisitor(FieldVisitorConvertToNumber<Int128>(), arg->as<ASTLiteral>()->value);
|
||||
|
||||
/// if multiplication is negative, min function becomes max
|
||||
|
||||
if ((arg->as<ASTLiteral>()->value.getType() == Field::Types::Int64 ||
|
||||
arg->as<ASTLiteral>()->value.getType() == Field::Types::Int128) && num < static_cast<Int128>(0))
|
||||
sign *= -1;
|
||||
}
|
||||
|
||||
if (sign == -1)
|
||||
buildTree(f_n, max, mul, nodes);
|
||||
else
|
||||
buildTree(f_n, min, mul, nodes);
|
||||
}
|
||||
else if (inter_node && inter_node->name == plus)
|
||||
{
|
||||
std::pair<ASTs, ASTs> nodes = findAllConsts(f_n, plus);
|
||||
buildTree(f_n, min, plus, nodes);
|
||||
}
|
||||
}
|
||||
|
||||
void maxOptimize(ASTFunction * f_n)
|
||||
{
|
||||
ASTFunction * inter_node = getInternalFunction(f_n);
|
||||
if (inter_node && inter_node->name == mul)
|
||||
{
|
||||
int sign = 1;
|
||||
std::pair<ASTs, ASTs> nodes = findAllConsts(f_n, mul);
|
||||
|
||||
if (nodes.first.empty())
|
||||
return;
|
||||
|
||||
for (const auto & arg: nodes.first)
|
||||
{
|
||||
Int128 num = applyVisitor(FieldVisitorConvertToNumber<Int128>(), arg->as<ASTLiteral>()->value);
|
||||
|
||||
/// if multiplication is negative, max function becomes min
|
||||
if ((arg->as<ASTLiteral>()->value.getType() == Field::Types::Int64 ||
|
||||
arg->as<ASTLiteral>()->value.getType() == Field::Types::Int128) && num < static_cast<Int128>(0))
|
||||
sign *= -1;
|
||||
}
|
||||
|
||||
if (sign == -1)
|
||||
buildTree(f_n, min, mul, nodes);
|
||||
else
|
||||
buildTree(f_n, max, mul, nodes);
|
||||
}
|
||||
else if (inter_node && inter_node->name == plus)
|
||||
{
|
||||
std::pair<ASTs, ASTs> nodes = findAllConsts(f_n, plus);
|
||||
buildTree(f_n, max, plus, nodes);
|
||||
}
|
||||
}
|
||||
|
||||
/// optimize for min, max, sum is ready, ToDo: groupBitAnd, groupBitOr, groupBitXor
|
||||
void ArithmeticOperationsInAgrFuncMatcher::visit(ASTFunction * function_node, Data data)
|
||||
{
|
||||
data = {};
|
||||
if (function_node->name == "sum")
|
||||
sumOptimize(function_node);
|
||||
else if (function_node->name == "min")
|
||||
minOptimize(function_node);
|
||||
else if (function_node->name == "max")
|
||||
maxOptimize(function_node);
|
||||
}
|
||||
|
||||
void ArithmeticOperationsInAgrFuncMatcher::visit(const ASTPtr & current_ast, Data data)
|
||||
{
|
||||
if (!current_ast)
|
||||
return;
|
||||
|
||||
if (auto * function_node = current_ast->as<ASTFunction>())
|
||||
visit(function_node, data);
|
||||
}
|
||||
|
||||
bool ArithmeticOperationsInAgrFuncMatcher::needChildVisit(const ASTPtr & node, const ASTPtr & child)
|
||||
{
|
||||
if (!child)
|
||||
throw Exception("AST item should not have nullptr in children", ErrorCodes::LOGICAL_ERROR);
|
||||
|
||||
if (node->as<ASTTableExpression>() || node->as<ASTArrayJoin>())
|
||||
return false; // NOLINT
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
}
|
22
src/Interpreters/ArithmeticOperationsInAgrFuncOptimize.h
Normal file
22
src/Interpreters/ArithmeticOperationsInAgrFuncOptimize.h
Normal file
@ -0,0 +1,22 @@
|
||||
#pragma once
|
||||
|
||||
#include <Parsers/IAST.h>
|
||||
#include <Interpreters/InDepthNodeVisitor.h>
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
/// It converts some arithmetic. Optimization due to the linearity property of some aggregate functions.
|
||||
/// Function collects const and not const nodes and rebuilds old tree.
|
||||
class ArithmeticOperationsInAgrFuncMatcher
|
||||
{
|
||||
public:
|
||||
struct Data {};
|
||||
|
||||
static void visit(const ASTPtr & ast, Data data);
|
||||
static void visit(ASTFunction *, Data data);
|
||||
static bool needChildVisit(const ASTPtr & node, const ASTPtr & child);
|
||||
|
||||
};
|
||||
using ArithmeticOperationsInAgrFuncVisitor = InDepthNodeVisitor<ArithmeticOperationsInAgrFuncMatcher, true>;
|
||||
}
|
@ -22,6 +22,7 @@
|
||||
#include <Interpreters/ExpressionActions.h> /// getSmallestColumn()
|
||||
#include <Interpreters/getTableExpressions.h>
|
||||
#include <Interpreters/OptimizeIfChains.h>
|
||||
#include <Interpreters/ArithmeticOperationsInAgrFuncOptimize.h>
|
||||
|
||||
#include <Parsers/ASTExpressionList.h>
|
||||
#include <Parsers/ASTFunction.h>
|
||||
@ -429,6 +430,16 @@ void optimizeIf(ASTPtr & query, Aliases & aliases, bool if_chain_to_miltiif)
|
||||
OptimizeIfChainsVisitor().visit(query);
|
||||
}
|
||||
|
||||
void optimizeArithmeticOperationsInAgr(ASTPtr & query, bool optimize_arithmetic_operations_in_agr_func)
|
||||
{
|
||||
if (optimize_arithmetic_operations_in_agr_func)
|
||||
{
|
||||
/// Removing arithmetic operations from functions
|
||||
ArithmeticOperationsInAgrFuncVisitor::Data data = {};
|
||||
ArithmeticOperationsInAgrFuncVisitor(data).visit(query);
|
||||
}
|
||||
}
|
||||
|
||||
void getArrayJoinedColumns(ASTPtr & query, SyntaxAnalyzerResult & result, const ASTSelectQuery * select_query,
|
||||
const NamesAndTypesList & source_columns, const NameSet & source_columns_set)
|
||||
{
|
||||
@ -811,6 +822,8 @@ SyntaxAnalyzerResultPtr SyntaxAnalyzer::analyzeSelect(
|
||||
{
|
||||
optimizeIf(query, result.aliases, settings.optimize_if_chain_to_miltiif);
|
||||
|
||||
optimizeArithmeticOperationsInAgr(query, settings.optimize_arithmetic_operations_in_agr_func);
|
||||
|
||||
/// Push the predicate expression down to the subqueries.
|
||||
result.rewrite_subqueries = PredicateExpressionsOptimizer(context, tables_with_column_names, settings).optimize(*select_query);
|
||||
|
||||
|
@ -19,6 +19,8 @@ SRCS(
|
||||
addMissingDefaults.cpp
|
||||
addTypeConversionToAST.cpp
|
||||
Aggregator.cpp
|
||||
ArithmeticOperationsInAgrFuncOptimize.cpp
|
||||
ArithmeticOperationsInAgrFuncOptimize.h
|
||||
ArrayJoinAction.cpp
|
||||
AsynchronousMetrics.cpp
|
||||
BloomFilter.cpp
|
||||
|
22
tests/performance/arithmetic_operations_in_aggr_func.xml
Normal file
22
tests/performance/arithmetic_operations_in_aggr_func.xml
Normal file
@ -0,0 +1,22 @@
|
||||
<test>
|
||||
|
||||
<stop_conditions>
|
||||
<all_of>
|
||||
<iterations>10</iterations>
|
||||
</all_of>
|
||||
</stop_conditions>
|
||||
|
||||
|
||||
<query>SELECT max(-1 * (((-2 * (number * -3)) * -4) * -5)) FROM numbers(120000000)</query>
|
||||
|
||||
<query>SELECT min(-1 * (((-2 * (number * -3)) * -4) * -5)) FROM numbers(120000000)</query>
|
||||
|
||||
<query>SELECT sum(-1 * (((-2 * (number * -3)) * -4) * -5)) FROM numbers(120000000)</query>
|
||||
|
||||
<query>SELECT min(-1 + (((-2 + (number + -3)) + -4) + -5)) FROM numbers(120000000)</query>
|
||||
|
||||
<query>SELECT max(-1 + (((-2 + (number + -3)) + -4) + -5)) FROM numbers(120000000)</query>
|
||||
|
||||
<query>SELECT max(((((number) * 10) * -2) * 3) * 2) + min(((((number) * 10) * -2) * 3) * 2) FROM numbers(120000000)</query>
|
||||
|
||||
</test>
|
@ -0,0 +1,6 @@
|
||||
-150000044999994
|
||||
6931471.112452272
|
||||
24580677
|
||||
-150000044999994
|
||||
6931471.112452272
|
||||
24580677
|
@ -0,0 +1,9 @@
|
||||
set optimize_arithmetic_operations_in_agr_func = 1;
|
||||
SELECT sum(number * -3) + min(2 * number * -3) - max(-1 * -2 * number * -3) FROM numbers(10000000);
|
||||
SELECT max(log(2) * number) FROM numbers(10000000);
|
||||
SELECT round(max(log(2) * 3 * sin(0.3) * number * 4)) FROM numbers(10000000);
|
||||
|
||||
set optimize_arithmetic_operations_in_agr_func = 0;
|
||||
SELECT sum(number * -3) + min(2 * number * -3) - max(-1 * -2 * number * -3) FROM numbers(10000000);
|
||||
SELECT max(log(2) * number) FROM numbers(10000000);
|
||||
SELECT round(max(log(2) * 3 * sin(0.3) * number * 4)) FROM numbers(10000000);
|
Loading…
Reference in New Issue
Block a user