Compile AggregateFunctionAvgWeighted

This commit is contained in:
Maksim Kita 2021-06-06 18:43:03 +03:00
parent 56c1a4e447
commit 507d9405e2
13 changed files with 762 additions and 212 deletions

View File

@ -93,11 +93,13 @@ struct AvgFraction
* @tparam Derived When deriving from this class, use the child class name as in CRTP, e.g. * @tparam Derived When deriving from this class, use the child class name as in CRTP, e.g.
* class Self : Agg<char, bool, bool, Self>. * class Self : Agg<char, bool, bool, Self>.
*/ */
template <typename Numerator, typename Denominator, typename Derived> template <typename TNumerator, typename TDenominator, typename Derived>
class AggregateFunctionAvgBase : public class AggregateFunctionAvgBase : public
IAggregateFunctionDataHelper<AvgFraction<Numerator, Denominator>, Derived> IAggregateFunctionDataHelper<AvgFraction<TNumerator, TDenominator>, Derived>
{ {
public: public:
using Numerator = TNumerator;
using Denominator = TDenominator;
using Fraction = AvgFraction<Numerator, Denominator>; using Fraction = AvgFraction<Numerator, Denominator>;
using Base = IAggregateFunctionDataHelper<Fraction, Derived>; using Base = IAggregateFunctionDataHelper<Fraction, Derived>;
@ -143,6 +145,89 @@ public:
else else
assert_cast<ColumnVector<Float64> &>(to).getData().push_back(this->data(place).divide()); assert_cast<ColumnVector<Float64> &>(to).getData().push_back(this->data(place).divide());
} }
#if USE_EMBEDDED_COMPILER
bool isCompilable() const override
{
bool can_be_compiled = true;
for (const auto & argument : this->argument_types)
can_be_compiled &= canBeNativeType(*argument);
auto return_type = getReturnType();
can_be_compiled &= canBeNativeType(*return_type);
return can_be_compiled;
}
void compileCreate(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr) const override
{
llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder);
auto * numerator_type = toNativeType<Numerator>(b);
auto * numerator_ptr = b.CreatePointerCast(aggregate_data_ptr, numerator_type->getPointerTo());
auto * denominator_type = toNativeType<Denominator>(b);
auto * denominator_offset_ptr = b.CreateConstGEP1_32(nullptr, aggregate_data_ptr, sizeof(Numerator));
auto * denominator_ptr = b.CreatePointerCast(denominator_offset_ptr, denominator_type->getPointerTo());
b.CreateStore(llvm::Constant::getNullValue(numerator_type), numerator_ptr);
b.CreateStore(llvm::Constant::getNullValue(denominator_type), denominator_ptr);
}
void compileMerge(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_dst_ptr, llvm::Value * aggregate_data_src_ptr) const override
{
llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder);
auto * numerator_type = toNativeType<Numerator>(b);
auto * numerator_dst_ptr = b.CreatePointerCast(aggregate_data_dst_ptr, numerator_type->getPointerTo());
auto * numerator_dst_value = b.CreateLoad(numerator_type, numerator_dst_ptr);
auto * numerator_src_ptr = b.CreatePointerCast(aggregate_data_dst_ptr, numerator_type->getPointerTo());
auto * numerator_src_value = b.CreateLoad(numerator_type, numerator_src_ptr);
auto * numerator_result_value = numerator_type->isIntegerTy() ? b.CreateAdd(numerator_dst_value, numerator_src_value) : b.CreateFAdd(numerator_dst_value, numerator_src_value);
b.CreateStore(numerator_result_value, numerator_dst_ptr);
auto * denominator_type = toNativeType<Denominator>(b);
auto * denominator_dst_offset_ptr = b.CreateConstGEP1_32(nullptr, aggregate_data_dst_ptr, sizeof(Numerator));
auto * denominator_src_offset_ptr = b.CreateConstGEP1_32(nullptr, aggregate_data_src_ptr, sizeof(Numerator));
auto * denominator_dst_ptr = b.CreatePointerCast(denominator_dst_offset_ptr, denominator_type->getPointerTo());
auto * denominator_src_ptr = b.CreatePointerCast(denominator_src_offset_ptr, denominator_type->getPointerTo());
auto * denominator_dst_value = b.CreateLoad(denominator_type, denominator_dst_ptr);
auto * denominator_src_value = b.CreateLoad(denominator_type, denominator_src_ptr);
auto * denominator_result_value = numerator_type->isIntegerTy() ? b.CreateAdd(denominator_src_value, denominator_dst_value) : b.CreateFAdd(denominator_src_value, denominator_dst_value);
b.CreateStore(denominator_result_value, denominator_dst_ptr);
}
llvm::Value * compileGetResult(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr) const override
{
llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder);
auto * numerator_type = toNativeType<Numerator>(b);
auto * numerator_ptr = b.CreatePointerCast(aggregate_data_ptr, numerator_type->getPointerTo());
auto * numerator_value = b.CreateLoad(numerator_type, numerator_ptr);
auto * denominator_type = toNativeType<Denominator>(b);
auto * denominator_offset_ptr = b.CreateConstGEP1_32(nullptr, aggregate_data_ptr, sizeof(Numerator));
auto * denominator_ptr = b.CreatePointerCast(denominator_offset_ptr, denominator_type->getPointerTo());
auto * denominator_value = b.CreateLoad(denominator_type, denominator_ptr);
auto * double_numerator = nativeCast<Numerator>(b, numerator_value, b.getDoubleTy());
auto * double_denominator = nativeCast<Denominator>(b, denominator_value, b.getDoubleTy());
return b.CreateFDiv(double_numerator, double_denominator);
}
#endif
private: private:
UInt32 num_scale; UInt32 num_scale;
UInt32 denom_scale; UInt32 denom_scale;
@ -157,7 +242,11 @@ template <typename T>
class AggregateFunctionAvg final : public AggregateFunctionAvgBase<AvgFieldType<T>, UInt64, AggregateFunctionAvg<T>> class AggregateFunctionAvg final : public AggregateFunctionAvgBase<AvgFieldType<T>, UInt64, AggregateFunctionAvg<T>>
{ {
public: public:
using AggregateFunctionAvgBase<AvgFieldType<T>, UInt64, AggregateFunctionAvg<T>>::AggregateFunctionAvgBase; using Base = AggregateFunctionAvgBase<AvgFieldType<T>, UInt64, AggregateFunctionAvg<T>>;
using Base::Base;
using Numerator = typename Base::Numerator;
using Denominator = typename Base::Denominator;
void NO_SANITIZE_UNDEFINED add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const final void NO_SANITIZE_UNDEFINED add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const final
{ {
@ -169,36 +258,11 @@ public:
#if USE_EMBEDDED_COMPILER #if USE_EMBEDDED_COMPILER
bool isCompilable() const override
{
bool can_be_compiled = true;
for (const auto & argument : this->argument_types)
can_be_compiled &= canBeNativeType(*argument);
return can_be_compiled;
}
void compileCreate(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr) const override
{
llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder);
auto * numerator_type = toNativeType<AvgFieldType<T>>(b);
auto * numerator_ptr = b.CreatePointerCast(aggregate_data_ptr, numerator_type->getPointerTo());
auto * denominator_type = toNativeType<UInt64>(b);
auto * denominator_offset_ptr = b.CreateConstGEP1_32(nullptr, aggregate_data_ptr, sizeof(AvgFieldType<T>));
auto * denominator_ptr = b.CreatePointerCast(denominator_offset_ptr, denominator_type->getPointerTo());
b.CreateStore(llvm::Constant::getNullValue(numerator_type), numerator_ptr);
b.CreateStore(llvm::Constant::getNullValue(denominator_type), denominator_ptr);
}
void compileAdd(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr, const DataTypes & arguments_types, const std::vector<llvm::Value *> & argument_values) const override void compileAdd(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr, const DataTypes & arguments_types, const std::vector<llvm::Value *> & argument_values) const override
{ {
llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder); llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder);
auto * numerator_type = toNativeType<AvgFieldType<T>>(b); auto * numerator_type = toNativeType<Numerator>(b);
auto * numerator_ptr = b.CreatePointerCast(aggregate_data_ptr, numerator_type->getPointerTo()); auto * numerator_ptr = b.CreatePointerCast(aggregate_data_ptr, numerator_type->getPointerTo());
auto * numerator_value = b.CreateLoad(numerator_type, numerator_ptr); auto * numerator_value = b.CreateLoad(numerator_type, numerator_ptr);
@ -209,9 +273,9 @@ public:
auto * numerator_result_value = numerator_type->isIntegerTy() ? b.CreateAdd(numerator_value, value_cast_to_numerator) : b.CreateFAdd(numerator_value, value_cast_to_numerator); auto * numerator_result_value = numerator_type->isIntegerTy() ? b.CreateAdd(numerator_value, value_cast_to_numerator) : b.CreateFAdd(numerator_value, value_cast_to_numerator);
b.CreateStore(numerator_result_value, numerator_ptr); b.CreateStore(numerator_result_value, numerator_ptr);
auto * denominator_type = toNativeType<UInt64>(b); auto * denominator_type = toNativeType<Denominator>(b);
auto * denominator_offset_ptr = b.CreateConstGEP1_32(nullptr, aggregate_data_ptr, sizeof(AvgFieldType<T>)); auto * denominator_offset_ptr = b.CreateConstGEP1_32(nullptr, aggregate_data_ptr, sizeof(Numerator));
auto * denominator_ptr = b.CreatePointerCast(denominator_offset_ptr, denominator_type->getPointerTo()); auto * denominator_ptr = b.CreatePointerCast(denominator_offset_ptr, denominator_type->getPointerTo());
auto * denominator_value = b.CreateLoad(denominator_type, denominator_ptr); auto * denominator_value = b.CreateLoad(denominator_type, denominator_ptr);
@ -220,55 +284,6 @@ public:
b.CreateStore(denominator_value_updated, denominator_ptr); b.CreateStore(denominator_value_updated, denominator_ptr);
} }
void compileMerge(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_dst_ptr, llvm::Value * aggregate_data_src_ptr) const override
{
llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder);
auto * numerator_type = toNativeType<AvgFieldType<T>>(b);
auto * numerator_dst_ptr = b.CreatePointerCast(aggregate_data_dst_ptr, numerator_type->getPointerTo());
auto * numerator_dst_value = b.CreateLoad(numerator_type, numerator_dst_ptr);
auto * numerator_src_ptr = b.CreatePointerCast(aggregate_data_dst_ptr, numerator_type->getPointerTo());
auto * numerator_src_value = b.CreateLoad(numerator_type, numerator_src_ptr);
auto * numerator_result_value = numerator_type->isIntegerTy() ? b.CreateAdd(numerator_dst_value, numerator_src_value) : b.CreateFAdd(numerator_dst_value, numerator_src_value);
b.CreateStore(numerator_result_value, numerator_dst_ptr);
auto * denominator_type = toNativeType<UInt64>(b);
auto * denominator_dst_offset_ptr = b.CreateConstGEP1_32(nullptr, aggregate_data_dst_ptr, sizeof(AvgFieldType<T>));
auto * denominator_src_offset_ptr = b.CreateConstGEP1_32(nullptr, aggregate_data_src_ptr, sizeof(AvgFieldType<T>));
auto * denominator_dst_ptr = b.CreatePointerCast(denominator_dst_offset_ptr, denominator_type->getPointerTo());
auto * denominator_src_ptr = b.CreatePointerCast(denominator_src_offset_ptr, denominator_type->getPointerTo());
auto * denominator_dst_value = b.CreateLoad(denominator_type, denominator_dst_ptr);
auto * denominator_src_value = b.CreateLoad(denominator_type, denominator_src_ptr);
auto * denominator_result_value = b.CreateAdd(denominator_src_value, denominator_dst_value);
b.CreateStore(denominator_result_value, denominator_dst_ptr);
}
llvm::Value * compileGetResult(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr) const override
{
llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder);
auto * numerator_type = toNativeType<AvgFieldType<T>>(b);
auto * numerator_ptr = b.CreatePointerCast(aggregate_data_ptr, numerator_type->getPointerTo());
auto * numerator_value = b.CreateLoad(numerator_type, numerator_ptr);
auto * denominator_type = toNativeType<UInt64>(b);
auto * denominator_offset_ptr = b.CreateConstGEP1_32(nullptr, aggregate_data_ptr, sizeof(AvgFieldType<T>));
auto * denominator_ptr = b.CreatePointerCast(denominator_offset_ptr, denominator_type->getPointerTo());
auto * denominator_value = b.CreateLoad(denominator_type, denominator_ptr);
auto * double_numerator = nativeCast<AvgFieldType<T>>(b, numerator_value, b.getDoubleTy());
auto * double_denominator = nativeCast<UInt64>(b, denominator_value, b.getDoubleTy());
return b.CreateFDiv(double_numerator, double_denominator);
}
#endif #endif
}; };

View File

@ -28,19 +28,64 @@ public:
MaxFieldType<Value, Weight>, AvgWeightedFieldType<Weight>, AggregateFunctionAvgWeighted<Value, Weight>>; MaxFieldType<Value, Weight>, AvgWeightedFieldType<Weight>, AggregateFunctionAvgWeighted<Value, Weight>>;
using Base::Base; using Base::Base;
using ValueT = MaxFieldType<Value, Weight>; using Numerator = typename Base::Numerator;
using Denominator = typename Base::Denominator;
void NO_SANITIZE_UNDEFINED add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override void NO_SANITIZE_UNDEFINED add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override
{ {
const auto& weights = static_cast<const DecimalOrVectorCol<Weight> &>(*columns[1]); const auto& weights = static_cast<const DecimalOrVectorCol<Weight> &>(*columns[1]);
this->data(place).numerator += static_cast<ValueT>( this->data(place).numerator += static_cast<Numerator>(
static_cast<const DecimalOrVectorCol<Value> &>(*columns[0]).getData()[row_num]) * static_cast<const DecimalOrVectorCol<Value> &>(*columns[0]).getData()[row_num]) *
static_cast<ValueT>(weights.getData()[row_num]); static_cast<Numerator>(weights.getData()[row_num]);
this->data(place).denominator += static_cast<AvgWeightedFieldType<Weight>>(weights.getData()[row_num]); this->data(place).denominator += static_cast<Denominator>(weights.getData()[row_num]);
} }
String getName() const override { return "avgWeighted"; } String getName() const override { return "avgWeighted"; }
#if USE_EMBEDDED_COMPILER
bool isCompilable() const override
{
bool can_be_compiled = Base::isCompilable();
can_be_compiled &= canBeNativeType<Weight>();
return can_be_compiled;
}
void compileAdd(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr, const DataTypes & arguments_types, const std::vector<llvm::Value *> & argument_values) const override
{
llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder);
auto * numerator_type = toNativeType<Numerator>(b);
auto * numerator_ptr = b.CreatePointerCast(aggregate_data_ptr, numerator_type->getPointerTo());
auto * numerator_value = b.CreateLoad(numerator_type, numerator_ptr);
const auto & argument = nativeCast(b, arguments_types[0], argument_values[0], numerator_type);
const auto & weight = nativeCast(b, arguments_types[1], argument_values[1], numerator_type);
llvm::Value * value_weight_multiplication = argument->getType()->isIntegerTy() ? b.CreateMul(argument, weight) : b.CreateFMul(argument, weight);
/// TODO: Fix accuracy
auto * numerator_result_value = numerator_type->isIntegerTy() ? b.CreateAdd(numerator_value, value_weight_multiplication) : b.CreateFAdd(numerator_value, value_weight_multiplication);
b.CreateStore(numerator_result_value, numerator_ptr);
auto * denominator_type = toNativeType<Denominator>(b);
auto * denominator_offset_ptr = b.CreateConstGEP1_32(nullptr, aggregate_data_ptr, sizeof(Numerator));
auto * denominator_ptr = b.CreatePointerCast(denominator_offset_ptr, denominator_type->getPointerTo());
auto * weight_cast_to_denominator = nativeCast(b, arguments_types[1], argument_values[1], numerator_type);
auto * denominator_value = b.CreateLoad(denominator_type, denominator_ptr);
auto * denominator_value_updated = numerator_type->isIntegerTy() ? b.CreateAdd(denominator_value, weight_cast_to_denominator) : b.CreateFAdd(denominator_value, weight_cast_to_denominator);
b.CreateStore(denominator_value_updated, denominator_ptr);
}
#endif
}; };
} }

View File

@ -106,6 +106,16 @@ public:
this->nested_function->add(this->nestedPlace(place), &nested_column, row_num, arena); this->nested_function->add(this->nestedPlace(place), &nested_column, row_num, arena);
} }
} }
#if USE_EMBEDDED_COMPILER
bool isCompilable() const override
{
return false;
}
#endif
}; };
template <bool result_is_nullable, bool serialize_flag, bool null_is_skipped> template <bool result_is_nullable, bool serialize_flag, bool null_is_skipped>
@ -168,6 +178,15 @@ public:
} }
} }
#if USE_EMBEDDED_COMPILER
bool isCompilable() const override
{
return false;
}
#endif
private: private:
using Base = AggregateFunctionNullBase<result_is_nullable, serialize_flag, using Base = AggregateFunctionNullBase<result_is_nullable, serialize_flag,
AggregateFunctionIfNullVariadic<result_is_nullable, serialize_flag, null_is_skipped>>; AggregateFunctionIfNullVariadic<result_is_nullable, serialize_flag, null_is_skipped>>;

View File

@ -192,50 +192,6 @@ public:
} }
AggregateFunctionPtr getNestedFunction() const override { return nested_function; } AggregateFunctionPtr getNestedFunction() const override { return nested_function; }
};
/** There are two cases: for single argument and variadic.
* Code for single argument is much more efficient.
*/
template <bool result_is_nullable, bool serialize_flag>
class AggregateFunctionNullUnary final
: public AggregateFunctionNullBase<result_is_nullable, serialize_flag,
AggregateFunctionNullUnary<result_is_nullable, serialize_flag>>
{
public:
AggregateFunctionNullUnary(AggregateFunctionPtr nested_function_, const DataTypes & arguments, const Array & params)
: AggregateFunctionNullBase<result_is_nullable, serialize_flag,
AggregateFunctionNullUnary<result_is_nullable, serialize_flag>>(std::move(nested_function_), arguments, params)
{
}
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override
{
const ColumnNullable * column = assert_cast<const ColumnNullable *>(columns[0]);
const IColumn * nested_column = &column->getNestedColumn();
if (!column->isNullAt(row_num))
{
this->setFlag(place);
this->nested_function->add(this->nestedPlace(place), &nested_column, row_num, arena);
}
}
void addBatchSinglePlace(
size_t batch_size, AggregateDataPtr place, const IColumn ** columns, Arena * arena, ssize_t if_argument_pos = -1) const override
{
const ColumnNullable * column = assert_cast<const ColumnNullable *>(columns[0]);
const IColumn * nested_column = &column->getNestedColumn();
const UInt8 * null_map = column->getNullMapData().data();
this->nested_function->addBatchSinglePlaceNotNull(
batch_size, this->nestedPlace(place), &nested_column, null_map, arena, if_argument_pos);
if constexpr (result_is_nullable)
if (!memoryIsByte(null_map, batch_size, 1))
this->setFlag(place);
}
#if USE_EMBEDDED_COMPILER #if USE_EMBEDDED_COMPILER
@ -258,38 +214,6 @@ public:
this->nested_function->compileCreate(b, aggregate_data_ptr_with_prefix_size_offset); this->nested_function->compileCreate(b, aggregate_data_ptr_with_prefix_size_offset);
} }
void compileAdd(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr, const DataTypes & arguments_types, const std::vector<llvm::Value *> & argument_values) const override
{
llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder);
const auto & nullable_type = arguments_types[0];
const auto & nullable_value = argument_values[0];
auto * wrapped_value = b.CreateExtractValue(nullable_value, {0});
auto * is_null_value = b.CreateExtractValue(nullable_value, {1});
auto * head = b.GetInsertBlock();
auto * join_block = llvm::BasicBlock::Create(head->getContext(), "join_block", head->getParent());
auto * if_null = llvm::BasicBlock::Create(head->getContext(), "if_null", head->getParent());
auto * if_not_null = llvm::BasicBlock::Create(head->getContext(), "if_not_null", head->getParent());
b.CreateCondBr(is_null_value, if_null, if_not_null);
b.SetInsertPoint(if_null);
b.CreateBr(join_block);
b.SetInsertPoint(if_not_null);
b.CreateStore(llvm::ConstantInt::get(b.getInt8Ty(), 1), aggregate_data_ptr);
auto * aggregate_data_ptr_with_prefix_size_offset = b.CreateConstGEP1_32(nullptr, aggregate_data_ptr, this->prefix_size);
this->nested_function->compileAdd(b, aggregate_data_ptr_with_prefix_size_offset, { removeNullable(nullable_type) }, { wrapped_value });
b.CreateBr(join_block);
b.SetInsertPoint(join_block);
}
void compileMerge(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_dst_ptr, llvm::Value * aggregate_data_src_ptr) const override void compileMerge(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_dst_ptr, llvm::Value * aggregate_data_src_ptr) const override
{ {
llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder); llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder);
@ -357,6 +281,85 @@ public:
}; };
/** There are two cases: for single argument and variadic.
* Code for single argument is much more efficient.
*/
template <bool result_is_nullable, bool serialize_flag>
class AggregateFunctionNullUnary final
: public AggregateFunctionNullBase<result_is_nullable, serialize_flag,
AggregateFunctionNullUnary<result_is_nullable, serialize_flag>>
{
public:
AggregateFunctionNullUnary(AggregateFunctionPtr nested_function_, const DataTypes & arguments, const Array & params)
: AggregateFunctionNullBase<result_is_nullable, serialize_flag,
AggregateFunctionNullUnary<result_is_nullable, serialize_flag>>(std::move(nested_function_), arguments, params)
{
}
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override
{
const ColumnNullable * column = assert_cast<const ColumnNullable *>(columns[0]);
const IColumn * nested_column = &column->getNestedColumn();
if (!column->isNullAt(row_num))
{
this->setFlag(place);
this->nested_function->add(this->nestedPlace(place), &nested_column, row_num, arena);
}
}
void addBatchSinglePlace(
size_t batch_size, AggregateDataPtr place, const IColumn ** columns, Arena * arena, ssize_t if_argument_pos = -1) const override
{
const ColumnNullable * column = assert_cast<const ColumnNullable *>(columns[0]);
const IColumn * nested_column = &column->getNestedColumn();
const UInt8 * null_map = column->getNullMapData().data();
this->nested_function->addBatchSinglePlaceNotNull(
batch_size, this->nestedPlace(place), &nested_column, null_map, arena, if_argument_pos);
if constexpr (result_is_nullable)
if (!memoryIsByte(null_map, batch_size, 1))
this->setFlag(place);
}
#if USE_EMBEDDED_COMPILER
void compileAdd(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr, const DataTypes & arguments_types, const std::vector<llvm::Value *> & argument_values) const override
{
llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder);
const auto & nullable_type = arguments_types[0];
const auto & nullable_value = argument_values[0];
auto * wrapped_value = b.CreateExtractValue(nullable_value, {0});
auto * is_null_value = b.CreateExtractValue(nullable_value, {1});
auto * head = b.GetInsertBlock();
auto * join_block = llvm::BasicBlock::Create(head->getContext(), "join_block", head->getParent());
auto * if_null = llvm::BasicBlock::Create(head->getContext(), "if_null", head->getParent());
auto * if_not_null = llvm::BasicBlock::Create(head->getContext(), "if_not_null", head->getParent());
b.CreateCondBr(is_null_value, if_null, if_not_null);
b.SetInsertPoint(if_null);
b.CreateBr(join_block);
b.SetInsertPoint(if_not_null);
b.CreateStore(llvm::ConstantInt::get(b.getInt8Ty(), 1), aggregate_data_ptr);
auto * aggregate_data_ptr_with_prefix_size_offset = b.CreateConstGEP1_32(nullptr, aggregate_data_ptr, this->prefix_size);
this->nested_function->compileAdd(b, aggregate_data_ptr_with_prefix_size_offset, { removeNullable(nullable_type) }, { wrapped_value });
b.CreateBr(join_block);
b.SetInsertPoint(join_block);
}
#endif
};
template <bool result_is_nullable, bool serialize_flag, bool null_is_skipped> template <bool result_is_nullable, bool serialize_flag, bool null_is_skipped>
class AggregateFunctionNullVariadic final class AggregateFunctionNullVariadic final
: public AggregateFunctionNullBase<result_is_nullable, serialize_flag, : public AggregateFunctionNullBase<result_is_nullable, serialize_flag,
@ -405,6 +408,87 @@ public:
this->nested_function->add(this->nestedPlace(place), nested_columns, row_num, arena); this->nested_function->add(this->nestedPlace(place), nested_columns, row_num, arena);
} }
#if USE_EMBEDDED_COMPILER
void compileAdd(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr, const DataTypes & arguments_types, const std::vector<llvm::Value *> & argument_values) const override
{
llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder);
size_t arguments_size = arguments_types.size();
DataTypes non_nullable_types;
std::vector<llvm::Value * > wrapped_values;
std::vector<llvm::Value * > is_null_values;
non_nullable_types.resize(arguments_size);
wrapped_values.resize(arguments_size);
is_null_values.resize(arguments_size);
for (size_t i = 0; i < arguments_size; ++i)
{
const auto & argument_value = argument_values[i];
if (is_nullable[i])
{
auto * wrapped_value = b.CreateExtractValue(argument_value, {0});
if constexpr (null_is_skipped)
is_null_values[i] = b.CreateExtractValue(argument_value, {1});
wrapped_values[i] = wrapped_value;
non_nullable_types[i] = removeNullable(arguments_types[i]);
}
else
{
wrapped_values[i] = argument_value;
non_nullable_types[i] = arguments_types[i];
}
}
if constexpr (null_is_skipped)
{
auto * head = b.GetInsertBlock();
auto * join_block = llvm::BasicBlock::Create(head->getContext(), "join_block", head->getParent());
auto * if_null = llvm::BasicBlock::Create(head->getContext(), "if_null", head->getParent());
auto * if_not_null = llvm::BasicBlock::Create(head->getContext(), "if_not_null", head->getParent());
auto * values_have_null_ptr = b.CreateAlloca(b.getInt1Ty());
b.CreateStore(b.getInt1(false), values_have_null_ptr);
for (auto * is_null_value : is_null_values)
{
if (!is_null_value)
continue;
auto * values_have_null = b.CreateLoad(b.getInt1Ty(), values_have_null_ptr);
b.CreateStore(b.CreateOr(values_have_null, is_null_value), values_have_null_ptr);
}
b.CreateCondBr(b.CreateLoad(b.getInt1Ty(), values_have_null_ptr), if_null, if_not_null);
b.SetInsertPoint(if_null);
b.CreateBr(join_block);
b.SetInsertPoint(if_not_null);
b.CreateStore(llvm::ConstantInt::get(b.getInt8Ty(), 1), aggregate_data_ptr);
auto * aggregate_data_ptr_with_prefix_size_offset = b.CreateConstGEP1_32(nullptr, aggregate_data_ptr, this->prefix_size);
this->nested_function->compileAdd(b, aggregate_data_ptr_with_prefix_size_offset, arguments_types, wrapped_values);
b.CreateBr(join_block);
b.SetInsertPoint(join_block);
}
else
{
b.CreateStore(llvm::ConstantInt::get(b.getInt8Ty(), 1), aggregate_data_ptr);
auto * aggregate_data_ptr_with_prefix_size_offset = b.CreateConstGEP1_32(nullptr, aggregate_data_ptr, this->prefix_size);
this->nested_function->compileAdd(b, aggregate_data_ptr_with_prefix_size_offset, non_nullable_types, wrapped_values);
}
}
#endif
private: private:
enum { MAX_ARGS = 8 }; enum { MAX_ARGS = 8 };
size_t number_of_arguments = 0; size_t number_of_arguments = 0;

View File

@ -48,6 +48,15 @@ public:
String getName() const final { return "sumCount"; } String getName() const final { return "sumCount"; }
#if USE_EMBEDDED_COMPILER
bool isCompilable() const override
{
return false;
}
#endif
private: private:
UInt32 scale; UInt32 scale;
}; };

View File

@ -251,6 +251,30 @@ public:
// of true window functions, so this hack-ish interface suffices. // of true window functions, so this hack-ish interface suffices.
virtual bool isOnlyWindowFunction() const { return false; } virtual bool isOnlyWindowFunction() const { return false; }
virtual String getDescription() const
{
String description;
description += getName();
description += '(';
for (const auto & argument_type : argument_types)
{
description += argument_type->getName();
description += ", ";
}
if (!argument_types.empty())
{
description.pop_back();
description.pop_back();
}
description += ')';
return description;
}
#if USE_EMBEDDED_COMPILER #if USE_EMBEDDED_COMPILER
virtual bool isCompilable() const { return false; } virtual bool isCompilable() const { return false; }

View File

@ -80,6 +80,25 @@ static inline llvm::Type * toNativeType(llvm::IRBuilderBase & builder)
return nullptr; return nullptr;
} }
template <typename Type>
static inline bool canBeNativeType()
{
if constexpr (std::is_same_v<Type, Int8> || std::is_same_v<Type, UInt8>)
return true;
else if constexpr (std::is_same_v<Type, Int16> || std::is_same_v<Type, UInt16>)
return true;
else if constexpr (std::is_same_v<Type, Int32> || std::is_same_v<Type, UInt32>)
return true;
else if constexpr (std::is_same_v<Type, Int64> || std::is_same_v<Type, UInt64>)
return true;
else if constexpr (std::is_same_v<Type, Float32>)
return true;
else if constexpr (std::is_same_v<Type, Float64>)
return true;
return false;
}
static inline bool canBeNativeType(const IDataType & type) static inline bool canBeNativeType(const IDataType & type)
{ {
WhichDataType data_type(type); WhichDataType data_type(type);
@ -180,6 +199,37 @@ static inline llvm::Value * nativeCast(llvm::IRBuilder<> & b, const DataTypePtr
return nativeCast(b, from, value, n_to); return nativeCast(b, from, value, n_to);
} }
static inline std::pair<llvm::Value *, llvm::Value *> nativeCastToCommon(llvm::IRBuilder<> & b, const DataTypePtr & lhs_type, llvm::Value * lhs, const DataTypePtr & rhs_type, llvm::Value * rhs)
{
llvm::Type * common;
bool lhs_is_signed = typeIsSigned(*lhs_type);
bool rhs_is_signed = typeIsSigned(*rhs_type);
if (lhs->getType()->isIntegerTy() && rhs->getType()->isIntegerTy())
{
/// if one integer has a sign bit, make sure the other does as well. llvm generates optimal code
/// (e.g. uses overflow flag on x86) for (word size + 1)-bit integer operations.
size_t lhs_bit_width = lhs->getType()->getIntegerBitWidth() + (!lhs_is_signed && rhs_is_signed);
size_t rhs_bit_width = rhs->getType()->getIntegerBitWidth() + (!rhs_is_signed && lhs_is_signed);
size_t max_bit_width = std::max(lhs_bit_width, rhs_bit_width);
common = b.getIntNTy(max_bit_width);
}
else
{
/// TODO: Check
/// (double, float) or (double, int_N where N <= double's mantissa width) -> double
common = b.getDoubleTy();
}
auto * cast_lhs_to_common = nativeCast(b, lhs_type, lhs, common);
auto * cast_rhs_to_common = nativeCast(b, rhs_type, rhs, common);
return std::make_pair(cast_lhs_to_common, cast_rhs_to_common);
}
static inline llvm::Constant * getColumnNativeValue(llvm::IRBuilderBase & builder, const DataTypePtr & column_type, const IColumn & column, size_t index) static inline llvm::Constant * getColumnNativeValue(llvm::IRBuilderBase & builder, const DataTypePtr & column_type, const IColumn & column, size_t index)
{ {
if (const auto * constant = typeid_cast<const ColumnConst *>(&column)) if (const auto * constant = typeid_cast<const ColumnConst *>(&column))

View File

@ -1265,23 +1265,7 @@ public:
assert(2 == types.size() && 2 == values.size()); assert(2 == types.size() && 2 == values.size());
auto & b = static_cast<llvm::IRBuilder<> &>(builder); auto & b = static_cast<llvm::IRBuilder<> &>(builder);
auto * x = values[0]; auto [x, y] = nativeCastToCommon(b, types[0], values[0], types[1], values[1]);
auto * y = values[1];
if (!types[0]->equals(*types[1]))
{
llvm::Type * common;
if (x->getType()->isIntegerTy() && y->getType()->isIntegerTy())
common = b.getIntNTy(std::max(
/// if one integer has a sign bit, make sure the other does as well. llvm generates optimal code
/// (e.g. uses overflow flag on x86) for (word size + 1)-bit integer operations.
x->getType()->getIntegerBitWidth() + (!typeIsSigned(*types[0]) && typeIsSigned(*types[1])),
y->getType()->getIntegerBitWidth() + (!typeIsSigned(*types[1]) && typeIsSigned(*types[0]))));
else
/// (double, float) or (double, int_N where N <= double's mantissa width) -> double
common = b.getDoubleTy();
x = nativeCast(b, types[0], x, common);
y = nativeCast(b, types[1], y, common);
}
auto * result = CompileOp<Op>::compile(b, x, y, typeIsSigned(*types[0]) || typeIsSigned(*types[1])); auto * result = CompileOp<Op>::compile(b, x, y, typeIsSigned(*types[0]) || typeIsSigned(*types[1]));
return b.CreateSelect(result, b.getInt8(1), b.getInt8(0)); return b.CreateSelect(result, b.getInt8(1), b.getInt8(0));
} }

View File

@ -222,32 +222,6 @@ static CHJIT & getJITInstance()
return jit; return jit;
} }
static std::string dumpAggregateFunction(const IAggregateFunction * function)
{
std::string function_dump;
auto return_type_name = function->getReturnType()->getName();
function_dump += return_type_name;
function_dump += ' ';
function_dump += function->getName();
function_dump += '(';
const auto & argument_types = function->getArgumentTypes();
for (const auto & argument_type : argument_types)
{
function_dump += argument_type->getName();
function_dump += ',';
}
if (!argument_types.empty())
function_dump.pop_back();
function_dump += ')';
return function_dump;
}
#endif #endif
Aggregator::Aggregator(const Params & params_) Aggregator::Aggregator(const Params & params_)
@ -317,7 +291,7 @@ void Aggregator::compileAggregateFunctions()
std::vector<AggregateFunctionWithOffset> functions_to_compile; std::vector<AggregateFunctionWithOffset> functions_to_compile;
size_t aggregate_instructions_size = 0; size_t aggregate_instructions_size = 0;
std::string functions_dump; String functions_description;
/// Add values to the aggregate functions. /// Add values to the aggregate functions.
for (size_t i = 0; i < aggregate_functions.size(); ++i) for (size_t i = 0; i < aggregate_functions.size(); ++i)
@ -333,11 +307,10 @@ void Aggregator::compileAggregateFunctions()
.aggregate_data_offset = offset_of_aggregate_function .aggregate_data_offset = offset_of_aggregate_function
}; };
std::string function_dump = dumpAggregateFunction(function);
functions_dump += function_dump;
functions_dump += ' ';
functions_to_compile.emplace_back(std::move(function_to_compile)); functions_to_compile.emplace_back(std::move(function_to_compile));
functions_description += function->getDescription();
functions_description += ' ';
} }
++aggregate_instructions_size; ++aggregate_instructions_size;
@ -354,20 +327,21 @@ void Aggregator::compileAggregateFunctions()
std::lock_guard<std::mutex> lock(mtx); std::lock_guard<std::mutex> lock(mtx);
auto it = aggregation_functions_dump_to_add_compiled.find(functions_dump); auto it = aggregation_functions_dump_to_add_compiled.find(functions_description);
if (it != aggregation_functions_dump_to_add_compiled.end()) if (it != aggregation_functions_dump_to_add_compiled.end())
{ {
compiled_aggregate_functions = it->second; compiled_aggregate_functions = it->second;
} }
else else
{ {
LOG_TRACE(log, "Compile expression {}", functions_dump); LOG_TRACE(log, "Compile expression {}", functions_description);
compiled_aggregate_functions = compileAggregateFunctons(getJITInstance(), functions_to_compile, functions_dump); compiled_aggregate_functions = compileAggregateFunctons(getJITInstance(), functions_to_compile, functions_description);
aggregation_functions_dump_to_add_compiled[functions_dump] = compiled_aggregate_functions; aggregation_functions_dump_to_add_compiled[functions_description] = compiled_aggregate_functions;
} }
} }
LOG_TRACE(log, "Use compiled expression {}", functions_description);
compiled_functions.emplace(std::move(compiled_aggregate_functions)); compiled_functions.emplace(std::move(compiled_aggregate_functions));
} }

View File

@ -0,0 +1,12 @@
Test unsigned integer values
0 1140 1140 1140 1140
1 1220 1220 1220 1220
2 1180 1180 1180 1180
Test signed integer values
0 1140 1140 1140 1140
1 1220 1220 1220 1220
2 1180 1180 1180 1180
Test float values
0 1140 1140
1 1220 1220
2 1180 1180

View File

@ -0,0 +1,141 @@
SET compile_aggregate_expressions = 1;
SET min_count_to_compile_aggregate_expression = 0;
SELECT 'Test unsigned integer values';
DROP TABLE IF EXISTS test_table_unsigned_values;
CREATE TABLE test_table_unsigned_values
(
id UInt64,
value1 UInt8,
value2 UInt16,
value3 UInt32,
value4 UInt64,
predicate_value UInt8
) ENGINE=TinyLog;
INSERT INTO test_table_unsigned_values SELECT number % 3, number, number, number, number, if(number % 2 == 0, 1, 0) FROM system.numbers LIMIT 120;
SELECT
id,
sumIf(value1, predicate_value),
sumIf(value2, predicate_value),
sumIf(value3, predicate_value),
sumIf(value4, predicate_value)
FROM test_table_unsigned_values GROUP BY id ORDER BY id;
DROP TABLE test_table_unsigned_values;
SELECT 'Test signed integer values';
DROP TABLE IF EXISTS test_table_signed_values;
CREATE TABLE test_table_signed_values
(
id UInt64,
value1 Int8,
value2 Int16,
value3 Int32,
value4 Int64,
predicate_value UInt8
) ENGINE=TinyLog;
INSERT INTO test_table_signed_values SELECT number % 3, number, number, number, number, if(number % 2 == 0, 1, 0) FROM system.numbers LIMIT 120;
SELECT
id,
sumIf(value1, predicate_value),
sumIf(value2, predicate_value),
sumIf(value3, predicate_value),
sumIf(value4, predicate_value)
FROM test_table_signed_values GROUP BY id ORDER BY id;
DROP TABLE test_table_signed_values;
SELECT 'Test float values';
DROP TABLE IF EXISTS test_table_float_values;
CREATE TABLE test_table_float_values
(
id UInt64,
value1 Float32,
value2 Float64,
predicate_value UInt8
) ENGINE=TinyLog;
INSERT INTO test_table_float_values SELECT number % 3, number, number, if(number % 2 == 0, 1, 0) FROM system.numbers LIMIT 120;
SELECT
id,
sumIf(value1, predicate_value),
sumIf(value2, predicate_value)
FROM test_table_float_values GROUP BY id ORDER BY id;
DROP TABLE test_table_float_values;
-- SELECT 'Test nullable unsigned integer values';
-- DROP TABLE IF EXISTS test_table_nullable_unsigned_values;
-- CREATE TABLE test_table_nullable_unsigned_values
-- (
-- id UInt64,
-- value1 Nullable(UInt8),
-- value2 Nullable(UInt16),
-- value3 Nullable(UInt32),
-- value4 Nullable(UInt64)
-- ) ENGINE=TinyLog;
-- INSERT INTO test_table_nullable_unsigned_values SELECT number % 3, number, number, number, number FROM system.numbers LIMIT 120;
-- SELECT id, sum(value1), sum(value2), sum(value3), sum(value4) FROM test_table_nullable_unsigned_values GROUP BY id ORDER BY id;
-- DROP TABLE test_table_nullable_unsigned_values;
-- SELECT 'Test nullable signed integer values';
-- DROP TABLE IF EXISTS test_table_nullable_signed_values;
-- CREATE TABLE test_table_nullable_signed_values
-- (
-- id UInt64,
-- value1 Nullable(Int8),
-- value2 Nullable(Int16),
-- value3 Nullable(Int32),
-- value4 Nullable(Int64)
-- ) ENGINE=TinyLog;
-- INSERT INTO test_table_nullable_signed_values SELECT number % 3, number, number, number, number FROM system.numbers LIMIT 120;
-- SELECT id, sum(value1), sum(value2), sum(value3), sum(value4) FROM test_table_nullable_signed_values GROUP BY id ORDER BY id;
-- DROP TABLE test_table_nullable_signed_values;
-- SELECT 'Test nullable float values';
-- DROP TABLE IF EXISTS test_table_nullable_float_values;
-- CREATE TABLE test_table_nullable_float_values
-- (
-- id UInt64,
-- value1 Nullable(Float32),
-- value2 Nullable(Float64)
-- ) ENGINE=TinyLog;
-- INSERT INTO test_table_nullable_float_values SELECT number % 3, number, number FROM system.numbers LIMIT 120;
-- SELECT id, sum(value1), sum(value2) FROM test_table_nullable_float_values GROUP BY id ORDER BY id;
-- DROP TABLE test_table_nullable_float_values;
-- SELECT 'Test null specifics';
-- DROP TABLE IF EXISTS test_table_null_specifics;
-- CREATE TABLE test_table_null_specifics
-- (
-- id UInt64,
-- value1 Nullable(UInt64),
-- value2 Nullable(UInt64),
-- value3 Nullable(UInt64)
-- ) ENGINE=TinyLog;
-- INSERT INTO test_table_null_specifics VALUES (0, 1, 1, NULL);
-- INSERT INTO test_table_null_specifics VALUES (0, 2, NULL, NULL);
-- INSERT INTO test_table_null_specifics VALUES (0, 3, 3, NULL);
-- SELECT id, sum(value1), sum(value2), sum(value3) FROM test_table_null_specifics GROUP BY id ORDER BY id;
-- DROP TABLE IF EXISTS test_table_null_specifics;

View File

@ -0,0 +1,26 @@
Test unsigned integer values
0 nan nan nan nan
1 59.5 59.5 59.5 59.5
2 60.5 60.5 60.5 60.5
Test signed integer values
0 nan nan nan nan
1 59.5 59.5 59.5 59.5
2 60.5 60.5 60.5 60.5
Test float values
0 nan nan
1 59.5 59.5
2 60.5 60.5
Test nullable unsigned integer values
0 nan nan nan nan
1 59.5 59.5 59.5 59.5
2 60.5 60.5 60.5 60.5
Test nullable signed integer values
0 nan nan nan nan
1 59.5 59.5 59.5 59.5
2 60.5 60.5 60.5 60.5
Test nullable float values
0 nan nan
1 59.5 59.5
2 60.5 60.5
Test null specifics
0 2.3333333333333335 2.5 \N 2.5 2.5 \N

View File

@ -0,0 +1,167 @@
SET compile_aggregate_expressions = 1;
SET min_count_to_compile_aggregate_expression = 0;
SELECT 'Test unsigned integer values';
DROP TABLE IF EXISTS test_table_unsigned_values;
CREATE TABLE test_table_unsigned_values
(
id UInt64,
value1 UInt8,
value2 UInt16,
value3 UInt32,
value4 UInt64,
weight UInt64
) ENGINE=TinyLog;
INSERT INTO test_table_unsigned_values SELECT number % 3, number, number, number, number, number % 3 FROM system.numbers LIMIT 120;
SELECT
id,
avgWeighted(value1, weight),
avgWeighted(value2, weight),
avgWeighted(value3, weight),
avgWeighted(value4, weight)
FROM test_table_unsigned_values GROUP BY id ORDER BY id;
DROP TABLE test_table_unsigned_values;
SELECT 'Test signed integer values';
DROP TABLE IF EXISTS test_table_signed_values;
CREATE TABLE test_table_signed_values
(
id UInt64,
value1 Int8,
value2 Int16,
value3 Int32,
value4 Int64,
weight UInt64
) ENGINE=TinyLog;
INSERT INTO test_table_signed_values SELECT number % 3, number, number, number, number, number % 3 FROM system.numbers LIMIT 120;
SELECT
id,
avgWeighted(value1, weight),
avgWeighted(value2, weight),
avgWeighted(value3, weight),
avgWeighted(value4, weight)
FROM test_table_signed_values GROUP BY id ORDER BY id;
DROP TABLE test_table_signed_values;
SELECT 'Test float values';
DROP TABLE IF EXISTS test_table_float_values;
CREATE TABLE test_table_float_values
(
id UInt64,
value1 Float32,
value2 Float64,
weight UInt64
) ENGINE=TinyLog;
INSERT INTO test_table_float_values SELECT number % 3, number, number, number % 3 FROM system.numbers LIMIT 120;
SELECT id, avgWeighted(value1, weight), avgWeighted(value2, weight) FROM test_table_float_values GROUP BY id ORDER BY id;
DROP TABLE test_table_float_values;
SELECT 'Test nullable unsigned integer values';
DROP TABLE IF EXISTS test_table_nullable_unsigned_values;
CREATE TABLE test_table_nullable_unsigned_values
(
id UInt64,
value1 Nullable(UInt8),
value2 Nullable(UInt16),
value3 Nullable(UInt32),
value4 Nullable(UInt64),
weight UInt64
) ENGINE=TinyLog;
INSERT INTO test_table_nullable_unsigned_values SELECT number % 3, number, number, number, number, number % 3 FROM system.numbers LIMIT 120;
SELECT
id,
avgWeighted(value1, weight),
avgWeighted(value2, weight),
avgWeighted(value3, weight),
avgWeighted(value4, weight)
FROM test_table_nullable_unsigned_values GROUP BY id ORDER BY id;
DROP TABLE test_table_nullable_unsigned_values;
SELECT 'Test nullable signed integer values';
DROP TABLE IF EXISTS test_table_nullable_signed_values;
CREATE TABLE test_table_nullable_signed_values
(
id UInt64,
value1 Nullable(Int8),
value2 Nullable(Int16),
value3 Nullable(Int32),
value4 Nullable(Int64),
weight UInt64
) ENGINE=TinyLog;
INSERT INTO test_table_nullable_signed_values SELECT number % 3, number, number, number, number, number % 3 FROM system.numbers LIMIT 120;
SELECT
id,
avgWeighted(value1, weight),
avgWeighted(value2, weight),
avgWeighted(value3, weight),
avgWeighted(value4, weight)
FROM test_table_nullable_signed_values GROUP BY id ORDER BY id;
DROP TABLE test_table_nullable_signed_values;
SELECT 'Test nullable float values';
DROP TABLE IF EXISTS test_table_nullable_float_values;
CREATE TABLE test_table_nullable_float_values
(
id UInt64,
value1 Nullable(Float32),
value2 Nullable(Float64),
weight UInt64
) ENGINE=TinyLog;
INSERT INTO test_table_nullable_float_values SELECT number % 3, number, number, number % 3 FROM system.numbers LIMIT 120;
SELECT id, avgWeighted(value1, weight), avgWeighted(value2, weight) FROM test_table_nullable_float_values GROUP BY id ORDER BY id;
DROP TABLE test_table_nullable_float_values;
SELECT 'Test null specifics';
DROP TABLE IF EXISTS test_table_null_specifics;
CREATE TABLE test_table_null_specifics
(
id UInt64,
value1 Nullable(UInt64),
value2 Nullable(UInt64),
value3 Nullable(UInt64),
weight UInt64,
weight_nullable Nullable(UInt64)
) ENGINE=TinyLog;
INSERT INTO test_table_null_specifics VALUES (0, 1, 1, NULL, 1, 1);
INSERT INTO test_table_null_specifics VALUES (0, 2, NULL, NULL, 2, NULL);
INSERT INTO test_table_null_specifics VALUES (0, 3, 3, NULL, 3, 3);
SELECT
id,
avgWeighted(value1, weight),
avgWeighted(value2, weight),
avgWeighted(value3, weight),
avgWeighted(value1, weight_nullable),
avgWeighted(value2, weight_nullable),
avgWeighted(value3, weight_nullable)
FROM test_table_null_specifics GROUP BY id ORDER BY id;
DROP TABLE IF EXISTS test_table_null_specifics;