var/stddev for decimal [CLICKHOUSE-3765]

This commit is contained in:
chertus 2018-09-13 21:36:47 +03:00
parent 59f8313b83
commit 7adf8d29cf
4 changed files with 201 additions and 52 deletions

View File

@ -2,13 +2,17 @@
#include <cmath> #include <cmath>
#include <common/arithmeticOverflow.h>
#include <IO/WriteHelpers.h> #include <IO/WriteHelpers.h>
#include <IO/ReadHelpers.h> #include <IO/ReadHelpers.h>
#include <AggregateFunctions/IAggregateFunction.h> #include <AggregateFunctions/IAggregateFunction.h>
#include <DataTypes/DataTypesNumber.h> #include <DataTypes/DataTypesNumber.h>
#include <DataTypes/DataTypesDecimal.h>
#include <Columns/ColumnVector.h> #include <Columns/ColumnVector.h>
#include <Columns/ColumnDecimal.h>
/** This is simple, not numerically stable /** This is simple, not numerically stable
@ -26,17 +30,11 @@
namespace DB namespace DB
{ {
enum class VarianceMode namespace ErrorCodes
{ {
Population, extern const int LOGICAL_ERROR;
Sample extern const int DECIMAL_OVERFLOW;
}; }
enum class VariancePower
{
Original,
Sqrt
};
template <typename T> template <typename T>
@ -70,15 +68,74 @@ struct VarMoments
readPODBinary(*this, buf); readPODBinary(*this, buf);
} }
template <VarianceMode mode, VariancePower power> T getPopulation() const
T get() const
{ {
if (m0 == 0 && mode == VarianceMode::Sample) return (m2 - m1 * m1 / m0) / m0;
}
T getSample() const
{
if (m0 == 0)
return std::numeric_limits<T>::quiet_NaN();
return (m2 - m1 * m1 / m0) / (m0 - 1);
}
T get() const { throw Exception("Unexpected call", ErrorCodes::LOGICAL_ERROR); }
};
template <typename T>
struct VarMomentsDecimal
{
using NativeType = typename T::NativeType;
UInt64 m0{};
NativeType m1{};
NativeType m2{};
void add(NativeType x)
{
++m0;
m1 += x;
NativeType tmp; /// scale' = 2 * scale
if (common::mulOverflow(x, x, tmp) || common::addOverflow(m2, tmp, m2))
throw Exception("Decimal math overflow", ErrorCodes::DECIMAL_OVERFLOW);
}
void merge(const VarMomentsDecimal & rhs)
{
m0 += rhs.m0;
m1 += rhs.m1;
if (common::addOverflow(m2, rhs.m2, m2))
throw Exception("Decimal math overflow", ErrorCodes::DECIMAL_OVERFLOW);
}
void write(WriteBuffer & buf) const { writePODBinary(*this, buf); }
void read(ReadBuffer & buf) { readPODBinary(*this, buf); }
Float64 getPopulation(UInt32 scale) const
{
NativeType tmp;
if (common::mulOverflow(m1, m1, tmp) ||
common::subOverflow(m2, NativeType(tmp/m0), tmp))
throw Exception("Decimal math overflow", ErrorCodes::DECIMAL_OVERFLOW);
return convertFromDecimal<DataTypeDecimal<T>, DataTypeNumber<Float64>>(tmp / m0, scale);
}
Float64 getSample(UInt32 scale) const
{
if (m0 == 0)
return std::numeric_limits<T>::quiet_NaN(); return std::numeric_limits<T>::quiet_NaN();
T res = (m2 - m1 * m1 / m0) / (m0 - (mode == VarianceMode::Sample)); NativeType tmp;
return power == VariancePower::Original ? res : sqrt(res); if (common::mulOverflow(m1, m1, tmp) ||
common::subOverflow(m2, NativeType(tmp/m0), tmp))
throw Exception("Decimal math overflow", ErrorCodes::DECIMAL_OVERFLOW);
return convertFromDecimal<DataTypeDecimal<T>, DataTypeNumber<Float64>>(tmp / (m0 - 1), scale);
} }
Float64 get() const { throw Exception("Unexpected call", ErrorCodes::LOGICAL_ERROR); }
}; };
template <typename T> template <typename T>
@ -115,14 +172,19 @@ struct CovarMoments
readPODBinary(*this, buf); readPODBinary(*this, buf);
} }
template <VarianceMode mode> T getPopulation() const
T get() const
{ {
if (m0 == 0 && mode == VarianceMode::Sample) return (xy - x1 * y1 / m0) / m0;
return std::numeric_limits<T>::quiet_NaN();
return (xy - x1 * y1 / m0) / (m0 - (mode == VarianceMode::Sample));
} }
T getSample() const
{
if (m0 == 0)
return std::numeric_limits<T>::quiet_NaN();
return (xy - x1 * y1 / m0) / (m0 - 1);
}
T get() const { throw Exception("Unexpected call", ErrorCodes::LOGICAL_ERROR); }
}; };
template <typename T> template <typename T>
@ -169,6 +231,9 @@ struct CorrMoments
{ {
return (m0 * xy - x1 * y1) / sqrt((m0 * x2 - x1 * x1) * (m0 * y2 - y1 * y1)); return (m0 * xy - x1 * y1) / sqrt((m0 * x2 - x1 * x1) * (m0 * y2 - y1 * y1));
} }
T getPopulation() const { throw Exception("Unexpected call", ErrorCodes::LOGICAL_ERROR); }
T getSample() const { throw Exception("Unexpected call", ErrorCodes::LOGICAL_ERROR); }
}; };
@ -181,20 +246,54 @@ enum class StatisticsFunctionKind
}; };
template <typename T1, typename T2> template <typename T, StatisticsFunctionKind _kind>
using VarianceCalcType = std::conditional_t<std::is_same_v<T1, Float32> && std::is_same_v<T2, Float32>, Float32, Float64>; struct StatFuncOneArg
template <typename T1, typename T2, typename Data, StatisticsFunctionKind Kind>
class AggregateFunctionVarianceSimple final
: public IAggregateFunctionDataHelper<Data, AggregateFunctionVarianceSimple<T1, T2, Data, Kind>>
{ {
using ResultType = VarianceCalcType<T1, T2>; using Type1 = T;
using Type2 = T;
using ResultType = std::conditional_t<std::is_same_v<T, Float32>, Float32, Float64>;
using Data = std::conditional_t<IsDecimalNumber<T>, VarMomentsDecimal<Decimal128>, VarMoments<ResultType>>;
static constexpr StatisticsFunctionKind kind = _kind;
static constexpr UInt32 num_args = 1;
};
template <typename T1, typename T2, StatisticsFunctionKind _kind>
struct StatFuncTwoArg
{
using Type1 = T1;
using Type2 = T2;
using ResultType = std::conditional_t<std::is_same_v<T1, T2> && std::is_same_v<T1, Float32>, Float32, Float64>;
using Data = std::conditional_t<_kind == StatisticsFunctionKind::corr, CorrMoments<ResultType>, CovarMoments<ResultType>>;
static constexpr StatisticsFunctionKind kind = _kind;
static constexpr UInt32 num_args = 2;
};
template <typename StatFunc>
class AggregateFunctionVarianceSimple final
: public IAggregateFunctionDataHelper<typename StatFunc::Data, AggregateFunctionVarianceSimple<StatFunc>>
{
public: public:
using T1 = typename StatFunc::Type1;
using T2 = typename StatFunc::Type2;
using ColVecT1 = std::conditional_t<IsDecimalNumber<T1>, ColumnDecimal<T1>, ColumnVector<T1>>;
using ColVecT2 = std::conditional_t<IsDecimalNumber<T2>, ColumnDecimal<T2>, ColumnVector<T2>>;
using ResultType = typename StatFunc::ResultType;
using ColVecResult = ColumnVector<ResultType>;
AggregateFunctionVarianceSimple()
: src_scale(0)
{}
AggregateFunctionVarianceSimple(const IDataType & data_type)
: src_scale(getDecimalScale(data_type))
{}
String getName() const override String getName() const override
{ {
switch (Kind) switch (StatFunc::kind)
{ {
case StatisticsFunctionKind::varPop: return "varPop"; case StatisticsFunctionKind::varPop: return "varPop";
case StatisticsFunctionKind::varSamp: return "varSamp"; case StatisticsFunctionKind::varSamp: return "varSamp";
@ -214,13 +313,13 @@ public:
void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena *) const override void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena *) const override
{ {
if constexpr (Kind == StatisticsFunctionKind::covarPop || Kind == StatisticsFunctionKind::covarSamp || Kind == StatisticsFunctionKind::corr) if constexpr (StatFunc::num_args == 2)
this->data(place).add( this->data(place).add(
static_cast<const ColumnVector<T1> &>(*columns[0]).getData()[row_num], static_cast<const ColVecT1 &>(*columns[0]).getData()[row_num],
static_cast<const ColumnVector<T2> &>(*columns[1]).getData()[row_num]); static_cast<const ColVecT2 &>(*columns[1]).getData()[row_num]);
else else
this->data(place).add( this->data(place).add(
static_cast<const ColumnVector<T1> &>(*columns[0]).getData()[row_num]); static_cast<const ColVecT1 &>(*columns[0]).getData()[row_num]);
} }
void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena *) const override void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena *) const override
@ -241,27 +340,46 @@ public:
void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override
{ {
const auto & data = this->data(place); const auto & data = this->data(place);
auto & dst = static_cast<ColumnVector<ResultType> &>(to).getData(); auto & dst = static_cast<ColVecResult &>(to).getData();
if constexpr (Kind == StatisticsFunctionKind::varPop) dst.push_back(data.template get<VarianceMode::Population, VariancePower::Original>()); if constexpr (IsDecimalNumber<T1>)
else if constexpr (Kind == StatisticsFunctionKind::varSamp) dst.push_back(data.template get<VarianceMode::Sample, VariancePower::Original>()); {
else if constexpr (Kind == StatisticsFunctionKind::stddevPop) dst.push_back(data.template get<VarianceMode::Population, VariancePower::Sqrt>()); switch (StatFunc::kind)
else if constexpr (Kind == StatisticsFunctionKind::stddevSamp) dst.push_back(data.template get<VarianceMode::Sample, VariancePower::Sqrt>()); {
else if constexpr (Kind == StatisticsFunctionKind::covarPop) dst.push_back(data.template get<VarianceMode::Population>()); case StatisticsFunctionKind::varPop: dst.push_back(data.getPopulation(src_scale * 2)); break;
else if constexpr (Kind == StatisticsFunctionKind::covarSamp) dst.push_back(data.template get<VarianceMode::Sample>()); case StatisticsFunctionKind::varSamp: dst.push_back(data.getSample(src_scale * 2)); break;
else if constexpr (Kind == StatisticsFunctionKind::corr) dst.push_back(data.get()); case StatisticsFunctionKind::stddevPop: dst.push_back(sqrt(data.getPopulation(src_scale * 2))); break;
case StatisticsFunctionKind::stddevSamp: dst.push_back(sqrt(data.getSample(src_scale * 2))); break;
}
}
else
{
switch (StatFunc::kind)
{
case StatisticsFunctionKind::varPop: dst.push_back(data.getPopulation()); break;
case StatisticsFunctionKind::varSamp: dst.push_back(data.getSample()); break;
case StatisticsFunctionKind::stddevPop: dst.push_back(sqrt(data.getPopulation())); break;
case StatisticsFunctionKind::stddevSamp: dst.push_back(sqrt(data.getSample())); break;
case StatisticsFunctionKind::covarPop: dst.push_back(data.getPopulation()); break;
case StatisticsFunctionKind::covarSamp: dst.push_back(data.getSample()); break;
case StatisticsFunctionKind::corr: dst.push_back(data.get()); break;
}
}
} }
const char * getHeaderFilePath() const override { return __FILE__; } const char * getHeaderFilePath() const override { return __FILE__; }
private:
UInt32 src_scale;
}; };
template <typename T> using AggregateFunctionVarPopSimple = AggregateFunctionVarianceSimple<T, T, VarMoments<VarianceCalcType<T, T>>, StatisticsFunctionKind::varPop>; template <typename T> using AggregateFunctionVarPopSimple = AggregateFunctionVarianceSimple<StatFuncOneArg<T, StatisticsFunctionKind::varPop>>;
template <typename T> using AggregateFunctionVarSampSimple = AggregateFunctionVarianceSimple<T, T, VarMoments<VarianceCalcType<T, T>>, StatisticsFunctionKind::varSamp>; template <typename T> using AggregateFunctionVarSampSimple = AggregateFunctionVarianceSimple<StatFuncOneArg<T, StatisticsFunctionKind::varSamp>>;
template <typename T> using AggregateFunctionStddevPopSimple = AggregateFunctionVarianceSimple<T, T, VarMoments<VarianceCalcType<T, T>>, StatisticsFunctionKind::stddevPop>; template <typename T> using AggregateFunctionStddevPopSimple = AggregateFunctionVarianceSimple<StatFuncOneArg<T, StatisticsFunctionKind::stddevPop>>;
template <typename T> using AggregateFunctionStddevSampSimple = AggregateFunctionVarianceSimple<T, T, VarMoments<VarianceCalcType<T, T>>, StatisticsFunctionKind::stddevSamp>; template <typename T> using AggregateFunctionStddevSampSimple = AggregateFunctionVarianceSimple<StatFuncOneArg<T, StatisticsFunctionKind::stddevSamp>>;
template <typename T1, typename T2> using AggregateFunctionCovarPopSimple = AggregateFunctionVarianceSimple<T1, T2, CovarMoments<VarianceCalcType<T1, T2>>, StatisticsFunctionKind::covarPop>; template <typename T1, typename T2> using AggregateFunctionCovarPopSimple = AggregateFunctionVarianceSimple<StatFuncTwoArg<T1, T2, StatisticsFunctionKind::covarPop>>;
template <typename T1, typename T2> using AggregateFunctionCovarSampSimple = AggregateFunctionVarianceSimple<T1, T2, CovarMoments<VarianceCalcType<T1, T2>>, StatisticsFunctionKind::covarSamp>; template <typename T1, typename T2> using AggregateFunctionCovarSampSimple = AggregateFunctionVarianceSimple<StatFuncTwoArg<T1, T2, StatisticsFunctionKind::covarSamp>>;
template <typename T1, typename T2> using AggregateFunctionCorrSimple = AggregateFunctionVarianceSimple<T1, T2, CorrMoments<VarianceCalcType<T1, T2>>, StatisticsFunctionKind::corr>; template <typename T1, typename T2> using AggregateFunctionCorrSimple = AggregateFunctionVarianceSimple<StatFuncTwoArg<T1, T2, StatisticsFunctionKind::corr>>;
} }

View File

@ -21,11 +21,18 @@ AggregateFunctionPtr createAggregateFunctionStatisticsUnary(const std::string &
assertNoParameters(name, parameters); assertNoParameters(name, parameters);
assertUnary(name, argument_types); assertUnary(name, argument_types);
AggregateFunctionPtr res(createWithNumericType<FunctionTemplate>(*argument_types[0])); AggregateFunctionPtr res;
DataTypePtr data_type = argument_types[0];
if (isDecimal(data_type))
{
res.reset(createWithDecimalType<FunctionTemplate>(*data_type));
}
else
res.reset(createWithNumericType<FunctionTemplate>(*data_type));
if (!res) if (!res)
throw Exception("Illegal type " + argument_types[0]->getName() + " of argument for aggregate function " + name, ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); throw Exception("Illegal type " + argument_types[0]->getName() + " of argument for aggregate function " + name,
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
return res; return res;
} }
@ -51,6 +58,7 @@ void registerAggregateFunctionsStatisticsSimple(AggregateFunctionFactory & facto
factory.registerFunction("varPop", createAggregateFunctionStatisticsUnary<AggregateFunctionVarPopSimple>); factory.registerFunction("varPop", createAggregateFunctionStatisticsUnary<AggregateFunctionVarPopSimple>);
factory.registerFunction("stddevSamp", createAggregateFunctionStatisticsUnary<AggregateFunctionStddevSampSimple>); factory.registerFunction("stddevSamp", createAggregateFunctionStatisticsUnary<AggregateFunctionStddevSampSimple>);
factory.registerFunction("stddevPop", createAggregateFunctionStatisticsUnary<AggregateFunctionStddevPopSimple>); factory.registerFunction("stddevPop", createAggregateFunctionStatisticsUnary<AggregateFunctionStddevPopSimple>);
factory.registerFunction("covarSamp", createAggregateFunctionStatisticsBinary<AggregateFunctionCovarSampSimple>); factory.registerFunction("covarSamp", createAggregateFunctionStatisticsBinary<AggregateFunctionCovarSampSimple>);
factory.registerFunction("covarPop", createAggregateFunctionStatisticsBinary<AggregateFunctionCovarPopSimple>); factory.registerFunction("covarPop", createAggregateFunctionStatisticsBinary<AggregateFunctionCovarPopSimple>);
factory.registerFunction("corr", createAggregateFunctionStatisticsBinary<AggregateFunctionCorrSimple>, AggregateFunctionFactory::CaseInsensitive); factory.registerFunction("corr", createAggregateFunctionStatisticsBinary<AggregateFunctionCorrSimple>, AggregateFunctionFactory::CaseInsensitive);

View File

@ -26,3 +26,11 @@
[-50.0000,-40.0000,-30.0000,-20.0000,-10.0000,0.0000,10.0000,20.0000,30.0000,40.0000,50.0000] [-50.0000,-40.0000,-30.0000,-20.0000,-10.0000,0.0000,10.0000,20.0000,30.0000,40.0000,50.0000]
[-16.66666666,-13.33333333,-10.00000000,-6.66666666,-3.33333333,0.00000000,3.33333333,6.66666666,10.00000000,13.33333333,16.66666666] [-16.66666666,-13.33333333,-10.00000000,-6.66666666,-3.33333333,0.00000000,3.33333333,6.66666666,10.00000000,13.33333333,16.66666666]
[-10.00000000,-8.00000000,-6.00000000,-4.00000000,-2.00000000,0.00000000,2.00000000,4.00000000,6.00000000,8.00000000,10.00000000] [-10.00000000,-8.00000000,-6.00000000,-4.00000000,-2.00000000,0.00000000,2.00000000,4.00000000,6.00000000,8.00000000,10.00000000]
850 94.44444438684269 34 Float64 Float64 Float64
850 94.4444443868427 34.00000000000001
858.5 95.38888883071111 34.34 Float64 Float64 Float64
858.5 95.38888883071112 34.34
29.154759474226502 9.718253155111915 5.830951894845301 Float64 Float64 Float64
29.154759474226502 9.718253155111915 5.830951894845301
29.300170647967224 9.766723546344041 5.860034129593445 Float64 Float64 Float64
29.300170647967224 9.766723546344041 5.860034129593445

View File

@ -54,6 +54,21 @@ SELECT quantilesExact(0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0)(a)
SELECT quantilesExact(0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0)(b) FROM test.decimal; SELECT quantilesExact(0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0)(b) FROM test.decimal;
SELECT quantilesExact(0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0)(c) FROM test.decimal; SELECT quantilesExact(0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0)(c) FROM test.decimal;
SELECT varPop(a) AS va, varPop(b) AS vb, varPop(c) AS vc, toTypeName(va), toTypeName(vb), toTypeName(vc) FROM test.decimal;
SELECT varPop(toFloat64(a)), varPop(toFloat64(b)), varPop(toFloat64(c)) FROM test.decimal;
SELECT varSamp(a) AS va, varSamp(b) AS vb, varSamp(c) AS vc, toTypeName(va), toTypeName(vb), toTypeName(vc) FROM test.decimal;
SELECT varSamp(toFloat64(a)), varSamp(toFloat64(b)), varSamp(toFloat64(c)) FROM test.decimal;
SELECT stddevPop(a) AS da, stddevPop(b) AS db, stddevPop(c) AS dc, toTypeName(da), toTypeName(db), toTypeName(dc) FROM test.decimal;
SELECT stddevPop(toFloat64(a)), stddevPop(toFloat64(b)), stddevPop(toFloat64(c)) FROM test.decimal;
SELECT stddevSamp(a) AS da, stddevSamp(b) AS db, stddevSamp(c) AS dc, toTypeName(da), toTypeName(db), toTypeName(dc) FROM test.decimal;
SELECT stddevSamp(toFloat64(a)), stddevSamp(toFloat64(b)), stddevSamp(toFloat64(c)) FROM test.decimal;
SELECT covarPop(a, a), covarPop(b, b), covarPop(c, c) FROM test.decimal; -- { serverError 43 }
SELECT covarSamp(a, a), covarSamp(b, b), covarSamp(c, c) FROM test.decimal; -- { serverError 43 }
SELECT corr(a, a), corr(b, b), corr(c, c) FROM test.decimal; -- { serverError 43 }
SELECT 1 LIMIT 0;
-- TODO: sumMap -- TODO: sumMap
-- TODO: other quantile(s) -- TODO: other quantile(s)
-- TODO: groupArray, groupArrayInsertAt, groupUniqArray -- TODO: groupArray, groupArrayInsertAt, groupUniqArray