diff --git a/src/Analyzer/InDepthQueryTreeVisitor.h b/src/Analyzer/InDepthQueryTreeVisitor.h index 1cc48fb1e53..59ee57996c4 100644 --- a/src/Analyzer/InDepthQueryTreeVisitor.h +++ b/src/Analyzer/InDepthQueryTreeVisitor.h @@ -7,6 +7,7 @@ #include #include +#include #include #include @@ -90,26 +91,25 @@ private: template using ConstInDepthQueryTreeVisitor = InDepthQueryTreeVisitor; -/** Same as InDepthQueryTreeVisitor and additionally keeps track of current scope context. +/** Same as InDepthQueryTreeVisitor (but has a different interface) and additionally keeps track of current scope context. * This can be useful if your visitor has special logic that depends on current scope context. + * + * To specify behavior of the visitor you can implement following methods in derived class: + * 1. needChildVisit – This methods allows to skip subtree. + * 2. enterImpl – This method is called before children are processed. + * 3. leaveImpl – This method is called after children are processed. */ template class InDepthQueryTreeVisitorWithContext { public: - using VisitQueryTreeNodeType = std::conditional_t; + using VisitQueryTreeNodeType = QueryTreeNodePtr; explicit InDepthQueryTreeVisitorWithContext(ContextPtr context, size_t initial_subquery_depth = 0) : current_context(std::move(context)) , subquery_depth(initial_subquery_depth) {} - /// Return true if visitor should traverse tree top to bottom, false otherwise - bool shouldTraverseTopToBottom() const - { - return true; - } - /// Return true if visitor should visit child, false otherwise bool needChildVisit(VisitQueryTreeNodeType & parent [[maybe_unused]], VisitQueryTreeNodeType & child [[maybe_unused]]) { @@ -146,18 +146,16 @@ public: ++subquery_depth; - bool traverse_top_to_bottom = getDerived().shouldTraverseTopToBottom(); - if (!traverse_top_to_bottom) - visitChildren(query_tree_node); + getDerived().enterImpl(query_tree_node); - getDerived().visitImpl(query_tree_node); - - if (traverse_top_to_bottom) - visitChildren(query_tree_node); + visitChildren(query_tree_node); getDerived().leaveImpl(query_tree_node); } + void enterImpl(VisitQueryTreeNodeType & node [[maybe_unused]]) + {} + void leaveImpl(VisitQueryTreeNodeType & node [[maybe_unused]]) {} private: @@ -171,17 +169,31 @@ private: return *static_cast(this); } + bool shouldSkipSubtree( + VisitQueryTreeNodeType & parent, + VisitQueryTreeNodeType & child, + size_t subtree_index) + { + bool need_visit_child = getDerived().needChildVisit(parent, child); + if (!need_visit_child) + return true; + + if (auto * table_function_node = parent->as()) + { + const auto & unresolved_indexes = table_function_node->getUnresolvedArgumentIndexes(); + return std::find(unresolved_indexes.begin(), unresolved_indexes.end(), subtree_index) != unresolved_indexes.end(); + } + return false; + } + void visitChildren(VisitQueryTreeNodeType & expression) { + size_t index = 0; for (auto & child : expression->getChildren()) { - if (!child) - continue; - - bool need_visit_child = getDerived().needChildVisit(expression, child); - - if (need_visit_child) + if (child && !shouldSkipSubtree(expression, child, index)) visit(child); + ++index; } } @@ -189,50 +201,4 @@ private: size_t subquery_depth = 0; }; -template -using ConstInDepthQueryTreeVisitorWithContext = InDepthQueryTreeVisitorWithContext; - -/** Visitor that use another visitor to visit node only if condition for visiting node is true. - * For example, your visitor need to visit only query tree nodes or union nodes. - * - * Condition interface: - * struct Condition - * { - * bool operator()(VisitQueryTreeNodeType & node) - * { - * return shouldNestedVisitorVisitNode(node); - * } - * } - */ -template -class InDepthQueryTreeConditionalVisitor : public InDepthQueryTreeVisitor, const_visitor> -{ -public: - using Base = InDepthQueryTreeVisitor, const_visitor>; - using VisitQueryTreeNodeType = typename Base::VisitQueryTreeNodeType; - - explicit InDepthQueryTreeConditionalVisitor(Visitor & visitor_, Condition & condition_) - : visitor(visitor_) - , condition(condition_) - { - } - - bool shouldTraverseTopToBottom() const - { - return visitor.shouldTraverseTopToBottom(); - } - - void visitImpl(VisitQueryTreeNodeType & query_tree_node) - { - if (condition(query_tree_node)) - visitor.visit(query_tree_node); - } - - Visitor & visitor; - Condition & condition; -}; - -template -using ConstInDepthQueryTreeConditionalVisitor = InDepthQueryTreeConditionalVisitor; - } diff --git a/src/Analyzer/Passes/AggregateFunctionsArithmericOperationsPass.cpp b/src/Analyzer/Passes/AggregateFunctionsArithmericOperationsPass.cpp index 1476a66c892..3615a632374 100644 --- a/src/Analyzer/Passes/AggregateFunctionsArithmericOperationsPass.cpp +++ b/src/Analyzer/Passes/AggregateFunctionsArithmericOperationsPass.cpp @@ -51,13 +51,7 @@ public: using Base = InDepthQueryTreeVisitorWithContext; using Base::Base; - /// Traverse tree bottom to top - static bool shouldTraverseTopToBottom() - { - return false; - } - - void visitImpl(QueryTreeNodePtr & node) + void leaveImpl(QueryTreeNodePtr & node) { if (!getSettings().optimize_arithmetic_operations_in_aggregate_functions) return; diff --git a/src/Analyzer/Passes/ArrayExistsToHasPass.cpp b/src/Analyzer/Passes/ArrayExistsToHasPass.cpp index c0f958588f1..a95bcea4fac 100644 --- a/src/Analyzer/Passes/ArrayExistsToHasPass.cpp +++ b/src/Analyzer/Passes/ArrayExistsToHasPass.cpp @@ -22,7 +22,7 @@ public: using Base = InDepthQueryTreeVisitorWithContext; using Base::Base; - void visitImpl(QueryTreeNodePtr & node) + void enterImpl(QueryTreeNodePtr & node) { if (!getSettings().optimize_rewrite_array_exists_to_has) return; diff --git a/src/Analyzer/Passes/AutoFinalOnQueryPass.cpp b/src/Analyzer/Passes/AutoFinalOnQueryPass.cpp index 15326ca1dc8..2c89ec9dc20 100644 --- a/src/Analyzer/Passes/AutoFinalOnQueryPass.cpp +++ b/src/Analyzer/Passes/AutoFinalOnQueryPass.cpp @@ -20,7 +20,7 @@ public: using Base = InDepthQueryTreeVisitorWithContext; using Base::Base; - void visitImpl(QueryTreeNodePtr & node) + void enterImpl(QueryTreeNodePtr & node) { if (!getSettings().final) return; diff --git a/src/Analyzer/Passes/ConvertOrLikeChainPass.cpp b/src/Analyzer/Passes/ConvertOrLikeChainPass.cpp index 7d7362fb742..1fada88a21c 100644 --- a/src/Analyzer/Passes/ConvertOrLikeChainPass.cpp +++ b/src/Analyzer/Passes/ConvertOrLikeChainPass.cpp @@ -50,7 +50,7 @@ public: && settings.max_hyperscan_regexp_total_length == 0; } - void visitImpl(QueryTreeNodePtr & node) + void enterImpl(QueryTreeNodePtr & node) { auto * function_node = node->as(); if (!function_node || function_node->getFunctionName() != "or") diff --git a/src/Analyzer/Passes/ConvertQueryToCNFPass.cpp b/src/Analyzer/Passes/ConvertQueryToCNFPass.cpp index 4d32c96b845..724448ad742 100644 --- a/src/Analyzer/Passes/ConvertQueryToCNFPass.cpp +++ b/src/Analyzer/Passes/ConvertQueryToCNFPass.cpp @@ -688,7 +688,7 @@ public: using Base = InDepthQueryTreeVisitorWithContext; using Base::Base; - void visitImpl(QueryTreeNodePtr & node) + void enterImpl(QueryTreeNodePtr & node) { auto * query_node = node->as(); if (!query_node) diff --git a/src/Analyzer/Passes/CountDistinctPass.cpp b/src/Analyzer/Passes/CountDistinctPass.cpp index 945295f5cbc..dc58747221e 100644 --- a/src/Analyzer/Passes/CountDistinctPass.cpp +++ b/src/Analyzer/Passes/CountDistinctPass.cpp @@ -22,7 +22,7 @@ public: using Base = InDepthQueryTreeVisitorWithContext; using Base::Base; - void visitImpl(QueryTreeNodePtr & node) + void enterImpl(QueryTreeNodePtr & node) { if (!getSettings().count_distinct_optimization) return; diff --git a/src/Analyzer/Passes/CrossToInnerJoinPass.cpp b/src/Analyzer/Passes/CrossToInnerJoinPass.cpp index d4877d23f28..b5ece1a4c49 100644 --- a/src/Analyzer/Passes/CrossToInnerJoinPass.cpp +++ b/src/Analyzer/Passes/CrossToInnerJoinPass.cpp @@ -193,7 +193,7 @@ public: return true; } - void visitImpl(QueryTreeNodePtr & node) + void enterImpl(QueryTreeNodePtr & node) { if (!isEnabled()) return; diff --git a/src/Analyzer/Passes/FunctionToSubcolumnsPass.cpp b/src/Analyzer/Passes/FunctionToSubcolumnsPass.cpp index 696483862e0..cd635f87e0e 100644 --- a/src/Analyzer/Passes/FunctionToSubcolumnsPass.cpp +++ b/src/Analyzer/Passes/FunctionToSubcolumnsPass.cpp @@ -29,7 +29,7 @@ public: using Base = InDepthQueryTreeVisitorWithContext; using Base::Base; - void visitImpl(QueryTreeNodePtr & node) const + void enterImpl(QueryTreeNodePtr & node) const { if (!getSettings().optimize_functions_to_subcolumns) return; diff --git a/src/Analyzer/Passes/FuseFunctionsPass.cpp b/src/Analyzer/Passes/FuseFunctionsPass.cpp index 14082697955..2cb7afa4ad6 100644 --- a/src/Analyzer/Passes/FuseFunctionsPass.cpp +++ b/src/Analyzer/Passes/FuseFunctionsPass.cpp @@ -37,7 +37,7 @@ public: , names_to_collect(names_to_collect_) {} - void visitImpl(QueryTreeNodePtr & node) + void enterImpl(QueryTreeNodePtr & node) { if (!getSettings().optimize_syntax_fuse_functions) return; diff --git a/src/Analyzer/Passes/GroupingFunctionsResolvePass.cpp b/src/Analyzer/Passes/GroupingFunctionsResolvePass.cpp index 0cf5310a3ad..577bca8d1ae 100644 --- a/src/Analyzer/Passes/GroupingFunctionsResolvePass.cpp +++ b/src/Analyzer/Passes/GroupingFunctionsResolvePass.cpp @@ -46,7 +46,7 @@ public: { } - void visitImpl(const QueryTreeNodePtr & node) + void enterImpl(const QueryTreeNodePtr & node) { auto * function_node = node->as(); if (!function_node || function_node->getFunctionName() != "grouping") diff --git a/src/Analyzer/Passes/IfChainToMultiIfPass.cpp b/src/Analyzer/Passes/IfChainToMultiIfPass.cpp index 1f97e012331..b0018d474d5 100644 --- a/src/Analyzer/Passes/IfChainToMultiIfPass.cpp +++ b/src/Analyzer/Passes/IfChainToMultiIfPass.cpp @@ -23,7 +23,7 @@ public: , multi_if_function_ptr(std::move(multi_if_function_ptr_)) {} - void visitImpl(QueryTreeNodePtr & node) + void enterImpl(QueryTreeNodePtr & node) { if (!getSettings().optimize_if_chain_to_multiif) return; diff --git a/src/Analyzer/Passes/IfTransformStringsToEnumPass.cpp b/src/Analyzer/Passes/IfTransformStringsToEnumPass.cpp index 562aff4cf05..901867b8889 100644 --- a/src/Analyzer/Passes/IfTransformStringsToEnumPass.cpp +++ b/src/Analyzer/Passes/IfTransformStringsToEnumPass.cpp @@ -113,7 +113,7 @@ public: using Base = InDepthQueryTreeVisitorWithContext; using Base::Base; - void visitImpl(QueryTreeNodePtr & node) + void enterImpl(QueryTreeNodePtr & node) { if (!getSettings().optimize_if_transform_strings_to_enum) return; diff --git a/src/Analyzer/Passes/LogicalExpressionOptimizerPass.cpp b/src/Analyzer/Passes/LogicalExpressionOptimizerPass.cpp index 13f8025f5ea..46056aeaf6f 100644 --- a/src/Analyzer/Passes/LogicalExpressionOptimizerPass.cpp +++ b/src/Analyzer/Passes/LogicalExpressionOptimizerPass.cpp @@ -19,7 +19,7 @@ public: : Base(std::move(context)) {} - void visitImpl(QueryTreeNodePtr & node) + void enterImpl(QueryTreeNodePtr & node) { auto * function_node = node->as(); diff --git a/src/Analyzer/Passes/MultiIfToIfPass.cpp b/src/Analyzer/Passes/MultiIfToIfPass.cpp index 4672351bcfb..85dd33af8bb 100644 --- a/src/Analyzer/Passes/MultiIfToIfPass.cpp +++ b/src/Analyzer/Passes/MultiIfToIfPass.cpp @@ -21,7 +21,7 @@ public: , if_function_ptr(std::move(if_function_ptr_)) {} - void visitImpl(QueryTreeNodePtr & node) + void enterImpl(QueryTreeNodePtr & node) { if (!getSettings().optimize_multiif_to_if) return; diff --git a/src/Analyzer/Passes/NormalizeCountVariantsPass.cpp b/src/Analyzer/Passes/NormalizeCountVariantsPass.cpp index d36be98751c..c85b863a203 100644 --- a/src/Analyzer/Passes/NormalizeCountVariantsPass.cpp +++ b/src/Analyzer/Passes/NormalizeCountVariantsPass.cpp @@ -20,7 +20,7 @@ public: using Base = InDepthQueryTreeVisitorWithContext; using Base::Base; - void visitImpl(QueryTreeNodePtr & node) + void enterImpl(QueryTreeNodePtr & node) { if (!getSettings().optimize_normalize_count_variants) return; diff --git a/src/Analyzer/Passes/OptimizeGroupByFunctionKeysPass.cpp b/src/Analyzer/Passes/OptimizeGroupByFunctionKeysPass.cpp index 5ed52f1210b..2e3f207fdeb 100644 --- a/src/Analyzer/Passes/OptimizeGroupByFunctionKeysPass.cpp +++ b/src/Analyzer/Passes/OptimizeGroupByFunctionKeysPass.cpp @@ -26,7 +26,7 @@ public: return !child->as(); } - void visitImpl(QueryTreeNodePtr & node) + void enterImpl(QueryTreeNodePtr & node) { if (!getSettings().optimize_group_by_function_keys) return; diff --git a/src/Analyzer/Passes/OptimizeRedundantFunctionsInOrderByPass.cpp b/src/Analyzer/Passes/OptimizeRedundantFunctionsInOrderByPass.cpp index c6d312d0ecf..875d0c8b5fb 100644 --- a/src/Analyzer/Passes/OptimizeRedundantFunctionsInOrderByPass.cpp +++ b/src/Analyzer/Passes/OptimizeRedundantFunctionsInOrderByPass.cpp @@ -28,7 +28,7 @@ public: return true; } - void visitImpl(QueryTreeNodePtr & node) + void enterImpl(QueryTreeNodePtr & node) { if (!getSettings().optimize_redundant_functions_in_order_by) return; diff --git a/src/Analyzer/Passes/QueryAnalysisPass.cpp b/src/Analyzer/Passes/QueryAnalysisPass.cpp index 2888b7a5bec..9d1904d4351 100644 --- a/src/Analyzer/Passes/QueryAnalysisPass.cpp +++ b/src/Analyzer/Passes/QueryAnalysisPass.cpp @@ -6451,7 +6451,7 @@ void QueryAnalyzer::resolveTableFunction(QueryTreeNodePtr & table_function_node, table_function_ptr->parseArguments(table_function_ast, scope_context); auto table_function_storage = scope_context->getQueryContext()->executeTableFunction(table_function_ast, table_function_ptr); - table_function_node_typed.resolve(std::move(table_function_ptr), std::move(table_function_storage), scope_context); + table_function_node_typed.resolve(std::move(table_function_ptr), std::move(table_function_storage), scope_context, std::move(skip_analysis_arguments_indexes)); } /// Resolve array join node in scope diff --git a/src/Analyzer/Passes/RewriteAggregateFunctionWithIfPass.cpp b/src/Analyzer/Passes/RewriteAggregateFunctionWithIfPass.cpp index de264948d4c..38f2fbfa274 100644 --- a/src/Analyzer/Passes/RewriteAggregateFunctionWithIfPass.cpp +++ b/src/Analyzer/Passes/RewriteAggregateFunctionWithIfPass.cpp @@ -26,7 +26,7 @@ public: using Base = InDepthQueryTreeVisitorWithContext; using Base::Base; - void visitImpl(QueryTreeNodePtr & node) + void enterImpl(QueryTreeNodePtr & node) { if (!getSettings().optimize_rewrite_aggregate_function_with_if) return; diff --git a/src/Analyzer/Passes/ShardNumColumnToFunctionPass.cpp b/src/Analyzer/Passes/ShardNumColumnToFunctionPass.cpp index b28816e8ff3..52c30b7b35d 100644 --- a/src/Analyzer/Passes/ShardNumColumnToFunctionPass.cpp +++ b/src/Analyzer/Passes/ShardNumColumnToFunctionPass.cpp @@ -24,7 +24,7 @@ public: using Base = InDepthQueryTreeVisitorWithContext; using Base::Base; - void visitImpl(QueryTreeNodePtr & node) const + void enterImpl(QueryTreeNodePtr & node) const { auto * column_node = node->as(); if (!column_node) diff --git a/src/Analyzer/Passes/SumIfToCountIfPass.cpp b/src/Analyzer/Passes/SumIfToCountIfPass.cpp index d55af278152..cff9ba1111c 100644 --- a/src/Analyzer/Passes/SumIfToCountIfPass.cpp +++ b/src/Analyzer/Passes/SumIfToCountIfPass.cpp @@ -26,7 +26,7 @@ public: using Base = InDepthQueryTreeVisitorWithContext; using Base::Base; - void visitImpl(QueryTreeNodePtr & node) + void enterImpl(QueryTreeNodePtr & node) { if (!getSettings().optimize_rewrite_sum_if_to_count_if) return; diff --git a/src/Analyzer/Passes/UniqInjectiveFunctionsEliminationPass.cpp b/src/Analyzer/Passes/UniqInjectiveFunctionsEliminationPass.cpp index 5c4484457e8..179bd1c38e4 100644 --- a/src/Analyzer/Passes/UniqInjectiveFunctionsEliminationPass.cpp +++ b/src/Analyzer/Passes/UniqInjectiveFunctionsEliminationPass.cpp @@ -31,7 +31,7 @@ public: using Base = InDepthQueryTreeVisitorWithContext; using Base::Base; - void visitImpl(QueryTreeNodePtr & node) + void enterImpl(QueryTreeNodePtr & node) { if (!getSettings().optimize_injective_functions_inside_uniq) return; diff --git a/src/Analyzer/TableFunctionNode.cpp b/src/Analyzer/TableFunctionNode.cpp index c130503d660..e5158a06373 100644 --- a/src/Analyzer/TableFunctionNode.cpp +++ b/src/Analyzer/TableFunctionNode.cpp @@ -27,12 +27,13 @@ TableFunctionNode::TableFunctionNode(String table_function_name_) children[arguments_child_index] = std::make_shared(); } -void TableFunctionNode::resolve(TableFunctionPtr table_function_value, StoragePtr storage_value, ContextPtr context) +void TableFunctionNode::resolve(TableFunctionPtr table_function_value, StoragePtr storage_value, ContextPtr context, std::vector unresolved_arguments_indexes_) { table_function = std::move(table_function_value); storage = std::move(storage_value); storage_id = storage->getStorageID(); storage_snapshot = storage->getStorageSnapshot(storage->getInMemoryMetadataPtr(), context); + unresolved_arguments_indexes = std::move(unresolved_arguments_indexes_); } const StorageID & TableFunctionNode::getStorageID() const @@ -132,6 +133,7 @@ QueryTreeNodePtr TableFunctionNode::cloneImpl() const result->storage_snapshot = storage_snapshot; result->table_expression_modifiers = table_expression_modifiers; result->settings_changes = settings_changes; + result->unresolved_arguments_indexes = unresolved_arguments_indexes; return result; } diff --git a/src/Analyzer/TableFunctionNode.h b/src/Analyzer/TableFunctionNode.h index 7786ba62205..69237ac8416 100644 --- a/src/Analyzer/TableFunctionNode.h +++ b/src/Analyzer/TableFunctionNode.h @@ -98,7 +98,7 @@ public: } /// Resolve table function with table function, storage and context - void resolve(TableFunctionPtr table_function_value, StoragePtr storage_value, ContextPtr context); + void resolve(TableFunctionPtr table_function_value, StoragePtr storage_value, ContextPtr context, std::vector unresolved_arguments_indexes_); /// Get storage id, throws exception if function node is not resolved const StorageID & getStorageID() const; @@ -106,6 +106,11 @@ public: /// Get storage snapshot, throws exception if function node is not resolved const StorageSnapshotPtr & getStorageSnapshot() const; + const std::vector & getUnresolvedArgumentIndexes() const + { + return unresolved_arguments_indexes; + } + /// Return true if table function node has table expression modifiers, false otherwise bool hasTableExpressionModifiers() const { @@ -164,6 +169,7 @@ private: StoragePtr storage; StorageID storage_id; StorageSnapshotPtr storage_snapshot; + std::vector unresolved_arguments_indexes; std::optional table_expression_modifiers; SettingsChanges settings_changes; diff --git a/src/Storages/buildQueryTreeForShard.cpp b/src/Storages/buildQueryTreeForShard.cpp index 1ee7d747fcc..9929b5bb39b 100644 --- a/src/Storages/buildQueryTreeForShard.cpp +++ b/src/Storages/buildQueryTreeForShard.cpp @@ -130,7 +130,7 @@ public: return true; } - void visitImpl(QueryTreeNodePtr & node) + void enterImpl(QueryTreeNodePtr & node) { auto * function_node = node->as(); auto * join_node = node->as();