Refactored CompileDAG build algorithm

This commit is contained in:
Maksim Kita 2021-05-05 18:39:26 +03:00
parent 8828599380
commit 16a07f61ae
3 changed files with 202 additions and 158 deletions

View File

@ -144,30 +144,6 @@ static inline llvm::Value * nativeCast(llvm::IRBuilder<> & b, const DataTypePtr
return nativeCast(b, from, value, n_to);
}
static inline llvm::Constant * getColumnNativeConstant(llvm::Type * native_type, WhichDataType column_data_type, const IColumn & column)
{
llvm::Constant * result = nullptr;
if (column_data_type.isFloat32())
{
result = llvm::ConstantFP::get(native_type, column.getFloat32(0));
}
else if (column_data_type.isFloat64())
{
result = llvm::ConstantFP::get(native_type, column.getFloat64(0));
}
else if (column_data_type.isNativeInt())
{
result = llvm::ConstantInt::get(native_type, column.getInt(0));
}
else if (column_data_type.isNativeUInt())
{
result = llvm::ConstantInt::get(native_type, column.getUInt(0));
}
return result;
}
static inline llvm::Constant * getColumnNativeValue(llvm::IRBuilderBase & builder, const DataTypePtr & column_type, const IColumn & column, size_t index)
{
WhichDataType column_data_type(column_type);
@ -210,8 +186,6 @@ static inline llvm::Constant * getColumnNativeValue(llvm::IRBuilderBase & builde
return nullptr;
}
}
#endif

View File

@ -5,6 +5,7 @@
#include <optional>
#include <stack>
#include <common/logger_useful.h>
#include <Columns/ColumnConst.h>
#include <Columns/ColumnNullable.h>
#include <Columns/ColumnVector.h>
@ -26,6 +27,7 @@ namespace DB
namespace ErrorCodes
{
extern const int LOGICAL_ERROR;
extern const int NOT_IMPLEMENTED;
}
static CHJIT & getJITInstance()
@ -34,18 +36,21 @@ static CHJIT & getJITInstance()
return jit;
}
static Poco::Logger * getLogger()
{
static Poco::Logger & logger = Poco::Logger::get("ExpressionJIT");
return &logger;
}
class LLVMExecutableFunction : public IExecutableFunctionImpl
{
std::string name;
void * function = nullptr;
public:
explicit LLVMExecutableFunction(const std::string & name_)
explicit LLVMExecutableFunction(const std::string & name_, void * function_)
: name(name_)
, function(function_)
{
function = getJITInstance().findCompiledFunction(name);
if (!function)
throw Exception(ErrorCodes::LOGICAL_ERROR, "Cannot find compiled function {}", name);
}
String getName() const override { return name; }
@ -150,7 +155,12 @@ public:
ExecutableFunctionImplPtr prepare(const ColumnsWithTypeAndName &) const override
{
return std::make_unique<LLVMExecutableFunction>(name);
void * function = getJITInstance().findCompiledFunction(name);
if (!function)
throw Exception(ErrorCodes::NOT_IMPLEMENTED, "Cannot find compiled function {}", name);
return std::make_unique<LLVMExecutableFunction>(name, function);
}
bool isDeterministic() const override
@ -266,15 +276,18 @@ static FunctionBasePtr compile(
size_t compiled_size = llvm_function->getCompiledSize();
FunctionBasePtr llvm_function_wrapper = std::make_shared<FunctionBaseAdaptor>(std::move(llvm_function));
CompiledFunction compiled_function
CompiledFunction function
{
.function = llvm_function_wrapper,
.compiled_size = compiled_size
};
return std::make_shared<CompiledFunction>(compiled_function);
return std::make_shared<CompiledFunction>(function);
});
if (was_inserted)
LOG_INFO(getLogger(), "Compiled expression {}", compiled_function->function->getName());
fn = compiled_function->function;
}
else
@ -282,11 +295,23 @@ static FunctionBasePtr compile(
fn = std::make_shared<FunctionBaseAdaptor>(std::make_unique<LLVMFunction>(dag));
}
LOG_INFO(getLogger(), "Use compiled expression {}", fn->getName());
return fn;
}
static bool isCompilable(const IFunctionBase & function)
static bool isCompilableConstant(const ActionsDAG::Node & node)
{
return node.column && isColumnConst(*node.column) && canBeNativeType(*node.result_type) && node.allow_constant_folding;
}
static bool isCompilableFunction(const ActionsDAG::Node & node)
{
if (node.type != ActionsDAG::ActionType::FUNCTION)
return false;
auto & function = *node.function_base;
if (!canBeNativeType(*function.getResultType()))
return false;
@ -299,24 +324,21 @@ static bool isCompilable(const IFunctionBase & function)
return function.isCompilable();
}
static bool isCompilableConstant(const ActionsDAG::Node & node)
static bool isCompilableInput(const ActionsDAG::Node & node)
{
return node.column && isColumnConst(*node.column) && canBeNativeType(*node.result_type) && node.allow_constant_folding;
}
static bool isCompilableFunction(const ActionsDAG::Node & node)
{
return node.type == ActionsDAG::ActionType::FUNCTION && isCompilable(*node.function_base);
return node.type == ActionsDAG::ActionType::INPUT || node.type == ActionsDAG::ActionType::ALIAS;
}
static CompileDAG getCompilableDAG(
const ActionsDAG::Node * root,
ActionsDAG::NodeRawConstPtrs & children,
const std::unordered_set<const ActionsDAG::Node *> & used_in_result)
ActionsDAG::NodeRawConstPtrs & children)
{
/// Extract CompileDag from root actions dag node, it is important that each root child is compilable.
CompileDAG dag;
std::unordered_map<const ActionsDAG::Node *, size_t> positions;
std::unordered_map<const ActionsDAG::Node *, size_t> visited_node_to_compile_dag_position;
struct Frame
{
const ActionsDAG::Node * node;
@ -324,56 +346,57 @@ static CompileDAG getCompilableDAG(
};
std::stack<Frame> stack;
stack.push(Frame{.node = root});
stack.emplace(Frame{.node = root});
while (!stack.empty())
{
auto & frame = stack.top();
bool is_const = isCompilableConstant(*frame.node);
bool can_inline = stack.size() == 1 || !used_in_result.count(frame.node);
bool is_compilable_function = !is_const && can_inline && isCompilableFunction(*frame.node);
const auto * node = frame.node;
while (is_compilable_function && frame.next_child_to_visit < frame.node->children.size())
while (frame.next_child_to_visit < node->children.size())
{
const auto * child = frame.node->children[frame.next_child_to_visit];
const auto & child = node->children[frame.next_child_to_visit];
if (positions.count(child))
if (visited_node_to_compile_dag_position.contains(child))
{
++frame.next_child_to_visit;
else
{
stack.emplace(Frame{.node = child});
break;
continue;
}
stack.emplace(Frame{.node = child});
break;
}
if (!is_compilable_function || frame.next_child_to_visit == frame.node->children.size())
bool should_visit_children_first = frame.next_child_to_visit < node->children.size();
if (should_visit_children_first)
continue;
CompileDAG::Node compile_node;
compile_node.function = node->function_base;
compile_node.result_type = node->result_type;
if (node->type == ActionsDAG::ActionType::FUNCTION)
{
CompileDAG::Node node;
node.function = frame.node->function_base;
node.result_type = frame.node->result_type;
if (is_compilable_function)
{
node.type = CompileDAG::CompileType::FUNCTION;
for (const auto * child : frame.node->children)
node.arguments.push_back(positions[child]);
}
else if (is_const)
{
node.type = CompileDAG::CompileType::CONSTANT;
node.column = frame.node->column;
}
else
{
node.type = CompileDAG::CompileType::INPUT;
children.emplace_back(frame.node);
}
positions[frame.node] = dag.getNodesCount();
dag.addNode(std::move(node));
stack.pop();
compile_node.type = CompileDAG::CompileType::FUNCTION;
for (const auto * child : node->children)
compile_node.arguments.push_back(visited_node_to_compile_dag_position[child]);
}
else if (isCompilableConstant(*node))
{
compile_node.type = CompileDAG::CompileType::CONSTANT;
compile_node.column = node->column;
}
else
{
compile_node.type = CompileDAG::CompileType::INPUT;
children.emplace_back(node);
}
visited_node_to_compile_dag_position[node] = dag.getNodesCount();
dag.addNode(std::move(compile_node));
stack.pop();
}
return dag;
@ -383,24 +406,21 @@ void ActionsDAG::compileFunctions(size_t min_count_to_compile_expression)
{
struct Data
{
bool is_compilable = false;
bool is_compilable_in_isolation = false;
bool all_children_compilable = false;
bool all_parents_compilable = true;
size_t num_inlineable_nodes = 0;
size_t children_size = 0;
};
std::unordered_map<const Node *, Data> data;
std::unordered_set<const Node *> used_in_result;
std::unordered_map<const Node *, Data> node_to_data;
/// Check which nodes can be compiled in isolation
for (const auto & node : nodes)
data[&node].is_compilable = isCompilableConstant(node) || isCompilableFunction(node);
for (const auto & node : nodes)
if (!data[&node].is_compilable)
for (const auto * child : node.children)
data[child].all_parents_compilable = false;
for (const auto * node : index)
used_in_result.insert(node);
{
bool node_is_compilable_in_isolation = isCompilableConstant(node) || isCompilableFunction(node) || isCompilableInput(node);
node_to_data[&node].is_compilable_in_isolation = node_is_compilable_in_isolation;
}
struct Frame
{
@ -409,84 +429,125 @@ void ActionsDAG::compileFunctions(size_t min_count_to_compile_expression)
};
std::stack<Frame> stack;
std::unordered_set<const Node *> visited;
std::unordered_set<const Node *> visited_nodes;
/** Algorithm is go throught each node in ActionsDAG, and for each node iterate thought all children.
* and update data with their compilable status.
* After this procedure in data for each node is initialized.
*/
for (auto & node : nodes)
{
if (visited.count(&node))
if (visited_nodes.contains(&node))
continue;
stack.emplace(Frame{.node = &node});
while (!stack.empty())
{
auto & frame = stack.top();
auto & current_frame = stack.top();
auto & current_node = current_frame.node;
while (frame.next_child_to_visit < frame.node->children.size())
while (current_frame.next_child_to_visit < current_node->children.size())
{
const auto * child = frame.node->children[frame.next_child_to_visit];
const auto & child = node.children[current_frame.next_child_to_visit];
if (visited.count(child))
++frame.next_child_to_visit;
else
if (visited_nodes.contains(child))
{
stack.emplace(Frame{.node = child});
break;
}
}
if (frame.next_child_to_visit == frame.node->children.size())
{
auto & cur = data[frame.node];
if (cur.is_compilable)
{
cur.num_inlineable_nodes = 1;
if (!isCompilableConstant(*frame.node))
for (const auto * child : frame.node->children)
if (!used_in_result.count(child))
cur.num_inlineable_nodes += data[child].num_inlineable_nodes;
/// Check if we should inline current node.
bool should_compile = true;
/// Inline parents instead of node is possible.
if (!used_in_result.count(frame.node) && cur.all_parents_compilable)
should_compile = false;
/// There is no reason to inline single node.
/// The result of compiling function in isolation is pretty much the same as its `execute` method.
if (cur.num_inlineable_nodes <= 1)
should_compile = false;
if (should_compile)
{
NodeRawConstPtrs new_children;
auto dag = getCompilableDAG(frame.node, new_children, used_in_result);
if (dag.getInputNodesCount() > 0)
{
if (auto fn = compile(dag, min_count_to_compile_expression))
{
ColumnsWithTypeAndName arguments;
arguments.reserve(new_children.size());
for (const auto * child : new_children)
arguments.emplace_back(child->column, child->result_type, child->result_name);
auto * frame_node = const_cast<Node *>(frame.node);
frame_node->type = ActionsDAG::ActionType::FUNCTION;
frame_node->function_base = fn;
frame_node->function = fn->prepare(arguments);
frame_node->children.swap(new_children);
frame_node->is_function_compiled = true;
frame_node->column = nullptr;
}
}
}
++current_frame.next_child_to_visit;
continue;
}
visited.insert(frame.node);
stack.pop();
stack.emplace(Frame{.node=child});
break;
}
bool should_visit_children_first = current_frame.next_child_to_visit < current_node->children.size();
if (should_visit_children_first)
continue;
auto & current_node_data = node_to_data[current_node];
current_node_data.all_children_compilable = true;
if (current_node_data.is_compilable_in_isolation)
{
for (const auto * child : current_node->children)
{
current_node_data.all_children_compilable &= node_to_data[child].is_compilable_in_isolation;
current_node_data.all_children_compilable &= node_to_data[child].all_children_compilable;
current_node_data.children_size += node_to_data[child].children_size;
}
current_node_data.children_size += current_node->children.size();
}
visited_nodes.insert(current_node);
stack.pop();
}
}
for (const auto & node : nodes)
{
auto & node_data = node_to_data[&node];
bool is_compilable = node_data.is_compilable_in_isolation && node_data.all_children_compilable;
for (const auto & child : node.children)
node_to_data[child].all_parents_compilable &= is_compilable;
}
for (const auto & node : index)
{
/// Force result nodes to compile
node_to_data[node].all_parents_compilable = false;
}
std::vector<Node *> nodes_to_compile;
for (auto & node : nodes)
{
auto & node_data = node_to_data[&node];
bool can_be_compiled = node_data.is_compilable_in_isolation && node_data.all_children_compilable && node.children.size() > 1;
/// If all parents are compilable then this node should not be standalone compiled
bool should_compile = can_be_compiled && !node_data.all_parents_compilable;
if (!should_compile)
continue;
nodes_to_compile.emplace_back(&node);
}
/** Sort nodes before compilation using their children size to avoid compiling subexpression before compile parent expression.
* For example we have expression SELECT a + 1 FROM test_table WHERE a + 1 > 0
* We can compile a + 1 and a + 1 > 0, but we should compile a + 1 after a + 1 > 0, because during compilation
* we change actions dag node children size.
*/
std::sort(nodes_to_compile.begin(), nodes_to_compile.end(), [&](const Node * lhs, const Node * rhs) { return node_to_data[lhs].children_size > node_to_data[rhs].children_size; });
for (auto & node : nodes_to_compile)
{
NodeRawConstPtrs new_children;
auto dag = getCompilableDAG(node, new_children);
if (dag.getInputNodesCount() == 0)
continue;
if (auto fn = compile(dag, min_count_to_compile_expression))
{
ColumnsWithTypeAndName arguments;
arguments.reserve(new_children.size());
for (const auto * child : new_children)
arguments.emplace_back(child->column, child->result_type, child->result_name);
node->type = ActionsDAG::ActionType::FUNCTION;
node->function_base = fn;
node->function = fn->prepare(arguments);
node->children.swap(new_children);
node->is_function_compiled = true;
node->column = nullptr;
}
}
}

View File

@ -1,5 +1,7 @@
#include "compileFunction.h"
#if USE_EMBEDDED_COMPILER
#include <llvm/IR/BasicBlock.h>
#include <llvm/IR/Function.h>
#include <llvm/IR/IRBuilder.h>
@ -30,6 +32,11 @@ namespace ProfileEvents
namespace DB
{
namespace ErrorCodes
{
extern const int LOGICAL_ERROR;
}
ColumnData getColumnData(const IColumn * column)
{
ColumnData result;
@ -158,3 +165,5 @@ CHJIT::CompiledModuleInfo compileFunction(CHJIT & jit, const IFunctionBaseImpl &
}
}
#endif