mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-11-26 09:32:01 +00:00
Compile AggregateFunctionAvgWeighted
This commit is contained in:
parent
56c1a4e447
commit
507d9405e2
@ -93,11 +93,13 @@ struct AvgFraction
|
|||||||
* @tparam Derived When deriving from this class, use the child class name as in CRTP, e.g.
|
* @tparam Derived When deriving from this class, use the child class name as in CRTP, e.g.
|
||||||
* class Self : Agg<char, bool, bool, Self>.
|
* class Self : Agg<char, bool, bool, Self>.
|
||||||
*/
|
*/
|
||||||
template <typename Numerator, typename Denominator, typename Derived>
|
template <typename TNumerator, typename TDenominator, typename Derived>
|
||||||
class AggregateFunctionAvgBase : public
|
class AggregateFunctionAvgBase : public
|
||||||
IAggregateFunctionDataHelper<AvgFraction<Numerator, Denominator>, Derived>
|
IAggregateFunctionDataHelper<AvgFraction<TNumerator, TDenominator>, Derived>
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
|
using Numerator = TNumerator;
|
||||||
|
using Denominator = TDenominator;
|
||||||
using Fraction = AvgFraction<Numerator, Denominator>;
|
using Fraction = AvgFraction<Numerator, Denominator>;
|
||||||
using Base = IAggregateFunctionDataHelper<Fraction, Derived>;
|
using Base = IAggregateFunctionDataHelper<Fraction, Derived>;
|
||||||
|
|
||||||
@ -143,6 +145,89 @@ public:
|
|||||||
else
|
else
|
||||||
assert_cast<ColumnVector<Float64> &>(to).getData().push_back(this->data(place).divide());
|
assert_cast<ColumnVector<Float64> &>(to).getData().push_back(this->data(place).divide());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
#if USE_EMBEDDED_COMPILER
|
||||||
|
|
||||||
|
bool isCompilable() const override
|
||||||
|
{
|
||||||
|
bool can_be_compiled = true;
|
||||||
|
|
||||||
|
for (const auto & argument : this->argument_types)
|
||||||
|
can_be_compiled &= canBeNativeType(*argument);
|
||||||
|
|
||||||
|
auto return_type = getReturnType();
|
||||||
|
can_be_compiled &= canBeNativeType(*return_type);
|
||||||
|
|
||||||
|
return can_be_compiled;
|
||||||
|
}
|
||||||
|
|
||||||
|
void compileCreate(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr) const override
|
||||||
|
{
|
||||||
|
llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder);
|
||||||
|
|
||||||
|
auto * numerator_type = toNativeType<Numerator>(b);
|
||||||
|
auto * numerator_ptr = b.CreatePointerCast(aggregate_data_ptr, numerator_type->getPointerTo());
|
||||||
|
|
||||||
|
auto * denominator_type = toNativeType<Denominator>(b);
|
||||||
|
auto * denominator_offset_ptr = b.CreateConstGEP1_32(nullptr, aggregate_data_ptr, sizeof(Numerator));
|
||||||
|
auto * denominator_ptr = b.CreatePointerCast(denominator_offset_ptr, denominator_type->getPointerTo());
|
||||||
|
|
||||||
|
b.CreateStore(llvm::Constant::getNullValue(numerator_type), numerator_ptr);
|
||||||
|
b.CreateStore(llvm::Constant::getNullValue(denominator_type), denominator_ptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
void compileMerge(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_dst_ptr, llvm::Value * aggregate_data_src_ptr) const override
|
||||||
|
{
|
||||||
|
llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder);
|
||||||
|
|
||||||
|
auto * numerator_type = toNativeType<Numerator>(b);
|
||||||
|
|
||||||
|
auto * numerator_dst_ptr = b.CreatePointerCast(aggregate_data_dst_ptr, numerator_type->getPointerTo());
|
||||||
|
auto * numerator_dst_value = b.CreateLoad(numerator_type, numerator_dst_ptr);
|
||||||
|
|
||||||
|
auto * numerator_src_ptr = b.CreatePointerCast(aggregate_data_dst_ptr, numerator_type->getPointerTo());
|
||||||
|
auto * numerator_src_value = b.CreateLoad(numerator_type, numerator_src_ptr);
|
||||||
|
|
||||||
|
auto * numerator_result_value = numerator_type->isIntegerTy() ? b.CreateAdd(numerator_dst_value, numerator_src_value) : b.CreateFAdd(numerator_dst_value, numerator_src_value);
|
||||||
|
b.CreateStore(numerator_result_value, numerator_dst_ptr);
|
||||||
|
|
||||||
|
auto * denominator_type = toNativeType<Denominator>(b);
|
||||||
|
|
||||||
|
auto * denominator_dst_offset_ptr = b.CreateConstGEP1_32(nullptr, aggregate_data_dst_ptr, sizeof(Numerator));
|
||||||
|
auto * denominator_src_offset_ptr = b.CreateConstGEP1_32(nullptr, aggregate_data_src_ptr, sizeof(Numerator));
|
||||||
|
|
||||||
|
auto * denominator_dst_ptr = b.CreatePointerCast(denominator_dst_offset_ptr, denominator_type->getPointerTo());
|
||||||
|
auto * denominator_src_ptr = b.CreatePointerCast(denominator_src_offset_ptr, denominator_type->getPointerTo());
|
||||||
|
|
||||||
|
auto * denominator_dst_value = b.CreateLoad(denominator_type, denominator_dst_ptr);
|
||||||
|
auto * denominator_src_value = b.CreateLoad(denominator_type, denominator_src_ptr);
|
||||||
|
|
||||||
|
auto * denominator_result_value = numerator_type->isIntegerTy() ? b.CreateAdd(denominator_src_value, denominator_dst_value) : b.CreateFAdd(denominator_src_value, denominator_dst_value);
|
||||||
|
b.CreateStore(denominator_result_value, denominator_dst_ptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
llvm::Value * compileGetResult(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr) const override
|
||||||
|
{
|
||||||
|
llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder);
|
||||||
|
|
||||||
|
auto * numerator_type = toNativeType<Numerator>(b);
|
||||||
|
auto * numerator_ptr = b.CreatePointerCast(aggregate_data_ptr, numerator_type->getPointerTo());
|
||||||
|
auto * numerator_value = b.CreateLoad(numerator_type, numerator_ptr);
|
||||||
|
|
||||||
|
auto * denominator_type = toNativeType<Denominator>(b);
|
||||||
|
auto * denominator_offset_ptr = b.CreateConstGEP1_32(nullptr, aggregate_data_ptr, sizeof(Numerator));
|
||||||
|
auto * denominator_ptr = b.CreatePointerCast(denominator_offset_ptr, denominator_type->getPointerTo());
|
||||||
|
auto * denominator_value = b.CreateLoad(denominator_type, denominator_ptr);
|
||||||
|
|
||||||
|
auto * double_numerator = nativeCast<Numerator>(b, numerator_value, b.getDoubleTy());
|
||||||
|
auto * double_denominator = nativeCast<Denominator>(b, denominator_value, b.getDoubleTy());
|
||||||
|
|
||||||
|
return b.CreateFDiv(double_numerator, double_denominator);
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
private:
|
private:
|
||||||
UInt32 num_scale;
|
UInt32 num_scale;
|
||||||
UInt32 denom_scale;
|
UInt32 denom_scale;
|
||||||
@ -157,7 +242,11 @@ template <typename T>
|
|||||||
class AggregateFunctionAvg final : public AggregateFunctionAvgBase<AvgFieldType<T>, UInt64, AggregateFunctionAvg<T>>
|
class AggregateFunctionAvg final : public AggregateFunctionAvgBase<AvgFieldType<T>, UInt64, AggregateFunctionAvg<T>>
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
using AggregateFunctionAvgBase<AvgFieldType<T>, UInt64, AggregateFunctionAvg<T>>::AggregateFunctionAvgBase;
|
using Base = AggregateFunctionAvgBase<AvgFieldType<T>, UInt64, AggregateFunctionAvg<T>>;
|
||||||
|
using Base::Base;
|
||||||
|
|
||||||
|
using Numerator = typename Base::Numerator;
|
||||||
|
using Denominator = typename Base::Denominator;
|
||||||
|
|
||||||
void NO_SANITIZE_UNDEFINED add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const final
|
void NO_SANITIZE_UNDEFINED add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const final
|
||||||
{
|
{
|
||||||
@ -169,36 +258,11 @@ public:
|
|||||||
|
|
||||||
#if USE_EMBEDDED_COMPILER
|
#if USE_EMBEDDED_COMPILER
|
||||||
|
|
||||||
bool isCompilable() const override
|
|
||||||
{
|
|
||||||
bool can_be_compiled = true;
|
|
||||||
|
|
||||||
for (const auto & argument : this->argument_types)
|
|
||||||
can_be_compiled &= canBeNativeType(*argument);
|
|
||||||
|
|
||||||
return can_be_compiled;
|
|
||||||
}
|
|
||||||
|
|
||||||
void compileCreate(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr) const override
|
|
||||||
{
|
|
||||||
llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder);
|
|
||||||
|
|
||||||
auto * numerator_type = toNativeType<AvgFieldType<T>>(b);
|
|
||||||
auto * numerator_ptr = b.CreatePointerCast(aggregate_data_ptr, numerator_type->getPointerTo());
|
|
||||||
|
|
||||||
auto * denominator_type = toNativeType<UInt64>(b);
|
|
||||||
auto * denominator_offset_ptr = b.CreateConstGEP1_32(nullptr, aggregate_data_ptr, sizeof(AvgFieldType<T>));
|
|
||||||
auto * denominator_ptr = b.CreatePointerCast(denominator_offset_ptr, denominator_type->getPointerTo());
|
|
||||||
|
|
||||||
b.CreateStore(llvm::Constant::getNullValue(numerator_type), numerator_ptr);
|
|
||||||
b.CreateStore(llvm::Constant::getNullValue(denominator_type), denominator_ptr);
|
|
||||||
}
|
|
||||||
|
|
||||||
void compileAdd(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr, const DataTypes & arguments_types, const std::vector<llvm::Value *> & argument_values) const override
|
void compileAdd(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr, const DataTypes & arguments_types, const std::vector<llvm::Value *> & argument_values) const override
|
||||||
{
|
{
|
||||||
llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder);
|
llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder);
|
||||||
|
|
||||||
auto * numerator_type = toNativeType<AvgFieldType<T>>(b);
|
auto * numerator_type = toNativeType<Numerator>(b);
|
||||||
|
|
||||||
auto * numerator_ptr = b.CreatePointerCast(aggregate_data_ptr, numerator_type->getPointerTo());
|
auto * numerator_ptr = b.CreatePointerCast(aggregate_data_ptr, numerator_type->getPointerTo());
|
||||||
auto * numerator_value = b.CreateLoad(numerator_type, numerator_ptr);
|
auto * numerator_value = b.CreateLoad(numerator_type, numerator_ptr);
|
||||||
@ -209,9 +273,9 @@ public:
|
|||||||
auto * numerator_result_value = numerator_type->isIntegerTy() ? b.CreateAdd(numerator_value, value_cast_to_numerator) : b.CreateFAdd(numerator_value, value_cast_to_numerator);
|
auto * numerator_result_value = numerator_type->isIntegerTy() ? b.CreateAdd(numerator_value, value_cast_to_numerator) : b.CreateFAdd(numerator_value, value_cast_to_numerator);
|
||||||
b.CreateStore(numerator_result_value, numerator_ptr);
|
b.CreateStore(numerator_result_value, numerator_ptr);
|
||||||
|
|
||||||
auto * denominator_type = toNativeType<UInt64>(b);
|
auto * denominator_type = toNativeType<Denominator>(b);
|
||||||
|
|
||||||
auto * denominator_offset_ptr = b.CreateConstGEP1_32(nullptr, aggregate_data_ptr, sizeof(AvgFieldType<T>));
|
auto * denominator_offset_ptr = b.CreateConstGEP1_32(nullptr, aggregate_data_ptr, sizeof(Numerator));
|
||||||
auto * denominator_ptr = b.CreatePointerCast(denominator_offset_ptr, denominator_type->getPointerTo());
|
auto * denominator_ptr = b.CreatePointerCast(denominator_offset_ptr, denominator_type->getPointerTo());
|
||||||
|
|
||||||
auto * denominator_value = b.CreateLoad(denominator_type, denominator_ptr);
|
auto * denominator_value = b.CreateLoad(denominator_type, denominator_ptr);
|
||||||
@ -220,55 +284,6 @@ public:
|
|||||||
b.CreateStore(denominator_value_updated, denominator_ptr);
|
b.CreateStore(denominator_value_updated, denominator_ptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
void compileMerge(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_dst_ptr, llvm::Value * aggregate_data_src_ptr) const override
|
|
||||||
{
|
|
||||||
llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder);
|
|
||||||
|
|
||||||
auto * numerator_type = toNativeType<AvgFieldType<T>>(b);
|
|
||||||
|
|
||||||
auto * numerator_dst_ptr = b.CreatePointerCast(aggregate_data_dst_ptr, numerator_type->getPointerTo());
|
|
||||||
auto * numerator_dst_value = b.CreateLoad(numerator_type, numerator_dst_ptr);
|
|
||||||
|
|
||||||
auto * numerator_src_ptr = b.CreatePointerCast(aggregate_data_dst_ptr, numerator_type->getPointerTo());
|
|
||||||
auto * numerator_src_value = b.CreateLoad(numerator_type, numerator_src_ptr);
|
|
||||||
|
|
||||||
auto * numerator_result_value = numerator_type->isIntegerTy() ? b.CreateAdd(numerator_dst_value, numerator_src_value) : b.CreateFAdd(numerator_dst_value, numerator_src_value);
|
|
||||||
b.CreateStore(numerator_result_value, numerator_dst_ptr);
|
|
||||||
|
|
||||||
auto * denominator_type = toNativeType<UInt64>(b);
|
|
||||||
|
|
||||||
auto * denominator_dst_offset_ptr = b.CreateConstGEP1_32(nullptr, aggregate_data_dst_ptr, sizeof(AvgFieldType<T>));
|
|
||||||
auto * denominator_src_offset_ptr = b.CreateConstGEP1_32(nullptr, aggregate_data_src_ptr, sizeof(AvgFieldType<T>));
|
|
||||||
|
|
||||||
auto * denominator_dst_ptr = b.CreatePointerCast(denominator_dst_offset_ptr, denominator_type->getPointerTo());
|
|
||||||
auto * denominator_src_ptr = b.CreatePointerCast(denominator_src_offset_ptr, denominator_type->getPointerTo());
|
|
||||||
|
|
||||||
auto * denominator_dst_value = b.CreateLoad(denominator_type, denominator_dst_ptr);
|
|
||||||
auto * denominator_src_value = b.CreateLoad(denominator_type, denominator_src_ptr);
|
|
||||||
|
|
||||||
auto * denominator_result_value = b.CreateAdd(denominator_src_value, denominator_dst_value);
|
|
||||||
b.CreateStore(denominator_result_value, denominator_dst_ptr);
|
|
||||||
}
|
|
||||||
|
|
||||||
llvm::Value * compileGetResult(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr) const override
|
|
||||||
{
|
|
||||||
llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder);
|
|
||||||
|
|
||||||
auto * numerator_type = toNativeType<AvgFieldType<T>>(b);
|
|
||||||
auto * numerator_ptr = b.CreatePointerCast(aggregate_data_ptr, numerator_type->getPointerTo());
|
|
||||||
auto * numerator_value = b.CreateLoad(numerator_type, numerator_ptr);
|
|
||||||
|
|
||||||
auto * denominator_type = toNativeType<UInt64>(b);
|
|
||||||
auto * denominator_offset_ptr = b.CreateConstGEP1_32(nullptr, aggregate_data_ptr, sizeof(AvgFieldType<T>));
|
|
||||||
auto * denominator_ptr = b.CreatePointerCast(denominator_offset_ptr, denominator_type->getPointerTo());
|
|
||||||
auto * denominator_value = b.CreateLoad(denominator_type, denominator_ptr);
|
|
||||||
|
|
||||||
auto * double_numerator = nativeCast<AvgFieldType<T>>(b, numerator_value, b.getDoubleTy());
|
|
||||||
auto * double_denominator = nativeCast<UInt64>(b, denominator_value, b.getDoubleTy());
|
|
||||||
|
|
||||||
return b.CreateFDiv(double_numerator, double_denominator);
|
|
||||||
}
|
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
};
|
};
|
||||||
|
@ -28,19 +28,64 @@ public:
|
|||||||
MaxFieldType<Value, Weight>, AvgWeightedFieldType<Weight>, AggregateFunctionAvgWeighted<Value, Weight>>;
|
MaxFieldType<Value, Weight>, AvgWeightedFieldType<Weight>, AggregateFunctionAvgWeighted<Value, Weight>>;
|
||||||
using Base::Base;
|
using Base::Base;
|
||||||
|
|
||||||
using ValueT = MaxFieldType<Value, Weight>;
|
using Numerator = typename Base::Numerator;
|
||||||
|
using Denominator = typename Base::Denominator;
|
||||||
|
|
||||||
void NO_SANITIZE_UNDEFINED add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override
|
void NO_SANITIZE_UNDEFINED add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override
|
||||||
{
|
{
|
||||||
const auto& weights = static_cast<const DecimalOrVectorCol<Weight> &>(*columns[1]);
|
const auto& weights = static_cast<const DecimalOrVectorCol<Weight> &>(*columns[1]);
|
||||||
|
|
||||||
this->data(place).numerator += static_cast<ValueT>(
|
this->data(place).numerator += static_cast<Numerator>(
|
||||||
static_cast<const DecimalOrVectorCol<Value> &>(*columns[0]).getData()[row_num]) *
|
static_cast<const DecimalOrVectorCol<Value> &>(*columns[0]).getData()[row_num]) *
|
||||||
static_cast<ValueT>(weights.getData()[row_num]);
|
static_cast<Numerator>(weights.getData()[row_num]);
|
||||||
|
|
||||||
this->data(place).denominator += static_cast<AvgWeightedFieldType<Weight>>(weights.getData()[row_num]);
|
this->data(place).denominator += static_cast<Denominator>(weights.getData()[row_num]);
|
||||||
}
|
}
|
||||||
|
|
||||||
String getName() const override { return "avgWeighted"; }
|
String getName() const override { return "avgWeighted"; }
|
||||||
|
|
||||||
|
#if USE_EMBEDDED_COMPILER
|
||||||
|
|
||||||
|
bool isCompilable() const override
|
||||||
|
{
|
||||||
|
bool can_be_compiled = Base::isCompilable();
|
||||||
|
can_be_compiled &= canBeNativeType<Weight>();
|
||||||
|
|
||||||
|
return can_be_compiled;
|
||||||
|
}
|
||||||
|
|
||||||
|
void compileAdd(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr, const DataTypes & arguments_types, const std::vector<llvm::Value *> & argument_values) const override
|
||||||
|
{
|
||||||
|
llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder);
|
||||||
|
|
||||||
|
auto * numerator_type = toNativeType<Numerator>(b);
|
||||||
|
|
||||||
|
auto * numerator_ptr = b.CreatePointerCast(aggregate_data_ptr, numerator_type->getPointerTo());
|
||||||
|
auto * numerator_value = b.CreateLoad(numerator_type, numerator_ptr);
|
||||||
|
|
||||||
|
const auto & argument = nativeCast(b, arguments_types[0], argument_values[0], numerator_type);
|
||||||
|
const auto & weight = nativeCast(b, arguments_types[1], argument_values[1], numerator_type);
|
||||||
|
|
||||||
|
llvm::Value * value_weight_multiplication = argument->getType()->isIntegerTy() ? b.CreateMul(argument, weight) : b.CreateFMul(argument, weight);
|
||||||
|
|
||||||
|
/// TODO: Fix accuracy
|
||||||
|
auto * numerator_result_value = numerator_type->isIntegerTy() ? b.CreateAdd(numerator_value, value_weight_multiplication) : b.CreateFAdd(numerator_value, value_weight_multiplication);
|
||||||
|
b.CreateStore(numerator_result_value, numerator_ptr);
|
||||||
|
|
||||||
|
auto * denominator_type = toNativeType<Denominator>(b);
|
||||||
|
|
||||||
|
auto * denominator_offset_ptr = b.CreateConstGEP1_32(nullptr, aggregate_data_ptr, sizeof(Numerator));
|
||||||
|
auto * denominator_ptr = b.CreatePointerCast(denominator_offset_ptr, denominator_type->getPointerTo());
|
||||||
|
|
||||||
|
auto * weight_cast_to_denominator = nativeCast(b, arguments_types[1], argument_values[1], numerator_type);
|
||||||
|
|
||||||
|
auto * denominator_value = b.CreateLoad(denominator_type, denominator_ptr);
|
||||||
|
auto * denominator_value_updated = numerator_type->isIntegerTy() ? b.CreateAdd(denominator_value, weight_cast_to_denominator) : b.CreateFAdd(denominator_value, weight_cast_to_denominator);
|
||||||
|
|
||||||
|
b.CreateStore(denominator_value_updated, denominator_ptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
@ -106,6 +106,16 @@ public:
|
|||||||
this->nested_function->add(this->nestedPlace(place), &nested_column, row_num, arena);
|
this->nested_function->add(this->nestedPlace(place), &nested_column, row_num, arena);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#if USE_EMBEDDED_COMPILER
|
||||||
|
|
||||||
|
bool isCompilable() const override
|
||||||
|
{
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
};
|
};
|
||||||
|
|
||||||
template <bool result_is_nullable, bool serialize_flag, bool null_is_skipped>
|
template <bool result_is_nullable, bool serialize_flag, bool null_is_skipped>
|
||||||
@ -168,6 +178,15 @@ public:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#if USE_EMBEDDED_COMPILER
|
||||||
|
|
||||||
|
bool isCompilable() const override
|
||||||
|
{
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
private:
|
private:
|
||||||
using Base = AggregateFunctionNullBase<result_is_nullable, serialize_flag,
|
using Base = AggregateFunctionNullBase<result_is_nullable, serialize_flag,
|
||||||
AggregateFunctionIfNullVariadic<result_is_nullable, serialize_flag, null_is_skipped>>;
|
AggregateFunctionIfNullVariadic<result_is_nullable, serialize_flag, null_is_skipped>>;
|
||||||
|
@ -192,50 +192,6 @@ public:
|
|||||||
}
|
}
|
||||||
|
|
||||||
AggregateFunctionPtr getNestedFunction() const override { return nested_function; }
|
AggregateFunctionPtr getNestedFunction() const override { return nested_function; }
|
||||||
};
|
|
||||||
|
|
||||||
|
|
||||||
/** There are two cases: for single argument and variadic.
|
|
||||||
* Code for single argument is much more efficient.
|
|
||||||
*/
|
|
||||||
template <bool result_is_nullable, bool serialize_flag>
|
|
||||||
class AggregateFunctionNullUnary final
|
|
||||||
: public AggregateFunctionNullBase<result_is_nullable, serialize_flag,
|
|
||||||
AggregateFunctionNullUnary<result_is_nullable, serialize_flag>>
|
|
||||||
{
|
|
||||||
public:
|
|
||||||
AggregateFunctionNullUnary(AggregateFunctionPtr nested_function_, const DataTypes & arguments, const Array & params)
|
|
||||||
: AggregateFunctionNullBase<result_is_nullable, serialize_flag,
|
|
||||||
AggregateFunctionNullUnary<result_is_nullable, serialize_flag>>(std::move(nested_function_), arguments, params)
|
|
||||||
{
|
|
||||||
}
|
|
||||||
|
|
||||||
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override
|
|
||||||
{
|
|
||||||
const ColumnNullable * column = assert_cast<const ColumnNullable *>(columns[0]);
|
|
||||||
const IColumn * nested_column = &column->getNestedColumn();
|
|
||||||
if (!column->isNullAt(row_num))
|
|
||||||
{
|
|
||||||
this->setFlag(place);
|
|
||||||
this->nested_function->add(this->nestedPlace(place), &nested_column, row_num, arena);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void addBatchSinglePlace(
|
|
||||||
size_t batch_size, AggregateDataPtr place, const IColumn ** columns, Arena * arena, ssize_t if_argument_pos = -1) const override
|
|
||||||
{
|
|
||||||
const ColumnNullable * column = assert_cast<const ColumnNullable *>(columns[0]);
|
|
||||||
const IColumn * nested_column = &column->getNestedColumn();
|
|
||||||
const UInt8 * null_map = column->getNullMapData().data();
|
|
||||||
|
|
||||||
this->nested_function->addBatchSinglePlaceNotNull(
|
|
||||||
batch_size, this->nestedPlace(place), &nested_column, null_map, arena, if_argument_pos);
|
|
||||||
|
|
||||||
if constexpr (result_is_nullable)
|
|
||||||
if (!memoryIsByte(null_map, batch_size, 1))
|
|
||||||
this->setFlag(place);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
#if USE_EMBEDDED_COMPILER
|
#if USE_EMBEDDED_COMPILER
|
||||||
|
|
||||||
@ -258,38 +214,6 @@ public:
|
|||||||
this->nested_function->compileCreate(b, aggregate_data_ptr_with_prefix_size_offset);
|
this->nested_function->compileCreate(b, aggregate_data_ptr_with_prefix_size_offset);
|
||||||
}
|
}
|
||||||
|
|
||||||
void compileAdd(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr, const DataTypes & arguments_types, const std::vector<llvm::Value *> & argument_values) const override
|
|
||||||
{
|
|
||||||
|
|
||||||
llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder);
|
|
||||||
|
|
||||||
const auto & nullable_type = arguments_types[0];
|
|
||||||
const auto & nullable_value = argument_values[0];
|
|
||||||
|
|
||||||
auto * wrapped_value = b.CreateExtractValue(nullable_value, {0});
|
|
||||||
auto * is_null_value = b.CreateExtractValue(nullable_value, {1});
|
|
||||||
|
|
||||||
auto * head = b.GetInsertBlock();
|
|
||||||
|
|
||||||
auto * join_block = llvm::BasicBlock::Create(head->getContext(), "join_block", head->getParent());
|
|
||||||
auto * if_null = llvm::BasicBlock::Create(head->getContext(), "if_null", head->getParent());
|
|
||||||
auto * if_not_null = llvm::BasicBlock::Create(head->getContext(), "if_not_null", head->getParent());
|
|
||||||
|
|
||||||
b.CreateCondBr(is_null_value, if_null, if_not_null);
|
|
||||||
|
|
||||||
b.SetInsertPoint(if_null);
|
|
||||||
b.CreateBr(join_block);
|
|
||||||
|
|
||||||
b.SetInsertPoint(if_not_null);
|
|
||||||
b.CreateStore(llvm::ConstantInt::get(b.getInt8Ty(), 1), aggregate_data_ptr);
|
|
||||||
auto * aggregate_data_ptr_with_prefix_size_offset = b.CreateConstGEP1_32(nullptr, aggregate_data_ptr, this->prefix_size);
|
|
||||||
this->nested_function->compileAdd(b, aggregate_data_ptr_with_prefix_size_offset, { removeNullable(nullable_type) }, { wrapped_value });
|
|
||||||
b.CreateBr(join_block);
|
|
||||||
|
|
||||||
b.SetInsertPoint(join_block);
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
void compileMerge(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_dst_ptr, llvm::Value * aggregate_data_src_ptr) const override
|
void compileMerge(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_dst_ptr, llvm::Value * aggregate_data_src_ptr) const override
|
||||||
{
|
{
|
||||||
llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder);
|
llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder);
|
||||||
@ -357,6 +281,85 @@ public:
|
|||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
/** There are two cases: for single argument and variadic.
|
||||||
|
* Code for single argument is much more efficient.
|
||||||
|
*/
|
||||||
|
template <bool result_is_nullable, bool serialize_flag>
|
||||||
|
class AggregateFunctionNullUnary final
|
||||||
|
: public AggregateFunctionNullBase<result_is_nullable, serialize_flag,
|
||||||
|
AggregateFunctionNullUnary<result_is_nullable, serialize_flag>>
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
AggregateFunctionNullUnary(AggregateFunctionPtr nested_function_, const DataTypes & arguments, const Array & params)
|
||||||
|
: AggregateFunctionNullBase<result_is_nullable, serialize_flag,
|
||||||
|
AggregateFunctionNullUnary<result_is_nullable, serialize_flag>>(std::move(nested_function_), arguments, params)
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override
|
||||||
|
{
|
||||||
|
const ColumnNullable * column = assert_cast<const ColumnNullable *>(columns[0]);
|
||||||
|
const IColumn * nested_column = &column->getNestedColumn();
|
||||||
|
if (!column->isNullAt(row_num))
|
||||||
|
{
|
||||||
|
this->setFlag(place);
|
||||||
|
this->nested_function->add(this->nestedPlace(place), &nested_column, row_num, arena);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void addBatchSinglePlace(
|
||||||
|
size_t batch_size, AggregateDataPtr place, const IColumn ** columns, Arena * arena, ssize_t if_argument_pos = -1) const override
|
||||||
|
{
|
||||||
|
const ColumnNullable * column = assert_cast<const ColumnNullable *>(columns[0]);
|
||||||
|
const IColumn * nested_column = &column->getNestedColumn();
|
||||||
|
const UInt8 * null_map = column->getNullMapData().data();
|
||||||
|
|
||||||
|
this->nested_function->addBatchSinglePlaceNotNull(
|
||||||
|
batch_size, this->nestedPlace(place), &nested_column, null_map, arena, if_argument_pos);
|
||||||
|
|
||||||
|
if constexpr (result_is_nullable)
|
||||||
|
if (!memoryIsByte(null_map, batch_size, 1))
|
||||||
|
this->setFlag(place);
|
||||||
|
}
|
||||||
|
|
||||||
|
#if USE_EMBEDDED_COMPILER
|
||||||
|
|
||||||
|
void compileAdd(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr, const DataTypes & arguments_types, const std::vector<llvm::Value *> & argument_values) const override
|
||||||
|
{
|
||||||
|
llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder);
|
||||||
|
|
||||||
|
const auto & nullable_type = arguments_types[0];
|
||||||
|
const auto & nullable_value = argument_values[0];
|
||||||
|
|
||||||
|
auto * wrapped_value = b.CreateExtractValue(nullable_value, {0});
|
||||||
|
auto * is_null_value = b.CreateExtractValue(nullable_value, {1});
|
||||||
|
|
||||||
|
auto * head = b.GetInsertBlock();
|
||||||
|
|
||||||
|
auto * join_block = llvm::BasicBlock::Create(head->getContext(), "join_block", head->getParent());
|
||||||
|
auto * if_null = llvm::BasicBlock::Create(head->getContext(), "if_null", head->getParent());
|
||||||
|
auto * if_not_null = llvm::BasicBlock::Create(head->getContext(), "if_not_null", head->getParent());
|
||||||
|
|
||||||
|
b.CreateCondBr(is_null_value, if_null, if_not_null);
|
||||||
|
|
||||||
|
b.SetInsertPoint(if_null);
|
||||||
|
b.CreateBr(join_block);
|
||||||
|
|
||||||
|
b.SetInsertPoint(if_not_null);
|
||||||
|
b.CreateStore(llvm::ConstantInt::get(b.getInt8Ty(), 1), aggregate_data_ptr);
|
||||||
|
auto * aggregate_data_ptr_with_prefix_size_offset = b.CreateConstGEP1_32(nullptr, aggregate_data_ptr, this->prefix_size);
|
||||||
|
this->nested_function->compileAdd(b, aggregate_data_ptr_with_prefix_size_offset, { removeNullable(nullable_type) }, { wrapped_value });
|
||||||
|
b.CreateBr(join_block);
|
||||||
|
|
||||||
|
b.SetInsertPoint(join_block);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
template <bool result_is_nullable, bool serialize_flag, bool null_is_skipped>
|
template <bool result_is_nullable, bool serialize_flag, bool null_is_skipped>
|
||||||
class AggregateFunctionNullVariadic final
|
class AggregateFunctionNullVariadic final
|
||||||
: public AggregateFunctionNullBase<result_is_nullable, serialize_flag,
|
: public AggregateFunctionNullBase<result_is_nullable, serialize_flag,
|
||||||
@ -405,6 +408,87 @@ public:
|
|||||||
this->nested_function->add(this->nestedPlace(place), nested_columns, row_num, arena);
|
this->nested_function->add(this->nestedPlace(place), nested_columns, row_num, arena);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
#if USE_EMBEDDED_COMPILER
|
||||||
|
|
||||||
|
void compileAdd(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr, const DataTypes & arguments_types, const std::vector<llvm::Value *> & argument_values) const override
|
||||||
|
{
|
||||||
|
llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder);
|
||||||
|
|
||||||
|
size_t arguments_size = arguments_types.size();
|
||||||
|
|
||||||
|
DataTypes non_nullable_types;
|
||||||
|
std::vector<llvm::Value * > wrapped_values;
|
||||||
|
std::vector<llvm::Value * > is_null_values;
|
||||||
|
|
||||||
|
non_nullable_types.resize(arguments_size);
|
||||||
|
wrapped_values.resize(arguments_size);
|
||||||
|
is_null_values.resize(arguments_size);
|
||||||
|
|
||||||
|
for (size_t i = 0; i < arguments_size; ++i)
|
||||||
|
{
|
||||||
|
const auto & argument_value = argument_values[i];
|
||||||
|
|
||||||
|
if (is_nullable[i])
|
||||||
|
{
|
||||||
|
auto * wrapped_value = b.CreateExtractValue(argument_value, {0});
|
||||||
|
|
||||||
|
if constexpr (null_is_skipped)
|
||||||
|
is_null_values[i] = b.CreateExtractValue(argument_value, {1});
|
||||||
|
|
||||||
|
wrapped_values[i] = wrapped_value;
|
||||||
|
non_nullable_types[i] = removeNullable(arguments_types[i]);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
wrapped_values[i] = argument_value;
|
||||||
|
non_nullable_types[i] = arguments_types[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if constexpr (null_is_skipped)
|
||||||
|
{
|
||||||
|
auto * head = b.GetInsertBlock();
|
||||||
|
|
||||||
|
auto * join_block = llvm::BasicBlock::Create(head->getContext(), "join_block", head->getParent());
|
||||||
|
auto * if_null = llvm::BasicBlock::Create(head->getContext(), "if_null", head->getParent());
|
||||||
|
auto * if_not_null = llvm::BasicBlock::Create(head->getContext(), "if_not_null", head->getParent());
|
||||||
|
|
||||||
|
auto * values_have_null_ptr = b.CreateAlloca(b.getInt1Ty());
|
||||||
|
b.CreateStore(b.getInt1(false), values_have_null_ptr);
|
||||||
|
|
||||||
|
for (auto * is_null_value : is_null_values)
|
||||||
|
{
|
||||||
|
if (!is_null_value)
|
||||||
|
continue;
|
||||||
|
|
||||||
|
auto * values_have_null = b.CreateLoad(b.getInt1Ty(), values_have_null_ptr);
|
||||||
|
b.CreateStore(b.CreateOr(values_have_null, is_null_value), values_have_null_ptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
b.CreateCondBr(b.CreateLoad(b.getInt1Ty(), values_have_null_ptr), if_null, if_not_null);
|
||||||
|
|
||||||
|
b.SetInsertPoint(if_null);
|
||||||
|
b.CreateBr(join_block);
|
||||||
|
|
||||||
|
b.SetInsertPoint(if_not_null);
|
||||||
|
b.CreateStore(llvm::ConstantInt::get(b.getInt8Ty(), 1), aggregate_data_ptr);
|
||||||
|
auto * aggregate_data_ptr_with_prefix_size_offset = b.CreateConstGEP1_32(nullptr, aggregate_data_ptr, this->prefix_size);
|
||||||
|
this->nested_function->compileAdd(b, aggregate_data_ptr_with_prefix_size_offset, arguments_types, wrapped_values);
|
||||||
|
b.CreateBr(join_block);
|
||||||
|
|
||||||
|
b.SetInsertPoint(join_block);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
b.CreateStore(llvm::ConstantInt::get(b.getInt8Ty(), 1), aggregate_data_ptr);
|
||||||
|
auto * aggregate_data_ptr_with_prefix_size_offset = b.CreateConstGEP1_32(nullptr, aggregate_data_ptr, this->prefix_size);
|
||||||
|
this->nested_function->compileAdd(b, aggregate_data_ptr_with_prefix_size_offset, non_nullable_types, wrapped_values);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
private:
|
private:
|
||||||
enum { MAX_ARGS = 8 };
|
enum { MAX_ARGS = 8 };
|
||||||
size_t number_of_arguments = 0;
|
size_t number_of_arguments = 0;
|
||||||
|
@ -48,6 +48,15 @@ public:
|
|||||||
|
|
||||||
String getName() const final { return "sumCount"; }
|
String getName() const final { return "sumCount"; }
|
||||||
|
|
||||||
|
#if USE_EMBEDDED_COMPILER
|
||||||
|
|
||||||
|
bool isCompilable() const override
|
||||||
|
{
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
private:
|
private:
|
||||||
UInt32 scale;
|
UInt32 scale;
|
||||||
};
|
};
|
||||||
|
@ -251,6 +251,30 @@ public:
|
|||||||
// of true window functions, so this hack-ish interface suffices.
|
// of true window functions, so this hack-ish interface suffices.
|
||||||
virtual bool isOnlyWindowFunction() const { return false; }
|
virtual bool isOnlyWindowFunction() const { return false; }
|
||||||
|
|
||||||
|
virtual String getDescription() const
|
||||||
|
{
|
||||||
|
String description;
|
||||||
|
|
||||||
|
description += getName();
|
||||||
|
description += '(';
|
||||||
|
|
||||||
|
for (const auto & argument_type : argument_types)
|
||||||
|
{
|
||||||
|
description += argument_type->getName();
|
||||||
|
description += ", ";
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!argument_types.empty())
|
||||||
|
{
|
||||||
|
description.pop_back();
|
||||||
|
description.pop_back();
|
||||||
|
}
|
||||||
|
|
||||||
|
description += ')';
|
||||||
|
|
||||||
|
return description;
|
||||||
|
}
|
||||||
|
|
||||||
#if USE_EMBEDDED_COMPILER
|
#if USE_EMBEDDED_COMPILER
|
||||||
|
|
||||||
virtual bool isCompilable() const { return false; }
|
virtual bool isCompilable() const { return false; }
|
||||||
|
@ -80,6 +80,25 @@ static inline llvm::Type * toNativeType(llvm::IRBuilderBase & builder)
|
|||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename Type>
|
||||||
|
static inline bool canBeNativeType()
|
||||||
|
{
|
||||||
|
if constexpr (std::is_same_v<Type, Int8> || std::is_same_v<Type, UInt8>)
|
||||||
|
return true;
|
||||||
|
else if constexpr (std::is_same_v<Type, Int16> || std::is_same_v<Type, UInt16>)
|
||||||
|
return true;
|
||||||
|
else if constexpr (std::is_same_v<Type, Int32> || std::is_same_v<Type, UInt32>)
|
||||||
|
return true;
|
||||||
|
else if constexpr (std::is_same_v<Type, Int64> || std::is_same_v<Type, UInt64>)
|
||||||
|
return true;
|
||||||
|
else if constexpr (std::is_same_v<Type, Float32>)
|
||||||
|
return true;
|
||||||
|
else if constexpr (std::is_same_v<Type, Float64>)
|
||||||
|
return true;
|
||||||
|
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
static inline bool canBeNativeType(const IDataType & type)
|
static inline bool canBeNativeType(const IDataType & type)
|
||||||
{
|
{
|
||||||
WhichDataType data_type(type);
|
WhichDataType data_type(type);
|
||||||
@ -180,6 +199,37 @@ static inline llvm::Value * nativeCast(llvm::IRBuilder<> & b, const DataTypePtr
|
|||||||
return nativeCast(b, from, value, n_to);
|
return nativeCast(b, from, value, n_to);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static inline std::pair<llvm::Value *, llvm::Value *> nativeCastToCommon(llvm::IRBuilder<> & b, const DataTypePtr & lhs_type, llvm::Value * lhs, const DataTypePtr & rhs_type, llvm::Value * rhs)
|
||||||
|
{
|
||||||
|
llvm::Type * common;
|
||||||
|
|
||||||
|
bool lhs_is_signed = typeIsSigned(*lhs_type);
|
||||||
|
bool rhs_is_signed = typeIsSigned(*rhs_type);
|
||||||
|
|
||||||
|
if (lhs->getType()->isIntegerTy() && rhs->getType()->isIntegerTy())
|
||||||
|
{
|
||||||
|
/// if one integer has a sign bit, make sure the other does as well. llvm generates optimal code
|
||||||
|
/// (e.g. uses overflow flag on x86) for (word size + 1)-bit integer operations.
|
||||||
|
|
||||||
|
size_t lhs_bit_width = lhs->getType()->getIntegerBitWidth() + (!lhs_is_signed && rhs_is_signed);
|
||||||
|
size_t rhs_bit_width = rhs->getType()->getIntegerBitWidth() + (!rhs_is_signed && lhs_is_signed);
|
||||||
|
|
||||||
|
size_t max_bit_width = std::max(lhs_bit_width, rhs_bit_width);
|
||||||
|
common = b.getIntNTy(max_bit_width);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
/// TODO: Check
|
||||||
|
/// (double, float) or (double, int_N where N <= double's mantissa width) -> double
|
||||||
|
common = b.getDoubleTy();
|
||||||
|
}
|
||||||
|
|
||||||
|
auto * cast_lhs_to_common = nativeCast(b, lhs_type, lhs, common);
|
||||||
|
auto * cast_rhs_to_common = nativeCast(b, rhs_type, rhs, common);
|
||||||
|
|
||||||
|
return std::make_pair(cast_lhs_to_common, cast_rhs_to_common);
|
||||||
|
}
|
||||||
|
|
||||||
static inline llvm::Constant * getColumnNativeValue(llvm::IRBuilderBase & builder, const DataTypePtr & column_type, const IColumn & column, size_t index)
|
static inline llvm::Constant * getColumnNativeValue(llvm::IRBuilderBase & builder, const DataTypePtr & column_type, const IColumn & column, size_t index)
|
||||||
{
|
{
|
||||||
if (const auto * constant = typeid_cast<const ColumnConst *>(&column))
|
if (const auto * constant = typeid_cast<const ColumnConst *>(&column))
|
||||||
|
@ -1265,23 +1265,7 @@ public:
|
|||||||
assert(2 == types.size() && 2 == values.size());
|
assert(2 == types.size() && 2 == values.size());
|
||||||
|
|
||||||
auto & b = static_cast<llvm::IRBuilder<> &>(builder);
|
auto & b = static_cast<llvm::IRBuilder<> &>(builder);
|
||||||
auto * x = values[0];
|
auto [x, y] = nativeCastToCommon(b, types[0], values[0], types[1], values[1]);
|
||||||
auto * y = values[1];
|
|
||||||
if (!types[0]->equals(*types[1]))
|
|
||||||
{
|
|
||||||
llvm::Type * common;
|
|
||||||
if (x->getType()->isIntegerTy() && y->getType()->isIntegerTy())
|
|
||||||
common = b.getIntNTy(std::max(
|
|
||||||
/// if one integer has a sign bit, make sure the other does as well. llvm generates optimal code
|
|
||||||
/// (e.g. uses overflow flag on x86) for (word size + 1)-bit integer operations.
|
|
||||||
x->getType()->getIntegerBitWidth() + (!typeIsSigned(*types[0]) && typeIsSigned(*types[1])),
|
|
||||||
y->getType()->getIntegerBitWidth() + (!typeIsSigned(*types[1]) && typeIsSigned(*types[0]))));
|
|
||||||
else
|
|
||||||
/// (double, float) or (double, int_N where N <= double's mantissa width) -> double
|
|
||||||
common = b.getDoubleTy();
|
|
||||||
x = nativeCast(b, types[0], x, common);
|
|
||||||
y = nativeCast(b, types[1], y, common);
|
|
||||||
}
|
|
||||||
auto * result = CompileOp<Op>::compile(b, x, y, typeIsSigned(*types[0]) || typeIsSigned(*types[1]));
|
auto * result = CompileOp<Op>::compile(b, x, y, typeIsSigned(*types[0]) || typeIsSigned(*types[1]));
|
||||||
return b.CreateSelect(result, b.getInt8(1), b.getInt8(0));
|
return b.CreateSelect(result, b.getInt8(1), b.getInt8(0));
|
||||||
}
|
}
|
||||||
|
@ -222,32 +222,6 @@ static CHJIT & getJITInstance()
|
|||||||
return jit;
|
return jit;
|
||||||
}
|
}
|
||||||
|
|
||||||
static std::string dumpAggregateFunction(const IAggregateFunction * function)
|
|
||||||
{
|
|
||||||
std::string function_dump;
|
|
||||||
|
|
||||||
auto return_type_name = function->getReturnType()->getName();
|
|
||||||
|
|
||||||
function_dump += return_type_name;
|
|
||||||
function_dump += ' ';
|
|
||||||
function_dump += function->getName();
|
|
||||||
function_dump += '(';
|
|
||||||
|
|
||||||
const auto & argument_types = function->getArgumentTypes();
|
|
||||||
for (const auto & argument_type : argument_types)
|
|
||||||
{
|
|
||||||
function_dump += argument_type->getName();
|
|
||||||
function_dump += ',';
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!argument_types.empty())
|
|
||||||
function_dump.pop_back();
|
|
||||||
|
|
||||||
function_dump += ')';
|
|
||||||
|
|
||||||
return function_dump;
|
|
||||||
}
|
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
Aggregator::Aggregator(const Params & params_)
|
Aggregator::Aggregator(const Params & params_)
|
||||||
@ -317,7 +291,7 @@ void Aggregator::compileAggregateFunctions()
|
|||||||
|
|
||||||
std::vector<AggregateFunctionWithOffset> functions_to_compile;
|
std::vector<AggregateFunctionWithOffset> functions_to_compile;
|
||||||
size_t aggregate_instructions_size = 0;
|
size_t aggregate_instructions_size = 0;
|
||||||
std::string functions_dump;
|
String functions_description;
|
||||||
|
|
||||||
/// Add values to the aggregate functions.
|
/// Add values to the aggregate functions.
|
||||||
for (size_t i = 0; i < aggregate_functions.size(); ++i)
|
for (size_t i = 0; i < aggregate_functions.size(); ++i)
|
||||||
@ -333,11 +307,10 @@ void Aggregator::compileAggregateFunctions()
|
|||||||
.aggregate_data_offset = offset_of_aggregate_function
|
.aggregate_data_offset = offset_of_aggregate_function
|
||||||
};
|
};
|
||||||
|
|
||||||
std::string function_dump = dumpAggregateFunction(function);
|
|
||||||
functions_dump += function_dump;
|
|
||||||
functions_dump += ' ';
|
|
||||||
|
|
||||||
functions_to_compile.emplace_back(std::move(function_to_compile));
|
functions_to_compile.emplace_back(std::move(function_to_compile));
|
||||||
|
|
||||||
|
functions_description += function->getDescription();
|
||||||
|
functions_description += ' ';
|
||||||
}
|
}
|
||||||
|
|
||||||
++aggregate_instructions_size;
|
++aggregate_instructions_size;
|
||||||
@ -354,20 +327,21 @@ void Aggregator::compileAggregateFunctions()
|
|||||||
|
|
||||||
std::lock_guard<std::mutex> lock(mtx);
|
std::lock_guard<std::mutex> lock(mtx);
|
||||||
|
|
||||||
auto it = aggregation_functions_dump_to_add_compiled.find(functions_dump);
|
auto it = aggregation_functions_dump_to_add_compiled.find(functions_description);
|
||||||
if (it != aggregation_functions_dump_to_add_compiled.end())
|
if (it != aggregation_functions_dump_to_add_compiled.end())
|
||||||
{
|
{
|
||||||
compiled_aggregate_functions = it->second;
|
compiled_aggregate_functions = it->second;
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
LOG_TRACE(log, "Compile expression {}", functions_dump);
|
LOG_TRACE(log, "Compile expression {}", functions_description);
|
||||||
|
|
||||||
compiled_aggregate_functions = compileAggregateFunctons(getJITInstance(), functions_to_compile, functions_dump);
|
compiled_aggregate_functions = compileAggregateFunctons(getJITInstance(), functions_to_compile, functions_description);
|
||||||
aggregation_functions_dump_to_add_compiled[functions_dump] = compiled_aggregate_functions;
|
aggregation_functions_dump_to_add_compiled[functions_description] = compiled_aggregate_functions;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
LOG_TRACE(log, "Use compiled expression {}", functions_description);
|
||||||
compiled_functions.emplace(std::move(compiled_aggregate_functions));
|
compiled_functions.emplace(std::move(compiled_aggregate_functions));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -0,0 +1,12 @@
|
|||||||
|
Test unsigned integer values
|
||||||
|
0 1140 1140 1140 1140
|
||||||
|
1 1220 1220 1220 1220
|
||||||
|
2 1180 1180 1180 1180
|
||||||
|
Test signed integer values
|
||||||
|
0 1140 1140 1140 1140
|
||||||
|
1 1220 1220 1220 1220
|
||||||
|
2 1180 1180 1180 1180
|
||||||
|
Test float values
|
||||||
|
0 1140 1140
|
||||||
|
1 1220 1220
|
||||||
|
2 1180 1180
|
141
tests/queries/0_stateless/01896_jit_aggregation_function_if.sql
Normal file
141
tests/queries/0_stateless/01896_jit_aggregation_function_if.sql
Normal file
@ -0,0 +1,141 @@
|
|||||||
|
SET compile_aggregate_expressions = 1;
|
||||||
|
SET min_count_to_compile_aggregate_expression = 0;
|
||||||
|
|
||||||
|
SELECT 'Test unsigned integer values';
|
||||||
|
|
||||||
|
DROP TABLE IF EXISTS test_table_unsigned_values;
|
||||||
|
CREATE TABLE test_table_unsigned_values
|
||||||
|
(
|
||||||
|
id UInt64,
|
||||||
|
|
||||||
|
value1 UInt8,
|
||||||
|
value2 UInt16,
|
||||||
|
value3 UInt32,
|
||||||
|
value4 UInt64,
|
||||||
|
|
||||||
|
predicate_value UInt8
|
||||||
|
) ENGINE=TinyLog;
|
||||||
|
|
||||||
|
INSERT INTO test_table_unsigned_values SELECT number % 3, number, number, number, number, if(number % 2 == 0, 1, 0) FROM system.numbers LIMIT 120;
|
||||||
|
SELECT
|
||||||
|
id,
|
||||||
|
sumIf(value1, predicate_value),
|
||||||
|
sumIf(value2, predicate_value),
|
||||||
|
sumIf(value3, predicate_value),
|
||||||
|
sumIf(value4, predicate_value)
|
||||||
|
FROM test_table_unsigned_values GROUP BY id ORDER BY id;
|
||||||
|
DROP TABLE test_table_unsigned_values;
|
||||||
|
|
||||||
|
SELECT 'Test signed integer values';
|
||||||
|
|
||||||
|
DROP TABLE IF EXISTS test_table_signed_values;
|
||||||
|
CREATE TABLE test_table_signed_values
|
||||||
|
(
|
||||||
|
id UInt64,
|
||||||
|
|
||||||
|
value1 Int8,
|
||||||
|
value2 Int16,
|
||||||
|
value3 Int32,
|
||||||
|
value4 Int64,
|
||||||
|
|
||||||
|
predicate_value UInt8
|
||||||
|
) ENGINE=TinyLog;
|
||||||
|
|
||||||
|
INSERT INTO test_table_signed_values SELECT number % 3, number, number, number, number, if(number % 2 == 0, 1, 0) FROM system.numbers LIMIT 120;
|
||||||
|
SELECT
|
||||||
|
id,
|
||||||
|
sumIf(value1, predicate_value),
|
||||||
|
sumIf(value2, predicate_value),
|
||||||
|
sumIf(value3, predicate_value),
|
||||||
|
sumIf(value4, predicate_value)
|
||||||
|
FROM test_table_signed_values GROUP BY id ORDER BY id;
|
||||||
|
DROP TABLE test_table_signed_values;
|
||||||
|
|
||||||
|
SELECT 'Test float values';
|
||||||
|
|
||||||
|
DROP TABLE IF EXISTS test_table_float_values;
|
||||||
|
CREATE TABLE test_table_float_values
|
||||||
|
(
|
||||||
|
id UInt64,
|
||||||
|
|
||||||
|
value1 Float32,
|
||||||
|
value2 Float64,
|
||||||
|
|
||||||
|
predicate_value UInt8
|
||||||
|
) ENGINE=TinyLog;
|
||||||
|
|
||||||
|
INSERT INTO test_table_float_values SELECT number % 3, number, number, if(number % 2 == 0, 1, 0) FROM system.numbers LIMIT 120;
|
||||||
|
SELECT
|
||||||
|
id,
|
||||||
|
sumIf(value1, predicate_value),
|
||||||
|
sumIf(value2, predicate_value)
|
||||||
|
FROM test_table_float_values GROUP BY id ORDER BY id;
|
||||||
|
DROP TABLE test_table_float_values;
|
||||||
|
|
||||||
|
-- SELECT 'Test nullable unsigned integer values';
|
||||||
|
|
||||||
|
-- DROP TABLE IF EXISTS test_table_nullable_unsigned_values;
|
||||||
|
-- CREATE TABLE test_table_nullable_unsigned_values
|
||||||
|
-- (
|
||||||
|
-- id UInt64,
|
||||||
|
|
||||||
|
-- value1 Nullable(UInt8),
|
||||||
|
-- value2 Nullable(UInt16),
|
||||||
|
-- value3 Nullable(UInt32),
|
||||||
|
-- value4 Nullable(UInt64)
|
||||||
|
-- ) ENGINE=TinyLog;
|
||||||
|
|
||||||
|
-- INSERT INTO test_table_nullable_unsigned_values SELECT number % 3, number, number, number, number FROM system.numbers LIMIT 120;
|
||||||
|
-- SELECT id, sum(value1), sum(value2), sum(value3), sum(value4) FROM test_table_nullable_unsigned_values GROUP BY id ORDER BY id;
|
||||||
|
-- DROP TABLE test_table_nullable_unsigned_values;
|
||||||
|
|
||||||
|
-- SELECT 'Test nullable signed integer values';
|
||||||
|
|
||||||
|
-- DROP TABLE IF EXISTS test_table_nullable_signed_values;
|
||||||
|
-- CREATE TABLE test_table_nullable_signed_values
|
||||||
|
-- (
|
||||||
|
-- id UInt64,
|
||||||
|
|
||||||
|
-- value1 Nullable(Int8),
|
||||||
|
-- value2 Nullable(Int16),
|
||||||
|
-- value3 Nullable(Int32),
|
||||||
|
-- value4 Nullable(Int64)
|
||||||
|
-- ) ENGINE=TinyLog;
|
||||||
|
|
||||||
|
-- INSERT INTO test_table_nullable_signed_values SELECT number % 3, number, number, number, number FROM system.numbers LIMIT 120;
|
||||||
|
-- SELECT id, sum(value1), sum(value2), sum(value3), sum(value4) FROM test_table_nullable_signed_values GROUP BY id ORDER BY id;
|
||||||
|
-- DROP TABLE test_table_nullable_signed_values;
|
||||||
|
|
||||||
|
-- SELECT 'Test nullable float values';
|
||||||
|
|
||||||
|
-- DROP TABLE IF EXISTS test_table_nullable_float_values;
|
||||||
|
-- CREATE TABLE test_table_nullable_float_values
|
||||||
|
-- (
|
||||||
|
-- id UInt64,
|
||||||
|
|
||||||
|
-- value1 Nullable(Float32),
|
||||||
|
-- value2 Nullable(Float64)
|
||||||
|
-- ) ENGINE=TinyLog;
|
||||||
|
|
||||||
|
-- INSERT INTO test_table_nullable_float_values SELECT number % 3, number, number FROM system.numbers LIMIT 120;
|
||||||
|
-- SELECT id, sum(value1), sum(value2) FROM test_table_nullable_float_values GROUP BY id ORDER BY id;
|
||||||
|
-- DROP TABLE test_table_nullable_float_values;
|
||||||
|
|
||||||
|
-- SELECT 'Test null specifics';
|
||||||
|
|
||||||
|
-- DROP TABLE IF EXISTS test_table_null_specifics;
|
||||||
|
-- CREATE TABLE test_table_null_specifics
|
||||||
|
-- (
|
||||||
|
-- id UInt64,
|
||||||
|
|
||||||
|
-- value1 Nullable(UInt64),
|
||||||
|
-- value2 Nullable(UInt64),
|
||||||
|
-- value3 Nullable(UInt64)
|
||||||
|
-- ) ENGINE=TinyLog;
|
||||||
|
|
||||||
|
-- INSERT INTO test_table_null_specifics VALUES (0, 1, 1, NULL);
|
||||||
|
-- INSERT INTO test_table_null_specifics VALUES (0, 2, NULL, NULL);
|
||||||
|
-- INSERT INTO test_table_null_specifics VALUES (0, 3, 3, NULL);
|
||||||
|
|
||||||
|
-- SELECT id, sum(value1), sum(value2), sum(value3) FROM test_table_null_specifics GROUP BY id ORDER BY id;
|
||||||
|
-- DROP TABLE IF EXISTS test_table_null_specifics;
|
@ -0,0 +1,26 @@
|
|||||||
|
Test unsigned integer values
|
||||||
|
0 nan nan nan nan
|
||||||
|
1 59.5 59.5 59.5 59.5
|
||||||
|
2 60.5 60.5 60.5 60.5
|
||||||
|
Test signed integer values
|
||||||
|
0 nan nan nan nan
|
||||||
|
1 59.5 59.5 59.5 59.5
|
||||||
|
2 60.5 60.5 60.5 60.5
|
||||||
|
Test float values
|
||||||
|
0 nan nan
|
||||||
|
1 59.5 59.5
|
||||||
|
2 60.5 60.5
|
||||||
|
Test nullable unsigned integer values
|
||||||
|
0 nan nan nan nan
|
||||||
|
1 59.5 59.5 59.5 59.5
|
||||||
|
2 60.5 60.5 60.5 60.5
|
||||||
|
Test nullable signed integer values
|
||||||
|
0 nan nan nan nan
|
||||||
|
1 59.5 59.5 59.5 59.5
|
||||||
|
2 60.5 60.5 60.5 60.5
|
||||||
|
Test nullable float values
|
||||||
|
0 nan nan
|
||||||
|
1 59.5 59.5
|
||||||
|
2 60.5 60.5
|
||||||
|
Test null specifics
|
||||||
|
0 2.3333333333333335 2.5 \N 2.5 2.5 \N
|
@ -0,0 +1,167 @@
|
|||||||
|
SET compile_aggregate_expressions = 1;
|
||||||
|
SET min_count_to_compile_aggregate_expression = 0;
|
||||||
|
|
||||||
|
SELECT 'Test unsigned integer values';
|
||||||
|
|
||||||
|
DROP TABLE IF EXISTS test_table_unsigned_values;
|
||||||
|
CREATE TABLE test_table_unsigned_values
|
||||||
|
(
|
||||||
|
id UInt64,
|
||||||
|
|
||||||
|
value1 UInt8,
|
||||||
|
value2 UInt16,
|
||||||
|
value3 UInt32,
|
||||||
|
value4 UInt64,
|
||||||
|
|
||||||
|
weight UInt64
|
||||||
|
) ENGINE=TinyLog;
|
||||||
|
|
||||||
|
INSERT INTO test_table_unsigned_values SELECT number % 3, number, number, number, number, number % 3 FROM system.numbers LIMIT 120;
|
||||||
|
SELECT
|
||||||
|
id,
|
||||||
|
avgWeighted(value1, weight),
|
||||||
|
avgWeighted(value2, weight),
|
||||||
|
avgWeighted(value3, weight),
|
||||||
|
avgWeighted(value4, weight)
|
||||||
|
FROM test_table_unsigned_values GROUP BY id ORDER BY id;
|
||||||
|
DROP TABLE test_table_unsigned_values;
|
||||||
|
|
||||||
|
SELECT 'Test signed integer values';
|
||||||
|
|
||||||
|
DROP TABLE IF EXISTS test_table_signed_values;
|
||||||
|
CREATE TABLE test_table_signed_values
|
||||||
|
(
|
||||||
|
id UInt64,
|
||||||
|
|
||||||
|
value1 Int8,
|
||||||
|
value2 Int16,
|
||||||
|
value3 Int32,
|
||||||
|
value4 Int64,
|
||||||
|
|
||||||
|
weight UInt64
|
||||||
|
) ENGINE=TinyLog;
|
||||||
|
|
||||||
|
INSERT INTO test_table_signed_values SELECT number % 3, number, number, number, number, number % 3 FROM system.numbers LIMIT 120;
|
||||||
|
SELECT
|
||||||
|
id,
|
||||||
|
avgWeighted(value1, weight),
|
||||||
|
avgWeighted(value2, weight),
|
||||||
|
avgWeighted(value3, weight),
|
||||||
|
avgWeighted(value4, weight)
|
||||||
|
FROM test_table_signed_values GROUP BY id ORDER BY id;
|
||||||
|
DROP TABLE test_table_signed_values;
|
||||||
|
|
||||||
|
SELECT 'Test float values';
|
||||||
|
|
||||||
|
DROP TABLE IF EXISTS test_table_float_values;
|
||||||
|
CREATE TABLE test_table_float_values
|
||||||
|
(
|
||||||
|
id UInt64,
|
||||||
|
|
||||||
|
value1 Float32,
|
||||||
|
value2 Float64,
|
||||||
|
|
||||||
|
weight UInt64
|
||||||
|
) ENGINE=TinyLog;
|
||||||
|
|
||||||
|
INSERT INTO test_table_float_values SELECT number % 3, number, number, number % 3 FROM system.numbers LIMIT 120;
|
||||||
|
SELECT id, avgWeighted(value1, weight), avgWeighted(value2, weight) FROM test_table_float_values GROUP BY id ORDER BY id;
|
||||||
|
DROP TABLE test_table_float_values;
|
||||||
|
|
||||||
|
SELECT 'Test nullable unsigned integer values';
|
||||||
|
|
||||||
|
DROP TABLE IF EXISTS test_table_nullable_unsigned_values;
|
||||||
|
CREATE TABLE test_table_nullable_unsigned_values
|
||||||
|
(
|
||||||
|
id UInt64,
|
||||||
|
|
||||||
|
value1 Nullable(UInt8),
|
||||||
|
value2 Nullable(UInt16),
|
||||||
|
value3 Nullable(UInt32),
|
||||||
|
value4 Nullable(UInt64),
|
||||||
|
|
||||||
|
weight UInt64
|
||||||
|
) ENGINE=TinyLog;
|
||||||
|
|
||||||
|
INSERT INTO test_table_nullable_unsigned_values SELECT number % 3, number, number, number, number, number % 3 FROM system.numbers LIMIT 120;
|
||||||
|
SELECT
|
||||||
|
id,
|
||||||
|
avgWeighted(value1, weight),
|
||||||
|
avgWeighted(value2, weight),
|
||||||
|
avgWeighted(value3, weight),
|
||||||
|
avgWeighted(value4, weight)
|
||||||
|
FROM test_table_nullable_unsigned_values GROUP BY id ORDER BY id;
|
||||||
|
DROP TABLE test_table_nullable_unsigned_values;
|
||||||
|
|
||||||
|
SELECT 'Test nullable signed integer values';
|
||||||
|
|
||||||
|
DROP TABLE IF EXISTS test_table_nullable_signed_values;
|
||||||
|
CREATE TABLE test_table_nullable_signed_values
|
||||||
|
(
|
||||||
|
id UInt64,
|
||||||
|
|
||||||
|
value1 Nullable(Int8),
|
||||||
|
value2 Nullable(Int16),
|
||||||
|
value3 Nullable(Int32),
|
||||||
|
value4 Nullable(Int64),
|
||||||
|
|
||||||
|
weight UInt64
|
||||||
|
) ENGINE=TinyLog;
|
||||||
|
|
||||||
|
|
||||||
|
INSERT INTO test_table_nullable_signed_values SELECT number % 3, number, number, number, number, number % 3 FROM system.numbers LIMIT 120;
|
||||||
|
SELECT
|
||||||
|
id,
|
||||||
|
avgWeighted(value1, weight),
|
||||||
|
avgWeighted(value2, weight),
|
||||||
|
avgWeighted(value3, weight),
|
||||||
|
avgWeighted(value4, weight)
|
||||||
|
FROM test_table_nullable_signed_values GROUP BY id ORDER BY id;
|
||||||
|
DROP TABLE test_table_nullable_signed_values;
|
||||||
|
|
||||||
|
SELECT 'Test nullable float values';
|
||||||
|
|
||||||
|
DROP TABLE IF EXISTS test_table_nullable_float_values;
|
||||||
|
CREATE TABLE test_table_nullable_float_values
|
||||||
|
(
|
||||||
|
id UInt64,
|
||||||
|
|
||||||
|
value1 Nullable(Float32),
|
||||||
|
value2 Nullable(Float64),
|
||||||
|
|
||||||
|
weight UInt64
|
||||||
|
) ENGINE=TinyLog;
|
||||||
|
|
||||||
|
INSERT INTO test_table_nullable_float_values SELECT number % 3, number, number, number % 3 FROM system.numbers LIMIT 120;
|
||||||
|
SELECT id, avgWeighted(value1, weight), avgWeighted(value2, weight) FROM test_table_nullable_float_values GROUP BY id ORDER BY id;
|
||||||
|
DROP TABLE test_table_nullable_float_values;
|
||||||
|
|
||||||
|
SELECT 'Test null specifics';
|
||||||
|
|
||||||
|
DROP TABLE IF EXISTS test_table_null_specifics;
|
||||||
|
CREATE TABLE test_table_null_specifics
|
||||||
|
(
|
||||||
|
id UInt64,
|
||||||
|
|
||||||
|
value1 Nullable(UInt64),
|
||||||
|
value2 Nullable(UInt64),
|
||||||
|
value3 Nullable(UInt64),
|
||||||
|
|
||||||
|
weight UInt64,
|
||||||
|
weight_nullable Nullable(UInt64)
|
||||||
|
) ENGINE=TinyLog;
|
||||||
|
|
||||||
|
INSERT INTO test_table_null_specifics VALUES (0, 1, 1, NULL, 1, 1);
|
||||||
|
INSERT INTO test_table_null_specifics VALUES (0, 2, NULL, NULL, 2, NULL);
|
||||||
|
INSERT INTO test_table_null_specifics VALUES (0, 3, 3, NULL, 3, 3);
|
||||||
|
|
||||||
|
SELECT
|
||||||
|
id,
|
||||||
|
avgWeighted(value1, weight),
|
||||||
|
avgWeighted(value2, weight),
|
||||||
|
avgWeighted(value3, weight),
|
||||||
|
avgWeighted(value1, weight_nullable),
|
||||||
|
avgWeighted(value2, weight_nullable),
|
||||||
|
avgWeighted(value3, weight_nullable)
|
||||||
|
FROM test_table_null_specifics GROUP BY id ORDER BY id;
|
||||||
|
DROP TABLE IF EXISTS test_table_null_specifics;
|
Loading…
Reference in New Issue
Block a user