diff --git a/src/Analyzer/Passes/AnyFunctionPass.cpp b/src/Analyzer/Passes/AnyFunctionPass.cpp index aada2d3a4a7..b785df7fb05 100644 --- a/src/Analyzer/Passes/AnyFunctionPass.cpp +++ b/src/Analyzer/Passes/AnyFunctionPass.cpp @@ -5,7 +5,9 @@ #include #include +#include #include +#include namespace DB { @@ -15,6 +17,39 @@ namespace class AnyFunctionVisitor : public InDepthQueryTreeVisitorWithContext { +private: + bool canRewrite(const FunctionNode * function_node) + { + for (auto & argument : function_node->getArguments().getNodes()) + { + /// arrayJoin() is special and should not be optimized (think about + /// it as a an aggregate function), otherwise wrong result will be + /// produced: + /// SELECT *, any(arrayJoin([[], []])) FROM numbers(1) GROUP BY number + /// ┌─number─┬─arrayJoin(array(array(), array()))─┐ + /// │ 0 │ [] │ + /// │ 0 │ [] │ + /// └────────┴────────────────────────────────────┘ + /// While should be: + /// ┌─number─┬─any(arrayJoin(array(array(), array())))─┐ + /// │ 0 │ [] │ + /// └────────┴─────────────────────────────────────────┘ + if (argument->as()) + return false; + + if (argument->as()) + return false; + + if (const auto * inside_function = argument->as()) + { + if (!canRewrite(inside_function)) + return false; + } + } + + return true; + } + public: using Base = InDepthQueryTreeVisitorWithContext; using Base::Base; @@ -24,6 +59,12 @@ public: if (!getSettings().optimize_move_functions_out_of_any) return; + if (rewritten.count(node.get())) + { + node = rewritten.at(node.get()); + return; + } + auto * function_node = node->as(); if (!function_node) return; @@ -40,8 +81,11 @@ public: auto * inside_function_node = arguments[0]->as(); /// check argument is a function and can not be arrayJoin or lambda - if (!inside_function_node || inside_function_node->getFunctionName() == "arrayJoin" - || inside_function_node->getFunctionName() == "lambda") + if (!inside_function_node) + return; + + /// check arguments can not contain arrayJoin or lambda + if (!canRewrite(inside_function_node)) return; auto & inside_arguments = inside_function_node->getArguments().getNodes(); @@ -50,12 +94,6 @@ public: if (inside_arguments.empty()) return; - if (rewritten.count(node.get())) - { - node = rewritten.at(node.get()); - return; - } - /// checking done, rewrite function bool pushed = false; for (auto & inside_argument : inside_arguments)