Make LLVMFunction monotonicity computation shorter (and fix a typo-bug)

This commit is contained in:
pyos 2018-04-28 17:41:13 +03:00
parent a1eb938ed2
commit 1ffc2a0775
2 changed files with 31 additions and 41 deletions

View File

@ -12,12 +12,7 @@
namespace DB
{
namespace ErrorCodes
{
extern const int NOT_IMPLEMENTED;
}
static llvm::Type * toNativeType(llvm::IRBuilderBase & builder, const DataTypePtr & type)
static inline llvm::Type * toNativeType(llvm::IRBuilderBase & builder, const DataTypePtr & type)
{
if (auto * nullable = typeid_cast<const DataTypeNullable *>(type.get()))
{
@ -40,7 +35,7 @@ static llvm::Type * toNativeType(llvm::IRBuilderBase & builder, const DataTypePt
return nullptr;
}
static llvm::Constant * getDefaultNativeValue(llvm::Type * type)
static inline llvm::Constant * getDefaultNativeValue(llvm::Type * type)
{
if (type->isIntegerTy())
return llvm::ConstantInt::get(type, 0);
@ -52,6 +47,27 @@ static llvm::Constant * getDefaultNativeValue(llvm::Type * type)
return llvm::ConstantStruct::get(static_cast<llvm::StructType *>(type), value, is_null);
}
static inline llvm::Constant * getNativeValue(llvm::Type * type, const IColumn * column, size_t i)
{
if (!column || !type)
return nullptr;
if (auto * constant = typeid_cast<const ColumnConst *>(column))
return getNativeValue(type, &constant->getDataColumn(), 0);
if (auto * nullable = typeid_cast<const ColumnNullable *>(column))
{
auto * value = getNativeValue(type->getContainedType(0), &nullable->getNestedColumn(), i);
auto * is_null = llvm::ConstantInt::get(type->getContainedType(1), nullable->isNullAt(i));
return value ? llvm::ConstantStruct::get(static_cast<llvm::StructType *>(type), value, is_null) : nullptr;
}
if (type->isFloatTy())
return llvm::ConstantFP::get(type, static_cast<const ColumnVector<Float32> *>(column)->getElement(i));
if (type->isDoubleTy())
return llvm::ConstantFP::get(type, static_cast<const ColumnVector<Float64> *>(column)->getElement(i));
if (type->isIntegerTy())
return llvm::ConstantInt::get(type, column->getUInt(i));
return nullptr;
}
}
#endif

View File

@ -159,27 +159,6 @@ void LLVMPreparedFunction::execute(Block & block, const ColumnNumbers & argument
block.getByPosition(result).column = std::move(col_res);
};
static llvm::Constant * getConstantValue(const IColumn * column, llvm::Type * type)
{
if (!column || !type)
return nullptr;
if (auto * constant = typeid_cast<const ColumnConst *>(column))
return getConstantValue(&constant->getDataColumn(), type);
if (auto * nullable = typeid_cast<const ColumnNullable *>(column))
{
auto * value = getConstantValue(&nullable->getNestedColumn(), type->getContainedType(0));
auto * is_null = llvm::ConstantInt::get(type->getContainedType(1), nullable->isNullAt(0));
return value ? llvm::ConstantStruct::get(static_cast<llvm::StructType *>(type), value, is_null) : nullptr;
}
if (type->isFloatTy())
return llvm::ConstantFP::get(type, static_cast<const ColumnVector<Float32> *>(column)->getElement(0));
if (type->isDoubleTy())
return llvm::ConstantFP::get(type, static_cast<const ColumnVector<Float64> *>(column)->getElement(0));
if (type->isIntegerTy())
return llvm::ConstantInt::get(type, column->getUInt(0));
return nullptr;
}
LLVMFunction::LLVMFunction(ExpressionActions::Actions actions_, LLVMContext context, const Block & sample_block)
: actions(std::move(actions_)), context(context)
{
@ -197,7 +176,7 @@ LLVMFunction::LLVMFunction(ExpressionActions::Actions actions_, LLVMContext cont
std::unordered_map<std::string, std::function<llvm::Value * ()>> by_name;
for (const auto & c : sample_block)
if (auto * value = getConstantValue(c.column.get(), toNativeType(b, c.type)))
if (auto * value = getNativeValue(toNativeType(b, c.type), c.column.get(), 0))
by_name[c.name] = [=]() { return value; };
std::unordered_set<std::string> seen;
@ -304,17 +283,12 @@ LLVMFunction::LLVMFunction(ExpressionActions::Actions actions_, LLVMContext cont
b.CreateRetVoid();
}
static Field evaluateFunction(IFunctionBase & function, const IDataType & type, const Field & arg)
static void applyFunction(IFunctionBase & function, Field & value)
{
const auto & arg_types = function.getArgumentTypes();
if (arg_types.size() != 1 || !arg_types[0]->equals(type))
return {};
auto column = arg_types[0]->createColumn();
column->insert(arg);
Block block = {{ ColumnConst::create(std::move(column), 1), arg_types[0], "_arg" }, { nullptr, function.getReturnType(), "_result" }};
const auto & type = function.getArgumentTypes().at(0);
Block block = {{ type->createColumnConst(1, value), type, "x" }, { nullptr, function.getReturnType(), "y" }};
function.execute(block, {0}, 1);
auto result = block.getByPosition(1).column;
return result && result->size() == 1 ? (*result)[0] : Field();
block.safeGetByPosition(1).column->get(0, value);
}
IFunctionBase::Monotonicity LLVMFunction::getMonotonicityForRange(const IDataType & type, const Field & left, const Field & right) const
@ -326,7 +300,7 @@ IFunctionBase::Monotonicity LLVMFunction::getMonotonicityForRange(const IDataTyp
/// monotonicity is only defined for unary functions, so the chain must describe a sequence of nested calls
for (size_t i = 0; i < actions.size(); i++)
{
Monotonicity m = actions[i].function->getMonotonicityForRange(type, left_, right_);
Monotonicity m = actions[i].function->getMonotonicityForRange(*type_, left_, right_);
if (!m.is_monotonic)
return m;
result.is_positive ^= !m.is_positive;
@ -334,9 +308,9 @@ IFunctionBase::Monotonicity LLVMFunction::getMonotonicityForRange(const IDataTyp
if (i + 1 < actions.size())
{
if (left_ != Field())
left_ = evaluateFunction(*actions[i].function, *type_, left_);
applyFunction(*actions[i].function, left_);
if (right_ != Field())
right_ = evaluateFunction(*actions[i].function, *type_, right_);
applyFunction(*actions[i].function, right_);
if (!m.is_positive)
std::swap(left_, right_);
type_ = actions[i].function->getReturnType().get();