ClickHouse/src/AggregateFunctions/AggregateFunctionAvg.h
2020-11-03 17:56:07 +03:00

110 lines
4.0 KiB
C++

#pragma once
#include <IO/ReadHelpers.h>
#include <IO/WriteHelpers.h>
#include <Columns/ColumnsNumber.h>
#include <DataTypes/DataTypesDecimal.h>
#include <DataTypes/DataTypesNumber.h>
#include <AggregateFunctions/IAggregateFunction.h>
namespace DB
{
template <class Denominator>
struct RationalFraction
{
Float64 numerator{0};
Denominator denominator{0};
/// Allow division by zero as sometimes we need to return NaN.
Float64 NO_SANITIZE_UNDEFINED result() const { return numerator / denominator; }
};
template <class T> constexpr bool DecimalOrExtendedInt =
IsDecimalNumber<T>
|| std::is_same_v<T, Int128>
|| std::is_same_v<T, Int256>
|| std::is_same_v<T, UInt128>
|| std::is_same_v<T, UInt256>;
/**
* The discussion showed that the easiest (and simplest) way is to cast both the columns of numerator and denominator
* to Float64. Another way would be to write some template magic that figures out the appropriate numerator
* and denominator (and the resulting type) in favour of extended integral types (UInt128 e.g.) and Decimals (
* which are a mess themselves). The second way is also a bit useless because now Decimals are not used in functions
* like avg.
*
* The ability to explicitly specify the denominator is made for avg (it uses the integral value as the denominator is
* simply the length of the supplied list).
*
* @tparam Derived When deriving from this class, use the child class name as in CRTP, e.g.
* class Self : Agg<char, bool, bool, Self>.
*/
template <class Denominator, class Derived>
class AggregateFunctionAvgBase : public
IAggregateFunctionDataHelper<RationalFraction<Denominator>, Derived>
{
public:
using Fraction = RationalFraction<Denominator>;
using Base = IAggregateFunctionDataHelper<Fraction, Derived>;
explicit AggregateFunctionAvgBase(const DataTypes & argument_types_): Base(argument_types_, {}) {}
DataTypePtr getReturnType() const override { return std::make_shared<DataTypeNumber<Float64>>(); }
void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena *) const override
{
this->data(place).numerator += this->data(rhs).numerator;
this->data(place).denominator += this->data(rhs).denominator;
}
void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override
{
writeBinary(this->data(place).numerator, buf);
if constexpr (std::is_unsigned_v<Denominator>)
writeVarUInt(this->data(place).denominator, buf);
else /// Floating point denominator type can be used
writeBinary(this->data(place).denominator, buf);
}
void deserialize(AggregateDataPtr place, ReadBuffer & buf, Arena *) const override
{
readBinary(this->data(place).numerator, buf);
if constexpr (std::is_unsigned_v<Denominator>)
readVarUInt(this->data(place).denominator, buf);
else /// Floating point denominator type can be used
readBinary(this->data(place).denominator, buf);
}
void insertResultInto(AggregateDataPtr place, IColumn & to, Arena *) const override
{
static_cast<ColumnVector<Float64> &>(to).getData().push_back(this->data(place).result());
}
};
template <class T>
class AggregateFunctionAvg final : public AggregateFunctionAvgBase<UInt64, AggregateFunctionAvg<T>>
{
public:
using AggregateFunctionAvgBase<UInt64, AggregateFunctionAvg<T>>::AggregateFunctionAvgBase;
void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena *) const final
{
if constexpr(IsDecimalNumber<T>)
this->data(place).numerator += columns[0]->getFloat64(row_num);
else if constexpr(DecimalOrExtendedInt<T>)
this->data(place).numerator += static_cast<Float64>(
static_cast<const ColumnVector<T> &>(*columns[0]).getData()[row_num]);
else
this->data(place).numerator += static_cast<const ColumnVector<T> &>(*columns[0]).getData()[row_num];
++this->data(place).denominator;
}
String getName() const final { return "avg"; }
};
}