diff --git a/src/Analyzer/Passes/OptimizeGroupByFunctionKeysPass.cpp b/src/Analyzer/Passes/OptimizeGroupByFunctionKeysPass.cpp index a61a2388d76..55ed931885f 100644 --- a/src/Analyzer/Passes/OptimizeGroupByFunctionKeysPass.cpp +++ b/src/Analyzer/Passes/OptimizeGroupByFunctionKeysPass.cpp @@ -42,22 +42,29 @@ public: } private: + struct NodeWithInfo + { + QueryTreeNodePtr node; + bool parents_are_only_deterministic; + }; + static bool canBeEliminated(QueryTreeNodePtr & node, const QueryTreeNodePtrWithHashSet & group_by_keys) { auto * function = node->as(); if (!function || function->getArguments().getNodes().empty()) return false; - QueryTreeNodes candidates; + std::vector candidates; auto & function_arguments = function->getArguments().getNodes(); + bool is_deterministic = function->getFunction()->isDeterministicInScopeOfQuery(); for (auto it = function_arguments.rbegin(); it != function_arguments.rend(); ++it) - candidates.push_back(*it); + candidates.push_back({ *it, is_deterministic }); // Using DFS we traverse function tree and try to find if it uses other keys as function arguments. // TODO: Also process CONSTANT here. We can simplify GROUP BY x, x + 1 to GROUP BY x. while (!candidates.empty()) { - auto candidate = candidates.back(); + auto [candidate, deterministic_context] = candidates.back(); candidates.pop_back(); bool found = group_by_keys.contains(candidate); @@ -73,8 +80,9 @@ private: if (!found) { + bool is_deterministic_function = deterministic_context && function->getFunction()->isDeterministicInScopeOfQuery(); for (auto it = arguments.rbegin(); it != arguments.rend(); ++it) - candidates.push_back(*it); + candidates.push_back({ *it, is_deterministic_function }); } break; } @@ -82,6 +90,10 @@ private: if (!found) return false; break; + case QueryTreeNodeType::CONSTANT: + if (!deterministic_context) + return false; + break; default: return false; } diff --git a/tests/queries/0_stateless/02481_analyzer_optimize_grouping_sets_keys.reference b/tests/queries/0_stateless/02481_analyzer_optimize_grouping_sets_keys.reference index fdab24700ac..d01bb5715ad 100644 --- a/tests/queries/0_stateless/02481_analyzer_optimize_grouping_sets_keys.reference +++ b/tests/queries/0_stateless/02481_analyzer_optimize_grouping_sets_keys.reference @@ -246,3 +246,40 @@ QUERY id: 0, group_by_type: grouping_sets ARGUMENTS LIST id: 53, nodes: 1 COLUMN id: 54, column_name: number, result_type: UInt64, source_id: 11 +QUERY id: 0, group_by_type: grouping_sets + PROJECTION COLUMNS + count() UInt64 + PROJECTION + LIST id: 1, nodes: 1 + FUNCTION id: 2, function_name: count, function_type: aggregate, result_type: UInt64 + JOIN TREE + TABLE_FUNCTION id: 3, table_function_name: numbers + ARGUMENTS + LIST id: 4, nodes: 1 + CONSTANT id: 5, constant_value: UInt64_1000, constant_value_type: UInt16 + GROUP BY + LIST id: 6, nodes: 3 + LIST id: 7, nodes: 1 + COLUMN id: 8, column_name: number, result_type: UInt64, source_id: 3 + LIST id: 9, nodes: 2 + FUNCTION id: 10, function_name: modulo, function_type: ordinary, result_type: UInt8 + ARGUMENTS + LIST id: 11, nodes: 2 + COLUMN id: 8, column_name: number, result_type: UInt64, source_id: 3 + CONSTANT id: 12, constant_value: UInt64_2, constant_value_type: UInt8 + FUNCTION id: 13, function_name: modulo, function_type: ordinary, result_type: UInt8 + ARGUMENTS + LIST id: 14, nodes: 2 + COLUMN id: 8, column_name: number, result_type: UInt64, source_id: 3 + CONSTANT id: 15, constant_value: UInt64_3, constant_value_type: UInt8 + LIST id: 16, nodes: 2 + FUNCTION id: 17, function_name: divide, function_type: ordinary, result_type: Float64 + ARGUMENTS + LIST id: 18, nodes: 2 + COLUMN id: 8, column_name: number, result_type: UInt64, source_id: 3 + CONSTANT id: 19, constant_value: UInt64_2, constant_value_type: UInt8 + FUNCTION id: 20, function_name: divide, function_type: ordinary, result_type: Float64 + ARGUMENTS + LIST id: 21, nodes: 2 + COLUMN id: 8, column_name: number, result_type: UInt64, source_id: 3 + CONSTANT id: 22, constant_value: UInt64_3, constant_value_type: UInt8 diff --git a/tests/queries/0_stateless/02481_analyzer_optimize_grouping_sets_keys.sql b/tests/queries/0_stateless/02481_analyzer_optimize_grouping_sets_keys.sql index 0c757cb111c..b51233f734c 100644 --- a/tests/queries/0_stateless/02481_analyzer_optimize_grouping_sets_keys.sql +++ b/tests/queries/0_stateless/02481_analyzer_optimize_grouping_sets_keys.sql @@ -15,3 +15,12 @@ SELECT avg(log(2) * number) AS k FROM numbers(10000000) GROUP BY GROUPING SETS (((number % 2) * (number % 3), number % 3), (number % 2)) HAVING avg(log(2) * number) > 3465735.3 ORDER BY k; + +EXPLAIN QUERY TREE run_passes=1 +SELECT count() FROM numbers(1000) +GROUP BY GROUPING SETS + ( + (number, number + 1, number +2), + (number % 2, number % 3), + (number / 2, number / 3) + );