enable jit for modulo

This commit is contained in:
taiyang-li 2024-10-15 18:49:25 +08:00
parent 7fd58a78df
commit a24e63751f
2 changed files with 48 additions and 6 deletions

View File

@ -8,6 +8,10 @@
#include "config.h"
#if USE_EMBEDDED_COMPILER
# include <llvm/IR/IRBuilder.h>
#endif
namespace DB
{
@ -15,6 +19,7 @@ namespace DB
namespace ErrorCodes
{
extern const int ILLEGAL_DIVISION;
extern const int LOGICAL_ERROR;
}
template <typename A, typename B>
@ -158,7 +163,20 @@ struct ModuloImpl
}
#if USE_EMBEDDED_COMPILER
static constexpr bool compilable = false; /// don't know how to throw from LLVM IR
static constexpr bool compilable = true; /// Ignore exceptions in LLVM IR
static llvm::Value * compile(llvm::IRBuilder<> & b, llvm::Value * left, llvm::Value * right, bool is_signed)
{
if (left->getType()->isFloatingPointTy())
{
auto * func_frem = llvm::Intrinsic::getDeclaration(b.GetInsertBlock()->getModule(), llvm::Intrinsic::vp_frem, left->getType());
return b.CreateCall(func_frem, {left, right});
}
else if (left->getType()->isIntegerTy())
return is_signed ? b.CreateSRem(left, right) : b.CreateURem(left, right);
else
throw Exception(ErrorCodes::LOGICAL_ERROR, "ModuloImpl compilation expected native integer or floating point type");
}
#endif
};

View File

@ -2368,7 +2368,18 @@ ColumnPtr executeStringInteger(const ColumnsWithTypeAndName & arguments, const A
using ResultDataType = typename BinaryOperationTraits<Op, LeftDataType, RightDataType>::ResultDataType;
using OpSpec = Op<typename LeftDataType::FieldType, typename RightDataType::FieldType>;
if constexpr (!std::is_same_v<ResultDataType, InvalidType> && !IsDataTypeDecimal<ResultDataType> && OpSpec::compilable)
return true;
{
if constexpr (is_modulo)
{
using PromotedType = std::conditional_t<
std::is_floating_point_v<typename ResultDataType::FieldType>,
Float64,
NumberTraits::ResultOfIf<typename LeftDataType::FieldType, typename RightDataType::FieldType>>;
return std::is_integral_v<PromotedType> || std::is_floating_point_v<PromotedType>;
}
else
return true;
}
}
return false;
});
@ -2393,10 +2404,23 @@ ColumnPtr executeStringInteger(const ColumnsWithTypeAndName & arguments, const A
if constexpr (!std::is_same_v<ResultDataType, InvalidType> && !IsDataTypeDecimal<ResultDataType> && OpSpec::compilable)
{
auto & b = static_cast<llvm::IRBuilder<> &>(builder);
auto * lval = nativeCast(b, arguments[0], result_type);
auto * rval = nativeCast(b, arguments[1], result_type);
result = OpSpec::compile(b, lval, rval, std::is_signed_v<typename ResultDataType::FieldType>);
if constexpr (is_modulo)
{
using PromotedType = std::conditional_t<
std::is_floating_point_v<typename ResultDataType::FieldType>,
Float64,
NumberTraits::ResultOfIf<typename LeftDataType::FieldType, typename RightDataType::FieldType>>;
auto promoted_type = std::make_shared<DataTypeNumber<PromotedType>>();
auto * lval = nativeCast(b, arguments[0], promoted_type);
auto * rval = nativeCast(b, arguments[1], promoted_type);
result = OpSpec::compile(b, lval, rval, std::is_signed_v<PromotedType>);
}
else
{
auto * lval = nativeCast(b, arguments[0], result_type);
auto * rval = nativeCast(b, arguments[1], result_type);
result = OpSpec::compile(b, lval, rval, std::is_signed_v<typename ResultDataType::FieldType>);
}
return true;
}
}