#pragma once #include #include #include #include #include #include #include #include #include /** This is simple, not numerically stable * implementations of variance/covariance/correlation functions. * * It is about two times faster than stable variants. * Numerical errors may occur during summation. * * This implementation is selected as default, * because "you don't pay for what you don't need" principle. * * For more sophisticated implementation, look at AggregateFunctionStatistics.h */ namespace DB { namespace ErrorCodes { extern const int DECIMAL_OVERFLOW; } template struct VarMoments { T m0{}; T m1{}; T m2{}; void add(T x) { ++m0; m1 += x; m2 += x * x; } void merge(const VarMoments & rhs) { m0 += rhs.m0; m1 += rhs.m1; m2 += rhs.m2; } void write(WriteBuffer & buf) const { writePODBinary(*this, buf); } void read(ReadBuffer & buf) { readPODBinary(*this, buf); } T NO_SANITIZE_UNDEFINED getPopulation() const { return (m2 - m1 * m1 / m0) / m0; } T NO_SANITIZE_UNDEFINED getSample() const { if (m0 == 0) return std::numeric_limits::quiet_NaN(); return (m2 - m1 * m1 / m0) / (m0 - 1); } }; template struct VarMomentsDecimal { using NativeType = typename T::NativeType; UInt64 m0{}; NativeType m1{}; NativeType m2{}; void add(NativeType x) { ++m0; m1 += x; NativeType tmp; /// scale' = 2 * scale if (common::mulOverflow(x, x, tmp) || common::addOverflow(m2, tmp, m2)) throw Exception("Decimal math overflow", ErrorCodes::DECIMAL_OVERFLOW); } void merge(const VarMomentsDecimal & rhs) { m0 += rhs.m0; m1 += rhs.m1; if (common::addOverflow(m2, rhs.m2, m2)) throw Exception("Decimal math overflow", ErrorCodes::DECIMAL_OVERFLOW); } void write(WriteBuffer & buf) const { writePODBinary(*this, buf); } void read(ReadBuffer & buf) { readPODBinary(*this, buf); } Float64 getPopulation(UInt32 scale) const { if (m0 == 0) return std::numeric_limits::infinity(); NativeType tmp; if (common::mulOverflow(m1, m1, tmp) || common::subOverflow(m2, NativeType(tmp / m0), tmp)) throw Exception("Decimal math overflow", ErrorCodes::DECIMAL_OVERFLOW); return convertFromDecimal, DataTypeNumber>(tmp / m0, scale); } Float64 getSample(UInt32 scale) const { if (m0 == 0) return std::numeric_limits::quiet_NaN(); if (m0 == 1) return std::numeric_limits::infinity(); NativeType tmp; if (common::mulOverflow(m1, m1, tmp) || common::subOverflow(m2, NativeType(tmp / m0), tmp)) throw Exception("Decimal math overflow", ErrorCodes::DECIMAL_OVERFLOW); return convertFromDecimal, DataTypeNumber>(tmp / (m0 - 1), scale); } }; template struct HighOrderMoments { T m[_level + 1]{}; void add(T x) { ++m[0]; m[1] += x; m[2] += x * x; if constexpr (_level >= 3) m[3] += x * x * x; if constexpr (_level >= 4) m[4] += x * x * x * x; } void merge(const HighOrderMoments & rhs) { m[0] += rhs.m[0]; m[1] += rhs.m[1]; m[2] += rhs.m[2]; if constexpr (_level >= 3) m[3] += rhs.m[3]; if constexpr (_level >= 4) m[4] += rhs.m[4]; } void write(WriteBuffer & buf) const { writePODBinary(*this, buf); } void read(ReadBuffer & buf) { readPODBinary(*this, buf); } T NO_SANITIZE_UNDEFINED getPopulation() const { return (m[2] - m[1] * m[1] / m[0]) / m[0]; } T NO_SANITIZE_UNDEFINED getSample() const { if (m[0] == 0) return std::numeric_limits::quiet_NaN(); return (m[2] - m[1] * m[1] / m[0]) / (m[0] - 1); } T NO_SANITIZE_UNDEFINED getMoment3() const { // to avoid accuracy problem if (m[0] == 1) return 0; return (m[3] - (3 * m[2] - 2 * m[1] * m[1] / m[0] ) * m[1] / m[0] ) / m[0]; } T NO_SANITIZE_UNDEFINED getMoment4() const { // to avoid accuracy problem if (m[0] == 1) return 0; return (m[4] - (4 * m[3] - (6 * m[2] - 3 * m[1] * m[1] / m[0] ) * m[1] / m[0] ) * m[1] / m[0] ) / m[0]; } }; /** Calculating high-order central moments References: https://en.wikipedia.org/wiki/Moment_(mathematics) https://en.wikipedia.org/wiki/Skewness https://en.wikipedia.org/wiki/Kurtosis */ template struct HighOrderMomentsDecimal { using NativeType = typename T::NativeType; UInt64 m0{}; NativeType m[_level]{}; NativeType & getM(size_t i) { return m[i - 1]; } const NativeType & getM(size_t i) const { return m[i - 1]; } void add(NativeType x) { ++m0; getM(1) += x; NativeType tmp; if (common::mulOverflow(x, x, tmp) || common::addOverflow(getM(2), tmp, getM(2))) throw Exception("Decimal math overflow", ErrorCodes::DECIMAL_OVERFLOW); if constexpr (_level >= 3) if (common::mulOverflow(tmp, x, tmp) || common::addOverflow(getM(3), tmp, getM(3))) throw Exception("Decimal math overflow", ErrorCodes::DECIMAL_OVERFLOW); if constexpr (_level >= 4) if (common::mulOverflow(tmp, x, tmp) || common::addOverflow(getM(4), tmp, getM(4))) throw Exception("Decimal math overflow", ErrorCodes::DECIMAL_OVERFLOW); } void merge(const HighOrderMomentsDecimal & rhs) { m0 += rhs.m0; getM(1) += rhs.getM(1); if (common::addOverflow(getM(2), rhs.getM(2), getM(2))) throw Exception("Decimal math overflow", ErrorCodes::DECIMAL_OVERFLOW); if constexpr (_level >= 3) if (common::addOverflow(getM(3), rhs.getM(3), getM(3))) throw Exception("Decimal math overflow", ErrorCodes::DECIMAL_OVERFLOW); if constexpr (_level >= 4) if (common::addOverflow(getM(4), rhs.getM(4), getM(4))) throw Exception("Decimal math overflow", ErrorCodes::DECIMAL_OVERFLOW); } void write(WriteBuffer & buf) const { writePODBinary(*this, buf); } void read(ReadBuffer & buf) { readPODBinary(*this, buf); } Float64 getPopulation(UInt32 scale) const { if (m0 == 0) return std::numeric_limits::infinity(); NativeType tmp; if (common::mulOverflow(getM(1), getM(1), tmp) || common::subOverflow(getM(2), NativeType(tmp / m0), tmp)) throw Exception("Decimal math overflow", ErrorCodes::DECIMAL_OVERFLOW); return convertFromDecimal, DataTypeNumber>(tmp / m0, scale); } Float64 getSample(UInt32 scale) const { if (m0 == 0) return std::numeric_limits::quiet_NaN(); if (m0 == 1) return std::numeric_limits::infinity(); NativeType tmp; if (common::mulOverflow(getM(1), getM(1), tmp) || common::subOverflow(getM(2), NativeType(tmp / m0), tmp)) throw Exception("Decimal math overflow", ErrorCodes::DECIMAL_OVERFLOW); return convertFromDecimal, DataTypeNumber>(tmp / (m0 - 1), scale); } Float64 getMoment3(UInt32 scale) const { if (m0 == 0) return std::numeric_limits::infinity(); NativeType tmp; if (common::mulOverflow(2 * getM(1), getM(1), tmp) || common::subOverflow(3 * getM(2), NativeType(tmp / m0), tmp) || common::mulOverflow(tmp, getM(1), tmp) || common::subOverflow(getM(3), NativeType(tmp / m0), tmp)) throw Exception("Decimal math overflow", ErrorCodes::DECIMAL_OVERFLOW); return convertFromDecimal, DataTypeNumber>(tmp / m0, scale); } Float64 getMoment4(UInt32 scale) const { if (m0 == 0) return std::numeric_limits::infinity(); NativeType tmp; if (common::mulOverflow(3 * getM(1), getM(1), tmp) || common::subOverflow(6 * getM(2), NativeType(tmp / m0), tmp) || common::mulOverflow(tmp, getM(1), tmp) || common::subOverflow(4 * getM(3), NativeType(tmp / m0), tmp) || common::mulOverflow(tmp, getM(1), tmp) || common::subOverflow(getM(4), NativeType(tmp / m0), tmp)) throw Exception("Decimal math overflow", ErrorCodes::DECIMAL_OVERFLOW); return convertFromDecimal, DataTypeNumber>(tmp / m0, scale); } }; template struct CovarMoments { T m0{}; T x1{}; T y1{}; T xy{}; void add(T x, T y) { ++m0; x1 += x; y1 += y; xy += x * y; } void merge(const CovarMoments & rhs) { m0 += rhs.m0; x1 += rhs.x1; y1 += rhs.y1; xy += rhs.xy; } void write(WriteBuffer & buf) const { writePODBinary(*this, buf); } void read(ReadBuffer & buf) { readPODBinary(*this, buf); } T NO_SANITIZE_UNDEFINED getPopulation() const { return (xy - x1 * y1 / m0) / m0; } T NO_SANITIZE_UNDEFINED getSample() const { if (m0 == 0) return std::numeric_limits::quiet_NaN(); return (xy - x1 * y1 / m0) / (m0 - 1); } }; template struct CorrMoments { T m0{}; T x1{}; T y1{}; T xy{}; T x2{}; T y2{}; void add(T x, T y) { ++m0; x1 += x; y1 += y; xy += x * y; x2 += x * x; y2 += y * y; } void merge(const CorrMoments & rhs) { m0 += rhs.m0; x1 += rhs.x1; y1 += rhs.y1; xy += rhs.xy; x2 += rhs.x2; y2 += rhs.y2; } void write(WriteBuffer & buf) const { writePODBinary(*this, buf); } void read(ReadBuffer & buf) { readPODBinary(*this, buf); } T NO_SANITIZE_UNDEFINED get() const { return (m0 * xy - x1 * y1) / sqrt((m0 * x2 - x1 * x1) * (m0 * y2 - y1 * y1)); } }; enum class StatisticsFunctionKind { varPop, varSamp, stddevPop, stddevSamp, skewPop, skewSamp, kurtPop, kurtSamp, covarPop, covarSamp, corr }; template struct StatFuncOneArg { using Type1 = T; using Type2 = T; using ResultType = std::conditional_t, Float32, Float64>; using Data = std::conditional_t< _level <= 2, std::conditional_t, VarMomentsDecimal, VarMoments>, std::conditional_t, HighOrderMomentsDecimal, HighOrderMoments> >; static constexpr StatisticsFunctionKind kind = _kind; static constexpr UInt32 num_args = 1; }; template struct StatFuncTwoArg { using Type1 = T1; using Type2 = T2; using ResultType = std::conditional_t && std::is_same_v, Float32, Float64>; using Data = std::conditional_t<_kind == StatisticsFunctionKind::corr, CorrMoments, CovarMoments>; static constexpr StatisticsFunctionKind kind = _kind; static constexpr UInt32 num_args = 2; }; template class AggregateFunctionVarianceSimple final : public IAggregateFunctionDataHelper> { public: using T1 = typename StatFunc::Type1; using T2 = typename StatFunc::Type2; using ColVecT1 = std::conditional_t, ColumnDecimal, ColumnVector>; using ColVecT2 = std::conditional_t, ColumnDecimal, ColumnVector>; using ResultType = typename StatFunc::ResultType; using ColVecResult = ColumnVector; AggregateFunctionVarianceSimple(const DataTypes & argument_types_) : IAggregateFunctionDataHelper>(argument_types_, {}) , src_scale(0) {} AggregateFunctionVarianceSimple(const IDataType & data_type, const DataTypes & argument_types_) : IAggregateFunctionDataHelper>(argument_types_, {}) , src_scale(getDecimalScale(data_type)) {} String getName() const override { if constexpr (StatFunc::kind == StatisticsFunctionKind::varPop) return "varPop"; if constexpr (StatFunc::kind == StatisticsFunctionKind::varSamp) return "varSamp"; if constexpr (StatFunc::kind == StatisticsFunctionKind::stddevPop) return "stddevPop"; if constexpr (StatFunc::kind == StatisticsFunctionKind::stddevSamp) return "stddevSamp"; if constexpr (StatFunc::kind == StatisticsFunctionKind::skewPop) return "skewPop"; if constexpr (StatFunc::kind == StatisticsFunctionKind::skewSamp) return "skewSamp"; if constexpr (StatFunc::kind == StatisticsFunctionKind::kurtPop) return "kurtPop"; if constexpr (StatFunc::kind == StatisticsFunctionKind::kurtSamp) return "kurtSamp"; if constexpr (StatFunc::kind == StatisticsFunctionKind::covarPop) return "covarPop"; if constexpr (StatFunc::kind == StatisticsFunctionKind::covarSamp) return "covarSamp"; if constexpr (StatFunc::kind == StatisticsFunctionKind::corr) return "corr"; } DataTypePtr getReturnType() const override { return std::make_shared>(); } void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena *) const override { if constexpr (StatFunc::num_args == 2) this->data(place).add( static_cast(*columns[0]).getData()[row_num], static_cast(*columns[1]).getData()[row_num]); else this->data(place).add( static_cast(*columns[0]).getData()[row_num]); } void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena *) const override { this->data(place).merge(this->data(rhs)); } void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override { this->data(place).write(buf); } void deserialize(AggregateDataPtr place, ReadBuffer & buf, Arena *) const override { this->data(place).read(buf); } void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override { const auto & data = this->data(place); auto & dst = static_cast(to).getData(); if constexpr (IsDecimalNumber) { if constexpr (StatFunc::kind == StatisticsFunctionKind::varPop) dst.push_back(data.getPopulation(src_scale * 2)); if constexpr (StatFunc::kind == StatisticsFunctionKind::varSamp) dst.push_back(data.getSample(src_scale * 2)); if constexpr (StatFunc::kind == StatisticsFunctionKind::stddevPop) dst.push_back(sqrt(data.getPopulation(src_scale * 2))); if constexpr (StatFunc::kind == StatisticsFunctionKind::stddevSamp) dst.push_back(sqrt(data.getSample(src_scale * 2))); if constexpr (StatFunc::kind == StatisticsFunctionKind::skewPop) dst.push_back(data.getMoment3(src_scale * 3) / pow(data.getPopulation(src_scale * 2), 1.5)); if constexpr (StatFunc::kind == StatisticsFunctionKind::skewSamp) dst.push_back(data.getMoment3(src_scale * 3) / pow(data.getSample(src_scale * 2), 1.5)); if constexpr (StatFunc::kind == StatisticsFunctionKind::kurtPop) dst.push_back(data.getMoment4(src_scale * 4) / pow(data.getPopulation(src_scale * 2), 2)); if constexpr (StatFunc::kind == StatisticsFunctionKind::kurtSamp) dst.push_back(data.getMoment4(src_scale * 4) / pow(data.getSample(src_scale * 2), 2)); } else { if constexpr (StatFunc::kind == StatisticsFunctionKind::varPop) dst.push_back(data.getPopulation()); if constexpr (StatFunc::kind == StatisticsFunctionKind::varSamp) dst.push_back(data.getSample()); if constexpr (StatFunc::kind == StatisticsFunctionKind::stddevPop) dst.push_back(sqrt(data.getPopulation())); if constexpr (StatFunc::kind == StatisticsFunctionKind::stddevSamp) dst.push_back(sqrt(data.getSample())); if constexpr (StatFunc::kind == StatisticsFunctionKind::skewPop) dst.push_back(data.getMoment3() / pow(data.getPopulation(), 1.5)); if constexpr (StatFunc::kind == StatisticsFunctionKind::skewSamp) dst.push_back(data.getMoment3() / pow(data.getSample(), 1.5)); if constexpr (StatFunc::kind == StatisticsFunctionKind::kurtPop) dst.push_back(data.getMoment4() / pow(data.getPopulation(), 2)); if constexpr (StatFunc::kind == StatisticsFunctionKind::kurtSamp) dst.push_back(data.getMoment4() / pow(data.getSample(), 2)); if constexpr (StatFunc::kind == StatisticsFunctionKind::covarPop) dst.push_back(data.getPopulation()); if constexpr (StatFunc::kind == StatisticsFunctionKind::covarSamp) dst.push_back(data.getSample()); if constexpr (StatFunc::kind == StatisticsFunctionKind::corr) dst.push_back(data.get()); } } const char * getHeaderFilePath() const override { return __FILE__; } private: UInt32 src_scale; }; template using AggregateFunctionVarPopSimple = AggregateFunctionVarianceSimple>; template using AggregateFunctionVarSampSimple = AggregateFunctionVarianceSimple>; template using AggregateFunctionStddevPopSimple = AggregateFunctionVarianceSimple>; template using AggregateFunctionStddevSampSimple = AggregateFunctionVarianceSimple>; template using AggregateFunctionSkewPopSimple = AggregateFunctionVarianceSimple>; template using AggregateFunctionSkewSampSimple = AggregateFunctionVarianceSimple>; template using AggregateFunctionKurtPopSimple = AggregateFunctionVarianceSimple>; template using AggregateFunctionKurtSampSimple = AggregateFunctionVarianceSimple>; template using AggregateFunctionCovarPopSimple = AggregateFunctionVarianceSimple>; template using AggregateFunctionCovarSampSimple = AggregateFunctionVarianceSimple>; template using AggregateFunctionCorrSimple = AggregateFunctionVarianceSimple>; }