diff --git a/docs/en/sql-reference/aggregate-functions/reference/avg.md b/docs/en/sql-reference/aggregate-functions/reference/avg.md index 1741bbb744b..06117cb83cf 100644 --- a/docs/en/sql-reference/aggregate-functions/reference/avg.md +++ b/docs/en/sql-reference/aggregate-functions/reference/avg.md @@ -4,5 +4,59 @@ toc_priority: 5 # avg {#agg_function-avg} -Calculates the average. Only works for numbers (Integral, floating-point, or Decimals). -The result is always Float64. +Calculates the arithmetic mean. + +**Syntax** + +``` sql +avgWeighted(x) +``` + +**Parameter** + +- `x` — Values. + +`x` must be +[Integer](../../../sql-reference/data-types/int-uint.md), +[floating-point](../../../sql-reference/data-types/float.md), or +[Decimal](../../../sql-reference/data-types/decimal.md). + +**Returned value** + +- `0` if the supplied parameter is empty. +- Mean otherwise. + +**Return type** is always [Float64](../../../sql-reference/data-types/float.md). + +**Example** + +Query: + +``` sql +SELECT avg(x) FROM values('x Int8', 0, 1, 2, 3, 4, 5) +``` + +Result: + +``` text +┌─avg(x)─┐ +│ 2.5 │ +└────────┘ +``` + +**Example** + +Query: + +``` sql +CREATE table test (t UInt8) ENGINE = Memory; +SELECT avg(t) FROM test +``` + +Result: + +``` text +┌─avg(x)─┐ +│ 0 │ +└────────┘ +``` diff --git a/docs/en/sql-reference/aggregate-functions/reference/avgweighted.md b/docs/en/sql-reference/aggregate-functions/reference/avgweighted.md index 22993f93e16..7b9c0de2755 100644 --- a/docs/en/sql-reference/aggregate-functions/reference/avgweighted.md +++ b/docs/en/sql-reference/aggregate-functions/reference/avgweighted.md @@ -25,7 +25,7 @@ but may have different types. **Returned value** -- `NaN`. If all the weights are equal to 0. +- `NaN` if all the weights are equal to 0 or the supplied weights parameter is empty. - Weighted mean otherwise. **Return type** is always [Float64](../../../sql-reference/data-types/float.md). @@ -63,3 +63,37 @@ Result: │ 8 │ └────────────────────────┘ ``` + +**Example** + +Query: + +``` sql +SELECT avgWeighted(x, w) +FROM values('x Int8, w Int8', (0, 0), (1, 0), (10, 0)) +``` + +Result: + +``` text +┌─avgWeighted(x, weight)─┐ +│ nan │ +└────────────────────────┘ +``` + +**Example** + +Query: + +``` sql +CREATE table test (t UInt8) ENGINE = Memory; +SELECT avgWeighted(t) FROM test +``` + +Result: + +``` text +┌─avgWeighted(x, weight)─┐ +│ nan │ +└────────────────────────┘ +``` diff --git a/src/AggregateFunctions/AggregateFunctionAvg.h b/src/AggregateFunctions/AggregateFunctionAvg.h index 16eb11143da..c28d235a8f4 100644 --- a/src/AggregateFunctions/AggregateFunctionAvg.h +++ b/src/AggregateFunctions/AggregateFunctionAvg.h @@ -10,26 +10,23 @@ namespace DB { -/// A type-fixed fraction represented by a pair of #Numerator and #Denominator. -template + +/// @tparam BothZeroMeansNaN If false, the pair 0 / 0 = 0, nan otherwise. +template struct RationalFraction { - constexpr RationalFraction(): numerator(0), denominator(0) {} + Float64 numerator{0}; + Denominator denominator{0}; - Numerator numerator; - Denominator denominator; - - /// Calculate the fraction as a #Result. - template - Result NO_SANITIZE_UNDEFINED result() const + Float64 NO_SANITIZE_UNDEFINED result() const { - if constexpr (std::is_floating_point_v && std::numeric_limits::is_iec559) - return static_cast(numerator) / denominator; /// allow division by zero + if constexpr (BothZeroMeansNaN && std::numeric_limits::is_iec559) + return static_cast(numerator) / denominator; /// allow division by zero if (denominator == static_cast(0)) - return static_cast(0); + return static_cast(0); - return static_cast(numerator / denominator); + return static_cast(numerator / denominator); } }; @@ -46,31 +43,17 @@ struct RationalFraction * @tparam Derived When deriving from this class, use the child class name as in CRTP, e.g. * class Self : Agg. */ -template +template class AggregateFunctionAvgBase : public - IAggregateFunctionDataHelper, Derived> + IAggregateFunctionDataHelper, Derived> { public: - using Numerator = Float64; - using Fraction = RationalFraction; - - using ResultType = Float64; - using ResultDataType = DataTypeNumber; - using ResultVectorType = ColumnVector; - + using Fraction = RationalFraction; using Base = IAggregateFunctionDataHelper; - /// ctor for native types - explicit AggregateFunctionAvgBase(const DataTypes & argument_types_): Base(argument_types_, {}), scale(0) {} + explicit AggregateFunctionAvgBase(const DataTypes & argument_types_): Base(argument_types_, {}) {} - /// ctor for Decimals - AggregateFunctionAvgBase(const IDataType & data_type, const DataTypes & argument_types_) - : Base(argument_types_, {}), scale(getDecimalScale(data_type)) {} - - DataTypePtr getReturnType() const override - { - return std::make_shared(); - } + DataTypePtr getReturnType() const override { return std::make_shared>(); } void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena *) const override { @@ -100,17 +83,14 @@ public: void insertResultInto(AggregateDataPtr place, IColumn & to, Arena *) const override { - static_cast(to).getData().push_back(this->data(place).template result()); + static_cast &>(to).getData().push_back(this->data(place).result()); } - -protected: - UInt32 scale; }; -class AggregateFunctionAvg final : public AggregateFunctionAvgBase +class AggregateFunctionAvg final : public AggregateFunctionAvgBase { public: - using AggregateFunctionAvgBase::AggregateFunctionAvgBase; + using AggregateFunctionAvgBase::AggregateFunctionAvgBase; void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena *) const final { diff --git a/src/AggregateFunctions/AggregateFunctionAvgWeighted.h b/src/AggregateFunctions/AggregateFunctionAvgWeighted.h index ca9f0757cba..ef9384e48ab 100644 --- a/src/AggregateFunctions/AggregateFunctionAvgWeighted.h +++ b/src/AggregateFunctions/AggregateFunctionAvgWeighted.h @@ -5,10 +5,10 @@ namespace DB { -class AggregateFunctionAvgWeighted final : public AggregateFunctionAvgBase +class AggregateFunctionAvgWeighted final : public AggregateFunctionAvgBase { public: - using AggregateFunctionAvgBase::AggregateFunctionAvgBase; + using AggregateFunctionAvgBase::AggregateFunctionAvgBase; void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena *) const override { diff --git a/tests/queries/0_stateless/00700_decimal_empty_aggregates.reference b/tests/queries/0_stateless/00700_decimal_empty_aggregates.reference index 580cf0e26b7..b079e91fddc 100644 --- a/tests/queries/0_stateless/00700_decimal_empty_aggregates.reference +++ b/tests/queries/0_stateless/00700_decimal_empty_aggregates.reference @@ -5,9 +5,6 @@ 0.0000 0.0000000 0.00000000 0.0000 0.0000000 0.00000000 0.0000 0.0000000 0.00000000 0.0000 0.0000000 0.00000000 0.0000 0.0000000 0.00000000 0.0000 0.0000000 0.00000000 -0.0000 0.0000000 0.00000000 Decimal(9, 4) Decimal(18, 7) Decimal(38, 8) -0.0000 0.0000000 0.00000000 Decimal(9, 4) Decimal(18, 7) Decimal(38, 8) -0.0000 0.0000000 0.00000000 Decimal(9, 4) Decimal(18, 7) Decimal(38, 8) (0,0,0) (0,0,0) (0,0,0) (0,0,0) (0,0,0) 0 0 0 0 0 0 diff --git a/tests/queries/0_stateless/00700_decimal_empty_aggregates.sql b/tests/queries/0_stateless/00700_decimal_empty_aggregates.sql index 2d14ffae49d..c77f605a4c2 100644 --- a/tests/queries/0_stateless/00700_decimal_empty_aggregates.sql +++ b/tests/queries/0_stateless/00700_decimal_empty_aggregates.sql @@ -16,10 +16,6 @@ SELECT sum(a), sum(b), sum(c), sumWithOverflow(a), sumWithOverflow(b), sumWithOv SELECT sum(a+1), sum(b+1), sum(c+1), sumWithOverflow(a+1), sumWithOverflow(b+1), sumWithOverflow(c+1) FROM decimal; SELECT sum(a-1), sum(b-1), sum(c-1), sumWithOverflow(a-1), sumWithOverflow(b-1), sumWithOverflow(c-1) FROM decimal; -SELECT avg(a) as aa, avg(b) as ab, avg(c) as ac, toTypeName(aa), toTypeName(ab),toTypeName(ac) FROM decimal; -SELECT avg(a) as aa, avg(b) as ab, avg(c) as ac, toTypeName(aa), toTypeName(ab),toTypeName(ac) FROM decimal WHERE a > 0; -SELECT avg(a) as aa, avg(b) as ab, avg(c) as ac, toTypeName(aa), toTypeName(ab),toTypeName(ac) FROM decimal WHERE a < 0; - SELECT (uniq(a), uniq(b), uniq(c)), (uniqCombined(a), uniqCombined(b), uniqCombined(c)), (uniqCombined(17)(a), uniqCombined(17)(b), uniqCombined(17)(c)), diff --git a/tests/queries/0_stateless/01035_avg.reference b/tests/queries/0_stateless/01035_avg.reference new file mode 100644 index 00000000000..d1644f95165 --- /dev/null +++ b/tests/queries/0_stateless/01035_avg.reference @@ -0,0 +1,2 @@ +0 +499.5 diff --git a/tests/queries/0_stateless/01035_avg.sql b/tests/queries/0_stateless/01035_avg.sql new file mode 100644 index 00000000000..ee58587736f --- /dev/null +++ b/tests/queries/0_stateless/01035_avg.sql @@ -0,0 +1,9 @@ +CREATE TABLE IF NOT EXISTS test_01035 ( + t UInt16 +) ENGINE = Memory; + +SELECT avg(t) FROM test_01035; +INSERT INTO test_01035 SELECT * FROM system.numbers LIMIT 1000; +SELECT avg(t) FROM test_01035; + +DROP TABLE IF EXISTS test_01035