From 994f9f3cc6f21beacf0e0f643daf62cd0beda5af Mon Sep 17 00:00:00 2001 From: chertus Date: Thu, 15 Aug 2019 16:54:59 +0300 Subject: [PATCH] unify ActionsVisitor: rewrite as InDepthNodeVisitor --- dbms/src/Interpreters/ActionsVisitor.cpp | 551 ++++++++++--------- dbms/src/Interpreters/ActionsVisitor.h | 80 ++- dbms/src/Interpreters/ExpressionAnalyzer.cpp | 7 +- 3 files changed, 339 insertions(+), 299 deletions(-) diff --git a/dbms/src/Interpreters/ActionsVisitor.cpp b/dbms/src/Interpreters/ActionsVisitor.cpp index 523343a288e..7c6f97d5ed5 100644 --- a/dbms/src/Interpreters/ActionsVisitor.cpp +++ b/dbms/src/Interpreters/ActionsVisitor.cpp @@ -1,3 +1,5 @@ +#include + #include #include @@ -19,8 +21,6 @@ #include #include -#include -#include #include #include #include @@ -228,346 +228,351 @@ const Block & ScopeStack::getSampleBlock() const return stack.back().actions->getSampleBlock(); } - -ActionsVisitor::ActionsVisitor( - const Context & context_, SizeLimits set_size_limit_, size_t subquery_depth_, - const NamesAndTypesList & source_columns_, const ExpressionActionsPtr & actions, - PreparedSets & prepared_sets_, SubqueriesForSets & subqueries_for_sets_, - bool no_subqueries_, bool only_consts_, bool no_storage_or_local_, std::ostream * ostr_) -: context(context_), - set_size_limit(set_size_limit_), - subquery_depth(subquery_depth_), - source_columns(source_columns_), - prepared_sets(prepared_sets_), - subqueries_for_sets(subqueries_for_sets_), - no_subqueries(no_subqueries_), - only_consts(only_consts_), - no_storage_or_local(no_storage_or_local_), - visit_depth(0), - ostr(ostr_), - actions_stack(actions, context) +struct CachedColumnName { + String cached; + + const String & get(const ASTPtr & ast) + { + if (cached.empty()) + cached = ast->getColumnName(); + return cached; + } +}; + +bool ActionsMatcher::needChildVisit(const ASTPtr & node, const ASTPtr & child) +{ + /// Visit children themself + if (node->as() || + node->as() || + node->as()) + return false; + + /// Do not go to FROM, JOIN, UNION. + if (child->as() || + child->as()) + return false; + + return true; } -void ActionsVisitor::visit(const ASTPtr & ast) +void ActionsMatcher::visit(const ASTPtr & ast, Data & data) { - DumpASTNode dump(*ast, ostr, visit_depth, "getActions"); + if (const auto * identifier = ast->as()) + visit(*identifier, ast, data); + else if (const auto * node = ast->as()) + visit(*node, ast, data); + else if (const auto * literal = ast->as()) + visit(*literal, ast, data); +} - String ast_column_name; - auto getColumnName = [&ast, &ast_column_name]() +void ActionsMatcher::visit(const ASTIdentifier & identifier, const ASTPtr & ast, Data & data) +{ + CachedColumnName column_name; + + if (!data.only_consts && !data.actions_stack.getSampleBlock().has(column_name.get(ast))) { - if (ast_column_name.empty()) - ast_column_name = ast->getColumnName(); + /// The requested column is not in the block. + /// If such a column exists in the table, then the user probably forgot to surround it with an aggregate function or add it to GROUP BY. - return ast_column_name; - }; + bool found = false; + for (const auto & column_name_type : data.source_columns) + if (column_name_type.name == column_name.get(ast)) + found = true; - /// If the result of the calculation already exists in the block. - if ((ast->as() || ast->as()) && actions_stack.getSampleBlock().has(getColumnName())) + if (found) + throw Exception("Column " + column_name.get(ast) + " is not under aggregate function and not in GROUP BY.", + ErrorCodes::NOT_AN_AGGREGATE); + + /// Special check for WITH statement alias. Add alias action to be able to use this alias. + if (identifier.prefer_alias_to_column_name && !identifier.alias.empty()) + data.actions_stack.addAction(ExpressionAction::addAliases({{identifier.name, identifier.alias}})); + } +} + +void ActionsMatcher::visit(const ASTFunction & node, const ASTPtr & ast, Data & data) +{ + CachedColumnName column_name; + + if (data.hasColumn(column_name.get(ast))) return; - if (const auto * identifier = ast->as()) + if (node.name == "lambda") + throw Exception("Unexpected lambda expression", ErrorCodes::UNEXPECTED_EXPRESSION); + + /// Function arrayJoin. + if (node.name == "arrayJoin") { - if (!only_consts && !actions_stack.getSampleBlock().has(getColumnName())) + if (node.arguments->children.size() != 1) + throw Exception("arrayJoin requires exactly 1 argument", ErrorCodes::TYPE_MISMATCH); + + ASTPtr arg = node.arguments->children.at(0); + visit(arg, data); + if (!data.only_consts) { - /// The requested column is not in the block. - /// If such a column exists in the table, then the user probably forgot to surround it with an aggregate function or add it to GROUP BY. - - bool found = false; - for (const auto & column_name_type : source_columns) - if (column_name_type.name == getColumnName()) - found = true; - - if (found) - throw Exception("Column " + getColumnName() + " is not under aggregate function and not in GROUP BY.", - ErrorCodes::NOT_AN_AGGREGATE); - - /// Special check for WITH statement alias. Add alias action to be able to use this alias. - if (identifier->prefer_alias_to_column_name && !identifier->alias.empty()) - actions_stack.addAction(ExpressionAction::addAliases({{identifier->name, identifier->alias}})); + String result_name = column_name.get(ast); + data.actions_stack.addAction(ExpressionAction::copyColumn(arg->getColumnName(), result_name)); + NameSet joined_columns; + joined_columns.insert(result_name); + data.actions_stack.addAction(ExpressionAction::arrayJoin(joined_columns, false, data.context)); } + + return; } - else if (const auto * node = ast->as()) + + SetPtr prepared_set; + if (functionIsInOrGlobalInOperator(node.name)) { - if (node->name == "lambda") - throw Exception("Unexpected lambda expression", ErrorCodes::UNEXPECTED_EXPRESSION); + /// Let's find the type of the first argument (then getActionsImpl will be called again and will not affect anything). + visit(node.arguments->children.at(0), data); - /// Function arrayJoin. - if (node->name == "arrayJoin") + if (!data.no_subqueries) { - if (node->arguments->children.size() != 1) - throw Exception("arrayJoin requires exactly 1 argument", ErrorCodes::TYPE_MISMATCH); - - ASTPtr arg = node->arguments->children.at(0); - visit(arg); - if (!only_consts) + /// Transform tuple or subquery into a set. + prepared_set = makeSet(node, data); + } + else + { + if (!data.only_consts) { - String result_name = getColumnName(); - actions_stack.addAction(ExpressionAction::copyColumn(arg->getColumnName(), result_name)); - NameSet joined_columns; - joined_columns.insert(result_name); - actions_stack.addAction(ExpressionAction::arrayJoin(joined_columns, false, context)); - } + /// We are in the part of the tree that we are not going to compute. You just need to define types. + /// Do not subquery and create sets. We treat "IN" as "ignoreExceptNull" function. + data.actions_stack.addAction(ExpressionAction::applyFunction( + FunctionFactory::instance().get("ignoreExceptNull", data.context), + { node.arguments->children.at(0)->getColumnName() }, + column_name.get(ast))); + } return; } + } - SetPtr prepared_set; - if (functionIsInOrGlobalInOperator(node->name)) + /// A special function `indexHint`. Everything that is inside it is not calculated + /// (and is used only for index analysis, see KeyCondition). + if (node.name == "indexHint") + { + data.actions_stack.addAction(ExpressionAction::addColumn(ColumnWithTypeAndName( + ColumnConst::create(ColumnUInt8::create(1, 1), 1), std::make_shared(), + column_name.get(ast)))); + return; + } + + if (AggregateFunctionFactory::instance().isAggregateFunctionName(node.name)) + return; + + /// Context object that we pass to function should live during query. + const Context & function_context = data.context.hasQueryContext() + ? data.context.getQueryContext() + : data.context; + + FunctionBuilderPtr function_builder; + try + { + function_builder = FunctionFactory::instance().get(node.name, function_context); + } + catch (DB::Exception & e) + { + auto hints = AggregateFunctionFactory::instance().getHints(node.name); + if (!hints.empty()) + e.addMessage("Or unknown aggregate function " + node.name + ". Maybe you meant: " + toString(hints)); + e.rethrow(); + } + + Names argument_names; + DataTypes argument_types; + bool arguments_present = true; + + /// If the function has an argument-lambda expression, you need to determine its type before the recursive call. + bool has_lambda_arguments = false; + + for (size_t arg = 0; arg < node.arguments->children.size(); ++arg) + { + auto & child = node.arguments->children[arg]; + auto child_column_name = child->getColumnName(); + + const auto * lambda = child->as(); + if (lambda && lambda->name == "lambda") { - /// Let's find the type of the first argument (then getActionsImpl will be called again and will not affect anything). - visit(node->arguments->children.at(0)); + /// If the argument is a lambda expression, just remember its approximate type. + if (lambda->arguments->children.size() != 2) + throw Exception("lambda requires two arguments", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); - if (!no_subqueries) + const auto * lambda_args_tuple = lambda->arguments->children.at(0)->as(); + + if (!lambda_args_tuple || lambda_args_tuple->name != "tuple") + throw Exception("First argument of lambda must be a tuple", ErrorCodes::TYPE_MISMATCH); + + has_lambda_arguments = true; + argument_types.emplace_back(std::make_shared(DataTypes(lambda_args_tuple->arguments->children.size()))); + /// Select the name in the next cycle. + argument_names.emplace_back(); + } + else if (functionIsInOrGlobalInOperator(node.name) && arg == 1 && prepared_set) + { + ColumnWithTypeAndName column; + column.type = std::make_shared(); + + /// If the argument is a set given by an enumeration of values (so, the set was already built), give it a unique name, + /// so that sets with the same literal representation do not fuse together (they can have different types). + if (!prepared_set->empty()) + column.name = getUniqueName(data.actions_stack.getSampleBlock(), "__set"); + else + column.name = child_column_name; + + if (!data.actions_stack.getSampleBlock().has(column.name)) { - /// Transform tuple or subquery into a set. - prepared_set = makeSet(node, actions_stack.getSampleBlock()); + column.column = ColumnSet::create(1, prepared_set); + + data.actions_stack.addAction(ExpressionAction::addColumn(column)); + } + + argument_types.push_back(column.type); + argument_names.push_back(column.name); + } + else + { + /// If the argument is not a lambda expression, call it recursively and find out its type. + visit(child, data); + std::string name = child_column_name; + if (data.actions_stack.getSampleBlock().has(name)) + { + argument_types.push_back(data.actions_stack.getSampleBlock().getByName(name).type); + argument_names.push_back(name); } else { - if (!only_consts) - { - /// We are in the part of the tree that we are not going to compute. You just need to define types. - /// Do not subquery and create sets. We treat "IN" as "ignoreExceptNull" function. - - actions_stack.addAction(ExpressionAction::applyFunction( - FunctionFactory::instance().get("ignoreExceptNull", context), - { node->arguments->children.at(0)->getColumnName() }, - getColumnName())); - } - return; + if (data.only_consts) + arguments_present = false; + else + throw Exception("Unknown identifier: " + name, ErrorCodes::UNKNOWN_IDENTIFIER); } } + } - /// A special function `indexHint`. Everything that is inside it is not calculated - /// (and is used only for index analysis, see KeyCondition). - if (node->name == "indexHint") + if (data.only_consts && !arguments_present) + return; + + if (has_lambda_arguments && !data.only_consts) + { + function_builder->getLambdaArgumentTypes(argument_types); + + /// Call recursively for lambda expressions. + for (size_t i = 0; i < node.arguments->children.size(); ++i) { - actions_stack.addAction(ExpressionAction::addColumn(ColumnWithTypeAndName( - ColumnConst::create(ColumnUInt8::create(1, 1), 1), std::make_shared(), - getColumnName()))); - return; - } - - if (AggregateFunctionFactory::instance().isAggregateFunctionName(node->name)) - return; - - /// Context object that we pass to function should live during query. - const Context & function_context = context.hasQueryContext() - ? context.getQueryContext() - : context; - - FunctionBuilderPtr function_builder; - try - { - function_builder = FunctionFactory::instance().get(node->name, function_context); - } - catch (DB::Exception & e) - { - auto hints = AggregateFunctionFactory::instance().getHints(node->name); - if (!hints.empty()) - e.addMessage("Or unknown aggregate function " + node->name + ". Maybe you meant: " + toString(hints)); - e.rethrow(); - } - - Names argument_names; - DataTypes argument_types; - bool arguments_present = true; - - /// If the function has an argument-lambda expression, you need to determine its type before the recursive call. - bool has_lambda_arguments = false; - - for (size_t arg = 0; arg < node->arguments->children.size(); ++arg) - { - auto & child = node->arguments->children[arg]; - auto child_column_name = child->getColumnName(); + ASTPtr child = node.arguments->children[i]; const auto * lambda = child->as(); if (lambda && lambda->name == "lambda") { - /// If the argument is a lambda expression, just remember its approximate type. - if (lambda->arguments->children.size() != 2) - throw Exception("lambda requires two arguments", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); - + const DataTypeFunction * lambda_type = typeid_cast(argument_types[i].get()); const auto * lambda_args_tuple = lambda->arguments->children.at(0)->as(); + const ASTs & lambda_arg_asts = lambda_args_tuple->arguments->children; + NamesAndTypesList lambda_arguments; - if (!lambda_args_tuple || lambda_args_tuple->name != "tuple") - throw Exception("First argument of lambda must be a tuple", ErrorCodes::TYPE_MISMATCH); - - has_lambda_arguments = true; - argument_types.emplace_back(std::make_shared(DataTypes(lambda_args_tuple->arguments->children.size()))); - /// Select the name in the next cycle. - argument_names.emplace_back(); - } - else if (functionIsInOrGlobalInOperator(node->name) && arg == 1 && prepared_set) - { - ColumnWithTypeAndName column; - column.type = std::make_shared(); - - /// If the argument is a set given by an enumeration of values (so, the set was already built), give it a unique name, - /// so that sets with the same literal representation do not fuse together (they can have different types). - if (!prepared_set->empty()) - column.name = getUniqueName(actions_stack.getSampleBlock(), "__set"); - else - column.name = child_column_name; - - if (!actions_stack.getSampleBlock().has(column.name)) + for (size_t j = 0; j < lambda_arg_asts.size(); ++j) { - column.column = ColumnSet::create(1, prepared_set); + auto opt_arg_name = tryGetIdentifierName(lambda_arg_asts[j]); + if (!opt_arg_name) + throw Exception("lambda argument declarations must be identifiers", ErrorCodes::TYPE_MISMATCH); - actions_stack.addAction(ExpressionAction::addColumn(column)); + lambda_arguments.emplace_back(*opt_arg_name, lambda_type->getArgumentTypes()[j]); } - argument_types.push_back(column.type); - argument_names.push_back(column.name); + data.actions_stack.pushLevel(lambda_arguments); + visit(lambda->arguments->children.at(1), data); + ExpressionActionsPtr lambda_actions = data.actions_stack.popLevel(); + + String result_name = lambda->arguments->children.at(1)->getColumnName(); + lambda_actions->finalize(Names(1, result_name)); + DataTypePtr result_type = lambda_actions->getSampleBlock().getByName(result_name).type; + + Names captured; + Names required = lambda_actions->getRequiredColumns(); + for (const auto & required_arg : required) + if (findColumn(required_arg, lambda_arguments) == lambda_arguments.end()) + captured.push_back(required_arg); + + /// We can not name `getColumnName()`, + /// because it does not uniquely define the expression (the types of arguments can be different). + String lambda_name = getUniqueName(data.actions_stack.getSampleBlock(), "__lambda"); + + auto function_capture = std::make_shared( + lambda_actions, captured, lambda_arguments, result_type, result_name); + data.actions_stack.addAction(ExpressionAction::applyFunction(function_capture, captured, lambda_name)); + + argument_types[i] = std::make_shared(lambda_type->getArgumentTypes(), result_type); + argument_names[i] = lambda_name; } - else - { - /// If the argument is not a lambda expression, call it recursively and find out its type. - visit(child); - std::string name = child_column_name; - if (actions_stack.getSampleBlock().has(name)) - { - argument_types.push_back(actions_stack.getSampleBlock().getByName(name).type); - argument_names.push_back(name); - } - else - { - if (only_consts) - arguments_present = false; - else - throw Exception("Unknown identifier: " + name, ErrorCodes::UNKNOWN_IDENTIFIER); - } - } - } - - if (only_consts && !arguments_present) - return; - - if (has_lambda_arguments && !only_consts) - { - function_builder->getLambdaArgumentTypes(argument_types); - - /// Call recursively for lambda expressions. - for (size_t i = 0; i < node->arguments->children.size(); ++i) - { - ASTPtr child = node->arguments->children[i]; - - const auto * lambda = child->as(); - if (lambda && lambda->name == "lambda") - { - const DataTypeFunction * lambda_type = typeid_cast(argument_types[i].get()); - const auto * lambda_args_tuple = lambda->arguments->children.at(0)->as(); - const ASTs & lambda_arg_asts = lambda_args_tuple->arguments->children; - NamesAndTypesList lambda_arguments; - - for (size_t j = 0; j < lambda_arg_asts.size(); ++j) - { - auto opt_arg_name = tryGetIdentifierName(lambda_arg_asts[j]); - if (!opt_arg_name) - throw Exception("lambda argument declarations must be identifiers", ErrorCodes::TYPE_MISMATCH); - - lambda_arguments.emplace_back(*opt_arg_name, lambda_type->getArgumentTypes()[j]); - } - - actions_stack.pushLevel(lambda_arguments); - visit(lambda->arguments->children.at(1)); - ExpressionActionsPtr lambda_actions = actions_stack.popLevel(); - - String result_name = lambda->arguments->children.at(1)->getColumnName(); - lambda_actions->finalize(Names(1, result_name)); - DataTypePtr result_type = lambda_actions->getSampleBlock().getByName(result_name).type; - - Names captured; - Names required = lambda_actions->getRequiredColumns(); - for (const auto & required_arg : required) - if (findColumn(required_arg, lambda_arguments) == lambda_arguments.end()) - captured.push_back(required_arg); - - /// We can not name `getColumnName()`, - /// because it does not uniquely define the expression (the types of arguments can be different). - String lambda_name = getUniqueName(actions_stack.getSampleBlock(), "__lambda"); - - auto function_capture = std::make_shared( - lambda_actions, captured, lambda_arguments, result_type, result_name); - actions_stack.addAction(ExpressionAction::applyFunction(function_capture, captured, lambda_name)); - - argument_types[i] = std::make_shared(lambda_type->getArgumentTypes(), result_type); - argument_names[i] = lambda_name; - } - } - } - - if (only_consts) - { - for (const auto & argument_name : argument_names) - { - if (!actions_stack.getSampleBlock().has(argument_name)) - { - arguments_present = false; - break; - } - } - } - - if (arguments_present) - { - actions_stack.addAction( - ExpressionAction::applyFunction(function_builder, argument_names, getColumnName())); } } - else if (const auto * literal = ast->as()) - { - DataTypePtr type = applyVisitor(FieldToDataType(), literal->value); - ColumnWithTypeAndName column; - column.column = type->createColumnConst(1, convertFieldToType(literal->value, *type)); - column.type = type; - column.name = getColumnName(); - - actions_stack.addAction(ExpressionAction::addColumn(column)); - } - else + if (data.only_consts) { - for (auto & child : ast->children) + for (const auto & argument_name : argument_names) { - /// Do not go to FROM, JOIN, UNION. - if (!child->as() && !child->as()) - visit(child); + if (!data.actions_stack.getSampleBlock().has(argument_name)) + { + arguments_present = false; + break; + } } } + + if (arguments_present) + { + data.actions_stack.addAction( + ExpressionAction::applyFunction(function_builder, argument_names, column_name.get(ast))); + } } -SetPtr ActionsVisitor::makeSet(const ASTFunction * node, const Block & sample_block) +void ActionsMatcher::visit(const ASTLiteral & literal, const ASTPtr & ast, Data & data) +{ + CachedColumnName column_name; + + if (data.hasColumn(column_name.get(ast))) + return; + + DataTypePtr type = applyVisitor(FieldToDataType(), literal.value); + + ColumnWithTypeAndName column; + column.column = type->createColumnConst(1, convertFieldToType(literal.value, *type)); + column.type = type; + column.name = column_name.get(ast); + + data.actions_stack.addAction(ExpressionAction::addColumn(column)); +} + +SetPtr ActionsMatcher::makeSet(const ASTFunction & node, Data & data) { /** You need to convert the right argument to a set. * This can be a table name, a value, a value enumeration, or a subquery. * The enumeration of values is parsed as a function `tuple`. */ - const IAST & args = *node->arguments; + const IAST & args = *node.arguments; const ASTPtr & arg = args.children.at(1); + const Block & sample_block = data.actions_stack.getSampleBlock(); /// If the subquery or table name for SELECT. const auto * identifier = arg->as(); if (arg->as() || identifier) { auto set_key = PreparedSetKey::forSubquery(*arg); - if (prepared_sets.count(set_key)) - return prepared_sets.at(set_key); + if (data.prepared_sets.count(set_key)) + return data.prepared_sets.at(set_key); /// A special case is if the name of the table is specified on the right side of the IN statement, /// and the table has the type Set (a previously prepared set). if (identifier) { DatabaseAndTableWithAlias database_table(*identifier); - StoragePtr table = context.tryGetTable(database_table.database, database_table.table); + StoragePtr table = data.context.tryGetTable(database_table.database, database_table.table); if (table) { StorageSet * storage_set = dynamic_cast(table.get()); if (storage_set) { - prepared_sets[set_key] = storage_set->getSet(); + data.prepared_sets[set_key] = storage_set->getSet(); return storage_set->getSet(); } } @@ -576,25 +581,25 @@ SetPtr ActionsVisitor::makeSet(const ASTFunction * node, const Block & sample_bl /// We get the stream of blocks for the subquery. Create Set and put it in place of the subquery. String set_id = arg->getColumnName(); - SubqueryForSet & subquery_for_set = subqueries_for_sets[set_id]; + SubqueryForSet & subquery_for_set = data.subqueries_for_sets[set_id]; /// If you already created a Set with the same subquery / table. if (subquery_for_set.set) { - prepared_sets[set_key] = subquery_for_set.set; + data.prepared_sets[set_key] = subquery_for_set.set; return subquery_for_set.set; } - SetPtr set = std::make_shared(set_size_limit, false); + SetPtr set = std::make_shared(data.set_size_limit, false); /** The following happens for GLOBAL INs: * - in the addExternalStorage function, the IN (SELECT ...) subquery is replaced with IN _data1, * in the subquery_for_set object, this subquery is set as source and the temporary table _data1 as the table. * - this function shows the expression IN_data1. */ - if (!subquery_for_set.source && no_storage_or_local) + if (!subquery_for_set.source && data.no_storage_or_local) { - auto interpreter = interpretSubquery(arg, context, subquery_depth, {}); + auto interpreter = interpretSubquery(arg, data.context, data.subquery_depth, {}); subquery_for_set.source = std::make_shared( interpreter->getSampleBlock(), [interpreter]() mutable { return interpreter->execute().in; }); @@ -627,13 +632,13 @@ SetPtr ActionsVisitor::makeSet(const ASTFunction * node, const Block & sample_bl } subquery_for_set.set = set; - prepared_sets[set_key] = set; + data.prepared_sets[set_key] = set; return set; } else { /// An explicit enumeration of values in parentheses. - return makeExplicitSet(node, sample_block, false, context, set_size_limit, prepared_sets); + return makeExplicitSet(&node, sample_block, false, data.context, data.set_size_limit, data.prepared_sets); } } diff --git a/dbms/src/Interpreters/ActionsVisitor.h b/dbms/src/Interpreters/ActionsVisitor.h index 4d03f758f61..963dd9f8675 100644 --- a/dbms/src/Interpreters/ActionsVisitor.h +++ b/dbms/src/Interpreters/ActionsVisitor.h @@ -4,6 +4,7 @@ #include #include #include +#include namespace DB @@ -52,38 +53,71 @@ struct ScopeStack const Block & getSampleBlock() const; }; +class ASTIdentifier; +class ASTFunction; +class ASTLiteral; /// Collect ExpressionAction from AST. Returns PreparedSets and SubqueriesForSets too. -class ActionsVisitor +class ActionsMatcher { public: - ActionsVisitor(const Context & context_, SizeLimits set_size_limit_, size_t subquery_depth_, - const NamesAndTypesList & source_columns_, const ExpressionActionsPtr & actions, - PreparedSets & prepared_sets_, SubqueriesForSets & subqueries_for_sets_, - bool no_subqueries_, bool only_consts_, bool no_storage_or_local_, std::ostream * ostr_ = nullptr); + using Visitor = ConstInDepthNodeVisitor; - void visit(const ASTPtr & ast, ExpressionActionsPtr & actions) + struct Data { - visit(ast); - actions = actions_stack.popLevel(); - } + const Context & context; + SizeLimits set_size_limit; + size_t subquery_depth; + const NamesAndTypesList & source_columns; + PreparedSets & prepared_sets; + SubqueriesForSets & subqueries_for_sets; + bool no_subqueries; + bool only_consts; + bool no_storage_or_local; + size_t visit_depth; + ScopeStack actions_stack; + + Data(const Context & context_, SizeLimits set_size_limit_, size_t subquery_depth_, + const NamesAndTypesList & source_columns_, const ExpressionActionsPtr & actions, + PreparedSets & prepared_sets_, SubqueriesForSets & subqueries_for_sets_, + bool no_subqueries_, bool only_consts_, bool no_storage_or_local_) + : context(context_), + set_size_limit(set_size_limit_), + subquery_depth(subquery_depth_), + source_columns(source_columns_), + prepared_sets(prepared_sets_), + subqueries_for_sets(subqueries_for_sets_), + no_subqueries(no_subqueries_), + only_consts(only_consts_), + no_storage_or_local(no_storage_or_local_), + visit_depth(0), + actions_stack(actions, context) + {} + + void updateActions(ExpressionActionsPtr & actions) + { + actions = actions_stack.popLevel(); + } + + /// Does result of the calculation already exists in the block. + bool hasColumn(const String & columnName) const + { + return actions_stack.getSampleBlock().has(columnName); + } + }; + + static void visit(const ASTPtr & ast, Data & data); + static bool needChildVisit(const ASTPtr & node, const ASTPtr & child); private: - const Context & context; - SizeLimits set_size_limit; - size_t subquery_depth; - const NamesAndTypesList & source_columns; - PreparedSets & prepared_sets; - SubqueriesForSets & subqueries_for_sets; - const bool no_subqueries; - const bool only_consts; - const bool no_storage_or_local; - mutable size_t visit_depth; - std::ostream * ostr; - ScopeStack actions_stack; - void visit(const ASTPtr & ast); - SetPtr makeSet(const ASTFunction * node, const Block & sample_block); + static void visit(const ASTIdentifier & identifier, const ASTPtr & ast, Data & data); + static void visit(const ASTFunction & node, const ASTPtr & ast, Data & data); + static void visit(const ASTLiteral & literal, const ASTPtr & ast, Data & data); + + static SetPtr makeSet(const ASTFunction & node, Data & data); }; +using ActionsVisitor = ActionsMatcher::Visitor; + } diff --git a/dbms/src/Interpreters/ExpressionAnalyzer.cpp b/dbms/src/Interpreters/ExpressionAnalyzer.cpp index eaba3d568e3..2132259ecaa 100644 --- a/dbms/src/Interpreters/ExpressionAnalyzer.cpp +++ b/dbms/src/Interpreters/ExpressionAnalyzer.cpp @@ -315,10 +315,11 @@ void SelectQueryExpressionAnalyzer::makeSetsForIndex(const ASTPtr & node) void ExpressionAnalyzer::getRootActions(const ASTPtr & ast, bool no_subqueries, ExpressionActionsPtr & actions, bool only_consts) { LogAST log; - ActionsVisitor actions_visitor(context, settings.size_limits_for_set, subquery_depth, + ActionsVisitor::Data visitor_data(context, settings.size_limits_for_set, subquery_depth, sourceColumns(), actions, prepared_sets, subqueries_for_sets, - no_subqueries, only_consts, !isRemoteStorage(), log.stream()); - actions_visitor.visit(ast, actions); + no_subqueries, only_consts, !isRemoteStorage()); + ActionsVisitor(visitor_data, log.stream()).visit(ast); + visitor_data.updateActions(actions); }