#pragma once #include #include #include #include #include #include namespace DB { namespace ErrorCodes { extern const int LOGICAL_ERROR; } template struct AggregateFunctionAvgData { T sum = 0; UInt64 count = 0; template ResultT NO_SANITIZE_UNDEFINED result() const { if constexpr (std::is_floating_point_v) if constexpr (std::numeric_limits::is_iec559) return static_cast(sum) / count; /// allow division by zero if (!count) throw Exception("AggregateFunctionAvg with zero values", ErrorCodes::LOGICAL_ERROR); return static_cast(sum) / count; } }; /// Calculates arithmetic mean of numbers. template class AggregateFunctionAvg final : public IAggregateFunctionDataHelper> { public: using ResultType = std::conditional_t, Decimal128, Float64>; using ResultDataType = std::conditional_t, DataTypeDecimal, DataTypeNumber>; using ColVecType = std::conditional_t, ColumnDecimal, ColumnVector>; using ColVecResult = std::conditional_t, ColumnDecimal, ColumnVector>; /// ctor for native types AggregateFunctionAvg(const DataTypes & argument_types_) : IAggregateFunctionDataHelper>(argument_types_, {}) , scale(0) {} /// ctor for Decimals AggregateFunctionAvg(const IDataType & data_type, const DataTypes & argument_types_) : IAggregateFunctionDataHelper>(argument_types_, {}) , scale(getDecimalScale(data_type)) {} String getName() const override { return "avg"; } DataTypePtr getReturnType() const override { if constexpr (IsDecimalNumber) return std::make_shared(ResultDataType::maxPrecision(), scale); else return std::make_shared(); } void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena *) const override { const auto & column = static_cast(*columns[0]); this->data(place).sum += column.getData()[row_num]; ++this->data(place).count; } void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena *) const override { this->data(place).sum += this->data(rhs).sum; this->data(place).count += this->data(rhs).count; } void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override { writeBinary(this->data(place).sum, buf); writeVarUInt(this->data(place).count, buf); } void deserialize(AggregateDataPtr place, ReadBuffer & buf, Arena *) const override { readBinary(this->data(place).sum, buf); readVarUInt(this->data(place).count, buf); } void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override { auto & column = static_cast(to); column.getData().push_back(this->data(place).template result()); } const char * getHeaderFilePath() const override { return __FILE__; } private: UInt32 scale; }; }