diff --git a/src/Core/Settings.h b/src/Core/Settings.h index d1e36e9f3de..e1d64a783d3 100644 --- a/src/Core/Settings.h +++ b/src/Core/Settings.h @@ -411,6 +411,7 @@ struct Settings : public SettingsCollection M(SettingBool, enable_scalar_subquery_optimization, true, "If it is set to true, prevent scalar subqueries from (de)serializing large scalar values and possibly avoid running the same subquery more than once.", 0) \ M(SettingBool, optimize_trivial_count_query, true, "Process trivial 'SELECT count() FROM table' query from metadata.", 0) \ M(SettingUInt64, mutations_sync, 0, "Wait for synchronous execution of ALTER TABLE UPDATE/DELETE queries (mutations). 0 - execute asynchronously. 1 - wait current server. 2 - wait all replicas if they exist.", 0) \ + M(SettingBool, optimize_arithmetic_operations_in_agr_func, true, "Removing arithmetic operations from aggregation functions", 0) \ M(SettingBool, optimize_if_chain_to_miltiif, false, "Replace if(cond1, then1, if(cond2, ...)) chains to multiIf. Currently it's not beneficial for numeric types.", 0) \ M(SettingBool, allow_experimental_alter_materialized_view_structure, false, "Allow atomic alter on Materialized views. Work in progress.", 0) \ M(SettingBool, enable_early_constant_folding, true, "Enable query optimization where we analyze function and subqueries results and rewrite query if there're constants there", 0) \ diff --git a/src/Interpreters/ArithmeticOperationsInAgrFuncOptimize.cpp b/src/Interpreters/ArithmeticOperationsInAgrFuncOptimize.cpp new file mode 100644 index 00000000000..665c2febd9d --- /dev/null +++ b/src/Interpreters/ArithmeticOperationsInAgrFuncOptimize.cpp @@ -0,0 +1,313 @@ +#include +#include +#include +#include +#include +#include + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int LOGICAL_ERROR; + extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; + extern const int UNEXPECTED_AST_STRUCTURE; +} +namespace +{ + constexpr const char * min = "min"; + constexpr const char * max = "max"; + constexpr const char * mul = "multiply"; + constexpr const char * plus = "plus"; + constexpr const char * sum = "sum"; +} + +bool isConstantField(const Field & field) +{ + return field.getType() == Field::Types::Int64 || + field.getType() == Field::Types::UInt64 || + field.getType() == Field::Types::Int128 || + field.getType() == Field::Types::UInt128; +} + +bool onlyConstsInside(const ASTFunction * func_node) +{ + return !(func_node->arguments->children[0]->as()) && + (func_node->arguments->children.size() == 2 && + !(func_node->arguments->children[1]->as())); +} + +bool inappropriateNameInside(const ASTFunction * func_node, const char * inter_func_name) +{ + return (func_node->arguments->children[0]->as() && + inter_func_name != func_node->arguments->children[0]->as()->name) || + (func_node->arguments->children.size() == 2 && + func_node->arguments->children[1]->as() && + inter_func_name != func_node->arguments->children[1]->as()->name); +} + +bool isInappropriate(const ASTPtr & node, const char * inter_func_name) +{ + return !node->as() || inter_func_name != node->as()->name; +} + +ASTFunction * getInternalFunction(const ASTFunction * f_n) +{ + const auto * function_args = f_n->arguments->as(); + if (!function_args || function_args->children.size() != 1) + throw Exception("Wrong number of arguments for function" + f_n->name + "(" + toString(function_args->children.size()) + " instead of 1)", + ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); + + return f_n->arguments->children[0]->as(); +} + +ASTFunction * treeFiller(ASTFunction * old_tree, const ASTs & nodes_array, size_t size, const char * name) +{ + for (size_t i = 0; i < size; ++i) + { + old_tree->arguments->children = {}; + old_tree->arguments->children.push_back(nodes_array[i]); + + old_tree->arguments->children.push_back(makeASTFunction(name)); + old_tree = old_tree->arguments->children[1]->as(); + } + return old_tree; +} + +/// scalar values from the first level +std::pair tryGetConst(const char * name, const ASTs & arguments) +{ + ASTs const_num; + ASTs not_const; + + for (const auto & arg : arguments) + { + if (const auto * literal = arg->as()) + { + if (isConstantField(literal->value)) + const_num.push_back(arg); + else + not_const.push_back(arg); + } + else + not_const.push_back(arg); + } + + if ((name == plus || name == mul) && const_num.size() + not_const.size() != 2) + { + throw Exception("Wrong number of arguments for function 'plus' or 'multiply' (" + toString(const_num.size() + not_const.size()) + " instead of 2)", + ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); + } + + return {const_num, not_const}; +} + +std::pair findAllConsts(const ASTFunction * func_node, const char * inter_func_name) +{ + if (!func_node->arguments) + return {}; + + if (onlyConstsInside(func_node)) + return tryGetConst(func_node->name.c_str(), func_node->arguments->children); + else if (inappropriateNameInside(func_node, inter_func_name)) + { + bool first_child_is_const = func_node->arguments->children[0]->as() && + isConstantField(func_node->arguments->children[0]->as()->value); + bool second_child_is_const = func_node->arguments->children.size() == 2 && + func_node->arguments->children[1]->as() && + isConstantField(func_node->arguments->children[1]->as()->value); + if (first_child_is_const) + return {{func_node->arguments->children[0]}, {func_node->arguments->children[1]}}; + else if (second_child_is_const) + return {{func_node->arguments->children[1]}, {func_node->arguments->children[0]}}; + + if (isInappropriate(func_node->arguments->children[0], inter_func_name) && isInappropriate(func_node->arguments->children[1], inter_func_name)) + return {{}, {func_node->arguments->children[0], func_node->arguments->children[1]}}; + else if (isInappropriate(func_node->arguments->children[0], inter_func_name)) + { + std::pair ans = findAllConsts(func_node->arguments->children[1]->as(), inter_func_name); + ans.second.push_back(func_node->arguments->children[0]); + return ans; + } + + std::pair ans = findAllConsts(func_node->arguments->children[0]->as(), inter_func_name); + ans.second.push_back(func_node->arguments->children[1]); + return ans; + } + + std::pair fl = tryGetConst(func_node->name.c_str(), func_node->arguments->children); + ASTs first_lvl_consts = fl.first; + ASTs first_lvl_not_consts = fl.second; + if (!first_lvl_not_consts[0]->as()) + return {first_lvl_consts, first_lvl_not_consts}; + + std::pair ans = findAllConsts(first_lvl_not_consts[0]->as(), inter_func_name); + ASTs all_consts = ans.first; + ASTs all_not_consts = ans.second; + + if (first_lvl_consts.size() == 1) + { + if (!first_lvl_not_consts[0]->as()) + all_not_consts.push_back(first_lvl_not_consts[0]); + + all_consts.push_back(first_lvl_consts[0]); + } + else if (first_lvl_consts.empty()) + { + /// if node is inappropriate to go into it, we just add this node to all_not_consts vector + bool first_node_inappropriate_to_go_into = isInappropriate(first_lvl_not_consts[0], inter_func_name); + bool second_node_inappropriate_to_go_into = first_lvl_not_consts.size() == 2 && + isInappropriate(first_lvl_not_consts[1], inter_func_name); + if (first_node_inappropriate_to_go_into) + all_not_consts.push_back(first_lvl_not_consts[0]); + + if (second_node_inappropriate_to_go_into) + all_not_consts.push_back(first_lvl_not_consts[1]); + } + else + throw Exception("did not expect that", ErrorCodes::UNEXPECTED_AST_STRUCTURE); + return {all_consts, all_not_consts}; +} + +/// rebuilds tree, all scalar values now outside the main func +void buildTree(ASTFunction * cur_node, const char * func_name, const char * intro_func, const std::pair & tree_comp) +{ + ASTs cons_val = tree_comp.first; + ASTs non_cons = tree_comp.second; + + cur_node->name = intro_func; + cur_node = treeFiller(cur_node, cons_val, cons_val.size(), intro_func); + cur_node->name = func_name; + + if (non_cons.size() == 1) + cur_node->arguments->children.push_back(non_cons[0]); + else + { + cur_node->arguments->children.push_back(makeASTFunction(intro_func)); + cur_node = cur_node->arguments->children[0]->as(); + cur_node = treeFiller(cur_node, non_cons, non_cons.size() - 2, intro_func); + cur_node->arguments->children = {non_cons[non_cons.size() - 2], non_cons[non_cons.size() - 1]}; + } +} + +void sumOptimize(ASTFunction * f_n) +{ + const auto * function_args = f_n->arguments->as(); + + if (!function_args || function_args->children.size() != 1) + throw Exception("Wrong number of arguments for function 'sum' (" + toString(function_args->children.size()) + " instead of 1)", + ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); + + ASTFunction * inter_node = getInternalFunction(f_n); + if (inter_node && inter_node->name == mul) + { + std::pair nodes = findAllConsts(f_n, mul); + + if (nodes.first.empty()) + return; + + buildTree(f_n, sum, mul, nodes); + } +} + +void minOptimize(ASTFunction * f_n) +{ + ASTFunction * inter_node = getInternalFunction(f_n); + if (inter_node && inter_node->name == mul) + { + int sign = 1; + std::pair nodes = findAllConsts(f_n, mul); + + if (nodes.first.empty()) + return; + + for (const auto & arg : nodes.first) + { + Int128 num = applyVisitor(FieldVisitorConvertToNumber(), arg->as()->value); + + /// if multiplication is negative, min function becomes max + + if ((arg->as()->value.getType() == Field::Types::Int64 || + arg->as()->value.getType() == Field::Types::Int128) && num < static_cast(0)) + sign *= -1; + } + + if (sign == -1) + buildTree(f_n, max, mul, nodes); + else + buildTree(f_n, min, mul, nodes); + } + else if (inter_node && inter_node->name == plus) + { + std::pair nodes = findAllConsts(f_n, plus); + buildTree(f_n, min, plus, nodes); + } +} + +void maxOptimize(ASTFunction * f_n) +{ + ASTFunction * inter_node = getInternalFunction(f_n); + if (inter_node && inter_node->name == mul) + { + int sign = 1; + std::pair nodes = findAllConsts(f_n, mul); + + if (nodes.first.empty()) + return; + + for (const auto & arg: nodes.first) + { + Int128 num = applyVisitor(FieldVisitorConvertToNumber(), arg->as()->value); + + /// if multiplication is negative, max function becomes min + if ((arg->as()->value.getType() == Field::Types::Int64 || + arg->as()->value.getType() == Field::Types::Int128) && num < static_cast(0)) + sign *= -1; + } + + if (sign == -1) + buildTree(f_n, min, mul, nodes); + else + buildTree(f_n, max, mul, nodes); + } + else if (inter_node && inter_node->name == plus) + { + std::pair nodes = findAllConsts(f_n, plus); + buildTree(f_n, max, plus, nodes); + } +} + +/// optimize for min, max, sum is ready, ToDo: groupBitAnd, groupBitOr, groupBitXor +void ArithmeticOperationsInAgrFuncMatcher::visit(ASTFunction * function_node, Data data) +{ + data = {}; + if (function_node->name == "sum") + sumOptimize(function_node); + else if (function_node->name == "min") + minOptimize(function_node); + else if (function_node->name == "max") + maxOptimize(function_node); +} + +void ArithmeticOperationsInAgrFuncMatcher::visit(const ASTPtr & current_ast, Data data) +{ + if (!current_ast) + return; + + if (auto * function_node = current_ast->as()) + visit(function_node, data); +} + +bool ArithmeticOperationsInAgrFuncMatcher::needChildVisit(const ASTPtr & node, const ASTPtr & child) +{ + if (!child) + throw Exception("AST item should not have nullptr in children", ErrorCodes::LOGICAL_ERROR); + + if (node->as() || node->as()) + return false; // NOLINT + + return true; +} + +} diff --git a/src/Interpreters/ArithmeticOperationsInAgrFuncOptimize.h b/src/Interpreters/ArithmeticOperationsInAgrFuncOptimize.h new file mode 100644 index 00000000000..46af1e272db --- /dev/null +++ b/src/Interpreters/ArithmeticOperationsInAgrFuncOptimize.h @@ -0,0 +1,22 @@ +#pragma once + +#include +#include + +namespace DB +{ + +/// It converts some arithmetic. Optimization due to the linearity property of some aggregate functions. +/// Function collects const and not const nodes and rebuilds old tree. +class ArithmeticOperationsInAgrFuncMatcher +{ +public: + struct Data {}; + + static void visit(const ASTPtr & ast, Data data); + static void visit(ASTFunction *, Data data); + static bool needChildVisit(const ASTPtr & node, const ASTPtr & child); + +}; +using ArithmeticOperationsInAgrFuncVisitor = InDepthNodeVisitor; +} diff --git a/src/Interpreters/SyntaxAnalyzer.cpp b/src/Interpreters/SyntaxAnalyzer.cpp index b5f86b87fdc..831379090ad 100644 --- a/src/Interpreters/SyntaxAnalyzer.cpp +++ b/src/Interpreters/SyntaxAnalyzer.cpp @@ -22,6 +22,7 @@ #include /// getSmallestColumn() #include #include +#include #include #include @@ -429,6 +430,16 @@ void optimizeIf(ASTPtr & query, Aliases & aliases, bool if_chain_to_miltiif) OptimizeIfChainsVisitor().visit(query); } +void optimizeArithmeticOperationsInAgr(ASTPtr & query, bool optimize_arithmetic_operations_in_agr_func) +{ + if (optimize_arithmetic_operations_in_agr_func) + { + /// Removing arithmetic operations from functions + ArithmeticOperationsInAgrFuncVisitor::Data data = {}; + ArithmeticOperationsInAgrFuncVisitor(data).visit(query); + } +} + void getArrayJoinedColumns(ASTPtr & query, SyntaxAnalyzerResult & result, const ASTSelectQuery * select_query, const NamesAndTypesList & source_columns, const NameSet & source_columns_set) { @@ -811,6 +822,8 @@ SyntaxAnalyzerResultPtr SyntaxAnalyzer::analyzeSelect( { optimizeIf(query, result.aliases, settings.optimize_if_chain_to_miltiif); + optimizeArithmeticOperationsInAgr(query, settings.optimize_arithmetic_operations_in_agr_func); + /// Push the predicate expression down to the subqueries. result.rewrite_subqueries = PredicateExpressionsOptimizer(context, tables_with_column_names, settings).optimize(*select_query); diff --git a/src/Interpreters/ya.make b/src/Interpreters/ya.make index b210a1c5b8c..178c3ee3125 100644 --- a/src/Interpreters/ya.make +++ b/src/Interpreters/ya.make @@ -19,6 +19,8 @@ SRCS( addMissingDefaults.cpp addTypeConversionToAST.cpp Aggregator.cpp + ArithmeticOperationsInAgrFuncOptimize.cpp + ArithmeticOperationsInAgrFuncOptimize.h ArrayJoinAction.cpp AsynchronousMetrics.cpp BloomFilter.cpp diff --git a/tests/performance/arithmetic_operations_in_aggr_func.xml b/tests/performance/arithmetic_operations_in_aggr_func.xml new file mode 100644 index 00000000000..28f13823731 --- /dev/null +++ b/tests/performance/arithmetic_operations_in_aggr_func.xml @@ -0,0 +1,22 @@ + + + + + 10 + + + + + SELECT max(-1 * (((-2 * (number * -3)) * -4) * -5)) FROM numbers(120000000) + + SELECT min(-1 * (((-2 * (number * -3)) * -4) * -5)) FROM numbers(120000000) + + SELECT sum(-1 * (((-2 * (number * -3)) * -4) * -5)) FROM numbers(120000000) + + SELECT min(-1 + (((-2 + (number + -3)) + -4) + -5)) FROM numbers(120000000) + + SELECT max(-1 + (((-2 + (number + -3)) + -4) + -5)) FROM numbers(120000000) + + SELECT max(((((number) * 10) * -2) * 3) * 2) + min(((((number) * 10) * -2) * 3) * 2) FROM numbers(120000000) + + diff --git a/tests/queries/0_stateless/01271_optimize_arithmetic_operations_in_aggr_func.reference b/tests/queries/0_stateless/01271_optimize_arithmetic_operations_in_aggr_func.reference new file mode 100644 index 00000000000..2fe897e3819 --- /dev/null +++ b/tests/queries/0_stateless/01271_optimize_arithmetic_operations_in_aggr_func.reference @@ -0,0 +1,6 @@ +-150000044999994 +6931471.112452272 +24580677 +-150000044999994 +6931471.112452272 +24580677 diff --git a/tests/queries/0_stateless/01271_optimize_arithmetic_operations_in_aggr_func.sql b/tests/queries/0_stateless/01271_optimize_arithmetic_operations_in_aggr_func.sql new file mode 100644 index 00000000000..3550ed64e8c --- /dev/null +++ b/tests/queries/0_stateless/01271_optimize_arithmetic_operations_in_aggr_func.sql @@ -0,0 +1,9 @@ +set optimize_arithmetic_operations_in_agr_func = 1; +SELECT sum(number * -3) + min(2 * number * -3) - max(-1 * -2 * number * -3) FROM numbers(10000000); +SELECT max(log(2) * number) FROM numbers(10000000); +SELECT round(max(log(2) * 3 * sin(0.3) * number * 4)) FROM numbers(10000000); + +set optimize_arithmetic_operations_in_agr_func = 0; +SELECT sum(number * -3) + min(2 * number * -3) - max(-1 * -2 * number * -3) FROM numbers(10000000); +SELECT max(log(2) * number) FROM numbers(10000000); +SELECT round(max(log(2) * 3 * sin(0.3) * number * 4)) FROM numbers(10000000);