diff --git a/src/Analyzer/Passes/AnyFunctionPass.cpp b/src/Analyzer/Passes/MoveFunctionsOutOfAnyPass.cpp similarity index 73% rename from src/Analyzer/Passes/AnyFunctionPass.cpp rename to src/Analyzer/Passes/MoveFunctionsOutOfAnyPass.cpp index 75f12bc7d46..51edbcc6bd0 100644 --- a/src/Analyzer/Passes/AnyFunctionPass.cpp +++ b/src/Analyzer/Passes/MoveFunctionsOutOfAnyPass.cpp @@ -1,4 +1,4 @@ -#include +#include #include #include @@ -14,8 +14,80 @@ namespace DB namespace { -class AnyFunctionVisitor : public InDepthQueryTreeVisitorWithContext +class AnyFunctionViMoveFunctionsOutOfAnyVisitor : public InDepthQueryTreeVisitorWithContext { +public: + using Base = InDepthQueryTreeVisitorWithContext; + using Base::Base; + + void enterImpl(QueryTreeNodePtr & node) + { + if (!getSettings().optimize_move_functions_out_of_any) + return; + + auto * function_node = node->as(); + if (!function_node) + return; + + /// check function is any + const auto & function_name = function_node->getFunctionName(); + if (function_name != "any" && function_name != "anyLast") + return; + + auto & arguments = function_node->getArguments().getNodes(); + if (arguments.size() != 1) + return; + + auto * inside_function_node = arguments[0]->as(); + + /// check argument is a function + if (!inside_function_node) + return; + + /// check arguments can not contain arrayJoin or lambda + if (!canRewrite(inside_function_node)) + return; + + auto & inside_function_node_arguments = inside_function_node->getArguments().getNodes(); + + /// case any(f()) + if (inside_function_node_arguments.empty()) + return; + + auto it = node_to_rewritten_node.find(node.get()); + if (it != node_to_rewritten_node.end()) + { + node = it->second; + return; + } + + /// checking done, rewrite function + bool changed_argument = false; + for (auto & inside_argument : inside_function_node_arguments) + { + if (inside_argument->as()) /// skip constant node + break; + + AggregateFunctionProperties properties; + auto aggregate_function = AggregateFunctionFactory::instance().get(function_name, {inside_argument->getResultType()}, {}, properties); + + auto any_function = std::make_shared(function_name); + any_function->resolveAsAggregateFunction(std::move(aggregate_function)); + + auto & any_function_arguments = any_function->getArguments().getNodes(); + any_function_arguments.push_back(std::move(inside_argument)); + + inside_argument = std::move(any_function); + changed_argument = true; + } + + if (changed_argument) + { + node_to_rewritten_node.emplace(node.get(), arguments[0]); + node = arguments[0]; + } + } + private: bool canRewrite(const FunctionNode * function_node) { @@ -45,90 +117,17 @@ private: return true; } -public: - using Base = InDepthQueryTreeVisitorWithContext; - using Base::Base; - - void enterImpl(QueryTreeNodePtr & node) - { - if (!getSettings().optimize_move_functions_out_of_any) - return; - - auto * function_node = node->as(); - if (!function_node) - return; - - /// check function is any - const auto & function_name = function_node->getFunctionName(); - if (!(function_name == "any" || function_name == "anyLast")) - return; - - auto & arguments = function_node->getArguments().getNodes(); - if (arguments.size() != 1) - return; - - auto * inside_function_node = arguments[0]->as(); - - /// check argument is a function - if (!inside_function_node) - return; - - /// check arguments can not contain arrayJoin or lambda - if (!canRewrite(inside_function_node)) - return; - - auto & inside_arguments = inside_function_node->getArguments().getNodes(); - - /// case any(f()) - if (inside_arguments.empty()) - return; - - if (rewritten.contains(node.get())) - { - node = rewritten.at(node.get()); - return; - } - - /// checking done, rewrite function - bool pushed = false; - for (auto & inside_argument : inside_arguments) - { - if (inside_argument->as()) /// skip constant node - break; - - AggregateFunctionProperties properties; - auto aggregate_function = AggregateFunctionFactory::instance().get(function_name, {inside_argument->getResultType()}, {}, properties); - - auto any_function = std::make_shared(function_name); - any_function->resolveAsAggregateFunction(std::move(aggregate_function)); - - auto & any_function_arguments = any_function->getArguments().getNodes(); - any_function_arguments.push_back(std::move(inside_argument)); - - inside_argument = std::move(any_function); - pushed = true; - } - - if (pushed) - { - rewritten.insert({node.get(), arguments[0]}); - node = arguments[0]; - } - } - -private: - /// After query analysis alias will be rewritten to QueryTreeNode - /// whose memory address is same with the original one. - /// So we can reuse the rewritten one. - std::unordered_map rewritten; + /// After query analysis, alias identifier will be resolved to node whose memory address is same with the original one. + /// So we can reuse the rewritten function. + std::unordered_map node_to_rewritten_node; }; } -void AnyFunctionPass::run(QueryTreeNodePtr query_tree_node, ContextPtr context) +void MoveFunctionsOutOfAnyPass::run(QueryTreeNodePtr query_tree_node, ContextPtr context) { - AnyFunctionVisitor visitor(context); + AnyFunctionViMoveFunctionsOutOfAnyVisitor visitor(context); visitor.visit(query_tree_node); } diff --git a/src/Analyzer/Passes/AnyFunctionPass.h b/src/Analyzer/Passes/MoveFunctionsOutOfAnyPass.h similarity index 64% rename from src/Analyzer/Passes/AnyFunctionPass.h rename to src/Analyzer/Passes/MoveFunctionsOutOfAnyPass.h index 0cc65d238dd..09a53f2b9e0 100644 --- a/src/Analyzer/Passes/AnyFunctionPass.h +++ b/src/Analyzer/Passes/MoveFunctionsOutOfAnyPass.h @@ -7,13 +7,13 @@ namespace DB /** Rewrite 'any' and 'anyLast' functions pushing them inside original function. * - * Example: any(f(x, y, g(z))) - * Result: f(any(x), any(y), g(any(z))) + * Example: SELECT any(f(x, y, g(z))); + * Result: SELECT f(any(x), any(y), g(any(z))); */ -class AnyFunctionPass final : public IQueryTreePass +class MoveFunctionsOutOfAnyPass final : public IQueryTreePass { public: - String getName() override { return "AnyFunction"; } + String getName() override { return "MoveFunctionsOutOfAnyPass"; } String getDescription() override { diff --git a/src/Analyzer/QueryTreePassManager.cpp b/src/Analyzer/QueryTreePassManager.cpp index 66082591890..08474c4100a 100644 --- a/src/Analyzer/QueryTreePassManager.cpp +++ b/src/Analyzer/QueryTreePassManager.cpp @@ -43,7 +43,7 @@ #include #include #include -#include +#include #include @@ -164,7 +164,6 @@ private: * * TODO: Support setting optimize_substitute_columns. * TODO: Support GROUP BY injective function elimination. - * TODO: Support setting optimize_move_functions_out_of_any. * TODO: Support setting optimize_aggregators_of_group_by_keys. * TODO: Support setting optimize_monotonous_functions_in_order_by. * TODO: Add optimizations based on function semantics. Example: SELECT * FROM test_table WHERE id != id. (id is not nullable column). @@ -283,7 +282,7 @@ void addQueryTreePasses(QueryTreePassManager & manager) manager.addPass(std::make_unique()); manager.addPass(std::make_unique()); - manager.addPass(std::make_unique()); + manager.addPass(std::make_unique()); manager.addPass(std::make_unique()); }