skip rewriting for lambda and arrayJoin

This commit is contained in:
JackyWoo 2023-07-14 17:27:32 +08:00
parent eb6c1cb549
commit a2dce9663e

View File

@ -5,7 +5,9 @@
#include <Analyzer/InDepthQueryTreeVisitor.h>
#include <Analyzer/FunctionNode.h>
#include <Analyzer/LambdaNode.h>
#include <Analyzer/ConstantNode.h>
#include <Analyzer/ArrayJoinNode.h>
namespace DB
{
@ -15,6 +17,39 @@ namespace
class AnyFunctionVisitor : public InDepthQueryTreeVisitorWithContext<AnyFunctionVisitor>
{
private:
bool canRewrite(const FunctionNode * function_node)
{
for (auto & argument : function_node->getArguments().getNodes())
{
/// arrayJoin() is special and should not be optimized (think about
/// it as a an aggregate function), otherwise wrong result will be
/// produced:
/// SELECT *, any(arrayJoin([[], []])) FROM numbers(1) GROUP BY number
/// ┌─number─┬─arrayJoin(array(array(), array()))─┐
/// │ 0 │ [] │
/// │ 0 │ [] │
/// └────────┴────────────────────────────────────┘
/// While should be:
/// ┌─number─┬─any(arrayJoin(array(array(), array())))─┐
/// │ 0 │ [] │
/// └────────┴─────────────────────────────────────────┘
if (argument->as<LambdaNode>())
return false;
if (argument->as<ArrayJoinNode>())
return false;
if (const auto * inside_function = argument->as<FunctionNode>())
{
if (!canRewrite(inside_function))
return false;
}
}
return true;
}
public:
using Base = InDepthQueryTreeVisitorWithContext<AnyFunctionVisitor>;
using Base::Base;
@ -24,6 +59,12 @@ public:
if (!getSettings().optimize_move_functions_out_of_any)
return;
if (rewritten.count(node.get()))
{
node = rewritten.at(node.get());
return;
}
auto * function_node = node->as<FunctionNode>();
if (!function_node)
return;
@ -40,8 +81,11 @@ public:
auto * inside_function_node = arguments[0]->as<FunctionNode>();
/// check argument is a function and can not be arrayJoin or lambda
if (!inside_function_node || inside_function_node->getFunctionName() == "arrayJoin"
|| inside_function_node->getFunctionName() == "lambda")
if (!inside_function_node)
return;
/// check arguments can not contain arrayJoin or lambda
if (!canRewrite(inside_function_node))
return;
auto & inside_arguments = inside_function_node->getArguments().getNodes();
@ -50,12 +94,6 @@ public:
if (inside_arguments.empty())
return;
if (rewritten.count(node.get()))
{
node = rewritten.at(node.get());
return;
}
/// checking done, rewrite function
bool pushed = false;
for (auto & inside_argument : inside_arguments)