diff --git a/src/Analyzer/Passes/UniqInjectiveFunctionsEliminationPass.cpp b/src/Analyzer/Passes/UniqInjectiveFunctionsEliminationPass.cpp index 7d920721dcb..91186db0e0c 100644 --- a/src/Analyzer/Passes/UniqInjectiveFunctionsEliminationPass.cpp +++ b/src/Analyzer/Passes/UniqInjectiveFunctionsEliminationPass.cpp @@ -1,4 +1,3 @@ -#include #include #include @@ -8,8 +7,6 @@ #include #include -#include -#include namespace DB @@ -46,15 +43,14 @@ public: bool replaced_argument = false; auto replaced_uniq_function_arguments_nodes = function_node->getArguments().getNodes(); - DataTypes new_argument_types; - new_argument_types.reserve(replaced_uniq_function_arguments_nodes.size()); - + /// Replace injective function with its single argument auto remove_injective_function = [&replaced_argument](QueryTreeNodePtr & arg) -> bool { auto * arg_typed = arg->as(); if (!arg_typed || !arg_typed->isOrdinaryFunction()) return false; + /// Do not apply optimization if injective function contains multiple arguments auto & arg_arguments_nodes = arg_typed->getArguments().getNodes(); if (arg_arguments_nodes.size() != 1) return false; @@ -71,27 +67,32 @@ public: { while (remove_injective_function(uniq_function_argument_node)) ; - new_argument_types.emplace_back(uniq_function_argument_node->getResultType()); } if (!replaced_argument) return; + DataTypes replaced_argument_types; + replaced_argument_types.reserve(replaced_uniq_function_arguments_nodes.size()); + + for (const auto & function_node_argument : replaced_uniq_function_arguments_nodes) + replaced_argument_types.emplace_back(function_node_argument->getResultType()); + auto current_aggregate_function = function_node->getAggregateFunction(); AggregateFunctionProperties properties; - auto new_aggregate_function = AggregateFunctionFactory::instance().get( + auto replaced_aggregate_function = AggregateFunctionFactory::instance().get( function_node->getFunctionName(), NullsAction::EMPTY, - new_argument_types, + replaced_argument_types, current_aggregate_function->getParameters(), properties); /// uniqCombined returns nullable with nullable arguments so the result type might change which breaks the pass - if (!new_aggregate_function->getResultType()->equals(*current_aggregate_function->getResultType())) + if (!replaced_aggregate_function->getResultType()->equals(*current_aggregate_function->getResultType())) return; function_node->getArguments().getNodes() = std::move(replaced_uniq_function_arguments_nodes); - function_node->resolveAsAggregateFunction(std::move(new_aggregate_function)); + function_node->resolveAsAggregateFunction(std::move(replaced_aggregate_function)); } }; diff --git a/tests/queries/0_stateless/02493_analyzer_uniq_injective_functions_elimination.reference b/tests/queries/0_stateless/02493_analyzer_uniq_injective_functions_elimination.reference index 74e5da04993..eb036e1b0c1 100644 --- a/tests/queries/0_stateless/02493_analyzer_uniq_injective_functions_elimination.reference +++ b/tests/queries/0_stateless/02493_analyzer_uniq_injective_functions_elimination.reference @@ -17,7 +17,6 @@ QUERY id: 0 ARGUMENTS LIST id: 9, nodes: 1 CONSTANT id: 10, constant_value: UInt64_1, constant_value_type: UInt8 - SETTINGS allow_experimental_analyzer=1 optimize_injective_functions_inside_uniq=1 1 QUERY id: 0 PROJECTION COLUMNS @@ -33,5 +32,4 @@ QUERY id: 0 ARGUMENTS LIST id: 6, nodes: 1 CONSTANT id: 7, constant_value: UInt64_10, constant_value_type: UInt8 - SETTINGS allow_experimental_analyzer=1 optimize_injective_functions_inside_uniq=1 10 diff --git a/tests/queries/0_stateless/02493_analyzer_uniq_injective_functions_elimination.sql b/tests/queries/0_stateless/02493_analyzer_uniq_injective_functions_elimination.sql index a389e025527..48fb0198991 100644 --- a/tests/queries/0_stateless/02493_analyzer_uniq_injective_functions_elimination.sql +++ b/tests/queries/0_stateless/02493_analyzer_uniq_injective_functions_elimination.sql @@ -1,7 +1,9 @@ -EXPLAIN QUERY TREE SELECT uniqCombined(tuple('')) FROM numbers(1) SETTINGS allow_experimental_analyzer = 1, optimize_injective_functions_inside_uniq = 1; +SET allow_experimental_analyzer = 1, optimize_injective_functions_inside_uniq = 1; -SELECT uniqCombined(tuple('')) FROM numbers(1) SETTINGS allow_experimental_analyzer = 1, optimize_injective_functions_inside_uniq = 1; +EXPLAIN QUERY TREE SELECT uniqCombined(tuple('')) FROM numbers(1); -EXPLAIN QUERY TREE SELECT uniqCombined(tuple(materialize(tuple(number)))) FROM numbers(10) SETTINGS allow_experimental_analyzer = 1, optimize_injective_functions_inside_uniq = 1; +SELECT uniqCombined(tuple('')) FROM numbers(1); -SELECT uniqCombined(tuple(materialize(tuple(number)))) FROM numbers(10) SETTINGS allow_experimental_analyzer = 1, optimize_injective_functions_inside_uniq = 1; +EXPLAIN QUERY TREE SELECT uniqCombined(tuple(materialize(tuple(number)))) FROM numbers(10); + +SELECT uniqCombined(tuple(materialize(tuple(number)))) FROM numbers(10);