From 72bcbc76b1e116aa9e0af9a78618ca827c19f0fe Mon Sep 17 00:00:00 2001 From: hcz Date: Wed, 13 Mar 2019 15:22:57 +0800 Subject: [PATCH] Add aggregate function leastSqr --- .../AggregateFunctionLeastSqr.cpp | 112 +++++++++++ .../AggregateFunctionLeastSqr.h | 176 ++++++++++++++++++ .../registerAggregateFunctions.cpp | 2 + .../0_stateless/00917_least_sqr.reference | 7 + .../queries/0_stateless/00917_least_sqr.sql | 7 + 5 files changed, 304 insertions(+) create mode 100644 dbms/src/AggregateFunctions/AggregateFunctionLeastSqr.cpp create mode 100644 dbms/src/AggregateFunctions/AggregateFunctionLeastSqr.h create mode 100644 dbms/tests/queries/0_stateless/00917_least_sqr.reference create mode 100644 dbms/tests/queries/0_stateless/00917_least_sqr.sql diff --git a/dbms/src/AggregateFunctions/AggregateFunctionLeastSqr.cpp b/dbms/src/AggregateFunctions/AggregateFunctionLeastSqr.cpp new file mode 100644 index 00000000000..1cb213b6360 --- /dev/null +++ b/dbms/src/AggregateFunctions/AggregateFunctionLeastSqr.cpp @@ -0,0 +1,112 @@ +#include + +#include +#include + + +namespace DB +{ + +namespace +{ + +AggregateFunctionPtr createAggregateFunctionLeastSqr( + const String & name, + const DataTypes & arguments, + const Array & params +) +{ + assertNoParameters(name, params); + assertBinary(name, arguments); + + const IDataType * x_arg = arguments.front().get(); + + WhichDataType which_x { + x_arg + }; + + if ( + !which_x.isNativeUInt() + && !which_x.isNativeInt() + && !which_x.isFloat() + ) + throw Exception { + "Illegal type " + x_arg->getName() + + " of first argument of aggregate function " + + name + ", must be Native Int, Native UInt or Float", + ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT + }; + + const IDataType * y_arg = arguments.back().get(); + + WhichDataType which_y { + y_arg + }; + + if ( + !which_y.isNativeUInt() + && !which_y.isNativeInt() + && !which_y.isFloat() + ) + throw Exception { + "Illegal type " + y_arg->getName() + + " of second argument of aggregate function " + + name + ", must be Native Int, Native UInt or Float", + ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT + }; + + if (which_x.isNativeUInt() && which_y.isNativeUInt()) + return std::make_shared>( + arguments, + params + ); + else if (which_x.isNativeUInt() && which_y.isNativeInt()) + return std::make_shared>( + arguments, + params + ); + else if (which_x.isNativeUInt() && which_y.isFloat()) + return std::make_shared>( + arguments, + params + ); + else if (which_x.isNativeInt() && which_y.isNativeUInt()) + return std::make_shared>( + arguments, + params + ); + else if (which_x.isNativeInt() && which_y.isNativeInt()) + return std::make_shared>( + arguments, + params + ); + else if (which_x.isNativeInt() && which_y.isFloat()) + return std::make_shared>( + arguments, + params + ); + else if (which_x.isFloat() && which_y.isNativeUInt()) + return std::make_shared>( + arguments, + params + ); + else if (which_x.isFloat() && which_y.isNativeInt()) + return std::make_shared>( + arguments, + params + ); + else // if (which_x.isFloat() && which_y.isFloat()) + return std::make_shared>( + arguments, + params + ); +} + +} + +void registerAggregateFunctionLeastSqr(AggregateFunctionFactory & factory) +{ + factory.registerFunction("leastSqr", createAggregateFunctionLeastSqr); +} + +} diff --git a/dbms/src/AggregateFunctions/AggregateFunctionLeastSqr.h b/dbms/src/AggregateFunctions/AggregateFunctionLeastSqr.h new file mode 100644 index 00000000000..c527e34588d --- /dev/null +++ b/dbms/src/AggregateFunctions/AggregateFunctionLeastSqr.h @@ -0,0 +1,176 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include + + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int ILLEGAL_TYPE_OF_ARGUMENT; +} + +template +struct AggregateFunctionLeastSqrData +{ + size_t count = 0; + Ret sum_x = 0; + Ret sum_y = 0; + Ret sum_xx = 0; + Ret sum_xy = 0; + + void add(X x, Y y) + { + count += 1; + sum_x += x; + sum_y += y; + sum_xx += x * x; + sum_xy += x * y; + } + + void merge(const AggregateFunctionLeastSqrData & other) + { + count += other.count; + sum_x += other.sum_x; + sum_y += other.sum_y; + sum_xx += other.sum_xx; + sum_xy += other.sum_xy; + } + + void serialize(WriteBuffer & buf) const + { + writeBinary(count, buf); + writeBinary(sum_x, buf); + writeBinary(sum_y, buf); + writeBinary(sum_xx, buf); + writeBinary(sum_xy, buf); + } + + void deserialize(ReadBuffer & buf) + { + readBinary(count, buf); + readBinary(sum_x, buf); + readBinary(sum_y, buf); + readBinary(sum_xx, buf); + readBinary(sum_xy, buf); + } + + Ret getK() const + { + return (sum_xy * count - sum_x * sum_y) + / (sum_xx * count - sum_x * sum_x); + } + + Ret getB(Ret k) const + { + return (sum_y - k * sum_x) / count; + } +}; + +template +class AggregateFunctionLeastSqr final : public IAggregateFunctionDataHelper< + AggregateFunctionLeastSqrData, + AggregateFunctionLeastSqr +> +{ +public: + AggregateFunctionLeastSqr( + const DataTypes & arguments, + const Array & params + ): + IAggregateFunctionDataHelper< + AggregateFunctionLeastSqrData, + AggregateFunctionLeastSqr + > {arguments, params} + { + // notice: arguments has been checked before + } + + void add( + AggregateDataPtr place, + const IColumn ** columns, + size_t row_num, + Arena * + ) const override + { + X x = (*columns[0])[row_num].template get(); + Y y = (*columns[1])[row_num].template get(); + + this->data(place).add(x, y); + } + + void merge( + AggregateDataPtr place, + ConstAggregateDataPtr rhs, Arena * + ) const override + { + this->data(place).merge(this->data(rhs)); + } + + void serialize( + ConstAggregateDataPtr place, + WriteBuffer & buf + ) const override + { + this->data(place).serialize(buf); + } + + void deserialize( + AggregateDataPtr place, + ReadBuffer & buf, Arena * + ) const override + { + this->data(place).deserialize(buf); + } + + DataTypePtr getReturnType() const override + { + DataTypes types { + std::make_shared( + std::make_shared() + ), + std::make_shared( + std::make_shared() + ), + }; + + Strings names { + "k", + "b", + }; + + return std::make_shared( + std::move(types), + std::move(names) + ); + } + + void insertResultInto( + ConstAggregateDataPtr place, + IColumn & to + ) const override + { + Ret k = this->data(place).getK(); + Ret b = this->data(place).getB(k); + + Tuple result; + result.toUnderType().reserve(2); + + result.toUnderType().emplace_back(k); + result.toUnderType().emplace_back(b); + + to.insert(std::move(result)); + } + + String getName() const override { return "leastSqr"; } + const char * getHeaderFilePath() const override { return __FILE__; } +}; + +} diff --git a/dbms/src/AggregateFunctions/registerAggregateFunctions.cpp b/dbms/src/AggregateFunctions/registerAggregateFunctions.cpp index 0ef138119f9..2d5a0eafc07 100644 --- a/dbms/src/AggregateFunctions/registerAggregateFunctions.cpp +++ b/dbms/src/AggregateFunctions/registerAggregateFunctions.cpp @@ -29,6 +29,7 @@ void registerAggregateFunctionsBitwise(AggregateFunctionFactory &); void registerAggregateFunctionsBitmap(AggregateFunctionFactory &); void registerAggregateFunctionsMaxIntersections(AggregateFunctionFactory &); void registerAggregateFunctionEntropy(AggregateFunctionFactory &); +void registerAggregateFunctionLeastSqr(AggregateFunctionFactory &); void registerAggregateFunctionCombinatorIf(AggregateFunctionCombinatorFactory &); void registerAggregateFunctionCombinatorArray(AggregateFunctionCombinatorFactory &); @@ -69,6 +70,7 @@ void registerAggregateFunctions() registerAggregateFunctionHistogram(factory); registerAggregateFunctionRetention(factory); registerAggregateFunctionEntropy(factory); + registerAggregateFunctionLeastSqr(factory); } { diff --git a/dbms/tests/queries/0_stateless/00917_least_sqr.reference b/dbms/tests/queries/0_stateless/00917_least_sqr.reference new file mode 100644 index 00000000000..89d168b03bb --- /dev/null +++ b/dbms/tests/queries/0_stateless/00917_least_sqr.reference @@ -0,0 +1,7 @@ +(10,90) +(10.3,89.5) +(10,-90) +(1,1) +(nan,nan) +(0,3) +(nan,nan) diff --git a/dbms/tests/queries/0_stateless/00917_least_sqr.sql b/dbms/tests/queries/0_stateless/00917_least_sqr.sql new file mode 100644 index 00000000000..80f28a6abd9 --- /dev/null +++ b/dbms/tests/queries/0_stateless/00917_least_sqr.sql @@ -0,0 +1,7 @@ +select arrayReduce('leastSqr', [1, 2, 3, 4], [100, 110, 120, 130]); +select arrayReduce('leastSqr', [1, 2, 3, 4], [100, 110, 120, 131]); +select arrayReduce('leastSqr', [-1, -2, -3, -4], [-100, -110, -120, -130]); +select arrayReduce('leastSqr', [5, 5.1], [6, 6.1]); +select arrayReduce('leastSqr', [0], [0]); +select arrayReduce('leastSqr', [3, 4], [3, 3]); +select arrayReduce('leastSqr', [3, 3], [3, 4]);