Represent nullable types as pairs instead of pointers.

Turns out LLVM has insertvalue & extractvalue for struct in registers. This is
faster than pointers because null checks are now subject to more optimizations.
This commit is contained in:
pyos 2018-04-28 14:12:21 +03:00
parent 5c75342d54
commit ccc895d162
4 changed files with 84 additions and 78 deletions

View File

@ -1,18 +1,13 @@
#pragma once
#include <Common/config.h>
#if USE_EMBEDDED_COMPILER
#include <DataTypes/DataTypeNullable.h>
#include <DataTypes/DataTypesNumber.h>
namespace llvm
{
class IRBuilderBase;
class Type;
}
#if USE_EMBEDDED_COMPILER
#include <llvm/IR/IRBuilder.h>
#endif
namespace DB
{
@ -22,13 +17,12 @@ namespace ErrorCodes
extern const int NOT_IMPLEMENTED;
}
static llvm::Type * toNativeType([[maybe_unused]] llvm::IRBuilderBase & builder, [[maybe_unused]] const DataTypePtr & type)
static llvm::Type * toNativeType(llvm::IRBuilderBase & builder, const DataTypePtr & type)
{
#if USE_EMBEDDED_COMPILER
if (auto * nullable = typeid_cast<const DataTypeNullable *>(type.get()))
{
auto * wrapped = toNativeType(builder, nullable->getNestedType());
return wrapped ? llvm::PointerType::get(wrapped, 0) : nullptr;
return wrapped ? llvm::StructType::get(wrapped, /* is null = */ builder.getInt1Ty()) : nullptr;
}
/// LLVM doesn't have unsigned types, it has unsigned instructions.
if (typeid_cast<const DataTypeInt8 *>(type.get()) || typeid_cast<const DataTypeUInt8 *>(type.get()))
@ -44,9 +38,18 @@ static llvm::Type * toNativeType([[maybe_unused]] llvm::IRBuilderBase & builder,
if (typeid_cast<const DataTypeFloat64 *>(type.get()))
return builder.getDoubleTy();
return nullptr;
#else
throw Exception("JIT-compilation is disabled", ErrorCodes::NOT_IMPLEMENTED);
#endif
}
static llvm::Constant * getDefaultNativeValue(llvm::IRBuilder<> & builder, llvm::Type * type)
{
if (type->isIntegerTy())
return llvm::ConstantInt::get(type, 0);
if (type->isFloatTy() || type->isDoubleTy())
return llvm::ConstantFP::get(type, 0.0);
auto * as_struct = static_cast<llvm::StructType *>(type); /// nullable
return llvm::ConstantStruct::get(as_struct, getDefaultNativeValue(builder, as_struct->getElementType(0)), builder.getTrue());
}
}
#endif

View File

@ -261,71 +261,74 @@ DataTypePtr FunctionBuilderImpl::getReturnType(const ColumnsWithTypeAndName & ar
return getReturnTypeImpl(arguments);
}
static bool anyNullable(const DataTypes & types)
static std::optional<DataTypes> removeNullables(const DataTypes & types)
{
for (const auto & type : types)
if (typeid_cast<const DataTypeNullable *>(type.get()))
return true;
return false;
{
if (!typeid_cast<const DataTypeNullable *>(type.get()))
continue;
DataTypes filtered;
for (const auto & type : types)
filtered.emplace_back(removeNullable(type));
return filtered;
}
return {};
}
bool IFunction::isCompilable(const DataTypes & arguments) const
{
if (useDefaultImplementationForNulls() && anyNullable(arguments))
{
DataTypes filtered;
for (const auto & type : arguments)
filtered.emplace_back(removeNullable(type));
return isCompilableImpl(filtered);
}
if (useDefaultImplementationForNulls())
if (auto denulled = removeNullables(arguments))
return isCompilableImpl(*denulled);
return isCompilableImpl(arguments);
}
std::vector<llvm::Value *> IFunction::compilePrologue(llvm::IRBuilderBase & builder, const DataTypes & arguments) const
{
auto result = compilePrologueImpl(builder, arguments);
#if USE_EMBEDDED_COMPILER
if (useDefaultImplementationForNulls() && anyNullable(arguments))
result.push_back(static_cast<llvm::IRBuilder<> &>(builder).CreateAlloca(toNativeType(builder, getReturnTypeImpl(arguments))));
#endif
return result;
if (useDefaultImplementationForNulls())
if (auto denulled = removeNullables(arguments))
return compilePrologueImpl(builder, *denulled);
return compilePrologueImpl(builder, arguments);
}
llvm::Value * IFunction::compile(llvm::IRBuilderBase & builder, const DataTypes & arguments, ValuePlaceholders values) const
{
#if USE_EMBEDDED_COMPILER
if (useDefaultImplementationForNulls() && anyNullable(arguments))
if (useDefaultImplementationForNulls())
{
/// FIXME: when only one column is nullable, this is actually slower than the non-jitted version
/// because this involves copying the null map while `wrapInNullable` reuses it.
auto & b = static_cast<llvm::IRBuilder<> &>(builder);
auto * fail = llvm::BasicBlock::Create(b.GetInsertBlock()->getContext(), "", b.GetInsertBlock()->getParent());
auto * join = llvm::BasicBlock::Create(b.GetInsertBlock()->getContext(), "", b.GetInsertBlock()->getParent());
auto * space = values.back()();
values.pop_back();
for (size_t i = 0; i < arguments.size(); i++)
if (auto denulled = removeNullables(arguments))
{
if (!arguments[i]->isNullable())
continue;
values[i] = [&, previous = std::move(values[i])]()
/// FIXME: when only one column is nullable, this is actually slower than the non-jitted version
/// because this involves copying the null map while `wrapInNullable` reuses it.
auto & b = static_cast<llvm::IRBuilder<> &>(builder);
auto * fail = llvm::BasicBlock::Create(b.GetInsertBlock()->getContext(), "", b.GetInsertBlock()->getParent());
auto * join = llvm::BasicBlock::Create(b.GetInsertBlock()->getContext(), "", b.GetInsertBlock()->getParent());
auto * init = getDefaultNativeValue(b, toNativeType(b, makeNullable(getReturnTypeImpl(*denulled))));
for (size_t i = 0; i < arguments.size(); i++)
{
auto * value = previous();
auto * ok = llvm::BasicBlock::Create(b.GetInsertBlock()->getContext(), "", b.GetInsertBlock()->getParent());
b.CreateCondBr(b.CreateIsNull(value), fail, ok);
b.SetInsertPoint(ok);
return b.CreateLoad(value);
};
if (!arguments[i]->isNullable())
continue;
values[i] = [&, previous = std::move(values[i])]()
{
auto * value = previous();
auto * ok = llvm::BasicBlock::Create(b.GetInsertBlock()->getContext(), "", b.GetInsertBlock()->getParent());
b.CreateCondBr(b.CreateExtractValue(value, {1}), fail, ok);
b.SetInsertPoint(ok);
return b.CreateExtractValue(value, {0});
};
}
auto * result = compileImpl(builder, *denulled, std::move(values));
auto * result_nullable = b.CreateInsertValue(b.CreateInsertValue(init, result, {0}), b.getFalse(), {1});
auto * result_block = b.GetInsertBlock();
b.CreateBr(join);
b.SetInsertPoint(fail); /// an empty joining block to avoid keeping track of where we could jump from
b.CreateBr(join);
b.SetInsertPoint(join);
auto * phi = b.CreatePHI(result_nullable->getType(), 2);
phi->addIncoming(result_nullable, result_block);
phi->addIncoming(init, fail);
return phi;
}
b.CreateStore(compileImpl(builder, arguments, std::move(values)), space);
b.CreateBr(join);
auto * result_block = b.GetInsertBlock();
b.SetInsertPoint(fail); /// an empty joining block to avoid keeping track of where we could jump from
b.CreateBr(join);
b.SetInsertPoint(join);
auto * phi = b.CreatePHI(space->getType(), 2);
phi->addIncoming(space, result_block);
phi->addIncoming(llvm::ConstantPointerNull::get(static_cast<llvm::PointerType *>(space->getType())), fail);
return phi;
}
#endif
return compileImpl(builder, arguments, std::move(values));

View File

@ -110,9 +110,8 @@ public:
/** Produce LLVM IR code that operates on scalar values.
*
* The first `getArgumentTypes().size()` values describe the current row of each column. Supported value types:
* - numbers, represented as native numbers;
* - nullable numbers, as pointers to native numbers or a null pointer.
* The first `getArgumentTypes().size()` values describe the current row of each column. (See
* `toNativeType` in DataTypes/Native.h for supported value types and how they map to LLVM types.)
* The rest are values returned by `compilePrologue`.
*
* NOTE: the builder is actually guaranteed to be exactly `llvm::IRBuilder<>`, so you may safely

View File

@ -72,7 +72,12 @@ void LLVMContext::finalize()
return;
llvm::PassManagerBuilder builder;
llvm::legacy::FunctionPassManager fpm(shared->module.get());
builder.OptLevel = 2;
builder.OptLevel = 3;
builder.SLPVectorize = true;
builder.LoopVectorize = true;
builder.RerollLoops = true;
builder.VerifyInput = true;
builder.VerifyOutput = true;
builder.populateFunctionPassManager(fpm);
for (auto & function : *shared->module)
fpm.run(function);
@ -218,13 +223,14 @@ LLVMFunction::LLVMFunction(ExpressionActions::Actions actions_, LLVMContext cont
for (size_t i = 0; i < arg_types.size(); i++)
{
by_name[arg_names[i]] = [&, &col = columns_v[i]]() -> llvm::Value *
by_name[arg_names[i]] = [&, &col = columns_v[i], i]() -> llvm::Value *
{
auto * value = b.CreateLoad(col.data);
if (!col.null)
return b.CreateLoad(col.data);
auto * is_valid = b.CreateICmpNE(b.CreateLoad(col.null), b.getInt8(1));
auto * null_ptr = llvm::ConstantPointerNull::get(reinterpret_cast<llvm::PointerType *>(col.data->getType()));
return b.CreateSelect(is_valid, col.data, null_ptr);
return value;
auto * is_null = b.CreateICmpEQ(b.CreateLoad(col.null), b.getInt8(1));
auto * nullable = getDefaultNativeValue(b, toNativeType(b, arg_types[i]));
return b.CreateInsertValue(b.CreateInsertValue(nullable, value, {0}), is_null, {1});
};
}
for (const auto & action : actions)
@ -259,14 +265,9 @@ LLVMFunction::LLVMFunction(ExpressionActions::Actions actions_, LLVMContext cont
auto * result = by_name.at(actions.back().result_name)();
if (columns_v[arg_types.size()].null)
{
auto * read = llvm::BasicBlock::Create(context->context, "not_null", func);
auto * join = llvm::BasicBlock::Create(context->context, "join", func);
b.CreateCondBr(b.CreateIsNull(result), join, read);
b.SetInsertPoint(read);
b.CreateStore(b.getInt8(0), columns_v[arg_types.size()].null); /// column initialized to all-NULL
b.CreateStore(b.CreateLoad(result), columns_v[arg_types.size()].data);
b.CreateBr(join);
b.SetInsertPoint(join);
b.CreateStore(b.CreateExtractValue(result, {0}), columns_v[arg_types.size()].data);
/// XXX: should zero-extend it to 1 instead of sign-extending to -1?
b.CreateStore(b.CreateExtractValue(result, {1}), columns_v[arg_types.size()].null);
}
else
{
@ -277,10 +278,10 @@ LLVMFunction::LLVMFunction(ExpressionActions::Actions actions_, LLVMContext cont
for (auto & col : columns_v)
{
auto * as_char = b.CreatePointerCast(col.data, b.getInt8PtrTy());
auto * as_type = b.CreatePointerCast(b.CreateGEP(as_char, col.stride), col.data->getType());
auto * as_type = b.CreatePointerCast(b.CreateInBoundsGEP(as_char, col.stride), col.data->getType());
col.data->addIncoming(as_type, cur_block);
if (col.null)
col.null->addIncoming(b.CreateSelect(col.is_const, col.null, b.CreateConstGEP1_32(col.null, 1)), cur_block);
col.null->addIncoming(b.CreateSelect(col.is_const, col.null, b.CreateConstInBoundsGEP1_32(b.getInt8Ty(), col.null, 1)), cur_block);
}
counter_phi->addIncoming(b.CreateSub(counter_phi, llvm::ConstantInt::get(size_type, 1)), cur_block);