From 69a7761812db0c34e261409ee0216b59e0b75639 Mon Sep 17 00:00:00 2001 From: Alexey Milovidov Date: Sat, 23 Dec 2017 01:59:45 +0300 Subject: [PATCH] Added simple, non-numerically stable variants of statistical functions; use it by default; old functions are accessible under -Stable suffix, like varPopStable [#CLICKHOUSE-2]. --- .../AggregateFunctionStatisticsSimple.h | 34 +++++++++++-------- .../AggregateFunctionsStatisticsSimple.cpp | 20 +++++------ ...1_aggregate_functions_statistics.reference | 8 ++--- 3 files changed, 34 insertions(+), 28 deletions(-) diff --git a/dbms/src/AggregateFunctions/AggregateFunctionStatisticsSimple.h b/dbms/src/AggregateFunctions/AggregateFunctionStatisticsSimple.h index 978c32f7b3e..598d9bfea07 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionStatisticsSimple.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionStatisticsSimple.h @@ -168,10 +168,16 @@ enum class StatisticsFunctionKind }; -template +template +using VarianceCalcType = std::conditional_t && std::is_same_v, Float32, Float64>; + + +template class AggregateFunctionVarianceSimple final - : public IAggregateFunctionDataHelper> + : public IAggregateFunctionDataHelper> { + using ResultType = VarianceCalcType; + public: String getName() const override { @@ -189,18 +195,18 @@ public: DataTypePtr getReturnType() const override { - return std::make_shared>(); + return std::make_shared>(); } void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena *) const override { if constexpr (Kind == StatisticsFunctionKind::covarPop || Kind == StatisticsFunctionKind::covarSamp || Kind == StatisticsFunctionKind::corr) this->data(place).add( - static_cast &>(*columns[0]).getData()[row_num], - static_cast &>(*columns[1]).getData()[row_num]); + static_cast &>(*columns[0]).getData()[row_num], + static_cast &>(*columns[1]).getData()[row_num]); else this->data(place).add( - static_cast &>(*columns[0]).getData()[row_num]); + static_cast &>(*columns[0]).getData()[row_num]); } void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena *) const override @@ -221,7 +227,7 @@ public: void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override { const auto & data = this->data(place); - auto & dst = static_cast &>(to).getData(); + auto & dst = static_cast &>(to).getData(); if constexpr (Kind == StatisticsFunctionKind::varPop) dst.push_back(data.template get()); else if constexpr (Kind == StatisticsFunctionKind::varSamp) dst.push_back(data.template get()); @@ -236,12 +242,12 @@ public: }; -template using AggregateFunctionVarPopSimple = AggregateFunctionVarianceSimple, StatisticsFunctionKind::varPop>; -template using AggregateFunctionVarSampSimple = AggregateFunctionVarianceSimple, StatisticsFunctionKind::varSamp>; -template using AggregateFunctionStddevPopSimple = AggregateFunctionVarianceSimple, StatisticsFunctionKind::stddevPop>; -template using AggregateFunctionStddevSampSimple = AggregateFunctionVarianceSimple, StatisticsFunctionKind::stddevSamp>; -template using AggregateFunctionCovarPopSimple = AggregateFunctionVarianceSimple, StatisticsFunctionKind::covarPop>; -template using AggregateFunctionCovarSampSimple = AggregateFunctionVarianceSimple, StatisticsFunctionKind::covarSamp>; -template using AggregateFunctionCorrSimple = AggregateFunctionVarianceSimple, StatisticsFunctionKind::corr>; +template using AggregateFunctionVarPopSimple = AggregateFunctionVarianceSimple>, StatisticsFunctionKind::varPop>; +template using AggregateFunctionVarSampSimple = AggregateFunctionVarianceSimple>, StatisticsFunctionKind::varSamp>; +template using AggregateFunctionStddevPopSimple = AggregateFunctionVarianceSimple>, StatisticsFunctionKind::stddevPop>; +template using AggregateFunctionStddevSampSimple = AggregateFunctionVarianceSimple>, StatisticsFunctionKind::stddevSamp>; +template using AggregateFunctionCovarPopSimple = AggregateFunctionVarianceSimple>, StatisticsFunctionKind::covarPop>; +template using AggregateFunctionCovarSampSimple = AggregateFunctionVarianceSimple>, StatisticsFunctionKind::covarSamp>; +template using AggregateFunctionCorrSimple = AggregateFunctionVarianceSimple>, StatisticsFunctionKind::corr>; } diff --git a/dbms/src/AggregateFunctions/AggregateFunctionsStatisticsSimple.cpp b/dbms/src/AggregateFunctions/AggregateFunctionsStatisticsSimple.cpp index 6d46b0a2275..089ea59cd79 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionsStatisticsSimple.cpp +++ b/dbms/src/AggregateFunctions/AggregateFunctionsStatisticsSimple.cpp @@ -21,26 +21,26 @@ AggregateFunctionPtr createAggregateFunctionStatisticsUnary(const std::string & assertNoParameters(name, parameters); assertUnary(name, argument_types); - if (typeid_cast(argument_types[0].get())) return std::make_shared>(); - if (typeid_cast(argument_types[0].get())) return std::make_shared>(); + AggregateFunctionPtr res(createWithNumericType(*argument_types[0])); - throw Exception("Illegal type " + argument_types[0]->getName() + " of argument for aggregate function " + name, ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); + if (!res) + throw Exception("Illegal type " + argument_types[0]->getName() + " of argument for aggregate function " + name, ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); + + return res; } -template