mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-11-22 15:42:02 +00:00
Make LLVMFunction monotonicity computation shorter (and fix a typo-bug)
This commit is contained in:
parent
a1eb938ed2
commit
1ffc2a0775
@ -12,12 +12,7 @@
|
||||
namespace DB
|
||||
{
|
||||
|
||||
namespace ErrorCodes
|
||||
{
|
||||
extern const int NOT_IMPLEMENTED;
|
||||
}
|
||||
|
||||
static llvm::Type * toNativeType(llvm::IRBuilderBase & builder, const DataTypePtr & type)
|
||||
static inline llvm::Type * toNativeType(llvm::IRBuilderBase & builder, const DataTypePtr & type)
|
||||
{
|
||||
if (auto * nullable = typeid_cast<const DataTypeNullable *>(type.get()))
|
||||
{
|
||||
@ -40,7 +35,7 @@ static llvm::Type * toNativeType(llvm::IRBuilderBase & builder, const DataTypePt
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
static llvm::Constant * getDefaultNativeValue(llvm::Type * type)
|
||||
static inline llvm::Constant * getDefaultNativeValue(llvm::Type * type)
|
||||
{
|
||||
if (type->isIntegerTy())
|
||||
return llvm::ConstantInt::get(type, 0);
|
||||
@ -52,6 +47,27 @@ static llvm::Constant * getDefaultNativeValue(llvm::Type * type)
|
||||
return llvm::ConstantStruct::get(static_cast<llvm::StructType *>(type), value, is_null);
|
||||
}
|
||||
|
||||
static inline llvm::Constant * getNativeValue(llvm::Type * type, const IColumn * column, size_t i)
|
||||
{
|
||||
if (!column || !type)
|
||||
return nullptr;
|
||||
if (auto * constant = typeid_cast<const ColumnConst *>(column))
|
||||
return getNativeValue(type, &constant->getDataColumn(), 0);
|
||||
if (auto * nullable = typeid_cast<const ColumnNullable *>(column))
|
||||
{
|
||||
auto * value = getNativeValue(type->getContainedType(0), &nullable->getNestedColumn(), i);
|
||||
auto * is_null = llvm::ConstantInt::get(type->getContainedType(1), nullable->isNullAt(i));
|
||||
return value ? llvm::ConstantStruct::get(static_cast<llvm::StructType *>(type), value, is_null) : nullptr;
|
||||
}
|
||||
if (type->isFloatTy())
|
||||
return llvm::ConstantFP::get(type, static_cast<const ColumnVector<Float32> *>(column)->getElement(i));
|
||||
if (type->isDoubleTy())
|
||||
return llvm::ConstantFP::get(type, static_cast<const ColumnVector<Float64> *>(column)->getElement(i));
|
||||
if (type->isIntegerTy())
|
||||
return llvm::ConstantInt::get(type, column->getUInt(i));
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
#endif
|
||||
|
@ -159,27 +159,6 @@ void LLVMPreparedFunction::execute(Block & block, const ColumnNumbers & argument
|
||||
block.getByPosition(result).column = std::move(col_res);
|
||||
};
|
||||
|
||||
static llvm::Constant * getConstantValue(const IColumn * column, llvm::Type * type)
|
||||
{
|
||||
if (!column || !type)
|
||||
return nullptr;
|
||||
if (auto * constant = typeid_cast<const ColumnConst *>(column))
|
||||
return getConstantValue(&constant->getDataColumn(), type);
|
||||
if (auto * nullable = typeid_cast<const ColumnNullable *>(column))
|
||||
{
|
||||
auto * value = getConstantValue(&nullable->getNestedColumn(), type->getContainedType(0));
|
||||
auto * is_null = llvm::ConstantInt::get(type->getContainedType(1), nullable->isNullAt(0));
|
||||
return value ? llvm::ConstantStruct::get(static_cast<llvm::StructType *>(type), value, is_null) : nullptr;
|
||||
}
|
||||
if (type->isFloatTy())
|
||||
return llvm::ConstantFP::get(type, static_cast<const ColumnVector<Float32> *>(column)->getElement(0));
|
||||
if (type->isDoubleTy())
|
||||
return llvm::ConstantFP::get(type, static_cast<const ColumnVector<Float64> *>(column)->getElement(0));
|
||||
if (type->isIntegerTy())
|
||||
return llvm::ConstantInt::get(type, column->getUInt(0));
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
LLVMFunction::LLVMFunction(ExpressionActions::Actions actions_, LLVMContext context, const Block & sample_block)
|
||||
: actions(std::move(actions_)), context(context)
|
||||
{
|
||||
@ -197,7 +176,7 @@ LLVMFunction::LLVMFunction(ExpressionActions::Actions actions_, LLVMContext cont
|
||||
|
||||
std::unordered_map<std::string, std::function<llvm::Value * ()>> by_name;
|
||||
for (const auto & c : sample_block)
|
||||
if (auto * value = getConstantValue(c.column.get(), toNativeType(b, c.type)))
|
||||
if (auto * value = getNativeValue(toNativeType(b, c.type), c.column.get(), 0))
|
||||
by_name[c.name] = [=]() { return value; };
|
||||
|
||||
std::unordered_set<std::string> seen;
|
||||
@ -304,17 +283,12 @@ LLVMFunction::LLVMFunction(ExpressionActions::Actions actions_, LLVMContext cont
|
||||
b.CreateRetVoid();
|
||||
}
|
||||
|
||||
static Field evaluateFunction(IFunctionBase & function, const IDataType & type, const Field & arg)
|
||||
static void applyFunction(IFunctionBase & function, Field & value)
|
||||
{
|
||||
const auto & arg_types = function.getArgumentTypes();
|
||||
if (arg_types.size() != 1 || !arg_types[0]->equals(type))
|
||||
return {};
|
||||
auto column = arg_types[0]->createColumn();
|
||||
column->insert(arg);
|
||||
Block block = {{ ColumnConst::create(std::move(column), 1), arg_types[0], "_arg" }, { nullptr, function.getReturnType(), "_result" }};
|
||||
const auto & type = function.getArgumentTypes().at(0);
|
||||
Block block = {{ type->createColumnConst(1, value), type, "x" }, { nullptr, function.getReturnType(), "y" }};
|
||||
function.execute(block, {0}, 1);
|
||||
auto result = block.getByPosition(1).column;
|
||||
return result && result->size() == 1 ? (*result)[0] : Field();
|
||||
block.safeGetByPosition(1).column->get(0, value);
|
||||
}
|
||||
|
||||
IFunctionBase::Monotonicity LLVMFunction::getMonotonicityForRange(const IDataType & type, const Field & left, const Field & right) const
|
||||
@ -326,7 +300,7 @@ IFunctionBase::Monotonicity LLVMFunction::getMonotonicityForRange(const IDataTyp
|
||||
/// monotonicity is only defined for unary functions, so the chain must describe a sequence of nested calls
|
||||
for (size_t i = 0; i < actions.size(); i++)
|
||||
{
|
||||
Monotonicity m = actions[i].function->getMonotonicityForRange(type, left_, right_);
|
||||
Monotonicity m = actions[i].function->getMonotonicityForRange(*type_, left_, right_);
|
||||
if (!m.is_monotonic)
|
||||
return m;
|
||||
result.is_positive ^= !m.is_positive;
|
||||
@ -334,9 +308,9 @@ IFunctionBase::Monotonicity LLVMFunction::getMonotonicityForRange(const IDataTyp
|
||||
if (i + 1 < actions.size())
|
||||
{
|
||||
if (left_ != Field())
|
||||
left_ = evaluateFunction(*actions[i].function, *type_, left_);
|
||||
applyFunction(*actions[i].function, left_);
|
||||
if (right_ != Field())
|
||||
right_ = evaluateFunction(*actions[i].function, *type_, right_);
|
||||
applyFunction(*actions[i].function, right_);
|
||||
if (!m.is_positive)
|
||||
std::swap(left_, right_);
|
||||
type_ = actions[i].function->getReturnType().get();
|
||||
|
Loading…
Reference in New Issue
Block a user