diff --git a/src/Interpreters/ExpressionActions.cpp b/src/Interpreters/ExpressionActions.cpp index 3772ec4fcf7..16c01b8747a 100644 --- a/src/Interpreters/ExpressionActions.cpp +++ b/src/Interpreters/ExpressionActions.cpp @@ -1251,7 +1251,7 @@ void ExpressionActionsChain::addStep() throw Exception("Cannot add action to empty ExpressionActionsChain", ErrorCodes::LOGICAL_ERROR); ColumnsWithTypeAndName columns = steps.back()->getResultColumns(); - steps.push_back(std::make_unique(std::make_shared(columns, context))); + steps.push_back(std::make_unique(std::make_shared(columns))); } void ExpressionActionsChain::finalize() @@ -1398,12 +1398,17 @@ void ExpressionActionsChain::JoinStep::finalize(const Names & required_output_) std::swap(result_columns, new_result_columns); } -ExpressionActionsPtr & ExpressionActionsChain::Step::actions() +ActionsDAGPtr & ExpressionActionsChain::Step::actions() { - return typeid_cast(this)->actions; + return typeid_cast(this)->actions_dag; } -const ExpressionActionsPtr & ExpressionActionsChain::Step::actions() const +const ActionsDAGPtr & ExpressionActionsChain::Step::actions() const +{ + return typeid_cast(this)->actions_dag; +} + +ExpressionActionsPtr ExpressionActionsChain::Step::getExpression() const { return typeid_cast(this)->actions; } @@ -1420,13 +1425,18 @@ ActionsDAG::ActionsDAG(const ColumnsWithTypeAndName & inputs) addInput(input.name, input.type); } -ActionsDAG::Node & ActionsDAG::addNode(Node node) +ActionsDAG::Node & ActionsDAG::addNode(Node node, bool can_replace) { - if (index.count(node.result_name) != 0) + auto it = index.find(node.result_name); + if (it != index.end() && !can_replace) throw Exception("Column '" + node.result_name + "' already exists", ErrorCodes::DUPLICATE_COLUMN); auto & res = nodes.emplace_back(std::move(node)); index[res.result_name] = &res; + + if (it != index.end()) + it->second->renaming_parent = &res; + return res; } @@ -1460,7 +1470,7 @@ const ActionsDAG::Node & ActionsDAG::addColumn(ColumnWithTypeAndName column) return addNode(std::move(node)); } -const ActionsDAG::Node & ActionsDAG::addAlias(const std::string & name, std::string alias) +const ActionsDAG::Node & ActionsDAG::addAlias(const std::string & name, std::string alias, bool can_replace) { auto & child = getNode(name); @@ -1472,7 +1482,7 @@ const ActionsDAG::Node & ActionsDAG::addAlias(const std::string & name, std::str node.allow_constant_folding = child.allow_constant_folding; node.children.emplace_back(&child); - return addNode(std::move(node)); + return addNode(std::move(node), can_replace); } const ActionsDAG::Node & ActionsDAG::addArrayJoin(const std::string & source_name, std::string result_name) @@ -1591,9 +1601,10 @@ const ActionsDAG::Node & ActionsDAG::addFunction( ColumnsWithTypeAndName ActionsDAG::getResultColumns() const { ColumnsWithTypeAndName result; - result.reserve(nodes.size()); + result.reserve(index.size()); for (const auto & node : nodes) - result.emplace_back(node.column, node.result_type, node.result_name); + if (!node.renaming_parent) + result.emplace_back(node.column, node.result_type, node.result_name); return result; } @@ -1602,11 +1613,23 @@ NamesAndTypesList ActionsDAG::getNamesAndTypesList() const { NamesAndTypesList result; for (const auto & node : nodes) - result.emplace_back(node.result_name, node.result_type); + if (!node.renaming_parent) + result.emplace_back(node.result_name, node.result_type); return result; } +Names ActionsDAG::getNames() const +{ + Names names; + names.reserve(index.size()); + for (const auto & node : nodes) + if (!node.renaming_parent) + names.emplace_back(node.result_name); + + return names; +} + std::string ActionsDAG::dumpNames() const { WriteBufferFromOwnString out; @@ -1625,7 +1648,9 @@ ExpressionActionsPtr ActionsDAG::buildExpressions(const Context & context) { Node * node = nullptr; size_t num_created_children = 0; + size_t num_expected_children = 0; std::vector parents; + Node * renamed_child = nullptr; }; std::vector data(nodes.size()); @@ -1643,13 +1668,38 @@ ExpressionActionsPtr ActionsDAG::buildExpressions(const Context & context) for (auto & node : nodes) { + data[reverse_index[&node]].num_expected_children += node.children.size(); + for (const auto & child : node.children) data[reverse_index[child]].parents.emplace_back(&node); - if (node.children.empty()) + if (node.renaming_parent) + { + + auto & cur = data[reverse_index[node.renaming_parent]]; + cur.renamed_child = &node; + cur.num_expected_children += 1; + } + } + + for (auto & node : nodes) + { + if (node.children.empty() && data[reverse_index[&node]].renamed_child == nullptr) ready_nodes.emplace(&node); } + auto update_parent = [&](Node * parent) + { + auto & cur = data[reverse_index[parent]]; + ++cur.num_created_children; + + if (cur.num_created_children == cur.num_expected_children) + { + auto & push_stack = parent->type == Type::ARRAY_JOIN ? ready_array_joins : ready_nodes; + push_stack.push(parent); + } + }; + auto expressions = std::make_shared(NamesAndTypesList(), context); while (!ready_nodes.empty() || !ready_array_joins.empty()) @@ -1662,6 +1712,8 @@ ExpressionActionsPtr ActionsDAG::buildExpressions(const Context & context) for (const auto & child : node->children) argument_names.emplace_back(child->result_name); + auto & cur = data[reverse_index[node]]; + switch (node->type) { case Type::INPUT: @@ -1671,7 +1723,7 @@ ExpressionActionsPtr ActionsDAG::buildExpressions(const Context & context) expressions->add(ExpressionAction::addColumn({node->column, node->result_type, node->result_name})); break; case Type::ALIAS: - expressions->add(ExpressionAction::copyColumn(argument_names.at(0), node->result_name)); + expressions->add(ExpressionAction::copyColumn(argument_names.at(0), node->result_name, cur.renamed_child != nullptr)); break; case Type::ARRAY_JOIN: expressions->add(ExpressionAction::arrayJoin(argument_names.at(0), node->result_name)); @@ -1681,17 +1733,11 @@ ExpressionActionsPtr ActionsDAG::buildExpressions(const Context & context) break; } - for (const auto & parent : data[reverse_index[node]].parents) - { - auto & cur = data[reverse_index[parent]]; - ++cur.num_created_children; + for (const auto & parent : cur.parents) + update_parent(parent); - if (parent->children.size() == cur.num_created_children) - { - auto & push_stack = parent->type == Type::ARRAY_JOIN ? ready_array_joins : ready_nodes; - push_stack.push(parent); - } - } + if (node->renaming_parent) + update_parent(node->renaming_parent); } return expressions; diff --git a/src/Interpreters/ExpressionActions.h b/src/Interpreters/ExpressionActions.h index 79107d3baa9..e6e5c038ac3 100644 --- a/src/Interpreters/ExpressionActions.h +++ b/src/Interpreters/ExpressionActions.h @@ -160,6 +160,8 @@ public: struct Node { std::vector children; + /// This field is filled if current node is replaced by existing node with the same name. + Node * renaming_parent = nullptr; Type type; @@ -192,16 +194,16 @@ public: ActionsDAG(const NamesAndTypesList & inputs); ActionsDAG(const ColumnsWithTypeAndName & inputs); - const std::list & getNodes() const; const Index & getIndex() const { return index; } ColumnsWithTypeAndName getResultColumns() const; NamesAndTypesList getNamesAndTypesList() const; + Names getNames() const; std::string dumpNames() const; const Node & addInput(std::string name, DataTypePtr type); const Node & addColumn(ColumnWithTypeAndName column); - const Node & addAlias(const std::string & name, std::string alias); + const Node & addAlias(const std::string & name, std::string alias, bool can_replace); const Node & addArrayJoin(const std::string & source_name, std::string result_name); const Node & addFunction( const FunctionOverloadResolverPtr & function, @@ -212,10 +214,12 @@ public: ExpressionActionsPtr buildExpressions(const Context & context); private: - Node & addNode(Node node); + Node & addNode(Node node, bool can_replace = false); Node & getNode(const std::string & name); }; +using ActionsDAGPtr = std::shared_ptr; + /** Contains a sequence of actions on the block. */ class ExpressionActions @@ -363,17 +367,19 @@ struct ExpressionActionsChain virtual std::string dump() const = 0; /// Only for ExpressionActionsStep - ExpressionActionsPtr & actions(); - const ExpressionActionsPtr & actions() const; + ActionsDAGPtr & actions(); + const ActionsDAGPtr & actions() const; + ExpressionActionsPtr getExpression() const; }; struct ExpressionActionsStep : public Step { + ActionsDAGPtr actions_dag; ExpressionActionsPtr actions; - explicit ExpressionActionsStep(ExpressionActionsPtr actions_, Names required_output_ = Names()) + explicit ExpressionActionsStep(ActionsDAGPtr actions_, Names required_output_ = Names()) : Step(std::move(required_output_)) - , actions(std::move(actions_)) + , actions_dag(std::move(actions_)) { } @@ -458,7 +464,9 @@ struct ExpressionActionsChain throw Exception("Empty ExpressionActionsChain", ErrorCodes::LOGICAL_ERROR); } - return steps.back()->actions(); + auto * step = typeid_cast(&steps.back()); + step->actions = step->actions_dag->buildExpressions(context); + return step->actions; } Step & getLastStep() @@ -472,7 +480,7 @@ struct ExpressionActionsChain Step & lastStep(const NamesAndTypesList & columns) { if (steps.empty()) - steps.emplace_back(std::make_unique(std::make_shared(columns, context))); + steps.emplace_back(std::make_unique(std::make_shared(columns))); return *steps.back(); } diff --git a/src/Interpreters/ExpressionAnalyzer.cpp b/src/Interpreters/ExpressionAnalyzer.cpp index 053d353bdfb..779c9ee7bf7 100644 --- a/src/Interpreters/ExpressionAnalyzer.cpp +++ b/src/Interpreters/ExpressionAnalyzer.cpp @@ -472,8 +472,8 @@ ArrayJoinActionPtr SelectQueryExpressionAnalyzer::appendArrayJoin(ExpressionActi getRootActions(array_join_expression_list, only_types, step.actions()); - before_array_join = chain.getLastActions(); auto array_join = addMultipleArrayJoinAction(step.actions(), is_array_join_left); + before_array_join = chain.getLastActions(); chain.steps.push_back(std::make_unique( array_join, step.getResultColumns())); @@ -615,13 +615,14 @@ JoinPtr SelectQueryExpressionAnalyzer::makeTableJoin(const ASTTablesInSelectQuer return subquery_for_join.join; } -bool SelectQueryExpressionAnalyzer::appendPrewhere( +ExpressionActionsPtr SelectQueryExpressionAnalyzer::appendPrewhere( ExpressionActionsChain & chain, bool only_types, const Names & additional_required_columns) { const auto * select_query = getSelectQuery(); + ExpressionActionsPtr prewhere_actions; if (!select_query->prewhere()) - return false; + return prewhere_actions; auto & step = chain.lastStep(sourceColumns()); getRootActions(select_query->prewhere(), only_types, step.actions()); @@ -629,15 +630,16 @@ bool SelectQueryExpressionAnalyzer::appendPrewhere( step.required_output.push_back(prewhere_column_name); step.can_remove_required_output.push_back(true); - auto filter_type = step.actions()->getSampleBlock().getByName(prewhere_column_name).type; + auto filter_type = step.actions()->getIndex().find(prewhere_column_name)->second->result_type; if (!filter_type->canBeUsedInBooleanContext()) throw Exception("Invalid type for filter in PREWHERE: " + filter_type->getName(), ErrorCodes::ILLEGAL_TYPE_OF_COLUMN_FOR_FILTER); { /// Remove unused source_columns from prewhere actions. - auto tmp_actions = std::make_shared(sourceColumns(), context); - getRootActions(select_query->prewhere(), only_types, tmp_actions); + auto tmp_actions_dag = std::make_shared(sourceColumns()); + getRootActions(select_query->prewhere(), only_types, tmp_actions_dag); + auto tmp_actions = tmp_actions_dag->buildExpressions(context); tmp_actions->finalize({prewhere_column_name}); auto required_columns = tmp_actions->getRequiredColumns(); NameSet required_source_columns(required_columns.begin(), required_columns.end()); @@ -653,7 +655,7 @@ bool SelectQueryExpressionAnalyzer::appendPrewhere( } } - auto names = step.actions()->getSampleBlock().getNames(); + auto names = step.actions()->getNames(); NameSet name_set(names.begin(), names.end()); for (const auto & column : sourceColumns()) @@ -661,7 +663,8 @@ bool SelectQueryExpressionAnalyzer::appendPrewhere( name_set.erase(column.name); Names required_output(name_set.begin(), name_set.end()); - step.actions()->finalize(required_output); + prewhere_actions = chain.getLastActions(); + prewhere_actions->finalize(required_output); } { @@ -672,8 +675,8 @@ bool SelectQueryExpressionAnalyzer::appendPrewhere( /// 2. Store side columns which were calculated during prewhere actions execution if they are used. /// Example: select F(A) prewhere F(A) > 0. F(A) can be saved from prewhere step. /// 3. Check if we can remove filter column at prewhere step. If we can, action will store single REMOVE_COLUMN. - ColumnsWithTypeAndName columns = step.actions()->getSampleBlock().getColumnsWithTypeAndName(); - auto required_columns = step.actions()->getRequiredColumns(); + ColumnsWithTypeAndName columns = prewhere_actions->getSampleBlock().getColumnsWithTypeAndName(); + auto required_columns = prewhere_actions->getRequiredColumns(); NameSet prewhere_input_names(required_columns.begin(), required_columns.end()); NameSet unused_source_columns; @@ -687,11 +690,13 @@ bool SelectQueryExpressionAnalyzer::appendPrewhere( } chain.steps.emplace_back(std::make_unique( - std::make_shared(std::move(columns), context))); + std::make_shared(std::move(columns)))); chain.steps.back()->additional_input = std::move(unused_source_columns); + chain.getLastActions(); + chain.addStep(); } - return true; + return prewhere_actions; } void SelectQueryExpressionAnalyzer::appendPreliminaryFilter(ExpressionActionsChain & chain, ExpressionActionsPtr actions, String column_name) @@ -699,7 +704,8 @@ void SelectQueryExpressionAnalyzer::appendPreliminaryFilter(ExpressionActionsCha ExpressionActionsChain::Step & step = chain.lastStep(sourceColumns()); // FIXME: assert(filter_info); - step.actions() = std::move(actions); + auto * expression_step = typeid_cast(&step); + expression_step->actions = std::move(actions); step.required_output.push_back(std::move(column_name)); step.can_remove_required_output = {true}; @@ -721,7 +727,7 @@ bool SelectQueryExpressionAnalyzer::appendWhere(ExpressionActionsChain & chain, getRootActions(select_query->where(), only_types, step.actions()); - auto filter_type = step.actions()->getSampleBlock().getByName(where_column_name).type; + auto filter_type = step.actions()->getIndex().find(where_column_name)->second->result_type; if (!filter_type->canBeUsedInBooleanContext()) throw Exception("Invalid type for filter in WHERE: " + filter_type->getName(), ErrorCodes::ILLEGAL_TYPE_OF_COLUMN_FOR_FILTER); @@ -750,8 +756,9 @@ bool SelectQueryExpressionAnalyzer::appendGroupBy(ExpressionActionsChain & chain { for (auto & child : asts) { - group_by_elements_actions.emplace_back(std::make_shared(columns_after_join, context)); - getRootActions(child, only_types, group_by_elements_actions.back()); + auto actions_dag = std::make_shared(columns_after_join); + getRootActions(child, only_types, actions_dag); + group_by_elements_actions.emplace_back(actions_dag->buildExpressions(context)); } } @@ -838,8 +845,9 @@ bool SelectQueryExpressionAnalyzer::appendOrderBy(ExpressionActionsChain & chain { for (auto & child : select_query->orderBy()->children) { - order_by_elements_actions.emplace_back(std::make_shared(columns_after_join, context)); - getRootActions(child, only_types, order_by_elements_actions.back()); + auto actions_dag = std::make_shared(columns_after_join); + getRootActions(child, only_types, actions_dag); + order_by_elements_actions.emplace_back(actions_dag->buildExpressions(context)); } } return true; @@ -919,7 +927,7 @@ void SelectQueryExpressionAnalyzer::appendProjectResult(ExpressionActionsChain & } } - step.actions()->add(ExpressionAction::project(result_columns)); + chain.getLastActions()->add(ExpressionAction::project(result_columns)); } @@ -933,7 +941,7 @@ void ExpressionAnalyzer::appendExpression(ExpressionActionsChain & chain, const ExpressionActionsPtr ExpressionAnalyzer::getActions(bool add_aliases, bool project_result) { - ExpressionActionsPtr actions = std::make_shared(aggregated_columns, context); + auto actions_dag = std::make_shared(aggregated_columns); NamesWithAliases result_columns; Names result_names; @@ -954,9 +962,11 @@ ExpressionActionsPtr ExpressionAnalyzer::getActions(bool add_aliases, bool proje alias = name; result_columns.emplace_back(name, alias); result_names.push_back(alias); - getRootActions(ast, false, actions); + getRootActions(ast, false, actions_dag); } + auto actions = actions_dag->buildExpressions(context); + if (add_aliases) { if (project_result) @@ -980,10 +990,10 @@ ExpressionActionsPtr ExpressionAnalyzer::getActions(bool add_aliases, bool proje ExpressionActionsPtr ExpressionAnalyzer::getConstActions() { - ExpressionActionsPtr actions = std::make_shared(NamesAndTypesList(), context); + auto actions = std::make_shared(NamesAndTypesList()); getRootActions(query, true, actions, true); - return actions; + return actions->buildExpressions(context); } ExpressionActionsPtr SelectQueryExpressionAnalyzer::simpleSelectActions() @@ -1064,10 +1074,9 @@ ExpressionAnalysisResult::ExpressionAnalysisResult( query_analyzer.appendPreliminaryFilter(chain, filter_info->actions, filter_info->column_name); } - if (query_analyzer.appendPrewhere(chain, !first_stage, additional_required_columns_after_prewhere)) + if (auto actions = query_analyzer.appendPrewhere(chain, !first_stage, additional_required_columns_after_prewhere)) { - prewhere_info = std::make_shared( - chain.steps.front()->actions(), query.prewhere()->getColumnName()); + prewhere_info = std::make_shared(actions, query.prewhere()->getColumnName()); if (allowEarlyConstantFolding(*prewhere_info->prewhere_actions, settings)) { @@ -1081,7 +1090,6 @@ ExpressionAnalysisResult::ExpressionAnalysisResult( prewhere_constant_filter_description = ConstantFilterDescription(*column_elem.column); } } - chain.addStep(); } array_join = query_analyzer.appendArrayJoin(chain, before_array_join, only_types || !first_stage); diff --git a/src/Interpreters/ExpressionAnalyzer.h b/src/Interpreters/ExpressionAnalyzer.h index bf4a4f564a4..7728cd9e6ea 100644 --- a/src/Interpreters/ExpressionAnalyzer.h +++ b/src/Interpreters/ExpressionAnalyzer.h @@ -319,7 +319,7 @@ private: void appendPreliminaryFilter(ExpressionActionsChain & chain, ExpressionActionsPtr actions, String column_name); /// remove_filter is set in ExpressionActionsChain::finalize(); /// Columns in `additional_required_columns` will not be removed (they can be used for e.g. sampling or FINAL modifier). - bool appendPrewhere(ExpressionActionsChain & chain, bool only_types, const Names & additional_required_columns); + ExpressionActionsPtr appendPrewhere(ExpressionActionsChain & chain, bool only_types, const Names & additional_required_columns); bool appendWhere(ExpressionActionsChain & chain, bool only_types); bool appendGroupBy(ExpressionActionsChain & chain, bool only_types, bool optimize_aggregation_in_order, ManyExpressionActions &); void appendAggregateFunctionsArguments(ExpressionActionsChain & chain, bool only_types); diff --git a/src/Interpreters/MutationsInterpreter.cpp b/src/Interpreters/MutationsInterpreter.cpp index 089e3d1c23f..2639c94a9ca 100644 --- a/src/Interpreters/MutationsInterpreter.cpp +++ b/src/Interpreters/MutationsInterpreter.cpp @@ -612,8 +612,8 @@ ASTPtr MutationsInterpreter::prepareInterpreterSelectQuery(std::vector & for (const auto & kv : stage.column_to_updated) { - actions_chain.getLastActions()->add(ExpressionAction::copyColumn( - kv.second->getColumnName(), kv.first, /* can_replace = */ true)); + actions_chain.getLastStep().actions()->addAlias( + kv.second->getColumnName(), kv.first, /* can_replace = */ true); } } @@ -624,7 +624,7 @@ ASTPtr MutationsInterpreter::prepareInterpreterSelectQuery(std::vector & actions_chain.finalize(); /// Propagate information about columns needed as input. - for (const auto & column : actions_chain.steps.front()->actions()->getRequiredColumnsWithTypes()) + for (const auto & column : actions_chain.steps.front()->getRequiredColumns()) prepared_stages[i - 1].output_columns.insert(column.name); } @@ -670,7 +670,7 @@ void MutationsInterpreter::addStreamsForLaterStages(const std::vector & p /// Execute DELETEs. pipeline.addSimpleTransform([&](const Block & header) { - return std::make_shared(header, step->actions(), stage.filter_column_names[i], false); + return std::make_shared(header, step->getExpression(), stage.filter_column_names[i], false); }); } else @@ -678,7 +678,7 @@ void MutationsInterpreter::addStreamsForLaterStages(const std::vector & p /// Execute UPDATE or final projection. pipeline.addSimpleTransform([&](const Block & header) { - return std::make_shared(header, step->actions()); + return std::make_shared(header, step->getExpression()); }); } }