Added CHJIT documentation

This commit is contained in:
Maksim Kita 2021-05-07 21:36:07 +03:00
parent 57d5f55d7f
commit 21d8684aaf
7 changed files with 144 additions and 36 deletions

View File

@ -681,9 +681,13 @@ std::string ActionsDAG::dumpDAG() const
out << " " << (node.column ? node.column->getName() : "(no column)");
out << " " << (node.result_type ? node.result_type->getName() : "(no type)");
out << " " << (!node.result_name.empty() ? node.result_name : "(no name)");
if (node.function_base)
out << " [" << node.function_base->getName() << "]";
if (node.is_function_compiled)
out << " [compiled]";
out << "\n";
}

View File

@ -45,9 +45,10 @@ static Poco::Logger * getLogger()
class LLVMExecutableFunction : public IExecutableFunctionImpl
{
std::string name;
void * function = nullptr;
JITCompiledFunction function = nullptr;
public:
explicit LLVMExecutableFunction(const std::string & name_, void * function_)
explicit LLVMExecutableFunction(const std::string & name_, JITCompiledFunction function_)
: name(name_)
, function(function_)
{
@ -81,8 +82,7 @@ public:
}
columns[arguments.size()] = getColumnData(result_column.get());
auto * function_typed = reinterpret_cast<void (*) (size_t, ColumnData *)>(function);
function_typed(input_rows_count, columns.data());
function(input_rows_count, columns.data());
#if defined(MEMORY_SANITIZER)
/// Memory sanitizer don't know about stores from JIT-ed code.
@ -160,7 +160,8 @@ public:
if (!function)
throw Exception(ErrorCodes::NOT_IMPLEMENTED, "Cannot find compiled function {}", name);
return std::make_unique<LLVMExecutableFunction>(name, function);
JITCompiledFunction function_typed = reinterpret_cast<JITCompiledFunction>(function);
return std::make_unique<LLVMExecutableFunction>(name, function_typed);
}
bool isDeterministic() const override
@ -511,6 +512,8 @@ void ActionsDAG::compileFunctions(size_t min_count_to_compile_expression)
node_to_data[node].all_parents_compilable = false;
}
std::vector<Node *> nodes_to_compile;
for (auto & node : nodes)
{
auto & node_data = node_to_data[&node];
@ -523,8 +526,18 @@ void ActionsDAG::compileFunctions(size_t min_count_to_compile_expression)
if (!should_compile)
continue;
nodes_to_compile.emplace_back(&node);
}
/** Sort nodes before compilation using their children size to avoid compiling subexpression before compile parent expression.
* This is needed to avoid compiling expression more than once with different names because of compilation order.
*/
std::sort(nodes_to_compile.begin(), nodes_to_compile.end(), [&](const Node * lhs, const Node * rhs) { return node_to_data[lhs].children_size > node_to_data[rhs].children_size; });
for (auto & node : nodes_to_compile)
{
NodeRawConstPtrs new_children;
auto dag = getCompilableDAG(&node, new_children);
auto dag = getCompilableDAG(node, new_children);
if (dag.getInputNodesCount() == 0)
continue;
@ -536,12 +549,12 @@ void ActionsDAG::compileFunctions(size_t min_count_to_compile_expression)
for (const auto * child : new_children)
arguments.emplace_back(child->column, child->result_type, child->result_name);
node.type = ActionsDAG::ActionType::FUNCTION;
node.function_base = fn;
node.function = fn->prepare(arguments);
node.children.swap(new_children);
node.is_function_compiled = true;
node.column = nullptr;
node->type = ActionsDAG::ActionType::FUNCTION;
node->function_base = fn;
node->function = fn->prepare(arguments);
node->children.swap(new_children);
node->is_function_compiled = true;
node->column = nullptr;
}
}
}

View File

@ -33,6 +33,9 @@ namespace ErrorCodes
extern const int LOGICAL_ERROR;
}
/** Simple module to object file compiler.
* Result object cannot be used as machine code directly, it should be passed to linker.
*/
class JITCompiler
{
public:
@ -77,6 +80,10 @@ private:
llvm::TargetMachine & target_machine;
};
/** MemoryManager for module.
* Keep total allocated size during RuntimeDyld linker execution.
* Release compiled module memory in destructor.
*/
class JITModuleMemoryManager
{
class DefaultMMapper final : public llvm::SectionMemoryManager::MemoryMapper
@ -141,6 +148,7 @@ private:
std::unordered_map<std::string, void *> symbol_name_to_symbol_address;
};
/// GDB JITEventListener. Can be used if result machine code need to be debugged.
// class JITEventListener
// {
// public:
@ -171,6 +179,8 @@ CHJIT::CHJIT()
, compiler(std::make_unique<JITCompiler>(*machine))
, symbol_resolver(std::make_unique<JITSymbolResolver>())
{
/// Define common symbols that can be generated during compilation
/// Necessary for valid linker symbol resolution
symbol_resolver->registerSymbol("memset", reinterpret_cast<void *>(&memset));
symbol_resolver->registerSymbol("memcpy", reinterpret_cast<void *>(&memcpy));
symbol_resolver->registerSymbol("memcmp", reinterpret_cast<void *>(&memcmp));
@ -245,10 +255,8 @@ CHJIT::CompiledModuleInfo CHJIT::compileModule(std::unique_ptr<llvm::Module> mod
module_info.compiled_functions.emplace_back(std::move(function_name));
}
auto module_identifier = module->getModuleIdentifier();
module_info.size = module_memory_manager->getAllocatedSize();
module_info.module_identifier = current_module_key;
module_info.identifier = current_module_key;
module_identifier_to_memory_manager[current_module_key] = std::move(module_memory_manager);
@ -261,9 +269,9 @@ void CHJIT::deleteCompiledModule(const CHJIT::CompiledModuleInfo & module_info)
{
std::lock_guard<std::mutex> lock(jit_lock);
auto module_it = module_identifier_to_memory_manager.find(module_info.module_identifier);
auto module_it = module_identifier_to_memory_manager.find(module_info.identifier);
if (module_it == module_identifier_to_memory_manager.end())
throw Exception(ErrorCodes::LOGICAL_ERROR, "There is no compiled module with identifier {}", module_info.module_identifier);
throw Exception(ErrorCodes::LOGICAL_ERROR, "There is no compiled module with identifier {}", module_info.identifier);
for (const auto & function : module_info.compiled_functions)
name_to_symbol.erase(function);
@ -276,7 +284,7 @@ void * CHJIT::findCompiledFunction(const CompiledModuleInfo & module_info, const
{
std::lock_guard<std::mutex> lock(jit_lock);
std::string symbol_name = std::to_string(module_info.module_identifier) + '_' + function_name;
std::string symbol_name = std::to_string(module_info.identifier) + '_' + function_name;
auto it = name_to_symbol.find(symbol_name);
if (it != name_to_symbol.end())
return it->second;
@ -366,7 +374,7 @@ std::unique_ptr<llvm::TargetMachine> CHJIT::getTargetMachine()
options,
llvm::None,
llvm::None,
llvm::CodeGenOpt::Default,
llvm::CodeGenOpt::Aggressive,
jit);
if (!target_machine)

View File

@ -20,7 +20,29 @@ class JITModuleMemoryManager;
class JITSymbolResolver;
class JITCompiler;
/// TODO: Add documentation
/** Custom jit implementation
* Main use cases:
* 1. Compiled functions in module.
* 2. Release memory for compiled functions.
*
* In LLVM library there are 2 main JIT stacks MCJIT and ORCv2.
*
* Main reasons for custom implementation vs MCJIT
* MCJIT keeps llvm::Module and compiled object code before linking process after module was compiled.
* llvm::Module can be removed, but compiled object code cannot be removed. Memory for compiled code
* will be release only during MCJIT instance destruction. It is too expensive to create MCJIT
* instance for each compiled module.
*
* Main reasong for custom implementation vs ORCv2.
* ORC is on request compiled, we does not need support for asynchronous compilation.
* It was possible to remove compiled code with ORCv1 but it was deprecated.
* In ORCv2 this probably can be done only with custom layer and materialization unit.
* But it is inconvenient, discard is only called for materialization units by JITDylib that are not yet materialized.
*
* CHJIT interface is thread safe, that means all functions can be called from multiple threads and state of CHJIT instance
* will not be broken.
* It is client responsibility to be sure and do not use compiled code after it was released.
*/
class CHJIT
{
public:
@ -30,19 +52,39 @@ public:
struct CompiledModuleInfo
{
/// Size of compiled module code in bytes
size_t size;
uint64_t module_identifier;
/// Module identifier. Should not be changed by client
uint64_t identifier;
/// Vector of compiled function nameds. Should not be changed by client.
std::vector<std::string> compiled_functions;
};
/** Compile module. In compile function client responsibility is to fill module with necessary
* IR code, then it will be compiled by CHJIT instance.
* Return compiled module info.
*/
CompiledModuleInfo compileModule(std::function<void (llvm::Module &)> compile_function);
/** Delete compiled module. Pointers to functions from module become invalid after this call.
* It is client responsibility to be sure that there are no pointers to compiled module code.
*/
void deleteCompiledModule(const CompiledModuleInfo & module_info);
/** Find compiled function using module_info, and function_name.
* It is client responsibility to case result function to right signature.
* After call to deleteCompiledModule compiled functions from module become invalid.
*/
void * findCompiledFunction(const CompiledModuleInfo & module_info, const std::string & function_name) const;
/** Register external symbol for CHJIT instance to use, during linking.
* It can be function, or global constant.
* It is client responsibility to be sure that address of symbol is valid during CHJIT instance lifetime.
*/
void registerExternalSymbol(const std::string & symbol_name, void * address);
/** Total compiled code size for module that are currently valid.
*/
inline size_t getCompiledCodeSize() const { return compiled_code_size.load(std::memory_order_relaxed); }
private:

View File

@ -24,9 +24,14 @@ namespace llvm
namespace DB
{
/// DAG is represented as list of nodes stored in in-order traverse order.
/// Expression (a + 1) + (b + 1) will be represented like chain: a, 1, a + 1, b, b + 1, (a + 1) + (b + 1).
/// TODO: Consider to rename in CompileStack
/** This class is needed to compile part of ActionsDAG.
* For example we have expression (a + 1) + (b + 1) in actions dag.
* It must be added into CompileDAG in order of compile evaluation.
* Node a, Constant 1, Function add(a + 1), Input b, Constant 1, Function add(b, 1), Function add(add(a + 1), add(a + 1)).
*
* Compile function must be called with input_nodes_values equal to input nodes count.
* During compile funciton call CompileDAG is compiled in order of added nodes.
*/
class CompileDAG
{
public:
@ -69,7 +74,6 @@ public:
inline Node & operator[](size_t index) { return nodes[index]; }
inline const Node & operator[](size_t index) const { return nodes[index]; }
inline Node & front() { return nodes.front(); }
inline const Node & front() const { return nodes.front(); }

View File

@ -47,7 +47,7 @@ ColumnData getColumnData(const IColumn * column)
if (const auto * nullable = typeid_cast<const ColumnNullable *>(column))
{
result.null = nullable->getNullMapColumn().getRawData().data;
result.null_data = nullable->getNullMapColumn().getRawData().data;
column = & nullable->getNestedColumn();
}
@ -56,39 +56,51 @@ ColumnData getColumnData(const IColumn * column)
return result;
}
static void compileFunction(llvm::Module & module, const IFunctionBaseImpl & f)
static void compileFunction(llvm::Module & module, const IFunctionBaseImpl & function)
{
/** Algorithm is to create a loop that iterate over ColumnDataRowsSize size_t argument and
* over ColumnData data and null_data. On each step compiled expression from function
* will be executed over column data and null_data row.
*/
ProfileEvents::increment(ProfileEvents::CompileFunction);
const auto & arg_types = f.getArgumentTypes();
const auto & arg_types = function.getArgumentTypes();
llvm::IRBuilder<> b(module.getContext());
auto * size_type = b.getIntNTy(sizeof(size_t) * 8);
auto * data_type = llvm::StructType::get(b.getInt8PtrTy(), b.getInt8PtrTy());
auto * func_type = llvm::FunctionType::get(b.getVoidTy(), { size_type, data_type->getPointerTo() }, /*isVarArg=*/false);
auto * func = llvm::Function::Create(func_type, llvm::Function::ExternalLinkage, f.getName(), module);
/// Create function in module
auto * func = llvm::Function::Create(func_type, llvm::Function::ExternalLinkage, function.getName(), module);
auto * args = func->args().begin();
llvm::Value * counter_arg = &*args++;
llvm::Value * columns_arg = &*args++;
/// Initialize ColumnDataPlaceholder llvm represenation of ColumnData
/// Last columns ColumnDataPlaceholder is result column
auto * entry = llvm::BasicBlock::Create(b.getContext(), "entry", func);
b.SetInsertPoint(entry);
std::vector<ColumnDataPlaceholder> columns(arg_types.size() + 1);
for (size_t i = 0; i <= arg_types.size(); ++i)
{
const auto & type = i == arg_types.size() ? f.getResultType() : arg_types[i];
const auto & type = i == arg_types.size() ? function.getResultType() : arg_types[i];
auto * data = b.CreateLoad(b.CreateConstInBoundsGEP1_32(data_type, columns_arg, i));
columns[i].data_init = b.CreatePointerCast(b.CreateExtractValue(data, {0}), toNativeType(b, removeNullable(type))->getPointerTo());
columns[i].null_init = type->isNullable() ? b.CreateExtractValue(data, {1}) : nullptr;
}
/// assume nonzero initial value in `counter_arg`
/// Initialize loop
auto * loop = llvm::BasicBlock::Create(b.getContext(), "loop", func);
b.CreateBr(loop);
b.SetInsertPoint(loop);
auto * counter_phi = b.CreatePHI(counter_arg->getType(), 2);
counter_phi->addIncoming(counter_arg, entry);
for (auto & col : columns)
{
col.data = b.CreatePHI(col.data_init->getType(), 2);
@ -100,6 +112,8 @@ static void compileFunction(llvm::Module & module, const IFunctionBaseImpl & f)
}
}
/// Initialize column row values
Values arguments;
arguments.reserve(arg_types.size());
@ -121,7 +135,9 @@ static void compileFunction(llvm::Module & module, const IFunctionBaseImpl & f)
arguments.emplace_back(nullable_value);
}
auto * result = f.compile(b, std::move(arguments));
/// Compile values for column rows and store compiled value in result column
auto * result = function.compile(b, std::move(arguments));
if (columns.back().null)
{
b.CreateStore(b.CreateExtractValue(result, {0}), columns.back().data);
@ -132,6 +148,8 @@ static void compileFunction(llvm::Module & module, const IFunctionBaseImpl & f)
b.CreateStore(result, columns.back().data);
}
/// End of loop
auto * cur_block = b.GetInsertBlock();
for (auto & col : columns)
{
@ -148,13 +166,13 @@ static void compileFunction(llvm::Module & module, const IFunctionBaseImpl & f)
b.CreateRetVoid();
}
CHJIT::CompiledModuleInfo compileFunction(CHJIT & jit, const IFunctionBaseImpl & f)
CHJIT::CompiledModuleInfo compileFunction(CHJIT & jit, const IFunctionBaseImpl & function)
{
Stopwatch watch;
auto compiled_module_info = jit.compileModule([&](llvm::Module & module)
{
compileFunction(module, f);
compileFunction(module, function);
});
ProfileEvents::increment(ProfileEvents::CompileExpressionsMicroseconds, watch.elapsedMicroseconds());

View File

@ -12,15 +12,34 @@
namespace DB
{
/** ColumnData structure to pass into compiled function.
* data is raw column data.
* null_data is null map column raw data.
*/
struct ColumnData
{
const char * data = nullptr;
const char * null = nullptr;
const char * null_data = nullptr;
};
/** Returns ColumnData for column.
* If constant column is passed, LOGICAL_ERROR will be throwed.
*/
ColumnData getColumnData(const IColumn * column);
CHJIT::CompiledModuleInfo compileFunction(CHJIT & jit, const IFunctionBaseImpl & f);
using ColumnDataRowsSize = size_t;
using JITCompiledFunction = void (*)(ColumnDataRowsSize, ColumnData *);
/** Compile function to native jit code using CHJIT instance.
* Function is compiled as single module.
* After this function execution, code for function will be compiled and can be queried using
* findCompiledFunction with function name.
* Compiled function can be safely casted to JITCompiledFunction type and must be called with
* valid ColumnData and ColumnDataRowsSize.
* It is important that ColumnData parameter of JITCompiledFunction is result column,
* and will be filled by compiled function.
*/
CHJIT::CompiledModuleInfo compileFunction(CHJIT & jit, const IFunctionBaseImpl & function);
}