Merge pull request #56520 from kitaisreal/analyzer-move-functions-out-of-any-pass-refactoring

Analyzer MoveFunctionsOutOfAnyPass refactoring
This commit is contained in:
Nikolai Kochetov 2023-11-10 12:27:14 +01:00 committed by GitHub
commit 0898cf3e06
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 85 additions and 87 deletions

View File

@ -1,4 +1,4 @@
#include <Analyzer/Passes/AnyFunctionPass.h> #include <Analyzer/Passes/MoveFunctionsOutOfAnyPass.h>
#include <AggregateFunctions/AggregateFunctionFactory.h> #include <AggregateFunctions/AggregateFunctionFactory.h>
#include <AggregateFunctions/IAggregateFunction.h> #include <AggregateFunctions/IAggregateFunction.h>
@ -14,8 +14,80 @@ namespace DB
namespace namespace
{ {
class AnyFunctionVisitor : public InDepthQueryTreeVisitorWithContext<AnyFunctionVisitor> class AnyFunctionViMoveFunctionsOutOfAnyVisitor : public InDepthQueryTreeVisitorWithContext<AnyFunctionViMoveFunctionsOutOfAnyVisitor>
{ {
public:
using Base = InDepthQueryTreeVisitorWithContext<AnyFunctionViMoveFunctionsOutOfAnyVisitor>;
using Base::Base;
void enterImpl(QueryTreeNodePtr & node)
{
if (!getSettings().optimize_move_functions_out_of_any)
return;
auto * function_node = node->as<FunctionNode>();
if (!function_node)
return;
/// check function is any
const auto & function_name = function_node->getFunctionName();
if (function_name != "any" && function_name != "anyLast")
return;
auto & arguments = function_node->getArguments().getNodes();
if (arguments.size() != 1)
return;
auto * inside_function_node = arguments[0]->as<FunctionNode>();
/// check argument is a function
if (!inside_function_node)
return;
/// check arguments can not contain arrayJoin or lambda
if (!canRewrite(inside_function_node))
return;
auto & inside_function_node_arguments = inside_function_node->getArguments().getNodes();
/// case any(f())
if (inside_function_node_arguments.empty())
return;
auto it = node_to_rewritten_node.find(node.get());
if (it != node_to_rewritten_node.end())
{
node = it->second;
return;
}
/// checking done, rewrite function
bool changed_argument = false;
for (auto & inside_argument : inside_function_node_arguments)
{
if (inside_argument->as<ConstantNode>()) /// skip constant node
break;
AggregateFunctionProperties properties;
auto aggregate_function = AggregateFunctionFactory::instance().get(function_name, {inside_argument->getResultType()}, {}, properties);
auto any_function = std::make_shared<FunctionNode>(function_name);
any_function->resolveAsAggregateFunction(std::move(aggregate_function));
auto & any_function_arguments = any_function->getArguments().getNodes();
any_function_arguments.push_back(std::move(inside_argument));
inside_argument = std::move(any_function);
changed_argument = true;
}
if (changed_argument)
{
node_to_rewritten_node.emplace(node.get(), arguments[0]);
node = arguments[0];
}
}
private: private:
bool canRewrite(const FunctionNode * function_node) bool canRewrite(const FunctionNode * function_node)
{ {
@ -45,90 +117,17 @@ private:
return true; return true;
} }
public: /// After query analysis, alias identifier will be resolved to node whose memory address is same with the original one.
using Base = InDepthQueryTreeVisitorWithContext<AnyFunctionVisitor>; /// So we can reuse the rewritten function.
using Base::Base; std::unordered_map<IQueryTreeNode *, QueryTreeNodePtr> node_to_rewritten_node;
void enterImpl(QueryTreeNodePtr & node)
{
if (!getSettings().optimize_move_functions_out_of_any)
return;
auto * function_node = node->as<FunctionNode>();
if (!function_node)
return;
/// check function is any
const auto & function_name = function_node->getFunctionName();
if (!(function_name == "any" || function_name == "anyLast"))
return;
auto & arguments = function_node->getArguments().getNodes();
if (arguments.size() != 1)
return;
auto * inside_function_node = arguments[0]->as<FunctionNode>();
/// check argument is a function
if (!inside_function_node)
return;
/// check arguments can not contain arrayJoin or lambda
if (!canRewrite(inside_function_node))
return;
auto & inside_arguments = inside_function_node->getArguments().getNodes();
/// case any(f())
if (inside_arguments.empty())
return;
if (rewritten.contains(node.get()))
{
node = rewritten.at(node.get());
return;
}
/// checking done, rewrite function
bool pushed = false;
for (auto & inside_argument : inside_arguments)
{
if (inside_argument->as<ConstantNode>()) /// skip constant node
break;
AggregateFunctionProperties properties;
auto aggregate_function = AggregateFunctionFactory::instance().get(function_name, {inside_argument->getResultType()}, {}, properties);
auto any_function = std::make_shared<FunctionNode>(function_name);
any_function->resolveAsAggregateFunction(std::move(aggregate_function));
auto & any_function_arguments = any_function->getArguments().getNodes();
any_function_arguments.push_back(std::move(inside_argument));
inside_argument = std::move(any_function);
pushed = true;
}
if (pushed)
{
rewritten.insert({node.get(), arguments[0]});
node = arguments[0];
}
}
private:
/// After query analysis alias will be rewritten to QueryTreeNode
/// whose memory address is same with the original one.
/// So we can reuse the rewritten one.
std::unordered_map<IQueryTreeNode *, QueryTreeNodePtr > rewritten;
}; };
} }
void AnyFunctionPass::run(QueryTreeNodePtr query_tree_node, ContextPtr context) void MoveFunctionsOutOfAnyPass::run(QueryTreeNodePtr query_tree_node, ContextPtr context)
{ {
AnyFunctionVisitor visitor(context); AnyFunctionViMoveFunctionsOutOfAnyVisitor visitor(context);
visitor.visit(query_tree_node); visitor.visit(query_tree_node);
} }

View File

@ -7,13 +7,13 @@ namespace DB
/** Rewrite 'any' and 'anyLast' functions pushing them inside original function. /** Rewrite 'any' and 'anyLast' functions pushing them inside original function.
* *
* Example: any(f(x, y, g(z))) * Example: SELECT any(f(x, y, g(z)));
* Result: f(any(x), any(y), g(any(z))) * Result: SELECT f(any(x), any(y), g(any(z)));
*/ */
class AnyFunctionPass final : public IQueryTreePass class MoveFunctionsOutOfAnyPass final : public IQueryTreePass
{ {
public: public:
String getName() override { return "AnyFunction"; } String getName() override { return "MoveFunctionsOutOfAnyPass"; }
String getDescription() override String getDescription() override
{ {

View File

@ -43,7 +43,7 @@
#include <Analyzer/Passes/CrossToInnerJoinPass.h> #include <Analyzer/Passes/CrossToInnerJoinPass.h>
#include <Analyzer/Passes/ShardNumColumnToFunctionPass.h> #include <Analyzer/Passes/ShardNumColumnToFunctionPass.h>
#include <Analyzer/Passes/ConvertQueryToCNFPass.h> #include <Analyzer/Passes/ConvertQueryToCNFPass.h>
#include <Analyzer/Passes/AnyFunctionPass.h> #include <Analyzer/Passes/MoveFunctionsOutOfAnyPass.h>
#include <Analyzer/Passes/OptimizeDateOrDateTimeConverterWithPreimagePass.h> #include <Analyzer/Passes/OptimizeDateOrDateTimeConverterWithPreimagePass.h>
@ -164,7 +164,6 @@ private:
* *
* TODO: Support setting optimize_substitute_columns. * TODO: Support setting optimize_substitute_columns.
* TODO: Support GROUP BY injective function elimination. * TODO: Support GROUP BY injective function elimination.
* TODO: Support setting optimize_move_functions_out_of_any.
* TODO: Support setting optimize_aggregators_of_group_by_keys. * TODO: Support setting optimize_aggregators_of_group_by_keys.
* TODO: Support setting optimize_monotonous_functions_in_order_by. * TODO: Support setting optimize_monotonous_functions_in_order_by.
* TODO: Add optimizations based on function semantics. Example: SELECT * FROM test_table WHERE id != id. (id is not nullable column). * TODO: Add optimizations based on function semantics. Example: SELECT * FROM test_table WHERE id != id. (id is not nullable column).
@ -283,7 +282,7 @@ void addQueryTreePasses(QueryTreePassManager & manager)
manager.addPass(std::make_unique<CrossToInnerJoinPass>()); manager.addPass(std::make_unique<CrossToInnerJoinPass>());
manager.addPass(std::make_unique<ShardNumColumnToFunctionPass>()); manager.addPass(std::make_unique<ShardNumColumnToFunctionPass>());
manager.addPass(std::make_unique<AnyFunctionPass>()); manager.addPass(std::make_unique<MoveFunctionsOutOfAnyPass>());
manager.addPass(std::make_unique<OptimizeDateOrDateTimeConverterWithPreimagePass>()); manager.addPass(std::make_unique<OptimizeDateOrDateTimeConverterWithPreimagePass>());
} }