From 4071155043f8a8adcba13c622e39087aa1701620 Mon Sep 17 00:00:00 2001 From: chertus Date: Tue, 16 Oct 2018 15:34:20 +0300 Subject: [PATCH] better ActionsVisitor extraction [CLICKHOUSE-3996] --- dbms/src/Interpreters/ActionsVisitor.cpp | 36 ++++++-- dbms/src/Interpreters/ActionsVisitor.h | 31 +++---- dbms/src/Interpreters/ExpressionAnalyzer.cpp | 97 +++++++++----------- dbms/src/Interpreters/ExpressionAnalyzer.h | 4 +- 4 files changed, 86 insertions(+), 82 deletions(-) diff --git a/dbms/src/Interpreters/ActionsVisitor.cpp b/dbms/src/Interpreters/ActionsVisitor.cpp index bde379fa69b..29d8f190fbf 100644 --- a/dbms/src/Interpreters/ActionsVisitor.cpp +++ b/dbms/src/Interpreters/ActionsVisitor.cpp @@ -220,7 +220,31 @@ const Block & ScopeStack::getSampleBlock() const } -void ActionsVisitor::visit(const ASTPtr & ast, ScopeStack & actions_stack, ProjectionManipulatorPtr projection_manipulator) +ActionsVisitor::ActionsVisitor( + const Context & context_, SizeLimits set_size_limit_, bool is_conditional_tree, size_t subquery_depth_, + const NamesAndTypesList & source_columns_, const ExpressionActionsPtr & actions, + PreparedSets & prepared_sets_, SubqueriesForSets & subqueries_for_sets_, + bool no_subqueries_, bool only_consts_, bool no_storage_or_local_, std::ostream * ostr_) +: context(context_), + set_size_limit(set_size_limit_), + subquery_depth(subquery_depth_), + source_columns(source_columns_), + prepared_sets(prepared_sets_), + subqueries_for_sets(subqueries_for_sets_), + no_subqueries(no_subqueries_), + only_consts(only_consts_), + no_storage_or_local(no_storage_or_local_), + visit_depth(0), + ostr(ostr_), + actions_stack(actions, context) +{ + if (is_conditional_tree) + projection_manipulator = std::make_shared(actions_stack, context); + else + projection_manipulator = std::make_shared(actions_stack); +} + +void ActionsVisitor::visit(const ASTPtr & ast) { DumpASTNode dump(*ast, ostr, visit_depth, "getActions"); @@ -267,7 +291,7 @@ void ActionsVisitor::visit(const ASTPtr & ast, ScopeStack & actions_stack, Proje throw Exception("arrayJoin requires exactly 1 argument", ErrorCodes::TYPE_MISMATCH); ASTPtr arg = node->arguments->children.at(0); - visit(arg, actions_stack, projection_manipulator); + visit(arg); if (!only_consts) { String result_name = projection_manipulator->getColumnName(getColumnName()); @@ -283,7 +307,7 @@ void ActionsVisitor::visit(const ASTPtr & ast, ScopeStack & actions_stack, Proje if (functionIsInOrGlobalInOperator(node->name)) { /// Let's find the type of the first argument (then getActionsImpl will be called again and will not affect anything). - visit(node->arguments->children.at(0), actions_stack, projection_manipulator); + visit(node->arguments->children.at(0)); if (!no_subqueries) { @@ -387,7 +411,7 @@ void ActionsVisitor::visit(const ASTPtr & ast, ScopeStack & actions_stack, Proje { /// If the argument is not a lambda expression, call it recursively and find out its type. projection_action->preArgumentAction(); - visit(child, actions_stack, projection_manipulator); + visit(child); std::string name = projection_manipulator->getColumnName(child_column_name); projection_action->postArgumentAction(child_column_name); if (actions_stack.getSampleBlock().has(name)) @@ -442,7 +466,7 @@ void ActionsVisitor::visit(const ASTPtr & ast, ScopeStack & actions_stack, Proje projection_action->preArgumentAction(); actions_stack.pushLevel(lambda_arguments); - visit(lambda->arguments->children.at(1), actions_stack, projection_manipulator); + visit(lambda->arguments->children.at(1)); ExpressionActionsPtr lambda_actions = actions_stack.popLevel(); String result_name = projection_manipulator->getColumnName(lambda->arguments->children.at(1)->getColumnName()); @@ -515,7 +539,7 @@ void ActionsVisitor::visit(const ASTPtr & ast, ScopeStack & actions_stack, Proje /// Do not go to FROM, JOIN, UNION. if (!typeid_cast(child.get()) && !typeid_cast(child.get())) - visit(child, actions_stack, projection_manipulator); + visit(child); } } } diff --git a/dbms/src/Interpreters/ActionsVisitor.h b/dbms/src/Interpreters/ActionsVisitor.h index 78b91b64afc..41560c55c7c 100644 --- a/dbms/src/Interpreters/ActionsVisitor.h +++ b/dbms/src/Interpreters/ActionsVisitor.h @@ -8,12 +8,11 @@ namespace DB class Context; class ASTFunction; +struct ProjectionManipulatorBase; class ExpressionActions; using ExpressionActionsPtr = std::shared_ptr; -struct ProjectionManipulatorBase; -using ProjectionManipulatorPtr = std::shared_ptr; class Set; using SetPtr = std::shared_ptr; @@ -87,27 +86,19 @@ struct ScopeStack }; -/// TODO: There sould be some description, but... +/// Collect ExpressionAction from AST. Returns PreparedSets and SubqueriesForSets too. +/// After AST is visited source ExpressionActions should be updated with popActionsLevel() method. class ActionsVisitor { public: - ActionsVisitor(const Context & context_, SizeLimits set_size_limit_, size_t subquery_depth_, - const NamesAndTypesList & source_columns_, PreparedSets & prepared_sets_, SubqueriesForSets & subqueries_for_sets_, - bool no_subqueries_, bool only_consts_, bool no_storage_or_local_, std::ostream * ostr_ = nullptr) - : context(context_), - set_size_limit(set_size_limit_), - subquery_depth(subquery_depth_), - source_columns(source_columns_), - prepared_sets(prepared_sets_), - subqueries_for_sets(subqueries_for_sets_), - no_subqueries(no_subqueries_), - only_consts(only_consts_), - no_storage_or_local(no_storage_or_local_), - visit_depth(0), - ostr(ostr_) - {} + ActionsVisitor(const Context & context_, SizeLimits set_size_limit_, bool is_conditional_tree, size_t subquery_depth_, + const NamesAndTypesList & source_columns_, const ExpressionActionsPtr & actions, + PreparedSets & prepared_sets_, SubqueriesForSets & subqueries_for_sets_, + bool no_subqueries_, bool only_consts_, bool no_storage_or_local_, std::ostream * ostr_ = nullptr); - void visit(const ASTPtr & ast, ScopeStack & actions_stack, ProjectionManipulatorPtr projection_manipulator); + void visit(const ASTPtr & ast); + + ExpressionActionsPtr popActionsLevel() { return actions_stack.popLevel(); } private: const Context & context; @@ -121,6 +112,8 @@ private: const bool no_storage_or_local; mutable size_t visit_depth; std::ostream * ostr; + ScopeStack actions_stack; + std::shared_ptr projection_manipulator; void makeSet(const ASTFunction * node, const Block & sample_block); }; diff --git a/dbms/src/Interpreters/ExpressionAnalyzer.cpp b/dbms/src/Interpreters/ExpressionAnalyzer.cpp index 15938a3f52a..d04f251cb43 100644 --- a/dbms/src/Interpreters/ExpressionAnalyzer.cpp +++ b/dbms/src/Interpreters/ExpressionAnalyzer.cpp @@ -29,7 +29,6 @@ #include #include #include -#include #include #include @@ -439,7 +438,7 @@ void ExpressionAnalyzer::analyzeAggregation() if (select_query && select_query->array_join_expression_list()) { - getRootActions(select_query->array_join_expression_list(), true, false, temp_actions); + getRootActions(select_query->array_join_expression_list(), true, temp_actions); addMultipleArrayJoinAction(temp_actions); array_join_columns = temp_actions->getSampleBlock().getNamesAndTypesList(); } @@ -451,10 +450,10 @@ void ExpressionAnalyzer::analyzeAggregation() { const auto table_join = static_cast(*join->table_join); if (table_join.using_expression_list) - getRootActions(table_join.using_expression_list, true, false, temp_actions); + getRootActions(table_join.using_expression_list, true, temp_actions); if (table_join.on_expression) for (const auto & key_ast : analyzed_join.key_asts_left) - getRootActions(key_ast, true, false, temp_actions); + getRootActions(key_ast, true, temp_actions); addJoinAction(temp_actions, true); } @@ -474,7 +473,7 @@ void ExpressionAnalyzer::analyzeAggregation() for (ssize_t i = 0; i < ssize_t(group_asts.size()); ++i) { ssize_t size = group_asts.size(); - getRootActions(group_asts[i], true, false, temp_actions); + getRootActions(group_asts[i], true, temp_actions); const auto & column_name = group_asts[i]->getColumnName(); const auto & block = temp_actions->getSampleBlock(); @@ -1053,7 +1052,7 @@ void ExpressionAnalyzer::makeSetsForIndexImpl(const ASTPtr & node, const Block & for (const auto & joined_column : analyzed_join.columns_added_by_join) temp_columns.push_back(joined_column.name_and_type); ExpressionActionsPtr temp_actions = std::make_shared(temp_columns, context); - getRootActions(func->arguments->children.at(0), true, false, temp_actions); + getRootActions(func->arguments->children.at(0), true, temp_actions); Block sample_block_with_calculated_columns = temp_actions->getSampleBlock(); if (sample_block_with_calculated_columns.has(args.children.at(0)->getColumnName())) @@ -1065,25 +1064,6 @@ void ExpressionAnalyzer::makeSetsForIndexImpl(const ASTPtr & node, const Block & } -void ExpressionAnalyzer::getRootActions(const ASTPtr & ast, bool no_subqueries, bool only_consts, ExpressionActionsPtr & actions) -{ - ScopeStack scopes(actions, context); - - ProjectionManipulatorPtr projection_manipulator; - if (!isThereArrayJoin(ast) && settings.enable_conditional_computation && !only_consts) - projection_manipulator = std::make_shared(scopes, context); - else - projection_manipulator = std::make_shared(scopes); - - LogAST log; - ActionsVisitor actions_visitor(context, getSetSizeLimits(settings), subquery_depth, - source_columns, prepared_sets, subqueries_for_sets, - no_subqueries, only_consts, noStorageOrLocal(), log.stream()); - actions_visitor.visit(ast, scopes, projection_manipulator); - - actions = scopes.popLevel(); -} - void ExpressionAnalyzer::getArrayJoinedColumns() { if (select_query && select_query->array_join_expression_list()) @@ -1239,31 +1219,39 @@ bool ExpressionAnalyzer::isThereArrayJoin(const ASTPtr & ast) } } -void ExpressionAnalyzer::getActionsFromJoinKeys(const ASTTableJoin & table_join, bool no_subqueries, bool only_consts, - ExpressionActionsPtr & actions) -{ - ScopeStack scopes(actions, context); - ProjectionManipulatorPtr projection_manipulator; - if (!isThereArrayJoin(query) && settings.enable_conditional_computation && !only_consts) - projection_manipulator = std::make_shared(scopes, context); - else - projection_manipulator = std::make_shared(scopes); +void ExpressionAnalyzer::getRootActions(const ASTPtr & ast, bool no_subqueries, ExpressionActionsPtr & actions, bool only_consts) +{ + bool is_conditional_tree = !isThereArrayJoin(ast) && settings.enable_conditional_computation && !only_consts; LogAST log; - ActionsVisitor actions_visitor(context, getSetSizeLimits(settings), subquery_depth, - source_columns, prepared_sets, subqueries_for_sets, + ActionsVisitor actions_visitor(context, getSetSizeLimits(settings), is_conditional_tree, subquery_depth, + source_columns, actions, prepared_sets, subqueries_for_sets, + no_subqueries, only_consts, noStorageOrLocal(), log.stream()); + actions_visitor.visit(ast); + actions = actions_visitor.popActionsLevel(); +} + + +void ExpressionAnalyzer::getActionsFromJoinKeys(const ASTTableJoin & table_join, bool no_subqueries, ExpressionActionsPtr & actions) +{ + bool only_consts = false; + bool is_conditional_tree = !isThereArrayJoin(query) && settings.enable_conditional_computation && !only_consts; + + LogAST log; + ActionsVisitor actions_visitor(context, getSetSizeLimits(settings), is_conditional_tree, subquery_depth, + source_columns, actions, prepared_sets, subqueries_for_sets, no_subqueries, only_consts, noStorageOrLocal(), log.stream()); if (table_join.using_expression_list) - actions_visitor.visit(table_join.using_expression_list, scopes, projection_manipulator); + actions_visitor.visit(table_join.using_expression_list); else if (table_join.on_expression) { for (const auto & ast : analyzed_join.key_asts_left) - actions_visitor.visit(ast, scopes, projection_manipulator); + actions_visitor.visit(ast); } - actions = scopes.popLevel(); + actions = actions_visitor.popActionsLevel(); } @@ -1304,7 +1292,7 @@ void ExpressionAnalyzer::getAggregates(const ASTPtr & ast, ExpressionActionsPtr /// There can not be other aggregate functions within the aggregate functions. assertNoAggregates(arguments[i], "inside another aggregate function"); - getRootActions(arguments[i], true, false, actions); + getRootActions(arguments[i], true, actions); const std::string & name = arguments[i]->getColumnName(); types[i] = actions->getSampleBlock().getByName(name).type; aggregate.argument_names[i] = name; @@ -1387,7 +1375,7 @@ bool ExpressionAnalyzer::appendArrayJoin(ExpressionActionsChain & chain, bool on initChain(chain, source_columns); ExpressionActionsChain::Step & step = chain.steps.back(); - getRootActions(select_query->array_join_expression_list(), only_types, false, step.actions); + getRootActions(select_query->array_join_expression_list(), only_types, step.actions); addMultipleArrayJoinAction(step.actions); @@ -1519,7 +1507,7 @@ bool ExpressionAnalyzer::appendJoin(ExpressionActionsChain & chain, bool only_ty const auto & table_to_join = static_cast(*join_element.table_expression); - getActionsFromJoinKeys(join_params, only_types, false, step.actions); + getActionsFromJoinKeys(join_params, only_types, step.actions); /// Two JOINs are not supported with the same subquery, but different USINGs. auto join_hash = join_element.getTreeHash(); @@ -1634,7 +1622,7 @@ bool ExpressionAnalyzer::appendPrewhere(ExpressionActionsChain & chain, bool onl initChain(chain, source_columns); auto & step = chain.getLastStep(); - getRootActions(select_query->prewhere_expression, only_types, false, step.actions); + getRootActions(select_query->prewhere_expression, only_types, step.actions); String prewhere_column_name = select_query->prewhere_expression->getColumnName(); step.required_output.push_back(prewhere_column_name); step.can_remove_required_output.push_back(true); @@ -1642,7 +1630,7 @@ bool ExpressionAnalyzer::appendPrewhere(ExpressionActionsChain & chain, bool onl { /// Remove unused source_columns from prewhere actions. auto tmp_actions = std::make_shared(source_columns, context); - getRootActions(select_query->prewhere_expression, only_types, false, tmp_actions); + getRootActions(select_query->prewhere_expression, only_types, tmp_actions); tmp_actions->finalize({prewhere_column_name}); auto required_columns = tmp_actions->getRequiredColumns(); NameSet required_source_columns(required_columns.begin(), required_columns.end()); @@ -1711,7 +1699,7 @@ bool ExpressionAnalyzer::appendWhere(ExpressionActionsChain & chain, bool only_t step.required_output.push_back(select_query->where_expression->getColumnName()); step.can_remove_required_output = {true}; - getRootActions(select_query->where_expression, only_types, false, step.actions); + getRootActions(select_query->where_expression, only_types, step.actions); return true; } @@ -1730,7 +1718,7 @@ bool ExpressionAnalyzer::appendGroupBy(ExpressionActionsChain & chain, bool only for (size_t i = 0; i < asts.size(); ++i) { step.required_output.push_back(asts[i]->getColumnName()); - getRootActions(asts[i], only_types, false, step.actions); + getRootActions(asts[i], only_types, step.actions); } return true; @@ -1771,7 +1759,7 @@ bool ExpressionAnalyzer::appendHaving(ExpressionActionsChain & chain, bool only_ ExpressionActionsChain::Step & step = chain.steps.back(); step.required_output.push_back(select_query->having_expression->getColumnName()); - getRootActions(select_query->having_expression, only_types, false, step.actions); + getRootActions(select_query->having_expression, only_types, step.actions); return true; } @@ -1783,7 +1771,7 @@ void ExpressionAnalyzer::appendSelect(ExpressionActionsChain & chain, bool only_ initChain(chain, aggregated_columns); ExpressionActionsChain::Step & step = chain.steps.back(); - getRootActions(select_query->select_expression_list, only_types, false, step.actions); + getRootActions(select_query->select_expression_list, only_types, step.actions); for (const auto & child : select_query->select_expression_list->children) step.required_output.push_back(child->getColumnName()); @@ -1799,7 +1787,7 @@ bool ExpressionAnalyzer::appendOrderBy(ExpressionActionsChain & chain, bool only initChain(chain, aggregated_columns); ExpressionActionsChain::Step & step = chain.steps.back(); - getRootActions(select_query->order_expression_list, only_types, false, step.actions); + getRootActions(select_query->order_expression_list, only_types, step.actions); ASTs asts = select_query->order_expression_list->children; for (size_t i = 0; i < asts.size(); ++i) @@ -1824,7 +1812,7 @@ bool ExpressionAnalyzer::appendLimitBy(ExpressionActionsChain & chain, bool only initChain(chain, aggregated_columns); ExpressionActionsChain::Step & step = chain.steps.back(); - getRootActions(select_query->limit_by_expression_list, only_types, false, step.actions); + getRootActions(select_query->limit_by_expression_list, only_types, step.actions); for (const auto & child : select_query->limit_by_expression_list->children) step.required_output.push_back(child->getColumnName()); @@ -1861,7 +1849,7 @@ void ExpressionAnalyzer::appendExpression(ExpressionActionsChain & chain, const { initChain(chain, source_columns); ExpressionActionsChain::Step & step = chain.steps.back(); - getRootActions(expr, only_types, false, step.actions); + getRootActions(expr, only_types, step.actions); step.required_output.push_back(expr->getColumnName()); } @@ -1872,7 +1860,7 @@ void ExpressionAnalyzer::getActionsBeforeAggregation(const ASTPtr & ast, Express if (node && AggregateFunctionFactory::instance().isAggregateFunctionName(node->name)) for (auto & argument : node->arguments->children) - getRootActions(argument, no_subqueries, false, actions); + getRootActions(argument, no_subqueries, actions); else for (auto & child : ast->children) getActionsBeforeAggregation(child, actions, no_subqueries); @@ -1902,7 +1890,7 @@ ExpressionActionsPtr ExpressionAnalyzer::getActions(bool add_aliases, bool proje alias = name; result_columns.emplace_back(name, alias); result_names.push_back(alias); - getRootActions(asts[i], false, false, actions); + getRootActions(asts[i], false, actions); } if (add_aliases) @@ -1930,8 +1918,7 @@ ExpressionActionsPtr ExpressionAnalyzer::getConstActions() { ExpressionActionsPtr actions = std::make_shared(NamesAndTypesList(), context); - getRootActions(query, true, true, actions); - + getRootActions(query, true, actions, true); return actions; } diff --git a/dbms/src/Interpreters/ExpressionAnalyzer.h b/dbms/src/Interpreters/ExpressionAnalyzer.h index c4619bba078..e816b947431 100644 --- a/dbms/src/Interpreters/ExpressionAnalyzer.h +++ b/dbms/src/Interpreters/ExpressionAnalyzer.h @@ -313,9 +313,9 @@ private: bool isThereArrayJoin(const ASTPtr & ast); /// If ast is ASTSelectQuery with JOIN, add actions for JOIN key columns. - void getActionsFromJoinKeys(const ASTTableJoin & table_join, bool no_subqueries, bool only_consts, ExpressionActionsPtr & actions); + void getActionsFromJoinKeys(const ASTTableJoin & table_join, bool no_subqueries, ExpressionActionsPtr & actions); - void getRootActions(const ASTPtr & ast, bool no_subqueries, bool only_consts, ExpressionActionsPtr & actions); + void getRootActions(const ASTPtr & ast, bool no_subqueries, ExpressionActionsPtr & actions, bool only_consts = false); void getActionsBeforeAggregation(const ASTPtr & ast, ExpressionActionsPtr & actions, bool no_subqueries);