diff --git a/dbms/include/DB/AggregateFunctions/AggregateFunctionsStatistics.h b/dbms/include/DB/AggregateFunctions/AggregateFunctionsStatistics.h new file mode 100644 index 00000000000..2ed3b05deea --- /dev/null +++ b/dbms/include/DB/AggregateFunctions/AggregateFunctionsStatistics.h @@ -0,0 +1,426 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +#include + +namespace DB +{ + +/** Статистические аггрегатные функции: + * varSamp - выборочная дисперсия + * stddevSamp - среднее выборочное квадратичное отклонение + * varPop - дисперсия + * stddevPop - среднее квадратичное отклонение + * covarSamp - выборочная ковариация + * covarPop - ковариация + * corr - корреляция + */ + +/** Параллельный и инкрементальный алгоритм для вычисления дисперсии. + * Источник: "Updating formulae and a pairwise algorithm for computing sample variances" + * (Chan et al., Stanford University, 12.1979) + */ +template +class AggregateFunctionVarianceData +{ +public: + AggregateFunctionVarianceData() = default; + + void update(const IColumn & column, size_t row_num) + { + T received = static_cast &>(column).getData()[row_num]; + Float64 val = static_cast(received); + Float64 delta = val - mean; + + ++count; + mean += delta / count; + m2 += delta * (val - mean); + } + + void mergeWith(const AggregateFunctionVarianceData & source) + { + UInt64 total_count = count + source.count; + if (total_count == 0) + return; + + Float64 factor = static_cast(count * source.count) / total_count; + Float64 delta = mean - source.mean; + + auto res = std::minmax(count, source.count); + if (((1 - static_cast(res.first) / res.second) < 0.001) && (res.first > 10000)) + { + /// Эта формула более стабильная, когда размеры обоих источников велики и сравнимы. + mean = (source.count * source.mean + count * mean) / total_count; + } + else + mean = source.mean + delta * (static_cast(count) / total_count); + + m2 += source.m2 + delta * delta * factor; + count = total_count; + } + + void serialize(WriteBuffer & buf) const + { + writeVarUInt(count, buf); + writeBinary(mean, buf); + writeBinary(m2, buf); + } + + void deserialize(ReadBuffer & buf) + { + readVarUInt(count, buf); + readBinary(mean, buf); + readBinary(m2, buf); + } + + void publish(IColumn & to) const + { + static_cast(to).getData().push_back(Op::apply(m2, count)); + } + +private: + UInt64 count = 0; + Float64 mean = 0.0; + Float64 m2 = 0.0; +}; + +/** Основной код для реализации функций varSamp, stddevSamp, varPop, stddevPop. + */ +template +class AggregateFunctionVariance final : public IUnaryAggregateFunction, AggregateFunctionVariance > +{ +public: + String getName() const override { return Op::name; } + + DataTypePtr getReturnType() const override + { + return new DataTypeFloat64; + } + + void setArgument(const DataTypePtr & argument) override + { + if (!argument->behavesAsNumber()) + throw Exception("Illegal type " + argument->getName() + " of argument for aggregate function " + getName(), + ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); + } + + void addOne(AggregateDataPtr place, const IColumn & column, size_t row_num) const + { + this->data(place).update(column, row_num); + } + + void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs) const override + { + this->data(place).mergeWith(this->data(rhs)); + } + + void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override + { + this->data(place).serialize(buf); + } + + void deserializeMerge(AggregateDataPtr place, ReadBuffer & buf) const override + { + AggregateFunctionVarianceData source; + source.deserialize(buf); + + this->data(place).mergeWith(source); + } + + void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override + { + this->data(place).publish(to); + } +}; + +namespace +{ + +/** Реализации функции varSamp. + */ +struct VarSampImpl +{ + static constexpr auto name = "varSamp"; + + static inline Float64 apply(Float64 m2, UInt64 count) + { + if (count < 2) + return 0.0; + else + return m2 / (count - 1); + } +}; + +/** Реализация функции stddevSamp. + */ +struct StdDevSampImpl +{ + static constexpr auto name = "stddevSamp"; + + static inline Float64 apply(Float64 m2, UInt64 count) + { + return sqrt(VarSampImpl::apply(m2, count)); + } +}; + +/** Реализация функции varPop. + */ +struct VarPopImpl +{ + static constexpr auto name = "varPop"; + + static inline Float64 apply(Float64 m2, UInt64 count) + { + if (count < 2) + return 0.0; + else + return m2 / count; + } +}; + +/** Реализация функции stddevPop. + */ +struct StdDevPopImpl +{ + static constexpr auto name = "stddevPop"; + + static inline Float64 apply(Float64 m2, UInt64 count) + { + return sqrt(VarPopImpl::apply(m2, count)); + } +}; + +} + +/** Параллельный и инкрементальный алгоритм для вычисления ковариации. + * Источник: "Numerically Stable, Single-Pass, Parallel Statistics Algorithms" + * (J. Bennett et al., Sandia National Laboratories, + * 2009 IEEE International Conference on Cluster Computing) + */ +template +class CovarianceData +{ +public: + CovarianceData() = default; + + void update(const IColumn & column_left, const IColumn & column_right, size_t row_num) + { + T left_received = static_cast &>(column_left).getData()[row_num]; + Float64 val_left = static_cast(left_received); + Float64 left_delta = val_left - left_mean; + + U right_received = static_cast &>(column_right).getData()[row_num]; + Float64 val_right = static_cast(right_received); + Float64 right_delta = val_right - right_mean; + + Float64 old_right_mean = right_mean; + + ++count; + + left_mean += left_delta / count; + right_mean += right_delta / count; + co_moment += (val_left - left_mean) * (val_right - old_right_mean); + + if (compute_marginal_moments) + { + left_m2 += left_delta * (val_left - left_mean); + right_m2 += right_delta * (val_right - right_mean); + } + } + + void mergeWith(const CovarianceData & source) + { + UInt64 total_count = count + source.count; + if (total_count == 0) + return; + + Float64 factor = static_cast(count * source.count) / total_count; + Float64 left_delta = left_mean - source.left_mean; + Float64 right_delta = right_mean - source.right_mean; + + left_mean = source.left_mean + left_delta * (static_cast(count) / total_count); + right_mean = source.right_mean + right_delta * (static_cast(count) / total_count); + co_moment += source.co_moment + left_delta * right_delta * factor; + count = total_count; + + if (compute_marginal_moments) + { + left_m2 += source.left_m2 + left_delta * left_delta * factor; + right_m2 += source.right_m2 + right_delta * right_delta * factor; + } + } + + void serialize(WriteBuffer & buf) const + { + writeVarUInt(count, buf); + writeBinary(left_mean, buf); + writeBinary(right_mean, buf); + writeBinary(co_moment, buf); + + if (compute_marginal_moments) + { + writeBinary(left_m2, buf); + writeBinary(right_m2, buf); + } + } + + void deserialize(ReadBuffer & buf) + { + readVarUInt(count, buf); + readBinary(left_mean, buf); + readBinary(right_mean, buf); + readBinary(co_moment, buf); + + if (compute_marginal_moments) + { + readBinary(left_m2, buf); + readBinary(right_m2, buf); + } + } + + void publish(IColumn & to) const + { + static_cast(to).getData().push_back(Op::apply(co_moment, left_m2, right_m2, count)); + } + +private: + UInt64 count = 0; + Float64 left_mean = 0.0; + Float64 right_mean = 0.0; + Float64 co_moment = 0.0; + Float64 left_m2 = 0.0; + Float64 right_m2 = 0.0; +}; + +template +class AggregateFunctionCovariance final + : public IBinaryAggregateFunction< + CovarianceData, + AggregateFunctionCovariance > +{ +public: + String getName() const override { return Op::name; } + + DataTypePtr getReturnType() const override + { + return new DataTypeFloat64; + } + + void setArgumentsImpl(const DataTypes & arguments) + { + if (!arguments[0]->behavesAsNumber()) + throw Exception("Illegal type " + arguments[0]->getName() + " of first argument to function " + getName(), + ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); + + if (!arguments[1]->behavesAsNumber()) + throw Exception("Illegal type " + arguments[1]->getName() + " of second argument to function " + getName(), + ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); + } + + void addOne(AggregateDataPtr place, const IColumn & column_left, const IColumn & column_right, size_t row_num) const + { + this->data(place).update(column_left, column_right, row_num); + } + + void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs) const override + { + this->data(place).mergeWith(this->data(rhs)); + } + + void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override + { + this->data(place).serialize(buf); + } + + void deserializeMerge(AggregateDataPtr place, ReadBuffer & buf) const override + { + CovarianceData source; + source.deserialize(buf); + + this->data(place).mergeWith(source); + } + + void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override + { + this->data(place).publish(to); + } +}; + +namespace +{ + +/** Реализация функции covarSamp. + */ +struct CovarSampImpl +{ + static constexpr auto name = "covarSamp"; + + static inline Float64 apply(Float64 co_moment, Float64 left_m2, Float64 right_m2, UInt64 count) + { + if (count < 2) + return 0.0; + else + return co_moment / (count - 1); + } +}; + +/** Реализация функции covarPop. + */ +struct CovarPopImpl +{ + static constexpr auto name = "covarPop"; + + static inline Float64 apply(Float64 co_moment, Float64 left_m2, Float64 right_m2, UInt64 count) + { + if (count < 2) + return 0.0; + else + return co_moment / count; + } +}; + +/** Реализация функции corr. + */ +struct CorrImpl +{ + static constexpr auto name = "corr"; + + static inline Float64 apply(Float64 co_moment, Float64 left_m2, Float64 right_m2, UInt64 count) + { + if (count < 2) + return 0.0; + else + return co_moment / sqrt(left_m2 * right_m2); + } +}; + +} + +template +using AggregateFunctionVarSamp = AggregateFunctionVariance; + +template +using AggregateFunctionStdDevSamp = AggregateFunctionVariance; + +template +using AggregateFunctionVarPop = AggregateFunctionVariance; + +template +using AggregateFunctionStdDevPop = AggregateFunctionVariance; + +template +using AggregateFunctionCovarSamp = AggregateFunctionCovariance; + +template +using AggregateFunctionCovarPop = AggregateFunctionCovariance; + +template +using AggregateFunctionCorr = AggregateFunctionCovariance; + +} diff --git a/dbms/src/AggregateFunctions/AggregateFunctionFactory.cpp b/dbms/src/AggregateFunctions/AggregateFunctionFactory.cpp index 5db9268e4db..61b81fc37a5 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionFactory.cpp +++ b/dbms/src/AggregateFunctions/AggregateFunctionFactory.cpp @@ -16,6 +16,7 @@ #include #include #include +#include #include @@ -544,6 +545,90 @@ AggregateFunctionPtr AggregateFunctionFactory::get(const String & name, const Da return new AggregateFunctionSequenceMatch; } + else if (name == "varSamp") + { + if (argument_types.size() != 1) + throw Exception("Incorrect number of arguments for aggregate function " + name, ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); + + AggregateFunctionPtr res = createWithNumericType(*argument_types[0]); + + if (!res) + throw Exception("Illegal type " + argument_types[0]->getName() + " of argument for aggregate function " + name, ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); + + return res; + } + else if (name == "varPop") + { + if (argument_types.size() != 1) + throw Exception("Incorrect number of arguments for aggregate function " + name, ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); + + AggregateFunctionPtr res = createWithNumericType(*argument_types[0]); + + if (!res) + throw Exception("Illegal type " + argument_types[0]->getName() + " of argument for aggregate function " + name, ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); + + return res; + } + else if (name == "stddevSamp") + { + if (argument_types.size() != 1) + throw Exception("Incorrect number of arguments for aggregate function " + name, ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); + + AggregateFunctionPtr res = createWithNumericType(*argument_types[0]); + + if (!res) + throw Exception("Illegal type " + argument_types[0]->getName() + " of argument for aggregate function " + name, ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); + + return res; + } + else if (name == "stddevPop") + { + if (argument_types.size() != 1) + throw Exception("Incorrect number of arguments for aggregate function " + name, ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); + + AggregateFunctionPtr res = createWithNumericType(*argument_types[0]); + + if (!res) + throw Exception("Illegal type " + argument_types[0]->getName() + " of argument for aggregate function " + name, ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); + + return res; + } + else if (name == "covarSamp") + { + if (argument_types.size() != 2) + throw Exception("Incorrect number of arguments for aggregate function " + name, ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); + + AggregateFunctionPtr res = createWithTwoNumericTypes(*argument_types[0], *argument_types[1]); + if (!res) + throw Exception("Illegal types " + argument_types[0]->getName() + " and " + argument_types[1]->getName() + + " of arguments for aggregate function " + name, ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); + + return res; + } + else if (name == "covarPop") + { + if (argument_types.size() != 2) + throw Exception("Incorrect number of arguments for aggregate function " + name, ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); + + AggregateFunctionPtr res = createWithTwoNumericTypes(*argument_types[0], *argument_types[1]); + if (!res) + throw Exception("Illegal types " + argument_types[0]->getName() + " and " + argument_types[1]->getName() + + " of arguments for aggregate function " + name, ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); + + return res; + } + else if (name == "corr") + { + if (argument_types.size() != 2) + throw Exception("Incorrect number of arguments for aggregate function " + name, ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); + + AggregateFunctionPtr res = createWithTwoNumericTypes(*argument_types[0], *argument_types[1]); + if (!res) + throw Exception("Illegal types " + argument_types[0]->getName() + " and " + argument_types[1]->getName() + + " of arguments for aggregate function " + name, ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); + + return res; + } else if (recursion_level == 0 && name.size() > strlen("State") && !(strcmp(name.data() + name.size() - strlen("State"), "State"))) { /// Для агрегатных функций вида aggState, где agg - имя другой агрегатной функции. @@ -639,7 +724,14 @@ const AggregateFunctionFactory::FunctionNames & AggregateFunctionFactory::getFun "medianTimingWeighted", "quantileDeterministic", "quantilesDeterministic", - "sequenceMatch" + "sequenceMatch", + "varSamp", + "varPop", + "stddevSamp", + "stddevPop", + "covarSamp", + "covarPop", + "corr" }; return names;