ClickHouse/src/AggregateFunctions/AggregateFunctionAvgWeighted.h

92 lines
4.0 KiB
C++
Raw Normal View History

2019-11-23 07:48:22 +00:00
#pragma once
#include <type_traits>
2019-11-23 07:48:22 +00:00
#include <AggregateFunctions/AggregateFunctionAvg.h>
namespace DB
{
struct Settings;
2021-05-06 15:45:58 +00:00
template <typename T>
2021-09-10 11:49:22 +00:00
using AvgWeightedFieldType = std::conditional_t<is_decimal<T>,
std::conditional_t<std::is_same_v<T, Decimal256>, Decimal256, Decimal128>,
std::conditional_t<DecimalOrExtendedInt<T>,
Float64, // no way to do UInt128 * UInt128, better cast to Float64
NearestFieldType<T>>>;
2021-05-06 15:45:58 +00:00
template <typename T, typename U>
2020-11-04 15:23:29 +00:00
using MaxFieldType = std::conditional_t<(sizeof(AvgWeightedFieldType<T>) > sizeof(AvgWeightedFieldType<U>)),
AvgWeightedFieldType<T>, AvgWeightedFieldType<U>>;
2020-11-04 13:14:07 +00:00
2021-05-06 15:45:58 +00:00
template <typename Value, typename Weight>
2020-11-03 14:56:07 +00:00
class AggregateFunctionAvgWeighted final :
public AggregateFunctionAvgBase<
2020-11-04 15:23:29 +00:00
MaxFieldType<Value, Weight>, AvgWeightedFieldType<Weight>, AggregateFunctionAvgWeighted<Value, Weight>>
2019-11-23 07:48:22 +00:00
{
public:
2020-11-04 13:14:07 +00:00
using Base = AggregateFunctionAvgBase<
2020-11-04 15:23:29 +00:00
MaxFieldType<Value, Weight>, AvgWeightedFieldType<Weight>, AggregateFunctionAvgWeighted<Value, Weight>>;
2020-11-03 14:56:07 +00:00
using Base::Base;
2020-09-28 14:33:52 +00:00
2021-06-06 15:43:03 +00:00
using Numerator = typename Base::Numerator;
using Denominator = typename Base::Denominator;
2021-06-30 11:44:45 +00:00
using Fraction = typename Base::Fraction;
void NO_SANITIZE_UNDEFINED add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override
2019-11-23 07:48:22 +00:00
{
2021-09-10 21:28:43 +00:00
const auto& weights = static_cast<const ColumnVectorOrDecimal<Weight> &>(*columns[1]);
2020-11-03 14:56:07 +00:00
2021-06-06 15:43:03 +00:00
this->data(place).numerator += static_cast<Numerator>(
2021-09-10 21:28:43 +00:00
static_cast<const ColumnVectorOrDecimal<Value> &>(*columns[0]).getData()[row_num]) *
2021-06-06 15:43:03 +00:00
static_cast<Numerator>(weights.getData()[row_num]);
2021-06-06 15:43:03 +00:00
this->data(place).denominator += static_cast<Denominator>(weights.getData()[row_num]);
2019-11-23 07:48:22 +00:00
}
String getName() const override { return "avgWeighted"; }
2021-06-06 15:43:03 +00:00
#if USE_EMBEDDED_COMPILER
bool isCompilable() const override
{
bool can_be_compiled = Base::isCompilable();
can_be_compiled &= canBeNativeType<Weight>();
2021-06-06 15:43:03 +00:00
return can_be_compiled;
2021-06-06 15:43:03 +00:00
}
void compileAdd(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr, const DataTypes & arguments_types, const std::vector<llvm::Value *> & argument_values) const override
{
llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder);
2021-06-06 15:43:03 +00:00
auto * numerator_type = toNativeType<Numerator>(b);
2021-06-06 15:43:03 +00:00
auto * numerator_ptr = b.CreatePointerCast(aggregate_data_ptr, numerator_type->getPointerTo());
auto * numerator_value = b.CreateLoad(numerator_type, numerator_ptr);
2021-06-06 15:43:03 +00:00
auto * argument = nativeCast(b, arguments_types[0], argument_values[0], numerator_type);
auto * weight = nativeCast(b, arguments_types[1], argument_values[1], numerator_type);
2021-06-06 15:43:03 +00:00
llvm::Value * value_weight_multiplication = argument->getType()->isIntegerTy() ? b.CreateMul(argument, weight) : b.CreateFMul(argument, weight);
auto * numerator_result_value = numerator_type->isIntegerTy() ? b.CreateAdd(numerator_value, value_weight_multiplication) : b.CreateFAdd(numerator_value, value_weight_multiplication);
b.CreateStore(numerator_result_value, numerator_ptr);
2021-06-06 15:43:03 +00:00
auto * denominator_type = toNativeType<Denominator>(b);
2021-06-06 15:43:03 +00:00
2021-06-30 11:44:45 +00:00
static constexpr size_t denominator_offset = offsetof(Fraction, denominator);
auto * denominator_offset_ptr = b.CreateConstInBoundsGEP1_64(nullptr, aggregate_data_ptr, denominator_offset);
auto * denominator_ptr = b.CreatePointerCast(denominator_offset_ptr, denominator_type->getPointerTo());
2021-06-06 15:43:03 +00:00
2021-06-30 11:44:45 +00:00
auto * weight_cast_to_denominator = nativeCast(b, arguments_types[1], argument_values[1], denominator_type);
2021-06-06 15:43:03 +00:00
auto * denominator_value = b.CreateLoad(denominator_type, denominator_ptr);
auto * denominator_value_updated = denominator_type->isIntegerTy() ? b.CreateAdd(denominator_value, weight_cast_to_denominator) : b.CreateFAdd(denominator_value, weight_cast_to_denominator);
2021-06-06 15:43:03 +00:00
b.CreateStore(denominator_value_updated, denominator_ptr);
}
2021-06-06 15:43:03 +00:00
#endif
2019-11-23 07:48:22 +00:00
};
}