fixed some wrong tests, updated docs

This commit is contained in:
myrrc 2020-10-30 19:17:57 +03:00
parent 9ca35c0b44
commit b3379e9cf0
8 changed files with 122 additions and 50 deletions

View File

@ -4,5 +4,59 @@ toc_priority: 5
# avg {#agg_function-avg} # avg {#agg_function-avg}
Calculates the average. Only works for numbers (Integral, floating-point, or Decimals). Calculates the arithmetic mean.
The result is always Float64.
**Syntax**
``` sql
avgWeighted(x)
```
**Parameter**
- `x` — Values.
`x` must be
[Integer](../../../sql-reference/data-types/int-uint.md),
[floating-point](../../../sql-reference/data-types/float.md), or
[Decimal](../../../sql-reference/data-types/decimal.md).
**Returned value**
- `0` if the supplied parameter is empty.
- Mean otherwise.
**Return type** is always [Float64](../../../sql-reference/data-types/float.md).
**Example**
Query:
``` sql
SELECT avg(x) FROM values('x Int8', 0, 1, 2, 3, 4, 5)
```
Result:
``` text
┌─avg(x)─┐
│ 2.5 │
└────────┘
```
**Example**
Query:
``` sql
CREATE table test (t UInt8) ENGINE = Memory;
SELECT avg(t) FROM test
```
Result:
``` text
┌─avg(x)─┐
│ 0 │
└────────┘
```

View File

@ -25,7 +25,7 @@ but may have different types.
**Returned value** **Returned value**
- `NaN`. If all the weights are equal to 0. - `NaN` if all the weights are equal to 0 or the supplied weights parameter is empty.
- Weighted mean otherwise. - Weighted mean otherwise.
**Return type** is always [Float64](../../../sql-reference/data-types/float.md). **Return type** is always [Float64](../../../sql-reference/data-types/float.md).
@ -63,3 +63,37 @@ Result:
│ 8 │ │ 8 │
└────────────────────────┘ └────────────────────────┘
``` ```
**Example**
Query:
``` sql
SELECT avgWeighted(x, w)
FROM values('x Int8, w Int8', (0, 0), (1, 0), (10, 0))
```
Result:
``` text
┌─avgWeighted(x, weight)─┐
│ nan │
└────────────────────────┘
```
**Example**
Query:
``` sql
CREATE table test (t UInt8) ENGINE = Memory;
SELECT avgWeighted(t) FROM test
```
Result:
``` text
┌─avgWeighted(x, weight)─┐
│ nan │
└────────────────────────┘
```

View File

@ -10,26 +10,23 @@
namespace DB namespace DB
{ {
/// A type-fixed fraction represented by a pair of #Numerator and #Denominator.
template <class Numerator, class Denominator> /// @tparam BothZeroMeansNaN If false, the pair 0 / 0 = 0, nan otherwise.
template <class Denominator, bool BothZeroMeansNaN = true>
struct RationalFraction struct RationalFraction
{ {
constexpr RationalFraction(): numerator(0), denominator(0) {} Float64 numerator{0};
Denominator denominator{0};
Numerator numerator; Float64 NO_SANITIZE_UNDEFINED result() const
Denominator denominator;
/// Calculate the fraction as a #Result.
template <class Result>
Result NO_SANITIZE_UNDEFINED result() const
{ {
if constexpr (std::is_floating_point_v<Result> && std::numeric_limits<Result>::is_iec559) if constexpr (BothZeroMeansNaN && std::numeric_limits<Float64>::is_iec559)
return static_cast<Result>(numerator) / denominator; /// allow division by zero return static_cast<Float64>(numerator) / denominator; /// allow division by zero
if (denominator == static_cast<Denominator>(0)) if (denominator == static_cast<Denominator>(0))
return static_cast<Result>(0); return static_cast<Float64>(0);
return static_cast<Result>(numerator / denominator); return static_cast<Float64>(numerator / denominator);
} }
}; };
@ -46,31 +43,17 @@ struct RationalFraction
* @tparam Derived When deriving from this class, use the child class name as in CRTP, e.g. * @tparam Derived When deriving from this class, use the child class name as in CRTP, e.g.
* class Self : Agg<char, bool, bool, Self>. * class Self : Agg<char, bool, bool, Self>.
*/ */
template <class Denominator, class Derived> template <class Denominator, bool BothZeroMeansNaN, class Derived>
class AggregateFunctionAvgBase : public class AggregateFunctionAvgBase : public
IAggregateFunctionDataHelper<RationalFraction<Float64, Denominator>, Derived> IAggregateFunctionDataHelper<RationalFraction<Denominator, BothZeroMeansNaN>, Derived>
{ {
public: public:
using Numerator = Float64; using Fraction = RationalFraction<Denominator, BothZeroMeansNaN>;
using Fraction = RationalFraction<Numerator, Denominator>;
using ResultType = Float64;
using ResultDataType = DataTypeNumber<Float64>;
using ResultVectorType = ColumnVector<Float64>;
using Base = IAggregateFunctionDataHelper<Fraction, Derived>; using Base = IAggregateFunctionDataHelper<Fraction, Derived>;
/// ctor for native types explicit AggregateFunctionAvgBase(const DataTypes & argument_types_): Base(argument_types_, {}) {}
explicit AggregateFunctionAvgBase(const DataTypes & argument_types_): Base(argument_types_, {}), scale(0) {}
/// ctor for Decimals DataTypePtr getReturnType() const override { return std::make_shared<DataTypeNumber<Float64>>(); }
AggregateFunctionAvgBase(const IDataType & data_type, const DataTypes & argument_types_)
: Base(argument_types_, {}), scale(getDecimalScale(data_type)) {}
DataTypePtr getReturnType() const override
{
return std::make_shared<ResultDataType>();
}
void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena *) const override void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena *) const override
{ {
@ -100,17 +83,14 @@ public:
void insertResultInto(AggregateDataPtr place, IColumn & to, Arena *) const override void insertResultInto(AggregateDataPtr place, IColumn & to, Arena *) const override
{ {
static_cast<ResultVectorType &>(to).getData().push_back(this->data(place).template result<ResultType>()); static_cast<ColumnVector<Float64> &>(to).getData().push_back(this->data(place).result());
} }
protected:
UInt32 scale;
}; };
class AggregateFunctionAvg final : public AggregateFunctionAvgBase<UInt64, AggregateFunctionAvg> class AggregateFunctionAvg final : public AggregateFunctionAvgBase<UInt64, false, AggregateFunctionAvg>
{ {
public: public:
using AggregateFunctionAvgBase<UInt64, AggregateFunctionAvg>::AggregateFunctionAvgBase; using AggregateFunctionAvgBase<UInt64, false, AggregateFunctionAvg>::AggregateFunctionAvgBase;
void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena *) const final void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena *) const final
{ {

View File

@ -5,10 +5,10 @@
namespace DB namespace DB
{ {
class AggregateFunctionAvgWeighted final : public AggregateFunctionAvgBase<Float64, AggregateFunctionAvgWeighted> class AggregateFunctionAvgWeighted final : public AggregateFunctionAvgBase<Float64, true, AggregateFunctionAvgWeighted>
{ {
public: public:
using AggregateFunctionAvgBase<Float64, AggregateFunctionAvgWeighted>::AggregateFunctionAvgBase; using AggregateFunctionAvgBase<Float64, true, AggregateFunctionAvgWeighted>::AggregateFunctionAvgBase;
void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena *) const override void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena *) const override
{ {

View File

@ -5,9 +5,6 @@
0.0000 0.0000000 0.00000000 0.0000 0.0000000 0.00000000 0.0000 0.0000000 0.00000000 0.0000 0.0000000 0.00000000
0.0000 0.0000000 0.00000000 0.0000 0.0000000 0.00000000 0.0000 0.0000000 0.00000000 0.0000 0.0000000 0.00000000
0.0000 0.0000000 0.00000000 0.0000 0.0000000 0.00000000 0.0000 0.0000000 0.00000000 0.0000 0.0000000 0.00000000
0.0000 0.0000000 0.00000000 Decimal(9, 4) Decimal(18, 7) Decimal(38, 8)
0.0000 0.0000000 0.00000000 Decimal(9, 4) Decimal(18, 7) Decimal(38, 8)
0.0000 0.0000000 0.00000000 Decimal(9, 4) Decimal(18, 7) Decimal(38, 8)
(0,0,0) (0,0,0) (0,0,0) (0,0,0) (0,0,0) (0,0,0) (0,0,0) (0,0,0) (0,0,0) (0,0,0)
0 0 0 0 0 0
0 0 0 0 0 0

View File

@ -16,10 +16,6 @@ SELECT sum(a), sum(b), sum(c), sumWithOverflow(a), sumWithOverflow(b), sumWithOv
SELECT sum(a+1), sum(b+1), sum(c+1), sumWithOverflow(a+1), sumWithOverflow(b+1), sumWithOverflow(c+1) FROM decimal; SELECT sum(a+1), sum(b+1), sum(c+1), sumWithOverflow(a+1), sumWithOverflow(b+1), sumWithOverflow(c+1) FROM decimal;
SELECT sum(a-1), sum(b-1), sum(c-1), sumWithOverflow(a-1), sumWithOverflow(b-1), sumWithOverflow(c-1) FROM decimal; SELECT sum(a-1), sum(b-1), sum(c-1), sumWithOverflow(a-1), sumWithOverflow(b-1), sumWithOverflow(c-1) FROM decimal;
SELECT avg(a) as aa, avg(b) as ab, avg(c) as ac, toTypeName(aa), toTypeName(ab),toTypeName(ac) FROM decimal;
SELECT avg(a) as aa, avg(b) as ab, avg(c) as ac, toTypeName(aa), toTypeName(ab),toTypeName(ac) FROM decimal WHERE a > 0;
SELECT avg(a) as aa, avg(b) as ab, avg(c) as ac, toTypeName(aa), toTypeName(ab),toTypeName(ac) FROM decimal WHERE a < 0;
SELECT (uniq(a), uniq(b), uniq(c)), SELECT (uniq(a), uniq(b), uniq(c)),
(uniqCombined(a), uniqCombined(b), uniqCombined(c)), (uniqCombined(a), uniqCombined(b), uniqCombined(c)),
(uniqCombined(17)(a), uniqCombined(17)(b), uniqCombined(17)(c)), (uniqCombined(17)(a), uniqCombined(17)(b), uniqCombined(17)(c)),

View File

@ -0,0 +1,2 @@
0
499.5

View File

@ -0,0 +1,9 @@
CREATE TABLE IF NOT EXISTS test_01035 (
t UInt16
) ENGINE = Memory;
SELECT avg(t) FROM test_01035;
INSERT INTO test_01035 SELECT * FROM system.numbers LIMIT 1000;
SELECT avg(t) FROM test_01035;
DROP TABLE IF EXISTS test_01035