#pragma once #include #include #include #include #include #include namespace DB { namespace ErrorCodes { } template struct AggregateFunctionAvgData { T numerator = 0; Denominator denominator = 0; template ResultT NO_SANITIZE_UNDEFINED result() const { if constexpr (std::is_floating_point_v) if constexpr (std::numeric_limits::is_iec559) return static_cast(numerator) / denominator; /// allow division by zero if (denominator == 0) return static_cast(0); return static_cast(numerator / denominator); } }; /// Calculates arithmetic mean of numbers. template class AggregateFunctionAvgBase : public IAggregateFunctionDataHelper { public: using ResultType = std::conditional_t, T, Float64>; using ResultDataType = std::conditional_t, DataTypeDecimal, DataTypeNumber>; using ColVecType = std::conditional_t, ColumnDecimal, ColumnVector>; using ColVecResult = std::conditional_t, ColumnDecimal, ColumnVector>; /// ctor for native types AggregateFunctionAvgBase(const DataTypes & argument_types_) : IAggregateFunctionDataHelper(argument_types_, {}), scale(0) {} /// ctor for Decimals AggregateFunctionAvgBase(const IDataType & data_type, const DataTypes & argument_types_) : IAggregateFunctionDataHelper(argument_types_, {}), scale(getDecimalScale(data_type)) { } DataTypePtr getReturnType() const override { if constexpr (IsDecimalNumber) return std::make_shared(ResultDataType::maxPrecision(), scale); else return std::make_shared(); } void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena *) const override { this->data(place).numerator += this->data(rhs).numerator; this->data(place).denominator += this->data(rhs).denominator; } void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override { writeBinary(this->data(place).numerator, buf); writeVarUInt(this->data(place).denominator, buf); } void deserialize(AggregateDataPtr place, ReadBuffer & buf, Arena *) const override { readBinary(this->data(place).numerator, buf); readVarUInt(this->data(place).denominator, buf); } void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override { auto & column = static_cast(to); column.getData().push_back(this->data(place).template result()); } protected: UInt32 scale; }; template class AggregateFunctionAvg final : public AggregateFunctionAvgBase> { public: using AggregateFunctionAvgBase>::AggregateFunctionAvgBase; using ColVecType = std::conditional_t, ColumnDecimal, ColumnVector>; void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena *) const override { const auto & column = static_cast(*columns[0]); this->data(place).numerator += column.getData()[row_num]; this->data(place).denominator += 1; } String getName() const override { return "avg"; } }; }