diff --git a/src/AggregateFunctions/AggregateFunctionStatisticsSimple.h b/src/AggregateFunctions/AggregateFunctionStatisticsSimple.h index e23c220dd3d..5caa30dbdab 100644 --- a/src/AggregateFunctions/AggregateFunctionStatisticsSimple.h +++ b/src/AggregateFunctions/AggregateFunctionStatisticsSimple.h @@ -8,6 +8,7 @@ #include #include +#include #include #include @@ -30,310 +31,6 @@ namespace DB { -namespace ErrorCodes -{ - extern const int DECIMAL_OVERFLOW; -} - - -/** - Calculating univariate central moments - Levels: - level 2 (pop & samp): var, stddev - level 3: skewness - level 4: kurtosis - References: - https://en.wikipedia.org/wiki/Moment_(mathematics) - https://en.wikipedia.org/wiki/Skewness - https://en.wikipedia.org/wiki/Kurtosis -*/ -template -struct VarMoments -{ - T m[_level + 1]{}; - - void add(T x) - { - ++m[0]; - m[1] += x; - m[2] += x * x; - if constexpr (_level >= 3) m[3] += x * x * x; - if constexpr (_level >= 4) m[4] += x * x * x * x; - } - - void merge(const VarMoments & rhs) - { - m[0] += rhs.m[0]; - m[1] += rhs.m[1]; - m[2] += rhs.m[2]; - if constexpr (_level >= 3) m[3] += rhs.m[3]; - if constexpr (_level >= 4) m[4] += rhs.m[4]; - } - - void write(WriteBuffer & buf) const - { - writePODBinary(*this, buf); - } - - void read(ReadBuffer & buf) - { - readPODBinary(*this, buf); - } - - T getPopulation() const - { - if (m[0] == 0) - return std::numeric_limits::quiet_NaN(); - - /// Due to numerical errors, the result can be slightly less than zero, - /// but it should be impossible. Trim to zero. - - return std::max(T{}, (m[2] - m[1] * m[1] / m[0]) / m[0]); - } - - T getSample() const - { - if (m[0] <= 1) - return std::numeric_limits::quiet_NaN(); - return std::max(T{}, (m[2] - m[1] * m[1] / m[0]) / (m[0] - 1)); - } - - T getMoment3() const - { - if (m[0] == 0) - return std::numeric_limits::quiet_NaN(); - // to avoid accuracy problem - if (m[0] == 1) - return 0; - return (m[3] - - (3 * m[2] - - 2 * m[1] * m[1] / m[0] - ) * m[1] / m[0] - ) / m[0]; - } - - T getMoment4() const - { - if (m[0] == 0) - return std::numeric_limits::quiet_NaN(); - // to avoid accuracy problem - if (m[0] == 1) - return 0; - return (m[4] - - (4 * m[3] - - (6 * m[2] - - 3 * m[1] * m[1] / m[0] - ) * m[1] / m[0] - ) * m[1] / m[0] - ) / m[0]; - } -}; - -template -class VarMomentsDecimal -{ -public: - using NativeType = typename T::NativeType; - - void add(NativeType x) - { - ++m0; - getM(1) += x; - - NativeType tmp; - bool overflow = common::mulOverflow(x, x, tmp) || common::addOverflow(getM(2), tmp, getM(2)); - if constexpr (_level >= 3) - overflow = overflow || common::mulOverflow(tmp, x, tmp) || common::addOverflow(getM(3), tmp, getM(3)); - if constexpr (_level >= 4) - overflow = overflow || common::mulOverflow(tmp, x, tmp) || common::addOverflow(getM(4), tmp, getM(4)); - - if (overflow) - throw Exception("Decimal math overflow", ErrorCodes::DECIMAL_OVERFLOW); - } - - void merge(const VarMomentsDecimal & rhs) - { - m0 += rhs.m0; - getM(1) += rhs.getM(1); - - bool overflow = common::addOverflow(getM(2), rhs.getM(2), getM(2)); - if constexpr (_level >= 3) - overflow = overflow || common::addOverflow(getM(3), rhs.getM(3), getM(3)); - if constexpr (_level >= 4) - overflow = overflow || common::addOverflow(getM(4), rhs.getM(4), getM(4)); - - if (overflow) - throw Exception("Decimal math overflow", ErrorCodes::DECIMAL_OVERFLOW); - } - - void write(WriteBuffer & buf) const { writePODBinary(*this, buf); } - void read(ReadBuffer & buf) { readPODBinary(*this, buf); } - - Float64 getPopulation(UInt32 scale) const - { - if (m0 == 0) - return std::numeric_limits::infinity(); - - NativeType tmp; - if (common::mulOverflow(getM(1), getM(1), tmp) || - common::subOverflow(getM(2), NativeType(tmp / m0), tmp)) - throw Exception("Decimal math overflow", ErrorCodes::DECIMAL_OVERFLOW); - return std::max(Float64{}, DecimalUtils::convertTo(T(tmp / m0), scale)); - } - - Float64 getSample(UInt32 scale) const - { - if (m0 == 0) - return std::numeric_limits::quiet_NaN(); - if (m0 == 1) - return std::numeric_limits::infinity(); - - NativeType tmp; - if (common::mulOverflow(getM(1), getM(1), tmp) || - common::subOverflow(getM(2), NativeType(tmp / m0), tmp)) - throw Exception("Decimal math overflow", ErrorCodes::DECIMAL_OVERFLOW); - return std::max(Float64{}, DecimalUtils::convertTo(T(tmp / (m0 - 1)), scale)); - } - - Float64 getMoment3(UInt32 scale) const - { - if (m0 == 0) - return std::numeric_limits::infinity(); - - NativeType tmp; - if (common::mulOverflow(2 * getM(1), getM(1), tmp) || - common::subOverflow(3 * getM(2), NativeType(tmp / m0), tmp) || - common::mulOverflow(tmp, getM(1), tmp) || - common::subOverflow(getM(3), NativeType(tmp / m0), tmp)) - throw Exception("Decimal math overflow", ErrorCodes::DECIMAL_OVERFLOW); - return DecimalUtils::convertTo(T(tmp / m0), scale); - } - - Float64 getMoment4(UInt32 scale) const - { - if (m0 == 0) - return std::numeric_limits::infinity(); - - NativeType tmp; - if (common::mulOverflow(3 * getM(1), getM(1), tmp) || - common::subOverflow(6 * getM(2), NativeType(tmp / m0), tmp) || - common::mulOverflow(tmp, getM(1), tmp) || - common::subOverflow(4 * getM(3), NativeType(tmp / m0), tmp) || - common::mulOverflow(tmp, getM(1), tmp) || - common::subOverflow(getM(4), NativeType(tmp / m0), tmp)) - throw Exception("Decimal math overflow", ErrorCodes::DECIMAL_OVERFLOW); - return DecimalUtils::convertTo(T(tmp / m0), scale); - } - -private: - UInt64 m0{}; - NativeType m[_level]{}; - - NativeType & getM(size_t i) { return m[i - 1]; } - const NativeType & getM(size_t i) const { return m[i - 1]; } -}; - -/** - Calculating multivariate central moments - Levels: - level 2 (pop & samp): covar - References: - https://en.wikipedia.org/wiki/Moment_(mathematics) -*/ -template -struct CovarMoments -{ - T m0{}; - T x1{}; - T y1{}; - T xy{}; - - void add(T x, T y) - { - ++m0; - x1 += x; - y1 += y; - xy += x * y; - } - - void merge(const CovarMoments & rhs) - { - m0 += rhs.m0; - x1 += rhs.x1; - y1 += rhs.y1; - xy += rhs.xy; - } - - void write(WriteBuffer & buf) const - { - writePODBinary(*this, buf); - } - - void read(ReadBuffer & buf) - { - readPODBinary(*this, buf); - } - - T NO_SANITIZE_UNDEFINED getPopulation() const - { - return (xy - x1 * y1 / m0) / m0; - } - - T NO_SANITIZE_UNDEFINED getSample() const - { - if (m0 == 0) - return std::numeric_limits::quiet_NaN(); - return (xy - x1 * y1 / m0) / (m0 - 1); - } -}; - -template -struct CorrMoments -{ - T m0{}; - T x1{}; - T y1{}; - T xy{}; - T x2{}; - T y2{}; - - void add(T x, T y) - { - ++m0; - x1 += x; - y1 += y; - xy += x * y; - x2 += x * x; - y2 += y * y; - } - - void merge(const CorrMoments & rhs) - { - m0 += rhs.m0; - x1 += rhs.x1; - y1 += rhs.y1; - xy += rhs.xy; - x2 += rhs.x2; - y2 += rhs.y2; - } - - void write(WriteBuffer & buf) const - { - writePODBinary(*this, buf); - } - - void read(ReadBuffer & buf) - { - readPODBinary(*this, buf); - } - - T NO_SANITIZE_UNDEFINED get() const - { - return (m0 * xy - x1 * y1) / sqrt((m0 * x2 - x1 * x1) * (m0 * y2 - y1 * y1)); - } -}; - - enum class StatisticsFunctionKind { varPop, varSamp, diff --git a/src/AggregateFunctions/AggregateFunctionStudentTTest.cpp b/src/AggregateFunctions/AggregateFunctionStudentTTest.cpp index 58fc9e5b5b9..3137f9718bd 100644 --- a/src/AggregateFunctions/AggregateFunctionStudentTTest.cpp +++ b/src/AggregateFunctions/AggregateFunctionStudentTTest.cpp @@ -1,52 +1,70 @@ #include -#include +#include #include +#include + #include "registerAggregateFunctions.h" -#include -#include - - -// the return type is boolean (we use UInt8 as we do not have boolean in clickhouse) namespace ErrorCodes { -extern const int NOT_IMPLEMENTED; + extern const int BAD_ARGUMENTS; } + namespace DB { namespace { +struct StudentTTestData : public TTestMoments +{ + static constexpr auto name = "studentTTest"; + + std::pair getResult() const + { + Float64 degrees_of_freedom = 2.0 * (m0 - 1); + + Float64 mean_x = x1 / m0; + Float64 mean_y = y1 / m0; + + /// Calculate s^2 + + /// The original formulae looks like + /// \frac{\sum_{i = 1}^{n_x}{(x_i - \bar{x}) ^ 2} + \sum_{i = 1}^{n_y}{(y_i - \bar{y}) ^ 2}}{n_x + n_y - 2} + /// But we made some mathematical transformations not to store original sequences. + /// Also we dropped sqrt, because later it will be squared later. + + Float64 all_x = x2 + m0 * mean_x * mean_x - 2 * mean_x * m0; + Float64 all_y = y2 + m0 * mean_y * mean_y - 2 * mean_y * m0; + + Float64 s2 = (all_x + all_y) / degrees_of_freedom; + Float64 std_err2 = 2.0 * s2 / m0; + + /// t-statistic, squared + Float64 t_stat = (mean_x - mean_y) / sqrt(std_err2); + + return {t_stat, getPValue(degrees_of_freedom, t_stat * t_stat)}; + } +}; + AggregateFunctionPtr createAggregateFunctionStudentTTest(const std::string & name, const DataTypes & argument_types, const Array & parameters) { assertBinary(name, argument_types); assertNoParameters(name, parameters); - AggregateFunctionPtr res; + if (!isNumber(argument_types[0]) || !isNumber(argument_types[1])) + throw Exception("Aggregate function " + name + " only supports numerical types", ErrorCodes::BAD_ARGUMENTS); - if (isDecimal(argument_types[0]) || isDecimal(argument_types[1])) - { - throw Exception("Aggregate function " + name + " only supports numerical types", ErrorCodes::NOT_IMPLEMENTED); - } - else - { - res.reset(createWithTwoNumericTypes(*argument_types[0], *argument_types[1], argument_types)); - } - - if (!res) - { - throw Exception("Aggregate function " + name + " only supports numerical types", ErrorCodes::NOT_IMPLEMENTED); - } - - return res; + return std::make_shared>(argument_types); } + } void registerAggregateFunctionStudentTTest(AggregateFunctionFactory & factory) { factory.registerFunction("studentTTest", createAggregateFunctionStudentTTest); } + } diff --git a/src/AggregateFunctions/AggregateFunctionStudentTTest.h b/src/AggregateFunctions/AggregateFunctionStudentTTest.h deleted file mode 100644 index 0aef8f3ee2a..00000000000 --- a/src/AggregateFunctions/AggregateFunctionStudentTTest.h +++ /dev/null @@ -1,262 +0,0 @@ -#pragma once - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include - -namespace ErrorCodes -{ - extern const int BAD_ARGUMENTS; -} - -#if defined(OS_DARWIN) -extern "C" -{ - double lgammal_r(double x, int * signgamp); -} -#endif - - -namespace DB -{ - -template -struct AggregateFunctionStudentTTestData final -{ - size_t size_x = 0; - size_t size_y = 0; - X sum_x = static_cast(0); - Y sum_y = static_cast(0); - X square_sum_x = static_cast(0); - Y square_sum_y = static_cast(0); - Float64 mean_x = static_cast(0); - Float64 mean_y = static_cast(0); - - void add(X x, Y y) - { - sum_x += x; - sum_y += y; - size_x++; - size_y++; - mean_x = static_cast(sum_x) / size_x; - mean_y = static_cast(sum_y) / size_y; - square_sum_x += x * x; - square_sum_y += y * y; - } - - void merge(const AggregateFunctionStudentTTestData &other) - { - sum_x += other.sum_x; - sum_y += other.sum_y; - size_x += other.size_x; - size_y += other.size_y; - mean_x = static_cast(sum_x) / size_x; - mean_y = static_cast(sum_y) / size_y; - square_sum_x += other.square_sum_x; - square_sum_y += other.square_sum_y; - } - - void serialize(WriteBuffer &buf) const - { - writeBinary(mean_x, buf); - writeBinary(mean_y, buf); - writeBinary(sum_x, buf); - writeBinary(sum_y, buf); - writeBinary(square_sum_x, buf); - writeBinary(square_sum_y, buf); - writeBinary(size_x, buf); - writeBinary(size_y, buf); - } - - void deserialize(ReadBuffer &buf) - { - readBinary(mean_x, buf); - readBinary(mean_y, buf); - readBinary(sum_x, buf); - readBinary(sum_y, buf); - readBinary(square_sum_x, buf); - readBinary(square_sum_y, buf); - readBinary(size_x, buf); - readBinary(size_y, buf); - } - - size_t getSizeY() const - { - return size_y; - } - - size_t getSizeX() const - { - return size_x; - } - - Float64 getSSquared() const - { - /// The original formulae looks like - /// \frac{\sum_{i = 1}^{n_x}{(x_i - \bar{x}) ^ 2} + \sum_{i = 1}^{n_y}{(y_i - \bar{y}) ^ 2}}{n_x + n_y - 2} - /// But we made some mathematical transformations not to store original sequences. - /// Also we dropped sqrt, because later it will be squared later. - const Float64 all_x = square_sum_x + size_x * std::pow(mean_x, 2) - 2 * mean_x * sum_x; - const Float64 all_y = square_sum_y + size_y * std::pow(mean_y, 2) - 2 * mean_y * sum_y; - return static_cast(all_x + all_y) / (size_x + size_y - 2); - } - - - Float64 getTStatisticSquared() const - { - return std::pow(mean_x - mean_y, 2) / getStandartErrorSquared(); - } - - Float64 getTStatistic() const - { - return (mean_x - mean_y) / std::sqrt(getStandartErrorSquared()); - } - - Float64 getStandartErrorSquared() const - { - if (size_x == 0 || size_y == 0) - throw Exception("Division by zero encountered in Aggregate function StudentTTest", ErrorCodes::BAD_ARGUMENTS); - - return getSSquared() * (1.0 / static_cast(size_x) + 1.0 / static_cast(size_y)); - } - - Float64 getDegreesOfFreedom() const - { - return static_cast(size_x + size_y - 2); - } - - static Float64 integrateSimpson(Float64 a, Float64 b, std::function func) - { - const size_t iterations = std::max(1e6, 1e4 * std::abs(std::round(b))); - const long double h = (b - a) / iterations; - Float64 sum_odds = 0.0; - for (size_t i = 1; i < iterations; i += 2) - sum_odds += func(a + i * h); - Float64 sum_evens = 0.0; - for (size_t i = 2; i < iterations; i += 2) - sum_evens += func(a + i * h); - return (func(a) + func(b) + 2 * sum_evens + 4 * sum_odds) * h / 3; - } - - Float64 getPValue() const - { - const Float64 v = getDegreesOfFreedom(); - const Float64 t = getTStatisticSquared(); - auto f = [&v] (double x) { return std::pow(x, v/2 - 1) / std::sqrt(1 - x); }; - Float64 numenator = integrateSimpson(0, v / (t + v), f); - int unused; - Float64 denominator = std::exp(lgammal_r(v / 2, &unused) + lgammal_r(0.5, &unused) - lgammal_r(v / 2 + 0.5, &unused)); - return numenator / denominator; - } - - std::pair getResult() const - { - return std::make_pair(getTStatistic(), getPValue()); - } -}; - -/// Returns tuple of (t-statistic, p-value) -/// https://cpb-us-w2.wpmucdn.com/voices.uchicago.edu/dist/9/1193/files/2016/01/05b-TandP.pdf -template -class AggregateFunctionStudentTTest : - public IAggregateFunctionDataHelper,AggregateFunctionStudentTTest> -{ - -public: - AggregateFunctionStudentTTest(const DataTypes & arguments) - : IAggregateFunctionDataHelper, AggregateFunctionStudentTTest> ({arguments}, {}) - {} - - String getName() const override - { - return "studentTTest"; - } - - DataTypePtr getReturnType() const override - { - DataTypes types - { - std::make_shared>(), - std::make_shared>(), - }; - - Strings names - { - "t-statistic", - "p-value" - }; - - return std::make_shared( - std::move(types), - std::move(names) - ); - } - - void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena *) const override - { - auto col_x = assert_cast *>(columns[0]); - auto col_y = assert_cast *>(columns[1]); - - X x = col_x->getData()[row_num]; - Y y = col_y->getData()[row_num]; - - this->data(place).add(x, y); - } - - void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena *) const override - { - this->data(place).merge(this->data(rhs)); - } - - void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override - { - this->data(place).serialize(buf); - } - - void deserialize(AggregateDataPtr place, ReadBuffer & buf, Arena *) const override - { - this->data(place).deserialize(buf); - } - - void insertResultInto(AggregateDataPtr place, IColumn & to, Arena * /*arena*/) const override - { - size_t size_x = this->data(place).getSizeX(); - size_t size_y = this->data(place).getSizeY(); - - if (size_x < 2 || size_y < 2) - { - throw Exception("Aggregate function " + getName() + " requires samples to be of size > 1", ErrorCodes::BAD_ARGUMENTS); - } - - Float64 t_statistic = 0.0; - Float64 p_value = 0.0; - std::tie(t_statistic, p_value) = this->data(place).getResult(); - - /// Because p-value is a probability. - p_value = std::min(1.0, std::max(0.0, p_value)); - - auto & column_tuple = assert_cast(to); - auto & column_stat = assert_cast &>(column_tuple.getColumn(0)); - auto & column_value = assert_cast &>(column_tuple.getColumn(1)); - - column_stat.getData().push_back(t_statistic); - column_value.getData().push_back(p_value); - } - -}; - -}; diff --git a/src/AggregateFunctions/AggregateFunctionTTest.h b/src/AggregateFunctions/AggregateFunctionTTest.h new file mode 100644 index 00000000000..941327eb53e --- /dev/null +++ b/src/AggregateFunctions/AggregateFunctionTTest.h @@ -0,0 +1,135 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + + +/// This function is used in implementations of different T-Tests. +/// On Darwin it's unavailable in math.h but actually exists in the library (can be linked successfully). +#if defined(OS_DARWIN) +extern "C" +{ + double lgamma_r(double x, int * signgamp); +} +#endif + + +namespace DB +{ + +class ReadBuffer; +class WriteBuffer; + + +template +static Float64 integrateSimpson(Float64 a, Float64 b, F && func) +{ + const size_t iterations = std::max(1e6, 1e4 * std::abs(std::round(b))); + const long double h = (b - a) / iterations; + Float64 sum_odds = 0.0; + for (size_t i = 1; i < iterations; i += 2) + sum_odds += func(a + i * h); + Float64 sum_evens = 0.0; + for (size_t i = 2; i < iterations; i += 2) + sum_evens += func(a + i * h); + return (func(a) + func(b) + 2 * sum_evens + 4 * sum_odds) * h / 3; +} + +static inline Float64 getPValue(Float64 degrees_of_freedom, Float64 t_stat2) +{ + Float64 numerator = integrateSimpson(0, degrees_of_freedom / (t_stat2 + degrees_of_freedom), + [degrees_of_freedom](double x) { return std::pow(x, degrees_of_freedom / 2 - 1) / std::sqrt(1 - x); }); + + int unused; + Float64 denominator = std::exp( + lgamma_r(degrees_of_freedom / 2, &unused) + + lgamma_r(0.5, &unused) + - lgamma_r(degrees_of_freedom / 2 + 0.5, &unused)); + + return std::min(1.0, std::max(0.0, numerator / denominator)); +} + + +/// Returns tuple of (t-statistic, p-value) +/// https://cpb-us-w2.wpmucdn.com/voices.uchicago.edu/dist/9/1193/files/2016/01/05b-TandP.pdf +template +class AggregateFunctionTTest : + public IAggregateFunctionDataHelper> +{ +public: + AggregateFunctionTTest(const DataTypes & arguments) + : IAggregateFunctionDataHelper>({arguments}, {}) + { + } + + String getName() const override + { + return Data::name; + } + + DataTypePtr getReturnType() const override + { + DataTypes types + { + std::make_shared>(), + std::make_shared>(), + }; + + Strings names + { + "t_statistic", + "p_value" + }; + + return std::make_shared( + std::move(types), + std::move(names) + ); + } + + void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena *) const override + { + Float64 x = columns[0]->getFloat64(row_num); + Float64 y = columns[1]->getFloat64(row_num); + + this->data(place).add(x, y); + } + + void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena *) const override + { + this->data(place).merge(this->data(rhs)); + } + + void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override + { + this->data(place).write(buf); + } + + void deserialize(AggregateDataPtr place, ReadBuffer & buf, Arena *) const override + { + this->data(place).read(buf); + } + + void insertResultInto(AggregateDataPtr place, IColumn & to, Arena *) const override + { + auto [t_statistic, p_value] = this->data(place).getResult(); + + /// Because p-value is a probability. + p_value = std::min(1.0, std::max(0.0, p_value)); + + auto & column_tuple = assert_cast(to); + auto & column_stat = assert_cast &>(column_tuple.getColumn(0)); + auto & column_value = assert_cast &>(column_tuple.getColumn(1)); + + column_stat.getData().push_back(t_statistic); + column_value.getData().push_back(p_value); + } +}; + +}; diff --git a/src/AggregateFunctions/AggregateFunctionWelchTTest.cpp b/src/AggregateFunctions/AggregateFunctionWelchTTest.cpp index 0dcb125305d..044d52a42a0 100644 --- a/src/AggregateFunctions/AggregateFunctionWelchTTest.cpp +++ b/src/AggregateFunctions/AggregateFunctionWelchTTest.cpp @@ -1,49 +1,74 @@ #include -#include +#include #include +#include + #include "registerAggregateFunctions.h" -#include -#include namespace ErrorCodes { -extern const int NOT_IMPLEMENTED; + extern const int BAD_ARGUMENTS; } + namespace DB { namespace { +struct WelchTTestData : public TTestMoments +{ + static constexpr auto name = "welchTTest"; + + std::pair getResult() const + { + Float64 mean_x = x1 / m0; + Float64 mean_y = y1 / m0; + + /// s_x^2, s_y^2 + + /// The original formulae looks like \frac{1}{size_x - 1} \sum_{i = 1}^{size_x}{(x_i - \bar{x}) ^ 2} + /// But we made some mathematical transformations not to store original sequences. + /// Also we dropped sqrt, because later it will be squared later. + + Float64 sx2 = (x2 + m0 * mean_x * mean_x - 2 * mean_x * x1) / (m0 - 1); + Float64 sy2 = (y2 + m0 * mean_y * mean_y - 2 * mean_y * y1) / (m0 - 1); + + /// t-statistic, squared + Float64 t_stat = (mean_x - mean_y) / sqrt(sx2 / m0 + sy2 / m0); + + /// degrees of freedom + + Float64 numerator_sqrt = sx2 / m0 + sy2 / m0; + Float64 numerator = numerator_sqrt * numerator_sqrt; + + Float64 denominator_x = sx2 * sx2 / (m0 * m0 * (m0 - 1)); + Float64 denominator_y = sy2 * sy2 / (m0 * m0 * (m0 - 1)); + + Float64 degrees_of_freedom = numerator / (denominator_x + denominator_y); + + return {t_stat, getPValue(degrees_of_freedom, t_stat * t_stat)}; + } +}; + AggregateFunctionPtr createAggregateFunctionWelchTTest(const std::string & name, const DataTypes & argument_types, const Array & parameters) { assertBinary(name, argument_types); assertNoParameters(name, parameters); - AggregateFunctionPtr res; + if (!isNumber(argument_types[0]) || !isNumber(argument_types[1])) + throw Exception("Aggregate function " + name + " only supports numerical types", ErrorCodes::BAD_ARGUMENTS); - if (isDecimal(argument_types[0]) || isDecimal(argument_types[1])) - { - throw Exception("Aggregate function " + name + " only supports numerical types", ErrorCodes::NOT_IMPLEMENTED); - } - else - { - res.reset(createWithTwoNumericTypes(*argument_types[0], *argument_types[1], argument_types)); - } - - if (!res) - { - throw Exception("Aggregate function " + name + " only supports numerical types", ErrorCodes::NOT_IMPLEMENTED); - } - - return res; + return std::make_shared>(argument_types); } + } void registerAggregateFunctionWelchTTest(AggregateFunctionFactory & factory) { factory.registerFunction("welchTTest", createAggregateFunctionWelchTTest); } + } diff --git a/src/AggregateFunctions/AggregateFunctionWelchTTest.h b/src/AggregateFunctions/AggregateFunctionWelchTTest.h deleted file mode 100644 index b598f25162e..00000000000 --- a/src/AggregateFunctions/AggregateFunctionWelchTTest.h +++ /dev/null @@ -1,274 +0,0 @@ -#pragma once - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include - - -namespace ErrorCodes -{ - extern const int BAD_ARGUMENTS; -} - -#if defined(OS_DARWIN) -extern "C" -{ - double lgammal_r(double x, int * signgamp); -} -#endif - - -namespace DB -{ - -template -struct AggregateFunctionWelchTTestData final -{ - size_t size_x = 0; - size_t size_y = 0; - X sum_x = static_cast(0); - Y sum_y = static_cast(0); - X square_sum_x = static_cast(0); - Y square_sum_y = static_cast(0); - Float64 mean_x = static_cast(0); - Float64 mean_y = static_cast(0); - - void add(X x, Y y) - { - sum_x += x; - sum_y += y; - size_x++; - size_y++; - mean_x = static_cast(sum_x) / size_x; - mean_y = static_cast(sum_y) / size_y; - square_sum_x += x * x; - square_sum_y += y * y; - } - - void merge(const AggregateFunctionWelchTTestData &other) - { - sum_x += other.sum_x; - sum_y += other.sum_y; - size_x += other.size_x; - size_y += other.size_y; - mean_x = static_cast(sum_x) / size_x; - mean_y = static_cast(sum_y) / size_y; - square_sum_x += other.square_sum_x; - square_sum_y += other.square_sum_y; - } - - void serialize(WriteBuffer &buf) const - { - writeBinary(mean_x, buf); - writeBinary(mean_y, buf); - writeBinary(sum_x, buf); - writeBinary(sum_y, buf); - writeBinary(square_sum_x, buf); - writeBinary(square_sum_y, buf); - writeBinary(size_x, buf); - writeBinary(size_y, buf); - } - - void deserialize(ReadBuffer &buf) - { - readBinary(mean_x, buf); - readBinary(mean_y, buf); - readBinary(sum_x, buf); - readBinary(sum_y, buf); - readBinary(square_sum_x, buf); - readBinary(square_sum_y, buf); - readBinary(size_x, buf); - readBinary(size_y, buf); - } - - size_t getSizeY() const - { - return size_y; - } - - size_t getSizeX() const - { - return size_x; - } - - Float64 getSxSquared() const - { - /// The original formulae looks like \frac{1}{size_x - 1} \sum_{i = 1}^{size_x}{(x_i - \bar{x}) ^ 2} - /// But we made some mathematical transformations not to store original sequences. - /// Also we dropped sqrt, because later it will be squared later. - return static_cast(square_sum_x + size_x * std::pow(mean_x, 2) - 2 * mean_x * sum_x) / (size_x - 1); - } - - Float64 getSySquared() const - { - /// The original formulae looks like \frac{1}{size_y - 1} \sum_{i = 1}^{size_y}{(y_i - \bar{y}) ^ 2} - /// But we made some mathematical transformations not to store original sequences. - /// Also we dropped sqrt, because later it will be squared later. - return static_cast(square_sum_y + size_y * std::pow(mean_y, 2) - 2 * mean_y * sum_y) / (size_y - 1); - } - - Float64 getTStatisticSquared() const - { - if (size_x == 0 || size_y == 0) - { - throw Exception("Division by zero encountered in Aggregate function WelchTTest", ErrorCodes::BAD_ARGUMENTS); - } - - return std::pow(mean_x - mean_y, 2) / (getSxSquared() / size_x + getSySquared() / size_y); - } - - Float64 getTStatistic() const - { - if (size_x == 0 || size_y == 0) - { - throw Exception("Division by zero encountered in Aggregate function WelchTTest", ErrorCodes::BAD_ARGUMENTS); - } - - return (mean_x - mean_y) / std::sqrt(getSxSquared() / size_x + getSySquared() / size_y); - } - - Float64 getDegreesOfFreedom() const - { - auto sx = getSxSquared(); - auto sy = getSySquared(); - Float64 numerator = std::pow(sx / size_x + sy / size_y, 2); - Float64 denominator_first = std::pow(sx, 2) / (std::pow(size_x, 2) * (size_x - 1)); - Float64 denominator_second = std::pow(sy, 2) / (std::pow(size_y, 2) * (size_y - 1)); - return numerator / (denominator_first + denominator_second); - } - - static Float64 integrateSimpson(Float64 a, Float64 b, std::function func) - { - size_t iterations = std::max(1e6, 1e4 * std::abs(std::round(b))); - double h = (b - a) / iterations; - Float64 sum_odds = 0.0; - for (size_t i = 1; i < iterations; i += 2) - sum_odds += func(a + i * h); - Float64 sum_evens = 0.0; - for (size_t i = 2; i < iterations; i += 2) - sum_evens += func(a + i * h); - return (func(a) + func(b) + 2 * sum_evens + 4 * sum_odds) * h / 3; - } - - Float64 getPValue() const - { - const Float64 v = getDegreesOfFreedom(); - const Float64 t = getTStatisticSquared(); - auto f = [&v] (double x) { return std::pow(x, v / 2 - 1) / std::sqrt(1 - x); }; - Float64 numenator = integrateSimpson(0, v / (t + v), f); - int unused; - Float64 denominator = std::exp(lgammal_r(v / 2, &unused) + lgammal_r(0.5, &unused) - lgammal_r(v / 2 + 0.5, &unused)); - return numenator / denominator; - } - - std::pair getResult() const - { - return std::make_pair(getTStatistic(), getPValue()); - } -}; - -/// Returns tuple of (t-statistic, p-value) -/// https://cpb-us-w2.wpmucdn.com/voices.uchicago.edu/dist/9/1193/files/2016/01/05b-TandP.pdf -template -class AggregateFunctionWelchTTest : - public IAggregateFunctionDataHelper,AggregateFunctionWelchTTest> -{ - -public: - AggregateFunctionWelchTTest(const DataTypes & arguments) - : IAggregateFunctionDataHelper, AggregateFunctionWelchTTest> ({arguments}, {}) - {} - - String getName() const override - { - return "welchTTest"; - } - - DataTypePtr getReturnType() const override - { - DataTypes types - { - std::make_shared>(), - std::make_shared>(), - }; - - Strings names - { - "t-statistic", - "p-value" - }; - - return std::make_shared( - std::move(types), - std::move(names) - ); - } - - void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena *) const override - { - auto col_x = assert_cast *>(columns[0]); - auto col_y = assert_cast *>(columns[1]); - - X x = col_x->getData()[row_num]; - Y y = col_y->getData()[row_num]; - - this->data(place).add(x, y); - } - - void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena *) const override - { - this->data(place).merge(this->data(rhs)); - } - - void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override - { - this->data(place).serialize(buf); - } - - void deserialize(AggregateDataPtr place, ReadBuffer & buf, Arena *) const override - { - this->data(place).deserialize(buf); - } - - void insertResultInto(AggregateDataPtr place, IColumn & to, Arena * /*arena*/) const override - { - size_t size_x = this->data(place).getSizeX(); - size_t size_y = this->data(place).getSizeY(); - - if (size_x < 2 || size_y < 2) - { - throw Exception("Aggregate function " + getName() + " requires samples to be of size > 1", ErrorCodes::BAD_ARGUMENTS); - } - - Float64 t_statistic = 0.0; - Float64 p_value = 0.0; - std::tie(t_statistic, p_value) = this->data(place).getResult(); - - /// Because p-value is a probability. - p_value = std::min(1.0, std::max(0.0, p_value)); - - auto & column_tuple = assert_cast(to); - auto & column_stat = assert_cast &>(column_tuple.getColumn(0)); - auto & column_value = assert_cast &>(column_tuple.getColumn(1)); - - column_stat.getData().push_back(t_statistic); - column_value.getData().push_back(p_value); - } - -}; - -};