From 1433e6e849af13e19b949b1bbc00db5ce66c1cd4 Mon Sep 17 00:00:00 2001 From: pyos Date: Mon, 7 May 2018 22:21:23 +0300 Subject: [PATCH] Extract native bool cast; generalize number cast to nullables --- dbms/src/DataTypes/Native.h | 55 +++++++++++++++++++----- dbms/src/Functions/FunctionsArithmetic.h | 8 ++-- dbms/src/Functions/FunctionsLogical.h | 22 ++-------- 3 files changed, 53 insertions(+), 32 deletions(-) diff --git a/dbms/src/DataTypes/Native.h b/dbms/src/DataTypes/Native.h index 61daececd3e..6a793d13ca4 100644 --- a/dbms/src/DataTypes/Native.h +++ b/dbms/src/DataTypes/Native.h @@ -62,19 +62,54 @@ static inline llvm::Type * toNativeType(llvm::IRBuilderBase & builder, const Dat return toNativeType(builder, *type); } -static inline llvm::Value * castNativeNumber(llvm::IRBuilder<> & builder, llvm::Value * value, llvm::Type * type, bool is_signed) +static inline llvm::Value * nativeBoolCast(llvm::IRBuilder<> & b, const DataTypePtr & from, llvm::Value * value) { - if (value->getType() == type) - return value; - if (value->getType()->isIntegerTy()) + if (from->isNullable()) { - if (type->isIntegerTy()) - return builder.CreateIntCast(value, type, is_signed); - return is_signed ? builder.CreateSIToFP(value, type) : builder.CreateUIToFP(value, type); + auto * inner = nativeBoolCast(b, removeNullable(from), b.CreateExtractValue(value, {0})); + return b.CreateAnd(b.CreateNot(b.CreateExtractValue(value, {1})), inner); } - if (type->isFloatingPointTy()) - return builder.CreateFPCast(value, type); - return is_signed ? builder.CreateFPToSI(value, type) : builder.CreateFPToUI(value, type); + auto * zero = llvm::Constant::getNullValue(value->getType()); + if (value->getType()->isIntegerTy()) + return b.CreateICmpNE(value, zero); + if (value->getType()->isFloatingPointTy()) + return b.CreateFCmpONE(value, zero); /// QNaN is false + throw Exception("Cannot cast non-number " + from->getName() + " to bool", ErrorCodes::NOT_IMPLEMENTED); +} + +static inline llvm::Value * nativeCast(llvm::IRBuilder<> & b, const DataTypePtr & from, llvm::Value * value, const DataTypePtr & to) +{ + auto * n_from = value->getType(); + auto * n_to = toNativeType(b, to); + if (n_from == n_to) + return value; + if (from->isNullable() && to->isNullable()) + { + auto * inner = nativeCast(b, removeNullable(from), b.CreateExtractValue(value, {0}), to); + return b.CreateInsertValue(inner, b.CreateExtractValue(value, {1}), {1}); + } + if (from->isNullable()) + return nativeCast(b, removeNullable(from), b.CreateExtractValue(value, {0}), to); + if (to->isNullable()) + { + auto * inner = nativeCast(b, from, value, removeNullable(to)); + return b.CreateInsertValue(llvm::Constant::getNullValue(n_to), inner, {0}); + } + + bool is_signed = typeIsEither< + DataTypeInt8, DataTypeInt16, DataTypeInt32, DataTypeInt64, + DataTypeFloat32, DataTypeFloat64, + DataTypeDate, DataTypeDateTime, DataTypeInterval + >(*from); + if (n_from->isIntegerTy() && n_to->isFloatingPointTy()) + return is_signed ? b.CreateSIToFP(value, n_to) : b.CreateUIToFP(value, n_to); + if (n_from->isFloatingPointTy() && n_to->isIntegerTy()) + return is_signed ? b.CreateFPToSI(value, n_to) : b.CreateFPToUI(value, n_to); + if (n_from->isIntegerTy() && n_to->isIntegerTy()) + return b.CreateIntCast(value, n_to, is_signed); + if (n_from->isFloatingPointTy() && n_to->isFloatingPointTy()) + return b.CreateFPCast(value, n_to); + throw Exception("Cannot cast " + from->getName() + " to " + to->getName(), ErrorCodes::NOT_IMPLEMENTED); } } diff --git a/dbms/src/Functions/FunctionsArithmetic.h b/dbms/src/Functions/FunctionsArithmetic.h index cef7ec4b924..a8ec0cf4942 100644 --- a/dbms/src/Functions/FunctionsArithmetic.h +++ b/dbms/src/Functions/FunctionsArithmetic.h @@ -982,9 +982,9 @@ public: if constexpr (!std::is_same_v && OpSpec::compilable) { auto & b = static_cast &>(builder); - auto * type = toNativeType(b, ResultDataType{}); - auto * lval = castNativeNumber(b, values[0](), type, std::is_signed_v); - auto * rval = castNativeNumber(b, values[1](), type, std::is_signed_v); + auto type = std::make_shared(); + auto * lval = nativeCast(b, types[0], values[0](), type); + auto * rval = nativeCast(b, types[1], values[1](), type); result = OpSpec::compile(b, lval, rval, std::is_signed_v); return true; } @@ -1088,7 +1088,7 @@ public: if constexpr (Op::compilable) { auto & b = static_cast &>(builder); - auto * v = castNativeNumber(b, values[0](), toNativeType(b, DataTypeNumber{}), std::is_signed_v); + auto * v = nativeCast(b, types[0], values[0](), std::make_shared>()); result = Op::compile(b, v, std::is_signed_v); return true; } diff --git a/dbms/src/Functions/FunctionsLogical.h b/dbms/src/Functions/FunctionsLogical.h index c62816be734..d5e77c4a450 100644 --- a/dbms/src/Functions/FunctionsLogical.h +++ b/dbms/src/Functions/FunctionsLogical.h @@ -192,20 +192,6 @@ struct AssociativeOperationImpl }; -#if USE_EMBEDDED_COMPILER -static llvm::Value * isNativeTrueValue(llvm::IRBuilder<> & b, const DataTypePtr & type, llvm::Value * x) -{ - if (type->isNullable()) - { - auto * subexpr = isNativeTrueValue(b, removeNullable(type), b.CreateExtractValue(x, {0})); - return b.CreateAnd(b.CreateNot(b.CreateExtractValue(x, {1})), subexpr); - } - auto * zero = llvm::Constant::getNullValue(x->getType()); - return x->getType()->isIntegerTy() ? b.CreateICmpNE(x, zero) : b.CreateFCmpONE(x, zero); /// QNaN -> false -} -#endif - - template class FunctionAnyArityLogical : public IFunction { @@ -407,9 +393,9 @@ public: auto & b = static_cast &>(builder); if constexpr (!Impl::isSaturable()) { - auto * result = isNativeTrueValue(b, types[0], values[0]()); + auto * result = nativeBoolCast(b, types[0], values[0]()); for (size_t i = 1; i < types.size(); i++) - result = Impl::apply(b, result, isNativeTrueValue(b, types[i], values[i]())); + result = Impl::apply(b, result, nativeBoolCast(b, types[i], values[i]())); return b.CreateSelect(result, b.getInt8(1), b.getInt8(0)); } constexpr bool breakOnTrue = Impl::isSaturatedValue(true); @@ -421,7 +407,7 @@ public: { b.SetInsertPoint(next); auto * value = values[i](); - auto * truth = isNativeTrueValue(b, types[i], value); + auto * truth = nativeBoolCast(b, types[i], value); if (!types[i]->equals(DataTypeUInt8{})) value = b.CreateSelect(truth, b.getInt8(1), b.getInt8(0)); phi->addIncoming(value, b.GetInsertBlock()); @@ -509,7 +495,7 @@ public: llvm::Value * compileImpl(llvm::IRBuilderBase & builder, const DataTypes & types, ValuePlaceholders values) const override { auto & b = static_cast &>(builder); - return b.CreateSelect(Impl::apply(b, isNativeTrueValue(b, types[0], values[0]())), b.getInt8(1), b.getInt8(0)); + return b.CreateSelect(Impl::apply(b, nativeBoolCast(b, types[0], values[0]())), b.getInt8(1), b.getInt8(0)); } #endif };