better ActionsVisitor extraction [CLICKHOUSE-3996]

This commit is contained in:
chertus 2018-10-16 15:34:20 +03:00
parent e6e28d2451
commit 4071155043
4 changed files with 86 additions and 82 deletions

View File

@ -220,7 +220,31 @@ const Block & ScopeStack::getSampleBlock() const
}
void ActionsVisitor::visit(const ASTPtr & ast, ScopeStack & actions_stack, ProjectionManipulatorPtr projection_manipulator)
ActionsVisitor::ActionsVisitor(
const Context & context_, SizeLimits set_size_limit_, bool is_conditional_tree, size_t subquery_depth_,
const NamesAndTypesList & source_columns_, const ExpressionActionsPtr & actions,
PreparedSets & prepared_sets_, SubqueriesForSets & subqueries_for_sets_,
bool no_subqueries_, bool only_consts_, bool no_storage_or_local_, std::ostream * ostr_)
: context(context_),
set_size_limit(set_size_limit_),
subquery_depth(subquery_depth_),
source_columns(source_columns_),
prepared_sets(prepared_sets_),
subqueries_for_sets(subqueries_for_sets_),
no_subqueries(no_subqueries_),
only_consts(only_consts_),
no_storage_or_local(no_storage_or_local_),
visit_depth(0),
ostr(ostr_),
actions_stack(actions, context)
{
if (is_conditional_tree)
projection_manipulator = std::make_shared<ConditionalTree>(actions_stack, context);
else
projection_manipulator = std::make_shared<DefaultProjectionManipulator>(actions_stack);
}
void ActionsVisitor::visit(const ASTPtr & ast)
{
DumpASTNode dump(*ast, ostr, visit_depth, "getActions");
@ -267,7 +291,7 @@ void ActionsVisitor::visit(const ASTPtr & ast, ScopeStack & actions_stack, Proje
throw Exception("arrayJoin requires exactly 1 argument", ErrorCodes::TYPE_MISMATCH);
ASTPtr arg = node->arguments->children.at(0);
visit(arg, actions_stack, projection_manipulator);
visit(arg);
if (!only_consts)
{
String result_name = projection_manipulator->getColumnName(getColumnName());
@ -283,7 +307,7 @@ void ActionsVisitor::visit(const ASTPtr & ast, ScopeStack & actions_stack, Proje
if (functionIsInOrGlobalInOperator(node->name))
{
/// Let's find the type of the first argument (then getActionsImpl will be called again and will not affect anything).
visit(node->arguments->children.at(0), actions_stack, projection_manipulator);
visit(node->arguments->children.at(0));
if (!no_subqueries)
{
@ -387,7 +411,7 @@ void ActionsVisitor::visit(const ASTPtr & ast, ScopeStack & actions_stack, Proje
{
/// If the argument is not a lambda expression, call it recursively and find out its type.
projection_action->preArgumentAction();
visit(child, actions_stack, projection_manipulator);
visit(child);
std::string name = projection_manipulator->getColumnName(child_column_name);
projection_action->postArgumentAction(child_column_name);
if (actions_stack.getSampleBlock().has(name))
@ -442,7 +466,7 @@ void ActionsVisitor::visit(const ASTPtr & ast, ScopeStack & actions_stack, Proje
projection_action->preArgumentAction();
actions_stack.pushLevel(lambda_arguments);
visit(lambda->arguments->children.at(1), actions_stack, projection_manipulator);
visit(lambda->arguments->children.at(1));
ExpressionActionsPtr lambda_actions = actions_stack.popLevel();
String result_name = projection_manipulator->getColumnName(lambda->arguments->children.at(1)->getColumnName());
@ -515,7 +539,7 @@ void ActionsVisitor::visit(const ASTPtr & ast, ScopeStack & actions_stack, Proje
/// Do not go to FROM, JOIN, UNION.
if (!typeid_cast<const ASTTableExpression *>(child.get())
&& !typeid_cast<const ASTSelectQuery *>(child.get()))
visit(child, actions_stack, projection_manipulator);
visit(child);
}
}
}

View File

@ -8,12 +8,11 @@ namespace DB
class Context;
class ASTFunction;
struct ProjectionManipulatorBase;
class ExpressionActions;
using ExpressionActionsPtr = std::shared_ptr<ExpressionActions>;
struct ProjectionManipulatorBase;
using ProjectionManipulatorPtr = std::shared_ptr<ProjectionManipulatorBase>;
class Set;
using SetPtr = std::shared_ptr<Set>;
@ -87,27 +86,19 @@ struct ScopeStack
};
/// TODO: There sould be some description, but...
/// Collect ExpressionAction from AST. Returns PreparedSets and SubqueriesForSets too.
/// After AST is visited source ExpressionActions should be updated with popActionsLevel() method.
class ActionsVisitor
{
public:
ActionsVisitor(const Context & context_, SizeLimits set_size_limit_, size_t subquery_depth_,
const NamesAndTypesList & source_columns_, PreparedSets & prepared_sets_, SubqueriesForSets & subqueries_for_sets_,
bool no_subqueries_, bool only_consts_, bool no_storage_or_local_, std::ostream * ostr_ = nullptr)
: context(context_),
set_size_limit(set_size_limit_),
subquery_depth(subquery_depth_),
source_columns(source_columns_),
prepared_sets(prepared_sets_),
subqueries_for_sets(subqueries_for_sets_),
no_subqueries(no_subqueries_),
only_consts(only_consts_),
no_storage_or_local(no_storage_or_local_),
visit_depth(0),
ostr(ostr_)
{}
ActionsVisitor(const Context & context_, SizeLimits set_size_limit_, bool is_conditional_tree, size_t subquery_depth_,
const NamesAndTypesList & source_columns_, const ExpressionActionsPtr & actions,
PreparedSets & prepared_sets_, SubqueriesForSets & subqueries_for_sets_,
bool no_subqueries_, bool only_consts_, bool no_storage_or_local_, std::ostream * ostr_ = nullptr);
void visit(const ASTPtr & ast, ScopeStack & actions_stack, ProjectionManipulatorPtr projection_manipulator);
void visit(const ASTPtr & ast);
ExpressionActionsPtr popActionsLevel() { return actions_stack.popLevel(); }
private:
const Context & context;
@ -121,6 +112,8 @@ private:
const bool no_storage_or_local;
mutable size_t visit_depth;
std::ostream * ostr;
ScopeStack actions_stack;
std::shared_ptr<ProjectionManipulatorBase> projection_manipulator;
void makeSet(const ASTFunction * node, const Block & sample_block);
};

View File

@ -29,7 +29,6 @@
#include <Interpreters/ExternalDictionaries.h>
#include <Interpreters/Set.h>
#include <Interpreters/Join.h>
#include <Interpreters/ProjectionManipulation.h>
#include <Interpreters/TranslateQualifiedNamesVisitor.h>
#include <Interpreters/ExecuteScalarSubqueriesVisitor.h>
@ -439,7 +438,7 @@ void ExpressionAnalyzer::analyzeAggregation()
if (select_query && select_query->array_join_expression_list())
{
getRootActions(select_query->array_join_expression_list(), true, false, temp_actions);
getRootActions(select_query->array_join_expression_list(), true, temp_actions);
addMultipleArrayJoinAction(temp_actions);
array_join_columns = temp_actions->getSampleBlock().getNamesAndTypesList();
}
@ -451,10 +450,10 @@ void ExpressionAnalyzer::analyzeAggregation()
{
const auto table_join = static_cast<const ASTTableJoin &>(*join->table_join);
if (table_join.using_expression_list)
getRootActions(table_join.using_expression_list, true, false, temp_actions);
getRootActions(table_join.using_expression_list, true, temp_actions);
if (table_join.on_expression)
for (const auto & key_ast : analyzed_join.key_asts_left)
getRootActions(key_ast, true, false, temp_actions);
getRootActions(key_ast, true, temp_actions);
addJoinAction(temp_actions, true);
}
@ -474,7 +473,7 @@ void ExpressionAnalyzer::analyzeAggregation()
for (ssize_t i = 0; i < ssize_t(group_asts.size()); ++i)
{
ssize_t size = group_asts.size();
getRootActions(group_asts[i], true, false, temp_actions);
getRootActions(group_asts[i], true, temp_actions);
const auto & column_name = group_asts[i]->getColumnName();
const auto & block = temp_actions->getSampleBlock();
@ -1053,7 +1052,7 @@ void ExpressionAnalyzer::makeSetsForIndexImpl(const ASTPtr & node, const Block &
for (const auto & joined_column : analyzed_join.columns_added_by_join)
temp_columns.push_back(joined_column.name_and_type);
ExpressionActionsPtr temp_actions = std::make_shared<ExpressionActions>(temp_columns, context);
getRootActions(func->arguments->children.at(0), true, false, temp_actions);
getRootActions(func->arguments->children.at(0), true, temp_actions);
Block sample_block_with_calculated_columns = temp_actions->getSampleBlock();
if (sample_block_with_calculated_columns.has(args.children.at(0)->getColumnName()))
@ -1065,25 +1064,6 @@ void ExpressionAnalyzer::makeSetsForIndexImpl(const ASTPtr & node, const Block &
}
void ExpressionAnalyzer::getRootActions(const ASTPtr & ast, bool no_subqueries, bool only_consts, ExpressionActionsPtr & actions)
{
ScopeStack scopes(actions, context);
ProjectionManipulatorPtr projection_manipulator;
if (!isThereArrayJoin(ast) && settings.enable_conditional_computation && !only_consts)
projection_manipulator = std::make_shared<ConditionalTree>(scopes, context);
else
projection_manipulator = std::make_shared<DefaultProjectionManipulator>(scopes);
LogAST log;
ActionsVisitor actions_visitor(context, getSetSizeLimits(settings), subquery_depth,
source_columns, prepared_sets, subqueries_for_sets,
no_subqueries, only_consts, noStorageOrLocal(), log.stream());
actions_visitor.visit(ast, scopes, projection_manipulator);
actions = scopes.popLevel();
}
void ExpressionAnalyzer::getArrayJoinedColumns()
{
if (select_query && select_query->array_join_expression_list())
@ -1239,31 +1219,39 @@ bool ExpressionAnalyzer::isThereArrayJoin(const ASTPtr & ast)
}
}
void ExpressionAnalyzer::getActionsFromJoinKeys(const ASTTableJoin & table_join, bool no_subqueries, bool only_consts,
ExpressionActionsPtr & actions)
{
ScopeStack scopes(actions, context);
ProjectionManipulatorPtr projection_manipulator;
if (!isThereArrayJoin(query) && settings.enable_conditional_computation && !only_consts)
projection_manipulator = std::make_shared<ConditionalTree>(scopes, context);
else
projection_manipulator = std::make_shared<DefaultProjectionManipulator>(scopes);
void ExpressionAnalyzer::getRootActions(const ASTPtr & ast, bool no_subqueries, ExpressionActionsPtr & actions, bool only_consts)
{
bool is_conditional_tree = !isThereArrayJoin(ast) && settings.enable_conditional_computation && !only_consts;
LogAST log;
ActionsVisitor actions_visitor(context, getSetSizeLimits(settings), subquery_depth,
source_columns, prepared_sets, subqueries_for_sets,
ActionsVisitor actions_visitor(context, getSetSizeLimits(settings), is_conditional_tree, subquery_depth,
source_columns, actions, prepared_sets, subqueries_for_sets,
no_subqueries, only_consts, noStorageOrLocal(), log.stream());
actions_visitor.visit(ast);
actions = actions_visitor.popActionsLevel();
}
void ExpressionAnalyzer::getActionsFromJoinKeys(const ASTTableJoin & table_join, bool no_subqueries, ExpressionActionsPtr & actions)
{
bool only_consts = false;
bool is_conditional_tree = !isThereArrayJoin(query) && settings.enable_conditional_computation && !only_consts;
LogAST log;
ActionsVisitor actions_visitor(context, getSetSizeLimits(settings), is_conditional_tree, subquery_depth,
source_columns, actions, prepared_sets, subqueries_for_sets,
no_subqueries, only_consts, noStorageOrLocal(), log.stream());
if (table_join.using_expression_list)
actions_visitor.visit(table_join.using_expression_list, scopes, projection_manipulator);
actions_visitor.visit(table_join.using_expression_list);
else if (table_join.on_expression)
{
for (const auto & ast : analyzed_join.key_asts_left)
actions_visitor.visit(ast, scopes, projection_manipulator);
actions_visitor.visit(ast);
}
actions = scopes.popLevel();
actions = actions_visitor.popActionsLevel();
}
@ -1304,7 +1292,7 @@ void ExpressionAnalyzer::getAggregates(const ASTPtr & ast, ExpressionActionsPtr
/// There can not be other aggregate functions within the aggregate functions.
assertNoAggregates(arguments[i], "inside another aggregate function");
getRootActions(arguments[i], true, false, actions);
getRootActions(arguments[i], true, actions);
const std::string & name = arguments[i]->getColumnName();
types[i] = actions->getSampleBlock().getByName(name).type;
aggregate.argument_names[i] = name;
@ -1387,7 +1375,7 @@ bool ExpressionAnalyzer::appendArrayJoin(ExpressionActionsChain & chain, bool on
initChain(chain, source_columns);
ExpressionActionsChain::Step & step = chain.steps.back();
getRootActions(select_query->array_join_expression_list(), only_types, false, step.actions);
getRootActions(select_query->array_join_expression_list(), only_types, step.actions);
addMultipleArrayJoinAction(step.actions);
@ -1519,7 +1507,7 @@ bool ExpressionAnalyzer::appendJoin(ExpressionActionsChain & chain, bool only_ty
const auto & table_to_join = static_cast<const ASTTableExpression &>(*join_element.table_expression);
getActionsFromJoinKeys(join_params, only_types, false, step.actions);
getActionsFromJoinKeys(join_params, only_types, step.actions);
/// Two JOINs are not supported with the same subquery, but different USINGs.
auto join_hash = join_element.getTreeHash();
@ -1634,7 +1622,7 @@ bool ExpressionAnalyzer::appendPrewhere(ExpressionActionsChain & chain, bool onl
initChain(chain, source_columns);
auto & step = chain.getLastStep();
getRootActions(select_query->prewhere_expression, only_types, false, step.actions);
getRootActions(select_query->prewhere_expression, only_types, step.actions);
String prewhere_column_name = select_query->prewhere_expression->getColumnName();
step.required_output.push_back(prewhere_column_name);
step.can_remove_required_output.push_back(true);
@ -1642,7 +1630,7 @@ bool ExpressionAnalyzer::appendPrewhere(ExpressionActionsChain & chain, bool onl
{
/// Remove unused source_columns from prewhere actions.
auto tmp_actions = std::make_shared<ExpressionActions>(source_columns, context);
getRootActions(select_query->prewhere_expression, only_types, false, tmp_actions);
getRootActions(select_query->prewhere_expression, only_types, tmp_actions);
tmp_actions->finalize({prewhere_column_name});
auto required_columns = tmp_actions->getRequiredColumns();
NameSet required_source_columns(required_columns.begin(), required_columns.end());
@ -1711,7 +1699,7 @@ bool ExpressionAnalyzer::appendWhere(ExpressionActionsChain & chain, bool only_t
step.required_output.push_back(select_query->where_expression->getColumnName());
step.can_remove_required_output = {true};
getRootActions(select_query->where_expression, only_types, false, step.actions);
getRootActions(select_query->where_expression, only_types, step.actions);
return true;
}
@ -1730,7 +1718,7 @@ bool ExpressionAnalyzer::appendGroupBy(ExpressionActionsChain & chain, bool only
for (size_t i = 0; i < asts.size(); ++i)
{
step.required_output.push_back(asts[i]->getColumnName());
getRootActions(asts[i], only_types, false, step.actions);
getRootActions(asts[i], only_types, step.actions);
}
return true;
@ -1771,7 +1759,7 @@ bool ExpressionAnalyzer::appendHaving(ExpressionActionsChain & chain, bool only_
ExpressionActionsChain::Step & step = chain.steps.back();
step.required_output.push_back(select_query->having_expression->getColumnName());
getRootActions(select_query->having_expression, only_types, false, step.actions);
getRootActions(select_query->having_expression, only_types, step.actions);
return true;
}
@ -1783,7 +1771,7 @@ void ExpressionAnalyzer::appendSelect(ExpressionActionsChain & chain, bool only_
initChain(chain, aggregated_columns);
ExpressionActionsChain::Step & step = chain.steps.back();
getRootActions(select_query->select_expression_list, only_types, false, step.actions);
getRootActions(select_query->select_expression_list, only_types, step.actions);
for (const auto & child : select_query->select_expression_list->children)
step.required_output.push_back(child->getColumnName());
@ -1799,7 +1787,7 @@ bool ExpressionAnalyzer::appendOrderBy(ExpressionActionsChain & chain, bool only
initChain(chain, aggregated_columns);
ExpressionActionsChain::Step & step = chain.steps.back();
getRootActions(select_query->order_expression_list, only_types, false, step.actions);
getRootActions(select_query->order_expression_list, only_types, step.actions);
ASTs asts = select_query->order_expression_list->children;
for (size_t i = 0; i < asts.size(); ++i)
@ -1824,7 +1812,7 @@ bool ExpressionAnalyzer::appendLimitBy(ExpressionActionsChain & chain, bool only
initChain(chain, aggregated_columns);
ExpressionActionsChain::Step & step = chain.steps.back();
getRootActions(select_query->limit_by_expression_list, only_types, false, step.actions);
getRootActions(select_query->limit_by_expression_list, only_types, step.actions);
for (const auto & child : select_query->limit_by_expression_list->children)
step.required_output.push_back(child->getColumnName());
@ -1861,7 +1849,7 @@ void ExpressionAnalyzer::appendExpression(ExpressionActionsChain & chain, const
{
initChain(chain, source_columns);
ExpressionActionsChain::Step & step = chain.steps.back();
getRootActions(expr, only_types, false, step.actions);
getRootActions(expr, only_types, step.actions);
step.required_output.push_back(expr->getColumnName());
}
@ -1872,7 +1860,7 @@ void ExpressionAnalyzer::getActionsBeforeAggregation(const ASTPtr & ast, Express
if (node && AggregateFunctionFactory::instance().isAggregateFunctionName(node->name))
for (auto & argument : node->arguments->children)
getRootActions(argument, no_subqueries, false, actions);
getRootActions(argument, no_subqueries, actions);
else
for (auto & child : ast->children)
getActionsBeforeAggregation(child, actions, no_subqueries);
@ -1902,7 +1890,7 @@ ExpressionActionsPtr ExpressionAnalyzer::getActions(bool add_aliases, bool proje
alias = name;
result_columns.emplace_back(name, alias);
result_names.push_back(alias);
getRootActions(asts[i], false, false, actions);
getRootActions(asts[i], false, actions);
}
if (add_aliases)
@ -1930,8 +1918,7 @@ ExpressionActionsPtr ExpressionAnalyzer::getConstActions()
{
ExpressionActionsPtr actions = std::make_shared<ExpressionActions>(NamesAndTypesList(), context);
getRootActions(query, true, true, actions);
getRootActions(query, true, actions, true);
return actions;
}

View File

@ -313,9 +313,9 @@ private:
bool isThereArrayJoin(const ASTPtr & ast);
/// If ast is ASTSelectQuery with JOIN, add actions for JOIN key columns.
void getActionsFromJoinKeys(const ASTTableJoin & table_join, bool no_subqueries, bool only_consts, ExpressionActionsPtr & actions);
void getActionsFromJoinKeys(const ASTTableJoin & table_join, bool no_subqueries, ExpressionActionsPtr & actions);
void getRootActions(const ASTPtr & ast, bool no_subqueries, bool only_consts, ExpressionActionsPtr & actions);
void getRootActions(const ASTPtr & ast, bool no_subqueries, ExpressionActionsPtr & actions, bool only_consts = false);
void getActionsBeforeAggregation(const ASTPtr & ast, ExpressionActionsPtr & actions, bool no_subqueries);