diff --git a/src/Interpreters/ActionsDAG.cpp b/src/Interpreters/ActionsDAG.cpp index 3ea48af7daa..00ba4d6d21c 100644 --- a/src/Interpreters/ActionsDAG.cpp +++ b/src/Interpreters/ActionsDAG.cpp @@ -670,6 +670,97 @@ ActionsDAGPtr ActionsDAG::makeConvertingActions( return actions_dag; } +ActionsDAGPtr ActionsDAG::merge(ActionsDAG && lhs, ActionsDAG && rhs) +{ + /// lhs: x (1), x (2), y ==> x (2), z, x (3) + /// rhs: x (1), x (2), x (3) ==> x (3), x (2), x (1) + /// merge: x (1), x (2), x (3), y =(lhs)=> x (3), y, x (2), z, x (4) =(rhs)=> y, z, x (4), x (2), x (3) + + /// Will store merged result in lhs. + + std::unordered_set removed_lhs_result; + std::unordered_map inputs_map; + + /// Update inputs list. + { + std::unordered_map> lhs_result; + for (auto & node : lhs.index) + lhs_result[node->result_name].push_back(node); + + for (auto & node : rhs.inputs) + { + auto it = lhs_result.find(node->result_name); + if (it == lhs_result.end() || it->second.empty()) + { + if (lhs.settings.project_input) + throw Exception(ErrorCodes::LOGICAL_ERROR, + "Cannot find column {} in ActionsDAG result", node->result_name); + + lhs.inputs.push_back(node); + } + else + { + inputs_map[node] = it->second.front(); + removed_lhs_result.emplace(it->second.front()); + it->second.pop_front(); + } + } + } + + /// Update index. + if (rhs.settings.project_input) + { + lhs.index.swap(rhs.index); + lhs.settings.project_input = true; + } + else + { + /// Remove rhs inputs from index. + for (auto it = lhs.index.begin(); it != lhs.index.end();) + { + auto cur = it; + ++it; + + if (removed_lhs_result.count(*cur)) + lhs.index.remove(cur); + } + + for (auto * node : rhs.index) + lhs.index.insert(node); + } + + /// Replace inputs from rhs to nodes from lhs result. + for (auto & node : rhs.nodes) + { + for (auto & child : node.children) + { + if (child->type == ActionType::INPUT) + { + auto it = inputs_map.find(child); + if (it != inputs_map.end()) + child = it->second; + } + } + } + + lhs.nodes.splice(lhs.nodes.end(), std::move(rhs.nodes)); + +#if USE_EMBEDDED_COMPILER + if (lhs.compilation_cache == nullptr) + lhs.compilation_cache = rhs.compilation_cache; +#endif + + lhs.settings.max_temporary_columns = std::max(lhs.settings.max_temporary_columns, rhs.settings.max_temporary_columns); + lhs.settings.max_temporary_non_const_columns = std::max(lhs.settings.max_temporary_non_const_columns, rhs.settings.max_temporary_non_const_columns); + lhs.settings.min_count_to_compile_expression = std::max(lhs.settings.min_count_to_compile_expression, rhs.settings.min_count_to_compile_expression); + lhs.settings.projected_output = rhs.settings.projected_output; + + /// Drop unused inputs and, probably, some actions. + lhs.removeUnusedActions(); + + return std::make_shared(std::move(lhs)); +} + ActionsDAGPtr ActionsDAG::splitActionsBeforeArrayJoin(const NameSet & array_joined_columns) { /// Split DAG into two parts. diff --git a/src/Interpreters/ActionsDAG.h b/src/Interpreters/ActionsDAG.h index ca54dab7231..704415c91d2 100644 --- a/src/Interpreters/ActionsDAG.h +++ b/src/Interpreters/ActionsDAG.h @@ -143,6 +143,15 @@ public: map.erase(it); } + void remove(std::list::iterator it) + { + auto map_it = map.find((*it)->result_name); + if (map_it != map.end() && map_it->second == it) + map.erase(map_it); + + list.erase(it); + } + void swap(Index & other) { list.swap(other.list); @@ -176,6 +185,7 @@ private: public: ActionsDAG() = default; + ActionsDAG(ActionsDAG &&) = default; ActionsDAG(const ActionsDAG &) = delete; ActionsDAG & operator=(const ActionsDAG &) = delete; explicit ActionsDAG(const NamesAndTypesList & inputs_); @@ -248,6 +258,10 @@ public: MatchColumnsMode mode, bool ignore_constant_values = false); /// Do not check that constants are same. Use value from result_header. + /// Create ActionsDAG which represents expression equivalent to applying lhs and rhs actions consequently. + /// Is used to replace `(lhs -> rhs)` expression chain to single `merge(lhs, rhs)` expression. + static ActionsDAGPtr merge(ActionsDAG && lhs, ActionsDAG && rhs); + private: Node & addNode(Node node, bool can_replace = false); Node & getNode(const std::string & name);