If all inputs to a jitted function are constant, return a constant

This commit is contained in:
pyos 2018-05-03 13:22:41 +03:00
parent 1f89849650
commit 23bbf632e5

View File

@ -124,7 +124,7 @@ struct LLVMContext
}
};
class LLVMPreparedFunction : public IPreparedFunction
class LLVMPreparedFunction : public PreparedFunctionImpl
{
std::string name;
std::shared_ptr<LLVMContext> context;
@ -142,7 +142,11 @@ public:
String getName() const override { return name; }
void execute(Block & block, const ColumnNumbers & arguments, size_t result, size_t block_size) override
bool useDefaultImplementationForNulls() const override { return false; }
bool useDefaultImplementationForConstants() const override { return true; }
void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result, size_t block_size) override
{
auto col_res = block.getByPosition(result).type->createColumn()->cloneResized(block_size);
if (block_size)
@ -168,7 +172,7 @@ static void compileFunction(std::shared_ptr<LLVMContext> & context, const IFunct
auto & b = context->builder;
auto * size_type = b.getIntNTy(sizeof(size_t) * 8);
auto * data_type = llvm::StructType::get(b.getInt8PtrTy(), b.getInt8PtrTy(), size_type);
auto * func_type = llvm::FunctionType::get(b.getVoidTy(), { size_type, llvm::PointerType::get(data_type, 0) }, /*isVarArg=*/false);
auto * func_type = llvm::FunctionType::get(b.getVoidTy(), { size_type, data_type->getPointerTo() }, /*isVarArg=*/false);
auto * func = llvm::Function::Create(func_type, llvm::Function::ExternalLinkage, f.getName(), context->module.get());
auto args = func->args().begin();
llvm::Value * counter_arg = &*args++;
@ -180,10 +184,10 @@ static void compileFunction(std::shared_ptr<LLVMContext> & context, const IFunct
for (size_t i = 0; i <= arg_types.size(); i++)
{
auto & type = i == arg_types.size() ? f.getReturnType() : arg_types[i];
auto * native = llvm::PointerType::get(toNativeType(b, removeNullable(type)), 0);
columns[i].data_init = b.CreatePointerCast(b.CreateLoad(b.CreateConstInBoundsGEP2_32(data_type, columns_arg, i, 0)), native);
columns[i].null_init = type->isNullable() ? b.CreateLoad(b.CreateConstInBoundsGEP2_32(data_type, columns_arg, i, 1)) : nullptr;
columns[i].stride = b.CreateLoad(b.CreateConstInBoundsGEP2_32(data_type, columns_arg, i, 2));
auto * data = b.CreateLoad(b.CreateConstInBoundsGEP1_32(data_type, columns_arg, i));
columns[i].data_init = b.CreatePointerCast(b.CreateExtractValue(data, {0}), toNativeType(b, removeNullable(type))->getPointerTo());
columns[i].null_init = type->isNullable() ? b.CreateExtractValue(data, {1}) : nullptr;
columns[i].stride = b.CreateExtractValue(data, {2});
}
/// assume nonzero initial value in `counter_arg`
@ -228,14 +232,11 @@ static void compileFunction(std::shared_ptr<LLVMContext> & context, const IFunct
auto * cur_block = b.GetInsertBlock();
for (auto & col : columns)
{
auto * as_char = b.CreatePointerCast(col.data, b.getInt8PtrTy());
auto * as_type = b.CreatePointerCast(b.CreateInBoundsGEP(as_char, col.stride), col.data->getType());
col.data->addIncoming(as_type, cur_block);
/// currently, stride is either 0 or size of native type
auto * is_const = b.CreateICmpEQ(col.stride, llvm::ConstantInt::get(size_type, 0));
col.data->addIncoming(b.CreateSelect(is_const, col.data, b.CreateConstInBoundsGEP1_32(nullptr, col.data, 1)), cur_block);
if (col.null)
{
auto * is_const = b.CreateICmpEQ(col.stride, llvm::ConstantInt::get(size_type, 0));
col.null->addIncoming(b.CreateSelect(is_const, col.null, b.CreateConstInBoundsGEP1_32(b.getInt8Ty(), col.null, 1)), cur_block);
}
col.null->addIncoming(b.CreateSelect(is_const, col.null, b.CreateConstInBoundsGEP1_32(nullptr, col.null, 1)), cur_block);
}
counter_phi->addIncoming(b.CreateSub(counter_phi, llvm::ConstantInt::get(size_type, 1)), cur_block);