var/stddev for decimal [CLICKHOUSE-3765]

This commit is contained in:
chertus 2018-09-13 21:36:47 +03:00
parent 59f8313b83
commit 7adf8d29cf
4 changed files with 201 additions and 52 deletions

View File

@ -2,13 +2,17 @@
#include <cmath>
#include <common/arithmeticOverflow.h>
#include <IO/WriteHelpers.h>
#include <IO/ReadHelpers.h>
#include <AggregateFunctions/IAggregateFunction.h>
#include <DataTypes/DataTypesNumber.h>
#include <DataTypes/DataTypesDecimal.h>
#include <Columns/ColumnVector.h>
#include <Columns/ColumnDecimal.h>
/** This is simple, not numerically stable
@ -26,17 +30,11 @@
namespace DB
{
enum class VarianceMode
namespace ErrorCodes
{
Population,
Sample
};
enum class VariancePower
{
Original,
Sqrt
};
extern const int LOGICAL_ERROR;
extern const int DECIMAL_OVERFLOW;
}
template <typename T>
@ -70,15 +68,74 @@ struct VarMoments
readPODBinary(*this, buf);
}
template <VarianceMode mode, VariancePower power>
T get() const
T getPopulation() const
{
if (m0 == 0 && mode == VarianceMode::Sample)
return (m2 - m1 * m1 / m0) / m0;
}
T getSample() const
{
if (m0 == 0)
return std::numeric_limits<T>::quiet_NaN();
return (m2 - m1 * m1 / m0) / (m0 - 1);
}
T get() const { throw Exception("Unexpected call", ErrorCodes::LOGICAL_ERROR); }
};
template <typename T>
struct VarMomentsDecimal
{
using NativeType = typename T::NativeType;
UInt64 m0{};
NativeType m1{};
NativeType m2{};
void add(NativeType x)
{
++m0;
m1 += x;
NativeType tmp; /// scale' = 2 * scale
if (common::mulOverflow(x, x, tmp) || common::addOverflow(m2, tmp, m2))
throw Exception("Decimal math overflow", ErrorCodes::DECIMAL_OVERFLOW);
}
void merge(const VarMomentsDecimal & rhs)
{
m0 += rhs.m0;
m1 += rhs.m1;
if (common::addOverflow(m2, rhs.m2, m2))
throw Exception("Decimal math overflow", ErrorCodes::DECIMAL_OVERFLOW);
}
void write(WriteBuffer & buf) const { writePODBinary(*this, buf); }
void read(ReadBuffer & buf) { readPODBinary(*this, buf); }
Float64 getPopulation(UInt32 scale) const
{
NativeType tmp;
if (common::mulOverflow(m1, m1, tmp) ||
common::subOverflow(m2, NativeType(tmp/m0), tmp))
throw Exception("Decimal math overflow", ErrorCodes::DECIMAL_OVERFLOW);
return convertFromDecimal<DataTypeDecimal<T>, DataTypeNumber<Float64>>(tmp / m0, scale);
}
Float64 getSample(UInt32 scale) const
{
if (m0 == 0)
return std::numeric_limits<T>::quiet_NaN();
T res = (m2 - m1 * m1 / m0) / (m0 - (mode == VarianceMode::Sample));
return power == VariancePower::Original ? res : sqrt(res);
NativeType tmp;
if (common::mulOverflow(m1, m1, tmp) ||
common::subOverflow(m2, NativeType(tmp/m0), tmp))
throw Exception("Decimal math overflow", ErrorCodes::DECIMAL_OVERFLOW);
return convertFromDecimal<DataTypeDecimal<T>, DataTypeNumber<Float64>>(tmp / (m0 - 1), scale);
}
Float64 get() const { throw Exception("Unexpected call", ErrorCodes::LOGICAL_ERROR); }
};
template <typename T>
@ -115,14 +172,19 @@ struct CovarMoments
readPODBinary(*this, buf);
}
template <VarianceMode mode>
T get() const
T getPopulation() const
{
if (m0 == 0 && mode == VarianceMode::Sample)
return std::numeric_limits<T>::quiet_NaN();
return (xy - x1 * y1 / m0) / (m0 - (mode == VarianceMode::Sample));
return (xy - x1 * y1 / m0) / m0;
}
T getSample() const
{
if (m0 == 0)
return std::numeric_limits<T>::quiet_NaN();
return (xy - x1 * y1 / m0) / (m0 - 1);
}
T get() const { throw Exception("Unexpected call", ErrorCodes::LOGICAL_ERROR); }
};
template <typename T>
@ -169,6 +231,9 @@ struct CorrMoments
{
return (m0 * xy - x1 * y1) / sqrt((m0 * x2 - x1 * x1) * (m0 * y2 - y1 * y1));
}
T getPopulation() const { throw Exception("Unexpected call", ErrorCodes::LOGICAL_ERROR); }
T getSample() const { throw Exception("Unexpected call", ErrorCodes::LOGICAL_ERROR); }
};
@ -181,20 +246,54 @@ enum class StatisticsFunctionKind
};
template <typename T1, typename T2>
using VarianceCalcType = std::conditional_t<std::is_same_v<T1, Float32> && std::is_same_v<T2, Float32>, Float32, Float64>;
template <typename T1, typename T2, typename Data, StatisticsFunctionKind Kind>
class AggregateFunctionVarianceSimple final
: public IAggregateFunctionDataHelper<Data, AggregateFunctionVarianceSimple<T1, T2, Data, Kind>>
template <typename T, StatisticsFunctionKind _kind>
struct StatFuncOneArg
{
using ResultType = VarianceCalcType<T1, T2>;
using Type1 = T;
using Type2 = T;
using ResultType = std::conditional_t<std::is_same_v<T, Float32>, Float32, Float64>;
using Data = std::conditional_t<IsDecimalNumber<T>, VarMomentsDecimal<Decimal128>, VarMoments<ResultType>>;
static constexpr StatisticsFunctionKind kind = _kind;
static constexpr UInt32 num_args = 1;
};
template <typename T1, typename T2, StatisticsFunctionKind _kind>
struct StatFuncTwoArg
{
using Type1 = T1;
using Type2 = T2;
using ResultType = std::conditional_t<std::is_same_v<T1, T2> && std::is_same_v<T1, Float32>, Float32, Float64>;
using Data = std::conditional_t<_kind == StatisticsFunctionKind::corr, CorrMoments<ResultType>, CovarMoments<ResultType>>;
static constexpr StatisticsFunctionKind kind = _kind;
static constexpr UInt32 num_args = 2;
};
template <typename StatFunc>
class AggregateFunctionVarianceSimple final
: public IAggregateFunctionDataHelper<typename StatFunc::Data, AggregateFunctionVarianceSimple<StatFunc>>
{
public:
using T1 = typename StatFunc::Type1;
using T2 = typename StatFunc::Type2;
using ColVecT1 = std::conditional_t<IsDecimalNumber<T1>, ColumnDecimal<T1>, ColumnVector<T1>>;
using ColVecT2 = std::conditional_t<IsDecimalNumber<T2>, ColumnDecimal<T2>, ColumnVector<T2>>;
using ResultType = typename StatFunc::ResultType;
using ColVecResult = ColumnVector<ResultType>;
AggregateFunctionVarianceSimple()
: src_scale(0)
{}
AggregateFunctionVarianceSimple(const IDataType & data_type)
: src_scale(getDecimalScale(data_type))
{}
String getName() const override
{
switch (Kind)
switch (StatFunc::kind)
{
case StatisticsFunctionKind::varPop: return "varPop";
case StatisticsFunctionKind::varSamp: return "varSamp";
@ -214,13 +313,13 @@ public:
void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena *) const override
{
if constexpr (Kind == StatisticsFunctionKind::covarPop || Kind == StatisticsFunctionKind::covarSamp || Kind == StatisticsFunctionKind::corr)
if constexpr (StatFunc::num_args == 2)
this->data(place).add(
static_cast<const ColumnVector<T1> &>(*columns[0]).getData()[row_num],
static_cast<const ColumnVector<T2> &>(*columns[1]).getData()[row_num]);
static_cast<const ColVecT1 &>(*columns[0]).getData()[row_num],
static_cast<const ColVecT2 &>(*columns[1]).getData()[row_num]);
else
this->data(place).add(
static_cast<const ColumnVector<T1> &>(*columns[0]).getData()[row_num]);
static_cast<const ColVecT1 &>(*columns[0]).getData()[row_num]);
}
void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena *) const override
@ -241,27 +340,46 @@ public:
void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override
{
const auto & data = this->data(place);
auto & dst = static_cast<ColumnVector<ResultType> &>(to).getData();
auto & dst = static_cast<ColVecResult &>(to).getData();
if constexpr (Kind == StatisticsFunctionKind::varPop) dst.push_back(data.template get<VarianceMode::Population, VariancePower::Original>());
else if constexpr (Kind == StatisticsFunctionKind::varSamp) dst.push_back(data.template get<VarianceMode::Sample, VariancePower::Original>());
else if constexpr (Kind == StatisticsFunctionKind::stddevPop) dst.push_back(data.template get<VarianceMode::Population, VariancePower::Sqrt>());
else if constexpr (Kind == StatisticsFunctionKind::stddevSamp) dst.push_back(data.template get<VarianceMode::Sample, VariancePower::Sqrt>());
else if constexpr (Kind == StatisticsFunctionKind::covarPop) dst.push_back(data.template get<VarianceMode::Population>());
else if constexpr (Kind == StatisticsFunctionKind::covarSamp) dst.push_back(data.template get<VarianceMode::Sample>());
else if constexpr (Kind == StatisticsFunctionKind::corr) dst.push_back(data.get());
if constexpr (IsDecimalNumber<T1>)
{
switch (StatFunc::kind)
{
case StatisticsFunctionKind::varPop: dst.push_back(data.getPopulation(src_scale * 2)); break;
case StatisticsFunctionKind::varSamp: dst.push_back(data.getSample(src_scale * 2)); break;
case StatisticsFunctionKind::stddevPop: dst.push_back(sqrt(data.getPopulation(src_scale * 2))); break;
case StatisticsFunctionKind::stddevSamp: dst.push_back(sqrt(data.getSample(src_scale * 2))); break;
}
}
else
{
switch (StatFunc::kind)
{
case StatisticsFunctionKind::varPop: dst.push_back(data.getPopulation()); break;
case StatisticsFunctionKind::varSamp: dst.push_back(data.getSample()); break;
case StatisticsFunctionKind::stddevPop: dst.push_back(sqrt(data.getPopulation())); break;
case StatisticsFunctionKind::stddevSamp: dst.push_back(sqrt(data.getSample())); break;
case StatisticsFunctionKind::covarPop: dst.push_back(data.getPopulation()); break;
case StatisticsFunctionKind::covarSamp: dst.push_back(data.getSample()); break;
case StatisticsFunctionKind::corr: dst.push_back(data.get()); break;
}
}
}
const char * getHeaderFilePath() const override { return __FILE__; }
private:
UInt32 src_scale;
};
template <typename T> using AggregateFunctionVarPopSimple = AggregateFunctionVarianceSimple<T, T, VarMoments<VarianceCalcType<T, T>>, StatisticsFunctionKind::varPop>;
template <typename T> using AggregateFunctionVarSampSimple = AggregateFunctionVarianceSimple<T, T, VarMoments<VarianceCalcType<T, T>>, StatisticsFunctionKind::varSamp>;
template <typename T> using AggregateFunctionStddevPopSimple = AggregateFunctionVarianceSimple<T, T, VarMoments<VarianceCalcType<T, T>>, StatisticsFunctionKind::stddevPop>;
template <typename T> using AggregateFunctionStddevSampSimple = AggregateFunctionVarianceSimple<T, T, VarMoments<VarianceCalcType<T, T>>, StatisticsFunctionKind::stddevSamp>;
template <typename T1, typename T2> using AggregateFunctionCovarPopSimple = AggregateFunctionVarianceSimple<T1, T2, CovarMoments<VarianceCalcType<T1, T2>>, StatisticsFunctionKind::covarPop>;
template <typename T1, typename T2> using AggregateFunctionCovarSampSimple = AggregateFunctionVarianceSimple<T1, T2, CovarMoments<VarianceCalcType<T1, T2>>, StatisticsFunctionKind::covarSamp>;
template <typename T1, typename T2> using AggregateFunctionCorrSimple = AggregateFunctionVarianceSimple<T1, T2, CorrMoments<VarianceCalcType<T1, T2>>, StatisticsFunctionKind::corr>;
template <typename T> using AggregateFunctionVarPopSimple = AggregateFunctionVarianceSimple<StatFuncOneArg<T, StatisticsFunctionKind::varPop>>;
template <typename T> using AggregateFunctionVarSampSimple = AggregateFunctionVarianceSimple<StatFuncOneArg<T, StatisticsFunctionKind::varSamp>>;
template <typename T> using AggregateFunctionStddevPopSimple = AggregateFunctionVarianceSimple<StatFuncOneArg<T, StatisticsFunctionKind::stddevPop>>;
template <typename T> using AggregateFunctionStddevSampSimple = AggregateFunctionVarianceSimple<StatFuncOneArg<T, StatisticsFunctionKind::stddevSamp>>;
template <typename T1, typename T2> using AggregateFunctionCovarPopSimple = AggregateFunctionVarianceSimple<StatFuncTwoArg<T1, T2, StatisticsFunctionKind::covarPop>>;
template <typename T1, typename T2> using AggregateFunctionCovarSampSimple = AggregateFunctionVarianceSimple<StatFuncTwoArg<T1, T2, StatisticsFunctionKind::covarSamp>>;
template <typename T1, typename T2> using AggregateFunctionCorrSimple = AggregateFunctionVarianceSimple<StatFuncTwoArg<T1, T2, StatisticsFunctionKind::corr>>;
}

View File

@ -21,11 +21,18 @@ AggregateFunctionPtr createAggregateFunctionStatisticsUnary(const std::string &
assertNoParameters(name, parameters);
assertUnary(name, argument_types);
AggregateFunctionPtr res(createWithNumericType<FunctionTemplate>(*argument_types[0]));
AggregateFunctionPtr res;
DataTypePtr data_type = argument_types[0];
if (isDecimal(data_type))
{
res.reset(createWithDecimalType<FunctionTemplate>(*data_type));
}
else
res.reset(createWithNumericType<FunctionTemplate>(*data_type));
if (!res)
throw Exception("Illegal type " + argument_types[0]->getName() + " of argument for aggregate function " + name, ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
throw Exception("Illegal type " + argument_types[0]->getName() + " of argument for aggregate function " + name,
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
return res;
}
@ -51,6 +58,7 @@ void registerAggregateFunctionsStatisticsSimple(AggregateFunctionFactory & facto
factory.registerFunction("varPop", createAggregateFunctionStatisticsUnary<AggregateFunctionVarPopSimple>);
factory.registerFunction("stddevSamp", createAggregateFunctionStatisticsUnary<AggregateFunctionStddevSampSimple>);
factory.registerFunction("stddevPop", createAggregateFunctionStatisticsUnary<AggregateFunctionStddevPopSimple>);
factory.registerFunction("covarSamp", createAggregateFunctionStatisticsBinary<AggregateFunctionCovarSampSimple>);
factory.registerFunction("covarPop", createAggregateFunctionStatisticsBinary<AggregateFunctionCovarPopSimple>);
factory.registerFunction("corr", createAggregateFunctionStatisticsBinary<AggregateFunctionCorrSimple>, AggregateFunctionFactory::CaseInsensitive);

View File

@ -26,3 +26,11 @@
[-50.0000,-40.0000,-30.0000,-20.0000,-10.0000,0.0000,10.0000,20.0000,30.0000,40.0000,50.0000]
[-16.66666666,-13.33333333,-10.00000000,-6.66666666,-3.33333333,0.00000000,3.33333333,6.66666666,10.00000000,13.33333333,16.66666666]
[-10.00000000,-8.00000000,-6.00000000,-4.00000000,-2.00000000,0.00000000,2.00000000,4.00000000,6.00000000,8.00000000,10.00000000]
850 94.44444438684269 34 Float64 Float64 Float64
850 94.4444443868427 34.00000000000001
858.5 95.38888883071111 34.34 Float64 Float64 Float64
858.5 95.38888883071112 34.34
29.154759474226502 9.718253155111915 5.830951894845301 Float64 Float64 Float64
29.154759474226502 9.718253155111915 5.830951894845301
29.300170647967224 9.766723546344041 5.860034129593445 Float64 Float64 Float64
29.300170647967224 9.766723546344041 5.860034129593445

View File

@ -54,6 +54,21 @@ SELECT quantilesExact(0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0)(a)
SELECT quantilesExact(0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0)(b) FROM test.decimal;
SELECT quantilesExact(0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0)(c) FROM test.decimal;
SELECT varPop(a) AS va, varPop(b) AS vb, varPop(c) AS vc, toTypeName(va), toTypeName(vb), toTypeName(vc) FROM test.decimal;
SELECT varPop(toFloat64(a)), varPop(toFloat64(b)), varPop(toFloat64(c)) FROM test.decimal;
SELECT varSamp(a) AS va, varSamp(b) AS vb, varSamp(c) AS vc, toTypeName(va), toTypeName(vb), toTypeName(vc) FROM test.decimal;
SELECT varSamp(toFloat64(a)), varSamp(toFloat64(b)), varSamp(toFloat64(c)) FROM test.decimal;
SELECT stddevPop(a) AS da, stddevPop(b) AS db, stddevPop(c) AS dc, toTypeName(da), toTypeName(db), toTypeName(dc) FROM test.decimal;
SELECT stddevPop(toFloat64(a)), stddevPop(toFloat64(b)), stddevPop(toFloat64(c)) FROM test.decimal;
SELECT stddevSamp(a) AS da, stddevSamp(b) AS db, stddevSamp(c) AS dc, toTypeName(da), toTypeName(db), toTypeName(dc) FROM test.decimal;
SELECT stddevSamp(toFloat64(a)), stddevSamp(toFloat64(b)), stddevSamp(toFloat64(c)) FROM test.decimal;
SELECT covarPop(a, a), covarPop(b, b), covarPop(c, c) FROM test.decimal; -- { serverError 43 }
SELECT covarSamp(a, a), covarSamp(b, b), covarSamp(c, c) FROM test.decimal; -- { serverError 43 }
SELECT corr(a, a), corr(b, b), corr(c, c) FROM test.decimal; -- { serverError 43 }
SELECT 1 LIMIT 0;
-- TODO: sumMap
-- TODO: other quantile(s)
-- TODO: groupArray, groupArrayInsertAt, groupUniqArray