Implement jit for logic functions

This commit is contained in:
pyos 2018-04-30 15:33:40 +03:00
parent e4ace21f24
commit 7483ed24f0

View File

@ -12,6 +12,10 @@
#include <Functions/FunctionHelpers.h>
#include <type_traits>
#if USE_EMBEDDED_COMPILER
#include <llvm/IR/IRBuilder.h>
#endif
namespace DB
{
@ -31,65 +35,71 @@ namespace ErrorCodes
* For example, 1 OR NULL returns 1, not NULL.
*/
struct AndImpl
{
static inline bool isSaturable()
static inline constexpr bool isSaturable()
{
return true;
}
static inline bool isSaturatedValue(bool a)
static inline constexpr bool isSaturatedValue(bool a)
{
return !a;
}
static inline bool apply(bool a, bool b)
static inline constexpr bool apply(bool a, bool b)
{
return a && b;
}
static inline bool specialImplementationForNulls() { return false; }
static inline constexpr bool specialImplementationForNulls() { return false; }
};
struct OrImpl
{
static inline bool isSaturable()
static inline constexpr bool isSaturable()
{
return true;
}
static inline bool isSaturatedValue(bool a)
static inline constexpr bool isSaturatedValue(bool a)
{
return a;
}
static inline bool apply(bool a, bool b)
static inline constexpr bool apply(bool a, bool b)
{
return a || b;
}
static inline bool specialImplementationForNulls() { return true; }
static inline constexpr bool specialImplementationForNulls() { return true; }
};
struct XorImpl
{
static inline bool isSaturable()
static inline constexpr bool isSaturable()
{
return false;
}
static inline bool isSaturatedValue(bool)
static inline constexpr bool isSaturatedValue(bool)
{
return false;
}
static inline bool apply(bool a, bool b)
static inline constexpr bool apply(bool a, bool b)
{
return a != b;
}
static inline bool specialImplementationForNulls() { return false; }
static inline constexpr bool specialImplementationForNulls() { return false; }
#if USE_EMBEDDED_COMPILER
static inline llvm::Value * apply(llvm::IRBuilder<> & builder, llvm::Value * a, llvm::Value * b)
{
return builder.CreateXor(a, b);
}
#endif
};
template <typename A>
@ -101,6 +111,13 @@ struct NotImpl
{
return !a;
}
#if USE_EMBEDDED_COMPILER
static inline llvm::Value * apply(llvm::IRBuilder<> & builder, llvm::Value * a)
{
return builder.CreateNot(a);
}
#endif
};
@ -172,6 +189,20 @@ struct AssociativeOperationImpl<Op, 1>
};
#if USE_EMBEDDED_COMPILER
static llvm::Value * isNativeTrueValue(llvm::IRBuilder<> & b, const DataTypePtr & type, llvm::Value * x)
{
if (type->isNullable())
{
auto * subexpr = isNativeTrueValue(b, removeNullable(type), b.CreateExtractValue(x, {0}));
return b.CreateAnd(b.CreateNot(b.CreateExtractValue(x, {1})), subexpr);
}
auto * zero = llvm::Constant::getNullValue(x->getType());
return x->getType()->isIntegerTy() ? b.CreateICmpNE(x, zero) : b.CreateFCmpONE(x, zero); /// QNaN -> false
}
#endif
template <typename Impl, typename Name>
class FunctionAnyArityLogical : public IFunction
{
@ -364,6 +395,44 @@ public:
block.getByPosition(result).column = std::move(col_res);
}
#if USE_EMBEDDED_COMPILER
bool isCompilableImpl(const DataTypes &) const override { return true; }
llvm::Value * compileImpl(llvm::IRBuilderBase & builder, const DataTypes & types, ValuePlaceholders values) const override
{
auto & b = static_cast<llvm::IRBuilder<> &>(builder);
if constexpr (!Impl::isSaturable())
{
auto * result = isNativeTrueValue(b, types[0], values[0]());
for (size_t i = 1; i < types.size(); i++)
result = Impl::apply(b, result, isNativeTrueValue(b, types[i], values[i]()));
return b.CreateSelect(result, b.getInt8(1), b.getInt8(0));
}
constexpr bool breakOnTrue = Impl::isSaturatedValue(true);
auto * next = b.GetInsertBlock();
auto * stop = llvm::BasicBlock::Create(next->getContext(), "", next->getParent());
b.SetInsertPoint(stop);
auto * phi = b.CreatePHI(b.getInt8Ty(), values.size());
for (size_t i = 0; i < types.size(); i++)
{
b.SetInsertPoint(next);
auto * value = values[i]();
auto * truth = isNativeTrueValue(b, types[i], value);
if (!types[i]->equals(DataTypeUInt8{}))
value = b.CreateSelect(truth, b.getInt8(1), b.getInt8(0));
phi->addIncoming(value, b.GetInsertBlock());
if (i + 1 < types.size())
{
next = llvm::BasicBlock::Create(next->getContext(), "", next->getParent());
b.CreateCondBr(truth, breakOnTrue ? stop : next, breakOnTrue ? next : stop);
}
}
b.CreateBr(stop);
b.SetInsertPoint(stop);
return phi;
}
#endif
};
@ -430,6 +499,16 @@ public:
+ " of argument of function " + getName(),
ErrorCodes::ILLEGAL_COLUMN);
}
#if USE_EMBEDDED_COMPILER
bool isCompilableImpl(const DataTypes &) const override { return true; }
llvm::Value * compileImpl(llvm::IRBuilderBase & builder, const DataTypes & types, ValuePlaceholders values) const override
{
auto & b = static_cast<llvm::IRBuilder<> &>(builder);
return b.CreateSelect(Impl<UInt8>::apply(b, isNativeTrueValue(b, types[0], values[0]())), b.getInt8(1), b.getInt8(0));
}
#endif
};