mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-11-26 01:22:04 +00:00
Let jit-compilable functions deal with NULLs themselves.
And provide a default implementation of compile() for nullable columns that actually works and is consistent with execute().
This commit is contained in:
parent
49b61cd27d
commit
979c4d959f
@ -105,6 +105,7 @@ if (USE_EMBEDDED_COMPILER)
|
||||
target_include_directories (dbms BEFORE PUBLIC ${LLVM_INCLUDE_DIRS})
|
||||
# LLVM 5.0 has a bunch of unused parameters in its header files.
|
||||
# TODO: global-disable no-unused-parameter
|
||||
set_source_files_properties(src/Functions/IFunction.cpp PROPERTIES COMPILE_FLAGS "-Wno-unused-parameter")
|
||||
set_source_files_properties(src/Interpreters/ExpressionJIT.cpp PROPERTIES COMPILE_FLAGS "-Wno-unused-parameter -Wno-non-virtual-dtor")
|
||||
endif ()
|
||||
|
||||
|
52
dbms/src/DataTypes/Native.h
Normal file
52
dbms/src/DataTypes/Native.h
Normal file
@ -0,0 +1,52 @@
|
||||
#pragma once
|
||||
|
||||
#include <Common/config.h>
|
||||
#include <DataTypes/DataTypeNullable.h>
|
||||
#include <DataTypes/DataTypesNumber.h>
|
||||
|
||||
namespace llvm
|
||||
{
|
||||
class IRBuilderBase;
|
||||
class Type;
|
||||
}
|
||||
|
||||
#if USE_EMBEDDED_COMPILER
|
||||
#include <llvm/IR/IRBuilder.h>
|
||||
#endif
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
namespace ErrorCodes
|
||||
{
|
||||
extern const int NOT_IMPLEMENTED;
|
||||
}
|
||||
|
||||
static llvm::Type * toNativeType([[maybe_unused]] llvm::IRBuilderBase & builder, [[maybe_unused]] const DataTypePtr & type)
|
||||
{
|
||||
#if USE_EMBEDDED_COMPILER
|
||||
if (auto * nullable = typeid_cast<const DataTypeNullable *>(type.get()))
|
||||
{
|
||||
auto * wrapped = toNativeType(builder, nullable->getNestedType());
|
||||
return wrapped ? llvm::PointerType::get(wrapped, 0) : nullptr;
|
||||
}
|
||||
/// LLVM doesn't have unsigned types, it has unsigned instructions.
|
||||
if (typeid_cast<const DataTypeInt8 *>(type.get()) || typeid_cast<const DataTypeUInt8 *>(type.get()))
|
||||
return builder.getInt8Ty();
|
||||
if (typeid_cast<const DataTypeInt16 *>(type.get()) || typeid_cast<const DataTypeUInt16 *>(type.get()))
|
||||
return builder.getInt16Ty();
|
||||
if (typeid_cast<const DataTypeInt32 *>(type.get()) || typeid_cast<const DataTypeUInt32 *>(type.get()))
|
||||
return builder.getInt32Ty();
|
||||
if (typeid_cast<const DataTypeInt64 *>(type.get()) || typeid_cast<const DataTypeUInt64 *>(type.get()))
|
||||
return builder.getInt64Ty();
|
||||
if (typeid_cast<const DataTypeFloat32 *>(type.get()))
|
||||
return builder.getFloatTy();
|
||||
if (typeid_cast<const DataTypeFloat64 *>(type.get()))
|
||||
return builder.getDoubleTy();
|
||||
return nullptr;
|
||||
#else
|
||||
throw Exception("JIT-compilation is disabled", ErrorCodes::NOT_IMPLEMENTED);
|
||||
#endif
|
||||
}
|
||||
|
||||
}
|
@ -25,12 +25,12 @@ public:
|
||||
static constexpr auto name = "something";
|
||||
|
||||
#if USE_EMBEDDED_COMPILER
|
||||
bool isCompilable(const DataTypes & types) const override
|
||||
bool isCompilableImpl(const DataTypes & types) const override
|
||||
{
|
||||
return types.size() == 2 && types[0]->equals(*types[1]);
|
||||
}
|
||||
|
||||
llvm::Value * compile(llvm::IRBuilderBase & builder, const DataTypes & types, const ValuePlaceholders & values) const override
|
||||
llvm::Value * compileImpl(llvm::IRBuilderBase & builder, const DataTypes & types, ValuePlaceholders values) const override
|
||||
{
|
||||
if (types[0]->equals(DataTypeFloat32{}) || types[0]->equals(DataTypeFloat64{}))
|
||||
return static_cast<llvm::IRBuilder<>&>(builder).CreateFAdd(values[0](), values[1]());
|
||||
|
@ -1,14 +1,20 @@
|
||||
#include <Functions/IFunction.h>
|
||||
#include <Functions/FunctionHelpers.h>
|
||||
#include <Columns/ColumnNullable.h>
|
||||
#include <DataTypes/DataTypeNullable.h>
|
||||
#include <DataTypes/DataTypeNothing.h>
|
||||
#include <Columns/ColumnConst.h>
|
||||
#include <Interpreters/ExpressionActions.h>
|
||||
#include <Common/config.h>
|
||||
#include <Common/typeid_cast.h>
|
||||
#include <Columns/ColumnConst.h>
|
||||
#include <Columns/ColumnNullable.h>
|
||||
#include <DataTypes/DataTypeNothing.h>
|
||||
#include <DataTypes/DataTypeNullable.h>
|
||||
#include <DataTypes/Native.h>
|
||||
#include <Functions/FunctionHelpers.h>
|
||||
#include <Functions/IFunction.h>
|
||||
#include <Interpreters/ExpressionActions.h>
|
||||
#include <ext/range.h>
|
||||
#include <ext/collection_cast.h>
|
||||
|
||||
#if USE_EMBEDDED_COMPILER
|
||||
#include <llvm/IR/IRBuilder.h>
|
||||
#endif
|
||||
|
||||
|
||||
namespace DB
|
||||
{
|
||||
@ -254,4 +260,75 @@ DataTypePtr FunctionBuilderImpl::getReturnType(const ColumnsWithTypeAndName & ar
|
||||
|
||||
return getReturnTypeImpl(arguments);
|
||||
}
|
||||
|
||||
static bool anyNullable(const DataTypes & types)
|
||||
{
|
||||
for (const auto & type : types)
|
||||
if (typeid_cast<const DataTypeNullable *>(type.get()))
|
||||
return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
bool IFunction::isCompilable(const DataTypes & arguments) const
|
||||
{
|
||||
if (useDefaultImplementationForNulls() && anyNullable(arguments))
|
||||
{
|
||||
DataTypes filtered;
|
||||
for (const auto & type : arguments)
|
||||
filtered.emplace_back(removeNullable(type));
|
||||
return isCompilableImpl(filtered);
|
||||
}
|
||||
return isCompilableImpl(arguments);
|
||||
}
|
||||
|
||||
std::vector<llvm::Value *> IFunction::compilePrologue(llvm::IRBuilderBase & builder, const DataTypes & arguments) const
|
||||
{
|
||||
auto result = compilePrologueImpl(builder, arguments);
|
||||
#if USE_EMBEDDED_COMPILER
|
||||
if (useDefaultImplementationForNulls() && anyNullable(arguments))
|
||||
result.push_back(static_cast<llvm::IRBuilder<> &>(builder).CreateAlloca(toNativeType(builder, getReturnTypeImpl(arguments))));
|
||||
#endif
|
||||
return result;
|
||||
}
|
||||
|
||||
llvm::Value * IFunction::compile(llvm::IRBuilderBase & builder, const DataTypes & arguments, ValuePlaceholders values) const
|
||||
{
|
||||
#if USE_EMBEDDED_COMPILER
|
||||
if (useDefaultImplementationForNulls() && anyNullable(arguments))
|
||||
{
|
||||
/// FIXME: when only one column is nullable, this is actually slower than the non-jitted version
|
||||
/// because this involves copying the null map while `wrapInNullable` reuses it.
|
||||
auto & b = static_cast<llvm::IRBuilder<> &>(builder);
|
||||
auto * fail = llvm::BasicBlock::Create(b.GetInsertBlock()->getContext(), "", b.GetInsertBlock()->getParent());
|
||||
auto * join = llvm::BasicBlock::Create(b.GetInsertBlock()->getContext(), "", b.GetInsertBlock()->getParent());
|
||||
auto * space = values.back()();
|
||||
values.pop_back();
|
||||
for (size_t i = 0; i < arguments.size(); i++)
|
||||
{
|
||||
if (!arguments[i]->isNullable())
|
||||
continue;
|
||||
values[i] = [&, previous = std::move(values[i])]()
|
||||
{
|
||||
auto * value = previous();
|
||||
auto * ok = llvm::BasicBlock::Create(b.GetInsertBlock()->getContext(), "", b.GetInsertBlock()->getParent());
|
||||
b.CreateCondBr(b.CreateIsNull(value), fail, ok);
|
||||
b.SetInsertPoint(ok);
|
||||
return b.CreateLoad(value);
|
||||
};
|
||||
}
|
||||
b.CreateStore(compileImpl(builder, arguments, std::move(values)), space);
|
||||
b.CreateBr(join);
|
||||
auto * result_block = b.GetInsertBlock();
|
||||
b.SetInsertPoint(fail); /// an empty joining block to avoid keeping track of where we could jump from
|
||||
b.CreateBr(join);
|
||||
b.SetInsertPoint(join);
|
||||
auto * phi = b.CreatePHI(space->getType(), 2);
|
||||
phi->addIncoming(space, result_block);
|
||||
phi->addIncoming(llvm::ConstantPointerNull::get(static_cast<llvm::PointerType *>(space->getType())), fail);
|
||||
return phi;
|
||||
}
|
||||
#endif
|
||||
return compileImpl(builder, arguments, std::move(values));
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -102,17 +102,25 @@ public:
|
||||
|
||||
virtual bool isCompilable() const { return false; }
|
||||
|
||||
/** Produce LLVM IR code that operates on *scalar* values. JIT-compilation is only supported for native
|
||||
* data types, i.e. numbers. This method will never be called if there is a non-number argument or
|
||||
* a non-number result type. Also, for any compilable function default behavior on NULL values is assumed,
|
||||
* i.e. the result is NULL if and only if any argument is NULL.
|
||||
/// Produce LLVM IR code that runs before the loop over the input rows. Mostly useful for allocating stack variables.
|
||||
virtual std::vector<llvm::Value *> compilePrologue(llvm::IRBuilderBase &) const
|
||||
{
|
||||
return {};
|
||||
}
|
||||
|
||||
/** Produce LLVM IR code that operates on scalar values.
|
||||
*
|
||||
* The first `getArgumentTypes().size()` values describe the current row of each column. Supported value types:
|
||||
* - numbers, represented as native numbers;
|
||||
* - nullable numbers, as pointers to native numbers or a null pointer.
|
||||
* The rest are values returned by `compilePrologue`.
|
||||
*
|
||||
* NOTE: the builder is actually guaranteed to be exactly `llvm::IRBuilder<>`, so you may safely
|
||||
* downcast it to that type. This method is specified with `IRBuilderBase` because forward-declaring
|
||||
* templates with default arguments is impossible and including LLVM in such a generic header
|
||||
* as this one is a major pain.
|
||||
*/
|
||||
virtual llvm::Value * compile(llvm::IRBuilderBase & /*builder*/, const ValuePlaceholders & /*values*/) const
|
||||
virtual llvm::Value * compile(llvm::IRBuilderBase & /*builder*/, ValuePlaceholders /*values*/) const
|
||||
{
|
||||
throw Exception(getName() + " is not JIT-compilable", ErrorCodes::NOT_IMPLEMENTED);
|
||||
}
|
||||
@ -286,11 +294,8 @@ public:
|
||||
using PreparedFunctionImpl::execute;
|
||||
using FunctionBuilderImpl::getReturnTypeImpl;
|
||||
using FunctionBuilderImpl::getLambdaArgumentTypesImpl;
|
||||
|
||||
using FunctionBuilderImpl::getReturnType;
|
||||
|
||||
virtual bool isCompilable(const DataTypes & /*types*/) const { return false; }
|
||||
|
||||
bool isCompilable() const final
|
||||
{
|
||||
throw Exception("isCompilable without explicit types is not implemented for IFunction", ErrorCodes::NOT_IMPLEMENTED);
|
||||
@ -301,12 +306,12 @@ public:
|
||||
throw Exception("prepare is not implemented for IFunction", ErrorCodes::NOT_IMPLEMENTED);
|
||||
}
|
||||
|
||||
virtual llvm::Value * compile(llvm::IRBuilderBase & /*builder*/, const DataTypes & /*types*/, const ValuePlaceholders & /*values*/) const
|
||||
std::vector<llvm::Value *> compilePrologue(llvm::IRBuilderBase &) const final
|
||||
{
|
||||
throw Exception(getName() + " is not JIT-compilable", ErrorCodes::NOT_IMPLEMENTED);
|
||||
throw Exception("compilePrologue without explicit types is not implemented for IFunction", ErrorCodes::NOT_IMPLEMENTED);
|
||||
}
|
||||
|
||||
llvm::Value * compile(llvm::IRBuilderBase & /*builder*/, const ValuePlaceholders & /*values*/) const final
|
||||
llvm::Value * compile(llvm::IRBuilderBase & /*builder*/, ValuePlaceholders /*values*/) const final
|
||||
{
|
||||
throw Exception("compile without explicit types is not implemented for IFunction", ErrorCodes::NOT_IMPLEMENTED);
|
||||
}
|
||||
@ -321,7 +326,25 @@ public:
|
||||
throw Exception("getReturnType is not implemented for IFunction", ErrorCodes::NOT_IMPLEMENTED);
|
||||
}
|
||||
|
||||
bool isCompilable(const DataTypes & arguments) const;
|
||||
|
||||
std::vector<llvm::Value *> compilePrologue(llvm::IRBuilderBase &, const DataTypes & arguments) const;
|
||||
|
||||
llvm::Value * compile(llvm::IRBuilderBase &, const DataTypes & arguments, ValuePlaceholders values) const;
|
||||
|
||||
protected:
|
||||
virtual bool isCompilableImpl(const DataTypes &) const { return false; }
|
||||
|
||||
virtual std::vector<llvm::Value *> compilePrologueImpl(llvm::IRBuilderBase &, const DataTypes &) const
|
||||
{
|
||||
return {};
|
||||
}
|
||||
|
||||
virtual llvm::Value * compileImpl(llvm::IRBuilderBase &, const DataTypes &, ValuePlaceholders) const
|
||||
{
|
||||
throw Exception(getName() + " is not JIT-compilable", ErrorCodes::NOT_IMPLEMENTED);
|
||||
}
|
||||
|
||||
FunctionBasePtr buildImpl(const ColumnsWithTypeAndName & /*arguments*/, const DataTypePtr & /*return_type*/) const final
|
||||
{
|
||||
throw Exception("buildImpl is not implemented for IFunction", ErrorCodes::NOT_IMPLEMENTED);
|
||||
@ -363,7 +386,9 @@ public:
|
||||
|
||||
bool isCompilable() const override { return function->isCompilable(arguments); }
|
||||
|
||||
llvm::Value * compile(llvm::IRBuilderBase & builder, const ValuePlaceholders & values) const override { return function->compile(builder, arguments, values); }
|
||||
std::vector<llvm::Value *> compilePrologue(llvm::IRBuilderBase & builder) const override { return function->compilePrologue(builder, arguments); }
|
||||
|
||||
llvm::Value * compile(llvm::IRBuilderBase & builder, ValuePlaceholders values) const override { return function->compile(builder, arguments, std::move(values)); }
|
||||
|
||||
PreparedFunctionPtr prepare(const Block & /*sample_block*/) const override { return std::make_shared<DefaultExecutable>(function); }
|
||||
|
||||
|
@ -3,10 +3,12 @@
|
||||
#if USE_EMBEDDED_COMPILER
|
||||
|
||||
#include <Columns/ColumnConst.h>
|
||||
#include <Columns/ColumnNullable.h>
|
||||
#include <Columns/ColumnVector.h>
|
||||
#include <Common/typeid_cast.h>
|
||||
#include <DataTypes/DataTypeNullable.h>
|
||||
#include <DataTypes/DataTypesNumber.h>
|
||||
#include <DataTypes/Native.h>
|
||||
|
||||
#include <llvm/IR/BasicBlock.h>
|
||||
#include <llvm/IR/DataLayout.h>
|
||||
@ -17,7 +19,6 @@
|
||||
#include <llvm/IR/Mangler.h>
|
||||
#include <llvm/IR/Module.h>
|
||||
#include <llvm/IR/Type.h>
|
||||
#include <llvm/IR/Verifier.h>
|
||||
#include <llvm/ExecutionEngine/ExecutionEngine.h>
|
||||
#include <llvm/ExecutionEngine/JITSymbol.h>
|
||||
#include <llvm/ExecutionEngine/SectionMemoryManager.h>
|
||||
@ -25,12 +26,10 @@
|
||||
#include <llvm/ExecutionEngine/Orc/IRCompileLayer.h>
|
||||
#include <llvm/ExecutionEngine/Orc/NullResolver.h>
|
||||
#include <llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h>
|
||||
#include <llvm/Support/TargetSelect.h>
|
||||
#include <llvm/Support/raw_ostream.h>
|
||||
#include <llvm/Target/TargetMachine.h>
|
||||
#include <llvm/Support/TargetSelect.h>
|
||||
#include <llvm/Transforms/IPO/PassManagerBuilder.h>
|
||||
|
||||
#include <stdexcept>
|
||||
|
||||
namespace DB
|
||||
{
|
||||
@ -40,12 +39,6 @@ namespace ErrorCodes
|
||||
extern const int LOGICAL_ERROR;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static bool typeIsA(const DataTypePtr & type)
|
||||
{
|
||||
return typeid_cast<const T *>(removeNullable(type).get());;
|
||||
}
|
||||
|
||||
struct LLVMContext::Data
|
||||
{
|
||||
llvm::LLVMContext context;
|
||||
@ -67,33 +60,6 @@ struct LLVMContext::Data
|
||||
module->setDataLayout(layout);
|
||||
module->setTargetTriple(machine->getTargetTriple().getTriple());
|
||||
}
|
||||
|
||||
llvm::Type * toNativeType(const DataTypePtr & type)
|
||||
{
|
||||
/// LLVM doesn't have unsigned types, it has unsigned instructions.
|
||||
if (typeIsA<DataTypeInt8>(type) || typeIsA<DataTypeUInt8>(type))
|
||||
return builder.getInt8Ty();
|
||||
if (typeIsA<DataTypeInt16>(type) || typeIsA<DataTypeUInt16>(type))
|
||||
return builder.getInt16Ty();
|
||||
if (typeIsA<DataTypeInt32>(type) || typeIsA<DataTypeUInt32>(type))
|
||||
return builder.getInt32Ty();
|
||||
if (typeIsA<DataTypeInt64>(type) || typeIsA<DataTypeUInt64>(type))
|
||||
return builder.getInt64Ty();
|
||||
if (typeIsA<DataTypeFloat32>(type))
|
||||
return builder.getFloatTy();
|
||||
if (typeIsA<DataTypeFloat64>(type))
|
||||
return builder.getDoubleTy();
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
const void * lookup(const std::string& name)
|
||||
{
|
||||
std::string mangledName;
|
||||
llvm::raw_string_ostream mangledNameStream(mangledName);
|
||||
llvm::Mangler::getNameWithPrefix(mangledNameStream, name, layout);
|
||||
/// why is `findSymbol` not const? we may never know.
|
||||
return reinterpret_cast<const void *>(compileLayer.findSymbol(mangledNameStream.str(), false).getAddress().get());
|
||||
}
|
||||
};
|
||||
|
||||
LLVMContext::LLVMContext()
|
||||
@ -104,7 +70,6 @@ void LLVMContext::finalize()
|
||||
{
|
||||
if (!shared->module->size())
|
||||
return;
|
||||
shared->module->print(llvm::errs(), nullptr, false, true);
|
||||
llvm::PassManagerBuilder builder;
|
||||
llvm::legacy::FunctionPassManager fpm(shared->module.get());
|
||||
builder.OptLevel = 2;
|
||||
@ -112,46 +77,67 @@ void LLVMContext::finalize()
|
||||
for (auto & function : *shared->module)
|
||||
fpm.run(function);
|
||||
llvm::cantFail(shared->compileLayer.addModule(shared->module, std::make_shared<llvm::orc::NullResolver>()));
|
||||
shared->module->print(llvm::errs(), nullptr, false, true);
|
||||
}
|
||||
|
||||
bool LLVMContext::isCompilable(const IFunctionBase& function) const
|
||||
{
|
||||
if (!function.isCompilable() || !shared->toNativeType(function.getReturnType()))
|
||||
if (!function.isCompilable() || !toNativeType(shared->builder, function.getReturnType()))
|
||||
return false;
|
||||
for (const auto & type : function.getArgumentTypes())
|
||||
if (!shared->toNativeType(type))
|
||||
if (!toNativeType(shared->builder, type))
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
LLVMPreparedFunction::LLVMPreparedFunction(LLVMContext context, std::shared_ptr<const IFunctionBase> parent)
|
||||
: parent(parent), context(context), function(context->lookup(parent->getName()))
|
||||
{}
|
||||
: parent(parent), context(context)
|
||||
{
|
||||
std::string mangledName;
|
||||
llvm::raw_string_ostream mangledNameStream(mangledName);
|
||||
llvm::Mangler::getNameWithPrefix(mangledNameStream, parent->getName(), context->layout);
|
||||
function = reinterpret_cast<const void *>(context->compileLayer.findSymbol(mangledNameStream.str(), false).getAddress().get());
|
||||
}
|
||||
|
||||
namespace
|
||||
{
|
||||
struct ColumnData
|
||||
{
|
||||
const char * data;
|
||||
const char * data = nullptr;
|
||||
const char * null = nullptr;
|
||||
size_t stride;
|
||||
};
|
||||
|
||||
struct ColumnDataPlaceholders
|
||||
{
|
||||
llvm::PHINode * data;
|
||||
llvm::PHINode * null;
|
||||
llvm::Value * data_init;
|
||||
llvm::Value * null_init;
|
||||
llvm::Value * stride;
|
||||
llvm::Value * is_const;
|
||||
};
|
||||
}
|
||||
|
||||
static ColumnData getColumnData(const IColumn * column)
|
||||
{
|
||||
if (!column->isFixedAndContiguous())
|
||||
throw Exception("column type " + column->getName() + " is not a contiguous array; its data type "
|
||||
"should've had no native equivalent in LLVMContext::Data::toNativeType", ErrorCodes::LOGICAL_ERROR);
|
||||
/// TODO: handle ColumnNullable
|
||||
return {column->getRawData().data, !column->isColumnConst() ? column->sizeOfValueIfFixed() : 0};
|
||||
ColumnData result;
|
||||
const bool is_const = column->isColumnConst();
|
||||
if (is_const)
|
||||
column = &reinterpret_cast<const ColumnConst *>(column)->getDataColumn();
|
||||
if (auto * nullable = typeid_cast<const ColumnNullable *>(column))
|
||||
{
|
||||
result.null = nullable->getNullMapColumn().getRawData().data;
|
||||
column = &nullable->getNestedColumn();
|
||||
}
|
||||
result.data = column->getRawData().data;
|
||||
result.stride = is_const ? 0 : column->sizeOfValueIfFixed();
|
||||
return result;
|
||||
}
|
||||
|
||||
void LLVMPreparedFunction::executeImpl(Block & block, const ColumnNumbers & arguments, size_t result)
|
||||
void LLVMPreparedFunction::execute(Block & block, const ColumnNumbers & arguments, size_t result)
|
||||
{
|
||||
size_t block_size = block.rows();
|
||||
/// assuming that the function has default behavior on NULL, the column will be wrapped by `PreparedFunctionImpl::execute`.
|
||||
auto col_res = removeNullable(parent->getReturnType())->createColumn()->cloneResized(block_size);
|
||||
auto col_res = parent->getReturnType()->createColumn()->cloneResized(block_size);
|
||||
if (block_size)
|
||||
{
|
||||
std::vector<ColumnData> columns(arguments.size() + 1);
|
||||
@ -171,22 +157,34 @@ void LLVMPreparedFunction::executeImpl(Block & block, const ColumnNumbers & argu
|
||||
LLVMFunction::LLVMFunction(ExpressionActions::Actions actions_, LLVMContext context, const Block & sample_block)
|
||||
: actions(std::move(actions_)), context(context)
|
||||
{
|
||||
auto & b = context->builder;
|
||||
auto * size_type = b.getIntNTy(sizeof(size_t) * 8);
|
||||
auto * data_type = llvm::StructType::get(b.getInt8PtrTy(), b.getInt8PtrTy(), size_type);
|
||||
auto * func_type = llvm::FunctionType::get(b.getVoidTy(), { size_type, llvm::PointerType::get(data_type, 0) }, /*isVarArg=*/false);
|
||||
auto * func = llvm::Function::Create(func_type, llvm::Function::ExternalLinkage, actions.back().result_name, context->module.get());
|
||||
auto args = func->args().begin();
|
||||
llvm::Value * counter = &*args++;
|
||||
llvm::Value * columns = &*args++;
|
||||
|
||||
auto * entry = llvm::BasicBlock::Create(context->context, "entry", func);
|
||||
b.SetInsertPoint(entry);
|
||||
|
||||
std::unordered_map<std::string, std::function<llvm::Value * ()>> by_name;
|
||||
for (const auto & c : sample_block)
|
||||
{
|
||||
auto generator = [&]() -> llvm::Value *
|
||||
{
|
||||
auto * type = context->toNativeType(c.type);
|
||||
if (typeIsA<DataTypeFloat32>(c.type))
|
||||
return llvm::ConstantFP::get(type, typeid_cast<const ColumnVector<Float32> *>(c.column.get())->getElement(0));
|
||||
if (typeIsA<DataTypeFloat64>(c.type))
|
||||
return llvm::ConstantFP::get(type, typeid_cast<const ColumnVector<Float64> *>(c.column.get())->getElement(0));
|
||||
if (type && type->isIntegerTy())
|
||||
return llvm::ConstantInt::get(type, c.column->getUInt(0));
|
||||
return nullptr;
|
||||
};
|
||||
if (c.column && generator() && !by_name.emplace(c.name, std::move(generator)).second)
|
||||
throw Exception("duplicate constant column " + c.name, ErrorCodes::LOGICAL_ERROR);
|
||||
auto * type = toNativeType(b, c.type);
|
||||
if (!type || !c.column)
|
||||
continue;
|
||||
llvm::Value * value = nullptr;
|
||||
if (type->isFloatTy())
|
||||
value = llvm::ConstantFP::get(type, typeid_cast<const ColumnVector<Float32> *>(c.column.get())->getElement(0));
|
||||
else if (type->isDoubleTy())
|
||||
value = llvm::ConstantFP::get(type, typeid_cast<const ColumnVector<Float64> *>(c.column.get())->getElement(0));
|
||||
else if (type->isIntegerTy())
|
||||
value = llvm::ConstantInt::get(type, c.column->getUInt(0));
|
||||
/// TODO: handle nullable (create a pointer)
|
||||
if (value)
|
||||
by_name[c.name] = [=]() { return value; };
|
||||
}
|
||||
|
||||
std::unordered_set<std::string> seen;
|
||||
@ -196,85 +194,100 @@ LLVMFunction::LLVMFunction(ExpressionActions::Actions actions_, LLVMContext cont
|
||||
const auto & types = action.function->getArgumentTypes();
|
||||
for (size_t i = 0; i < names.size(); i++)
|
||||
{
|
||||
if (seen.emplace(names[i]).second && by_name.find(names[i]) == by_name.end())
|
||||
{
|
||||
arg_names.push_back(names[i]);
|
||||
arg_types.push_back(types[i]);
|
||||
}
|
||||
if (!seen.emplace(names[i]).second || by_name.find(names[i]) != by_name.end())
|
||||
continue;
|
||||
arg_names.push_back(names[i]);
|
||||
arg_types.push_back(types[i]);
|
||||
}
|
||||
seen.insert(action.result_name);
|
||||
}
|
||||
|
||||
auto * char_type = context->builder.getInt8Ty();
|
||||
auto * size_type = context->builder.getIntNTy(sizeof(size_t) * 8);
|
||||
auto * data_type = llvm::StructType::get(llvm::PointerType::get(char_type, 0), size_type);
|
||||
auto * func_type = llvm::FunctionType::get(context->builder.getVoidTy(), { size_type, llvm::PointerType::get(data_type, 0) }, /*isVarArg=*/false);
|
||||
auto * func = llvm::Function::Create(func_type, llvm::Function::ExternalLinkage, actions.back().result_name, context->module.get());
|
||||
auto args = func->args().begin();
|
||||
llvm::Value * counter = &*args++;
|
||||
llvm::Value * columns = &*args++;
|
||||
|
||||
auto * entry = llvm::BasicBlock::Create(context->context, "entry", func);
|
||||
context->builder.SetInsertPoint(entry);
|
||||
|
||||
struct CastedColumnData
|
||||
{
|
||||
llvm::PHINode * data;
|
||||
llvm::Value * data_init;
|
||||
llvm::Value * stride;
|
||||
};
|
||||
std::vector<CastedColumnData> columns_v(arg_types.size() + 1);
|
||||
std::vector<ColumnDataPlaceholders> columns_v(arg_types.size() + 1);
|
||||
for (size_t i = 0; i <= arg_types.size(); i++)
|
||||
{
|
||||
auto * type = llvm::PointerType::getUnqual(context->toNativeType(i == arg_types.size() ? getReturnType() : arg_types[i]));
|
||||
auto * data = context->builder.CreateConstInBoundsGEP2_32(data_type, columns, i, 0);
|
||||
auto * stride = context->builder.CreateConstInBoundsGEP2_32(data_type, columns, i, 1);
|
||||
columns_v[i] = { nullptr, context->builder.CreatePointerCast(context->builder.CreateLoad(data), type), context->builder.CreateLoad(stride) };
|
||||
auto & column_type = (i == arg_types.size()) ? getReturnType() : arg_types[i];
|
||||
auto * type = llvm::PointerType::get(toNativeType(b, removeNullable(column_type)), 0);
|
||||
columns_v[i].data_init = b.CreatePointerCast(b.CreateLoad(b.CreateConstInBoundsGEP2_32(data_type, columns, i, 0)), type);
|
||||
columns_v[i].stride = b.CreateLoad(b.CreateConstInBoundsGEP2_32(data_type, columns, i, 2));
|
||||
if (column_type->isNullable())
|
||||
{
|
||||
columns_v[i].null_init = b.CreateLoad(b.CreateConstInBoundsGEP2_32(data_type, columns, i, 1));
|
||||
columns_v[i].is_const = b.CreateICmpEQ(columns_v[i].stride, b.getIntN(sizeof(size_t) * 8, 0));
|
||||
}
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < arg_types.size(); i++)
|
||||
{
|
||||
by_name[arg_names[i]] = [&, &col = columns_v[i]]() -> llvm::Value *
|
||||
{
|
||||
if (!col.null)
|
||||
return b.CreateLoad(col.data);
|
||||
auto * is_valid = b.CreateICmpNE(b.CreateLoad(col.null), b.getInt8(1));
|
||||
auto * null_ptr = llvm::ConstantPointerNull::get(reinterpret_cast<llvm::PointerType *>(col.data->getType()));
|
||||
return b.CreateSelect(is_valid, col.data, null_ptr);
|
||||
};
|
||||
}
|
||||
for (const auto & action : actions)
|
||||
{
|
||||
ValuePlaceholders input;
|
||||
for (const auto & name : action.argument_names)
|
||||
input.push_back(by_name.at(name));
|
||||
/// TODO: pass compile-time constant arguments to `compilePrologue`?
|
||||
auto extra = action.function->compilePrologue(b);
|
||||
for (auto * value : extra)
|
||||
input.emplace_back([=]() { return value; });
|
||||
by_name[action.result_name] = [&, input = std::move(input)]() { return action.function->compile(b, input); };
|
||||
}
|
||||
|
||||
/// assume nonzero initial value in `counter`
|
||||
auto * loop = llvm::BasicBlock::Create(context->context, "loop", func);
|
||||
context->builder.CreateBr(loop);
|
||||
context->builder.SetInsertPoint(loop);
|
||||
auto * counter_phi = context->builder.CreatePHI(counter->getType(), 2);
|
||||
b.CreateBr(loop);
|
||||
b.SetInsertPoint(loop);
|
||||
auto * counter_phi = b.CreatePHI(counter->getType(), 2);
|
||||
counter_phi->addIncoming(counter, entry);
|
||||
for (auto & col : columns_v)
|
||||
{
|
||||
col.data = context->builder.CreatePHI(col.data_init->getType(), 2);
|
||||
col.data = b.CreatePHI(col.data_init->getType(), 2);
|
||||
col.data->addIncoming(col.data_init, entry);
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < arg_types.size(); i++)
|
||||
if (!by_name.emplace(arg_names[i], [&, i]() { return context->builder.CreateLoad(columns_v[i].data); }).second)
|
||||
throw Exception("duplicate input column name " + arg_names[i], ErrorCodes::LOGICAL_ERROR);
|
||||
for (const auto & action : actions)
|
||||
{
|
||||
ValuePlaceholders action_input;
|
||||
action_input.reserve(action.argument_names.size());
|
||||
for (const auto & name : action.argument_names)
|
||||
action_input.push_back(by_name.at(name));
|
||||
auto generator = [&action, &context, action_input{std::move(action_input)}]()
|
||||
if (col.null_init)
|
||||
{
|
||||
return action.function->compile(context->builder, action_input);
|
||||
};
|
||||
if (!by_name.emplace(action.result_name, std::move(generator)).second)
|
||||
throw Exception("duplicate action result name " + action.result_name, ErrorCodes::LOGICAL_ERROR);
|
||||
col.null = b.CreatePHI(col.null_init->getType(), 2);
|
||||
col.null->addIncoming(col.null_init, entry);
|
||||
}
|
||||
}
|
||||
context->builder.CreateStore(by_name.at(actions.back().result_name)(), columns_v[arg_types.size()].data);
|
||||
|
||||
auto * cur_block = context->builder.GetInsertBlock();
|
||||
auto * result = by_name.at(actions.back().result_name)();
|
||||
if (columns_v[arg_types.size()].null)
|
||||
{
|
||||
auto * read = llvm::BasicBlock::Create(context->context, "not_null", func);
|
||||
auto * join = llvm::BasicBlock::Create(context->context, "join", func);
|
||||
b.CreateCondBr(b.CreateIsNull(result), join, read);
|
||||
b.SetInsertPoint(read);
|
||||
b.CreateStore(b.getInt8(0), columns_v[arg_types.size()].null); /// column initialized to all-NULL
|
||||
b.CreateStore(b.CreateLoad(result), columns_v[arg_types.size()].data);
|
||||
b.CreateBr(join);
|
||||
b.SetInsertPoint(join);
|
||||
}
|
||||
else
|
||||
{
|
||||
b.CreateStore(result, columns_v[arg_types.size()].data);
|
||||
}
|
||||
|
||||
auto * cur_block = b.GetInsertBlock();
|
||||
for (auto & col : columns_v)
|
||||
{
|
||||
auto * as_char = context->builder.CreatePointerCast(col.data, llvm::PointerType::get(char_type, 0));
|
||||
auto * as_type = context->builder.CreatePointerCast(context->builder.CreateGEP(as_char, col.stride), col.data->getType());
|
||||
auto * as_char = b.CreatePointerCast(col.data, b.getInt8PtrTy());
|
||||
auto * as_type = b.CreatePointerCast(b.CreateGEP(as_char, col.stride), col.data->getType());
|
||||
col.data->addIncoming(as_type, cur_block);
|
||||
if (col.null)
|
||||
col.null->addIncoming(b.CreateSelect(col.is_const, col.null, b.CreateConstGEP1_32(col.null, 1)), cur_block);
|
||||
}
|
||||
counter_phi->addIncoming(context->builder.CreateSub(counter_phi, llvm::ConstantInt::get(counter_phi->getType(), 1)), cur_block);
|
||||
counter_phi->addIncoming(b.CreateSub(counter_phi, llvm::ConstantInt::get(size_type, 1)), cur_block);
|
||||
|
||||
auto * end = llvm::BasicBlock::Create(context->context, "end", func);
|
||||
context->builder.CreateCondBr(context->builder.CreateICmpNE(counter_phi, llvm::ConstantInt::get(counter_phi->getType(), 1)), loop, end);
|
||||
context->builder.SetInsertPoint(end);
|
||||
context->builder.CreateRetVoid();
|
||||
b.CreateCondBr(b.CreateICmpNE(counter_phi, llvm::ConstantInt::get(size_type, 1)), loop, end);
|
||||
b.SetInsertPoint(end);
|
||||
b.CreateRetVoid();
|
||||
}
|
||||
|
||||
static Field evaluateFunction(IFunctionBase & function, const IDataType & type, const Field & arg)
|
||||
|
@ -28,7 +28,7 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
class LLVMPreparedFunction : public PreparedFunctionImpl
|
||||
class LLVMPreparedFunction : public IPreparedFunction
|
||||
{
|
||||
std::shared_ptr<const IFunctionBase> parent;
|
||||
LLVMContext context;
|
||||
@ -39,7 +39,7 @@ public:
|
||||
|
||||
String getName() const override { return parent->getName(); }
|
||||
|
||||
void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result) override;
|
||||
void execute(Block & block, const ColumnNumbers & arguments, size_t result) override;
|
||||
};
|
||||
|
||||
class LLVMFunction : public std::enable_shared_from_this<LLVMFunction>, public IFunctionBase
|
||||
|
Loading…
Reference in New Issue
Block a user