diff --git a/dbms/CMakeLists.txt b/dbms/CMakeLists.txt index 2cd85d63700..e3bf825226b 100644 --- a/dbms/CMakeLists.txt +++ b/dbms/CMakeLists.txt @@ -105,6 +105,7 @@ if (USE_EMBEDDED_COMPILER) target_include_directories (dbms BEFORE PUBLIC ${LLVM_INCLUDE_DIRS}) # LLVM 5.0 has a bunch of unused parameters in its header files. # TODO: global-disable no-unused-parameter + set_source_files_properties(src/Functions/IFunction.cpp PROPERTIES COMPILE_FLAGS "-Wno-unused-parameter") set_source_files_properties(src/Interpreters/ExpressionJIT.cpp PROPERTIES COMPILE_FLAGS "-Wno-unused-parameter -Wno-non-virtual-dtor") endif () diff --git a/dbms/src/DataTypes/Native.h b/dbms/src/DataTypes/Native.h new file mode 100644 index 00000000000..411ba6bb1da --- /dev/null +++ b/dbms/src/DataTypes/Native.h @@ -0,0 +1,52 @@ +#pragma once + +#include +#include +#include + +namespace llvm +{ + class IRBuilderBase; + class Type; +} + +#if USE_EMBEDDED_COMPILER +#include +#endif + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int NOT_IMPLEMENTED; +} + +static llvm::Type * toNativeType([[maybe_unused]] llvm::IRBuilderBase & builder, [[maybe_unused]] const DataTypePtr & type) +{ +#if USE_EMBEDDED_COMPILER + if (auto * nullable = typeid_cast(type.get())) + { + auto * wrapped = toNativeType(builder, nullable->getNestedType()); + return wrapped ? llvm::PointerType::get(wrapped, 0) : nullptr; + } + /// LLVM doesn't have unsigned types, it has unsigned instructions. + if (typeid_cast(type.get()) || typeid_cast(type.get())) + return builder.getInt8Ty(); + if (typeid_cast(type.get()) || typeid_cast(type.get())) + return builder.getInt16Ty(); + if (typeid_cast(type.get()) || typeid_cast(type.get())) + return builder.getInt32Ty(); + if (typeid_cast(type.get()) || typeid_cast(type.get())) + return builder.getInt64Ty(); + if (typeid_cast(type.get())) + return builder.getFloatTy(); + if (typeid_cast(type.get())) + return builder.getDoubleTy(); + return nullptr; +#else + throw Exception("JIT-compilation is disabled", ErrorCodes::NOT_IMPLEMENTED); +#endif +} + +} diff --git a/dbms/src/Functions/FunctionsLLVMTest.cpp b/dbms/src/Functions/FunctionsLLVMTest.cpp index 6342daa76c8..8619c5b0201 100644 --- a/dbms/src/Functions/FunctionsLLVMTest.cpp +++ b/dbms/src/Functions/FunctionsLLVMTest.cpp @@ -25,12 +25,12 @@ public: static constexpr auto name = "something"; #if USE_EMBEDDED_COMPILER - bool isCompilable(const DataTypes & types) const override + bool isCompilableImpl(const DataTypes & types) const override { return types.size() == 2 && types[0]->equals(*types[1]); } - llvm::Value * compile(llvm::IRBuilderBase & builder, const DataTypes & types, const ValuePlaceholders & values) const override + llvm::Value * compileImpl(llvm::IRBuilderBase & builder, const DataTypes & types, ValuePlaceholders values) const override { if (types[0]->equals(DataTypeFloat32{}) || types[0]->equals(DataTypeFloat64{})) return static_cast&>(builder).CreateFAdd(values[0](), values[1]()); diff --git a/dbms/src/Functions/IFunction.cpp b/dbms/src/Functions/IFunction.cpp index 12e8dfabbd8..ca8df11719c 100644 --- a/dbms/src/Functions/IFunction.cpp +++ b/dbms/src/Functions/IFunction.cpp @@ -1,14 +1,20 @@ -#include -#include -#include -#include -#include -#include -#include +#include #include +#include +#include +#include +#include +#include +#include +#include +#include #include #include +#if USE_EMBEDDED_COMPILER +#include +#endif + namespace DB { @@ -254,4 +260,75 @@ DataTypePtr FunctionBuilderImpl::getReturnType(const ColumnsWithTypeAndName & ar return getReturnTypeImpl(arguments); } + +static bool anyNullable(const DataTypes & types) +{ + for (const auto & type : types) + if (typeid_cast(type.get())) + return true; + return false; +} + +bool IFunction::isCompilable(const DataTypes & arguments) const +{ + if (useDefaultImplementationForNulls() && anyNullable(arguments)) + { + DataTypes filtered; + for (const auto & type : arguments) + filtered.emplace_back(removeNullable(type)); + return isCompilableImpl(filtered); + } + return isCompilableImpl(arguments); +} + +std::vector IFunction::compilePrologue(llvm::IRBuilderBase & builder, const DataTypes & arguments) const +{ + auto result = compilePrologueImpl(builder, arguments); +#if USE_EMBEDDED_COMPILER + if (useDefaultImplementationForNulls() && anyNullable(arguments)) + result.push_back(static_cast &>(builder).CreateAlloca(toNativeType(builder, getReturnTypeImpl(arguments)))); +#endif + return result; +} + +llvm::Value * IFunction::compile(llvm::IRBuilderBase & builder, const DataTypes & arguments, ValuePlaceholders values) const +{ +#if USE_EMBEDDED_COMPILER + if (useDefaultImplementationForNulls() && anyNullable(arguments)) + { + /// FIXME: when only one column is nullable, this is actually slower than the non-jitted version + /// because this involves copying the null map while `wrapInNullable` reuses it. + auto & b = static_cast &>(builder); + auto * fail = llvm::BasicBlock::Create(b.GetInsertBlock()->getContext(), "", b.GetInsertBlock()->getParent()); + auto * join = llvm::BasicBlock::Create(b.GetInsertBlock()->getContext(), "", b.GetInsertBlock()->getParent()); + auto * space = values.back()(); + values.pop_back(); + for (size_t i = 0; i < arguments.size(); i++) + { + if (!arguments[i]->isNullable()) + continue; + values[i] = [&, previous = std::move(values[i])]() + { + auto * value = previous(); + auto * ok = llvm::BasicBlock::Create(b.GetInsertBlock()->getContext(), "", b.GetInsertBlock()->getParent()); + b.CreateCondBr(b.CreateIsNull(value), fail, ok); + b.SetInsertPoint(ok); + return b.CreateLoad(value); + }; + } + b.CreateStore(compileImpl(builder, arguments, std::move(values)), space); + b.CreateBr(join); + auto * result_block = b.GetInsertBlock(); + b.SetInsertPoint(fail); /// an empty joining block to avoid keeping track of where we could jump from + b.CreateBr(join); + b.SetInsertPoint(join); + auto * phi = b.CreatePHI(space->getType(), 2); + phi->addIncoming(space, result_block); + phi->addIncoming(llvm::ConstantPointerNull::get(static_cast(space->getType())), fail); + return phi; + } +#endif + return compileImpl(builder, arguments, std::move(values)); +} + } diff --git a/dbms/src/Functions/IFunction.h b/dbms/src/Functions/IFunction.h index a07f0a5c99e..43d3ea060e4 100644 --- a/dbms/src/Functions/IFunction.h +++ b/dbms/src/Functions/IFunction.h @@ -102,17 +102,25 @@ public: virtual bool isCompilable() const { return false; } - /** Produce LLVM IR code that operates on *scalar* values. JIT-compilation is only supported for native - * data types, i.e. numbers. This method will never be called if there is a non-number argument or - * a non-number result type. Also, for any compilable function default behavior on NULL values is assumed, - * i.e. the result is NULL if and only if any argument is NULL. + /// Produce LLVM IR code that runs before the loop over the input rows. Mostly useful for allocating stack variables. + virtual std::vector compilePrologue(llvm::IRBuilderBase &) const + { + return {}; + } + + /** Produce LLVM IR code that operates on scalar values. + * + * The first `getArgumentTypes().size()` values describe the current row of each column. Supported value types: + * - numbers, represented as native numbers; + * - nullable numbers, as pointers to native numbers or a null pointer. + * The rest are values returned by `compilePrologue`. * * NOTE: the builder is actually guaranteed to be exactly `llvm::IRBuilder<>`, so you may safely * downcast it to that type. This method is specified with `IRBuilderBase` because forward-declaring * templates with default arguments is impossible and including LLVM in such a generic header * as this one is a major pain. */ - virtual llvm::Value * compile(llvm::IRBuilderBase & /*builder*/, const ValuePlaceholders & /*values*/) const + virtual llvm::Value * compile(llvm::IRBuilderBase & /*builder*/, ValuePlaceholders /*values*/) const { throw Exception(getName() + " is not JIT-compilable", ErrorCodes::NOT_IMPLEMENTED); } @@ -286,11 +294,8 @@ public: using PreparedFunctionImpl::execute; using FunctionBuilderImpl::getReturnTypeImpl; using FunctionBuilderImpl::getLambdaArgumentTypesImpl; - using FunctionBuilderImpl::getReturnType; - virtual bool isCompilable(const DataTypes & /*types*/) const { return false; } - bool isCompilable() const final { throw Exception("isCompilable without explicit types is not implemented for IFunction", ErrorCodes::NOT_IMPLEMENTED); @@ -301,12 +306,12 @@ public: throw Exception("prepare is not implemented for IFunction", ErrorCodes::NOT_IMPLEMENTED); } - virtual llvm::Value * compile(llvm::IRBuilderBase & /*builder*/, const DataTypes & /*types*/, const ValuePlaceholders & /*values*/) const + std::vector compilePrologue(llvm::IRBuilderBase &) const final { - throw Exception(getName() + " is not JIT-compilable", ErrorCodes::NOT_IMPLEMENTED); + throw Exception("compilePrologue without explicit types is not implemented for IFunction", ErrorCodes::NOT_IMPLEMENTED); } - llvm::Value * compile(llvm::IRBuilderBase & /*builder*/, const ValuePlaceholders & /*values*/) const final + llvm::Value * compile(llvm::IRBuilderBase & /*builder*/, ValuePlaceholders /*values*/) const final { throw Exception("compile without explicit types is not implemented for IFunction", ErrorCodes::NOT_IMPLEMENTED); } @@ -321,7 +326,25 @@ public: throw Exception("getReturnType is not implemented for IFunction", ErrorCodes::NOT_IMPLEMENTED); } + bool isCompilable(const DataTypes & arguments) const; + + std::vector compilePrologue(llvm::IRBuilderBase &, const DataTypes & arguments) const; + + llvm::Value * compile(llvm::IRBuilderBase &, const DataTypes & arguments, ValuePlaceholders values) const; + protected: + virtual bool isCompilableImpl(const DataTypes &) const { return false; } + + virtual std::vector compilePrologueImpl(llvm::IRBuilderBase &, const DataTypes &) const + { + return {}; + } + + virtual llvm::Value * compileImpl(llvm::IRBuilderBase &, const DataTypes &, ValuePlaceholders) const + { + throw Exception(getName() + " is not JIT-compilable", ErrorCodes::NOT_IMPLEMENTED); + } + FunctionBasePtr buildImpl(const ColumnsWithTypeAndName & /*arguments*/, const DataTypePtr & /*return_type*/) const final { throw Exception("buildImpl is not implemented for IFunction", ErrorCodes::NOT_IMPLEMENTED); @@ -363,7 +386,9 @@ public: bool isCompilable() const override { return function->isCompilable(arguments); } - llvm::Value * compile(llvm::IRBuilderBase & builder, const ValuePlaceholders & values) const override { return function->compile(builder, arguments, values); } + std::vector compilePrologue(llvm::IRBuilderBase & builder) const override { return function->compilePrologue(builder, arguments); } + + llvm::Value * compile(llvm::IRBuilderBase & builder, ValuePlaceholders values) const override { return function->compile(builder, arguments, std::move(values)); } PreparedFunctionPtr prepare(const Block & /*sample_block*/) const override { return std::make_shared(function); } diff --git a/dbms/src/Interpreters/ExpressionJIT.cpp b/dbms/src/Interpreters/ExpressionJIT.cpp index 29d03db4f0d..ef46fb67f94 100644 --- a/dbms/src/Interpreters/ExpressionJIT.cpp +++ b/dbms/src/Interpreters/ExpressionJIT.cpp @@ -3,10 +3,12 @@ #if USE_EMBEDDED_COMPILER #include +#include #include #include #include #include +#include #include #include @@ -17,7 +19,6 @@ #include #include #include -#include #include #include #include @@ -25,12 +26,10 @@ #include #include #include -#include -#include #include +#include #include -#include namespace DB { @@ -40,12 +39,6 @@ namespace ErrorCodes extern const int LOGICAL_ERROR; } -template -static bool typeIsA(const DataTypePtr & type) -{ - return typeid_cast(removeNullable(type).get());; -} - struct LLVMContext::Data { llvm::LLVMContext context; @@ -67,33 +60,6 @@ struct LLVMContext::Data module->setDataLayout(layout); module->setTargetTriple(machine->getTargetTriple().getTriple()); } - - llvm::Type * toNativeType(const DataTypePtr & type) - { - /// LLVM doesn't have unsigned types, it has unsigned instructions. - if (typeIsA(type) || typeIsA(type)) - return builder.getInt8Ty(); - if (typeIsA(type) || typeIsA(type)) - return builder.getInt16Ty(); - if (typeIsA(type) || typeIsA(type)) - return builder.getInt32Ty(); - if (typeIsA(type) || typeIsA(type)) - return builder.getInt64Ty(); - if (typeIsA(type)) - return builder.getFloatTy(); - if (typeIsA(type)) - return builder.getDoubleTy(); - return nullptr; - } - - const void * lookup(const std::string& name) - { - std::string mangledName; - llvm::raw_string_ostream mangledNameStream(mangledName); - llvm::Mangler::getNameWithPrefix(mangledNameStream, name, layout); - /// why is `findSymbol` not const? we may never know. - return reinterpret_cast(compileLayer.findSymbol(mangledNameStream.str(), false).getAddress().get()); - } }; LLVMContext::LLVMContext() @@ -104,7 +70,6 @@ void LLVMContext::finalize() { if (!shared->module->size()) return; - shared->module->print(llvm::errs(), nullptr, false, true); llvm::PassManagerBuilder builder; llvm::legacy::FunctionPassManager fpm(shared->module.get()); builder.OptLevel = 2; @@ -112,46 +77,67 @@ void LLVMContext::finalize() for (auto & function : *shared->module) fpm.run(function); llvm::cantFail(shared->compileLayer.addModule(shared->module, std::make_shared())); - shared->module->print(llvm::errs(), nullptr, false, true); } bool LLVMContext::isCompilable(const IFunctionBase& function) const { - if (!function.isCompilable() || !shared->toNativeType(function.getReturnType())) + if (!function.isCompilable() || !toNativeType(shared->builder, function.getReturnType())) return false; for (const auto & type : function.getArgumentTypes()) - if (!shared->toNativeType(type)) + if (!toNativeType(shared->builder, type)) return false; return true; } LLVMPreparedFunction::LLVMPreparedFunction(LLVMContext context, std::shared_ptr parent) - : parent(parent), context(context), function(context->lookup(parent->getName())) -{} + : parent(parent), context(context) +{ + std::string mangledName; + llvm::raw_string_ostream mangledNameStream(mangledName); + llvm::Mangler::getNameWithPrefix(mangledNameStream, parent->getName(), context->layout); + function = reinterpret_cast(context->compileLayer.findSymbol(mangledNameStream.str(), false).getAddress().get()); +} namespace { struct ColumnData { - const char * data; + const char * data = nullptr; + const char * null = nullptr; size_t stride; }; + + struct ColumnDataPlaceholders + { + llvm::PHINode * data; + llvm::PHINode * null; + llvm::Value * data_init; + llvm::Value * null_init; + llvm::Value * stride; + llvm::Value * is_const; + }; } static ColumnData getColumnData(const IColumn * column) { - if (!column->isFixedAndContiguous()) - throw Exception("column type " + column->getName() + " is not a contiguous array; its data type " - "should've had no native equivalent in LLVMContext::Data::toNativeType", ErrorCodes::LOGICAL_ERROR); - /// TODO: handle ColumnNullable - return {column->getRawData().data, !column->isColumnConst() ? column->sizeOfValueIfFixed() : 0}; + ColumnData result; + const bool is_const = column->isColumnConst(); + if (is_const) + column = &reinterpret_cast(column)->getDataColumn(); + if (auto * nullable = typeid_cast(column)) + { + result.null = nullable->getNullMapColumn().getRawData().data; + column = &nullable->getNestedColumn(); + } + result.data = column->getRawData().data; + result.stride = is_const ? 0 : column->sizeOfValueIfFixed(); + return result; } -void LLVMPreparedFunction::executeImpl(Block & block, const ColumnNumbers & arguments, size_t result) +void LLVMPreparedFunction::execute(Block & block, const ColumnNumbers & arguments, size_t result) { size_t block_size = block.rows(); - /// assuming that the function has default behavior on NULL, the column will be wrapped by `PreparedFunctionImpl::execute`. - auto col_res = removeNullable(parent->getReturnType())->createColumn()->cloneResized(block_size); + auto col_res = parent->getReturnType()->createColumn()->cloneResized(block_size); if (block_size) { std::vector columns(arguments.size() + 1); @@ -171,22 +157,34 @@ void LLVMPreparedFunction::executeImpl(Block & block, const ColumnNumbers & argu LLVMFunction::LLVMFunction(ExpressionActions::Actions actions_, LLVMContext context, const Block & sample_block) : actions(std::move(actions_)), context(context) { + auto & b = context->builder; + auto * size_type = b.getIntNTy(sizeof(size_t) * 8); + auto * data_type = llvm::StructType::get(b.getInt8PtrTy(), b.getInt8PtrTy(), size_type); + auto * func_type = llvm::FunctionType::get(b.getVoidTy(), { size_type, llvm::PointerType::get(data_type, 0) }, /*isVarArg=*/false); + auto * func = llvm::Function::Create(func_type, llvm::Function::ExternalLinkage, actions.back().result_name, context->module.get()); + auto args = func->args().begin(); + llvm::Value * counter = &*args++; + llvm::Value * columns = &*args++; + + auto * entry = llvm::BasicBlock::Create(context->context, "entry", func); + b.SetInsertPoint(entry); + std::unordered_map> by_name; for (const auto & c : sample_block) { - auto generator = [&]() -> llvm::Value * - { - auto * type = context->toNativeType(c.type); - if (typeIsA(c.type)) - return llvm::ConstantFP::get(type, typeid_cast *>(c.column.get())->getElement(0)); - if (typeIsA(c.type)) - return llvm::ConstantFP::get(type, typeid_cast *>(c.column.get())->getElement(0)); - if (type && type->isIntegerTy()) - return llvm::ConstantInt::get(type, c.column->getUInt(0)); - return nullptr; - }; - if (c.column && generator() && !by_name.emplace(c.name, std::move(generator)).second) - throw Exception("duplicate constant column " + c.name, ErrorCodes::LOGICAL_ERROR); + auto * type = toNativeType(b, c.type); + if (!type || !c.column) + continue; + llvm::Value * value = nullptr; + if (type->isFloatTy()) + value = llvm::ConstantFP::get(type, typeid_cast *>(c.column.get())->getElement(0)); + else if (type->isDoubleTy()) + value = llvm::ConstantFP::get(type, typeid_cast *>(c.column.get())->getElement(0)); + else if (type->isIntegerTy()) + value = llvm::ConstantInt::get(type, c.column->getUInt(0)); + /// TODO: handle nullable (create a pointer) + if (value) + by_name[c.name] = [=]() { return value; }; } std::unordered_set seen; @@ -196,85 +194,100 @@ LLVMFunction::LLVMFunction(ExpressionActions::Actions actions_, LLVMContext cont const auto & types = action.function->getArgumentTypes(); for (size_t i = 0; i < names.size(); i++) { - if (seen.emplace(names[i]).second && by_name.find(names[i]) == by_name.end()) - { - arg_names.push_back(names[i]); - arg_types.push_back(types[i]); - } + if (!seen.emplace(names[i]).second || by_name.find(names[i]) != by_name.end()) + continue; + arg_names.push_back(names[i]); + arg_types.push_back(types[i]); } seen.insert(action.result_name); } - auto * char_type = context->builder.getInt8Ty(); - auto * size_type = context->builder.getIntNTy(sizeof(size_t) * 8); - auto * data_type = llvm::StructType::get(llvm::PointerType::get(char_type, 0), size_type); - auto * func_type = llvm::FunctionType::get(context->builder.getVoidTy(), { size_type, llvm::PointerType::get(data_type, 0) }, /*isVarArg=*/false); - auto * func = llvm::Function::Create(func_type, llvm::Function::ExternalLinkage, actions.back().result_name, context->module.get()); - auto args = func->args().begin(); - llvm::Value * counter = &*args++; - llvm::Value * columns = &*args++; - - auto * entry = llvm::BasicBlock::Create(context->context, "entry", func); - context->builder.SetInsertPoint(entry); - - struct CastedColumnData - { - llvm::PHINode * data; - llvm::Value * data_init; - llvm::Value * stride; - }; - std::vector columns_v(arg_types.size() + 1); + std::vector columns_v(arg_types.size() + 1); for (size_t i = 0; i <= arg_types.size(); i++) { - auto * type = llvm::PointerType::getUnqual(context->toNativeType(i == arg_types.size() ? getReturnType() : arg_types[i])); - auto * data = context->builder.CreateConstInBoundsGEP2_32(data_type, columns, i, 0); - auto * stride = context->builder.CreateConstInBoundsGEP2_32(data_type, columns, i, 1); - columns_v[i] = { nullptr, context->builder.CreatePointerCast(context->builder.CreateLoad(data), type), context->builder.CreateLoad(stride) }; + auto & column_type = (i == arg_types.size()) ? getReturnType() : arg_types[i]; + auto * type = llvm::PointerType::get(toNativeType(b, removeNullable(column_type)), 0); + columns_v[i].data_init = b.CreatePointerCast(b.CreateLoad(b.CreateConstInBoundsGEP2_32(data_type, columns, i, 0)), type); + columns_v[i].stride = b.CreateLoad(b.CreateConstInBoundsGEP2_32(data_type, columns, i, 2)); + if (column_type->isNullable()) + { + columns_v[i].null_init = b.CreateLoad(b.CreateConstInBoundsGEP2_32(data_type, columns, i, 1)); + columns_v[i].is_const = b.CreateICmpEQ(columns_v[i].stride, b.getIntN(sizeof(size_t) * 8, 0)); + } + } + + for (size_t i = 0; i < arg_types.size(); i++) + { + by_name[arg_names[i]] = [&, &col = columns_v[i]]() -> llvm::Value * + { + if (!col.null) + return b.CreateLoad(col.data); + auto * is_valid = b.CreateICmpNE(b.CreateLoad(col.null), b.getInt8(1)); + auto * null_ptr = llvm::ConstantPointerNull::get(reinterpret_cast(col.data->getType())); + return b.CreateSelect(is_valid, col.data, null_ptr); + }; + } + for (const auto & action : actions) + { + ValuePlaceholders input; + for (const auto & name : action.argument_names) + input.push_back(by_name.at(name)); + /// TODO: pass compile-time constant arguments to `compilePrologue`? + auto extra = action.function->compilePrologue(b); + for (auto * value : extra) + input.emplace_back([=]() { return value; }); + by_name[action.result_name] = [&, input = std::move(input)]() { return action.function->compile(b, input); }; } /// assume nonzero initial value in `counter` auto * loop = llvm::BasicBlock::Create(context->context, "loop", func); - context->builder.CreateBr(loop); - context->builder.SetInsertPoint(loop); - auto * counter_phi = context->builder.CreatePHI(counter->getType(), 2); + b.CreateBr(loop); + b.SetInsertPoint(loop); + auto * counter_phi = b.CreatePHI(counter->getType(), 2); counter_phi->addIncoming(counter, entry); for (auto & col : columns_v) { - col.data = context->builder.CreatePHI(col.data_init->getType(), 2); + col.data = b.CreatePHI(col.data_init->getType(), 2); col.data->addIncoming(col.data_init, entry); - } - - for (size_t i = 0; i < arg_types.size(); i++) - if (!by_name.emplace(arg_names[i], [&, i]() { return context->builder.CreateLoad(columns_v[i].data); }).second) - throw Exception("duplicate input column name " + arg_names[i], ErrorCodes::LOGICAL_ERROR); - for (const auto & action : actions) - { - ValuePlaceholders action_input; - action_input.reserve(action.argument_names.size()); - for (const auto & name : action.argument_names) - action_input.push_back(by_name.at(name)); - auto generator = [&action, &context, action_input{std::move(action_input)}]() + if (col.null_init) { - return action.function->compile(context->builder, action_input); - }; - if (!by_name.emplace(action.result_name, std::move(generator)).second) - throw Exception("duplicate action result name " + action.result_name, ErrorCodes::LOGICAL_ERROR); + col.null = b.CreatePHI(col.null_init->getType(), 2); + col.null->addIncoming(col.null_init, entry); + } } - context->builder.CreateStore(by_name.at(actions.back().result_name)(), columns_v[arg_types.size()].data); - auto * cur_block = context->builder.GetInsertBlock(); + auto * result = by_name.at(actions.back().result_name)(); + if (columns_v[arg_types.size()].null) + { + auto * read = llvm::BasicBlock::Create(context->context, "not_null", func); + auto * join = llvm::BasicBlock::Create(context->context, "join", func); + b.CreateCondBr(b.CreateIsNull(result), join, read); + b.SetInsertPoint(read); + b.CreateStore(b.getInt8(0), columns_v[arg_types.size()].null); /// column initialized to all-NULL + b.CreateStore(b.CreateLoad(result), columns_v[arg_types.size()].data); + b.CreateBr(join); + b.SetInsertPoint(join); + } + else + { + b.CreateStore(result, columns_v[arg_types.size()].data); + } + + auto * cur_block = b.GetInsertBlock(); for (auto & col : columns_v) { - auto * as_char = context->builder.CreatePointerCast(col.data, llvm::PointerType::get(char_type, 0)); - auto * as_type = context->builder.CreatePointerCast(context->builder.CreateGEP(as_char, col.stride), col.data->getType()); + auto * as_char = b.CreatePointerCast(col.data, b.getInt8PtrTy()); + auto * as_type = b.CreatePointerCast(b.CreateGEP(as_char, col.stride), col.data->getType()); col.data->addIncoming(as_type, cur_block); + if (col.null) + col.null->addIncoming(b.CreateSelect(col.is_const, col.null, b.CreateConstGEP1_32(col.null, 1)), cur_block); } - counter_phi->addIncoming(context->builder.CreateSub(counter_phi, llvm::ConstantInt::get(counter_phi->getType(), 1)), cur_block); + counter_phi->addIncoming(b.CreateSub(counter_phi, llvm::ConstantInt::get(size_type, 1)), cur_block); auto * end = llvm::BasicBlock::Create(context->context, "end", func); - context->builder.CreateCondBr(context->builder.CreateICmpNE(counter_phi, llvm::ConstantInt::get(counter_phi->getType(), 1)), loop, end); - context->builder.SetInsertPoint(end); - context->builder.CreateRetVoid(); + b.CreateCondBr(b.CreateICmpNE(counter_phi, llvm::ConstantInt::get(size_type, 1)), loop, end); + b.SetInsertPoint(end); + b.CreateRetVoid(); } static Field evaluateFunction(IFunctionBase & function, const IDataType & type, const Field & arg) diff --git a/dbms/src/Interpreters/ExpressionJIT.h b/dbms/src/Interpreters/ExpressionJIT.h index 7aa7ee4098a..75d16d9facf 100644 --- a/dbms/src/Interpreters/ExpressionJIT.h +++ b/dbms/src/Interpreters/ExpressionJIT.h @@ -28,7 +28,7 @@ public: } }; -class LLVMPreparedFunction : public PreparedFunctionImpl +class LLVMPreparedFunction : public IPreparedFunction { std::shared_ptr parent; LLVMContext context; @@ -39,7 +39,7 @@ public: String getName() const override { return parent->getName(); } - void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result) override; + void execute(Block & block, const ColumnNumbers & arguments, size_t result) override; }; class LLVMFunction : public std::enable_shared_from_this, public IFunctionBase