Add converting logic to ActionsDAG.

This commit is contained in:
Nikolai Kochetov 2020-11-17 17:51:05 +03:00
parent 71d726ea21
commit 54f0338e22
2 changed files with 142 additions and 19 deletions

View File

@ -1,7 +1,10 @@
#include <Interpreters/ActionsDAG.h>
#include <DataTypes/DataTypeArray.h>
#include <DataTypes/DataTypeString.h>
#include <Functions/IFunction.h>
#include <Functions/IFunctionAdaptors.h>
#include <Functions/FunctionsConversion.h>
#include <Interpreters/Context.h>
#include <Interpreters/ExpressionJIT.h>
#include <IO/WriteBufferFromString.h>
@ -18,6 +21,8 @@ namespace ErrorCodes
extern const int DUPLICATE_COLUMN;
extern const int UNKNOWN_IDENTIFIER;
extern const int TYPE_MISMATCH;
extern const int NUMBER_OF_COLUMNS_DOESNT_MATCH;
extern const int THERE_IS_NO_COLUMN;
}
@ -83,7 +88,7 @@ const ActionsDAG::Node & ActionsDAG::addInput(ColumnWithTypeAndName column, bool
return addNode(std::move(node), can_replace);
}
const ActionsDAG::Node & ActionsDAG::addColumn(ColumnWithTypeAndName column)
const ActionsDAG::Node & ActionsDAG::addColumn(ColumnWithTypeAndName column, bool can_replace)
{
if (!column.column)
throw Exception(ErrorCodes::LOGICAL_ERROR, "Cannot add column {} because it is nullptr", column.name);
@ -94,7 +99,7 @@ const ActionsDAG::Node & ActionsDAG::addColumn(ColumnWithTypeAndName column)
node.result_name = std::move(column.name);
node.column = std::move(column.column);
return addNode(std::move(node));
return addNode(std::move(node), can_replace);
}
const ActionsDAG::Node & ActionsDAG::addAlias(const std::string & name, std::string alias, bool can_replace)
@ -147,28 +152,33 @@ const ActionsDAG::Node & ActionsDAG::addFunction(
compilation_cache = context.getCompiledExpressionCache();
#endif
return addFunction(function, argument_names, std::move(result_name));
Inputs children;
children.reserve(argument_names.size());
for (const auto & name : argument_names)
children.push_back(&getNode(name));
addFunction(function, children, std::move(result_name), false);
}
const ActionsDAG::Node & ActionsDAG::addFunction(
ActionsDAG::Node & ActionsDAG::addFunction(
const FunctionOverloadResolverPtr & function,
const Names & argument_names,
std::string result_name)
Inputs children,
std::string result_name,
bool can_replace)
{
size_t num_arguments = argument_names.size();
size_t num_arguments = children.size();
Node node;
node.type = ActionType::FUNCTION;
node.function_builder = function;
node.children.reserve(num_arguments);
node.children = std::move(children);
bool all_const = true;
ColumnsWithTypeAndName arguments(num_arguments);
for (size_t i = 0; i < num_arguments; ++i)
{
auto & child = getNode(argument_names[i]);
node.children.emplace_back(&child);
auto & child = *node.children[i];
node.allow_constant_folding = node.allow_constant_folding && child.allow_constant_folding;
ColumnWithTypeAndName argument;
@ -224,18 +234,18 @@ const ActionsDAG::Node & ActionsDAG::addFunction(
if (result_name.empty())
{
result_name = function->getName() + "(";
for (size_t i = 0; i < argument_names.size(); ++i)
for (size_t i = 0; i < num_arguments; ++i)
{
if (i)
result_name += ", ";
result_name += argument_names[i];
result_name += node.children[i]->result_name;
}
result_name += ")";
}
node.result_name = std::move(result_name);
return addNode(std::move(node));
return addNode(std::move(node), can_replace);
}
@ -546,6 +556,102 @@ bool ActionsDAG::empty() const
return true;
}
ActionsDAGPtr ActionsDAG::makeConvertingActions(
const ColumnsWithTypeAndName & source,
const ColumnsWithTypeAndName & result,
MatchColumnsMode mode,
bool ignore_constant_values)
{
size_t num_input_columns = source.size();
size_t num_result_columns = result.size();
if (mode == MatchColumnsMode::Position && num_input_columns != num_result_columns)
throw Exception("Number of columns doesn't match", ErrorCodes::NUMBER_OF_COLUMNS_DOESNT_MATCH);
auto actions_dag = std::make_shared<ActionsDAG>(source);
std::vector<Node *> projection(num_result_columns);
FunctionOverloadResolverPtr func_builder_cast =
std::make_shared<FunctionOverloadResolverAdaptor>(CastOverloadResolver::createImpl(false));
std::map<std::string_view, std::list<size_t>> inputs;
if (mode == MatchColumnsMode::Name)
{
for (size_t pos = 0; pos < actions_dag->inputs.size(); ++pos)
inputs[actions_dag->inputs[pos]->result_name].push_back(pos);
}
for (size_t result_col_num = 0; result_col_num < num_result_columns; ++result_col_num)
{
const auto & res_elem = result[result_col_num];
Node * src_node = nullptr;
switch (mode)
{
case MatchColumnsMode::Position:
{
src_node = actions_dag->inputs[result_col_num];
break;
}
case MatchColumnsMode::Name:
{
auto & input = inputs[res_elem.name];
if (input.empty())
throw Exception("Cannot find column " + backQuoteIfNeed(res_elem.name) + " in source stream",
ErrorCodes::THERE_IS_NO_COLUMN);
src_node = actions_dag->inputs[input.front()];
input.pop_front();
break;
}
}
/// Check constants.
if (const auto * res_const = typeid_cast<const ColumnConst *>(res_elem.column.get()))
{
if (const auto * src_const = typeid_cast<const ColumnConst *>(src_node->column.get()))
{
if (ignore_constant_values)
src_node = const_cast<Node *>(&actions_dag->addColumn(res_elem, true));
else if (res_const->getField() != src_const->getField())
throw Exception("Cannot convert column " + backQuoteIfNeed(res_elem.name) + " because "
"it is constant but values of constants are different in source and result",
ErrorCodes::ILLEGAL_COLUMN);
}
else
throw Exception("Cannot convert column " + backQuoteIfNeed(res_elem.name) + " because "
"it is non constant in source stream but must be constant in result",
ErrorCodes::ILLEGAL_COLUMN);
}
if (!res_elem.type->equals(*res_elem.type))
{
ColumnWithTypeAndName column;
column.name = res_elem.type->getName();
column.column = DataTypeString().createColumnConst(0, column.name);
column.type = std::make_shared<DataTypeString>();
auto * right_arg = const_cast<Node *>(&actions_dag->addColumn(std::move(column), true));
auto * left_arg = src_node;
Inputs children = { left_arg, right_arg };
src_node = &actions_dag->addFunction(func_builder_cast, children, "", true);
}
if (src_node->result_name != res_elem.name)
src_node = const_cast<Node *>(&actions_dag->addAlias(src_node->result_name, res_elem.name, true));
projection[result_col_num] = src_node;
}
actions_dag->removeUnusedActions(projection);
actions_dag->projectInput();
return actions_dag;
}
ActionsDAGPtr ActionsDAG::splitActionsBeforeArrayJoin(const NameSet & array_joined_columns)
{
/// Split DAG into two parts.
@ -702,7 +808,7 @@ ActionsDAGPtr ActionsDAG::splitActionsBeforeArrayJoin(const NameSet & array_join
const auto & cur = data[input];
if (cur.to_this)
this_inputs.push_back(cur.to_this);
else
else if (cur.to_split)
split_inputs.push_back(cur.to_split);
}

View File

@ -195,13 +195,9 @@ public:
const Node & addInput(std::string name, DataTypePtr type, bool can_replace = false);
const Node & addInput(ColumnWithTypeAndName column, bool can_replace = false);
const Node & addColumn(ColumnWithTypeAndName column);
const Node & addColumn(ColumnWithTypeAndName column, bool can_replace = false);
const Node & addAlias(const std::string & name, std::string alias, bool can_replace = false);
const Node & addArrayJoin(const std::string & source_name, std::string result_name);
const Node & addFunction(
const FunctionOverloadResolverPtr & function,
const Names & argument_names,
std::string result_name);
const Node & addFunction(
const FunctionOverloadResolverPtr & function,
const Names & argument_names,
@ -234,10 +230,31 @@ public:
ActionsDAGPtr clone() const;
enum class MatchColumnsMode
{
/// Require same number of columns in source and result. Match columns by corresponding positions, regardless to names.
Position,
/// Find columns in source by their names. Allow excessive columns in source.
Name,
};
static ActionsDAGPtr makeConvertingActions(
const ColumnsWithTypeAndName & source,
const ColumnsWithTypeAndName & result,
MatchColumnsMode mode,
bool ignore_constant_values = false); /// Do not check that constants are same. Use value from result_header.
private:
Node & addNode(Node node, bool can_replace = false);
Node & getNode(const std::string & name);
ActionsDAG::Node & addFunction(
const FunctionOverloadResolverPtr & function,
Inputs children,
std::string result_name,
bool can_replace);
ActionsDAGPtr cloneEmpty() const
{
auto actions = std::make_shared<ActionsDAG>();