ClickHouse/src/AggregateFunctions/AggregateFunctionAvg.h
2021-02-01 20:12:12 +03:00

160 lines
6.2 KiB
C++

#pragma once
#include <type_traits>
#include <IO/ReadHelpers.h>
#include <IO/WriteHelpers.h>
#include <Columns/ColumnsNumber.h>
#include <DataTypes/DataTypesDecimal.h>
#include <DataTypes/DataTypesNumber.h>
#include <AggregateFunctions/IAggregateFunction.h>
#include "Core/DecimalFunctions.h"
namespace DB
{
template <typename T>
using DecimalOrVectorCol = std::conditional_t<IsDecimalNumber<T>, ColumnDecimal<T>, ColumnVector<T>>;
template <typename T> constexpr bool DecimalOrExtendedInt =
IsDecimalNumber<T>
|| std::is_same_v<T, Int128>
|| std::is_same_v<T, Int256>
|| std::is_same_v<T, UInt128>
|| std::is_same_v<T, UInt256>;
/**
* Helper class to encapsulate values conversion for avg and avgWeighted.
*/
template <typename Numerator, typename Denominator>
struct AvgFraction
{
Numerator numerator{0};
Denominator denominator{0};
/// Allow division by zero as sometimes we need to return NaN.
/// Invoked only is either Numerator or Denominator are Decimal.
Float64 NO_SANITIZE_UNDEFINED divideIfAnyDecimal(UInt32 num_scale, UInt32 denom_scale [[maybe_unused]]) const
{
if constexpr (IsDecimalNumber<Numerator> && IsDecimalNumber<Denominator>)
{
// According to the docs, num(S1) / denom(S2) would have scale S1
if constexpr (std::is_same_v<Numerator, Decimal256> && std::is_same_v<Denominator, Decimal128>)
///Special case as Decimal256 / Decimal128 = compile error (as Decimal128 is not parametrized by a wide
///int), but an __int128 instead
return DecimalUtils::convertTo<Float64>(
numerator / (denominator.template convertTo<Decimal256>()), num_scale);
else
return DecimalUtils::convertTo<Float64>(numerator / denominator, num_scale);
}
/// Numerator is always casted to Float64 to divide correctly if the denominator is not Float64.
Float64 num_converted;
if constexpr (IsDecimalNumber<Numerator>)
num_converted = DecimalUtils::convertTo<Float64>(numerator, num_scale);
else
num_converted = static_cast<Float64>(numerator); /// all other types, including extended integral.
std::conditional_t<DecimalOrExtendedInt<Denominator>,
Float64, Denominator> denom_converted;
if constexpr (IsDecimalNumber<Denominator>)
denom_converted = DecimalUtils::convertTo<Float64>(denominator, denom_scale);
else if constexpr (DecimalOrExtendedInt<Denominator>)
/// no way to divide Float64 and extended integral type without an explicit cast.
denom_converted = static_cast<Float64>(denominator);
else
denom_converted = denominator; /// can divide on float, no cast required.
return num_converted / denom_converted;
}
Float64 NO_SANITIZE_UNDEFINED divide() const
{
if constexpr (DecimalOrExtendedInt<Denominator>) /// if extended int
return static_cast<Float64>(numerator) / static_cast<Float64>(denominator);
else
return static_cast<Float64>(numerator) / denominator;
}
};
/**
* @tparam Derived When deriving from this class, use the child class name as in CRTP, e.g.
* class Self : Agg<char, bool, bool, Self>.
*/
template <typename Numerator, typename Denominator, typename Derived>
class AggregateFunctionAvgBase : public
IAggregateFunctionDataHelper<AvgFraction<Numerator, Denominator>, Derived>
{
public:
using Fraction = AvgFraction<Numerator, Denominator>;
using Base = IAggregateFunctionDataHelper<Fraction, Derived>;
explicit AggregateFunctionAvgBase(const DataTypes & argument_types_,
UInt32 num_scale_ = 0, UInt32 denom_scale_ = 0)
: Base(argument_types_, {}), num_scale(num_scale_), denom_scale(denom_scale_) {}
DataTypePtr getReturnType() const final { return std::make_shared<DataTypeNumber<Float64>>(); }
void NO_SANITIZE_UNDEFINED merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override
{
this->data(place).numerator += this->data(rhs).numerator;
this->data(place).denominator += this->data(rhs).denominator;
}
void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf) const override
{
writeBinary(this->data(place).numerator, buf);
if constexpr (std::is_unsigned_v<Denominator>)
writeVarUInt(this->data(place).denominator, buf);
else /// Floating point denominator type can be used
writeBinary(this->data(place).denominator, buf);
}
void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, Arena *) const override
{
readBinary(this->data(place).numerator, buf);
if constexpr (std::is_unsigned_v<Denominator>)
readVarUInt(this->data(place).denominator, buf);
else /// Floating point denominator type can be used
readBinary(this->data(place).denominator, buf);
}
void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override
{
if constexpr (IsDecimalNumber<Numerator> || IsDecimalNumber<Denominator>)
assert_cast<ColumnVector<Float64> &>(to).getData().push_back(
this->data(place).divideIfAnyDecimal(num_scale, denom_scale));
else
assert_cast<ColumnVector<Float64> &>(to).getData().push_back(this->data(place).divide());
}
private:
UInt32 num_scale;
UInt32 denom_scale;
};
template <typename T>
using AvgFieldType = std::conditional_t<IsDecimalNumber<T>,
std::conditional_t<std::is_same_v<T, Decimal256>, Decimal256, Decimal128>,
NearestFieldType<T>>;
template <typename T>
class AggregateFunctionAvg final : public AggregateFunctionAvgBase<AvgFieldType<T>, UInt64, AggregateFunctionAvg<T>>
{
public:
using AggregateFunctionAvgBase<AvgFieldType<T>, UInt64, AggregateFunctionAvg<T>>::AggregateFunctionAvgBase;
void NO_SANITIZE_UNDEFINED add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const final
{
this->data(place).numerator += static_cast<const DecimalOrVectorCol<T> &>(*columns[0]).getData()[row_num];
++this->data(place).denominator;
}
String getName() const final { return "avg"; }
};
}