From 88fec929213229f7987758cd87cdcdf531e66e0a Mon Sep 17 00:00:00 2001 From: nikitamikhaylov Date: Thu, 5 Nov 2020 22:08:49 +0300 Subject: [PATCH] initial (cherry picked from commit 158665e84f95f5677bacab84ab362bee37b8b4af) --- .../AggregateFunctionStudentTTest.cpp | 52 ++++ .../AggregateFunctionStudentTTest.h | 253 +++++++++++++++++ .../AggregateFunctionWelchTTest.cpp | 49 ++++ .../AggregateFunctionWelchTTest.h | 264 ++++++++++++++++++ .../registerAggregateFunctions.cpp | 6 +- 5 files changed, 622 insertions(+), 2 deletions(-) create mode 100644 src/AggregateFunctions/AggregateFunctionStudentTTest.cpp create mode 100644 src/AggregateFunctions/AggregateFunctionStudentTTest.h create mode 100644 src/AggregateFunctions/AggregateFunctionWelchTTest.cpp create mode 100644 src/AggregateFunctions/AggregateFunctionWelchTTest.h diff --git a/src/AggregateFunctions/AggregateFunctionStudentTTest.cpp b/src/AggregateFunctions/AggregateFunctionStudentTTest.cpp new file mode 100644 index 00000000000..a2c36e43488 --- /dev/null +++ b/src/AggregateFunctions/AggregateFunctionStudentTTest.cpp @@ -0,0 +1,52 @@ +#include +#include +#include +#include "registerAggregateFunctions.h" + +#include +#include + + +// the return type is boolean (we use UInt8 as we do not have boolean in clickhouse) + +namespace ErrorCodes +{ +extern const int NOT_IMPLEMENTED; +} + +namespace DB +{ + +namespace +{ + +AggregateFunctionPtr createAggregateFunctionStudentTTest(const std::string & name, const DataTypes & argument_types, const Array & parameters) +{ + assertBinary(name, argument_types); + assertNoParameters(name, parameters); + + AggregateFunctionPtr res; + + if (isDecimal(argument_types[0]) || isDecimal(argument_types[1])) + { + throw Exception("Aggregate function " + name + " only supports numerical types", ErrorCodes::NOT_IMPLEMENTED); + } + else + { + res.reset(createWithTwoNumericTypes(*argument_types[0], *argument_types[1], argument_types)); + } + + if (!res) + { + throw Exception("Aggregate function " + name + " only supports numerical types", ErrorCodes::NOT_IMPLEMENTED); + } + + return res; +} +} + +void registerAggregateFunctionStudentTTest(AggregateFunctionFactory & factory) +{ + factory.registerFunction("studentTTest", createAggregateFunctionStudentTTest, AggregateFunctionFactory::CaseInsensitive); +} +} diff --git a/src/AggregateFunctions/AggregateFunctionStudentTTest.h b/src/AggregateFunctions/AggregateFunctionStudentTTest.h new file mode 100644 index 00000000000..d260a6be980 --- /dev/null +++ b/src/AggregateFunctions/AggregateFunctionStudentTTest.h @@ -0,0 +1,253 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace ErrorCodes +{ +extern const int BAD_ARGUMENTS; +} + +namespace DB +{ + +template +struct AggregateFunctionStudentTTestData final +{ + size_t size_x = 0; + size_t size_y = 0; + X sum_x = static_cast(0); + Y sum_y = static_cast(0); + X square_sum_x = static_cast(0); + Y square_sum_y = static_cast(0); + Float64 mean_x = static_cast(0); + Float64 mean_y = static_cast(0); + + void add(X x, Y y) + { + sum_x += x; + sum_y += y; + size_x++; + size_y++; + mean_x = static_cast(sum_x) / size_x; + mean_y = static_cast(sum_y) / size_y; + square_sum_x += x * x; + square_sum_y += y * y; + } + + void merge(const AggregateFunctionStudentTTestData &other) + { + sum_x += other.sum_x; + sum_y += other.sum_y; + size_x += other.size_x; + size_y += other.size_y; + mean_x = static_cast(sum_x) / size_x; + mean_y = static_cast(sum_y) / size_y; + square_sum_x += other.square_sum_x; + square_sum_y += other.square_sum_y; + } + + void serialize(WriteBuffer &buf) const + { + writeBinary(mean_x, buf); + writeBinary(mean_y, buf); + writeBinary(sum_x, buf); + writeBinary(sum_y, buf); + writeBinary(square_sum_x, buf); + writeBinary(square_sum_y, buf); + writeBinary(size_x, buf); + writeBinary(size_y, buf); + } + + void deserialize(ReadBuffer &buf) + { + readBinary(mean_x, buf); + readBinary(mean_y, buf); + readBinary(sum_x, buf); + readBinary(sum_y, buf); + readBinary(square_sum_x, buf); + readBinary(square_sum_y, buf); + readBinary(size_x, buf); + readBinary(size_y, buf); + } + + size_t getSizeY() const + { + return size_y; + } + + size_t getSizeX() const + { + return size_x; + } + + Float64 getSSquared() const + { + /// The original formulae looks like + /// \frac{\sum_{i = 1}^{n_x}{(x_i - \bar{x}) ^ 2} + \sum_{i = 1}^{n_y}{(y_i - \bar{y}) ^ 2}}{n_x + n_y - 2} + /// But we made some mathematical transformations not to store original sequences. + /// Also we dropped sqrt, because later it will be squared later. + const Float64 all_x = square_sum_x + size_x * std::pow(mean_x, 2) - 2 * mean_x * sum_x; + const Float64 all_y = square_sum_y + size_y * std::pow(mean_y, 2) - 2 * mean_y * sum_y; + return static_cast(all_x + all_y) / (size_x + size_y - 2); + } + + + Float64 getTStatisticSquared() const + { + return std::pow(mean_x - mean_y, 2) / getStandartErrorSquared(); + } + + Float64 getTStatistic() const + { + return (mean_x - mean_y) / std::sqrt(getStandartErrorSquared()); + } + + Float64 getStandartErrorSquared() const + { + if (size_x == 0 || size_y == 0) + throw Exception("Division by zero encountered in Aggregate function StudentTTest", ErrorCodes::BAD_ARGUMENTS); + + return getSSquared() * (1.0 / static_cast(size_x) + 1.0 / static_cast(size_y)); + } + + Float64 getDegreesOfFreedom() const + { + return static_cast(size_x + size_y - 2); + } + + static Float64 integrateSimpson(Float64 a, Float64 b, std::function func) + { + const size_t iterations = std::max(1e6, 1e4 * std::abs(std::round(b))); + const long double h = (b - a) / iterations; + Float64 sum_odds = 0.0; + for (size_t i = 1; i < iterations; i += 2) + sum_odds += func(a + i * h); + Float64 sum_evens = 0.0; + for (size_t i = 2; i < iterations; i += 2) + sum_evens += func(a + i * h); + return (func(a) + func(b) + 2 * sum_evens + 4 * sum_odds) * h / 3; + } + + Float64 getPValue() const + { + const Float64 v = getDegreesOfFreedom(); + const Float64 t = getTStatisticSquared(); + auto f = [&v] (double x) { return std::pow(x, v/2 - 1) / std::sqrt(1 - x); }; + Float64 numenator = integrateSimpson(0, v / (t + v), f); + Float64 denominator = std::exp(std::lgammal(v/2) + std::lgammal(0.5) - std::lgammal(v/2 + 0.5)); + return numenator / denominator; + } + + std::pair getResult() const + { + return std::make_pair(getTStatistic(), getPValue()); + } +}; + +/// Returns tuple of (t-statistic, p-value) +/// https://cpb-us-w2.wpmucdn.com/voices.uchicago.edu/dist/9/1193/files/2016/01/05b-TandP.pdf +template +class AggregateFunctionStudentTTest : + public IAggregateFunctionDataHelper,AggregateFunctionStudentTTest> +{ + +public: + AggregateFunctionStudentTTest(const DataTypes & arguments) + : IAggregateFunctionDataHelper, AggregateFunctionStudentTTest> ({arguments}, {}) + {} + + String getName() const override + { + return "studentTTest"; + } + + DataTypePtr getReturnType() const override + { + DataTypes types + { + std::make_shared>(), + std::make_shared>(), + }; + + Strings names + { + "t-statistic", + "p-value" + }; + + return std::make_shared( + std::move(types), + std::move(names) + ); + } + + void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena *) const override + { + auto col_x = assert_cast *>(columns[0]); + auto col_y = assert_cast *>(columns[1]); + + X x = col_x->getData()[row_num]; + Y y = col_y->getData()[row_num]; + + this->data(place).add(x, y); + } + + void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena *) const override + { + this->data(place).merge(this->data(rhs)); + } + + void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override + { + this->data(place).serialize(buf); + } + + void deserialize(AggregateDataPtr place, ReadBuffer & buf, Arena *) const override + { + this->data(place).deserialize(buf); + } + + void insertResultInto(AggregateDataPtr place, IColumn & to, Arena * /*arena*/) const override + { + size_t size_x = this->data(place).getSizeX(); + size_t size_y = this->data(place).getSizeY(); + + if (size_x < 2 || size_y < 2) + { + throw Exception("Aggregate function " + getName() + " requires samples to be of size > 1", ErrorCodes::BAD_ARGUMENTS); + } + + Float64 t_statistic = 0.0; + Float64 p_value = 0.0; + std::tie(t_statistic, p_value) = this->data(place).getResult(); + + /// Because p-value is a probability. + p_value = std::min(1.0, std::max(0.0, p_value)); + + auto & column_tuple = assert_cast(to); + auto & column_stat = assert_cast &>(column_tuple.getColumn(0)); + auto & column_value = assert_cast &>(column_tuple.getColumn(1)); + + column_stat.getData().push_back(t_statistic); + column_value.getData().push_back(p_value); + } + +}; + +}; diff --git a/src/AggregateFunctions/AggregateFunctionWelchTTest.cpp b/src/AggregateFunctions/AggregateFunctionWelchTTest.cpp new file mode 100644 index 00000000000..483c99dde9b --- /dev/null +++ b/src/AggregateFunctions/AggregateFunctionWelchTTest.cpp @@ -0,0 +1,49 @@ +#include +#include +#include +#include "registerAggregateFunctions.h" + +#include +#include + +namespace ErrorCodes +{ +extern const int NOT_IMPLEMENTED; +} + +namespace DB +{ + +namespace +{ + +AggregateFunctionPtr createAggregateFunctionWelchTTest(const std::string & name, const DataTypes & argument_types, const Array & parameters) +{ + assertBinary(name, argument_types); + assertNoParameters(name, parameters); + + AggregateFunctionPtr res; + + if (isDecimal(argument_types[0]) || isDecimal(argument_types[1])) + { + throw Exception("Aggregate function " + name + " only supports numerical types", ErrorCodes::NOT_IMPLEMENTED); + } + else + { + res.reset(createWithTwoNumericTypes(*argument_types[0], *argument_types[1], argument_types)); + } + + if (!res) + { + throw Exception("Aggregate function " + name + " only supports numerical types", ErrorCodes::NOT_IMPLEMENTED); + } + + return res; +} +} + +void registerAggregateFunctionWelchTTest(AggregateFunctionFactory & factory) +{ + factory.registerFunction("welchTTest", createAggregateFunctionWelchTTest, AggregateFunctionFactory::CaseInsensitive); +} +} diff --git a/src/AggregateFunctions/AggregateFunctionWelchTTest.h b/src/AggregateFunctions/AggregateFunctionWelchTTest.h new file mode 100644 index 00000000000..175e0171606 --- /dev/null +++ b/src/AggregateFunctions/AggregateFunctionWelchTTest.h @@ -0,0 +1,264 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace ErrorCodes +{ +extern const int BAD_ARGUMENTS; +} + +namespace DB +{ + +template +struct AggregateFunctionWelchTTestData final +{ + size_t size_x = 0; + size_t size_y = 0; + X sum_x = static_cast(0); + Y sum_y = static_cast(0); + X square_sum_x = static_cast(0); + Y square_sum_y = static_cast(0); + Float64 mean_x = static_cast(0); + Float64 mean_y = static_cast(0); + + void add(X x, Y y) + { + sum_x += x; + sum_y += y; + size_x++; + size_y++; + mean_x = static_cast(sum_x) / size_x; + mean_y = static_cast(sum_y) / size_y; + square_sum_x += x * x; + square_sum_y += y * y; + } + + void merge(const AggregateFunctionWelchTTestData &other) + { + sum_x += other.sum_x; + sum_y += other.sum_y; + size_x += other.size_x; + size_y += other.size_y; + mean_x = static_cast(sum_x) / size_x; + mean_y = static_cast(sum_y) / size_y; + square_sum_x += other.square_sum_x; + square_sum_y += other.square_sum_y; + } + + void serialize(WriteBuffer &buf) const + { + writeBinary(mean_x, buf); + writeBinary(mean_y, buf); + writeBinary(sum_x, buf); + writeBinary(sum_y, buf); + writeBinary(square_sum_x, buf); + writeBinary(square_sum_y, buf); + writeBinary(size_x, buf); + writeBinary(size_y, buf); + } + + void deserialize(ReadBuffer &buf) + { + readBinary(mean_x, buf); + readBinary(mean_y, buf); + readBinary(sum_x, buf); + readBinary(sum_y, buf); + readBinary(square_sum_x, buf); + readBinary(square_sum_y, buf); + readBinary(size_x, buf); + readBinary(size_y, buf); + } + + size_t getSizeY() const + { + return size_y; + } + + size_t getSizeX() const + { + return size_x; + } + + Float64 getSxSquared() const + { + /// The original formulae looks like \frac{1}{size_x - 1} \sum_{i = 1}^{size_x}{(x_i - \bar{x}) ^ 2} + /// But we made some mathematical transformations not to store original sequences. + /// Also we dropped sqrt, because later it will be squared later. + return static_cast(square_sum_x + size_x * std::pow(mean_x, 2) - 2 * mean_x * sum_x) / (size_x - 1); + } + + Float64 getSySquared() const + { + /// The original formulae looks like \frac{1}{size_y - 1} \sum_{i = 1}^{size_y}{(y_i - \bar{y}) ^ 2} + /// But we made some mathematical transformations not to store original sequences. + /// Also we dropped sqrt, because later it will be squared later. + return static_cast(square_sum_y + size_y * std::pow(mean_y, 2) - 2 * mean_y * sum_y) / (size_y - 1); + } + + Float64 getTStatisticSquared() const + { + if (size_x == 0 || size_y == 0) + { + throw Exception("Division by zero encountered in Aggregate function WelchTTest", ErrorCodes::BAD_ARGUMENTS); + } + + return std::pow(mean_x - mean_y, 2) / (getSxSquared() / size_x + getSySquared() / size_y); + } + + Float64 getTStatistic() const + { + if (size_x == 0 || size_y == 0) + { + throw Exception("Division by zero encountered in Aggregate function WelchTTest", ErrorCodes::BAD_ARGUMENTS); + } + + return (mean_x - mean_y) / std::sqrt(getSxSquared() / size_x + getSySquared() / size_y); + } + + Float64 getDegreesOfFreedom() const + { + auto sx = getSxSquared(); + auto sy = getSySquared(); + Float64 numerator = std::pow(sx / size_x + sy / size_y, 2); + Float64 denominator_first = std::pow(sx, 2) / (std::pow(size_x, 2) * (size_x - 1)); + Float64 denominator_second = std::pow(sy, 2) / (std::pow(size_y, 2) * (size_y - 1)); + return numerator / (denominator_first + denominator_second); + } + + static Float64 integrateSimpson(Float64 a, Float64 b, std::function func) + { + size_t iterations = std::max(1e6, 1e4 * std::abs(std::round(b))); + double h = (b - a) / iterations; + Float64 sum_odds = 0.0; + for (size_t i = 1; i < iterations; i += 2) + sum_odds += func(a + i * h); + Float64 sum_evens = 0.0; + for (size_t i = 2; i < iterations; i += 2) + sum_evens += func(a + i * h); + return (func(a) + func(b) + 2 * sum_evens + 4 * sum_odds) * h / 3; + } + + Float64 getPValue() const + { + const Float64 v = getDegreesOfFreedom(); + const Float64 t = getTStatisticSquared(); + auto f = [&v] (double x) { return std::pow(x, v/2 - 1) / std::sqrt(1 - x); }; + Float64 numenator = integrateSimpson(0, v / (t + v), f); + Float64 denominator = std::exp(std::lgammal(v/2) + std::lgammal(0.5) - std::lgammal(v/2 + 0.5)); + return numenator / denominator; + } + + std::pair getResult() const + { + return std::make_pair(getTStatistic(), getPValue()); + } +}; + +/// Returns tuple of (t-statistic, p-value) +/// https://cpb-us-w2.wpmucdn.com/voices.uchicago.edu/dist/9/1193/files/2016/01/05b-TandP.pdf +template +class AggregateFunctionWelchTTest : + public IAggregateFunctionDataHelper,AggregateFunctionWelchTTest> +{ + +public: + AggregateFunctionWelchTTest(const DataTypes & arguments) + : IAggregateFunctionDataHelper, AggregateFunctionWelchTTest> ({arguments}, {}) + {} + + String getName() const override + { + return "welchTTest"; + } + + DataTypePtr getReturnType() const override + { + DataTypes types + { + std::make_shared>(), + std::make_shared>(), + }; + + Strings names + { + "t-statistic", + "p-value" + }; + + return std::make_shared( + std::move(types), + std::move(names) + ); + } + + void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena *) const override + { + auto col_x = assert_cast *>(columns[0]); + auto col_y = assert_cast *>(columns[1]); + + X x = col_x->getData()[row_num]; + Y y = col_y->getData()[row_num]; + + this->data(place).add(x, y); + } + + void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena *) const override + { + this->data(place).merge(this->data(rhs)); + } + + void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override + { + this->data(place).serialize(buf); + } + + void deserialize(AggregateDataPtr place, ReadBuffer & buf, Arena *) const override + { + this->data(place).deserialize(buf); + } + + void insertResultInto(AggregateDataPtr place, IColumn & to, Arena * /*arena*/) const override + { + size_t size_x = this->data(place).getSizeX(); + size_t size_y = this->data(place).getSizeY(); + + if (size_x < 2 || size_y < 2) + { + throw Exception("Aggregate function " + getName() + " requires samples to be of size > 1", ErrorCodes::BAD_ARGUMENTS); + } + + Float64 t_statistic = 0.0; + Float64 p_value = 0.0; + std::tie(t_statistic, p_value) = this->data(place).getResult(); + + /// Because p-value is a probability. + p_value = std::min(1.0, std::max(0.0, p_value)); + + auto & column_tuple = assert_cast(to); + auto & column_stat = assert_cast &>(column_tuple.getColumn(0)); + auto & column_value = assert_cast &>(column_tuple.getColumn(1)); + + column_stat.getData().push_back(t_statistic); + column_value.getData().push_back(p_value); + } + +}; + +}; diff --git a/src/AggregateFunctions/registerAggregateFunctions.cpp b/src/AggregateFunctions/registerAggregateFunctions.cpp index 02c3867f6f3..41e8d32a3e3 100644 --- a/src/AggregateFunctions/registerAggregateFunctions.cpp +++ b/src/AggregateFunctions/registerAggregateFunctions.cpp @@ -39,10 +39,10 @@ void registerAggregateFunctionSimpleLinearRegression(AggregateFunctionFactory &) void registerAggregateFunctionMoving(AggregateFunctionFactory &); void registerAggregateFunctionCategoricalIV(AggregateFunctionFactory &); void registerAggregateFunctionAggThrow(AggregateFunctionFactory &); -void registerAggregateFunctionWelchTTest(AggregateFunctionFactory &); -void registerAggregateFunctionStudentTTest(AggregateFunctionFactory &); void registerAggregateFunctionRankCorrelation(AggregateFunctionFactory &); void registerAggregateFunctionMannWhitney(AggregateFunctionFactory &); +void registerAggregateFunctionWelchTTest(AggregateFunctionFactory &); +void registerAggregateFunctionStudentTTest(AggregateFunctionFactory &); class AggregateFunctionCombinatorFactory; void registerAggregateFunctionCombinatorIf(AggregateFunctionCombinatorFactory &); @@ -96,6 +96,8 @@ void registerAggregateFunctions() registerAggregateFunctionAggThrow(factory); registerAggregateFunctionRankCorrelation(factory); registerAggregateFunctionMannWhitney(factory); + registerAggregateFunctionWelchTTest(factory); + registerAggregateFunctionStudentTTest(factory); } {