Refactored interfaces

This commit is contained in:
Maksim Kita 2021-05-04 23:32:43 +03:00
parent 24798ef07c
commit c79d7eae21
14 changed files with 751 additions and 633 deletions

View File

@ -879,7 +879,8 @@ int Server::main(const std::vector<std::string> & /*args*/)
global_context->setMMappedFileCache(mmap_cache_size); global_context->setMMappedFileCache(mmap_cache_size);
#if USE_EMBEDDED_COMPILER #if USE_EMBEDDED_COMPILER
size_t compiled_expression_cache_size = config().getUInt64("compiled_expression_cache_size", 500); constexpr size_t compiled_expression_cache_size_default = 1024 * 1024 * 1024;
size_t compiled_expression_cache_size = config().getUInt64("compiled_expression_cache_size", compiled_expression_cache_size_default);
CompiledExpressionCacheFactory::instance().init(compiled_expression_cache_size); CompiledExpressionCacheFactory::instance().init(compiled_expression_cache_size);
#endif #endif

View File

@ -10,6 +10,8 @@
# include <DataTypes/IDataType.h> # include <DataTypes/IDataType.h>
# include <DataTypes/DataTypeNullable.h> # include <DataTypes/DataTypeNullable.h>
# include <DataTypes/DataTypeFixedString.h> # include <DataTypes/DataTypeFixedString.h>
# include <Columns/ColumnConst.h>
# include <Columns/ColumnNullable.h>
# pragma GCC diagnostic push # pragma GCC diagnostic push
# pragma GCC diagnostic ignored "-Wunused-parameter" # pragma GCC diagnostic ignored "-Wunused-parameter"
@ -142,6 +144,74 @@ static inline llvm::Value * nativeCast(llvm::IRBuilder<> & b, const DataTypePtr
return nativeCast(b, from, value, n_to); return nativeCast(b, from, value, n_to);
} }
static inline llvm::Constant * getColumnNativeConstant(llvm::Type * native_type, WhichDataType column_data_type, const IColumn & column)
{
llvm::Constant * result = nullptr;
if (column_data_type.isFloat32())
{
result = llvm::ConstantFP::get(native_type, column.getFloat32(0));
}
else if (column_data_type.isFloat64())
{
result = llvm::ConstantFP::get(native_type, column.getFloat64(0));
}
else if (column_data_type.isNativeInt())
{
result = llvm::ConstantInt::get(native_type, column.getInt(0));
}
else if (column_data_type.isNativeUInt())
{
result = llvm::ConstantInt::get(native_type, column.getUInt(0));
}
return result;
}
static inline llvm::Constant * getColumnNativeValue(llvm::IRBuilderBase & builder, const DataTypePtr & column_type, const IColumn & column, size_t index)
{
WhichDataType column_data_type(column_type);
auto * type = toNativeType(builder, column_type);
if (!type || column.size() <= index)
return nullptr;
if (const auto * constant = typeid_cast<const ColumnConst *>(&column))
{
return getColumnNativeValue(builder, column_type, constant->getDataColumn(), 0);
}
else if (column_data_type.isNullable())
{
const auto & nullable_data_type = assert_cast<const DataTypeNullable &>(*column_type);
const auto & nullable_column = assert_cast<const ColumnNullable &>(column);
auto * value = getColumnNativeValue(builder, nullable_data_type.getNestedType(), nullable_column.getNestedColumn(), index);
auto * is_null = llvm::ConstantInt::get(type->getContainedType(1), nullable_column.isNullAt(index));
return value ? llvm::ConstantStruct::get(static_cast<llvm::StructType *>(type), value, is_null) : nullptr;
}
else if (column_data_type.isFloat32())
{
return llvm::ConstantFP::get(type, assert_cast<const ColumnVector<Float32> &>(column).getElement(index));
}
else if (column_data_type.isFloat64())
{
return llvm::ConstantFP::get(type, assert_cast<const ColumnVector<Float64> &>(column).getElement(index));
}
else if (column_data_type.isNativeUInt())
{
return llvm::ConstantInt::get(type, column.getUInt(index));
}
else if (column_data_type.isNativeInt())
{
return llvm::ConstantInt::get(type, column.getInt(index));
}
return nullptr;
}
} }
#endif #endif

View File

@ -557,65 +557,6 @@ llvm::Value * IFunction::compile(llvm::IRBuilderBase & builder, const DataTypes
nullable_structure_result_null = b.CreateOr(nullable_structure_result_null, is_null_value); nullable_structure_result_null = b.CreateOr(nullable_structure_result_null, is_null_value);
return b.CreateInsertValue(nullable_structure_with_result_value, nullable_structure_result_null, {1}); return b.CreateInsertValue(nullable_structure_with_result_value, nullable_structure_result_null, {1});
// DataTypes non_null_arguments;
// non_null_arguments.reserve(arguments.size());
// for (size_t i = 0; i < arguments.size(); ++i)
// {
// WhichDataType data_type(arguments[i]);
// if (data_type.isNullable())
// {
// }
// else
// {
// }
// non_null_arguments.emplace_back(removeNullable(arguments[i]));
// auto * value = values[i]();
// }
// if (auto denulled = removeNullables(arguments))
// {
// DataTypes denulled_types = *denulled;
// std::cerr << "IFunction::denulled types " << std::endl;
// for (size_t i = 0; i < denulled_types.size(); ++i)
// {
// std::cerr << "Index " << i << " name " << denulled_types[i]->getName() << std::endl;
// }
// /// FIXME: when only one column is nullable, this can actually be slower than the non-jitted version
// /// because this involves copying the null map while `wrapInNullable` reuses it.
// auto & b = static_cast<llvm::IRBuilder<> &>(builder);
// auto * fail = llvm::BasicBlock::Create(b.GetInsertBlock()->getContext(), "", b.GetInsertBlock()->getParent());
// auto * join = llvm::BasicBlock::Create(b.GetInsertBlock()->getContext(), "", b.GetInsertBlock()->getParent());
// auto * zero = llvm::Constant::getNullValue(toNativeType(b, makeNullable(getReturnTypeImpl(*denulled))));
// for (size_t i = 0; i < arguments.size(); i++)
// {
// if (!arguments[i]->isNullable())
// continue;
// /// Would be nice to evaluate all this lazily, but that'd change semantics: if only unevaluated
// /// arguments happen to contain NULLs, the return value would not be NULL, though it should be.
// auto * value = values[i]();
// auto * ok = llvm::BasicBlock::Create(b.GetInsertBlock()->getContext(), "", b.GetInsertBlock()->getParent());
// b.CreateCondBr(b.CreateExtractValue(value, {1}), fail, ok);
// b.SetInsertPoint(ok);
// values[i] = [value = b.CreateExtractValue(value, {0})]() { return value; };
// }
// auto * result = b.CreateInsertValue(zero, compileImpl(builder, *denulled, std::move(values)), {0});
// auto * result_columns = b.GetInsertBlock();
// b.CreateBr(join);
// b.SetInsertPoint(fail);
// auto * null = b.CreateInsertValue(zero, b.getTrue(), {1});
// b.CreateBr(join);
// b.SetInsertPoint(join);
// auto * phi = b.CreatePHI(result->getType(), 2);
// phi->addIncoming(result, result_columns);
// phi->addIncoming(null, fail);
// return phi;
// }
} }
return compileImpl(builder, arguments, std::move(values)); return compileImpl(builder, arguments, std::move(values));

View File

@ -450,6 +450,12 @@ struct ContextSharedPart
/// TODO: Get rid of this. /// TODO: Get rid of this.
delete_system_logs = std::move(system_logs); delete_system_logs = std::move(system_logs);
#if USE_EMBEDDED_COMPILER
if (auto * cache = CompiledExpressionCacheFactory::instance().tryGetCache())
cache->reset();
#endif
embedded_dictionaries.reset(); embedded_dictionaries.reset();
external_dictionaries_loader.reset(); external_dictionaries_loader.reset();
models_repository_guard.reset(); models_repository_guard.reset();

View File

@ -116,10 +116,6 @@ struct BackgroundTaskSchedulingSettings;
class ZooKeeperMetadataTransaction; class ZooKeeperMetadataTransaction;
using ZooKeeperMetadataTransactionPtr = std::shared_ptr<ZooKeeperMetadataTransaction>; using ZooKeeperMetadataTransactionPtr = std::shared_ptr<ZooKeeperMetadataTransaction>;
#if USE_EMBEDDED_COMPILER
class CompiledExpressionCache;
#endif
/// Callback for external tables initializer /// Callback for external tables initializer
using ExternalTablesInitializer = std::function<void(ContextPtr)>; using ExternalTablesInitializer = std::function<void(ContextPtr)>;

View File

@ -10,27 +10,15 @@
#include <Columns/ColumnVector.h> #include <Columns/ColumnVector.h>
#include <Common/typeid_cast.h> #include <Common/typeid_cast.h>
#include <Common/assert_cast.h> #include <Common/assert_cast.h>
#include <Common/ProfileEvents.h>
#include <Common/Stopwatch.h>
#include <DataTypes/DataTypeNullable.h> #include <DataTypes/DataTypeNullable.h>
#include <DataTypes/DataTypesNumber.h> #include <DataTypes/DataTypesNumber.h>
#include <DataTypes/Native.h> #include <DataTypes/Native.h>
#include <Functions/IFunctionAdaptors.h> #include <Functions/IFunctionAdaptors.h>
#include <IO/WriteBufferFromString.h>
#include <IO/Operators.h>
#include <llvm/IR/BasicBlock.h>
#include <llvm/IR/Function.h>
#include <llvm/IR/IRBuilder.h>
#include <Interpreters/JIT/CHJIT.h> #include <Interpreters/JIT/CHJIT.h>
#include <Interpreters/JIT/CompileDAG.h>
namespace ProfileEvents #include <Interpreters/JIT/compileFunction.h>
{ #include <Interpreters/ActionsDAG.h>
extern const Event CompileFunction;
extern const Event CompileExpressionsMicroseconds;
extern const Event CompileExpressionsBytes;
}
namespace DB namespace DB
{ {
@ -40,51 +28,11 @@ namespace ErrorCodes
extern const int LOGICAL_ERROR; extern const int LOGICAL_ERROR;
} }
namespace static CHJIT & getJITInstance()
{ {
struct ColumnData
{
const char * data = nullptr;
const char * null = nullptr;
};
struct ColumnDataPlaceholder
{
llvm::Value * data_init; /// first row
llvm::Value * null_init;
llvm::PHINode * data; /// current row
llvm::PHINode * null;
};
}
static ColumnData getColumnData(const IColumn * column)
{
ColumnData result;
const bool is_const = isColumnConst(*column);
if (is_const)
throw Exception(ErrorCodes::LOGICAL_ERROR, "Input columns should not be constant");
if (const auto * nullable = typeid_cast<const ColumnNullable *>(column))
{
result.null = nullable->getNullMapColumn().getRawData().data;
column = &nullable->getNestedColumn();
}
result.data = column->getRawData().data;
return result;
}
static void applyFunction(IFunctionBase & function, Field & value)
{
const auto & type = function.getArgumentTypes().at(0);
ColumnsWithTypeAndName args{{type->createColumnConst(1, value), type, "x" }};
auto col = function.execute(args, function.getResultType(), 1);
col->get(0, value);
}
static CHJIT jit; static CHJIT jit;
return jit;
}
class LLVMExecutableFunction : public IExecutableFunctionImpl class LLVMExecutableFunction : public IExecutableFunctionImpl
{ {
@ -94,10 +42,10 @@ public:
explicit LLVMExecutableFunction(const std::string & name_) explicit LLVMExecutableFunction(const std::string & name_)
: name(name_) : name(name_)
{ {
function = jit.findCompiledFunction(name_); function = getJITInstance().findCompiledFunction(name);
if (!function) if (!function)
throw Exception(ErrorCodes::LOGICAL_ERROR, "Cannot find compiled function {}", name_); throw Exception(ErrorCodes::LOGICAL_ERROR, "Cannot find compiled function {}", name);
} }
String getName() const override { return name; } String getName() const override { return name; }
@ -122,13 +70,14 @@ public:
{ {
const auto * column = arguments[i].column.get(); const auto * column = arguments[i].column.get();
if (!column) if (!column)
throw Exception("Column " + arguments[i].name + " is missing", ErrorCodes::LOGICAL_ERROR); throw Exception(ErrorCodes::LOGICAL_ERROR, "Column {} is missing", arguments[i].name);
columns[i] = getColumnData(column); columns[i] = getColumnData(column);
} }
columns[arguments.size()] = getColumnData(result_column.get()); columns[arguments.size()] = getColumnData(result_column.get());
reinterpret_cast<void (*) (size_t, ColumnData *)>(function)(input_rows_count, columns.data()); auto * function_typed = reinterpret_cast<void (*) (size_t, ColumnData *)>(function);
function_typed(input_rows_count, columns.data());
#if defined(MEMORY_SANITIZER) #if defined(MEMORY_SANITIZER)
/// Memory sanitizer don't know about stores from JIT-ed code. /// Memory sanitizer don't know about stores from JIT-ed code.
@ -158,298 +107,183 @@ public:
}; };
static void compileFunction(llvm::Module & module, const IFunctionBaseImpl & f) class LLVMFunction : public IFunctionBaseImpl
{ {
ProfileEvents::increment(ProfileEvents::CompileFunction); public:
const auto & arg_types = f.getArgumentTypes(); explicit LLVMFunction(const CompileDAG & dag_)
: name(dag_.dump())
llvm::IRBuilder<> b(module.getContext()); , dag(dag_)
auto * size_type = b.getIntNTy(sizeof(size_t) * 8);
auto * data_type = llvm::StructType::get(b.getInt8PtrTy(), b.getInt8PtrTy(), size_type);
auto * func_type = llvm::FunctionType::get(b.getVoidTy(), { size_type, data_type->getPointerTo() }, /*isVarArg=*/false);
auto * func = llvm::Function::Create(func_type, llvm::Function::ExternalLinkage, f.getName(), module);
auto * args = func->args().begin();
llvm::Value * counter_arg = &*args++;
llvm::Value * columns_arg = &*args++;
auto * entry = llvm::BasicBlock::Create(b.getContext(), "entry", func);
b.SetInsertPoint(entry);
std::vector<ColumnDataPlaceholder> columns(arg_types.size() + 1);
for (size_t i = 0; i <= arg_types.size(); ++i)
{ {
const auto & type = i == arg_types.size() ? f.getResultType() : arg_types[i]; for (size_t i = 0; i < dag.getNodesCount(); ++i)
auto * data = b.CreateLoad(b.CreateConstInBoundsGEP1_32(data_type, columns_arg, i)); {
columns[i].data_init = b.CreatePointerCast(b.CreateExtractValue(data, {0}), toNativeType(b, removeNullable(type))->getPointerTo()); const auto & node = dag[i];
columns[i].null_init = type->isNullable() ? b.CreateExtractValue(data, {1}) : nullptr;
if (node.type == CompileDAG::CompileType::FUNCTION)
nested_functions.emplace_back(node.function);
else if (node.type == CompileDAG::CompileType::INPUT)
argument_types.emplace_back(node.result_type);
} }
/// assume nonzero initial value in `counter_arg` module_info = compileFunction(getJITInstance(), *this);
auto * loop = llvm::BasicBlock::Create(b.getContext(), "loop", func);
b.CreateBr(loop);
b.SetInsertPoint(loop);
auto * counter_phi = b.CreatePHI(counter_arg->getType(), 2);
counter_phi->addIncoming(counter_arg, entry);
for (auto & col : columns)
{
col.data = b.CreatePHI(col.data_init->getType(), 2);
col.data->addIncoming(col.data_init, entry);
if (col.null_init)
{
col.null = b.CreatePHI(col.null_init->getType(), 2);
col.null->addIncoming(col.null_init, entry);
}
} }
Values arguments; ~LLVMFunction() override
arguments.reserve(arg_types.size());
for (size_t i = 0; i < arg_types.size(); ++i) // NOLINT
{ {
auto & column = columns[i]; getJITInstance().deleteCompiledModule(module_info);
auto type = arg_types[i];
auto * value = b.CreateLoad(column.data);
if (!column.null)
{
arguments.emplace_back(value);
continue;
} }
auto * is_null = b.CreateICmpNE(b.CreateLoad(column.null), b.getInt8(0)); size_t getCompiledSize() const { return module_info.size; }
auto * nullable_unitilized = llvm::Constant::getNullValue(toNativeType(b, type));
auto * nullable_value = b.CreateInsertValue(b.CreateInsertValue(nullable_unitilized, value, {0}), is_null, {1}); bool isCompilable() const override { return true; }
arguments.emplace_back(nullable_value);
llvm::Value * compile(llvm::IRBuilderBase & builder, Values values) const override
{
return dag.compile(builder, values);
} }
auto * result = f.compile(b, std::move(arguments)); String getName() const override { return name; }
if (columns.back().null)
const DataTypes & getArgumentTypes() const override { return argument_types; }
const DataTypePtr & getResultType() const override { return dag.back().result_type; }
ExecutableFunctionImplPtr prepare(const ColumnsWithTypeAndName &) const override
{ {
b.CreateStore(b.CreateExtractValue(result, {0}), columns.back().data); return std::make_unique<LLVMExecutableFunction>(name);
b.CreateStore(b.CreateSelect(b.CreateExtractValue(result, {1}), b.getInt8(1), b.getInt8(0)), columns.back().null);
}
else
{
b.CreateStore(result, columns.back().data);
} }
auto * cur_block = b.GetInsertBlock(); bool isDeterministic() const override
for (auto & col : columns)
{ {
col.data->addIncoming(b.CreateConstInBoundsGEP1_32(nullptr, col.data, 1), cur_block); for (const auto & f : nested_functions)
if (col.null)
col.null->addIncoming(b.CreateConstInBoundsGEP1_32(nullptr, col.null, 1), cur_block);
}
counter_phi->addIncoming(b.CreateSub(counter_phi, llvm::ConstantInt::get(size_type, 1)), cur_block);
auto * end = llvm::BasicBlock::Create(b.getContext(), "end", func);
b.CreateCondBr(b.CreateICmpNE(counter_phi, llvm::ConstantInt::get(size_type, 1)), loop, end);
b.SetInsertPoint(end);
b.CreateRetVoid();
}
static llvm::Constant * getNativeValue(llvm::Type * type, const IColumn & column, size_t i)
{
/// TODO: Change name this is just for constants
if (!type || column.size() <= i)
return nullptr;
if (const auto * constant = typeid_cast<const ColumnConst *>(&column))
return getNativeValue(type, constant->getDataColumn(), 0);
if (const auto * nullable = typeid_cast<const ColumnNullable *>(&column))
{
auto * value = getNativeValue(type->getContainedType(0), nullable->getNestedColumn(), i);
auto * is_null = llvm::ConstantInt::get(type->getContainedType(1), nullable->isNullAt(i));
return value ? llvm::ConstantStruct::get(static_cast<llvm::StructType *>(type), value, is_null) : nullptr;
}
if (type->isFloatTy())
return llvm::ConstantFP::get(type, assert_cast<const ColumnVector<Float32> &>(column).getElement(i));
if (type->isDoubleTy())
return llvm::ConstantFP::get(type, assert_cast<const ColumnVector<Float64> &>(column).getElement(i));
if (type->isIntegerTy())
return llvm::ConstantInt::get(type, column.getUInt(i));
/// TODO: if (type->isVectorTy())
return nullptr;
}
/// Same as IFunctionBase::compile, but also for constants and input columns.
static CompilableExpression subexpression(ColumnPtr c, DataTypePtr type)
{
return [=](llvm::IRBuilderBase & b, const Values &)
{
auto * native_value = getNativeValue(toNativeType(b, type), *c, 0);
return native_value;
};
}
static CompilableExpression subexpression(size_t i)
{
return [=](llvm::IRBuilderBase &, const Values & inputs)
{
auto * column = inputs[i];
return column;
};
}
static CompilableExpression subexpression(const IFunctionBase & f, std::vector<CompilableExpression> args)
{
return [&, args = std::move(args)](llvm::IRBuilderBase & builder, const Values & inputs)
{
Values input;
for (const auto & arg : args)
input.push_back(arg(builder, inputs));
auto * result = f.compile(builder, input);
if (result->getType() != toNativeType(builder, f.getResultType()))
throw Exception("Function " + f.getName() + " generated an llvm::Value of invalid type", ErrorCodes::LOGICAL_ERROR);
return result;
};
}
LLVMFunction::LLVMFunction(const CompileDAG & dag)
: name(dag.dump())
{
std::vector<CompilableExpression> expressions;
expressions.reserve(dag.size());
jit.compileModule([&](llvm::Module & module)
{
auto & context = module.getContext();
llvm::IRBuilder<> builder(context);
for (const auto & node : dag)
{
switch (node.type)
{
case CompileNode::NodeType::CONSTANT:
{
const auto * col = typeid_cast<const ColumnConst *>(node.column.get());
/// TODO: implement `getNativeValue` for all types & replace the check with `c.column && toNativeType(...)`
if (!getNativeValue(toNativeType(builder, node.result_type), col->getDataColumn(), 0))
throw Exception(ErrorCodes::LOGICAL_ERROR,
"Cannot compile constant of type {} = {}",
node.result_type->getName(),
applyVisitor(FieldVisitorToString(), col->getDataColumn()[0]));
expressions.emplace_back(subexpression(col->getDataColumnPtr(), node.result_type));
break;
}
case CompileNode::NodeType::FUNCTION:
{
std::vector<CompilableExpression> args;
args.reserve(node.arguments.size());
for (auto arg : node.arguments)
{
// std::cerr << "CompileNode::Function emplace expression " << arg << std::endl;
args.emplace_back(expressions[arg]);
}
originals.push_back(node.function);
expressions.emplace_back(subexpression(*node.function, std::move(args)));
break;
}
case CompileNode::NodeType::INPUT:
{
expressions.emplace_back(subexpression(arg_types.size()));
arg_types.push_back(node.result_type);
break;
}
}
}
expression = std::move(expressions.back());
compileFunction(module, *this);
// for (auto & func : module)
// {
// std::cerr << "Func name " << std::string(func.getName()) << std::endl;
// }
// module.print(llvm::errs(), nullptr);
});
}
llvm::Value * LLVMFunction::compile(llvm::IRBuilderBase & builder, Values values) const
{
return expression(builder, values);
}
ExecutableFunctionImplPtr LLVMFunction::prepare(const ColumnsWithTypeAndName &) const { return std::make_unique<LLVMExecutableFunction>(name); }
bool LLVMFunction::isDeterministic() const
{
for (const auto & f : originals)
if (!f->isDeterministic()) if (!f->isDeterministic())
return false; return false;
return true; return true;
} }
bool LLVMFunction::isDeterministicInScopeOfQuery() const bool isDeterministicInScopeOfQuery() const override
{ {
for (const auto & f : originals) for (const auto & f : nested_functions)
if (!f->isDeterministicInScopeOfQuery()) if (!f->isDeterministicInScopeOfQuery())
return false; return false;
return true; return true;
} }
bool LLVMFunction::isSuitableForConstantFolding() const bool isSuitableForConstantFolding() const override
{ {
for (const auto & f : originals) for (const auto & f : nested_functions)
if (!f->isSuitableForConstantFolding()) if (!f->isSuitableForConstantFolding())
return false; return false;
return true; return true;
} }
bool LLVMFunction::isInjective(const ColumnsWithTypeAndName & sample_block) const bool isInjective(const ColumnsWithTypeAndName & sample_block) const override
{ {
for (const auto & f : originals) for (const auto & f : nested_functions)
if (!f->isInjective(sample_block)) if (!f->isInjective(sample_block))
return false; return false;
return true; return true;
} }
bool LLVMFunction::hasInformationAboutMonotonicity() const bool hasInformationAboutMonotonicity() const override
{ {
for (const auto & f : originals) for (const auto & f : nested_functions)
if (!f->hasInformationAboutMonotonicity()) if (!f->hasInformationAboutMonotonicity())
return false; return false;
return true; return true;
} }
LLVMFunction::Monotonicity LLVMFunction::getMonotonicityForRange(const IDataType & type, const Field & left, const Field & right) const Monotonicity getMonotonicityForRange(const IDataType & type, const Field & left, const Field & right) const override
{ {
const IDataType * type_ptr = &type; const IDataType * type_ptr = &type;
Field left_mut = left; Field left_mut = left;
Field right_mut = right; Field right_mut = right;
Monotonicity result(true, true, true); Monotonicity result(true, true, true);
/// monotonicity is only defined for unary functions, so the chain must describe a sequence of nested calls /// monotonicity is only defined for unary functions, so the chain must describe a sequence of nested calls
for (size_t i = 0; i < originals.size(); ++i) for (size_t i = 0; i < nested_functions.size(); ++i)
{ {
Monotonicity m = originals[i]->getMonotonicityForRange(*type_ptr, left_mut, right_mut); Monotonicity m = nested_functions[i]->getMonotonicityForRange(*type_ptr, left_mut, right_mut);
if (!m.is_monotonic) if (!m.is_monotonic)
return m; return m;
result.is_positive ^= !m.is_positive; result.is_positive ^= !m.is_positive;
result.is_always_monotonic &= m.is_always_monotonic; result.is_always_monotonic &= m.is_always_monotonic;
if (i + 1 < originals.size()) if (i + 1 < nested_functions.size())
{ {
if (left_mut != Field()) if (left_mut != Field())
applyFunction(*originals[i], left_mut); applyFunction(*nested_functions[i], left_mut);
if (right_mut != Field()) if (right_mut != Field())
applyFunction(*originals[i], right_mut); applyFunction(*nested_functions[i], right_mut);
if (!m.is_positive) if (!m.is_positive)
std::swap(left_mut, right_mut); std::swap(left_mut, right_mut);
type_ptr = originals[i]->getResultType().get(); type_ptr = nested_functions[i]->getResultType().get();
} }
} }
return result; return result;
} }
static void applyFunction(IFunctionBase & function, Field & value)
{
const auto & type = function.getArgumentTypes().at(0);
ColumnsWithTypeAndName args{{type->createColumnConst(1, value), type, "x" }};
auto col = function.execute(args, function.getResultType(), 1);
col->get(0, value);
}
private:
std::string name;
CompileDAG dag;
DataTypes argument_types;
std::vector<FunctionBasePtr> nested_functions;
CHJIT::CompiledModuleInfo module_info;
};
static FunctionBasePtr compile(
const CompileDAG & dag,
size_t min_count_to_compile_expression)
{
static std::unordered_map<UInt128, UInt64, UInt128Hash> counter;
static std::mutex mutex;
auto hash_key = dag.hash();
{
std::lock_guard lock(mutex);
if (counter[hash_key]++ < min_count_to_compile_expression)
return nullptr;
}
FunctionBasePtr fn;
if (auto * compilation_cache = CompiledExpressionCacheFactory::instance().tryGetCache())
{
auto [compiled_function, was_inserted] = compilation_cache->getOrSet(hash_key, [&dag] ()
{
auto llvm_function = std::make_unique<LLVMFunction>(dag);
size_t compiled_size = llvm_function->getCompiledSize();
FunctionBasePtr llvm_function_wrapper = std::make_shared<FunctionBaseAdaptor>(std::move(llvm_function));
CompiledFunction compiled_function
{
.function = llvm_function_wrapper,
.compiled_size = compiled_size
};
return std::make_shared<CompiledFunction>(compiled_function);
});
fn = compiled_function->function;
}
else
{
fn = std::make_shared<FunctionBaseAdaptor>(std::make_unique<LLVMFunction>(dag));
}
return fn;
}
static bool isCompilable(const IFunctionBase & function) static bool isCompilable(const IFunctionBase & function)
{ {
@ -475,12 +309,12 @@ static bool isCompilableFunction(const ActionsDAG::Node & node)
return node.type == ActionsDAG::ActionType::FUNCTION && isCompilable(*node.function_base); return node.type == ActionsDAG::ActionType::FUNCTION && isCompilable(*node.function_base);
} }
static LLVMFunction::CompileDAG getCompilableDAG( static CompileDAG getCompilableDAG(
const ActionsDAG::Node * root, const ActionsDAG::Node * root,
ActionsDAG::NodeRawConstPtrs & children, ActionsDAG::NodeRawConstPtrs & children,
const std::unordered_set<const ActionsDAG::Node *> & used_in_result) const std::unordered_set<const ActionsDAG::Node *> & used_in_result)
{ {
LLVMFunction::CompileDAG dag; CompileDAG dag;
std::unordered_map<const ActionsDAG::Node *, size_t> positions; std::unordered_map<const ActionsDAG::Node *, size_t> positions;
struct Frame struct Frame
@ -514,25 +348,30 @@ static LLVMFunction::CompileDAG getCompilableDAG(
if (!is_compilable_function || frame.next_child_to_visit == frame.node->children.size()) if (!is_compilable_function || frame.next_child_to_visit == frame.node->children.size())
{ {
LLVMFunction::CompileNode node; CompileDAG::Node node;
node.function = frame.node->function_base; node.function = frame.node->function_base;
node.result_type = frame.node->result_type; node.result_type = frame.node->result_type;
node.type = is_const ? LLVMFunction::CompileNode::NodeType::CONSTANT
: (is_compilable_function ? LLVMFunction::CompileNode::NodeType::FUNCTION
: LLVMFunction::CompileNode::NodeType::INPUT);
if (node.type == LLVMFunction::CompileNode::NodeType::FUNCTION) if (is_compilable_function)
{
node.type = CompileDAG::CompileType::FUNCTION;
for (const auto * child : frame.node->children) for (const auto * child : frame.node->children)
node.arguments.push_back(positions[child]); node.arguments.push_back(positions[child]);
}
if (node.type == LLVMFunction::CompileNode::NodeType::CONSTANT) else if (is_const)
{
node.type = CompileDAG::CompileType::CONSTANT;
node.column = frame.node->column; node.column = frame.node->column;
}
if (node.type == LLVMFunction::CompileNode::NodeType::INPUT) else
{
node.type = CompileDAG::CompileType::INPUT;
children.emplace_back(frame.node); children.emplace_back(frame.node);
positions[frame.node] = dag.size(); }
dag.push_back(std::move(node));
positions[frame.node] = dag.getNodesCount();
dag.addNode(std::move(node));
stack.pop(); stack.pop();
} }
} }
@ -540,127 +379,8 @@ static LLVMFunction::CompileDAG getCompilableDAG(
return dag; return dag;
} }
std::string LLVMFunction::CompileDAG::dump() const
{
WriteBufferFromOwnString out;
bool first = true;
for (const auto & node : *this)
{
if (!first)
out << " ; ";
first = false;
switch (node.type)
{
case CompileNode::NodeType::CONSTANT:
{
const auto * column = typeid_cast<const ColumnConst *>(node.column.get());
const auto & data = column->getDataColumn();
out << node.result_type->getName() << " = " << applyVisitor(FieldVisitorToString(), data[0]);
break;
}
case CompileNode::NodeType::FUNCTION:
{
out << node.result_type->getName() << " = ";
out << node.function->getName() << "(";
for (size_t i = 0; i < node.arguments.size(); ++i)
{
if (i)
out << ", ";
out << node.arguments[i];
}
out << ")";
break;
}
case CompileNode::NodeType::INPUT:
{
out << node.result_type->getName();
break;
}
}
}
return out.str();
}
UInt128 LLVMFunction::CompileDAG::hash() const
{
SipHash hash;
for (const auto & node : *this)
{
hash.update(node.type);
hash.update(node.result_type->getName());
switch (node.type)
{
case CompileNode::NodeType::CONSTANT:
{
typeid_cast<const ColumnConst *>(node.column.get())->getDataColumn().updateHashWithValue(0, hash);
break;
}
case CompileNode::NodeType::FUNCTION:
{
hash.update(node.function->getName());
for (size_t arg : node.arguments)
hash.update(arg);
break;
}
case CompileNode::NodeType::INPUT:
{
break;
}
}
}
UInt128 result;
hash.get128(result.low, result.high);
return result;
}
static FunctionBasePtr compile(
const LLVMFunction::CompileDAG & dag,
size_t min_count_to_compile_expression)
{
static std::unordered_map<UInt128, UInt32, UInt128Hash> counter;
static std::mutex mutex;
auto hash_key = dag.hash();
{
std::lock_guard lock(mutex);
if (counter[hash_key]++ < min_count_to_compile_expression)
return nullptr;
}
FunctionBasePtr fn;
if (auto * compilation_cache = CompiledExpressionCacheFactory::instance().tryGetCache())
{
std::tie(fn, std::ignore) = compilation_cache->getOrSet(hash_key, [&dag] ()
{
Stopwatch watch;
FunctionBasePtr result_fn;
result_fn = std::make_shared<FunctionBaseAdaptor>(std::make_unique<LLVMFunction>(dag));
ProfileEvents::increment(ProfileEvents::CompileExpressionsMicroseconds, watch.elapsedMicroseconds());
return result_fn;
});
}
else
{
Stopwatch watch;
fn = std::make_shared<FunctionBaseAdaptor>(std::make_unique<LLVMFunction>(dag));
ProfileEvents::increment(ProfileEvents::CompileExpressionsMicroseconds, watch.elapsedMicroseconds());
}
return fn;
}
void ActionsDAG::compileFunctions(size_t min_count_to_compile_expression) void ActionsDAG::compileFunctions(size_t min_count_to_compile_expression)
{ {
/// TODO: Rewrite
struct Data struct Data
{ {
bool is_compilable = false; bool is_compilable = false;
@ -743,18 +463,7 @@ void ActionsDAG::compileFunctions(size_t min_count_to_compile_expression)
NodeRawConstPtrs new_children; NodeRawConstPtrs new_children;
auto dag = getCompilableDAG(frame.node, new_children, used_in_result); auto dag = getCompilableDAG(frame.node, new_children, used_in_result);
bool all_constants = true; if (dag.getInputNodesCount() > 0)
for (const auto & compiled_node : dag)
{
if (compiled_node.type == LLVMFunction::CompileNode::NodeType::INPUT)
{
all_constants = false;
break;
}
}
if (!all_constants)
{ {
if (auto fn = compile(dag, min_count_to_compile_expression)) if (auto fn = compile(dag, min_count_to_compile_expression))
{ {

View File

@ -5,93 +5,34 @@
#endif #endif
#if USE_EMBEDDED_COMPILER #if USE_EMBEDDED_COMPILER
# include <set>
# include <Functions/IFunctionImpl.h> # include <Functions/IFunctionImpl.h>
# include <Interpreters/Context.h>
# include <Interpreters/ExpressionActions.h>
# include <Common/LRUCache.h> # include <Common/LRUCache.h>
namespace DB namespace DB
{ {
using CompilableExpression = std::function<llvm::Value * (llvm::IRBuilderBase &, const Values &)>; struct CompiledFunction
class LLVMFunction : public IFunctionBaseImpl
{ {
std::string name;
DataTypes arg_types;
std::vector<FunctionBasePtr> originals;
CompilableExpression expression;
public:
/// LLVMFunction is a compiled part of ActionsDAG.
/// We store this part as independent DAG with minial required information to compile it.
struct CompileNode
{
enum class NodeType
{
INPUT = 0,
CONSTANT = 1,
FUNCTION = 2,
};
NodeType type;
DataTypePtr result_type;
/// For CONSTANT
ColumnPtr column;
/// For FUNCTION
FunctionBasePtr function; FunctionBasePtr function;
std::vector<size_t> arguments; size_t compiled_size;
}; };
/// DAG is represented as list of nodes stored in in-order traverse order. struct CompiledFunctionWeightFunction
/// Expression (a + 1) + (b + 1) will be represented like chain: a, 1, a + 1, b, b + 1, (a + 1) + (b + 1).
struct CompileDAG : public std::vector<CompileNode>
{ {
std::string dump() const; size_t operator()(const CompiledFunction & compiled_function) const
UInt128 hash() const; {
}; return compiled_function.compiled_size;
}
explicit LLVMFunction(const CompileDAG & dag);
bool isCompilable() const override { return true; }
llvm::Value * compile(llvm::IRBuilderBase & builder, Values values) const override;
String getName() const override { return name; }
const DataTypes & getArgumentTypes() const override { return arg_types; }
const DataTypePtr & getResultType() const override { return originals.back()->getResultType(); }
ExecutableFunctionImplPtr prepare(const ColumnsWithTypeAndName &) const override;
bool isDeterministic() const override;
bool isDeterministicInScopeOfQuery() const override;
bool isSuitableForConstantFolding() const override;
bool isInjective(const ColumnsWithTypeAndName & sample_block) const override;
bool hasInformationAboutMonotonicity() const override;
Monotonicity getMonotonicityForRange(const IDataType & type, const Field & left, const Field & right) const override;
}; };
/** This child of LRUCache breaks one of it's invariants: total weight may be changed after insertion. /** This child of LRUCache breaks one of it's invariants: total weight may be changed after insertion.
* We have to do so, because we don't known real memory consumption of generated LLVM code for every function. * We have to do so, because we don't known real memory consumption of generated LLVM code for every function.
*/ */
class CompiledExpressionCache : public LRUCache<UInt128, IFunctionBase, UInt128Hash> class CompiledExpressionCache : public LRUCache<UInt128, CompiledFunction, UInt128Hash, CompiledFunctionWeightFunction>
{ {
public: public:
using Base = LRUCache<UInt128, IFunctionBase, UInt128Hash>; using Base = LRUCache<UInt128, CompiledFunction, UInt128Hash, CompiledFunctionWeightFunction>;
using Base::Base; using Base::Base;
}; };

View File

@ -13,6 +13,7 @@
#include <llvm/IR/LegacyPassManager.h> #include <llvm/IR/LegacyPassManager.h>
#include <llvm/ExecutionEngine/JITSymbol.h> #include <llvm/ExecutionEngine/JITSymbol.h>
#include <llvm/ExecutionEngine/SectionMemoryManager.h> #include <llvm/ExecutionEngine/SectionMemoryManager.h>
#include <llvm/ExecutionEngine/JITEventListener.h>
#include <llvm/MC/SubtargetFeature.h> #include <llvm/MC/SubtargetFeature.h>
#include <llvm/Support/DynamicLibrary.h> #include <llvm/Support/DynamicLibrary.h>
#include <llvm/Support/Host.h> #include <llvm/Support/Host.h>
@ -140,6 +141,30 @@ private:
std::unordered_map<std::string, void *> symbol_name_to_symbol_address; std::unordered_map<std::string, void *> symbol_name_to_symbol_address;
}; };
// class JITEventListener
// {
// public:
// JITEventListener()
// : gdb_listener(llvm::JITEventListener::createGDBRegistrationListener())
// {}
// void notifyObjectLoaded(
// llvm::JITEventListener::ObjectKey object_key,
// const llvm::object::ObjectFile & object_file,
// const llvm::RuntimeDyld::LoadedObjectInfo & loaded_object_Info)
// {
// gdb_listener->notifyObjectLoaded(object_key, object_file, loaded_object_Info);
// }
// void notifyFreeingObject(llvm::JITEventListener::ObjectKey object_key)
// {
// gdb_listener->notifyFreeingObject(object_key);
// }
// private:
// llvm::JITEventListener * gdb_listener = nullptr;
// };
CHJIT::CHJIT() CHJIT::CHJIT()
: machine(getTargetMachine()) : machine(getTargetMachine())
, layout(machine->createDataLayout()) , layout(machine->createDataLayout())
@ -156,9 +181,12 @@ CHJIT::~CHJIT() = default;
CHJIT::CompiledModuleInfo CHJIT::compileModule(std::function<void (llvm::Module &)> compile_function) CHJIT::CompiledModuleInfo CHJIT::compileModule(std::function<void (llvm::Module &)> compile_function)
{ {
std::lock_guard<std::mutex> lock(jit_lock); std::lock_guard<std::mutex> lock(jit_lock);
auto module = createModuleForCompilation(); auto module = createModuleForCompilation();
compile_function(*module); compile_function(*module);
auto module_info = compileModule(std::move(module)); auto module_info = compileModule(std::move(module));
++current_module_key;
return module_info; return module_info;
} }
@ -168,8 +196,6 @@ std::unique_ptr<llvm::Module> CHJIT::createModuleForCompilation()
module->setDataLayout(layout); module->setDataLayout(layout);
module->setTargetTriple(machine->getTargetTriple().getTriple()); module->setTargetTriple(machine->getTargetTriple().getTriple());
++current_module_key;
return module; return module;
} }
@ -194,9 +220,6 @@ CHJIT::CompiledModuleInfo CHJIT::compileModule(std::unique_ptr<llvm::Module> mod
std::unique_ptr<llvm::RuntimeDyld::LoadedObjectInfo> linked_object = dynamic_linker.loadObject(*object.get()); std::unique_ptr<llvm::RuntimeDyld::LoadedObjectInfo> linked_object = dynamic_linker.loadObject(*object.get());
if (dynamic_linker.hasError())
throw Exception(ErrorCodes::CANNOT_COMPILE_CODE, "RuntimeDyld error {}", std::string(dynamic_linker.getErrorString()));
dynamic_linker.resolveRelocations(); dynamic_linker.resolveRelocations();
module_memory_manager->getManager().finalizeMemory(); module_memory_manager->getManager().finalizeMemory();
@ -217,14 +240,15 @@ CHJIT::CompiledModuleInfo CHJIT::compileModule(std::unique_ptr<llvm::Module> mod
auto * jit_symbol_address = reinterpret_cast<void *>(jit_symbol.getAddress()); auto * jit_symbol_address = reinterpret_cast<void *>(jit_symbol.getAddress());
name_to_symbol[function_name] = jit_symbol_address; name_to_symbol[function_name] = jit_symbol_address;
module_info.compiled_functions.emplace_back(std::move(function_name));
} }
auto module_identifier = module->getModuleIdentifier(); auto module_identifier = module->getModuleIdentifier();
module_info.size = module_memory_manager->getAllocatedSize(); module_info.size = module_memory_manager->getAllocatedSize();
module_info.module_identifier = module_identifier; module_info.module_identifier = current_module_key;
module_identifier_to_memory_manager[module_identifier] = std::move(module_memory_manager); module_identifier_to_memory_manager[current_module_key] = std::move(module_memory_manager);
compiled_code_size.fetch_add(module_info.size, std::memory_order_relaxed); compiled_code_size.fetch_add(module_info.size, std::memory_order_relaxed);

View File

@ -20,7 +20,6 @@ class JITModuleMemoryManager;
class JITSymbolResolver; class JITSymbolResolver;
class JITCompiler; class JITCompiler;
/// TODO: Add profile events
/// TODO: Add documentation /// TODO: Add documentation
class CHJIT class CHJIT
{ {
@ -32,7 +31,7 @@ public:
struct CompiledModuleInfo struct CompiledModuleInfo
{ {
size_t size; size_t size;
std::string module_identifier; uint64_t module_identifier;
std::vector<std::string> compiled_functions; std::vector<std::string> compiled_functions;
}; };
@ -65,10 +64,11 @@ private:
std::unique_ptr<JITSymbolResolver> symbol_resolver; std::unique_ptr<JITSymbolResolver> symbol_resolver;
std::unordered_map<std::string, void *> name_to_symbol; std::unordered_map<std::string, void *> name_to_symbol;
std::unordered_map<std::string, std::unique_ptr<JITModuleMemoryManager>> module_identifier_to_memory_manager; std::unordered_map<uint64_t, std::unique_ptr<JITModuleMemoryManager>> module_identifier_to_memory_manager;
size_t current_module_key = 0; uint64_t current_module_key = 0;
std::atomic<size_t> compiled_code_size = 0; std::atomic<size_t> compiled_code_size = 0;
mutable std::mutex jit_lock; mutable std::mutex jit_lock;
}; };
} }

View File

@ -0,0 +1,155 @@
#include "CompileDAG.h"
#if USE_EMBEDDED_COMPILER
#include <llvm/IR/BasicBlock.h>
#include <llvm/IR/Function.h>
#include <llvm/IR/IRBuilder.h>
#include <Common/SipHash.h>
#include <Common/FieldVisitors.h>
#include <Columns/ColumnConst.h>
#include <DataTypes/Native.h>
#include <IO/WriteBufferFromString.h>
#include <IO/Operators.h>
namespace DB
{
llvm::Value * CompileDAG::compile(llvm::IRBuilderBase & builder, Values input_nodes_values) const
{
assert(input_nodes_values.size() == getInputNodesCount());
llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder);
PaddedPODArray<llvm::Value *> compiled_values;
compiled_values.resize_fill(nodes.size());
size_t input_nodes_values_index = 0;
size_t compiled_values_index = 0;
size_t dag_size = nodes.size();
for (size_t i = 0; i < dag_size; ++i)
{
const auto & node = nodes[i];
switch (node.type)
{
case CompileType::CONSTANT:
{
compiled_values[compiled_values_index] = getColumnNativeValue(b, node.result_type, *node.column, 0);
break;
}
case CompileType::FUNCTION:
{
Values temporary_values;
temporary_values.reserve(node.arguments.size());
for (auto argument_index : node.arguments)
temporary_values.emplace_back(compiled_values[argument_index]);
compiled_values[compiled_values_index] = node.function->compile(builder, temporary_values);
break;
}
case CompileType::INPUT:
{
compiled_values[compiled_values_index] = input_nodes_values[input_nodes_values_index];
++input_nodes_values_index;
break;
}
}
++compiled_values_index;
}
return compiled_values.back();
}
std::string CompileDAG::dump() const
{
std::vector<std::string> dumped_values;
dumped_values.resize(nodes.size());
size_t input_index = 0;
size_t dag_size = nodes.size();
for (size_t i = 0; i < dag_size; ++i)
{
const auto & node = nodes[i];
switch (node.type)
{
case CompileType::CONSTANT:
{
const auto * column = typeid_cast<const ColumnConst *>(node.column.get());
const auto & data = column->getDataColumn();
dumped_values[i] = applyVisitor(FieldVisitorToString(), data[0]) + " : " + node.result_type->getName();
break;
}
case CompileType::FUNCTION:
{
std::string function_dump = node.function->getName();
function_dump += '(';
for (auto argument_index : node.arguments)
function_dump += dumped_values[argument_index] += ", ";
function_dump.pop_back();
function_dump.pop_back();
function_dump += ')';
dumped_values[i] = function_dump;
break;
}
case CompileType::INPUT:
{
dumped_values[i] = node.result_type->getName();
++input_index;
break;
}
}
}
return dumped_values.back();
}
UInt128 CompileDAG::hash() const
{
SipHash hash;
for (const auto & node : nodes)
{
hash.update(node.type);
hash.update(node.result_type->getName());
switch (node.type)
{
case CompileType::CONSTANT:
{
assert_cast<const ColumnConst *>(node.column.get())->getDataColumn().updateHashWithValue(0, hash);
break;
}
case CompileType::FUNCTION:
{
hash.update(node.function->getName());
for (size_t arg : node.arguments)
hash.update(arg);
break;
}
case CompileType::INPUT:
{
break;
}
}
}
UInt128 result;
hash.get128(result.low, result.high);
return result;
}
}
#endif

View File

@ -0,0 +1,86 @@
#pragma once
#if !defined(ARCADIA_BUILD)
# include "config_core.h"
#endif
#if USE_EMBEDDED_COMPILER
#include <vector>
#include <Core/Types.h>
#include <Common/UInt128.h>
#include <Columns/IColumn.h>
#include <DataTypes/IDataType.h>
#include <Functions/IFunctionImpl.h>
namespace llvm
{
class Value;
class IRBuilderBase;
}
namespace DB
{
/// DAG is represented as list of nodes stored in in-order traverse order.
/// Expression (a + 1) + (b + 1) will be represented like chain: a, 1, a + 1, b, b + 1, (a + 1) + (b + 1).
/// TODO: Consider to rename in CompileStack
class CompileDAG
{
public:
enum class CompileType
{
INPUT = 0,
CONSTANT = 1,
FUNCTION = 2,
};
struct Node
{
CompileType type;
DataTypePtr result_type;
/// For CONSTANT
ColumnPtr column;
/// For FUNCTION
FunctionBasePtr function;
std::vector<size_t> arguments;
};
llvm::Value * compile(llvm::IRBuilderBase & builder, Values input_nodes_values) const;
std::string dump() const;
UInt128 hash() const;
void addNode(Node node)
{
input_nodes_count += (node.type == CompileType::INPUT);
nodes.emplace_back(std::move(node));
}
inline size_t getNodesCount() const { return nodes.size(); }
inline size_t getInputNodesCount() const { return input_nodes_count; }
inline Node & operator[](size_t index) { return nodes[index]; }
inline const Node & operator[](size_t index) const { return nodes[index]; }
inline Node & front() { return nodes.front(); }
inline const Node & front() const { return nodes.front(); }
inline Node & back() { return nodes.back(); }
inline const Node & back() const { return nodes.back(); }
private:
std::vector<Node> nodes;
size_t input_nodes_count = 0;
};
}
#endif

View File

@ -0,0 +1,160 @@
#include "compileFunction.h"
#include <llvm/IR/BasicBlock.h>
#include <llvm/IR/Function.h>
#include <llvm/IR/IRBuilder.h>
#include <Common/Stopwatch.h>
#include <Common/ProfileEvents.h>
#include <DataTypes/Native.h>
#include <Interpreters/JIT/CHJIT.h>
namespace
{
struct ColumnDataPlaceholder
{
llvm::Value * data_init; /// first row
llvm::Value * null_init;
llvm::PHINode * data; /// current row
llvm::PHINode * null;
};
}
namespace ProfileEvents
{
extern const Event CompileFunction;
extern const Event CompileExpressionsMicroseconds;
extern const Event CompileExpressionsBytes;
}
namespace DB
{
ColumnData getColumnData(const IColumn * column)
{
ColumnData result;
const bool is_const = isColumnConst(*column);
if (is_const)
throw Exception(ErrorCodes::LOGICAL_ERROR, "Input columns should not be constant");
if (const auto * nullable = typeid_cast<const ColumnNullable *>(column))
{
result.null = nullable->getNullMapColumn().getRawData().data;
column = & nullable->getNestedColumn();
}
result.data = column->getRawData().data;
return result;
}
static void compileFunction(llvm::Module & module, const IFunctionBaseImpl & f)
{
ProfileEvents::increment(ProfileEvents::CompileFunction);
const auto & arg_types = f.getArgumentTypes();
llvm::IRBuilder<> b(module.getContext());
auto * size_type = b.getIntNTy(sizeof(size_t) * 8);
auto * data_type = llvm::StructType::get(b.getInt8PtrTy(), b.getInt8PtrTy());
auto * func_type = llvm::FunctionType::get(b.getVoidTy(), { size_type, data_type->getPointerTo() }, /*isVarArg=*/false);
auto * func = llvm::Function::Create(func_type, llvm::Function::ExternalLinkage, f.getName(), module);
auto * args = func->args().begin();
llvm::Value * counter_arg = &*args++;
llvm::Value * columns_arg = &*args++;
auto * entry = llvm::BasicBlock::Create(b.getContext(), "entry", func);
b.SetInsertPoint(entry);
std::vector<ColumnDataPlaceholder> columns(arg_types.size() + 1);
for (size_t i = 0; i <= arg_types.size(); ++i)
{
const auto & type = i == arg_types.size() ? f.getResultType() : arg_types[i];
auto * data = b.CreateLoad(b.CreateConstInBoundsGEP1_32(data_type, columns_arg, i));
columns[i].data_init = b.CreatePointerCast(b.CreateExtractValue(data, {0}), toNativeType(b, removeNullable(type))->getPointerTo());
columns[i].null_init = type->isNullable() ? b.CreateExtractValue(data, {1}) : nullptr;
}
/// assume nonzero initial value in `counter_arg`
auto * loop = llvm::BasicBlock::Create(b.getContext(), "loop", func);
b.CreateBr(loop);
b.SetInsertPoint(loop);
auto * counter_phi = b.CreatePHI(counter_arg->getType(), 2);
counter_phi->addIncoming(counter_arg, entry);
for (auto & col : columns)
{
col.data = b.CreatePHI(col.data_init->getType(), 2);
col.data->addIncoming(col.data_init, entry);
if (col.null_init)
{
col.null = b.CreatePHI(col.null_init->getType(), 2);
col.null->addIncoming(col.null_init, entry);
}
}
Values arguments;
arguments.reserve(arg_types.size());
for (size_t i = 0; i < arg_types.size(); ++i) // NOLINT
{
auto & column = columns[i];
auto type = arg_types[i];
auto * value = b.CreateLoad(column.data);
if (!column.null)
{
arguments.emplace_back(value);
continue;
}
auto * is_null = b.CreateICmpNE(b.CreateLoad(column.null), b.getInt8(0));
auto * nullable_unitilized = llvm::Constant::getNullValue(toNativeType(b, type));
auto * nullable_value = b.CreateInsertValue(b.CreateInsertValue(nullable_unitilized, value, {0}), is_null, {1});
arguments.emplace_back(nullable_value);
}
auto * result = f.compile(b, std::move(arguments));
if (columns.back().null)
{
b.CreateStore(b.CreateExtractValue(result, {0}), columns.back().data);
b.CreateStore(b.CreateSelect(b.CreateExtractValue(result, {1}), b.getInt8(1), b.getInt8(0)), columns.back().null);
}
else
{
b.CreateStore(result, columns.back().data);
}
auto * cur_block = b.GetInsertBlock();
for (auto & col : columns)
{
col.data->addIncoming(b.CreateConstInBoundsGEP1_32(nullptr, col.data, 1), cur_block);
if (col.null)
col.null->addIncoming(b.CreateConstInBoundsGEP1_32(nullptr, col.null, 1), cur_block);
}
counter_phi->addIncoming(b.CreateSub(counter_phi, llvm::ConstantInt::get(size_type, 1)), cur_block);
auto * end = llvm::BasicBlock::Create(b.getContext(), "end", func);
b.CreateCondBr(b.CreateICmpNE(counter_phi, llvm::ConstantInt::get(size_type, 1)), loop, end);
b.SetInsertPoint(end);
b.CreateRetVoid();
}
CHJIT::CompiledModuleInfo compileFunction(CHJIT & jit, const IFunctionBaseImpl & f)
{
Stopwatch watch;
auto compiled_module_info = jit.compileModule([&](llvm::Module & module)
{
compileFunction(module, f);
});
ProfileEvents::increment(ProfileEvents::CompileExpressionsMicroseconds, watch.elapsedMicroseconds());
ProfileEvents::increment(ProfileEvents::CompileExpressionsBytes, compiled_module_info.size);
ProfileEvents::increment(ProfileEvents::CompileFunction);
return compiled_module_info;
}
}

View File

@ -0,0 +1,25 @@
#if !defined(ARCADIA_BUILD)
# include "config_core.h"
#endif
#if USE_EMBEDDED_COMPILER
#include <Functions/IFunctionImpl.h>
#include <Interpreters/JIT/CHJIT.h>
namespace DB
{
struct ColumnData
{
const char * data = nullptr;
const char * null = nullptr;
};
ColumnData getColumnData(const IColumn * column);
CHJIT::CompiledModuleInfo compileFunction(CHJIT & jit, const IFunctionBaseImpl & f);
}
#endif

View File

@ -26,7 +26,10 @@ int main(int argc, char **argv)
auto * func_declaration_type = llvm::FunctionType::get(b.getVoidTy(), { }, /*isVarArg=*/false); auto * func_declaration_type = llvm::FunctionType::get(b.getVoidTy(), { }, /*isVarArg=*/false);
auto * func_declaration = llvm::Function::Create(func_declaration_type, llvm::Function::ExternalLinkage, "test_function", module); auto * func_declaration = llvm::Function::Create(func_declaration_type, llvm::Function::ExternalLinkage, "test_function", module);
auto * func_type = llvm::FunctionType::get(b.getVoidTy(), { b.getInt64Ty() }, /*isVarArg=*/false); auto * value_type = b.getInt64Ty();
auto * pointer_type = value_type->getPointerTo();
auto * func_type = llvm::FunctionType::get(b.getVoidTy(), { pointer_type }, /*isVarArg=*/false);
auto * function = llvm::Function::Create(func_type, llvm::Function::ExternalLinkage, "test_name", module); auto * function = llvm::Function::Create(func_type, llvm::Function::ExternalLinkage, "test_name", module);
auto * entry = llvm::BasicBlock::Create(context, "entry", function); auto * entry = llvm::BasicBlock::Create(context, "entry", function);
@ -35,18 +38,19 @@ int main(int argc, char **argv)
b.CreateCall(func_declaration); b.CreateCall(func_declaration);
auto * value = b.CreateAdd(argument, argument); auto * load_argument = b.CreateLoad(argument);
auto * value = b.CreateAdd(load_argument, load_argument);
b.CreateRet(value); b.CreateRet(value);
}); });
std::cerr << "Compile module info " << compiled_module_info.module_identifier << " size " << compiled_module_info.size << std::endl;
for (const auto & compiled_function_name : compiled_module_info.compiled_functions) for (const auto & compiled_function_name : compiled_module_info.compiled_functions)
{ {
std::cerr << compiled_function_name << std::endl; std::cerr << compiled_function_name << std::endl;
} }
auto * test_name_function = reinterpret_cast<int64_t (*)(int64_t)>(jit.findCompiledFunction("test_name")); int64_t value = 5;
auto result = test_name_function(5); auto * test_name_function = reinterpret_cast<int64_t (*)(int64_t *)>(jit.findCompiledFunction("test_name"));
auto result = test_name_function(&value);
std::cerr << "Result " << result << std::endl; std::cerr << "Result " << result << std::endl;
return 0; return 0;