decimal avg [CLICKHOUSE-3765]

This commit is contained in:
chertus 2018-09-12 16:27:32 +03:00
parent dd5c55df2c
commit 1c4825138a
4 changed files with 73 additions and 24 deletions

View File

@ -9,16 +9,31 @@ namespace DB
namespace
{
template <typename T>
struct Avg
{
using FieldType = std::conditional_t<IsDecimalNumber<T>, Decimal128, typename NearestFieldType<T>::Type>;
using Function = AggregateFunctionAvg<T, AggregateFunctionAvgData<FieldType>>;
};
template <typename T>
using AggregateFuncAvg = typename Avg<T>::Function;
AggregateFunctionPtr createAggregateFunctionAvg(const std::string & name, const DataTypes & argument_types, const Array & parameters)
{
assertNoParameters(name, parameters);
assertUnary(name, argument_types);
AggregateFunctionPtr res(createWithNumericType<AggregateFunctionAvg>(*argument_types[0]));
AggregateFunctionPtr res;
DataTypePtr data_type = argument_types[0];
if (isDecimal(data_type))
res.reset(createWithDecimalType<AggregateFuncAvg>(*data_type));
else
res.reset(createWithNumericType<AggregateFuncAvg>(*data_type));
if (!res)
throw Exception("Illegal type " + argument_types[0]->getName() + " of argument for aggregate function " + name, ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
throw Exception("Illegal type " + argument_types[0]->getName() + " of argument for aggregate function " + name,
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
return res;
}

View File

@ -4,6 +4,7 @@
#include <IO/ReadHelpers.h>
#include <DataTypes/DataTypesNumber.h>
#include <DataTypes/DataTypesDecimal.h>
#include <Columns/ColumnsNumber.h>
#include <AggregateFunctions/IAggregateFunction.h>
@ -22,20 +23,39 @@ struct AggregateFunctionAvgData
/// Calculates arithmetic mean of numbers.
template <typename T>
class AggregateFunctionAvg final : public IAggregateFunctionDataHelper<AggregateFunctionAvgData<typename NearestFieldType<T>::Type>, AggregateFunctionAvg<T>>
template <typename T, typename Data>
class AggregateFunctionAvg final : public IAggregateFunctionDataHelper<Data, AggregateFunctionAvg<T, Data>>
{
public:
using ResultType = std::conditional_t<IsDecimalNumber<T>, Decimal128, Float64>;
using ResultDataType = std::conditional_t<IsDecimalNumber<T>, DataTypeDecimal<Decimal128>, DataTypeNumber<Float64>>;
using ColVecType = std::conditional_t<IsDecimalNumber<T>, ColumnDecimal<T>, ColumnVector<T>>;
using ColVecResult = std::conditional_t<IsDecimalNumber<T>, ColumnDecimal<Decimal128>, ColumnVector<Float64>>;
/// ctor for native types
AggregateFunctionAvg()
: scale(0)
{}
/// ctor for Decimals
AggregateFunctionAvg(const IDataType & data_type)
: scale(getDecimalScale(data_type))
{}
String getName() const override { return "avg"; }
DataTypePtr getReturnType() const override
{
return std::make_shared<DataTypeFloat64>();
if constexpr (IsDecimalNumber<T>)
return std::make_shared<ResultDataType>(ResultDataType::maxPrecision(), scale);
else
return std::make_shared<ResultDataType>();
}
void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena *) const override
{
this->data(place).sum += static_cast<const ColumnVector<T> &>(*columns[0]).getData()[row_num];
const auto & column = static_cast<const ColVecType &>(*columns[0]);
this->data(place).sum += column.getData()[row_num];
++this->data(place).count;
}
@ -59,11 +79,14 @@ public:
void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override
{
static_cast<ColumnFloat64 &>(to).getData().push_back(
static_cast<Float64>(this->data(place).sum) / this->data(place).count);
auto & column = static_cast<ColVecResult &>(to);
column.getData().push_back(static_cast<ResultType>(this->data(place).sum) / this->data(place).count);
}
const char * getHeaderFilePath() const override { return __FILE__; }
private:
UInt32 scale;
};

View File

@ -1,12 +1,19 @@
101 101 101
-25.0000 -16.66666666 -10.00000000
25.0000 16.66666666 10.00000000
-50.0000 -16.66666666 -10.00000000
50.0000 16.66666666 10.00000000
0.0000 0.00000000 0.00000000 0.0000 0.00000000 0.00000000
1275.0000 424.99999983 255.00000000 1275.0000 424.99999983 255.00000000
-1275.0000 -424.99999983 -255.00000000 -1275.0000 -424.99999983 -255.00000000
101.0000 101.00000000 101.00000000 101.0000 101.00000000 101.00000000
-101.0000 -101.00000000 -101.00000000 -101.0000 -101.00000000 -101.00000000
-25.0000 -16.66666666 -10.00000000
25.0000 16.66666666 10.00000000
0.0000 0.00000000 0.00000000
25.5000 8.49999999 5.10000000
-25.5000 -8.49999999 -5.10000000
101 101 101
101 101 101
101 101 101
101 100 101
102 100 101
-50.0000 -50.0000 -16.66666666 -16.66666666 -10.00000000 -10.00000000
1.0000 1.0000 0.33333333 0.33333333 0.20000000 0.20000000
50.0000 50.0000 16.66666666 16.66666666 10.00000000 10.00000000
-1.0000 -1.0000 -0.33333333 -0.33333333 -0.20000000 -0.20000000

View File

@ -12,7 +12,7 @@ CREATE TABLE test.decimal
) ENGINE = Memory;
INSERT INTO test.decimal (a, b, c)
SELECT toDecimal32(number - 50, 4) / 2, toDecimal64(number - 50, 8) / 3, toDecimal128(number - 50, 8) / 5
SELECT toDecimal32(number - 50, 4), toDecimal64(number - 50, 8) / 3, toDecimal128(number - 50, 8) / 5
FROM system.numbers LIMIT 101;
SELECT count(a), count(b), count(c) FROM test.decimal;
@ -20,18 +20,24 @@ SELECT min(a), min(b), min(c) FROM test.decimal;
SELECT max(a), max(b), max(c) FROM test.decimal;
SELECT sum(a), sum(b), sum(c), sumWithOverflow(a), sumWithOverflow(b), sumWithOverflow(c) FROM test.decimal;
SELECT sum(a), sum(b), sum(c), sumWithOverflow(a), sumWithOverflow(b), sumWithOverflow(c) FROM test.decimal WHERE a > 0;
SELECT sum(a), sum(b), sum(c), sumWithOverflow(a), sumWithOverflow(b), sumWithOverflow(c) FROM test.decimal WHERE a < 0;
SELECT sum(a+1), sum(b+1), sum(c+1), sumWithOverflow(a+1), sumWithOverflow(b+1), sumWithOverflow(c+1) FROM test.decimal;
SELECT sum(a-1), sum(b-1), sum(c-1), sumWithOverflow(a-1), sumWithOverflow(b-1), sumWithOverflow(c-1) FROM test.decimal;
--SELECT avg(a), avg(b), avg(c) FROM test.decimal;
SELECT avg(a), avg(b), avg(c) FROM test.decimal;
SELECT avg(a), avg(b), avg(c) FROM test.decimal WHERE a > 0;
SELECT avg(a), avg(b), avg(c) FROM test.decimal WHERE a < 0;
SELECT argMin(a, b), argMin(b, a), argMin(c, a) FROM test.decimal;
SELECT argMax(a, b), argMax(b, a), argMax(c, a) FROM test.decimal;
SELECT uniq(a), uniq(b), uniq(c) FROM (SELECT * FROM test.decimal ORDER BY a);
SELECT uniqCombined(a), uniqCombined(b), uniqCombined(c) FROM (SELECT * FROM test.decimal ORDER BY a);
SELECT uniqExact(a), uniqExact(b), uniqExact(c) FROM (SELECT * FROM test.decimal ORDER BY a);
SELECT uniqHLL12(a), uniqHLL12(b), uniqHLL12(c) FROM (SELECT * FROM test.decimal ORDER BY a);
SELECT uniq(a), uniq(b), uniq(c) FROM test.decimal;
SELECT uniqCombined(a), uniqCombined(b), uniqCombined(c) FROM test.decimal;
SELECT uniqExact(a), uniqExact(b), uniqExact(c) FROM test.decimal;
SELECT uniqHLL12(a), uniqHLL12(b), uniqHLL12(c) FROM test.decimal;
SELECT argMin(a, b), argMin(a, c), argMin(b, a), argMin(b, c), argMin(c, a), argMin(c, b) FROM test.decimal;
SELECT argMin(a, b), argMin(a, c), argMin(b, a), argMin(b, c), argMin(c, a), argMin(c, b) FROM test.decimal WHERE a > 0;
SELECT argMax(a, b), argMax(a, c), argMax(b, a), argMax(b, c), argMax(c, a), argMax(c, b) FROM test.decimal;
SELECT argMax(a, b), argMax(a, c), argMax(b, a), argMax(b, c), argMax(c, a), argMax(c, b) FROM test.decimal WHERE a < 0;
--SELECT median(a), median(b), median(c) FROM test.decimal;
--SELECT quantile(a), quantile(b), quantile(c) FROM test.decimal;
@ -45,5 +51,3 @@ SELECT uniqHLL12(a), uniqHLL12(b), uniqHLL12(c) FROM test.decimal;
-- TODO: sumMap
-- TODO: groupArray, groupArrayInsertAt, groupUniqArray
--SELECT topK(2)(a), topK(2)(b), topK(2)(c) FROM test.decimal; TODO: deterministic