diff --git a/src/Interpreters/ActionsDAG.cpp b/src/Interpreters/ActionsDAG.cpp index a788d6f84e3..0256090abc6 100644 --- a/src/Interpreters/ActionsDAG.cpp +++ b/src/Interpreters/ActionsDAG.cpp @@ -24,7 +24,7 @@ namespace ErrorCodes ActionsDAG::ActionsDAG(const NamesAndTypesList & inputs) { for (const auto & input : inputs) - addInput(input.name, input.type); + addInput(input.name, input.type, true); } ActionsDAG::ActionsDAG(const ColumnsWithTypeAndName & inputs) @@ -32,9 +32,9 @@ ActionsDAG::ActionsDAG(const ColumnsWithTypeAndName & inputs) for (const auto & input : inputs) { if (input.column && isColumnConst(*input.column)) - addInput(input); + addInput(input, true); else - addInput(input.name, input.type); + addInput(input.name, input.type, true); } } @@ -46,6 +46,9 @@ ActionsDAG::Node & ActionsDAG::addNode(Node node, bool can_replace) auto & res = nodes.emplace_back(std::move(node)); + if (res.type == ActionType::INPUT) + inputs.emplace_back(&res); + index.replace(&res); return res; } @@ -59,17 +62,17 @@ ActionsDAG::Node & ActionsDAG::getNode(const std::string & name) return **it; } -const ActionsDAG::Node & ActionsDAG::addInput(std::string name, DataTypePtr type) +const ActionsDAG::Node & ActionsDAG::addInput(std::string name, DataTypePtr type, bool can_replace) { Node node; node.type = ActionType::INPUT; node.result_type = std::move(type); node.result_name = std::move(name); - return addNode(std::move(node)); + return addNode(std::move(node), can_replace); } -const ActionsDAG::Node & ActionsDAG::addInput(ColumnWithTypeAndName column) +const ActionsDAG::Node & ActionsDAG::addInput(ColumnWithTypeAndName column, bool can_replace) { Node node; node.type = ActionType::INPUT; @@ -77,7 +80,7 @@ const ActionsDAG::Node & ActionsDAG::addInput(ColumnWithTypeAndName column) node.result_name = std::move(column.name); node.column = std::move(column.column); - return addNode(std::move(node)); + return addNode(std::move(node), can_replace); } const ActionsDAG::Node & ActionsDAG::addColumn(ColumnWithTypeAndName column) @@ -144,6 +147,14 @@ const ActionsDAG::Node & ActionsDAG::addFunction( compilation_cache = context.getCompiledExpressionCache(); #endif + return addFunction(function, argument_names, std::move(result_name)); +} + +const ActionsDAG::Node & ActionsDAG::addFunction( + const FunctionOverloadResolverPtr & function, + const Names & argument_names, + std::string result_name) +{ size_t num_arguments = argument_names.size(); Node node; @@ -231,9 +242,8 @@ const ActionsDAG::Node & ActionsDAG::addFunction( NamesAndTypesList ActionsDAG::getRequiredColumns() const { NamesAndTypesList result; - for (const auto & node : nodes) - if (node.type == ActionType::INPUT) - result.emplace_back(node.result_name, node.result_type); + for (const auto & input : inputs) + result.emplace_back(input->result_name, input->result_type); return result; } @@ -347,6 +357,8 @@ void ActionsDAG::removeUnusedActions() } nodes.remove_if([&](const Node & node) { return visited_nodes.count(&node) == 0; }); + auto it = std::remove_if(inputs.begin(), inputs.end(), [&](const Node * node) { return visited_nodes.count(node) == 0; }); + inputs.erase(it, inputs.end()); } void ActionsDAG::addAliases(const NamesWithAliases & aliases, std::vector & result_nodes) @@ -442,6 +454,9 @@ ActionsDAGPtr ActionsDAG::clone() const for (const auto & node : index) actions->index.insert(copy_map[node]); + for (const auto & node : inputs) + actions->inputs.push_back(copy_map[node]); + return actions; } @@ -540,6 +555,7 @@ ActionsDAGPtr ActionsDAG::splitActionsBeforeArrayJoin(const NameSet & array_join std::list split_nodes; Index this_index; Index split_index; + Inputs new_inputs; struct Frame { @@ -627,6 +643,7 @@ ActionsDAGPtr ActionsDAG::splitActionsBeforeArrayJoin(const NameSet & array_join input_node.result_type = child->result_type; input_node.result_name = child->result_name; // getUniqueNameForIndex(index, child->result_name); child_data.to_this = &this_nodes.emplace_back(std::move(input_node)); + new_inputs.push_back(child_data.to_this); /// This node is needed for current action, so put it to index also. split_index.replace(child_data.to_split); @@ -658,6 +675,7 @@ ActionsDAGPtr ActionsDAG::splitActionsBeforeArrayJoin(const NameSet & array_join input_node.result_type = node.result_type; input_node.result_name = node.result_name; cur_data.to_this = &this_nodes.emplace_back(std::move(input_node)); + new_inputs.push_back(cur_data.to_this); } } } @@ -676,12 +694,28 @@ ActionsDAGPtr ActionsDAG::splitActionsBeforeArrayJoin(const NameSet & array_join if (split_actions_are_empty) return {}; + Inputs this_inputs; + Inputs split_inputs; + + for (auto * input : inputs) + { + const auto & cur = data[input]; + if (cur.to_this) + this_inputs.push_back(cur.to_this); + else + split_inputs.push_back(cur.to_split); + } + + this_inputs.insert(this_inputs.end(), new_inputs.begin(), new_inputs.end()); + index.swap(this_index); nodes.swap(this_nodes); + inputs.swap(this_inputs); auto split_actions = cloneEmpty(); split_actions->nodes.swap(split_nodes); split_actions->index.swap(split_index); + split_actions->inputs.swap(split_inputs); split_actions->settings.project_input = false; return split_actions; diff --git a/src/Interpreters/ActionsDAG.h b/src/Interpreters/ActionsDAG.h index 4765456ca4f..5a5dbebdedd 100644 --- a/src/Interpreters/ActionsDAG.h +++ b/src/Interpreters/ActionsDAG.h @@ -151,6 +151,7 @@ public: }; using Nodes = std::list; + using Inputs = std::vector; struct ActionsSettings { @@ -165,6 +166,7 @@ public: private: Nodes nodes; Index index; + Inputs inputs; ActionsSettings settings; @@ -181,6 +183,7 @@ public: const Nodes & getNodes() const { return nodes; } const Index & getIndex() const { return index; } + const Inputs & getInputs() const { return inputs; } NamesAndTypesList getRequiredColumns() const; ColumnsWithTypeAndName getResultColumns() const; @@ -190,11 +193,15 @@ public: std::string dumpNames() const; std::string dumpDAG() const; - const Node & addInput(std::string name, DataTypePtr type); - const Node & addInput(ColumnWithTypeAndName column); + const Node & addInput(std::string name, DataTypePtr type, bool can_replace = false); + const Node & addInput(ColumnWithTypeAndName column, bool can_replace = false); const Node & addColumn(ColumnWithTypeAndName column); const Node & addAlias(const std::string & name, std::string alias, bool can_replace = false); const Node & addArrayJoin(const std::string & source_name, std::string result_name); + const Node & addFunction( + const FunctionOverloadResolverPtr & function, + const Names & argument_names, + std::string result_name); const Node & addFunction( const FunctionOverloadResolverPtr & function, const Names & argument_names, diff --git a/src/Interpreters/ExpressionActions.cpp b/src/Interpreters/ExpressionActions.cpp index 53c08481fc2..4c332036b41 100644 --- a/src/Interpreters/ExpressionActions.cpp +++ b/src/Interpreters/ExpressionActions.cpp @@ -83,6 +83,7 @@ void ExpressionActions::linearizeActions() const auto & nodes = getNodes(); const auto & index = actions_dag->getIndex(); + const auto & inputs = actions_dag->getInputs(); std::vector data(nodes.size()); std::unordered_map reverse_index; @@ -163,11 +164,11 @@ void ExpressionActions::linearizeActions() { /// Argument for input is special. It contains the position from required columns. ExpressionActions::Argument argument; - argument.pos = required_columns.size(); + // argument.pos = required_columns.size(); argument.needed_later = !cur.parents.empty(); arguments.emplace_back(argument); - required_columns.push_back({node->result_name, node->result_type}); + //required_columns.push_back({node->result_name, node->result_type}); } actions.push_back({node, arguments, free_position}); @@ -199,6 +200,15 @@ void ExpressionActions::linearizeActions() ColumnWithTypeAndName col{node->column, node->result_type, node->result_name}; sample_block.insert(std::move(col)); } + + for (const auto * input : inputs) + { + const auto & cur = data[reverse_index[input]]; + auto pos = required_columns.size(); + actions[cur.position].arguments.front().pos = pos; + required_columns.push_back({input->result_name, input->result_type}); + input_positions[input->result_name].emplace_back(pos); + } } @@ -412,7 +422,24 @@ void ExpressionActions::execute(Block & block, size_t & num_rows, bool dry_run) .num_rows = num_rows, }; - execution_context.inputs_pos.reserve(required_columns.size()); + execution_context.inputs_pos.assign(required_columns.size(), -1); + + for (size_t pos = 0; pos < block.columns(); ++pos) + { + const auto & col = block.getByPosition(pos); + auto it = input_positions.find(col.name); + if (it != input_positions.end()) + { + for (auto input_pos : it->second) + { + if (execution_context.inputs_pos[input_pos] < 0) + { + execution_context.inputs_pos[input_pos] = pos; + break; + } + } + } + } for (const auto & column : required_columns) { diff --git a/src/Interpreters/ExpressionActions.h b/src/Interpreters/ExpressionActions.h index f2f5862856b..2b1aa5e2456 100644 --- a/src/Interpreters/ExpressionActions.h +++ b/src/Interpreters/ExpressionActions.h @@ -44,10 +44,10 @@ public: struct Argument { /// Position in ExecutionContext::columns - size_t pos; + size_t pos = 0; /// True if there is another action which will use this column. /// Otherwise column will be removed. - bool needed_later; + bool needed_later = false; }; using Arguments = std::vector; @@ -63,6 +63,11 @@ public: using Actions = std::vector; + /// This map helps to find input position bu it's name. + /// Key is a view to input::result_name. + /// Result is a list because it is allowed for inputs to have same names. + using NameToInputMap = std::unordered_map>; + private: ActionsDAGPtr actions_dag; @@ -70,6 +75,7 @@ private: size_t num_columns = 0; NamesAndTypesList required_columns; + NameToInputMap input_positions; ColumnNumbers result_positions; Block sample_block;