2019-11-23 07:48:22 +00:00
|
|
|
#pragma once
|
|
|
|
|
2020-10-22 14:29:32 +00:00
|
|
|
#include <type_traits>
|
2019-11-23 07:48:22 +00:00
|
|
|
#include <AggregateFunctions/AggregateFunctionAvg.h>
|
|
|
|
|
|
|
|
namespace DB
|
|
|
|
{
|
2021-05-26 11:32:14 +00:00
|
|
|
struct Settings;
|
|
|
|
|
2021-05-06 15:45:58 +00:00
|
|
|
template <typename T>
|
2021-09-10 11:49:22 +00:00
|
|
|
using AvgWeightedFieldType = std::conditional_t<is_decimal<T>,
|
2020-11-04 14:23:04 +00:00
|
|
|
std::conditional_t<std::is_same_v<T, Decimal256>, Decimal256, Decimal128>,
|
|
|
|
std::conditional_t<DecimalOrExtendedInt<T>,
|
|
|
|
Float64, // no way to do UInt128 * UInt128, better cast to Float64
|
|
|
|
NearestFieldType<T>>>;
|
|
|
|
|
2021-05-06 15:45:58 +00:00
|
|
|
template <typename T, typename U>
|
2020-11-04 15:23:29 +00:00
|
|
|
using MaxFieldType = std::conditional_t<(sizeof(AvgWeightedFieldType<T>) > sizeof(AvgWeightedFieldType<U>)),
|
|
|
|
AvgWeightedFieldType<T>, AvgWeightedFieldType<U>>;
|
2020-11-04 13:14:07 +00:00
|
|
|
|
2021-05-06 15:45:58 +00:00
|
|
|
template <typename Value, typename Weight>
|
2020-11-03 14:56:07 +00:00
|
|
|
class AggregateFunctionAvgWeighted final :
|
2020-11-04 14:23:04 +00:00
|
|
|
public AggregateFunctionAvgBase<
|
2020-11-04 15:23:29 +00:00
|
|
|
MaxFieldType<Value, Weight>, AvgWeightedFieldType<Weight>, AggregateFunctionAvgWeighted<Value, Weight>>
|
2019-11-23 07:48:22 +00:00
|
|
|
{
|
|
|
|
public:
|
2020-11-04 13:14:07 +00:00
|
|
|
using Base = AggregateFunctionAvgBase<
|
2020-11-04 15:23:29 +00:00
|
|
|
MaxFieldType<Value, Weight>, AvgWeightedFieldType<Weight>, AggregateFunctionAvgWeighted<Value, Weight>>;
|
2020-11-03 14:56:07 +00:00
|
|
|
using Base::Base;
|
2020-09-28 14:33:52 +00:00
|
|
|
|
2021-06-06 15:43:03 +00:00
|
|
|
using Numerator = typename Base::Numerator;
|
|
|
|
using Denominator = typename Base::Denominator;
|
2021-06-30 11:44:45 +00:00
|
|
|
using Fraction = typename Base::Fraction;
|
2020-11-04 14:23:04 +00:00
|
|
|
|
2021-02-01 17:12:12 +00:00
|
|
|
void NO_SANITIZE_UNDEFINED add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override
|
2019-11-23 07:48:22 +00:00
|
|
|
{
|
2021-09-10 21:28:43 +00:00
|
|
|
const auto& weights = static_cast<const ColumnVectorOrDecimal<Weight> &>(*columns[1]);
|
2020-11-03 14:56:07 +00:00
|
|
|
|
2021-06-06 15:43:03 +00:00
|
|
|
this->data(place).numerator += static_cast<Numerator>(
|
2021-09-10 21:28:43 +00:00
|
|
|
static_cast<const ColumnVectorOrDecimal<Value> &>(*columns[0]).getData()[row_num]) *
|
2021-06-06 15:43:03 +00:00
|
|
|
static_cast<Numerator>(weights.getData()[row_num]);
|
2020-11-04 14:23:04 +00:00
|
|
|
|
2021-06-06 15:43:03 +00:00
|
|
|
this->data(place).denominator += static_cast<Denominator>(weights.getData()[row_num]);
|
2019-11-23 07:48:22 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
String getName() const override { return "avgWeighted"; }
|
2021-06-06 15:43:03 +00:00
|
|
|
|
|
|
|
#if USE_EMBEDDED_COMPILER
|
|
|
|
|
|
|
|
bool isCompilable() const override
|
|
|
|
{
|
2021-06-26 16:26:32 +00:00
|
|
|
bool can_be_compiled = Base::isCompilable();
|
|
|
|
can_be_compiled &= canBeNativeType<Weight>();
|
2021-06-06 15:43:03 +00:00
|
|
|
|
2021-06-26 16:26:32 +00:00
|
|
|
return can_be_compiled;
|
2021-06-06 15:43:03 +00:00
|
|
|
}
|
|
|
|
|
2021-06-26 16:26:32 +00:00
|
|
|
void compileAdd(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr, const DataTypes & arguments_types, const std::vector<llvm::Value *> & argument_values) const override
|
|
|
|
{
|
|
|
|
llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder);
|
2021-06-06 15:43:03 +00:00
|
|
|
|
2021-06-26 16:26:32 +00:00
|
|
|
auto * numerator_type = toNativeType<Numerator>(b);
|
2021-06-06 15:43:03 +00:00
|
|
|
|
2021-06-26 16:26:32 +00:00
|
|
|
auto * numerator_ptr = b.CreatePointerCast(aggregate_data_ptr, numerator_type->getPointerTo());
|
|
|
|
auto * numerator_value = b.CreateLoad(numerator_type, numerator_ptr);
|
2021-06-06 15:43:03 +00:00
|
|
|
|
2021-06-26 16:26:32 +00:00
|
|
|
auto * argument = nativeCast(b, arguments_types[0], argument_values[0], numerator_type);
|
|
|
|
auto * weight = nativeCast(b, arguments_types[1], argument_values[1], numerator_type);
|
2021-06-06 15:43:03 +00:00
|
|
|
|
2021-06-26 16:26:32 +00:00
|
|
|
llvm::Value * value_weight_multiplication = argument->getType()->isIntegerTy() ? b.CreateMul(argument, weight) : b.CreateFMul(argument, weight);
|
|
|
|
auto * numerator_result_value = numerator_type->isIntegerTy() ? b.CreateAdd(numerator_value, value_weight_multiplication) : b.CreateFAdd(numerator_value, value_weight_multiplication);
|
|
|
|
b.CreateStore(numerator_result_value, numerator_ptr);
|
2021-06-06 15:43:03 +00:00
|
|
|
|
2021-06-26 16:26:32 +00:00
|
|
|
auto * denominator_type = toNativeType<Denominator>(b);
|
2021-06-06 15:43:03 +00:00
|
|
|
|
2021-06-30 11:44:45 +00:00
|
|
|
static constexpr size_t denominator_offset = offsetof(Fraction, denominator);
|
2021-07-05 09:17:01 +00:00
|
|
|
auto * denominator_offset_ptr = b.CreateConstInBoundsGEP1_64(nullptr, aggregate_data_ptr, denominator_offset);
|
2021-06-26 16:26:32 +00:00
|
|
|
auto * denominator_ptr = b.CreatePointerCast(denominator_offset_ptr, denominator_type->getPointerTo());
|
2021-06-06 15:43:03 +00:00
|
|
|
|
2021-06-30 11:44:45 +00:00
|
|
|
auto * weight_cast_to_denominator = nativeCast(b, arguments_types[1], argument_values[1], denominator_type);
|
2021-06-06 15:43:03 +00:00
|
|
|
|
2021-06-26 16:26:32 +00:00
|
|
|
auto * denominator_value = b.CreateLoad(denominator_type, denominator_ptr);
|
|
|
|
auto * denominator_value_updated = denominator_type->isIntegerTy() ? b.CreateAdd(denominator_value, weight_cast_to_denominator) : b.CreateFAdd(denominator_value, weight_cast_to_denominator);
|
2021-06-06 15:43:03 +00:00
|
|
|
|
2021-06-26 16:26:32 +00:00
|
|
|
b.CreateStore(denominator_value_updated, denominator_ptr);
|
|
|
|
}
|
2021-06-06 15:43:03 +00:00
|
|
|
|
|
|
|
#endif
|
|
|
|
|
2019-11-23 07:48:22 +00:00
|
|
|
};
|
|
|
|
}
|