From e9ea1307e085f425b041c049a66edcf5fd76e609 Mon Sep 17 00:00:00 2001 From: Dmitry Novik Date: Tue, 23 Jan 2024 13:03:09 +0000 Subject: [PATCH] Make optimization more general --- .../OptimizeGroupByInjectiveFunctionsPass.cpp | 54 +++---- src/Analyzer/Passes/QueryAnalysisPass.cpp | 7 +- src/Core/Settings.h | 1 + ...er_eliminate_injective_functions.reference | 142 ++++++++++++++++++ ...analyzer_eliminate_injective_functions.sql | 31 ++++ 5 files changed, 201 insertions(+), 34 deletions(-) create mode 100644 tests/queries/0_stateless/02969_analyzer_eliminate_injective_functions.reference create mode 100644 tests/queries/0_stateless/02969_analyzer_eliminate_injective_functions.sql diff --git a/src/Analyzer/Passes/OptimizeGroupByInjectiveFunctionsPass.cpp b/src/Analyzer/Passes/OptimizeGroupByInjectiveFunctionsPass.cpp index 6dd36733edc..864752cdbeb 100644 --- a/src/Analyzer/Passes/OptimizeGroupByInjectiveFunctionsPass.cpp +++ b/src/Analyzer/Passes/OptimizeGroupByInjectiveFunctionsPass.cpp @@ -40,6 +40,9 @@ public: void enterImpl(QueryTreeNodePtr & node) { + if (!getSettings().optimize_injective_functions_in_group_by) + return; + auto * query = node->as(); if (!query) return; @@ -72,55 +75,42 @@ private: new_group_by_keys.reserve(grouping_set.size()); for (auto & group_by_elem : grouping_set) { - if (auto const * function_node = group_by_elem->as()) + std::queue nodes_to_process; + nodes_to_process.push(group_by_elem); + + while (!nodes_to_process.empty()) { - bool can_be_eliminated = false; - if (possibly_injective_function_names.contains(function_node->getFunctionName())) + auto node_to_process = nodes_to_process.front(); + nodes_to_process.pop(); + + auto const * function_node = node_to_process->as(); + if (!function_node) { - can_be_eliminated = canBeEliminated(function_node, context); - } - else - { - auto function = function_node->getFunctionOrThrow(); - can_be_eliminated = function->isInjective(function_node->getArgumentColumns()); + // Constant aggregation keys are removed in PlannerExpressionAnalysis.cpp + new_group_by_keys.push_back(node_to_process); + continue; } + // Aggregate functions are not allowed in GROUP BY clause + auto function = function_node->getFunctionOrThrow(); + bool can_be_eliminated = function->isInjective(function_node->getArgumentColumns()); + if (can_be_eliminated) { for (auto const & argument : function_node->getArguments()) { + // We can skip constants here because aggregation key is already not a constant. if (argument->getNodeType() != QueryTreeNodeType::CONSTANT) - new_group_by_keys.push_back(argument); + nodes_to_process.push(argument); } } else - new_group_by_keys.push_back(group_by_elem); + new_group_by_keys.push_back(node_to_process); } - else - new_group_by_keys.push_back(group_by_elem); } grouping_set = std::move(new_group_by_keys); } - - bool canBeEliminated(const FunctionNode * function_node, const ContextPtr & context) - { - const auto & function_arguments = function_node->getArguments().getNodes(); - auto const * dict_name_arg = function_arguments[0]->as(); - if (!dict_name_arg || !isString(dict_name_arg->getResultType())) - return false; - auto dict_name = dict_name_arg->getValue().safeGet(); - - const auto & dict_ptr = context->getExternalDictionariesLoader().getDictionary(dict_name, context); - - auto const * attr_name_arg = function_arguments[1]->as(); - if (!attr_name_arg || !isString(attr_name_arg->getResultType())) - return false; - auto attr_name = attr_name_arg->getValue().safeGet(); - - return dict_ptr->isInjective(attr_name); - } - }; } diff --git a/src/Analyzer/Passes/QueryAnalysisPass.cpp b/src/Analyzer/Passes/QueryAnalysisPass.cpp index c71eb9e3aca..840b4dbb96e 100644 --- a/src/Analyzer/Passes/QueryAnalysisPass.cpp +++ b/src/Analyzer/Passes/QueryAnalysisPass.cpp @@ -2315,11 +2315,15 @@ std::pair QueryAnalyzer::recursivelyCollectMaxOrdinaryExpressions( */ void QueryAnalyzer::expandGroupByAll(QueryNode & query_tree_node_typed) { + if (!query_tree_node_typed.isGroupByAll()) + return; + auto & group_by_nodes = query_tree_node_typed.getGroupBy().getNodes(); auto & projection_list = query_tree_node_typed.getProjection(); for (auto & node : projection_list.getNodes()) recursivelyCollectMaxOrdinaryExpressions(node, group_by_nodes); + query_tree_node_typed.setIsGroupByAll(false); } void QueryAnalyzer::expandOrderByAll(QueryNode & query_tree_node_typed) @@ -7380,8 +7384,7 @@ void QueryAnalyzer::resolveQuery(const QueryTreeNodePtr & query_node, Identifier node->removeAlias(); } - if (query_node_typed.isGroupByAll()) - expandGroupByAll(query_node_typed); + expandGroupByAll(query_node_typed); validateFilters(query_node); validateAggregates(query_node, { .group_by_use_nulls = scope.group_by_use_nulls }); diff --git a/src/Core/Settings.h b/src/Core/Settings.h index 292e945a29c..def1a1a80d5 100644 --- a/src/Core/Settings.h +++ b/src/Core/Settings.h @@ -699,6 +699,7 @@ class IColumn; M(SetOperationMode, intersect_default_mode, SetOperationMode::ALL, "Set default mode in INTERSECT query. Possible values: empty string, 'ALL', 'DISTINCT'. If empty, query without mode will throw exception.", 0) \ M(SetOperationMode, except_default_mode, SetOperationMode::ALL, "Set default mode in EXCEPT query. Possible values: empty string, 'ALL', 'DISTINCT'. If empty, query without mode will throw exception.", 0) \ M(Bool, optimize_aggregators_of_group_by_keys, true, "Eliminates min/max/any/anyLast aggregators of GROUP BY keys in SELECT section", 0) \ + M(Bool, optimize_injective_functions_in_group_by, true, "Replaces injective functions by it's arguments in GROUP BY section", 0) \ M(Bool, optimize_group_by_function_keys, true, "Eliminates functions of other keys in GROUP BY section", 0) \ M(Bool, optimize_group_by_constant_keys, true, "Optimize GROUP BY when all keys in block are constant", 0) \ M(Bool, legacy_column_name_of_tuple_literal, false, "List all names of element of large tuple literals in their column names instead of hash. This settings exists only for compatibility reasons. It makes sense to set to 'true', while doing rolling update of cluster from version lower than 21.7 to higher.", 0) \ diff --git a/tests/queries/0_stateless/02969_analyzer_eliminate_injective_functions.reference b/tests/queries/0_stateless/02969_analyzer_eliminate_injective_functions.reference new file mode 100644 index 00000000000..72d83e5cf6a --- /dev/null +++ b/tests/queries/0_stateless/02969_analyzer_eliminate_injective_functions.reference @@ -0,0 +1,142 @@ +QUERY id: 0 + PROJECTION COLUMNS + val String + count() UInt64 + PROJECTION + LIST id: 1, nodes: 2 + FUNCTION id: 2, function_name: toString, function_type: ordinary, result_type: String + ARGUMENTS + LIST id: 3, nodes: 1 + FUNCTION id: 4, function_name: toString, function_type: ordinary, result_type: String + ARGUMENTS + LIST id: 5, nodes: 1 + FUNCTION id: 6, function_name: plus, function_type: ordinary, result_type: UInt64 + ARGUMENTS + LIST id: 7, nodes: 2 + COLUMN id: 8, column_name: number, result_type: UInt64, source_id: 9 + CONSTANT id: 10, constant_value: UInt64_1, constant_value_type: UInt8 + FUNCTION id: 11, function_name: count, function_type: aggregate, result_type: UInt64 + JOIN TREE + TABLE_FUNCTION id: 9, alias: __table1, table_function_name: numbers + ARGUMENTS + LIST id: 12, nodes: 1 + CONSTANT id: 13, constant_value: UInt64_2, constant_value_type: UInt8 + GROUP BY + LIST id: 14, nodes: 1 + FUNCTION id: 6, function_name: plus, function_type: ordinary, result_type: UInt64 + ARGUMENTS + LIST id: 7, nodes: 2 + COLUMN id: 8, column_name: number, result_type: UInt64, source_id: 9 + CONSTANT id: 10, constant_value: UInt64_1, constant_value_type: UInt8 + ORDER BY + LIST id: 15, nodes: 1 + SORT id: 16, sort_direction: ASCENDING, with_fill: 0 + EXPRESSION + FUNCTION id: 2, function_name: toString, function_type: ordinary, result_type: String + ARGUMENTS + LIST id: 3, nodes: 1 + FUNCTION id: 4, function_name: toString, function_type: ordinary, result_type: String + ARGUMENTS + LIST id: 5, nodes: 1 + FUNCTION id: 6, function_name: plus, function_type: ordinary, result_type: UInt64 + ARGUMENTS + LIST id: 7, nodes: 2 + COLUMN id: 8, column_name: number, result_type: UInt64, source_id: 9 + CONSTANT id: 10, constant_value: UInt64_1, constant_value_type: UInt8 +1 1 +2 1 +QUERY id: 0 + PROJECTION COLUMNS + val String + count() UInt64 + PROJECTION + LIST id: 1, nodes: 2 + FUNCTION id: 2, function_name: toString, function_type: ordinary, result_type: String + ARGUMENTS + LIST id: 3, nodes: 1 + FUNCTION id: 4, function_name: toString, function_type: ordinary, result_type: String + ARGUMENTS + LIST id: 5, nodes: 1 + FUNCTION id: 6, function_name: plus, function_type: ordinary, result_type: UInt64 + ARGUMENTS + LIST id: 7, nodes: 2 + COLUMN id: 8, column_name: number, result_type: UInt64, source_id: 9 + CONSTANT id: 10, constant_value: UInt64_1, constant_value_type: UInt8 + FUNCTION id: 11, function_name: count, function_type: aggregate, result_type: UInt64 + JOIN TREE + TABLE_FUNCTION id: 9, alias: __table1, table_function_name: numbers + ARGUMENTS + LIST id: 12, nodes: 1 + CONSTANT id: 13, constant_value: UInt64_2, constant_value_type: UInt8 + GROUP BY + LIST id: 14, nodes: 1 + FUNCTION id: 6, function_name: plus, function_type: ordinary, result_type: UInt64 + ARGUMENTS + LIST id: 7, nodes: 2 + COLUMN id: 8, column_name: number, result_type: UInt64, source_id: 9 + CONSTANT id: 10, constant_value: UInt64_1, constant_value_type: UInt8 + ORDER BY + LIST id: 15, nodes: 1 + SORT id: 16, sort_direction: ASCENDING, with_fill: 0 + EXPRESSION + FUNCTION id: 2, function_name: toString, function_type: ordinary, result_type: String + ARGUMENTS + LIST id: 3, nodes: 1 + FUNCTION id: 4, function_name: toString, function_type: ordinary, result_type: String + ARGUMENTS + LIST id: 5, nodes: 1 + FUNCTION id: 6, function_name: plus, function_type: ordinary, result_type: UInt64 + ARGUMENTS + LIST id: 7, nodes: 2 + COLUMN id: 8, column_name: number, result_type: UInt64, source_id: 9 + CONSTANT id: 10, constant_value: UInt64_1, constant_value_type: UInt8 +CHECK WITH TOTALS +QUERY id: 0, is_group_by_with_totals: 1 + PROJECTION COLUMNS + val String + count() UInt64 + PROJECTION + LIST id: 1, nodes: 2 + FUNCTION id: 2, function_name: toString, function_type: ordinary, result_type: String + ARGUMENTS + LIST id: 3, nodes: 1 + FUNCTION id: 4, function_name: toString, function_type: ordinary, result_type: String + ARGUMENTS + LIST id: 5, nodes: 1 + FUNCTION id: 6, function_name: plus, function_type: ordinary, result_type: UInt64 + ARGUMENTS + LIST id: 7, nodes: 2 + COLUMN id: 8, column_name: number, result_type: UInt64, source_id: 9 + CONSTANT id: 10, constant_value: UInt64_1, constant_value_type: UInt8 + FUNCTION id: 11, function_name: count, function_type: aggregate, result_type: UInt64 + JOIN TREE + TABLE_FUNCTION id: 9, alias: __table1, table_function_name: numbers + ARGUMENTS + LIST id: 12, nodes: 1 + CONSTANT id: 13, constant_value: UInt64_2, constant_value_type: UInt8 + GROUP BY + LIST id: 14, nodes: 1 + FUNCTION id: 6, function_name: plus, function_type: ordinary, result_type: UInt64 + ARGUMENTS + LIST id: 7, nodes: 2 + COLUMN id: 8, column_name: number, result_type: UInt64, source_id: 9 + CONSTANT id: 10, constant_value: UInt64_1, constant_value_type: UInt8 + ORDER BY + LIST id: 15, nodes: 1 + SORT id: 16, sort_direction: ASCENDING, with_fill: 0 + EXPRESSION + FUNCTION id: 2, function_name: toString, function_type: ordinary, result_type: String + ARGUMENTS + LIST id: 3, nodes: 1 + FUNCTION id: 4, function_name: toString, function_type: ordinary, result_type: String + ARGUMENTS + LIST id: 5, nodes: 1 + FUNCTION id: 6, function_name: plus, function_type: ordinary, result_type: UInt64 + ARGUMENTS + LIST id: 7, nodes: 2 + COLUMN id: 8, column_name: number, result_type: UInt64, source_id: 9 + CONSTANT id: 10, constant_value: UInt64_1, constant_value_type: UInt8 +1 1 +2 1 + +0 2 diff --git a/tests/queries/0_stateless/02969_analyzer_eliminate_injective_functions.sql b/tests/queries/0_stateless/02969_analyzer_eliminate_injective_functions.sql new file mode 100644 index 00000000000..15f2550a63e --- /dev/null +++ b/tests/queries/0_stateless/02969_analyzer_eliminate_injective_functions.sql @@ -0,0 +1,31 @@ +set allow_experimental_analyzer = 1; + +EXPLAIN QUERY TREE +SELECT toString(toString(number + 1)) as val, count() +FROM numbers(2) +GROUP BY val +ORDER BY val; + +SELECT toString(toString(number + 1)) as val, count() +FROM numbers(2) +GROUP BY ALL +ORDER BY val; + +EXPLAIN QUERY TREE +SELECT toString(toString(number + 1)) as val, count() +FROM numbers(2) +GROUP BY ALL +ORDER BY val; + +SELECT 'CHECK WITH TOTALS'; + +EXPLAIN QUERY TREE +SELECT toString(toString(number + 1)) as val, count() +FROM numbers(2) +GROUP BY val WITH TOTALS +ORDER BY val; + +SELECT toString(toString(number + 1)) as val, count() +FROM numbers(2) +GROUP BY val WITH TOTALS +ORDER BY val;