diff --git a/src/Analyzer/Passes/ArrayExistsToHasPass.cpp b/src/Analyzer/Passes/ArrayExistsToHasPass.cpp index b4b8b5b4579..c0f958588f1 100644 --- a/src/Analyzer/Passes/ArrayExistsToHasPass.cpp +++ b/src/Analyzer/Passes/ArrayExistsToHasPass.cpp @@ -1,3 +1,5 @@ +#include + #include #include @@ -8,71 +10,85 @@ #include #include -#include "ArrayExistsToHasPass.h" - namespace DB { + namespace { - class RewriteArrayExistsToHasVisitor : public InDepthQueryTreeVisitorWithContext + +class RewriteArrayExistsToHasVisitor : public InDepthQueryTreeVisitorWithContext +{ +public: + using Base = InDepthQueryTreeVisitorWithContext; + using Base::Base; + + void visitImpl(QueryTreeNodePtr & node) { - public: - using Base = InDepthQueryTreeVisitorWithContext; - using Base::Base; + if (!getSettings().optimize_rewrite_array_exists_to_has) + return; - void visitImpl(QueryTreeNodePtr & node) + auto * array_exists_function_node = node->as(); + if (!array_exists_function_node || array_exists_function_node->getFunctionName() != "arrayExists") + return; + + auto & array_exists_function_arguments_nodes = array_exists_function_node->getArguments().getNodes(); + if (array_exists_function_arguments_nodes.size() != 2) + return; + + /// lambda function must be like: x -> x = elem + auto * lambda_node = array_exists_function_arguments_nodes[0]->as(); + if (!lambda_node) + return; + + auto & lambda_arguments_nodes = lambda_node->getArguments().getNodes(); + if (lambda_arguments_nodes.size() != 1) + return; + + const auto & lambda_argument_column_node = lambda_arguments_nodes[0]; + if (lambda_argument_column_node->getNodeType() != QueryTreeNodeType::COLUMN) + return; + + auto * filter_node = lambda_node->getExpression()->as(); + if (!filter_node || filter_node->getFunctionName() != "equals") + return; + + const auto & filter_arguments_nodes = filter_node->getArguments().getNodes(); + if (filter_arguments_nodes.size() != 2) + return; + + const auto & filter_lhs_argument_node = filter_arguments_nodes[0]; + auto filter_lhs_argument_node_type = filter_lhs_argument_node->getNodeType(); + + const auto & filter_rhs_argument_node = filter_arguments_nodes[1]; + auto filter_rhs_argument_node_type = filter_rhs_argument_node->getNodeType(); + + QueryTreeNodePtr has_constant_element_argument; + + if (filter_lhs_argument_node_type == QueryTreeNodeType::COLUMN && + filter_rhs_argument_node_type == QueryTreeNodeType::CONSTANT && + filter_lhs_argument_node->isEqual(*lambda_argument_column_node)) { - if (!getSettings().optimize_rewrite_array_exists_to_has) - return; - - auto * function_node = node->as(); - if (!function_node || function_node->getFunctionName() != "arrayExists") - return; - - auto & function_arguments_nodes = function_node->getArguments().getNodes(); - if (function_arguments_nodes.size() != 2) - return; - - /// lambda function must be like: x -> x = elem - auto * lambda_node = function_arguments_nodes[0]->as(); - if (!lambda_node) - return; - - auto & lambda_arguments_nodes = lambda_node->getArguments().getNodes(); - if (lambda_arguments_nodes.size() != 1) - return; - auto * column_node = lambda_arguments_nodes[0]->as(); - - auto * filter_node = lambda_node->getExpression()->as(); - if (!filter_node || filter_node->getFunctionName() != "equals") - return; - - auto filter_arguments_nodes = filter_node->getArguments().getNodes(); - if (filter_arguments_nodes.size() != 2) - return; - - ColumnNode * filter_column_node = nullptr; - if (filter_arguments_nodes[1]->as() && (filter_column_node = filter_arguments_nodes[0]->as()) - && filter_column_node->getColumnName() == column_node->getColumnName()) - { - /// Rewrite arrayExists(x -> x = elem, arr) -> has(arr, elem) - function_arguments_nodes[0] = std::move(function_arguments_nodes[1]); - function_arguments_nodes[1] = std::move(filter_arguments_nodes[1]); - function_node->resolveAsFunction( - FunctionFactory::instance().get("has", getContext())->build(function_node->getArgumentColumns())); - } - else if ( - filter_arguments_nodes[0]->as() && (filter_column_node = filter_arguments_nodes[1]->as()) - && filter_column_node->getColumnName() == column_node->getColumnName()) - { - /// Rewrite arrayExists(x -> elem = x, arr) -> has(arr, elem) - function_arguments_nodes[0] = std::move(function_arguments_nodes[1]); - function_arguments_nodes[1] = std::move(filter_arguments_nodes[0]); - function_node->resolveAsFunction( - FunctionFactory::instance().get("has", getContext())->build(function_node->getArgumentColumns())); - } + /// Rewrite arrayExists(x -> x = elem, arr) -> has(arr, elem) + has_constant_element_argument = filter_rhs_argument_node; } - }; + else if (filter_lhs_argument_node_type == QueryTreeNodeType::CONSTANT && + filter_rhs_argument_node_type == QueryTreeNodeType::COLUMN && + filter_rhs_argument_node->isEqual(*lambda_argument_column_node)) + { + /// Rewrite arrayExists(x -> elem = x, arr) -> has(arr, elem) + has_constant_element_argument = filter_lhs_argument_node; + } + else + { + return; + } + + auto has_function = FunctionFactory::instance().get("has", getContext()); + array_exists_function_arguments_nodes[0] = std::move(array_exists_function_arguments_nodes[1]); + array_exists_function_arguments_nodes[1] = std::move(has_constant_element_argument); + array_exists_function_node->resolveAsFunction(has_function->build(array_exists_function_node->getArgumentColumns())); + } +}; } diff --git a/src/Analyzer/Passes/ArrayExistsToHasPass.h b/src/Analyzer/Passes/ArrayExistsToHasPass.h index 7d9d1cf3d68..8f4623116e3 100644 --- a/src/Analyzer/Passes/ArrayExistsToHasPass.h +++ b/src/Analyzer/Passes/ArrayExistsToHasPass.h @@ -4,8 +4,15 @@ namespace DB { -/// Rewrite possible 'arrayExists(func, arr)' to 'has(arr, elem)' to improve performance -/// arrayExists(x -> x = 1, arr) -> has(arr, 1) + +/** Rewrite possible 'arrayExists(func, arr)' to 'has(arr, elem)' to improve performance. + * + * Example: SELECT arrayExists(x -> x = 1, arr); + * Result: SELECT has(arr, 1); + * + * Example: SELECT arrayExists(x -> 1 = x, arr); + * Result: SELECT has(arr, 1); + */ class RewriteArrayExistsToHasPass final : public IQueryTreePass { public: @@ -15,4 +22,5 @@ public: void run(QueryTreeNodePtr query_tree_node, ContextPtr context) override; }; + } diff --git a/src/Analyzer/Passes/AutoFinalOnQueryPass.cpp b/src/Analyzer/Passes/AutoFinalOnQueryPass.cpp index fdf818681d7..fa5fc0e75a8 100644 --- a/src/Analyzer/Passes/AutoFinalOnQueryPass.cpp +++ b/src/Analyzer/Passes/AutoFinalOnQueryPass.cpp @@ -22,8 +22,7 @@ public: void visitImpl(QueryTreeNodePtr & node) { - const auto & context = getContext(); - if (!context->getSettingsRef().final) + if (!getSettings().final) return; const auto * query_node = node->as(); diff --git a/src/Analyzer/Passes/ShardNumColumnToFunctionPass.h b/src/Analyzer/Passes/ShardNumColumnToFunctionPass.h index 83b974954fa..71a038bcf39 100644 --- a/src/Analyzer/Passes/ShardNumColumnToFunctionPass.h +++ b/src/Analyzer/Passes/ShardNumColumnToFunctionPass.h @@ -6,6 +6,9 @@ namespace DB { /** Rewrite _shard_num column into shardNum() function. + * + * Example: SELECT _shard_num FROM distributed_table; + * Result: SELECT shardNum() FROM distributed_table; */ class ShardNumColumnToFunctionPass final : public IQueryTreePass {