add passes for rewriting arrayexists

This commit is contained in:
taiyang-li 2023-02-09 16:30:53 +08:00
parent af7a6abf80
commit 19ca0ec4af
7 changed files with 254 additions and 11 deletions

View File

@ -0,0 +1,88 @@
#include <Functions/FunctionFactory.h>
#include <Interpreters/Context.h>
#include <Analyzer/ColumnNode.h>
#include <Analyzer/ConstantNode.h>
#include <Analyzer/FunctionNode.h>
#include <Analyzer/InDepthQueryTreeVisitor.h>
#include <Analyzer/LambdaNode.h>
#include "ArrayExistsToHasPass.h"
namespace DB
{
namespace
{
class RewriteArrayExistsToHasVisitor : public InDepthQueryTreeVisitorWithContext<RewriteArrayExistsToHasVisitor>
{
public:
using Base = InDepthQueryTreeVisitorWithContext<RewriteArrayExistsToHasVisitor>;
using Base::Base;
void visitImpl(QueryTreeNodePtr & node)
{
if (!getSettings().optimize_rewrite_array_exists_to_has)
return;
auto * function_node = node->as<FunctionNode>();
if (!function_node || function_node->getFunctionName() != "arrayExists")
return;
auto & function_arguments_nodes = function_node->getArguments().getNodes();
if (function_arguments_nodes.size() != 2)
return;
/// lambda function must be like: x -> x = elem
auto * lambda_node = function_arguments_nodes[0]->as<LambdaNode>();
if (!lambda_node)
return;
auto & lambda_arguments_nodes = lambda_node->getArguments().getNodes();
if (lambda_arguments_nodes.size() != 1)
return;
auto * column_node = lambda_arguments_nodes[0]->as<ColumnNode>();
auto * filter_node = lambda_node->getExpression()->as<FunctionNode>();
if (!filter_node || filter_node->getFunctionName() != "equals")
return;
auto filter_arguments_nodes = filter_node->getArguments().getNodes();
if (filter_arguments_nodes.size() != 2)
return;
ConstantNode * filter_constant_node = nullptr;
ColumnNode * filter_column_node = nullptr;
if ((filter_constant_node = filter_arguments_nodes[1]->as<ConstantNode>())
&& (filter_column_node = filter_arguments_nodes[0]->as<ColumnNode>())
&& filter_column_node->getColumnName() == column_node->getColumnName())
{
/// Rewrite arrayExists(x -> x = elem, arr) -> has(arr, elem)
function_arguments_nodes[0] = std::move(function_arguments_nodes[1]);
function_arguments_nodes[1] = std::move(filter_arguments_nodes[1]);
function_node->resolveAsFunction(
FunctionFactory::instance().get("has", getContext())->build(function_node->getArgumentColumns()));
}
else if (
(filter_constant_node = filter_arguments_nodes[0]->as<ConstantNode>())
&& (filter_column_node = filter_arguments_nodes[1]->as<ColumnNode>())
&& filter_column_node->getColumnName() == column_node->getColumnName())
{
/// Rewrite arrayExists(x -> elem = x, arr) -> has(arr, elem)
function_arguments_nodes[0] = std::move(function_arguments_nodes[1]);
function_arguments_nodes[1] = std::move(filter_arguments_nodes[0]);
function_node->resolveAsFunction(
FunctionFactory::instance().get("has", getContext())->build(function_node->getArgumentColumns()));
}
}
};
}
void RewriteArrayExistsToHasPass::run(QueryTreeNodePtr query_tree_node, ContextPtr context)
{
RewriteArrayExistsToHasVisitor visitor(context);
visitor.visit(query_tree_node);
}
}

View File

@ -0,0 +1,18 @@
#pragma once
#include <Analyzer/IQueryTreePass.h>
namespace DB
{
/// Rewrite possible 'arrayExists(func, arr)' to 'has(arr, elem)' to improve performance
/// arrayExists(x -> x = 1, arr) -> has(arr, 1)
class RewriteArrayExistsToHasPass final : public IQueryTreePass
{
public:
String getName() override { return "RewriteArrayExistsToHas"; }
String getDescription() override { return "Rewrite arrayExists(func, arr) functions to has(arr, elem) when logically equivalent"; }
void run(QueryTreeNodePtr query_tree_node, ContextPtr context) override;
};
}

View File

@ -35,6 +35,7 @@
#include <Analyzer/Passes/ConvertOrLikeChainPass.h>
#include <Analyzer/Passes/OptimizeRedundantFunctionsInOrderByPass.h>
#include <Analyzer/Passes/GroupingFunctionsResolvePass.h>
#include <Analyzer/Passes/ArrayExistsToHasPass.h>
namespace DB
{
@ -217,6 +218,7 @@ void addQueryTreePasses(QueryTreePassManager & manager)
manager.addPass(std::make_unique<CountDistinctPass>());
manager.addPass(std::make_unique<RewriteAggregateFunctionWithIfPass>());
manager.addPass(std::make_unique<SumIfToCountIfPass>());
manager.addPass(std::make_unique<RewriteArrayExistsToHasPass>());
manager.addPass(std::make_unique<NormalizeCountVariantsPass>());
manager.addPass(std::make_unique<CustomizeFunctionsPass>());

View File

@ -1,11 +1,10 @@
#include <Interpreters/RewriteArrayExistsFunctionVisitor.h>
#include <Parsers/ASTFunction.h>
#include <Parsers/ASTLiteral.h>
#include <Parsers/ASTIdentifier.h>
#include <Parsers/ASTLiteral.h>
namespace DB
{
void RewriteArrayExistsFunctionMatcher::visit(ASTPtr & ast, Data & data)
{
if (auto * func = ast->as<ASTFunction>())
@ -19,10 +18,7 @@ void RewriteArrayExistsFunctionMatcher::visit(ASTPtr & ast, Data & data)
void RewriteArrayExistsFunctionMatcher::visit(const ASTFunction & func, ASTPtr & ast, Data &)
{
if (!func.arguments || func.arguments->children.empty())
return;
if (func.name != "arrayExists")
if (func.name != "arrayExists" || !func.arguments)
return;
auto & array_exists_arguments = func.arguments->children;
@ -35,6 +31,9 @@ void RewriteArrayExistsFunctionMatcher::visit(const ASTFunction & func, ASTPtr &
return;
const auto & lambda_func_arguments = lambda_func->arguments->children;
if (lambda_func_arguments.size() != 2)
return;
const auto * tuple_func = lambda_func_arguments[0]->as<ASTFunction>();
if (!tuple_func || tuple_func->name != "tuple")
return;
@ -70,7 +69,6 @@ void RewriteArrayExistsFunctionMatcher::visit(const ASTFunction & func, ASTPtr &
(filter_id = filter_arguments[1]->as<ASTIdentifier>()) && (filter_literal = filter_arguments[0]->as<ASTLiteral>())
&& filter_id->full_name == id->full_name)
{
/// arrayExists(x -> elem = x, arr) -> has(arr, elem)
auto new_func = makeASTFunction("has", std::move(array_exists_arguments[1]), std::move(filter_arguments[0]));
new_func->setAlias(func.alias);

View File

@ -1,20 +1,20 @@
#pragma once
#include <Parsers/IAST.h>
#include <Interpreters/InDepthNodeVisitor.h>
#include <Parsers/IAST.h>
namespace DB
{
class ASTFunction;
/// Rewrite possible 'arrayExists(func, arr)' to 'has(arr, elem)' to improve performance
/// arrayExists(x -> x = 1, arr) -> has(arr, 1)
class RewriteArrayExistsFunctionMatcher
{
public:
struct Data{};
struct Data
{
};
static void visit(ASTPtr & ast, Data &);
static void visit(const ASTFunction &, ASTPtr & ast, Data &);

View File

@ -0,0 +1,118 @@
QUERY id: 0
PROJECTION COLUMNS
arrayExists(lambda(tuple(x), equals(x, 5)), materialize(range(10))) UInt8
PROJECTION
LIST id: 1, nodes: 1
FUNCTION id: 2, function_name: arrayExists, function_type: ordinary, result_type: UInt8
ARGUMENTS
LIST id: 3, nodes: 2
LAMBDA id: 4
ARGUMENTS
LIST id: 5, nodes: 1
COLUMN id: 6, column_name: x, result_type: UInt8, source_id: 4
EXPRESSION
FUNCTION id: 7, function_name: equals, function_type: ordinary, result_type: UInt8
ARGUMENTS
LIST id: 8, nodes: 2
COLUMN id: 6, column_name: x, result_type: UInt8, source_id: 4
CONSTANT id: 9, constant_value: UInt64_5, constant_value_type: UInt8
FUNCTION id: 10, function_name: materialize, function_type: ordinary, result_type: Array(UInt8)
ARGUMENTS
LIST id: 11, nodes: 1
CONSTANT id: 12, constant_value: Array_[UInt64_0, UInt64_1, UInt64_2, UInt64_3, UInt64_4, UInt64_5, UInt64_6, UInt64_7, UInt64_8, UInt64_9], constant_value_type: Array(UInt8)
EXPRESSION
FUNCTION id: 13, function_name: range, function_type: ordinary, result_type: Array(UInt8)
ARGUMENTS
LIST id: 14, nodes: 1
CONSTANT id: 15, constant_value: UInt64_10, constant_value_type: UInt8
JOIN TREE
TABLE_FUNCTION id: 16, table_function_name: numbers
ARGUMENTS
LIST id: 17, nodes: 1
CONSTANT id: 18, constant_value: UInt64_10, constant_value_type: UInt8
QUERY id: 0
PROJECTION COLUMNS
arrayExists(lambda(tuple(x), equals(5, x)), materialize(range(10))) UInt8
PROJECTION
LIST id: 1, nodes: 1
FUNCTION id: 2, function_name: arrayExists, function_type: ordinary, result_type: UInt8
ARGUMENTS
LIST id: 3, nodes: 2
LAMBDA id: 4
ARGUMENTS
LIST id: 5, nodes: 1
COLUMN id: 6, column_name: x, result_type: UInt8, source_id: 4
EXPRESSION
FUNCTION id: 7, function_name: equals, function_type: ordinary, result_type: UInt8
ARGUMENTS
LIST id: 8, nodes: 2
CONSTANT id: 9, constant_value: UInt64_5, constant_value_type: UInt8
COLUMN id: 6, column_name: x, result_type: UInt8, source_id: 4
FUNCTION id: 10, function_name: materialize, function_type: ordinary, result_type: Array(UInt8)
ARGUMENTS
LIST id: 11, nodes: 1
CONSTANT id: 12, constant_value: Array_[UInt64_0, UInt64_1, UInt64_2, UInt64_3, UInt64_4, UInt64_5, UInt64_6, UInt64_7, UInt64_8, UInt64_9], constant_value_type: Array(UInt8)
EXPRESSION
FUNCTION id: 13, function_name: range, function_type: ordinary, result_type: Array(UInt8)
ARGUMENTS
LIST id: 14, nodes: 1
CONSTANT id: 15, constant_value: UInt64_10, constant_value_type: UInt8
JOIN TREE
TABLE_FUNCTION id: 16, table_function_name: numbers
ARGUMENTS
LIST id: 17, nodes: 1
CONSTANT id: 18, constant_value: UInt64_10, constant_value_type: UInt8
QUERY id: 0
PROJECTION COLUMNS
arrayExists(lambda(tuple(x), equals(x, 5)), materialize(range(10))) UInt8
PROJECTION
LIST id: 1, nodes: 1
FUNCTION id: 2, function_name: has, function_type: ordinary, result_type: UInt8
ARGUMENTS
LIST id: 3, nodes: 2
FUNCTION id: 4, function_name: materialize, function_type: ordinary, result_type: Array(UInt8)
ARGUMENTS
LIST id: 5, nodes: 1
CONSTANT id: 6, constant_value: Array_[UInt64_0, UInt64_1, UInt64_2, UInt64_3, UInt64_4, UInt64_5, UInt64_6, UInt64_7, UInt64_8, UInt64_9], constant_value_type: Array(UInt8)
EXPRESSION
FUNCTION id: 7, function_name: range, function_type: ordinary, result_type: Array(UInt8)
ARGUMENTS
LIST id: 8, nodes: 1
CONSTANT id: 9, constant_value: UInt64_10, constant_value_type: UInt8
CONSTANT id: 10, constant_value: UInt64_5, constant_value_type: UInt8
JOIN TREE
TABLE_FUNCTION id: 11, table_function_name: numbers
ARGUMENTS
LIST id: 12, nodes: 1
CONSTANT id: 13, constant_value: UInt64_10, constant_value_type: UInt8
QUERY id: 0
PROJECTION COLUMNS
arrayExists(lambda(tuple(x), equals(5, x)), materialize(range(10))) UInt8
PROJECTION
LIST id: 1, nodes: 1
FUNCTION id: 2, function_name: has, function_type: ordinary, result_type: UInt8
ARGUMENTS
LIST id: 3, nodes: 2
FUNCTION id: 4, function_name: materialize, function_type: ordinary, result_type: Array(UInt8)
ARGUMENTS
LIST id: 5, nodes: 1
CONSTANT id: 6, constant_value: Array_[UInt64_0, UInt64_1, UInt64_2, UInt64_3, UInt64_4, UInt64_5, UInt64_6, UInt64_7, UInt64_8, UInt64_9], constant_value_type: Array(UInt8)
EXPRESSION
FUNCTION id: 7, function_name: range, function_type: ordinary, result_type: Array(UInt8)
ARGUMENTS
LIST id: 8, nodes: 1
CONSTANT id: 9, constant_value: UInt64_10, constant_value_type: UInt8
CONSTANT id: 10, constant_value: UInt64_5, constant_value_type: UInt8
JOIN TREE
TABLE_FUNCTION id: 11, table_function_name: numbers
ARGUMENTS
LIST id: 12, nodes: 1
CONSTANT id: 13, constant_value: UInt64_10, constant_value_type: UInt8
SELECT arrayExists(x -> (x = 5), materialize(range(10)))
FROM numbers(10)
SELECT arrayExists(x -> (5 = x), materialize(range(10)))
FROM numbers(10)
SELECT has(materialize(range(10)), 5)
FROM numbers(10)
SELECT has(materialize(range(10)), 5)
FROM numbers(10)

View File

@ -0,0 +1,19 @@
set allow_experimental_analyzer = true;
set optimize_rewrite_array_exists_to_has = false;
EXPLAIN QUERY TREE run_passes = 1 select arrayExists(x -> x = 5 , materialize(range(10))) from numbers(10);
EXPLAIN QUERY TREE run_passes = 1 select arrayExists(x -> 5 = x , materialize(range(10))) from numbers(10);
set optimize_rewrite_array_exists_to_has = true;
EXPLAIN QUERY TREE run_passes = 1 select arrayExists(x -> x = 5 , materialize(range(10))) from numbers(10);
EXPLAIN QUERY TREE run_passes = 1 select arrayExists(x -> 5 = x , materialize(range(10))) from numbers(10);
set allow_experimental_analyzer = false;
set optimize_rewrite_array_exists_to_has = false;
EXPLAIN SYNTAX select arrayExists(x -> x = 5 , materialize(range(10))) from numbers(10);
EXPLAIN SYNTAX select arrayExists(x -> 5 = x , materialize(range(10))) from numbers(10);
set optimize_rewrite_array_exists_to_has = true;
EXPLAIN SYNTAX select arrayExists(x -> x = 5 , materialize(range(10))) from numbers(10);
EXPLAIN SYNTAX select arrayExists(x -> 5 = x , materialize(range(10))) from numbers(10);