mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-11-24 08:32:02 +00:00
decimal avg [CLICKHOUSE-3765]
This commit is contained in:
parent
dd5c55df2c
commit
1c4825138a
@ -9,16 +9,31 @@ namespace DB
|
||||
namespace
|
||||
{
|
||||
|
||||
template <typename T>
|
||||
struct Avg
|
||||
{
|
||||
using FieldType = std::conditional_t<IsDecimalNumber<T>, Decimal128, typename NearestFieldType<T>::Type>;
|
||||
using Function = AggregateFunctionAvg<T, AggregateFunctionAvgData<FieldType>>;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
using AggregateFuncAvg = typename Avg<T>::Function;
|
||||
|
||||
AggregateFunctionPtr createAggregateFunctionAvg(const std::string & name, const DataTypes & argument_types, const Array & parameters)
|
||||
{
|
||||
assertNoParameters(name, parameters);
|
||||
assertUnary(name, argument_types);
|
||||
|
||||
AggregateFunctionPtr res(createWithNumericType<AggregateFunctionAvg>(*argument_types[0]));
|
||||
AggregateFunctionPtr res;
|
||||
DataTypePtr data_type = argument_types[0];
|
||||
if (isDecimal(data_type))
|
||||
res.reset(createWithDecimalType<AggregateFuncAvg>(*data_type));
|
||||
else
|
||||
res.reset(createWithNumericType<AggregateFuncAvg>(*data_type));
|
||||
|
||||
if (!res)
|
||||
throw Exception("Illegal type " + argument_types[0]->getName() + " of argument for aggregate function " + name, ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
|
||||
|
||||
throw Exception("Illegal type " + argument_types[0]->getName() + " of argument for aggregate function " + name,
|
||||
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
|
||||
return res;
|
||||
}
|
||||
|
||||
|
@ -4,6 +4,7 @@
|
||||
#include <IO/ReadHelpers.h>
|
||||
|
||||
#include <DataTypes/DataTypesNumber.h>
|
||||
#include <DataTypes/DataTypesDecimal.h>
|
||||
#include <Columns/ColumnsNumber.h>
|
||||
|
||||
#include <AggregateFunctions/IAggregateFunction.h>
|
||||
@ -22,20 +23,39 @@ struct AggregateFunctionAvgData
|
||||
|
||||
|
||||
/// Calculates arithmetic mean of numbers.
|
||||
template <typename T>
|
||||
class AggregateFunctionAvg final : public IAggregateFunctionDataHelper<AggregateFunctionAvgData<typename NearestFieldType<T>::Type>, AggregateFunctionAvg<T>>
|
||||
template <typename T, typename Data>
|
||||
class AggregateFunctionAvg final : public IAggregateFunctionDataHelper<Data, AggregateFunctionAvg<T, Data>>
|
||||
{
|
||||
public:
|
||||
using ResultType = std::conditional_t<IsDecimalNumber<T>, Decimal128, Float64>;
|
||||
using ResultDataType = std::conditional_t<IsDecimalNumber<T>, DataTypeDecimal<Decimal128>, DataTypeNumber<Float64>>;
|
||||
using ColVecType = std::conditional_t<IsDecimalNumber<T>, ColumnDecimal<T>, ColumnVector<T>>;
|
||||
using ColVecResult = std::conditional_t<IsDecimalNumber<T>, ColumnDecimal<Decimal128>, ColumnVector<Float64>>;
|
||||
|
||||
/// ctor for native types
|
||||
AggregateFunctionAvg()
|
||||
: scale(0)
|
||||
{}
|
||||
|
||||
/// ctor for Decimals
|
||||
AggregateFunctionAvg(const IDataType & data_type)
|
||||
: scale(getDecimalScale(data_type))
|
||||
{}
|
||||
|
||||
String getName() const override { return "avg"; }
|
||||
|
||||
DataTypePtr getReturnType() const override
|
||||
{
|
||||
return std::make_shared<DataTypeFloat64>();
|
||||
if constexpr (IsDecimalNumber<T>)
|
||||
return std::make_shared<ResultDataType>(ResultDataType::maxPrecision(), scale);
|
||||
else
|
||||
return std::make_shared<ResultDataType>();
|
||||
}
|
||||
|
||||
void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena *) const override
|
||||
{
|
||||
this->data(place).sum += static_cast<const ColumnVector<T> &>(*columns[0]).getData()[row_num];
|
||||
const auto & column = static_cast<const ColVecType &>(*columns[0]);
|
||||
this->data(place).sum += column.getData()[row_num];
|
||||
++this->data(place).count;
|
||||
}
|
||||
|
||||
@ -59,11 +79,14 @@ public:
|
||||
|
||||
void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override
|
||||
{
|
||||
static_cast<ColumnFloat64 &>(to).getData().push_back(
|
||||
static_cast<Float64>(this->data(place).sum) / this->data(place).count);
|
||||
auto & column = static_cast<ColVecResult &>(to);
|
||||
column.getData().push_back(static_cast<ResultType>(this->data(place).sum) / this->data(place).count);
|
||||
}
|
||||
|
||||
const char * getHeaderFilePath() const override { return __FILE__; }
|
||||
|
||||
private:
|
||||
UInt32 scale;
|
||||
};
|
||||
|
||||
|
||||
|
@ -1,12 +1,19 @@
|
||||
101 101 101
|
||||
-25.0000 -16.66666666 -10.00000000
|
||||
25.0000 16.66666666 10.00000000
|
||||
-50.0000 -16.66666666 -10.00000000
|
||||
50.0000 16.66666666 10.00000000
|
||||
0.0000 0.00000000 0.00000000 0.0000 0.00000000 0.00000000
|
||||
1275.0000 424.99999983 255.00000000 1275.0000 424.99999983 255.00000000
|
||||
-1275.0000 -424.99999983 -255.00000000 -1275.0000 -424.99999983 -255.00000000
|
||||
101.0000 101.00000000 101.00000000 101.0000 101.00000000 101.00000000
|
||||
-101.0000 -101.00000000 -101.00000000 -101.0000 -101.00000000 -101.00000000
|
||||
-25.0000 -16.66666666 -10.00000000
|
||||
25.0000 16.66666666 10.00000000
|
||||
0.0000 0.00000000 0.00000000
|
||||
25.5000 8.49999999 5.10000000
|
||||
-25.5000 -8.49999999 -5.10000000
|
||||
101 101 101
|
||||
101 101 101
|
||||
101 101 101
|
||||
101 100 101
|
||||
102 100 101
|
||||
-50.0000 -50.0000 -16.66666666 -16.66666666 -10.00000000 -10.00000000
|
||||
1.0000 1.0000 0.33333333 0.33333333 0.20000000 0.20000000
|
||||
50.0000 50.0000 16.66666666 16.66666666 10.00000000 10.00000000
|
||||
-1.0000 -1.0000 -0.33333333 -0.33333333 -0.20000000 -0.20000000
|
||||
|
@ -12,7 +12,7 @@ CREATE TABLE test.decimal
|
||||
) ENGINE = Memory;
|
||||
|
||||
INSERT INTO test.decimal (a, b, c)
|
||||
SELECT toDecimal32(number - 50, 4) / 2, toDecimal64(number - 50, 8) / 3, toDecimal128(number - 50, 8) / 5
|
||||
SELECT toDecimal32(number - 50, 4), toDecimal64(number - 50, 8) / 3, toDecimal128(number - 50, 8) / 5
|
||||
FROM system.numbers LIMIT 101;
|
||||
|
||||
SELECT count(a), count(b), count(c) FROM test.decimal;
|
||||
@ -20,18 +20,24 @@ SELECT min(a), min(b), min(c) FROM test.decimal;
|
||||
SELECT max(a), max(b), max(c) FROM test.decimal;
|
||||
|
||||
SELECT sum(a), sum(b), sum(c), sumWithOverflow(a), sumWithOverflow(b), sumWithOverflow(c) FROM test.decimal;
|
||||
SELECT sum(a), sum(b), sum(c), sumWithOverflow(a), sumWithOverflow(b), sumWithOverflow(c) FROM test.decimal WHERE a > 0;
|
||||
SELECT sum(a), sum(b), sum(c), sumWithOverflow(a), sumWithOverflow(b), sumWithOverflow(c) FROM test.decimal WHERE a < 0;
|
||||
SELECT sum(a+1), sum(b+1), sum(c+1), sumWithOverflow(a+1), sumWithOverflow(b+1), sumWithOverflow(c+1) FROM test.decimal;
|
||||
SELECT sum(a-1), sum(b-1), sum(c-1), sumWithOverflow(a-1), sumWithOverflow(b-1), sumWithOverflow(c-1) FROM test.decimal;
|
||||
|
||||
--SELECT avg(a), avg(b), avg(c) FROM test.decimal;
|
||||
SELECT avg(a), avg(b), avg(c) FROM test.decimal;
|
||||
SELECT avg(a), avg(b), avg(c) FROM test.decimal WHERE a > 0;
|
||||
SELECT avg(a), avg(b), avg(c) FROM test.decimal WHERE a < 0;
|
||||
|
||||
SELECT argMin(a, b), argMin(b, a), argMin(c, a) FROM test.decimal;
|
||||
SELECT argMax(a, b), argMax(b, a), argMax(c, a) FROM test.decimal;
|
||||
SELECT uniq(a), uniq(b), uniq(c) FROM (SELECT * FROM test.decimal ORDER BY a);
|
||||
SELECT uniqCombined(a), uniqCombined(b), uniqCombined(c) FROM (SELECT * FROM test.decimal ORDER BY a);
|
||||
SELECT uniqExact(a), uniqExact(b), uniqExact(c) FROM (SELECT * FROM test.decimal ORDER BY a);
|
||||
SELECT uniqHLL12(a), uniqHLL12(b), uniqHLL12(c) FROM (SELECT * FROM test.decimal ORDER BY a);
|
||||
|
||||
SELECT uniq(a), uniq(b), uniq(c) FROM test.decimal;
|
||||
SELECT uniqCombined(a), uniqCombined(b), uniqCombined(c) FROM test.decimal;
|
||||
SELECT uniqExact(a), uniqExact(b), uniqExact(c) FROM test.decimal;
|
||||
SELECT uniqHLL12(a), uniqHLL12(b), uniqHLL12(c) FROM test.decimal;
|
||||
SELECT argMin(a, b), argMin(a, c), argMin(b, a), argMin(b, c), argMin(c, a), argMin(c, b) FROM test.decimal;
|
||||
SELECT argMin(a, b), argMin(a, c), argMin(b, a), argMin(b, c), argMin(c, a), argMin(c, b) FROM test.decimal WHERE a > 0;
|
||||
SELECT argMax(a, b), argMax(a, c), argMax(b, a), argMax(b, c), argMax(c, a), argMax(c, b) FROM test.decimal;
|
||||
SELECT argMax(a, b), argMax(a, c), argMax(b, a), argMax(b, c), argMax(c, a), argMax(c, b) FROM test.decimal WHERE a < 0;
|
||||
|
||||
--SELECT median(a), median(b), median(c) FROM test.decimal;
|
||||
--SELECT quantile(a), quantile(b), quantile(c) FROM test.decimal;
|
||||
@ -45,5 +51,3 @@ SELECT uniqHLL12(a), uniqHLL12(b), uniqHLL12(c) FROM test.decimal;
|
||||
|
||||
-- TODO: sumMap
|
||||
-- TODO: groupArray, groupArrayInsertAt, groupUniqArray
|
||||
|
||||
--SELECT topK(2)(a), topK(2)(b), topK(2)(c) FROM test.decimal; TODO: deterministic
|
||||
|
Loading…
Reference in New Issue
Block a user