Merge pull request #46740 from ClickHouse/query-tree-visitor

Refactor Query Tree visitor
This commit is contained in:
Dmitry Novik 2023-07-29 15:10:07 +02:00 committed by GitHub
commit 8f4527d9e0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
26 changed files with 66 additions and 98 deletions

View File

@ -7,6 +7,7 @@
#include <Analyzer/IQueryTreeNode.h>
#include <Analyzer/QueryNode.h>
#include <Analyzer/TableFunctionNode.h>
#include <Analyzer/UnionNode.h>
#include <Interpreters/Context.h>
@ -90,26 +91,25 @@ private:
template <typename Derived>
using ConstInDepthQueryTreeVisitor = InDepthQueryTreeVisitor<Derived, true /*const_visitor*/>;
/** Same as InDepthQueryTreeVisitor and additionally keeps track of current scope context.
/** Same as InDepthQueryTreeVisitor (but has a different interface) and additionally keeps track of current scope context.
* This can be useful if your visitor has special logic that depends on current scope context.
*
* To specify behavior of the visitor you can implement following methods in derived class:
* 1. needChildVisit This methods allows to skip subtree.
* 2. enterImpl This method is called before children are processed.
* 3. leaveImpl This method is called after children are processed.
*/
template <typename Derived, bool const_visitor = false>
class InDepthQueryTreeVisitorWithContext
{
public:
using VisitQueryTreeNodeType = std::conditional_t<const_visitor, const QueryTreeNodePtr, QueryTreeNodePtr>;
using VisitQueryTreeNodeType = QueryTreeNodePtr;
explicit InDepthQueryTreeVisitorWithContext(ContextPtr context, size_t initial_subquery_depth = 0)
: current_context(std::move(context))
, subquery_depth(initial_subquery_depth)
{}
/// Return true if visitor should traverse tree top to bottom, false otherwise
bool shouldTraverseTopToBottom() const
{
return true;
}
/// Return true if visitor should visit child, false otherwise
bool needChildVisit(VisitQueryTreeNodeType & parent [[maybe_unused]], VisitQueryTreeNodeType & child [[maybe_unused]])
{
@ -146,18 +146,16 @@ public:
++subquery_depth;
bool traverse_top_to_bottom = getDerived().shouldTraverseTopToBottom();
if (!traverse_top_to_bottom)
visitChildren(query_tree_node);
getDerived().enterImpl(query_tree_node);
getDerived().visitImpl(query_tree_node);
if (traverse_top_to_bottom)
visitChildren(query_tree_node);
visitChildren(query_tree_node);
getDerived().leaveImpl(query_tree_node);
}
void enterImpl(VisitQueryTreeNodeType & node [[maybe_unused]])
{}
void leaveImpl(VisitQueryTreeNodeType & node [[maybe_unused]])
{}
private:
@ -171,17 +169,31 @@ private:
return *static_cast<Derived *>(this);
}
bool shouldSkipSubtree(
VisitQueryTreeNodeType & parent,
VisitQueryTreeNodeType & child,
size_t subtree_index)
{
bool need_visit_child = getDerived().needChildVisit(parent, child);
if (!need_visit_child)
return true;
if (auto * table_function_node = parent->as<TableFunctionNode>())
{
const auto & unresolved_indexes = table_function_node->getUnresolvedArgumentIndexes();
return std::find(unresolved_indexes.begin(), unresolved_indexes.end(), subtree_index) != unresolved_indexes.end();
}
return false;
}
void visitChildren(VisitQueryTreeNodeType & expression)
{
size_t index = 0;
for (auto & child : expression->getChildren())
{
if (!child)
continue;
bool need_visit_child = getDerived().needChildVisit(expression, child);
if (need_visit_child)
if (child && !shouldSkipSubtree(expression, child, index))
visit(child);
++index;
}
}
@ -189,50 +201,4 @@ private:
size_t subquery_depth = 0;
};
template <typename Derived>
using ConstInDepthQueryTreeVisitorWithContext = InDepthQueryTreeVisitorWithContext<Derived, true /*const_visitor*/>;
/** Visitor that use another visitor to visit node only if condition for visiting node is true.
* For example, your visitor need to visit only query tree nodes or union nodes.
*
* Condition interface:
* struct Condition
* {
* bool operator()(VisitQueryTreeNodeType & node)
* {
* return shouldNestedVisitorVisitNode(node);
* }
* }
*/
template <typename Visitor, typename Condition, bool const_visitor = false>
class InDepthQueryTreeConditionalVisitor : public InDepthQueryTreeVisitor<InDepthQueryTreeConditionalVisitor<Visitor, Condition, const_visitor>, const_visitor>
{
public:
using Base = InDepthQueryTreeVisitor<InDepthQueryTreeConditionalVisitor<Visitor, Condition, const_visitor>, const_visitor>;
using VisitQueryTreeNodeType = typename Base::VisitQueryTreeNodeType;
explicit InDepthQueryTreeConditionalVisitor(Visitor & visitor_, Condition & condition_)
: visitor(visitor_)
, condition(condition_)
{
}
bool shouldTraverseTopToBottom() const
{
return visitor.shouldTraverseTopToBottom();
}
void visitImpl(VisitQueryTreeNodeType & query_tree_node)
{
if (condition(query_tree_node))
visitor.visit(query_tree_node);
}
Visitor & visitor;
Condition & condition;
};
template <typename Visitor, typename Condition>
using ConstInDepthQueryTreeConditionalVisitor = InDepthQueryTreeConditionalVisitor<Visitor, Condition, true /*const_visitor*/>;
}

View File

@ -51,13 +51,7 @@ public:
using Base = InDepthQueryTreeVisitorWithContext<AggregateFunctionsArithmericOperationsVisitor>;
using Base::Base;
/// Traverse tree bottom to top
static bool shouldTraverseTopToBottom()
{
return false;
}
void visitImpl(QueryTreeNodePtr & node)
void leaveImpl(QueryTreeNodePtr & node)
{
if (!getSettings().optimize_arithmetic_operations_in_aggregate_functions)
return;

View File

@ -22,7 +22,7 @@ public:
using Base = InDepthQueryTreeVisitorWithContext<RewriteArrayExistsToHasVisitor>;
using Base::Base;
void visitImpl(QueryTreeNodePtr & node)
void enterImpl(QueryTreeNodePtr & node)
{
if (!getSettings().optimize_rewrite_array_exists_to_has)
return;

View File

@ -20,7 +20,7 @@ public:
using Base = InDepthQueryTreeVisitorWithContext<AutoFinalOnQueryPassVisitor>;
using Base::Base;
void visitImpl(QueryTreeNodePtr & node)
void enterImpl(QueryTreeNodePtr & node)
{
if (!getSettings().final)
return;

View File

@ -50,7 +50,7 @@ public:
&& settings.max_hyperscan_regexp_total_length == 0;
}
void visitImpl(QueryTreeNodePtr & node)
void enterImpl(QueryTreeNodePtr & node)
{
auto * function_node = node->as<FunctionNode>();
if (!function_node || function_node->getFunctionName() != "or")

View File

@ -688,7 +688,7 @@ public:
using Base = InDepthQueryTreeVisitorWithContext<ConvertQueryToCNFVisitor>;
using Base::Base;
void visitImpl(QueryTreeNodePtr & node)
void enterImpl(QueryTreeNodePtr & node)
{
auto * query_node = node->as<QueryNode>();
if (!query_node)

View File

@ -22,7 +22,7 @@ public:
using Base = InDepthQueryTreeVisitorWithContext<CountDistinctVisitor>;
using Base::Base;
void visitImpl(QueryTreeNodePtr & node)
void enterImpl(QueryTreeNodePtr & node)
{
if (!getSettings().count_distinct_optimization)
return;

View File

@ -193,7 +193,7 @@ public:
return true;
}
void visitImpl(QueryTreeNodePtr & node)
void enterImpl(QueryTreeNodePtr & node)
{
if (!isEnabled())
return;

View File

@ -29,7 +29,7 @@ public:
using Base = InDepthQueryTreeVisitorWithContext<FunctionToSubcolumnsVisitor>;
using Base::Base;
void visitImpl(QueryTreeNodePtr & node) const
void enterImpl(QueryTreeNodePtr & node) const
{
if (!getSettings().optimize_functions_to_subcolumns)
return;

View File

@ -37,7 +37,7 @@ public:
, names_to_collect(names_to_collect_)
{}
void visitImpl(QueryTreeNodePtr & node)
void enterImpl(QueryTreeNodePtr & node)
{
if (!getSettings().optimize_syntax_fuse_functions)
return;

View File

@ -46,7 +46,7 @@ public:
{
}
void visitImpl(const QueryTreeNodePtr & node)
void enterImpl(const QueryTreeNodePtr & node)
{
auto * function_node = node->as<FunctionNode>();
if (!function_node || function_node->getFunctionName() != "grouping")

View File

@ -23,7 +23,7 @@ public:
, multi_if_function_ptr(std::move(multi_if_function_ptr_))
{}
void visitImpl(QueryTreeNodePtr & node)
void enterImpl(QueryTreeNodePtr & node)
{
if (!getSettings().optimize_if_chain_to_multiif)
return;

View File

@ -113,7 +113,7 @@ public:
using Base = InDepthQueryTreeVisitorWithContext<ConvertStringsToEnumVisitor>;
using Base::Base;
void visitImpl(QueryTreeNodePtr & node)
void enterImpl(QueryTreeNodePtr & node)
{
if (!getSettings().optimize_if_transform_strings_to_enum)
return;

View File

@ -19,7 +19,7 @@ public:
: Base(std::move(context))
{}
void visitImpl(QueryTreeNodePtr & node)
void enterImpl(QueryTreeNodePtr & node)
{
auto * function_node = node->as<FunctionNode>();

View File

@ -21,7 +21,7 @@ public:
, if_function_ptr(std::move(if_function_ptr_))
{}
void visitImpl(QueryTreeNodePtr & node)
void enterImpl(QueryTreeNodePtr & node)
{
if (!getSettings().optimize_multiif_to_if)
return;

View File

@ -20,7 +20,7 @@ public:
using Base = InDepthQueryTreeVisitorWithContext<NormalizeCountVariantsVisitor>;
using Base::Base;
void visitImpl(QueryTreeNodePtr & node)
void enterImpl(QueryTreeNodePtr & node)
{
if (!getSettings().optimize_normalize_count_variants)
return;

View File

@ -26,7 +26,7 @@ public:
return !child->as<FunctionNode>();
}
void visitImpl(QueryTreeNodePtr & node)
void enterImpl(QueryTreeNodePtr & node)
{
if (!getSettings().optimize_group_by_function_keys)
return;

View File

@ -28,7 +28,7 @@ public:
return true;
}
void visitImpl(QueryTreeNodePtr & node)
void enterImpl(QueryTreeNodePtr & node)
{
if (!getSettings().optimize_redundant_functions_in_order_by)
return;

View File

@ -6451,7 +6451,7 @@ void QueryAnalyzer::resolveTableFunction(QueryTreeNodePtr & table_function_node,
table_function_ptr->parseArguments(table_function_ast, scope_context);
auto table_function_storage = scope_context->getQueryContext()->executeTableFunction(table_function_ast, table_function_ptr);
table_function_node_typed.resolve(std::move(table_function_ptr), std::move(table_function_storage), scope_context);
table_function_node_typed.resolve(std::move(table_function_ptr), std::move(table_function_storage), scope_context, std::move(skip_analysis_arguments_indexes));
}
/// Resolve array join node in scope

View File

@ -26,7 +26,7 @@ public:
using Base = InDepthQueryTreeVisitorWithContext<RewriteAggregateFunctionWithIfVisitor>;
using Base::Base;
void visitImpl(QueryTreeNodePtr & node)
void enterImpl(QueryTreeNodePtr & node)
{
if (!getSettings().optimize_rewrite_aggregate_function_with_if)
return;

View File

@ -24,7 +24,7 @@ public:
using Base = InDepthQueryTreeVisitorWithContext<ShardNumColumnToFunctionVisitor>;
using Base::Base;
void visitImpl(QueryTreeNodePtr & node) const
void enterImpl(QueryTreeNodePtr & node) const
{
auto * column_node = node->as<ColumnNode>();
if (!column_node)

View File

@ -26,7 +26,7 @@ public:
using Base = InDepthQueryTreeVisitorWithContext<SumIfToCountIfVisitor>;
using Base::Base;
void visitImpl(QueryTreeNodePtr & node)
void enterImpl(QueryTreeNodePtr & node)
{
if (!getSettings().optimize_rewrite_sum_if_to_count_if)
return;

View File

@ -31,7 +31,7 @@ public:
using Base = InDepthQueryTreeVisitorWithContext<UniqInjectiveFunctionsEliminationVisitor>;
using Base::Base;
void visitImpl(QueryTreeNodePtr & node)
void enterImpl(QueryTreeNodePtr & node)
{
if (!getSettings().optimize_injective_functions_inside_uniq)
return;

View File

@ -27,12 +27,13 @@ TableFunctionNode::TableFunctionNode(String table_function_name_)
children[arguments_child_index] = std::make_shared<ListNode>();
}
void TableFunctionNode::resolve(TableFunctionPtr table_function_value, StoragePtr storage_value, ContextPtr context)
void TableFunctionNode::resolve(TableFunctionPtr table_function_value, StoragePtr storage_value, ContextPtr context, std::vector<size_t> unresolved_arguments_indexes_)
{
table_function = std::move(table_function_value);
storage = std::move(storage_value);
storage_id = storage->getStorageID();
storage_snapshot = storage->getStorageSnapshot(storage->getInMemoryMetadataPtr(), context);
unresolved_arguments_indexes = std::move(unresolved_arguments_indexes_);
}
const StorageID & TableFunctionNode::getStorageID() const
@ -132,6 +133,7 @@ QueryTreeNodePtr TableFunctionNode::cloneImpl() const
result->storage_snapshot = storage_snapshot;
result->table_expression_modifiers = table_expression_modifiers;
result->settings_changes = settings_changes;
result->unresolved_arguments_indexes = unresolved_arguments_indexes;
return result;
}

View File

@ -98,7 +98,7 @@ public:
}
/// Resolve table function with table function, storage and context
void resolve(TableFunctionPtr table_function_value, StoragePtr storage_value, ContextPtr context);
void resolve(TableFunctionPtr table_function_value, StoragePtr storage_value, ContextPtr context, std::vector<size_t> unresolved_arguments_indexes_);
/// Get storage id, throws exception if function node is not resolved
const StorageID & getStorageID() const;
@ -106,6 +106,11 @@ public:
/// Get storage snapshot, throws exception if function node is not resolved
const StorageSnapshotPtr & getStorageSnapshot() const;
const std::vector<size_t> & getUnresolvedArgumentIndexes() const
{
return unresolved_arguments_indexes;
}
/// Return true if table function node has table expression modifiers, false otherwise
bool hasTableExpressionModifiers() const
{
@ -164,6 +169,7 @@ private:
StoragePtr storage;
StorageID storage_id;
StorageSnapshotPtr storage_snapshot;
std::vector<size_t> unresolved_arguments_indexes;
std::optional<TableExpressionModifiers> table_expression_modifiers;
SettingsChanges settings_changes;

View File

@ -130,7 +130,7 @@ public:
return true;
}
void visitImpl(QueryTreeNodePtr & node)
void enterImpl(QueryTreeNodePtr & node)
{
auto * function_node = node->as<FunctionNode>();
auto * join_node = node->as<JoinNode>();