Implement jit for comparisons, except for (double, int).

That one has some edge cases which I can't be bothered to code.
This commit is contained in:
pyos 2018-05-10 17:02:25 +03:00
parent bd332b9171
commit 6d2259f2cf
2 changed files with 104 additions and 17 deletions

View File

@ -30,6 +30,15 @@ static inline bool typeIsEither(const IDataType & type)
return (typeid_cast<const Ts *>(&type) || ...);
}
static inline bool typeIsSigned(const IDataType & type)
{
return typeIsEither<
DataTypeInt8, DataTypeInt16, DataTypeInt32, DataTypeInt64,
DataTypeFloat32, DataTypeFloat64,
DataTypeDate, DataTypeDateTime, DataTypeInterval
>(type);
}
static inline llvm::Type * toNativeType(llvm::IRBuilderBase & builder, const IDataType & type)
{
if (auto * nullable = typeid_cast<const DataTypeNullable *>(&type))
@ -77,11 +86,26 @@ static inline llvm::Value * nativeBoolCast(llvm::IRBuilder<> & b, const DataType
throw Exception("Cannot cast non-number " + from->getName() + " to bool", ErrorCodes::NOT_IMPLEMENTED);
}
static inline llvm::Value * nativeCast(llvm::IRBuilder<> & b, const DataTypePtr & from, llvm::Value * value, const DataTypePtr & to)
static inline llvm::Value * nativeCast(llvm::IRBuilder<> & b, const DataTypePtr & from, llvm::Value * value, llvm::Type * to)
{
auto * n_from = value->getType();
if (n_from == to)
return value;
if (n_from->isIntegerTy() && to->isFloatingPointTy())
return typeIsSigned(*from) ? b.CreateSIToFP(value, to) : b.CreateUIToFP(value, to);
if (n_from->isFloatingPointTy() && to->isIntegerTy())
return typeIsSigned(*from) ? b.CreateFPToSI(value, to) : b.CreateFPToUI(value, to);
if (n_from->isIntegerTy() && to->isIntegerTy())
return b.CreateIntCast(value, to, typeIsSigned(*from));
if (n_from->isFloatingPointTy() && to->isFloatingPointTy())
return b.CreateFPCast(value, to);
throw Exception("Cannot cast " + from->getName() + " to requested type", ErrorCodes::NOT_IMPLEMENTED);
}
static inline llvm::Value * nativeCast(llvm::IRBuilder<> & b, const DataTypePtr & from, llvm::Value * value, const DataTypePtr & to)
{
auto * n_to = toNativeType(b, to);
if (n_from == n_to)
if (value->getType() == n_to)
return value;
if (from->isNullable() && to->isNullable())
{
@ -95,21 +119,7 @@ static inline llvm::Value * nativeCast(llvm::IRBuilder<> & b, const DataTypePtr
auto * inner = nativeCast(b, from, value, removeNullable(to));
return b.CreateInsertValue(llvm::Constant::getNullValue(n_to), inner, {0});
}
bool is_signed = typeIsEither<
DataTypeInt8, DataTypeInt16, DataTypeInt32, DataTypeInt64,
DataTypeFloat32, DataTypeFloat64,
DataTypeDate, DataTypeDateTime, DataTypeInterval
>(*from);
if (n_from->isIntegerTy() && n_to->isFloatingPointTy())
return is_signed ? b.CreateSIToFP(value, n_to) : b.CreateUIToFP(value, n_to);
if (n_from->isFloatingPointTy() && n_to->isIntegerTy())
return is_signed ? b.CreateFPToSI(value, n_to) : b.CreateFPToUI(value, n_to);
if (n_from->isIntegerTy() && n_to->isIntegerTy())
return b.CreateIntCast(value, n_to, is_signed);
if (n_from->isFloatingPointTy() && n_to->isFloatingPointTy())
return b.CreateFPCast(value, n_to);
throw Exception("Cannot cast " + from->getName() + " to " + to->getName(), ErrorCodes::NOT_IMPLEMENTED);
return nativeCast(b, from, value, n_to);
}
}

View File

@ -53,12 +53,26 @@ template <typename A, typename B> struct EqualsOp
using SymmetricOp = EqualsOp<B, A>;
static UInt8 apply(A a, B b) { return accurate::equalsOp(a, b); }
#if USE_EMBEDDED_COMPILER
static llvm::Value * compile(llvm::IRBuilder<> & b, llvm::Value * x, llvm::Value * y, bool /*is_signed*/)
{
return x->getType()->isIntegerTy() ? b.CreateICmpEQ(x, y) : b.CreateFCmpOEQ(x, y); /// qNaNs always compare false
}
#endif
};
template <typename A, typename B> struct NotEqualsOp
{
using SymmetricOp = NotEqualsOp<B, A>;
static UInt8 apply(A a, B b) { return accurate::notEqualsOp(a, b); }
#if USE_EMBEDDED_COMPILER
static llvm::Value * compile(llvm::IRBuilder<> & b, llvm::Value * x, llvm::Value * y, bool /*is_signed*/)
{
return x->getType()->isIntegerTy() ? b.CreateICmpNE(x, y) : b.CreateFCmpONE(x, y);
}
#endif
};
template <typename A, typename B> struct GreaterOp;
@ -67,12 +81,26 @@ template <typename A, typename B> struct LessOp
{
using SymmetricOp = GreaterOp<B, A>;
static UInt8 apply(A a, B b) { return accurate::lessOp(a, b); }
#if USE_EMBEDDED_COMPILER
static llvm::Value * compile(llvm::IRBuilder<> & b, llvm::Value * x, llvm::Value * y, bool is_signed)
{
return x->getType()->isIntegerTy() ? (is_signed ? b.CreateICmpSLT(x, y) : b.CreateICmpULT(x, y)) : b.CreateFCmpOLT(x, y);
}
#endif
};
template <typename A, typename B> struct GreaterOp
{
using SymmetricOp = LessOp<B, A>;
static UInt8 apply(A a, B b) { return accurate::greaterOp(a, b); }
#if USE_EMBEDDED_COMPILER
static llvm::Value * compile(llvm::IRBuilder<> & b, llvm::Value * x, llvm::Value * y, bool is_signed)
{
return x->getType()->isIntegerTy() ? (is_signed ? b.CreateICmpSGT(x, y) : b.CreateICmpUGT(x, y)) : b.CreateFCmpOGT(x, y);
}
#endif
};
template <typename A, typename B> struct GreaterOrEqualsOp;
@ -81,12 +109,26 @@ template <typename A, typename B> struct LessOrEqualsOp
{
using SymmetricOp = GreaterOrEqualsOp<B, A>;
static UInt8 apply(A a, B b) { return accurate::lessOrEqualsOp(a, b); }
#if USE_EMBEDDED_COMPILER
static llvm::Value * compile(llvm::IRBuilder<> & b, llvm::Value * x, llvm::Value * y, bool is_signed)
{
return x->getType()->isIntegerTy() ? (is_signed ? b.CreateICmpSLE(x, y) : b.CreateICmpULE(x, y)) : b.CreateFCmpOLE(x, y);
}
#endif
};
template <typename A, typename B> struct GreaterOrEqualsOp
{
using SymmetricOp = LessOrEqualsOp<B, A>;
static UInt8 apply(A a, B b) { return accurate::greaterOrEqualsOp(a, b); }
#if USE_EMBEDDED_COMPILER
static llvm::Value * compile(llvm::IRBuilder<> & b, llvm::Value * x, llvm::Value * y, bool is_signed)
{
return x->getType()->isIntegerTy() ? (is_signed ? b.CreateICmpSGE(x, y) : b.CreateICmpUGE(x, y)) : b.CreateFCmpOGE(x, y);
}
#endif
};
@ -1136,6 +1178,41 @@ public:
col_with_type_and_name_left.type, col_with_type_and_name_right.type,
left_is_num, input_rows_count);
}
#if USE_EMBEDDED_COMPILER
bool isCompilableImpl(const DataTypes & types) const override
{
auto isBigInteger = &typeIsEither<DataTypeInt64, DataTypeUInt64, DataTypeUUID>;
auto isFloatingPoint = &typeIsEither<DataTypeFloat32, DataTypeFloat64>;
if ((isBigInteger(*types[0]) && isFloatingPoint(*types[1])) || (isBigInteger(*types[1]) && isFloatingPoint(*types[0])))
return false; /// TODO: implement (double, int_N where N > double's mantissa width)
return types[0]->isValueRepresentedByNumber() && types[1]->isValueRepresentedByNumber();
}
llvm::Value * compileImpl(llvm::IRBuilderBase & builder, const DataTypes & types, ValuePlaceholders values) const override
{
auto & b = static_cast<llvm::IRBuilder<> &>(builder);
auto * x = values[0]();
auto * y = values[1]();
if (!types[0]->equals(*types[1]))
{
llvm::Type * common;
if (x->getType()->isIntegerTy() && y->getType()->isIntegerTy())
common = b.getIntNTy(std::max(
/// if one integer has a sign bit, make sure the other does as well. llvm generates optimal code
/// (e.g. uses overflow flag on x86) for (word size + 1)-bit integer operations.
x->getType()->getIntegerBitWidth() + (!typeIsSigned(*types[0]) && typeIsSigned(*types[1])),
y->getType()->getIntegerBitWidth() + (!typeIsSigned(*types[1]) && typeIsSigned(*types[0]))));
else
/// (double, float) or (double, int_N where N <= double's mantissa width) -> double
common = b.getDoubleTy();
x = nativeCast(b, types[0], x, common);
y = nativeCast(b, types[1], y, common);
}
auto * result = Op<int, int>::compile(b, x, y, typeIsSigned(*types[0]) || typeIsSigned(*types[1]));
return b.CreateSelect(result, b.getInt8(1), b.getInt8(0));
}
#endif
};