mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-09-30 05:30:51 +00:00
Added CHJIT
This commit is contained in:
parent
ca44ff2ede
commit
3ec4409d52
@ -184,6 +184,7 @@ add_object_library(clickhouse_disks Disks)
|
||||
add_object_library(clickhouse_interpreters Interpreters)
|
||||
add_object_library(clickhouse_interpreters_mysql Interpreters/MySQL)
|
||||
add_object_library(clickhouse_interpreters_clusterproxy Interpreters/ClusterProxy)
|
||||
add_object_library(clickhouse_interpreters_jit Interpreters/JIT)
|
||||
add_object_library(clickhouse_columns Columns)
|
||||
add_object_library(clickhouse_storages Storages)
|
||||
add_object_library(clickhouse_storages_distributed Storages/Distributed)
|
||||
|
@ -19,42 +19,11 @@
|
||||
#include <IO/WriteBufferFromString.h>
|
||||
#include <IO/Operators.h>
|
||||
|
||||
#pragma GCC diagnostic push
|
||||
#pragma GCC diagnostic ignored "-Wunused-parameter"
|
||||
#pragma GCC diagnostic ignored "-Wnon-virtual-dtor"
|
||||
|
||||
#include <llvm/Analysis/TargetTransformInfo.h>
|
||||
#include <llvm/IR/BasicBlock.h>
|
||||
#include <llvm/IR/DataLayout.h>
|
||||
#include <llvm/IR/DerivedTypes.h>
|
||||
#include <llvm/IR/Function.h>
|
||||
#include <llvm/IR/IRBuilder.h>
|
||||
#include <llvm/IR/LLVMContext.h>
|
||||
#include <llvm/IR/Mangler.h>
|
||||
#include <llvm/IR/Module.h>
|
||||
#include <llvm/IR/Type.h>
|
||||
#include <llvm/IR/LegacyPassManager.h>
|
||||
#include <llvm/ExecutionEngine/ExecutionEngine.h>
|
||||
#include <llvm/ExecutionEngine/JITSymbol.h>
|
||||
#include <llvm/ExecutionEngine/SectionMemoryManager.h>
|
||||
#include <llvm/ExecutionEngine/Orc/CompileUtils.h>
|
||||
#include <llvm/ExecutionEngine/Orc/IRCompileLayer.h>
|
||||
#include <llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h>
|
||||
#include <llvm/ExecutionEngine/Orc/ExecutionUtils.h>
|
||||
#include <llvm/Target/TargetMachine.h>
|
||||
#include <llvm/MC/SubtargetFeature.h>
|
||||
#include <llvm/Support/DynamicLibrary.h>
|
||||
#include <llvm/Support/Host.h>
|
||||
#include <llvm/Support/TargetRegistry.h>
|
||||
#include <llvm/Support/TargetSelect.h>
|
||||
#include <llvm/Transforms/IPO/PassManagerBuilder.h>
|
||||
|
||||
#pragma GCC diagnostic pop
|
||||
|
||||
/// 'LegacyRTDyldObjectLinkingLayer' is deprecated: ORCv1 layers (layers with the 'Legacy' prefix) are deprecated. Please use ORCv2
|
||||
/// 'LegacyIRCompileLayer' is deprecated: ORCv1 layers (layers with the 'Legacy' prefix) are deprecated. Please use the ORCv2 IRCompileLayer instead
|
||||
#pragma GCC diagnostic ignored "-Wdeprecated-declarations"
|
||||
|
||||
#include <Interpreters/JIT/CHJIT.h>
|
||||
|
||||
namespace ProfileEvents
|
||||
{
|
||||
@ -115,156 +84,7 @@ static void applyFunction(IFunctionBase & function, Field & value)
|
||||
col->get(0, value);
|
||||
}
|
||||
|
||||
static llvm::TargetMachine * getNativeMachine()
|
||||
{
|
||||
std::string error;
|
||||
auto cpu = llvm::sys::getHostCPUName();
|
||||
auto triple = llvm::sys::getProcessTriple();
|
||||
const auto * target = llvm::TargetRegistry::lookupTarget(triple, error);
|
||||
if (!target)
|
||||
throw Exception("Could not initialize native target: " + error, ErrorCodes::CANNOT_COMPILE_CODE);
|
||||
llvm::SubtargetFeatures features;
|
||||
llvm::StringMap<bool> feature_map;
|
||||
if (llvm::sys::getHostCPUFeatures(feature_map))
|
||||
for (auto & f : feature_map)
|
||||
features.AddFeature(f.first(), f.second);
|
||||
llvm::TargetOptions options;
|
||||
return target->createTargetMachine(
|
||||
triple, cpu, features.getString(), options, llvm::None,
|
||||
llvm::None, llvm::CodeGenOpt::Aggressive, /*jit=*/true
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
struct SymbolResolver : public llvm::orc::SymbolResolver
|
||||
{
|
||||
llvm::LegacyJITSymbolResolver & impl;
|
||||
|
||||
explicit SymbolResolver(llvm::LegacyJITSymbolResolver & impl_) : impl(impl_) {}
|
||||
|
||||
llvm::orc::SymbolNameSet getResponsibilitySet(const llvm::orc::SymbolNameSet & symbols) final
|
||||
{
|
||||
return symbols;
|
||||
}
|
||||
|
||||
llvm::orc::SymbolNameSet lookup(std::shared_ptr<llvm::orc::AsynchronousSymbolQuery> query, llvm::orc::SymbolNameSet symbols) final
|
||||
{
|
||||
llvm::orc::SymbolNameSet missing;
|
||||
for (const auto & symbol : symbols)
|
||||
{
|
||||
bool has_resolved = false;
|
||||
impl.lookup({*symbol}, [&](llvm::Expected<llvm::JITSymbolResolver::LookupResult> resolved)
|
||||
{
|
||||
if (resolved && !resolved->empty())
|
||||
{
|
||||
query->notifySymbolMetRequiredState(symbol, resolved->begin()->second);
|
||||
has_resolved = true;
|
||||
}
|
||||
});
|
||||
|
||||
if (!has_resolved)
|
||||
missing.insert(symbol);
|
||||
}
|
||||
return missing;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
struct LLVMContext
|
||||
{
|
||||
std::shared_ptr<llvm::LLVMContext> context {std::make_shared<llvm::LLVMContext>()};
|
||||
std::unique_ptr<llvm::Module> module {std::make_unique<llvm::Module>("jit", *context)};
|
||||
std::unique_ptr<llvm::TargetMachine> machine {getNativeMachine()};
|
||||
llvm::DataLayout layout {machine->createDataLayout()};
|
||||
llvm::IRBuilder<> builder {*context};
|
||||
|
||||
llvm::orc::ExecutionSession execution_session;
|
||||
|
||||
std::shared_ptr<llvm::SectionMemoryManager> memory_manager;
|
||||
llvm::orc::LegacyRTDyldObjectLinkingLayer object_layer;
|
||||
llvm::orc::LegacyIRCompileLayer<decltype(object_layer), llvm::orc::SimpleCompiler> compile_layer;
|
||||
|
||||
std::unordered_map<std::string, void *> symbols;
|
||||
|
||||
LLVMContext()
|
||||
: memory_manager(std::make_shared<llvm::SectionMemoryManager>())
|
||||
, object_layer(execution_session, [this](llvm::orc::VModuleKey)
|
||||
{
|
||||
return llvm::orc::LegacyRTDyldObjectLinkingLayer::Resources{memory_manager, std::make_shared<SymbolResolver>(*memory_manager)};
|
||||
})
|
||||
, compile_layer(object_layer, llvm::orc::SimpleCompiler(*machine))
|
||||
{
|
||||
std::cerr << "LLVMContext::constructor" << std::endl;
|
||||
|
||||
module->setDataLayout(layout);
|
||||
module->setTargetTriple(machine->getTargetTriple().getTriple());
|
||||
}
|
||||
|
||||
/// returns used memory
|
||||
void compileAllFunctionsToNativeCode()
|
||||
{
|
||||
if (module->empty())
|
||||
return;
|
||||
llvm::PassManagerBuilder pass_manager_builder;
|
||||
llvm::legacy::PassManager mpm;
|
||||
llvm::legacy::FunctionPassManager fpm(module.get());
|
||||
pass_manager_builder.OptLevel = 3;
|
||||
pass_manager_builder.SLPVectorize = true;
|
||||
pass_manager_builder.LoopVectorize = true;
|
||||
pass_manager_builder.RerollLoops = true;
|
||||
pass_manager_builder.VerifyInput = true;
|
||||
pass_manager_builder.VerifyOutput = true;
|
||||
machine->adjustPassManager(pass_manager_builder);
|
||||
fpm.add(llvm::createTargetTransformInfoWrapperPass(machine->getTargetIRAnalysis()));
|
||||
mpm.add(llvm::createTargetTransformInfoWrapperPass(machine->getTargetIRAnalysis()));
|
||||
pass_manager_builder.populateFunctionPassManager(fpm);
|
||||
pass_manager_builder.populateModulePassManager(mpm);
|
||||
fpm.doInitialization();
|
||||
for (auto & function : *module)
|
||||
{
|
||||
// std::cerr << "Run for function " << std::string(function.getName()) << std::endl;
|
||||
fpm.run(function);
|
||||
}
|
||||
fpm.doFinalization();
|
||||
mpm.run(*module);
|
||||
|
||||
std::vector<std::string> functions;
|
||||
functions.reserve(module->size());
|
||||
for (const auto & function : *module)
|
||||
{
|
||||
|
||||
functions.emplace_back(function.getName());
|
||||
}
|
||||
|
||||
std::cerr << "Dump module after compile " << std::endl;
|
||||
|
||||
std::string value;
|
||||
llvm::raw_string_ostream stream(value);
|
||||
module->print(stream, nullptr);
|
||||
|
||||
std::cerr << value << std::endl;
|
||||
|
||||
llvm::orc::VModuleKey module_key = execution_session.allocateVModule();
|
||||
if (compile_layer.addModule(module_key, std::move(module)))
|
||||
throw Exception("Cannot add module to compile layer", ErrorCodes::CANNOT_COMPILE_CODE);
|
||||
|
||||
for (const auto & name : functions)
|
||||
{
|
||||
std::string mangled_name;
|
||||
llvm::raw_string_ostream mangled_name_stream(mangled_name);
|
||||
llvm::Mangler::getNameWithPrefix(mangled_name_stream, name, layout);
|
||||
mangled_name_stream.flush();
|
||||
auto symbol = compile_layer.findSymbol(mangled_name, false);
|
||||
if (!symbol)
|
||||
continue; /// external function (e.g. an intrinsic that calls into libc)
|
||||
auto address = symbol.getAddress();
|
||||
if (!address)
|
||||
throw Exception("Function " + name + " failed to link", ErrorCodes::CANNOT_COMPILE_CODE);
|
||||
symbols[name] = reinterpret_cast<void *>(*address);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
static CHJIT jit;
|
||||
|
||||
template <typename... Ts>
|
||||
static bool castToEitherWithNullable(IColumn * column)
|
||||
@ -279,13 +99,12 @@ class LLVMExecutableFunction : public IExecutableFunctionImpl
|
||||
void * function;
|
||||
|
||||
public:
|
||||
LLVMExecutableFunction(const std::string & name_, const std::unordered_map<std::string, void *> & symbols)
|
||||
explicit LLVMExecutableFunction(const std::string & name_)
|
||||
: name(name_)
|
||||
, function(jit.findCompiledFunction(name))
|
||||
{
|
||||
auto it = symbols.find(name);
|
||||
if (symbols.end() == it)
|
||||
throw Exception("Cannot find symbol " + name + " in LLVMContext", ErrorCodes::LOGICAL_ERROR);
|
||||
function = it->second;
|
||||
if (!function)
|
||||
throw Exception(ErrorCodes::LOGICAL_ERROR, "Cannot find compiled function {}", name_);
|
||||
}
|
||||
|
||||
String getName() const override { return name; }
|
||||
@ -331,18 +150,18 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
static void compileFunctionToLLVMByteCode(LLVMContext & context, const IFunctionBaseImpl & f)
|
||||
static void compileFunction(llvm::Module & module, const IFunctionBaseImpl & f)
|
||||
{
|
||||
ProfileEvents::increment(ProfileEvents::CompileFunction);
|
||||
|
||||
const auto & arg_types = f.getArgumentTypes();
|
||||
auto & b = context.builder;
|
||||
|
||||
llvm::IRBuilder<> b(module.getContext());
|
||||
auto * size_type = b.getIntNTy(sizeof(size_t) * 8);
|
||||
auto * data_type = llvm::StructType::get(b.getInt8PtrTy(), b.getInt8PtrTy(), size_type);
|
||||
auto * func_type = llvm::FunctionType::get(b.getVoidTy(), { size_type, data_type->getPointerTo() }, /*isVarArg=*/false);
|
||||
|
||||
/// TODO: External linkage
|
||||
auto * func = llvm::Function::Create(func_type, llvm::Function::ExternalLinkage, f.getName(), context.module.get());
|
||||
auto * func = llvm::Function::Create(func_type, llvm::Function::ExternalLinkage, f.getName(), module);
|
||||
auto * args = func->args().begin();
|
||||
llvm::Value * counter_arg = &*args++;
|
||||
llvm::Value * columns_arg = &*args++;
|
||||
@ -414,14 +233,6 @@ static void compileFunctionToLLVMByteCode(LLVMContext & context, const IFunction
|
||||
b.CreateCondBr(b.CreateICmpNE(counter_phi, llvm::ConstantInt::get(size_type, 1)), loop, end);
|
||||
b.SetInsertPoint(end);
|
||||
b.CreateRetVoid();
|
||||
|
||||
std::cerr << "Dump module" << std::endl;
|
||||
|
||||
std::string value;
|
||||
llvm::raw_string_ostream mangled_name_stream(value);
|
||||
context.module->print(mangled_name_stream, nullptr);
|
||||
|
||||
std::cerr << value << std::endl;
|
||||
}
|
||||
|
||||
static llvm::Constant * getNativeValue(llvm::Type * type, const IColumn & column, size_t i)
|
||||
@ -474,21 +285,16 @@ static CompilableExpression subexpression(const IFunctionBase & f, std::vector<C
|
||||
};
|
||||
}
|
||||
|
||||
struct LLVMModuleState
|
||||
{
|
||||
std::unordered_map<std::string, void *> symbols;
|
||||
std::shared_ptr<llvm::LLVMContext> major_context;
|
||||
std::shared_ptr<llvm::SectionMemoryManager> memory_manager;
|
||||
};
|
||||
|
||||
LLVMFunction::LLVMFunction(const CompileDAG & dag)
|
||||
: name(dag.dump())
|
||||
, module_state(std::make_unique<LLVMModuleState>())
|
||||
{
|
||||
LLVMContext context;
|
||||
std::vector<CompilableExpression> expressions;
|
||||
expressions.reserve(dag.size());
|
||||
|
||||
auto & context = jit.getContext();
|
||||
llvm::IRBuilder<> builder(context);
|
||||
|
||||
for (const auto & node : dag)
|
||||
{
|
||||
switch (node.type)
|
||||
@ -498,7 +304,7 @@ LLVMFunction::LLVMFunction(const CompileDAG & dag)
|
||||
const auto * col = typeid_cast<const ColumnConst *>(node.column.get());
|
||||
|
||||
/// TODO: implement `getNativeValue` for all types & replace the check with `c.column && toNativeType(...)`
|
||||
if (!getNativeValue(toNativeType(context.builder, node.result_type), col->getDataColumn(), 0))
|
||||
if (!getNativeValue(toNativeType(builder, node.result_type), col->getDataColumn(), 0))
|
||||
throw Exception(ErrorCodes::LOGICAL_ERROR,
|
||||
"Cannot compile constant of type {} = {}",
|
||||
node.result_type->getName(),
|
||||
@ -530,12 +336,9 @@ LLVMFunction::LLVMFunction(const CompileDAG & dag)
|
||||
|
||||
expression = std::move(expressions.back());
|
||||
|
||||
compileFunctionToLLVMByteCode(context, *this);
|
||||
context.compileAllFunctionsToNativeCode();
|
||||
|
||||
module_state->symbols = context.symbols;
|
||||
module_state->major_context = context.context;
|
||||
module_state->memory_manager = context.memory_manager;
|
||||
auto module_for_compilation = jit.createModuleForCompilation();
|
||||
compileFunction(*module_for_compilation, *this);
|
||||
jit.compileModule(std::move(module_for_compilation));
|
||||
}
|
||||
|
||||
llvm::Value * LLVMFunction::compile(llvm::IRBuilderBase & builder, ValuePlaceholders values) const
|
||||
@ -543,7 +346,7 @@ llvm::Value * LLVMFunction::compile(llvm::IRBuilderBase & builder, ValuePlacehol
|
||||
return expression(builder, values);
|
||||
}
|
||||
|
||||
ExecutableFunctionImplPtr LLVMFunction::prepare(const ColumnsWithTypeAndName &) const { return std::make_unique<LLVMExecutableFunction>(name, module_state->symbols); }
|
||||
ExecutableFunctionImplPtr LLVMFunction::prepare(const ColumnsWithTypeAndName &) const { return std::make_unique<LLVMExecutableFunction>(name); }
|
||||
|
||||
bool LLVMFunction::isDeterministic() const
|
||||
{
|
||||
@ -791,18 +594,6 @@ static FunctionBasePtr compile(
|
||||
static std::unordered_map<UInt128, UInt32, UInt128Hash> counter;
|
||||
static std::mutex mutex;
|
||||
|
||||
struct LLVMTargetInitializer
|
||||
{
|
||||
LLVMTargetInitializer()
|
||||
{
|
||||
llvm::InitializeNativeTarget();
|
||||
llvm::InitializeNativeTargetAsmPrinter();
|
||||
llvm::sys::DynamicLibrary::LoadLibraryPermanently(nullptr);
|
||||
}
|
||||
};
|
||||
|
||||
static LLVMTargetInitializer initializer;
|
||||
|
||||
auto hash_key = dag.hash();
|
||||
{
|
||||
std::lock_guard lock(mutex);
|
||||
@ -834,8 +625,6 @@ static FunctionBasePtr compile(
|
||||
|
||||
void ActionsDAG::compileFunctions(size_t min_count_to_compile_expression)
|
||||
{
|
||||
// std::cerr << "ActionsDAG::compileFunctions before dump " << dumpDAG() << std::endl;
|
||||
|
||||
struct Data
|
||||
{
|
||||
bool is_compilable = false;
|
||||
|
@ -17,8 +17,6 @@ namespace DB
|
||||
|
||||
using CompilableExpression = std::function<llvm::Value * (llvm::IRBuilderBase &, const ValuePlaceholders &)>;
|
||||
|
||||
struct LLVMModuleState;
|
||||
|
||||
class LLVMFunction : public IFunctionBaseImpl
|
||||
{
|
||||
std::string name;
|
||||
@ -27,8 +25,6 @@ class LLVMFunction : public IFunctionBaseImpl
|
||||
std::vector<FunctionBasePtr> originals;
|
||||
CompilableExpression expression;
|
||||
|
||||
std::unique_ptr<LLVMModuleState> module_state;
|
||||
|
||||
public:
|
||||
|
||||
/// LLVMFunction is a compiled part of ActionsDAG.
|
||||
@ -87,7 +83,6 @@ public:
|
||||
|
||||
Monotonicity getMonotonicityForRange(const IDataType & type, const Field & left, const Field & right) const override;
|
||||
|
||||
const LLVMModuleState * getLLVMModuleState() const { return module_state.get(); }
|
||||
};
|
||||
|
||||
/** This child of LRUCache breaks one of it's invariants: total weight may be changed after insertion.
|
||||
|
333
src/Interpreters/JIT/CHJIT.cpp
Normal file
333
src/Interpreters/JIT/CHJIT.cpp
Normal file
@ -0,0 +1,333 @@
|
||||
#include "CHJIT.h"
|
||||
|
||||
#include <llvm/Analysis/TargetTransformInfo.h>
|
||||
#include <llvm/IR/BasicBlock.h>
|
||||
#include <llvm/IR/DataLayout.h>
|
||||
#include <llvm/IR/DerivedTypes.h>
|
||||
#include <llvm/IR/Function.h>
|
||||
#include <llvm/IR/IRBuilder.h>
|
||||
#include <llvm/IR/Mangler.h>
|
||||
#include <llvm/IR/Type.h>
|
||||
#include <llvm/IR/LegacyPassManager.h>
|
||||
#include <llvm/ExecutionEngine/JITSymbol.h>
|
||||
#include <llvm/ExecutionEngine/SectionMemoryManager.h>
|
||||
#include <llvm/MC/SubtargetFeature.h>
|
||||
#include <llvm/Support/DynamicLibrary.h>
|
||||
#include <llvm/Support/Host.h>
|
||||
#include <llvm/Support/TargetRegistry.h>
|
||||
#include <llvm/Support/TargetSelect.h>
|
||||
#include <llvm/Transforms/IPO/PassManagerBuilder.h>
|
||||
#include <llvm/Support/SmallVectorMemoryBuffer.h>
|
||||
|
||||
#include <Common/Exception.h>
|
||||
|
||||
#include <iostream>
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
namespace ErrorCodes
|
||||
{
|
||||
extern const int CANNOT_COMPILE_CODE;
|
||||
extern const int LOGICAL_ERROR;
|
||||
}
|
||||
|
||||
class JITCompiler
|
||||
{
|
||||
public:
|
||||
|
||||
explicit JITCompiler(llvm::TargetMachine &target_machine_)
|
||||
: target_machine(target_machine_)
|
||||
{
|
||||
}
|
||||
|
||||
std::unique_ptr<llvm::MemoryBuffer> compile(llvm::Module & module)
|
||||
{
|
||||
auto materialize_error = module.materializeAll();
|
||||
if (materialize_error)
|
||||
{
|
||||
std::string error_message;
|
||||
handleAllErrors(
|
||||
std::move(materialize_error), [&](const llvm::ErrorInfoBase & error_info) { error_message = error_info.message(); });
|
||||
|
||||
throw Exception(ErrorCodes::CANNOT_COMPILE_CODE, "Cannot materialize module {}", error_message);
|
||||
}
|
||||
|
||||
llvm::SmallVector<char, 4096> object_buffer;
|
||||
|
||||
llvm::raw_svector_ostream object_stream(object_buffer);
|
||||
llvm::legacy::PassManager pass_manager;
|
||||
llvm::MCContext * machine_code_context = nullptr;
|
||||
|
||||
if (target_machine.addPassesToEmitMC(pass_manager, machine_code_context, object_stream))
|
||||
throw Exception(ErrorCodes::CANNOT_COMPILE_CODE, "MachineCode is not supported for the platform");
|
||||
|
||||
pass_manager.run(module);
|
||||
|
||||
std::unique_ptr<llvm::MemoryBuffer> compiled_object_buffer = std::make_unique<llvm::SmallVectorMemoryBuffer>(
|
||||
std::move(object_buffer), "<in memory object compiled from " + module.getModuleIdentifier() + ">");
|
||||
|
||||
return compiled_object_buffer;
|
||||
}
|
||||
|
||||
~JITCompiler() = default;
|
||||
|
||||
private:
|
||||
llvm::TargetMachine & target_machine;
|
||||
};
|
||||
|
||||
class JITModuleMemoryManager : public llvm::SectionMemoryManager
|
||||
{
|
||||
class DefaultMMapper final : public llvm::SectionMemoryManager::MemoryMapper
|
||||
{
|
||||
public:
|
||||
llvm::sys::MemoryBlock allocateMappedMemory(
|
||||
SectionMemoryManager::AllocationPurpose Purpose [[maybe_unused]],
|
||||
size_t NumBytes,
|
||||
const llvm::sys::MemoryBlock * const NearBlock,
|
||||
unsigned Flags,
|
||||
std::error_code & EC) override
|
||||
{
|
||||
auto allocated_memory_block = llvm::sys::Memory::allocateMappedMemory(NumBytes, NearBlock, Flags, EC);
|
||||
allocated_size += allocated_memory_block.allocatedSize();
|
||||
return allocated_memory_block;
|
||||
}
|
||||
|
||||
std::error_code protectMappedMemory(const llvm::sys::MemoryBlock & Block, unsigned Flags) override
|
||||
{
|
||||
return llvm::sys::Memory::protectMappedMemory(Block, Flags);
|
||||
}
|
||||
|
||||
std::error_code releaseMappedMemory(llvm::sys::MemoryBlock & M) override { return llvm::sys::Memory::releaseMappedMemory(M); }
|
||||
|
||||
size_t allocated_size = 0;
|
||||
};
|
||||
|
||||
public:
|
||||
JITModuleMemoryManager() : llvm::SectionMemoryManager(&mmaper) { }
|
||||
|
||||
inline size_t getAllocatedSize() const { return mmaper.allocated_size; }
|
||||
|
||||
~JITModuleMemoryManager() override = default;
|
||||
|
||||
private:
|
||||
DefaultMMapper mmaper;
|
||||
};
|
||||
|
||||
class JITSymbolResolver : public llvm::LegacyJITSymbolResolver
|
||||
{
|
||||
public:
|
||||
llvm::JITSymbol findSymbolInLogicalDylib(const std::string &) override { return nullptr; }
|
||||
|
||||
llvm::JITSymbol findSymbol(const std::string & Name) override
|
||||
{
|
||||
auto address_it = symbol_name_to_symbol_address.find(Name);
|
||||
if (address_it == symbol_name_to_symbol_address.end())
|
||||
throw Exception(ErrorCodes::CANNOT_COMPILE_CODE, "Could not find symbol {}", Name);
|
||||
|
||||
uint64_t symbol_address = reinterpret_cast<uint64_t>(address_it->second);
|
||||
auto jit_symbol = llvm::JITSymbol(symbol_address, llvm::JITSymbolFlags::None);
|
||||
|
||||
return jit_symbol;
|
||||
}
|
||||
|
||||
void registerSymbol(const std::string & symbol_name, void * symbol) { symbol_name_to_symbol_address[symbol_name] = symbol; }
|
||||
|
||||
~JITSymbolResolver() override = default;
|
||||
|
||||
private:
|
||||
std::unordered_map<std::string, void *> symbol_name_to_symbol_address;
|
||||
};
|
||||
|
||||
CHJIT::CHJIT()
|
||||
: machine(getTargetMachine())
|
||||
, layout(machine->createDataLayout())
|
||||
, compiler(std::make_unique<JITCompiler>(*machine))
|
||||
, symbol_resolver(std::make_unique<JITSymbolResolver>())
|
||||
{
|
||||
}
|
||||
|
||||
CHJIT::~CHJIT() = default;
|
||||
|
||||
std::unique_ptr<llvm::Module> CHJIT::createModuleForCompilation()
|
||||
{
|
||||
std::unique_ptr<llvm::Module> module = std::make_unique<llvm::Module>("jit " + std::to_string(current_module_key), context);
|
||||
module->setDataLayout(layout);
|
||||
module->setTargetTriple(machine->getTargetTriple().getTriple());
|
||||
|
||||
++current_module_key;
|
||||
|
||||
return module;
|
||||
}
|
||||
|
||||
CHJIT::CompiledModuleInfo CHJIT::compileModule(std::unique_ptr<llvm::Module> module)
|
||||
{
|
||||
runOptimizationPassesOnModule(*module);
|
||||
|
||||
auto buffer = compiler->compile(*module);
|
||||
|
||||
llvm::Expected<std::unique_ptr<llvm::object::ObjectFile>> object = llvm::object::ObjectFile::createObjectFile(*buffer);
|
||||
|
||||
if (!object)
|
||||
{
|
||||
std::string error_message;
|
||||
handleAllErrors(object.takeError(), [&](const llvm::ErrorInfoBase & error_info) { error_message = error_info.message(); });
|
||||
|
||||
throw Exception(ErrorCodes::CANNOT_COMPILE_CODE, "Cannot create object file from compiled buffer {}", error_message);
|
||||
}
|
||||
|
||||
std::unique_ptr<JITModuleMemoryManager> module_memory_manager = std::make_unique<JITModuleMemoryManager>();
|
||||
llvm::RuntimeDyld dynamic_linker = {*module_memory_manager, *symbol_resolver};
|
||||
|
||||
std::unique_ptr<llvm::RuntimeDyld::LoadedObjectInfo> linked_object = dynamic_linker.loadObject(*object.get());
|
||||
|
||||
if (dynamic_linker.hasError())
|
||||
throw Exception(ErrorCodes::CANNOT_COMPILE_CODE, "RuntimeDyld error {}", std::string(dynamic_linker.getErrorString()));
|
||||
|
||||
dynamic_linker.resolveRelocations();
|
||||
module_memory_manager->finalizeMemory();
|
||||
|
||||
CompiledModuleInfo module_info;
|
||||
|
||||
for (const auto & function : *module)
|
||||
{
|
||||
if (function.isDeclaration())
|
||||
continue;
|
||||
|
||||
auto function_name = std::string(function.getName());
|
||||
|
||||
auto mangled_name = getMangledName(function_name);
|
||||
auto jit_symbol = dynamic_linker.getSymbol(mangled_name);
|
||||
|
||||
if (!jit_symbol)
|
||||
throw Exception(ErrorCodes::CANNOT_COMPILE_CODE, "DynamicLinker could not found symbol {} after compilation", function_name);
|
||||
|
||||
auto * jit_symbol_address = reinterpret_cast<void *>(jit_symbol.getAddress());
|
||||
name_to_symbol[function_name] = jit_symbol_address;
|
||||
}
|
||||
|
||||
auto module_identifier = module->getModuleIdentifier();
|
||||
|
||||
module_info.size = module_memory_manager->getAllocatedSize();
|
||||
module_info.module_identifier = module_identifier;
|
||||
|
||||
module_identifier_to_memory_manager[module_identifier] = std::move(module_memory_manager);
|
||||
|
||||
compiled_code_size += module_info.size;
|
||||
|
||||
return module_info;
|
||||
}
|
||||
|
||||
void CHJIT::deleteCompiledModule(const CHJIT::CompiledModuleInfo & module_info)
|
||||
{
|
||||
auto module_it = module_identifier_to_memory_manager.find(module_info.module_identifier);
|
||||
if (module_it == module_identifier_to_memory_manager.end())
|
||||
throw Exception(ErrorCodes::LOGICAL_ERROR, "There is no compiled module with identifier {}", module_info.module_identifier);
|
||||
|
||||
for (const auto & function : module_info.compiled_functions)
|
||||
name_to_symbol.erase(function);
|
||||
|
||||
module_identifier_to_memory_manager.erase(module_it);
|
||||
compiled_code_size -= module_info.size;
|
||||
}
|
||||
|
||||
void * CHJIT::findCompiledFunction(const std::string & name) const
|
||||
{
|
||||
auto it = name_to_symbol.find(name);
|
||||
if (it != name_to_symbol.end())
|
||||
return it->second;
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void CHJIT::registerExternalSymbol(const std::string & symbol_name, void * address)
|
||||
{
|
||||
symbol_resolver->registerSymbol(symbol_name, address);
|
||||
}
|
||||
|
||||
std::string CHJIT::getMangledName(const std::string & name_to_mangle) const
|
||||
{
|
||||
std::string mangled_name;
|
||||
llvm::raw_string_ostream mangled_name_stream(mangled_name);
|
||||
llvm::Mangler::getNameWithPrefix(mangled_name_stream, name_to_mangle, layout);
|
||||
mangled_name_stream.flush();
|
||||
|
||||
return mangled_name;
|
||||
}
|
||||
|
||||
void CHJIT::runOptimizationPassesOnModule(llvm::Module & module) const
|
||||
{
|
||||
llvm::PassManagerBuilder pass_manager_builder;
|
||||
llvm::legacy::PassManager mpm;
|
||||
llvm::legacy::FunctionPassManager fpm(&module);
|
||||
pass_manager_builder.OptLevel = 3;
|
||||
pass_manager_builder.SLPVectorize = true;
|
||||
pass_manager_builder.LoopVectorize = true;
|
||||
pass_manager_builder.RerollLoops = true;
|
||||
pass_manager_builder.VerifyInput = true;
|
||||
pass_manager_builder.VerifyOutput = true;
|
||||
machine->adjustPassManager(pass_manager_builder);
|
||||
|
||||
fpm.add(llvm::createTargetTransformInfoWrapperPass(machine->getTargetIRAnalysis()));
|
||||
mpm.add(llvm::createTargetTransformInfoWrapperPass(machine->getTargetIRAnalysis()));
|
||||
|
||||
pass_manager_builder.populateFunctionPassManager(fpm);
|
||||
pass_manager_builder.populateModulePassManager(mpm);
|
||||
|
||||
fpm.doInitialization();
|
||||
for (auto & function : module)
|
||||
fpm.run(function);
|
||||
fpm.doFinalization();
|
||||
|
||||
mpm.run(module);
|
||||
}
|
||||
|
||||
std::atomic<bool> initialized = false;
|
||||
|
||||
static void initializeLLVMTarget()
|
||||
{
|
||||
if (initialized)
|
||||
return;
|
||||
|
||||
initialized = true;
|
||||
llvm::InitializeNativeTarget();
|
||||
llvm::InitializeNativeTargetAsmPrinter();
|
||||
llvm::sys::DynamicLibrary::LoadLibraryPermanently(nullptr);
|
||||
}
|
||||
|
||||
std::unique_ptr<llvm::TargetMachine> CHJIT::getTargetMachine()
|
||||
{
|
||||
initializeLLVMTarget();
|
||||
|
||||
std::string error;
|
||||
auto cpu = llvm::sys::getHostCPUName();
|
||||
auto triple = llvm::sys::getProcessTriple();
|
||||
const auto * target = llvm::TargetRegistry::lookupTarget(triple, error);
|
||||
if (!target)
|
||||
throw Exception(ErrorCodes::CANNOT_COMPILE_CODE, "Cannot find target triple {} error {}", triple, error);
|
||||
|
||||
llvm::SubtargetFeatures features;
|
||||
llvm::StringMap<bool> feature_map;
|
||||
if (llvm::sys::getHostCPUFeatures(feature_map))
|
||||
for (auto & f : feature_map)
|
||||
features.AddFeature(f.first(), f.second);
|
||||
|
||||
llvm::TargetOptions options;
|
||||
|
||||
bool jit = true;
|
||||
auto * target_machine = target->createTargetMachine(triple,
|
||||
cpu,
|
||||
features.getString(),
|
||||
options,
|
||||
llvm::None,
|
||||
llvm::None,
|
||||
llvm::CodeGenOpt::Default,
|
||||
jit);
|
||||
|
||||
if (!target_machine)
|
||||
throw Exception(ErrorCodes::CANNOT_COMPILE_CODE, "Cannot create target machine");
|
||||
|
||||
return std::unique_ptr<llvm::TargetMachine>(target_machine);
|
||||
}
|
||||
|
||||
}
|
70
src/Interpreters/JIT/CHJIT.h
Normal file
70
src/Interpreters/JIT/CHJIT.h
Normal file
@ -0,0 +1,70 @@
|
||||
#pragma once
|
||||
|
||||
#include <unordered_map>
|
||||
|
||||
#include <llvm/IR/LLVMContext.h>
|
||||
#include <llvm/IR/Module.h>
|
||||
#include <llvm/Target/TargetMachine.h>
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
class JITModuleMemoryManager;
|
||||
class JITSymbolResolver;
|
||||
class JITCompiler;
|
||||
|
||||
class CHJIT
|
||||
{
|
||||
public:
|
||||
CHJIT();
|
||||
|
||||
~CHJIT();
|
||||
|
||||
struct CompiledModuleInfo
|
||||
{
|
||||
size_t size;
|
||||
std::string module_identifier;
|
||||
std::vector<std::string> compiled_functions;
|
||||
};
|
||||
|
||||
std::unique_ptr<llvm::Module> createModuleForCompilation();
|
||||
|
||||
CompiledModuleInfo compileModule(std::unique_ptr<llvm::Module> module);
|
||||
|
||||
void deleteCompiledModule(const CompiledModuleInfo & module_info);
|
||||
|
||||
void * findCompiledFunction(const std::string & name) const;
|
||||
|
||||
void registerExternalSymbol(const std::string & symbol_name, void * address);
|
||||
|
||||
llvm::LLVMContext & getContext()
|
||||
{
|
||||
return context;
|
||||
}
|
||||
|
||||
inline size_t getCompiledCodeSize() const
|
||||
{
|
||||
return compiled_code_size;
|
||||
}
|
||||
|
||||
private:
|
||||
|
||||
std::string getMangledName(const std::string & name_to_mangle) const;
|
||||
|
||||
void runOptimizationPassesOnModule(llvm::Module & module) const;
|
||||
|
||||
static std::unique_ptr<llvm::TargetMachine> getTargetMachine();
|
||||
|
||||
llvm::LLVMContext context;
|
||||
std::unique_ptr<llvm::TargetMachine> machine;
|
||||
llvm::DataLayout layout;
|
||||
std::unique_ptr<JITCompiler> compiler;
|
||||
std::unique_ptr<JITSymbolResolver> symbol_resolver;
|
||||
|
||||
std::unordered_map<std::string, void *> name_to_symbol;
|
||||
std::unordered_map<std::string, std::unique_ptr<JITModuleMemoryManager>> module_identifier_to_memory_manager;
|
||||
size_t current_module_key = 0;
|
||||
size_t compiled_code_size = 0;
|
||||
};
|
||||
|
||||
}
|
@ -1,584 +1,56 @@
|
||||
#include <iostream>
|
||||
|
||||
#include <unordered_map>
|
||||
|
||||
#include <llvm/Analysis/TargetTransformInfo.h>
|
||||
#include <llvm/IR/BasicBlock.h>
|
||||
#include <llvm/IR/DataLayout.h>
|
||||
#include <llvm/IR/DerivedTypes.h>
|
||||
#include <llvm/IR/Function.h>
|
||||
#include <llvm/IR/IRBuilder.h>
|
||||
#include <llvm/IR/LLVMContext.h>
|
||||
#include <llvm/IR/Mangler.h>
|
||||
#include <llvm/IR/Module.h>
|
||||
#include <llvm/IR/Type.h>
|
||||
#include <llvm/IR/LegacyPassManager.h>
|
||||
#include <llvm/ExecutionEngine/ExecutionEngine.h>
|
||||
#include <llvm/ExecutionEngine/JITSymbol.h>
|
||||
#include <llvm/ExecutionEngine/SectionMemoryManager.h>
|
||||
#include <llvm/ExecutionEngine/Orc/CompileUtils.h>
|
||||
#include <llvm/ExecutionEngine/Orc/IRCompileLayer.h>
|
||||
#include <llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h>
|
||||
#include <llvm/Target/TargetMachine.h>
|
||||
#include <llvm/MC/SubtargetFeature.h>
|
||||
#include <llvm/Support/DynamicLibrary.h>
|
||||
#include <llvm/Support/Host.h>
|
||||
#include <llvm/Support/TargetRegistry.h>
|
||||
#include <llvm/Support/TargetSelect.h>
|
||||
#include <llvm/Transforms/IPO/PassManagerBuilder.h>
|
||||
|
||||
static llvm::TargetMachine * getNativeMachine()
|
||||
{
|
||||
std::string error;
|
||||
auto cpu = llvm::sys::getHostCPUName();
|
||||
auto triple = llvm::sys::getProcessTriple();
|
||||
const auto * target = llvm::TargetRegistry::lookupTarget(triple, error);
|
||||
if (!target)
|
||||
{
|
||||
std::cerr << "No target " << error << std::endl;
|
||||
std::terminate();
|
||||
}
|
||||
|
||||
llvm::SubtargetFeatures features;
|
||||
llvm::StringMap<bool> feature_map;
|
||||
if (llvm::sys::getHostCPUFeatures(feature_map))
|
||||
for (auto & f : feature_map)
|
||||
features.AddFeature(f.first(), f.second);
|
||||
llvm::TargetOptions options;
|
||||
return target->createTargetMachine(
|
||||
triple, cpu, features.getString(), options, llvm::None,
|
||||
llvm::None, llvm::CodeGenOpt::Default, /*jit=*/true
|
||||
);
|
||||
}
|
||||
#include <Interpreters/JIT/CHJIT.h>
|
||||
|
||||
void test_function()
|
||||
{
|
||||
std::cerr << "TestFunction" << std::endl;
|
||||
std::cerr << "Test function" << std::endl;
|
||||
}
|
||||
|
||||
namespace {
|
||||
// Trivial implementation of SectionMemoryManager::MemoryMapper that just calls
|
||||
// into sys::Memory.
|
||||
class CustomMMapper final : public llvm::SectionMemoryManager::MemoryMapper {
|
||||
public:
|
||||
llvm::sys::MemoryBlock
|
||||
allocateMappedMemory(llvm::SectionMemoryManager::AllocationPurpose Purpose,
|
||||
size_t NumBytes, const llvm::sys::MemoryBlock *const NearBlock,
|
||||
unsigned Flags, std::error_code &EC) override {
|
||||
(void)(Purpose);
|
||||
auto result_block = llvm::sys::Memory::allocateMappedMemory(NumBytes, NearBlock, Flags, EC);
|
||||
|
||||
// std::cerr << "CustomMMapper::allocateMappedMemory " << NumBytes << " result block " << result_block.base();
|
||||
// std::cerr << " allocated size " << result_block.allocatedSize() << std::endl;
|
||||
|
||||
return result_block;
|
||||
}
|
||||
|
||||
std::error_code protectMappedMemory(const llvm::sys::MemoryBlock &Block,
|
||||
unsigned Flags) override {
|
||||
// std::cerr << "CustomMMapper::protectMappedMemory " << Block.base() << " " << Block.allocatedSize() << std::endl;
|
||||
return llvm::sys::Memory::protectMappedMemory(Block, Flags);
|
||||
}
|
||||
|
||||
std::error_code releaseMappedMemory(llvm::sys::MemoryBlock &M) override {
|
||||
// std::cerr << "CustomMMapper::releaseMappedMemory " << M.base() << " size " << M.allocatedSize() << std::endl;
|
||||
return llvm::sys::Memory::releaseMappedMemory(M);
|
||||
}
|
||||
};
|
||||
|
||||
}
|
||||
|
||||
CustomMMapper DefaultMMapperInstance;
|
||||
|
||||
class CustomMemoryManager : public llvm::SectionMemoryManager
|
||||
{
|
||||
public:
|
||||
CustomMemoryManager() : llvm::SectionMemoryManager(&DefaultMMapperInstance)
|
||||
{
|
||||
// std::cerr << "CustomMemoryManager::constructor" << std::endl;
|
||||
}
|
||||
|
||||
uint8_t *allocateCodeSection(uintptr_t Size, unsigned Alignment,
|
||||
unsigned SectionID,
|
||||
llvm::StringRef SectionName) override
|
||||
{
|
||||
// std::cerr << "CustomMemoryManager::allocateCodeSection " << Size << " " << Alignment << std::endl;
|
||||
return llvm::SectionMemoryManager::allocateCodeSection(Size, Alignment, SectionID, SectionName);
|
||||
}
|
||||
|
||||
uint8_t *allocateDataSection(uintptr_t Size, unsigned Alignment,
|
||||
unsigned SectionID, llvm::StringRef SectionName,
|
||||
bool isReadOnly) override
|
||||
{
|
||||
// std::cerr << "CustomMemoryManager::allocateDataSection " << Size << " " << Alignment << std::endl;
|
||||
return llvm::SectionMemoryManager::allocateDataSection(Size, Alignment, SectionID, SectionName, isReadOnly);
|
||||
}
|
||||
|
||||
bool finalizeMemory(std::string *ErrMsg = nullptr) override
|
||||
{
|
||||
// std::cerr << "CustomMemoryManager::finalizeMemory" << std::endl;
|
||||
return llvm::SectionMemoryManager::finalizeMemory(ErrMsg);
|
||||
}
|
||||
|
||||
~CustomMemoryManager() override
|
||||
{
|
||||
// std::cerr << "CustomMemoryManager::destructor" << std::endl;
|
||||
}
|
||||
};
|
||||
|
||||
class CHCompiler: public llvm::orc::SimpleCompiler
|
||||
{
|
||||
public:
|
||||
using Base = llvm::orc::SimpleCompiler;
|
||||
using Base::Base;
|
||||
|
||||
typename Base::CompileResult operator()(llvm::Module &M)
|
||||
{
|
||||
// std::cerr << "CHCompiler::operator() module " << std::string(M.getName()) << std::endl;
|
||||
auto compile_result = Base::operator()(M);
|
||||
// auto buffer = compile_result->getBuffer();
|
||||
// std::cerr << "Compile result " << static_cast<const void*>(buffer.data()) << " compile size " << buffer.size() << std::endl;
|
||||
return compile_result;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
class CHObjectMaterializationUnit : public llvm::orc::BasicObjectLayerMaterializationUnit
|
||||
{
|
||||
public:
|
||||
|
||||
using Base = llvm::orc::BasicObjectLayerMaterializationUnit;
|
||||
using Base::Base;
|
||||
|
||||
static llvm::Expected<std::unique_ptr<BasicObjectLayerMaterializationUnit>>
|
||||
Create(llvm::orc::ObjectLayer &L, llvm::orc::VModuleKey Key, std::unique_ptr<llvm::MemoryBuffer> O) {
|
||||
std::cerr << "CHObjectMaterializationUnit::constructor" << std::endl;
|
||||
|
||||
auto symbol_flags =
|
||||
getObjectSymbolFlags(L.getExecutionSession(), O->getMemBufferRef());
|
||||
|
||||
if (!symbol_flags)
|
||||
return symbol_flags.takeError();
|
||||
|
||||
return std::unique_ptr<BasicObjectLayerMaterializationUnit>(
|
||||
new CHObjectMaterializationUnit(L, Key, std::move(O),
|
||||
std::move(*symbol_flags)));
|
||||
}
|
||||
|
||||
void discard(const llvm::orc::JITDylib &JD, const llvm::orc::SymbolStringPtr &Name) override
|
||||
{
|
||||
std::cerr << "CHObjectMaterializationUnit::discard jd " << JD.getName() << " name " << std::string(*Name) << std::endl;
|
||||
}
|
||||
|
||||
~CHObjectMaterializationUnit() override
|
||||
{
|
||||
std::cerr << "CHObjectMaterializationUnit::destructor" << std::endl;
|
||||
}
|
||||
};
|
||||
|
||||
class CHRTDyldObjectLinkingLayer: public llvm::orc::RTDyldObjectLinkingLayer
|
||||
{
|
||||
public:
|
||||
using Base = llvm::orc::RTDyldObjectLinkingLayer;
|
||||
|
||||
using Base::Base;
|
||||
|
||||
void emit(llvm::orc::MaterializationResponsibility R, std::unique_ptr<llvm::MemoryBuffer> O) override
|
||||
{
|
||||
std::cerr << "CHRTDyldObjectLinkingLayer::emit jitdylib " << R.getTargetJITDylib().getName() << std::endl;
|
||||
Base::emit(std::move(R), std::move(O));
|
||||
}
|
||||
|
||||
llvm::Error add(llvm::orc::JITDylib &JD, std::unique_ptr<llvm::MemoryBuffer> O, llvm::orc::VModuleKey K = llvm::orc::VModuleKey()) override
|
||||
{
|
||||
std::cerr << "CHRTDyldObjectLinkingLayer::add " << JD.getName() << std::endl;
|
||||
|
||||
auto object_mu = CHObjectMaterializationUnit::Create(*this, std::move(K),
|
||||
std::move(O));
|
||||
if (!object_mu)
|
||||
return object_mu.takeError();
|
||||
|
||||
return JD.define(std::move(*object_mu));
|
||||
}
|
||||
};
|
||||
|
||||
class CHIRLayerMaterializationUnit : public llvm::orc::BasicIRLayerMaterializationUnit
|
||||
{
|
||||
public:
|
||||
using Base = llvm::orc::BasicIRLayerMaterializationUnit;
|
||||
using Base::Base;
|
||||
|
||||
void discard(const llvm::orc::JITDylib &JD, const llvm::orc::SymbolStringPtr &Name) override
|
||||
{
|
||||
std::cerr << "CHIRLayerMaterializationUnit::discard " << JD.getName() << " symbol " << std::string(*Name) << std::endl;
|
||||
}
|
||||
|
||||
~CHIRLayerMaterializationUnit() override
|
||||
{
|
||||
std::cerr << "CHIRLayerMaterializationUnit::~CHIRLayerMaterializationUnit" << std::endl;
|
||||
}
|
||||
};
|
||||
|
||||
class CHIRCompileLayer: public llvm::orc::IRCompileLayer
|
||||
{
|
||||
public:
|
||||
using Base = llvm::orc::IRCompileLayer;
|
||||
using Base::Base;
|
||||
|
||||
llvm::Error add(llvm::orc::JITDylib &JD, llvm::orc::ThreadSafeModule TSM, llvm::orc::VModuleKey K) override {
|
||||
std::cerr << "CHIRCompileLayer::add " << JD.getName() << std::endl;
|
||||
auto materialization_unit = llvm::make_unique<CHIRLayerMaterializationUnit>(*this, std::move(K), std::move(TSM));
|
||||
auto symbols = materialization_unit->getSymbols();
|
||||
std::cerr << "CHIRCompileLayer:: symbols in materialization unit " << symbols.size() << std::endl;
|
||||
for (const auto & symbol : symbols)
|
||||
{
|
||||
std::cerr << std::string(*symbol.getFirst()) << std::endl;
|
||||
}
|
||||
|
||||
return JD.define(materialization_unit);
|
||||
}
|
||||
|
||||
void emit(llvm::orc::MaterializationResponsibility R, llvm::orc::ThreadSafeModule TSM) override
|
||||
{
|
||||
std::cerr << "CHIRCompileLayer::emit " << R.getTargetJITDylib().getName() << std::endl;
|
||||
Base::emit(std::move(R), std::move(TSM));
|
||||
}
|
||||
};
|
||||
|
||||
class MCJitWrapper
|
||||
{
|
||||
public:
|
||||
llvm::LLVMContext context;
|
||||
std::unique_ptr<llvm::TargetMachine> machine {getNativeMachine()};
|
||||
llvm::orc::SimpleCompiler compiler;
|
||||
|
||||
MCJitWrapper()
|
||||
{
|
||||
// std::cerr << "Engine " << engine << " builder " << *builder.getErrorStr() << std::endl;
|
||||
|
||||
std::unique_ptr<llvm::Module> module = std::make_unique<llvm::Module>("jit", context);
|
||||
generateFunctionInModule(*module);
|
||||
|
||||
// auto builder = llvm::EngineBuilder(std::move(module));
|
||||
// auto * engine = builder
|
||||
// .setEngineKind(llvm::EngineKind::JIT)
|
||||
// .setMemoryManager(std::make_unique<llvm::SectionMemoryManager>())
|
||||
// .create(getNativeMachine());
|
||||
// engine->finalizeObject();
|
||||
// auto test_function = engine->getFunctionAddress("test1");
|
||||
|
||||
// std::cerr << "Test function address " << test_function << std::endl;
|
||||
// auto test_function_typed = reinterpret_cast<int64_t (*)(int64_t)>(test_function);
|
||||
|
||||
// int64_t result = 5;
|
||||
|
||||
// while (result != 15)
|
||||
// {
|
||||
// result = test_function_typed(result);
|
||||
// std::cerr << "Result " << result << std::endl;
|
||||
// }
|
||||
|
||||
llvm::cantFail(module->materializeAll());
|
||||
|
||||
auto buffer = compiler(*module);
|
||||
|
||||
llvm::Expected<std::unique_ptr<llvm::object::ObjectFile>> object = llvm::object::ObjectFile::createObjectFile(*buffer);
|
||||
|
||||
if (!object)
|
||||
{
|
||||
logAllUnhandledErrors(object.takeError(), llvm::errs());
|
||||
return;
|
||||
}
|
||||
|
||||
std::unique_ptr<RuntimeDyld::LoadedObjectInfo> L = Dyld.loadObject(*LoadedObject.get());
|
||||
|
||||
if (Dyld.hasError())
|
||||
report_fatal_error(Dyld.getErrorString());
|
||||
|
||||
// llvm::legacy::PassManager pass_manager;
|
||||
|
||||
// // The RuntimeDyld will take ownership of this shortly
|
||||
// llvm::SmallVector<char, 4096> object_buffer_vector;
|
||||
// llvm::raw_svector_ostream object_buffer_stream(object_buffer_vector);
|
||||
|
||||
// // Turn the machine code intermediate representation into bytes in memory
|
||||
// // that may be executed.
|
||||
// if (machine->addPassesToEmitMC(pass_manager, nullptr, object_buffer_stream, true))
|
||||
// report_fatal_error("Target does not support MC emission!");
|
||||
|
||||
// Initialize passes.
|
||||
// PM.run(*M);
|
||||
// Flush the output buffer to get the generated code into memory
|
||||
|
||||
// std::unique_ptr<MemoryBuffer> CompiledObjBuffer(
|
||||
// new SmallVectorMemoryBuffer(std::move(ObjBufferSV)));
|
||||
}
|
||||
|
||||
static void generateFunctionInModule(llvm::Module & module)
|
||||
{
|
||||
llvm::IRBuilder<> b(module.getContext());
|
||||
|
||||
auto * func_type = llvm::FunctionType::get(b.getInt64Ty(), { b.getInt64Ty() }, /*isVarArg=*/false);
|
||||
auto * func = llvm::Function::Create(
|
||||
func_type,
|
||||
llvm::Function::LinkageTypes::ExternalLinkage,
|
||||
"test1",
|
||||
module);
|
||||
|
||||
auto * entry = llvm::BasicBlock::Create(b.getContext(), "entry", func);
|
||||
b.SetInsertPoint(entry);
|
||||
|
||||
// auto * argument = func->args().begin();
|
||||
|
||||
auto * loop = llvm::BasicBlock::Create(b.getContext(), "loop", func);
|
||||
b.CreateBr(loop);
|
||||
b.SetInsertPoint(loop);
|
||||
|
||||
auto * counter = b.CreatePHI(b.getInt64Ty(), 2);
|
||||
counter->addIncoming(llvm::ConstantInt::get(b.getInt64Ty(), 0), entry);
|
||||
|
||||
auto * add_value = b.CreateAdd(counter, llvm::ConstantInt::get(b.getInt64Ty(), 1));
|
||||
counter->addIncoming(add_value, loop);
|
||||
|
||||
auto * end = llvm::BasicBlock::Create(b.getContext(), "end", func);
|
||||
b.CreateCondBr(b.CreateICmpNE(counter, llvm::ConstantInt::get(b.getInt64Ty(), 5000000)), loop, end);
|
||||
b.SetInsertPoint(end);
|
||||
|
||||
// auto * result = b.CreateAdd(phi, argument);
|
||||
b.CreateRet(counter);
|
||||
|
||||
// module.print(llvm::errs(), nullptr);
|
||||
}
|
||||
|
||||
void optimizeModuleFunctions(llvm::Module & module) const
|
||||
{
|
||||
llvm::PassManagerBuilder pass_manager_builder;
|
||||
llvm::legacy::PassManager mpm;
|
||||
llvm::legacy::FunctionPassManager fpm(&module);
|
||||
pass_manager_builder.OptLevel = 3;
|
||||
pass_manager_builder.SLPVectorize = true;
|
||||
pass_manager_builder.LoopVectorize = true;
|
||||
pass_manager_builder.RerollLoops = true;
|
||||
pass_manager_builder.VerifyInput = true;
|
||||
pass_manager_builder.VerifyOutput = true;
|
||||
machine->adjustPassManager(pass_manager_builder);
|
||||
|
||||
fpm.add(llvm::createTargetTransformInfoWrapperPass(machine->getTargetIRAnalysis()));
|
||||
mpm.add(llvm::createTargetTransformInfoWrapperPass(machine->getTargetIRAnalysis()));
|
||||
pass_manager_builder.populateFunctionPassManager(fpm);
|
||||
pass_manager_builder.populateModulePassManager(mpm);
|
||||
fpm.doInitialization();
|
||||
for (auto & function : module)
|
||||
fpm.run(function);
|
||||
fpm.doFinalization();
|
||||
mpm.run(module);
|
||||
|
||||
module.print(llvm::errs(), nullptr);
|
||||
}
|
||||
};
|
||||
|
||||
struct LLVMContext
|
||||
{
|
||||
llvm::orc::ThreadSafeContext context { std::make_unique<llvm::LLVMContext>() };
|
||||
std::unique_ptr<llvm::Module> module {std::make_unique<llvm::Module>("jit", *context.getContext())};
|
||||
std::unique_ptr<llvm::TargetMachine> machine {getNativeMachine()};
|
||||
llvm::DataLayout layout {machine->createDataLayout()};
|
||||
llvm::IRBuilder<> builder {*context.getContext()};
|
||||
|
||||
llvm::orc::ExecutionSession execution_session;
|
||||
|
||||
CHRTDyldObjectLinkingLayer object_layer;
|
||||
std::unique_ptr<CHCompiler> compiler;
|
||||
CHIRCompileLayer compile_layer;
|
||||
llvm::orc::MangleAndInterner mangler;
|
||||
|
||||
std::unordered_map<std::string, void *> symbols;
|
||||
|
||||
std::vector<llvm::orc::VModuleKey> modules;
|
||||
|
||||
LLVMContext()
|
||||
: object_layer(execution_session, []() {
|
||||
std::cerr << "CHRTDyldObjectLinkingLayer get SectionMemoryManager" << std::endl;
|
||||
return std::make_unique<CustomMemoryManager>();
|
||||
})
|
||||
, compiler(std::make_unique<CHCompiler>(*machine))
|
||||
, compile_layer(execution_session, object_layer, *compiler)
|
||||
, mangler(execution_session, layout)
|
||||
{
|
||||
module->setDataLayout(layout);
|
||||
module->setTargetTriple(machine->getTargetTriple().getTriple());
|
||||
|
||||
auto pointer = llvm::pointerToJITTargetAddress(&test_function);
|
||||
auto symbol = mangler("test_function");
|
||||
|
||||
llvm::orc::SymbolMap map;
|
||||
map[symbol] = llvm::JITEvaluatedSymbol(pointer, llvm::JITSymbolFlags::Exported | llvm::JITSymbolFlags::Absolute);
|
||||
|
||||
auto error = execution_session.getMainJITDylib().define(llvm::orc::absoluteSymbols(map));
|
||||
bool is_error = static_cast<bool>(error);
|
||||
std::cerr << "Error " << is_error << std::endl;
|
||||
|
||||
// if (error)
|
||||
// {
|
||||
// std::cerr << "Could not define symbols " << error-> << std::endl;
|
||||
// std::terminate();
|
||||
// }
|
||||
}
|
||||
|
||||
/// returns used memory
|
||||
void compileAllFunctionsToNativeCode()
|
||||
{
|
||||
if (module->empty())
|
||||
return;
|
||||
|
||||
llvm::PassManagerBuilder pass_manager_builder;
|
||||
llvm::legacy::PassManager mpm;
|
||||
llvm::legacy::FunctionPassManager fpm(module.get());
|
||||
pass_manager_builder.OptLevel = 3;
|
||||
pass_manager_builder.SLPVectorize = true;
|
||||
pass_manager_builder.LoopVectorize = true;
|
||||
pass_manager_builder.RerollLoops = true;
|
||||
pass_manager_builder.VerifyInput = true;
|
||||
pass_manager_builder.VerifyOutput = true;
|
||||
machine->adjustPassManager(pass_manager_builder);
|
||||
|
||||
fpm.add(llvm::createTargetTransformInfoWrapperPass(machine->getTargetIRAnalysis()));
|
||||
mpm.add(llvm::createTargetTransformInfoWrapperPass(machine->getTargetIRAnalysis()));
|
||||
pass_manager_builder.populateFunctionPassManager(fpm);
|
||||
pass_manager_builder.populateModulePassManager(mpm);
|
||||
fpm.doInitialization();
|
||||
for (auto & function : *module)
|
||||
fpm.run(function);
|
||||
fpm.doFinalization();
|
||||
mpm.run(*module);
|
||||
|
||||
std::vector<std::string> functions;
|
||||
functions.reserve(module->size());
|
||||
for (const auto & function : *module)
|
||||
functions.emplace_back(function.getName());
|
||||
|
||||
llvm::orc::VModuleKey module_key = execution_session.allocateVModule();
|
||||
llvm::orc::ThreadSafeModule thread_safe_module(std::move(module), context);
|
||||
modules.emplace_back(module_key);
|
||||
|
||||
if (compile_layer.add(execution_session.getMainJITDylib(), std::move(thread_safe_module), module_key))
|
||||
{
|
||||
std::cerr << "Terminate because cannot add module" << std::endl;
|
||||
std::terminate();
|
||||
}
|
||||
|
||||
std::cerr << "Module key " << module_key << std::endl;
|
||||
|
||||
// for (const auto & name : functions)
|
||||
// {
|
||||
// auto symbol = execution_session.lookup({&execution_session.getMainJITDylib()}, mangler(name));
|
||||
// if (!symbol)
|
||||
// continue; /// external function (e.g. an intrinsic that calls into libc)
|
||||
|
||||
// auto address = symbol->getAddress();
|
||||
// if (!address)
|
||||
// {
|
||||
// std::cerr << "Terminate because cannot add module" << std::endl;
|
||||
// std::terminate();
|
||||
// }
|
||||
|
||||
// std::cerr << "Name " << name << " address " << reinterpret_cast<void *>(address) << std::endl;
|
||||
|
||||
// symbols[name] = reinterpret_cast<void *>(address);
|
||||
// }
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
int main(int argc, char **argv)
|
||||
{
|
||||
(void)(argc);
|
||||
(void)(argv);
|
||||
|
||||
llvm::InitializeNativeTarget();
|
||||
llvm::InitializeNativeTargetAsmPrinter();
|
||||
LLVMLinkInMCJIT();
|
||||
auto jit = DB::CHJIT();
|
||||
|
||||
// std::string error_message;
|
||||
// bool load_permanently = llvm::sys::DynamicLibrary::LoadLibraryPermanently(nullptr, &error_message);
|
||||
// std::cerr << "Load " << load_permanently << " error " << error_message << std::endl;
|
||||
jit.registerExternalSymbol("test_function", reinterpret_cast<void *>(&test_function));
|
||||
|
||||
auto wrapper = MCJitWrapper();
|
||||
auto module = jit.createModuleForCompilation();
|
||||
|
||||
// LLVMContext context;
|
||||
// auto & b = context.builder;
|
||||
// auto * integer_type = b.getInt64Ty();
|
||||
// auto * func_type = llvm::FunctionType::get(integer_type, { integer_type }, /*isVarArg=*/false);
|
||||
auto & context = module->getContext();
|
||||
llvm::IRBuilder<> b (module->getContext());
|
||||
|
||||
// std::cerr << "Context module " << context.module.get() << std::endl;
|
||||
auto * func_declaration_type = llvm::FunctionType::get(b.getVoidTy(), { }, /*isVarArg=*/false);
|
||||
auto * func_declaration = llvm::Function::Create(func_declaration_type, llvm::Function::ExternalLinkage, "test_function", module.get());
|
||||
|
||||
// auto * standard_function_type = llvm::FunctionType::get(b.getVoidTy(), {}, false);
|
||||
// auto * standard_function = llvm::Function::Create(
|
||||
// standard_function_type,
|
||||
// llvm::Function::LinkageTypes::ExternalLinkage,
|
||||
// "test_function",
|
||||
// *context.module);
|
||||
// standard_function->setCallingConv(llvm::CallingConv::C);
|
||||
auto * func_type = llvm::FunctionType::get(b.getVoidTy(), { b.getInt64Ty() }, /*isVarArg=*/false);
|
||||
auto * function = llvm::Function::Create(func_type, llvm::Function::ExternalLinkage, "test_name", module.get());
|
||||
auto * entry = llvm::BasicBlock::Create(context, "entry", function);
|
||||
|
||||
// auto * func = llvm::Function::Create(func_type, llvm::Function::LinkageTypes::ExternalLinkage, "test1", context.module.get());
|
||||
auto * argument = function->args().begin();
|
||||
b.SetInsertPoint(entry);
|
||||
|
||||
// auto * entry = llvm::BasicBlock::Create(b.getContext(), "entry", func);
|
||||
// b.SetInsertPoint(entry);
|
||||
b.CreateCall(func_declaration);
|
||||
|
||||
// auto * argument = func->args().begin();
|
||||
auto * value = b.CreateAdd(argument, argument);
|
||||
b.CreateRet(value);
|
||||
|
||||
// auto * value = llvm::ConstantInt::get(b.getInt64Ty(), 1);
|
||||
// auto * loop_block = llvm::BasicBlock::Create(b.getContext(), "loop", func);
|
||||
// b.CreateBr(loop_block);
|
||||
module->print(llvm::errs(), nullptr);
|
||||
|
||||
// b.SetInsertPoint(loop_block);
|
||||
auto compiled_module_info = jit.compileModule(std::move(module));
|
||||
|
||||
// auto * phi_value = b.CreatePHI(b.getInt64Ty(), 2);
|
||||
// phi_value->addIncoming(value, entry);
|
||||
std::cerr << "Compile module info " << compiled_module_info.module_identifier << " size " << compiled_module_info.size << std::endl;
|
||||
for (const auto & compiled_function_name : compiled_module_info.compiled_functions)
|
||||
{
|
||||
std::cerr << compiled_function_name << std::endl;
|
||||
}
|
||||
|
||||
// b.CreateCall(standard_function);
|
||||
|
||||
// auto * add_value = b.CreateAdd(phi_value, llvm::ConstantInt::get(b.getInt64Ty(), 1));
|
||||
// phi_value->addIncoming(add_value, loop_block);
|
||||
|
||||
// auto * end = llvm::BasicBlock::Create(b.getContext(), "end", func);
|
||||
// b.CreateCondBr(b.CreateICmpNE(phi_value, llvm::ConstantInt::get(b.getInt64Ty(), 10)), loop_block, end);
|
||||
|
||||
// b.SetInsertPoint(end);
|
||||
|
||||
// auto * result = b.CreateAdd(phi_value, argument);
|
||||
// b.CreateRet(result);
|
||||
|
||||
// std::cerr << "Context module " << context.module.get() << std::endl;
|
||||
// if (context.module)
|
||||
// context.module->print(llvm::errs(), nullptr);
|
||||
|
||||
// context.compileAllFunctionsToNativeCode();
|
||||
|
||||
// context.module->print(llvm::errs(), nullptr);
|
||||
|
||||
std::cerr << "ExecutionSession before module release dump " << std::endl;
|
||||
// context.execution_session.dump(llvm::errs());
|
||||
|
||||
// for (auto module_key : context.modules)
|
||||
// context.execution_session.releaseVModule(module_key);
|
||||
|
||||
// llvm::orc::SymbolNameSet set;
|
||||
|
||||
// auto ptr = context.execution_session.intern("test1");
|
||||
// set.insert(ptr);
|
||||
|
||||
// auto error = context.execution_session.getMainJITDylib().remove(set);
|
||||
|
||||
// if (error)
|
||||
// llvm::logAllUnhandledErrors(std::move(error), llvm::errs(), "Error logging ");
|
||||
|
||||
// std::cerr << "ExecutionSession after module release dump " << std::endl;
|
||||
// context.execution_session.dump(llvm::errs());
|
||||
|
||||
// auto * symbol = context.symbols.at("test1");
|
||||
// auto compiled_func = reinterpret_cast<int64_t (*)(int64_t)>(symbol);
|
||||
// std::cerr << "Function " << compiled_func(5) << std::endl;
|
||||
auto * test_name_function = reinterpret_cast<int64_t (*)(int64_t)>(jit.findCompiledFunction("test_name"));
|
||||
auto result = test_name_function(5);
|
||||
std::cerr << "Result " << result << std::endl;
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user