enable jit for bitCount

This commit is contained in:
taiyang-li 2024-10-12 18:00:34 +08:00
parent 7ab71ecf90
commit e5886d4e66
5 changed files with 65 additions and 5 deletions

View File

@ -497,7 +497,10 @@ public:
using T0 = typename DataType::FieldType; using T0 = typename DataType::FieldType;
using T1 = typename Op<T0>::ResultType; using T1 = typename Op<T0>::ResultType;
if constexpr (!std::is_same_v<T1, InvalidType> && !IsDataTypeDecimal<DataType> && Op<T0>::compilable) if constexpr (!std::is_same_v<T1, InvalidType> && !IsDataTypeDecimal<DataType> && Op<T0>::compilable)
{
std::cout << "abs is compilable" << std::endl;
return true; return true;
}
} }
return false; return false;
@ -523,9 +526,10 @@ public:
if constexpr (!std::is_same_v<T1, InvalidType> && !IsDataTypeDecimal<DataType> && Op<T0>::compilable) if constexpr (!std::is_same_v<T1, InvalidType> && !IsDataTypeDecimal<DataType> && Op<T0>::compilable)
{ {
auto & b = static_cast<llvm::IRBuilder<> &>(builder); auto & b = static_cast<llvm::IRBuilder<> &>(builder);
if constexpr (std::is_same_v<Op<T0>, AbsImpl<T0>>) if constexpr (std::is_same_v<Op<T0>, AbsImpl<T0>> || std::is_same_v<Op<T0>, BitCountImpl<T0>>)
{ {
/// We don't need to cast the argument to the result type if it's abs function. std::cout << "start to compile abs" << std::endl;
/// We don't need to cast the argument to the result type if it's abs/bitcount function.
result = Op<T0>::compile(b, arguments[0].value, is_signed_v<T0>); result = Op<T0>::compile(b, arguments[0].value, is_signed_v<T0>);
} }
else else

View File

@ -27,7 +27,7 @@ struct AbsImpl
} }
#if USE_EMBEDDED_COMPILER #if USE_EMBEDDED_COMPILER
static constexpr bool compilable = true; /// special type handling, some other time static constexpr bool compilable = true;
static llvm::Value * compile(llvm::IRBuilder<> & b, llvm::Value * arg, bool sign) static llvm::Value * compile(llvm::IRBuilder<> & b, llvm::Value * arg, bool sign)
{ {

View File

@ -38,7 +38,26 @@ struct BitCountImpl
} }
#if USE_EMBEDDED_COMPILER #if USE_EMBEDDED_COMPILER
static constexpr bool compilable = false; static constexpr bool compilable = true;
static llvm::Value * compile(llvm::IRBuilder<> & b, llvm::Value * arg, bool)
{
const auto & type = arg->getType();
llvm::Value * int_value = nullptr;
if (type->isIntegerTy())
int_value = arg;
else if (type->isFloatTy())
int_value = b.CreateBitCast(arg, llvm::Type::getInt32Ty(b.getContext()));
else if (type->isDoubleTy())
int_value = b.CreateBitCast(arg, llvm::Type::getInt64Ty(b.getContext()));
else
throw Exception(ErrorCodes::LOGICAL_ERROR, "BitCountImpl compilation expected native integer or floating-point type");
auto * func_ctpop = llvm::Intrinsic::getDeclaration(b.GetInsertBlock()->getModule(), llvm::Intrinsic::ctpop, {int_value->getType()});
llvm::Value * ctpop_value = b.CreateCall(func_ctpop, {int_value});
return b.CreateZExtOrTrunc(ctpop_value, llvm::Type::getInt8Ty(b.getContext()));
}
#endif #endif
}; };

View File

@ -22,7 +22,40 @@ struct SignImpl
} }
#if USE_EMBEDDED_COMPILER #if USE_EMBEDDED_COMPILER
static constexpr bool compilable = false; static constexpr bool compilable = true;
static llvm::Value * compile(llvm::IRBuilder<> & b, llvm::Value * arg, bool)
{
const auto & type = arg->getType();
if (type->isIntegerTy())
{
auto * zero = llvm::ConstantInt::get(type, 0);
auto * one = llvm::ConstantInt::get(type, 1);
auto * minus_one = llvm::ConstantInt::getSigned(type, -1);
auto * is_zero = b.CreateICmpEQ(arg, zero);
auto * is_negative = b.CreateICmpSLT(arg, zero);
auto * select_zero = b.CreateSelect(is_zero, zero, one);
return b.CreateSelect(is_negative, minus_one, select_zero);
}
else if (type->isDoubleTy() || type->isFloatTy())
{
auto * zero = llvm::ConstantFP::get(type, 0.0);
auto * one = llvm::ConstantFP::get(type, 1.0);
auto * minus_one = llvm::ConstantFP::get(type, -1.0);
auto * is_zero = b.CreateFCmpOEQ(arg, zero);
auto * is_negative = b.CreateFCmpOLT(arg, zero);
auto * select_zero = b.CreateSelect(is_zero, zero, one);
return b.CreateSelect(is_negative, minus_one, select_zero);
}
else
throw Exception(ErrorCodes::LOGICAL_ERROR, "SignImpl compilation expected native integer or floating point type");
}
#endif #endif
}; };

View File

@ -59,7 +59,11 @@ ExpressionActions::ExpressionActions(ActionsDAG actions_dag_, const ExpressionAc
#if USE_EMBEDDED_COMPILER #if USE_EMBEDDED_COMPILER
if (settings.can_compile_expressions && settings.compile_expressions == CompileExpressions::yes) if (settings.can_compile_expressions && settings.compile_expressions == CompileExpressions::yes)
{
std::cout << "old actions_dag: " << actions_dag.dumpDAG() << std::endl;
actions_dag.compileExpressions(settings.min_count_to_compile_expression, lazy_executed_nodes); actions_dag.compileExpressions(settings.min_count_to_compile_expression, lazy_executed_nodes);
std::cout << "new actions_dag: " << actions_dag.dumpDAG() << std::endl;
}
#endif #endif
linearizeActions(lazy_executed_nodes); linearizeActions(lazy_executed_nodes);