diff --git a/dbms/include/DB/AggregateFunctions/AggregateFunctionsStatistics.h b/dbms/include/DB/AggregateFunctions/AggregateFunctionsStatistics.h new file mode 100644 index 00000000000..a2b632e8deb --- /dev/null +++ b/dbms/include/DB/AggregateFunctions/AggregateFunctionsStatistics.h @@ -0,0 +1,166 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include + +namespace DB +{ + +template +class AggregateFunctionVarianceData +{ +public: + AggregateFunctionVarianceData() = default; + + void update(const IColumn & column, size_t row_num) + { + T received = static_cast &>(column).getData()[row_num]; + Float64 val = static_cast(received); + Float64 delta = val - mean; + + ++count; + mean += delta / count; + m2 += delta * (val - mean); + } + + void mergeWith(const AggregateFunctionVarianceData & source) + { + UInt64 total_count = count + source.count; + Float64 factor = static_cast(count * source.count) / total_count; + Float64 delta = mean - source.mean; + + count = total_count; + mean += delta * (source.count / count); + m2 += source.m2 + delta * delta * factor; + } + + void serialize(WriteBuffer & buf) const + { + writeVarUInt(count, buf); + writeBinary(mean, buf); + writeBinary(m2, buf); + } + + void deserialize(ReadBuffer & buf) + { + readVarUInt(count, buf); + readBinary(mean, buf); + readBinary(m2, buf); + } + + void publish(IColumn & to) const + { + static_cast(to).getData().push_back(Op::apply(m2, count)); + } + +private: + UInt64 count = 0; + Float64 mean = 0.0; + Float64 m2 = 0.0; +}; + +template +class AggregateFunctionVariance final : public IUnaryAggregateFunction, AggregateFunctionVariance > +{ +public: + String getName() const override { return Op::name; } + + DataTypePtr getReturnType() const override + { + return new DataTypeFloat64; + } + + void setArgument(const DataTypePtr & argument) override + { + if (!argument->behavesAsNumber()) + throw Exception("Illegal type " + argument->getName() + " of argument for aggregate function " + getName(), + ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); + } + + void addOne(AggregateDataPtr place, const IColumn & column, size_t row_num) const + { + this->data(place).update(column, row_num); + } + + void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs) const override + { + this->data(place).mergeWith(this->data(rhs)); + } + + void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override + { + this->data(place).serialize(buf); + } + + void deserializeMerge(AggregateDataPtr place, ReadBuffer & buf) const override + { + AggregateFunctionVarianceData source; + source.deserialize(buf); + + this->data(place).mergeWith(source); + } + + void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override + { + this->data(place).publish(to); + } +}; + +struct VarSampImpl +{ + static constexpr auto name = "varSamp"; + + static inline Float64 apply(Float64 m2, UInt64 count) + { + return m2 / (count - 1); + } +}; + +struct StdDevSampImpl +{ + static constexpr auto name = "stddevSamp"; + + static inline Float64 apply(Float64 m2, UInt64 count) + { + return sqrt(VarSampImpl::apply(m2, count)); + } +}; + +struct VarPopImpl +{ + static constexpr auto name = "varPop"; + + static inline Float64 apply(Float64 m2, UInt64 count) + { + return m2 / count; + } +}; + +struct StdDevPopImpl +{ + static constexpr auto name = "stddevPop"; + + static inline Float64 apply(Float64 m2, UInt64 count) + { + return sqrt(VarPopImpl::apply(m2, count)); + } +}; + +template +using AggregateFunctionVarSamp = AggregateFunctionVariance; + +template +using AggregateFunctionStdDevSamp = AggregateFunctionVariance; + +template +using AggregateFunctionVarPop = AggregateFunctionVariance; + +template +using AggregateFunctionStdDevPop = AggregateFunctionVariance; + +} diff --git a/dbms/src/AggregateFunctions/AggregateFunctionFactory.cpp b/dbms/src/AggregateFunctions/AggregateFunctionFactory.cpp index 5db9268e4db..e4996175a82 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionFactory.cpp +++ b/dbms/src/AggregateFunctions/AggregateFunctionFactory.cpp @@ -16,6 +16,7 @@ #include #include #include +#include #include @@ -544,6 +545,54 @@ AggregateFunctionPtr AggregateFunctionFactory::get(const String & name, const Da return new AggregateFunctionSequenceMatch; } + else if (name == "varSamp") + { + if (argument_types.size() != 1) + throw Exception("Incorrect number of arguments for aggregate function " + name, ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); + + AggregateFunctionPtr res = createWithNumericType(*argument_types[0]); + + if (!res) + throw Exception("Illegal type " + argument_types[0]->getName() + " of argument for aggregate function " + name, ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); + + return res; + } + else if (name == "varPop") + { + if (argument_types.size() != 1) + throw Exception("Incorrect number of arguments for aggregate function " + name, ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); + + AggregateFunctionPtr res = createWithNumericType(*argument_types[0]); + + if (!res) + throw Exception("Illegal type " + argument_types[0]->getName() + " of argument for aggregate function " + name, ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); + + return res; + } + else if (name == "stddevSamp") + { + if (argument_types.size() != 1) + throw Exception("Incorrect number of arguments for aggregate function " + name, ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); + + AggregateFunctionPtr res = createWithNumericType(*argument_types[0]); + + if (!res) + throw Exception("Illegal type " + argument_types[0]->getName() + " of argument for aggregate function " + name, ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); + + return res; + } + else if (name == "stddevPop") + { + if (argument_types.size() != 1) + throw Exception("Incorrect number of arguments for aggregate function " + name, ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); + + AggregateFunctionPtr res = createWithNumericType(*argument_types[0]); + + if (!res) + throw Exception("Illegal type " + argument_types[0]->getName() + " of argument for aggregate function " + name, ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); + + return res; + } else if (recursion_level == 0 && name.size() > strlen("State") && !(strcmp(name.data() + name.size() - strlen("State"), "State"))) { /// Для агрегатных функций вида aggState, где agg - имя другой агрегатной функции. @@ -639,7 +688,11 @@ const AggregateFunctionFactory::FunctionNames & AggregateFunctionFactory::getFun "medianTimingWeighted", "quantileDeterministic", "quantilesDeterministic", - "sequenceMatch" + "sequenceMatch", + "varSamp", + "varPop", + "stddevSamp", + "stddevPop" }; return names;