Merge pull request #47225 from kitaisreal/array-exists-to-has-pass-fix

Fix RewriteArrayExistsToHasPass
This commit is contained in:
Maksim Kita 2023-03-08 17:44:24 +03:00 committed by GitHub
commit fb45fd758d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 87 additions and 61 deletions

View File

@ -1,3 +1,5 @@
#include <Analyzer/Passes/ArrayExistsToHasPass.h>
#include <Functions/FunctionFactory.h>
#include <Interpreters/Context.h>
@ -8,71 +10,85 @@
#include <Analyzer/InDepthQueryTreeVisitor.h>
#include <Analyzer/LambdaNode.h>
#include "ArrayExistsToHasPass.h"
namespace DB
{
namespace
{
class RewriteArrayExistsToHasVisitor : public InDepthQueryTreeVisitorWithContext<RewriteArrayExistsToHasVisitor>
class RewriteArrayExistsToHasVisitor : public InDepthQueryTreeVisitorWithContext<RewriteArrayExistsToHasVisitor>
{
public:
using Base = InDepthQueryTreeVisitorWithContext<RewriteArrayExistsToHasVisitor>;
using Base::Base;
void visitImpl(QueryTreeNodePtr & node)
{
public:
using Base = InDepthQueryTreeVisitorWithContext<RewriteArrayExistsToHasVisitor>;
using Base::Base;
if (!getSettings().optimize_rewrite_array_exists_to_has)
return;
void visitImpl(QueryTreeNodePtr & node)
auto * array_exists_function_node = node->as<FunctionNode>();
if (!array_exists_function_node || array_exists_function_node->getFunctionName() != "arrayExists")
return;
auto & array_exists_function_arguments_nodes = array_exists_function_node->getArguments().getNodes();
if (array_exists_function_arguments_nodes.size() != 2)
return;
/// lambda function must be like: x -> x = elem
auto * lambda_node = array_exists_function_arguments_nodes[0]->as<LambdaNode>();
if (!lambda_node)
return;
auto & lambda_arguments_nodes = lambda_node->getArguments().getNodes();
if (lambda_arguments_nodes.size() != 1)
return;
const auto & lambda_argument_column_node = lambda_arguments_nodes[0];
if (lambda_argument_column_node->getNodeType() != QueryTreeNodeType::COLUMN)
return;
auto * filter_node = lambda_node->getExpression()->as<FunctionNode>();
if (!filter_node || filter_node->getFunctionName() != "equals")
return;
const auto & filter_arguments_nodes = filter_node->getArguments().getNodes();
if (filter_arguments_nodes.size() != 2)
return;
const auto & filter_lhs_argument_node = filter_arguments_nodes[0];
auto filter_lhs_argument_node_type = filter_lhs_argument_node->getNodeType();
const auto & filter_rhs_argument_node = filter_arguments_nodes[1];
auto filter_rhs_argument_node_type = filter_rhs_argument_node->getNodeType();
QueryTreeNodePtr has_constant_element_argument;
if (filter_lhs_argument_node_type == QueryTreeNodeType::COLUMN &&
filter_rhs_argument_node_type == QueryTreeNodeType::CONSTANT &&
filter_lhs_argument_node->isEqual(*lambda_argument_column_node))
{
if (!getSettings().optimize_rewrite_array_exists_to_has)
return;
auto * function_node = node->as<FunctionNode>();
if (!function_node || function_node->getFunctionName() != "arrayExists")
return;
auto & function_arguments_nodes = function_node->getArguments().getNodes();
if (function_arguments_nodes.size() != 2)
return;
/// lambda function must be like: x -> x = elem
auto * lambda_node = function_arguments_nodes[0]->as<LambdaNode>();
if (!lambda_node)
return;
auto & lambda_arguments_nodes = lambda_node->getArguments().getNodes();
if (lambda_arguments_nodes.size() != 1)
return;
auto * column_node = lambda_arguments_nodes[0]->as<ColumnNode>();
auto * filter_node = lambda_node->getExpression()->as<FunctionNode>();
if (!filter_node || filter_node->getFunctionName() != "equals")
return;
auto filter_arguments_nodes = filter_node->getArguments().getNodes();
if (filter_arguments_nodes.size() != 2)
return;
ColumnNode * filter_column_node = nullptr;
if (filter_arguments_nodes[1]->as<ConstantNode>() && (filter_column_node = filter_arguments_nodes[0]->as<ColumnNode>())
&& filter_column_node->getColumnName() == column_node->getColumnName())
{
/// Rewrite arrayExists(x -> x = elem, arr) -> has(arr, elem)
function_arguments_nodes[0] = std::move(function_arguments_nodes[1]);
function_arguments_nodes[1] = std::move(filter_arguments_nodes[1]);
function_node->resolveAsFunction(
FunctionFactory::instance().get("has", getContext())->build(function_node->getArgumentColumns()));
}
else if (
filter_arguments_nodes[0]->as<ConstantNode>() && (filter_column_node = filter_arguments_nodes[1]->as<ColumnNode>())
&& filter_column_node->getColumnName() == column_node->getColumnName())
{
/// Rewrite arrayExists(x -> elem = x, arr) -> has(arr, elem)
function_arguments_nodes[0] = std::move(function_arguments_nodes[1]);
function_arguments_nodes[1] = std::move(filter_arguments_nodes[0]);
function_node->resolveAsFunction(
FunctionFactory::instance().get("has", getContext())->build(function_node->getArgumentColumns()));
}
/// Rewrite arrayExists(x -> x = elem, arr) -> has(arr, elem)
has_constant_element_argument = filter_rhs_argument_node;
}
};
else if (filter_lhs_argument_node_type == QueryTreeNodeType::CONSTANT &&
filter_rhs_argument_node_type == QueryTreeNodeType::COLUMN &&
filter_rhs_argument_node->isEqual(*lambda_argument_column_node))
{
/// Rewrite arrayExists(x -> elem = x, arr) -> has(arr, elem)
has_constant_element_argument = filter_lhs_argument_node;
}
else
{
return;
}
auto has_function = FunctionFactory::instance().get("has", getContext());
array_exists_function_arguments_nodes[0] = std::move(array_exists_function_arguments_nodes[1]);
array_exists_function_arguments_nodes[1] = std::move(has_constant_element_argument);
array_exists_function_node->resolveAsFunction(has_function->build(array_exists_function_node->getArgumentColumns()));
}
};
}

View File

@ -4,8 +4,15 @@
namespace DB
{
/// Rewrite possible 'arrayExists(func, arr)' to 'has(arr, elem)' to improve performance
/// arrayExists(x -> x = 1, arr) -> has(arr, 1)
/** Rewrite possible 'arrayExists(func, arr)' to 'has(arr, elem)' to improve performance.
*
* Example: SELECT arrayExists(x -> x = 1, arr);
* Result: SELECT has(arr, 1);
*
* Example: SELECT arrayExists(x -> 1 = x, arr);
* Result: SELECT has(arr, 1);
*/
class RewriteArrayExistsToHasPass final : public IQueryTreePass
{
public:
@ -15,4 +22,5 @@ public:
void run(QueryTreeNodePtr query_tree_node, ContextPtr context) override;
};
}

View File

@ -22,8 +22,7 @@ public:
void visitImpl(QueryTreeNodePtr & node)
{
const auto & context = getContext();
if (!context->getSettingsRef().final)
if (!getSettings().final)
return;
const auto * query_node = node->as<QueryNode>();

View File

@ -6,6 +6,9 @@ namespace DB
{
/** Rewrite _shard_num column into shardNum() function.
*
* Example: SELECT _shard_num FROM distributed_table;
* Result: SELECT shardNum() FROM distributed_table;
*/
class ShardNumColumnToFunctionPass final : public IQueryTreePass
{