#pragma once #include #include #include #include #include #include #include #include #include #include /** This is simple, not numerically stable * implementations of variance/covariance/correlation functions. * * It is about two times faster than stable variants. * Numerical errors may occur during summation. * * This implementation is selected as default, * because "you don't pay for what you don't need" principle. * * For more sophisticated implementation, look at AggregateFunctionStatistics.h */ namespace DB { enum class StatisticsFunctionKind { varPop, varSamp, stddevPop, stddevSamp, skewPop, skewSamp, kurtPop, kurtSamp, covarPop, covarSamp, corr }; template struct StatFuncOneArg { using Type1 = T; using Type2 = T; using ResultType = std::conditional_t, Float32, Float64>; using Data = std::conditional_t, VarMomentsDecimal, VarMoments>; static constexpr StatisticsFunctionKind kind = _kind; static constexpr UInt32 num_args = 1; }; template struct StatFuncTwoArg { using Type1 = T1; using Type2 = T2; using ResultType = std::conditional_t && std::is_same_v, Float32, Float64>; using Data = std::conditional_t<_kind == StatisticsFunctionKind::corr, CorrMoments, CovarMoments>; static constexpr StatisticsFunctionKind kind = _kind; static constexpr UInt32 num_args = 2; }; template class AggregateFunctionVarianceSimple final : public IAggregateFunctionDataHelper> { public: using T1 = typename StatFunc::Type1; using T2 = typename StatFunc::Type2; using ColVecT1 = std::conditional_t, ColumnDecimal, ColumnVector>; using ColVecT2 = std::conditional_t, ColumnDecimal, ColumnVector>; using ResultType = typename StatFunc::ResultType; using ColVecResult = ColumnVector; AggregateFunctionVarianceSimple(const DataTypes & argument_types_) : IAggregateFunctionDataHelper>(argument_types_, {}) , src_scale(0) {} AggregateFunctionVarianceSimple(const IDataType & data_type, const DataTypes & argument_types_) : IAggregateFunctionDataHelper>(argument_types_, {}) , src_scale(getDecimalScale(data_type)) {} String getName() const override { if constexpr (StatFunc::kind == StatisticsFunctionKind::varPop) return "varPop"; if constexpr (StatFunc::kind == StatisticsFunctionKind::varSamp) return "varSamp"; if constexpr (StatFunc::kind == StatisticsFunctionKind::stddevPop) return "stddevPop"; if constexpr (StatFunc::kind == StatisticsFunctionKind::stddevSamp) return "stddevSamp"; if constexpr (StatFunc::kind == StatisticsFunctionKind::skewPop) return "skewPop"; if constexpr (StatFunc::kind == StatisticsFunctionKind::skewSamp) return "skewSamp"; if constexpr (StatFunc::kind == StatisticsFunctionKind::kurtPop) return "kurtPop"; if constexpr (StatFunc::kind == StatisticsFunctionKind::kurtSamp) return "kurtSamp"; if constexpr (StatFunc::kind == StatisticsFunctionKind::covarPop) return "covarPop"; if constexpr (StatFunc::kind == StatisticsFunctionKind::covarSamp) return "covarSamp"; if constexpr (StatFunc::kind == StatisticsFunctionKind::corr) return "corr"; __builtin_unreachable(); } DataTypePtr getReturnType() const override { return std::make_shared>(); } void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena *) const override { if constexpr (StatFunc::num_args == 2) this->data(place).add( static_cast(static_cast(*columns[0]).getData()[row_num]), static_cast(static_cast(*columns[1]).getData()[row_num])); else { if constexpr (std::is_same_v) { this->data(place).add(static_cast( static_cast(*columns[0]).getData()[row_num].value )); } else this->data(place).add( static_cast(static_cast(*columns[0]).getData()[row_num])); } } void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena *) const override { this->data(place).merge(this->data(rhs)); } void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override { this->data(place).write(buf); } void deserialize(AggregateDataPtr place, ReadBuffer & buf, Arena *) const override { this->data(place).read(buf); } void insertResultInto(AggregateDataPtr place, IColumn & to, Arena *) const override { const auto & data = this->data(place); auto & dst = static_cast(to).getData(); if constexpr (IsDecimalNumber) { if constexpr (StatFunc::kind == StatisticsFunctionKind::varPop) dst.push_back(data.getPopulation(src_scale * 2)); if constexpr (StatFunc::kind == StatisticsFunctionKind::varSamp) dst.push_back(data.getSample(src_scale * 2)); if constexpr (StatFunc::kind == StatisticsFunctionKind::stddevPop) dst.push_back(sqrt(data.getPopulation(src_scale * 2))); if constexpr (StatFunc::kind == StatisticsFunctionKind::stddevSamp) dst.push_back(sqrt(data.getSample(src_scale * 2))); if constexpr (StatFunc::kind == StatisticsFunctionKind::skewPop) { Float64 var_value = data.getPopulation(src_scale * 2); if (var_value > 0) dst.push_back(data.getMoment3(src_scale * 3) / pow(var_value, 1.5)); else dst.push_back(std::numeric_limits::quiet_NaN()); } if constexpr (StatFunc::kind == StatisticsFunctionKind::skewSamp) { Float64 var_value = data.getSample(src_scale * 2); if (var_value > 0) dst.push_back(data.getMoment3(src_scale * 3) / pow(var_value, 1.5)); else dst.push_back(std::numeric_limits::quiet_NaN()); } if constexpr (StatFunc::kind == StatisticsFunctionKind::kurtPop) { Float64 var_value = data.getPopulation(src_scale * 2); if (var_value > 0) dst.push_back(data.getMoment4(src_scale * 4) / pow(var_value, 2)); else dst.push_back(std::numeric_limits::quiet_NaN()); } if constexpr (StatFunc::kind == StatisticsFunctionKind::kurtSamp) { Float64 var_value = data.getSample(src_scale * 2); if (var_value > 0) dst.push_back(data.getMoment4(src_scale * 4) / pow(var_value, 2)); else dst.push_back(std::numeric_limits::quiet_NaN()); } } else { if constexpr (StatFunc::kind == StatisticsFunctionKind::varPop) dst.push_back(data.getPopulation()); if constexpr (StatFunc::kind == StatisticsFunctionKind::varSamp) dst.push_back(data.getSample()); if constexpr (StatFunc::kind == StatisticsFunctionKind::stddevPop) dst.push_back(sqrt(data.getPopulation())); if constexpr (StatFunc::kind == StatisticsFunctionKind::stddevSamp) dst.push_back(sqrt(data.getSample())); if constexpr (StatFunc::kind == StatisticsFunctionKind::skewPop) { ResultType var_value = data.getPopulation(); if (var_value > 0) dst.push_back(data.getMoment3() / pow(var_value, 1.5)); else dst.push_back(std::numeric_limits::quiet_NaN()); } if constexpr (StatFunc::kind == StatisticsFunctionKind::skewSamp) { ResultType var_value = data.getSample(); if (var_value > 0) dst.push_back(data.getMoment3() / pow(var_value, 1.5)); else dst.push_back(std::numeric_limits::quiet_NaN()); } if constexpr (StatFunc::kind == StatisticsFunctionKind::kurtPop) { ResultType var_value = data.getPopulation(); if (var_value > 0) dst.push_back(data.getMoment4() / pow(var_value, 2)); else dst.push_back(std::numeric_limits::quiet_NaN()); } if constexpr (StatFunc::kind == StatisticsFunctionKind::kurtSamp) { ResultType var_value = data.getSample(); if (var_value > 0) dst.push_back(data.getMoment4() / pow(var_value, 2)); else dst.push_back(std::numeric_limits::quiet_NaN()); } if constexpr (StatFunc::kind == StatisticsFunctionKind::covarPop) dst.push_back(data.getPopulation()); if constexpr (StatFunc::kind == StatisticsFunctionKind::covarSamp) dst.push_back(data.getSample()); if constexpr (StatFunc::kind == StatisticsFunctionKind::corr) dst.push_back(data.get()); } } private: UInt32 src_scale; }; template using AggregateFunctionVarPopSimple = AggregateFunctionVarianceSimple>; template using AggregateFunctionVarSampSimple = AggregateFunctionVarianceSimple>; template using AggregateFunctionStddevPopSimple = AggregateFunctionVarianceSimple>; template using AggregateFunctionStddevSampSimple = AggregateFunctionVarianceSimple>; template using AggregateFunctionSkewPopSimple = AggregateFunctionVarianceSimple>; template using AggregateFunctionSkewSampSimple = AggregateFunctionVarianceSimple>; template using AggregateFunctionKurtPopSimple = AggregateFunctionVarianceSimple>; template using AggregateFunctionKurtSampSimple = AggregateFunctionVarianceSimple>; template using AggregateFunctionCovarPopSimple = AggregateFunctionVarianceSimple>; template using AggregateFunctionCovarSampSimple = AggregateFunctionVarianceSimple>; template using AggregateFunctionCorrSimple = AggregateFunctionVarianceSimple>; }