ClickHouse/dbms/src/AggregateFunctions/AggregateFunctionAvg.h

112 lines
3.7 KiB
C++
Raw Normal View History

2011-09-26 04:00:46 +00:00
#pragma once
#include <IO/ReadHelpers.h>
2019-11-23 07:55:41 +00:00
#include <IO/WriteHelpers.h>
2011-09-26 04:00:46 +00:00
#include <Columns/ColumnsNumber.h>
2019-11-23 07:55:41 +00:00
#include <DataTypes/DataTypesDecimal.h>
#include <DataTypes/DataTypesNumber.h>
2011-09-26 04:00:46 +00:00
#include <AggregateFunctions/IAggregateFunction.h>
2011-09-26 04:00:46 +00:00
namespace DB
{
namespace ErrorCodes
{
extern const int LOGICAL_ERROR;
}
2011-09-26 04:00:46 +00:00
2019-12-17 14:00:40 +00:00
template <typename T, typename Denominator>
struct AggregateFunctionAvgData
2011-09-26 04:00:46 +00:00
{
2019-11-23 07:48:22 +00:00
T numerator = 0;
2019-12-17 14:00:40 +00:00
Denominator denominator = 0;
template <typename ResultT>
2018-12-27 00:51:14 +00:00
ResultT NO_SANITIZE_UNDEFINED result() const
{
if constexpr (std::is_floating_point_v<ResultT>)
if constexpr (std::numeric_limits<ResultT>::is_iec559)
2019-11-23 07:48:22 +00:00
return static_cast<ResultT>(numerator) / denominator; /// allow division by zero
2019-11-23 07:48:22 +00:00
if (denominator == 0)
return static_cast<ResultT>(0);
2019-11-23 07:48:22 +00:00
return static_cast<ResultT>(numerator / denominator);
}
};
2017-03-09 00:56:38 +00:00
/// Calculates arithmetic mean of numbers.
2019-12-15 13:36:44 +00:00
template <typename T, typename Data, typename Derived>
class AggregateFunctionAvgBase : public IAggregateFunctionDataHelper<Data, Derived>
{
public:
2019-10-22 15:31:56 +00:00
using ResultType = std::conditional_t<IsDecimalNumber<T>, T, Float64>;
using ResultDataType = std::conditional_t<IsDecimalNumber<T>, DataTypeDecimal<T>, DataTypeNumber<Float64>>;
2018-09-12 13:27:32 +00:00
using ColVecType = std::conditional_t<IsDecimalNumber<T>, ColumnDecimal<T>, ColumnVector<T>>;
2019-10-22 15:31:56 +00:00
using ColVecResult = std::conditional_t<IsDecimalNumber<T>, ColumnDecimal<T>, ColumnVector<Float64>>;
2018-09-12 13:27:32 +00:00
/// ctor for native types
2019-12-15 13:36:44 +00:00
AggregateFunctionAvgBase(const DataTypes & argument_types_) : IAggregateFunctionDataHelper<Data, Derived>(argument_types_, {}), scale(0) {}
2018-09-12 13:27:32 +00:00
/// ctor for Decimals
2019-11-23 07:48:22 +00:00
AggregateFunctionAvgBase(const IDataType & data_type, const DataTypes & argument_types_)
2019-12-15 13:36:44 +00:00
: IAggregateFunctionDataHelper<Data, Derived>(argument_types_, {}), scale(getDecimalScale(data_type))
2019-11-23 07:55:41 +00:00
{
}
2018-09-12 13:27:32 +00:00
DataTypePtr getReturnType() const override
{
2018-09-12 13:27:32 +00:00
if constexpr (IsDecimalNumber<T>)
return std::make_shared<ResultDataType>(ResultDataType::maxPrecision(), scale);
2018-09-12 13:27:32 +00:00
else
return std::make_shared<ResultDataType>();
}
2017-12-01 21:51:50 +00:00
void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena *) const override
{
2019-11-23 07:48:22 +00:00
this->data(place).numerator += this->data(rhs).numerator;
this->data(place).denominator += this->data(rhs).denominator;
}
void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override
{
2019-11-23 07:48:22 +00:00
writeBinary(this->data(place).numerator, buf);
writeVarUInt(this->data(place).denominator, buf);
}
void deserialize(AggregateDataPtr place, ReadBuffer & buf, Arena *) const override
{
2019-11-23 07:48:22 +00:00
readBinary(this->data(place).numerator, buf);
readVarUInt(this->data(place).denominator, buf);
}
void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override
{
2018-09-12 13:27:32 +00:00
auto & column = static_cast<ColVecResult &>(to);
column.getData().push_back(this->data(place).template result<ResultType>());
}
2019-11-23 07:48:22 +00:00
protected:
2018-09-12 13:27:32 +00:00
UInt32 scale;
2011-09-26 04:00:46 +00:00
};
2019-11-23 07:48:22 +00:00
template <typename T, typename Data>
2019-12-15 13:36:44 +00:00
class AggregateFunctionAvg final : public AggregateFunctionAvgBase<T, Data, AggregateFunctionAvg<T, Data>>
2019-11-23 07:48:22 +00:00
{
public:
2019-12-15 13:36:44 +00:00
using AggregateFunctionAvgBase<T, Data, AggregateFunctionAvg<T, Data>>::AggregateFunctionAvgBase;
2019-11-23 07:48:22 +00:00
using ColVecType = std::conditional_t<IsDecimalNumber<T>, ColumnDecimal<T>, ColumnVector<T>>;
void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena *) const override
{
const auto & column = static_cast<const ColVecType &>(*columns[0]);
this->data(place).numerator += column.getData()[row_num];
2019-12-15 13:36:44 +00:00
this->data(place).denominator += 1;
2019-11-23 07:48:22 +00:00
}
String getName() const override { return "avg"; }
};
2011-09-26 04:00:46 +00:00
}