mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-11-22 15:42:02 +00:00
Make the code less bad
(cherry picked from commit911cd77c1a
) (cherry picked from commitac7267ce48
)
This commit is contained in:
parent
88fec92921
commit
a7739d9afb
@ -8,6 +8,7 @@
|
|||||||
#include <IO/ReadHelpers.h>
|
#include <IO/ReadHelpers.h>
|
||||||
|
|
||||||
#include <AggregateFunctions/IAggregateFunction.h>
|
#include <AggregateFunctions/IAggregateFunction.h>
|
||||||
|
#include <AggregateFunctions/Moments.h>
|
||||||
|
|
||||||
#include <DataTypes/DataTypesNumber.h>
|
#include <DataTypes/DataTypesNumber.h>
|
||||||
#include <DataTypes/DataTypesDecimal.h>
|
#include <DataTypes/DataTypesDecimal.h>
|
||||||
@ -30,310 +31,6 @@
|
|||||||
namespace DB
|
namespace DB
|
||||||
{
|
{
|
||||||
|
|
||||||
namespace ErrorCodes
|
|
||||||
{
|
|
||||||
extern const int DECIMAL_OVERFLOW;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
Calculating univariate central moments
|
|
||||||
Levels:
|
|
||||||
level 2 (pop & samp): var, stddev
|
|
||||||
level 3: skewness
|
|
||||||
level 4: kurtosis
|
|
||||||
References:
|
|
||||||
https://en.wikipedia.org/wiki/Moment_(mathematics)
|
|
||||||
https://en.wikipedia.org/wiki/Skewness
|
|
||||||
https://en.wikipedia.org/wiki/Kurtosis
|
|
||||||
*/
|
|
||||||
template <typename T, size_t _level>
|
|
||||||
struct VarMoments
|
|
||||||
{
|
|
||||||
T m[_level + 1]{};
|
|
||||||
|
|
||||||
void add(T x)
|
|
||||||
{
|
|
||||||
++m[0];
|
|
||||||
m[1] += x;
|
|
||||||
m[2] += x * x;
|
|
||||||
if constexpr (_level >= 3) m[3] += x * x * x;
|
|
||||||
if constexpr (_level >= 4) m[4] += x * x * x * x;
|
|
||||||
}
|
|
||||||
|
|
||||||
void merge(const VarMoments & rhs)
|
|
||||||
{
|
|
||||||
m[0] += rhs.m[0];
|
|
||||||
m[1] += rhs.m[1];
|
|
||||||
m[2] += rhs.m[2];
|
|
||||||
if constexpr (_level >= 3) m[3] += rhs.m[3];
|
|
||||||
if constexpr (_level >= 4) m[4] += rhs.m[4];
|
|
||||||
}
|
|
||||||
|
|
||||||
void write(WriteBuffer & buf) const
|
|
||||||
{
|
|
||||||
writePODBinary(*this, buf);
|
|
||||||
}
|
|
||||||
|
|
||||||
void read(ReadBuffer & buf)
|
|
||||||
{
|
|
||||||
readPODBinary(*this, buf);
|
|
||||||
}
|
|
||||||
|
|
||||||
T getPopulation() const
|
|
||||||
{
|
|
||||||
if (m[0] == 0)
|
|
||||||
return std::numeric_limits<T>::quiet_NaN();
|
|
||||||
|
|
||||||
/// Due to numerical errors, the result can be slightly less than zero,
|
|
||||||
/// but it should be impossible. Trim to zero.
|
|
||||||
|
|
||||||
return std::max(T{}, (m[2] - m[1] * m[1] / m[0]) / m[0]);
|
|
||||||
}
|
|
||||||
|
|
||||||
T getSample() const
|
|
||||||
{
|
|
||||||
if (m[0] <= 1)
|
|
||||||
return std::numeric_limits<T>::quiet_NaN();
|
|
||||||
return std::max(T{}, (m[2] - m[1] * m[1] / m[0]) / (m[0] - 1));
|
|
||||||
}
|
|
||||||
|
|
||||||
T getMoment3() const
|
|
||||||
{
|
|
||||||
if (m[0] == 0)
|
|
||||||
return std::numeric_limits<T>::quiet_NaN();
|
|
||||||
// to avoid accuracy problem
|
|
||||||
if (m[0] == 1)
|
|
||||||
return 0;
|
|
||||||
return (m[3]
|
|
||||||
- (3 * m[2]
|
|
||||||
- 2 * m[1] * m[1] / m[0]
|
|
||||||
) * m[1] / m[0]
|
|
||||||
) / m[0];
|
|
||||||
}
|
|
||||||
|
|
||||||
T getMoment4() const
|
|
||||||
{
|
|
||||||
if (m[0] == 0)
|
|
||||||
return std::numeric_limits<T>::quiet_NaN();
|
|
||||||
// to avoid accuracy problem
|
|
||||||
if (m[0] == 1)
|
|
||||||
return 0;
|
|
||||||
return (m[4]
|
|
||||||
- (4 * m[3]
|
|
||||||
- (6 * m[2]
|
|
||||||
- 3 * m[1] * m[1] / m[0]
|
|
||||||
) * m[1] / m[0]
|
|
||||||
) * m[1] / m[0]
|
|
||||||
) / m[0];
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename T, size_t _level>
|
|
||||||
class VarMomentsDecimal
|
|
||||||
{
|
|
||||||
public:
|
|
||||||
using NativeType = typename T::NativeType;
|
|
||||||
|
|
||||||
void add(NativeType x)
|
|
||||||
{
|
|
||||||
++m0;
|
|
||||||
getM(1) += x;
|
|
||||||
|
|
||||||
NativeType tmp;
|
|
||||||
bool overflow = common::mulOverflow(x, x, tmp) || common::addOverflow(getM(2), tmp, getM(2));
|
|
||||||
if constexpr (_level >= 3)
|
|
||||||
overflow = overflow || common::mulOverflow(tmp, x, tmp) || common::addOverflow(getM(3), tmp, getM(3));
|
|
||||||
if constexpr (_level >= 4)
|
|
||||||
overflow = overflow || common::mulOverflow(tmp, x, tmp) || common::addOverflow(getM(4), tmp, getM(4));
|
|
||||||
|
|
||||||
if (overflow)
|
|
||||||
throw Exception("Decimal math overflow", ErrorCodes::DECIMAL_OVERFLOW);
|
|
||||||
}
|
|
||||||
|
|
||||||
void merge(const VarMomentsDecimal & rhs)
|
|
||||||
{
|
|
||||||
m0 += rhs.m0;
|
|
||||||
getM(1) += rhs.getM(1);
|
|
||||||
|
|
||||||
bool overflow = common::addOverflow(getM(2), rhs.getM(2), getM(2));
|
|
||||||
if constexpr (_level >= 3)
|
|
||||||
overflow = overflow || common::addOverflow(getM(3), rhs.getM(3), getM(3));
|
|
||||||
if constexpr (_level >= 4)
|
|
||||||
overflow = overflow || common::addOverflow(getM(4), rhs.getM(4), getM(4));
|
|
||||||
|
|
||||||
if (overflow)
|
|
||||||
throw Exception("Decimal math overflow", ErrorCodes::DECIMAL_OVERFLOW);
|
|
||||||
}
|
|
||||||
|
|
||||||
void write(WriteBuffer & buf) const { writePODBinary(*this, buf); }
|
|
||||||
void read(ReadBuffer & buf) { readPODBinary(*this, buf); }
|
|
||||||
|
|
||||||
Float64 getPopulation(UInt32 scale) const
|
|
||||||
{
|
|
||||||
if (m0 == 0)
|
|
||||||
return std::numeric_limits<Float64>::infinity();
|
|
||||||
|
|
||||||
NativeType tmp;
|
|
||||||
if (common::mulOverflow(getM(1), getM(1), tmp) ||
|
|
||||||
common::subOverflow(getM(2), NativeType(tmp / m0), tmp))
|
|
||||||
throw Exception("Decimal math overflow", ErrorCodes::DECIMAL_OVERFLOW);
|
|
||||||
return std::max(Float64{}, DecimalUtils::convertTo<Float64>(T(tmp / m0), scale));
|
|
||||||
}
|
|
||||||
|
|
||||||
Float64 getSample(UInt32 scale) const
|
|
||||||
{
|
|
||||||
if (m0 == 0)
|
|
||||||
return std::numeric_limits<Float64>::quiet_NaN();
|
|
||||||
if (m0 == 1)
|
|
||||||
return std::numeric_limits<Float64>::infinity();
|
|
||||||
|
|
||||||
NativeType tmp;
|
|
||||||
if (common::mulOverflow(getM(1), getM(1), tmp) ||
|
|
||||||
common::subOverflow(getM(2), NativeType(tmp / m0), tmp))
|
|
||||||
throw Exception("Decimal math overflow", ErrorCodes::DECIMAL_OVERFLOW);
|
|
||||||
return std::max(Float64{}, DecimalUtils::convertTo<Float64>(T(tmp / (m0 - 1)), scale));
|
|
||||||
}
|
|
||||||
|
|
||||||
Float64 getMoment3(UInt32 scale) const
|
|
||||||
{
|
|
||||||
if (m0 == 0)
|
|
||||||
return std::numeric_limits<Float64>::infinity();
|
|
||||||
|
|
||||||
NativeType tmp;
|
|
||||||
if (common::mulOverflow(2 * getM(1), getM(1), tmp) ||
|
|
||||||
common::subOverflow(3 * getM(2), NativeType(tmp / m0), tmp) ||
|
|
||||||
common::mulOverflow(tmp, getM(1), tmp) ||
|
|
||||||
common::subOverflow(getM(3), NativeType(tmp / m0), tmp))
|
|
||||||
throw Exception("Decimal math overflow", ErrorCodes::DECIMAL_OVERFLOW);
|
|
||||||
return DecimalUtils::convertTo<Float64>(T(tmp / m0), scale);
|
|
||||||
}
|
|
||||||
|
|
||||||
Float64 getMoment4(UInt32 scale) const
|
|
||||||
{
|
|
||||||
if (m0 == 0)
|
|
||||||
return std::numeric_limits<Float64>::infinity();
|
|
||||||
|
|
||||||
NativeType tmp;
|
|
||||||
if (common::mulOverflow(3 * getM(1), getM(1), tmp) ||
|
|
||||||
common::subOverflow(6 * getM(2), NativeType(tmp / m0), tmp) ||
|
|
||||||
common::mulOverflow(tmp, getM(1), tmp) ||
|
|
||||||
common::subOverflow(4 * getM(3), NativeType(tmp / m0), tmp) ||
|
|
||||||
common::mulOverflow(tmp, getM(1), tmp) ||
|
|
||||||
common::subOverflow(getM(4), NativeType(tmp / m0), tmp))
|
|
||||||
throw Exception("Decimal math overflow", ErrorCodes::DECIMAL_OVERFLOW);
|
|
||||||
return DecimalUtils::convertTo<Float64>(T(tmp / m0), scale);
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
UInt64 m0{};
|
|
||||||
NativeType m[_level]{};
|
|
||||||
|
|
||||||
NativeType & getM(size_t i) { return m[i - 1]; }
|
|
||||||
const NativeType & getM(size_t i) const { return m[i - 1]; }
|
|
||||||
};
|
|
||||||
|
|
||||||
/**
|
|
||||||
Calculating multivariate central moments
|
|
||||||
Levels:
|
|
||||||
level 2 (pop & samp): covar
|
|
||||||
References:
|
|
||||||
https://en.wikipedia.org/wiki/Moment_(mathematics)
|
|
||||||
*/
|
|
||||||
template <typename T>
|
|
||||||
struct CovarMoments
|
|
||||||
{
|
|
||||||
T m0{};
|
|
||||||
T x1{};
|
|
||||||
T y1{};
|
|
||||||
T xy{};
|
|
||||||
|
|
||||||
void add(T x, T y)
|
|
||||||
{
|
|
||||||
++m0;
|
|
||||||
x1 += x;
|
|
||||||
y1 += y;
|
|
||||||
xy += x * y;
|
|
||||||
}
|
|
||||||
|
|
||||||
void merge(const CovarMoments & rhs)
|
|
||||||
{
|
|
||||||
m0 += rhs.m0;
|
|
||||||
x1 += rhs.x1;
|
|
||||||
y1 += rhs.y1;
|
|
||||||
xy += rhs.xy;
|
|
||||||
}
|
|
||||||
|
|
||||||
void write(WriteBuffer & buf) const
|
|
||||||
{
|
|
||||||
writePODBinary(*this, buf);
|
|
||||||
}
|
|
||||||
|
|
||||||
void read(ReadBuffer & buf)
|
|
||||||
{
|
|
||||||
readPODBinary(*this, buf);
|
|
||||||
}
|
|
||||||
|
|
||||||
T NO_SANITIZE_UNDEFINED getPopulation() const
|
|
||||||
{
|
|
||||||
return (xy - x1 * y1 / m0) / m0;
|
|
||||||
}
|
|
||||||
|
|
||||||
T NO_SANITIZE_UNDEFINED getSample() const
|
|
||||||
{
|
|
||||||
if (m0 == 0)
|
|
||||||
return std::numeric_limits<T>::quiet_NaN();
|
|
||||||
return (xy - x1 * y1 / m0) / (m0 - 1);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
struct CorrMoments
|
|
||||||
{
|
|
||||||
T m0{};
|
|
||||||
T x1{};
|
|
||||||
T y1{};
|
|
||||||
T xy{};
|
|
||||||
T x2{};
|
|
||||||
T y2{};
|
|
||||||
|
|
||||||
void add(T x, T y)
|
|
||||||
{
|
|
||||||
++m0;
|
|
||||||
x1 += x;
|
|
||||||
y1 += y;
|
|
||||||
xy += x * y;
|
|
||||||
x2 += x * x;
|
|
||||||
y2 += y * y;
|
|
||||||
}
|
|
||||||
|
|
||||||
void merge(const CorrMoments & rhs)
|
|
||||||
{
|
|
||||||
m0 += rhs.m0;
|
|
||||||
x1 += rhs.x1;
|
|
||||||
y1 += rhs.y1;
|
|
||||||
xy += rhs.xy;
|
|
||||||
x2 += rhs.x2;
|
|
||||||
y2 += rhs.y2;
|
|
||||||
}
|
|
||||||
|
|
||||||
void write(WriteBuffer & buf) const
|
|
||||||
{
|
|
||||||
writePODBinary(*this, buf);
|
|
||||||
}
|
|
||||||
|
|
||||||
void read(ReadBuffer & buf)
|
|
||||||
{
|
|
||||||
readPODBinary(*this, buf);
|
|
||||||
}
|
|
||||||
|
|
||||||
T NO_SANITIZE_UNDEFINED get() const
|
|
||||||
{
|
|
||||||
return (m0 * xy - x1 * y1) / sqrt((m0 * x2 - x1 * x1) * (m0 * y2 - y1 * y1));
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
|
|
||||||
enum class StatisticsFunctionKind
|
enum class StatisticsFunctionKind
|
||||||
{
|
{
|
||||||
varPop, varSamp,
|
varPop, varSamp,
|
||||||
|
@ -1,52 +1,70 @@
|
|||||||
#include <AggregateFunctions/AggregateFunctionFactory.h>
|
#include <AggregateFunctions/AggregateFunctionFactory.h>
|
||||||
#include <AggregateFunctions/AggregateFunctionStudentTTest.h>
|
#include <AggregateFunctions/AggregateFunctionTTest.h>
|
||||||
#include <AggregateFunctions/FactoryHelpers.h>
|
#include <AggregateFunctions/FactoryHelpers.h>
|
||||||
|
#include <AggregateFunctions/Moments.h>
|
||||||
|
|
||||||
#include "registerAggregateFunctions.h"
|
#include "registerAggregateFunctions.h"
|
||||||
|
|
||||||
#include <AggregateFunctions/Helpers.h>
|
|
||||||
#include <DataTypes/DataTypeAggregateFunction.h>
|
|
||||||
|
|
||||||
|
|
||||||
// the return type is boolean (we use UInt8 as we do not have boolean in clickhouse)
|
|
||||||
|
|
||||||
namespace ErrorCodes
|
namespace ErrorCodes
|
||||||
{
|
{
|
||||||
extern const int NOT_IMPLEMENTED;
|
extern const int BAD_ARGUMENTS;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
namespace DB
|
namespace DB
|
||||||
{
|
{
|
||||||
|
|
||||||
namespace
|
namespace
|
||||||
{
|
{
|
||||||
|
|
||||||
|
struct StudentTTestData : public TTestMoments<Float64>
|
||||||
|
{
|
||||||
|
static constexpr auto name = "studentTTest";
|
||||||
|
|
||||||
|
std::pair<Float64, Float64> getResult() const
|
||||||
|
{
|
||||||
|
Float64 degrees_of_freedom = 2.0 * (m0 - 1);
|
||||||
|
|
||||||
|
Float64 mean_x = x1 / m0;
|
||||||
|
Float64 mean_y = y1 / m0;
|
||||||
|
|
||||||
|
/// Calculate s^2
|
||||||
|
|
||||||
|
/// The original formulae looks like
|
||||||
|
/// \frac{\sum_{i = 1}^{n_x}{(x_i - \bar{x}) ^ 2} + \sum_{i = 1}^{n_y}{(y_i - \bar{y}) ^ 2}}{n_x + n_y - 2}
|
||||||
|
/// But we made some mathematical transformations not to store original sequences.
|
||||||
|
/// Also we dropped sqrt, because later it will be squared later.
|
||||||
|
|
||||||
|
Float64 all_x = x2 + m0 * mean_x * mean_x - 2 * mean_x * m0;
|
||||||
|
Float64 all_y = y2 + m0 * mean_y * mean_y - 2 * mean_y * m0;
|
||||||
|
|
||||||
|
Float64 s2 = (all_x + all_y) / degrees_of_freedom;
|
||||||
|
Float64 std_err2 = 2.0 * s2 / m0;
|
||||||
|
|
||||||
|
/// t-statistic, squared
|
||||||
|
Float64 t_stat = (mean_x - mean_y) / sqrt(std_err2);
|
||||||
|
|
||||||
|
return {t_stat, getPValue(degrees_of_freedom, t_stat * t_stat)};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
AggregateFunctionPtr createAggregateFunctionStudentTTest(const std::string & name, const DataTypes & argument_types, const Array & parameters)
|
AggregateFunctionPtr createAggregateFunctionStudentTTest(const std::string & name, const DataTypes & argument_types, const Array & parameters)
|
||||||
{
|
{
|
||||||
assertBinary(name, argument_types);
|
assertBinary(name, argument_types);
|
||||||
assertNoParameters(name, parameters);
|
assertNoParameters(name, parameters);
|
||||||
|
|
||||||
AggregateFunctionPtr res;
|
if (!isNumber(argument_types[0]) || !isNumber(argument_types[1]))
|
||||||
|
throw Exception("Aggregate function " + name + " only supports numerical types", ErrorCodes::BAD_ARGUMENTS);
|
||||||
|
|
||||||
if (isDecimal(argument_types[0]) || isDecimal(argument_types[1]))
|
return std::make_shared<AggregateFunctionTTest<StudentTTestData>>(argument_types);
|
||||||
{
|
|
||||||
throw Exception("Aggregate function " + name + " only supports numerical types", ErrorCodes::NOT_IMPLEMENTED);
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
res.reset(createWithTwoNumericTypes<AggregateFunctionStudentTTest>(*argument_types[0], *argument_types[1], argument_types));
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!res)
|
|
||||||
{
|
|
||||||
throw Exception("Aggregate function " + name + " only supports numerical types", ErrorCodes::NOT_IMPLEMENTED);
|
|
||||||
}
|
|
||||||
|
|
||||||
return res;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void registerAggregateFunctionStudentTTest(AggregateFunctionFactory & factory)
|
void registerAggregateFunctionStudentTTest(AggregateFunctionFactory & factory)
|
||||||
{
|
{
|
||||||
factory.registerFunction("studentTTest", createAggregateFunctionStudentTTest, AggregateFunctionFactory::CaseInsensitive);
|
factory.registerFunction("studentTTest", createAggregateFunctionStudentTTest, AggregateFunctionFactory::CaseInsensitive);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
135
src/AggregateFunctions/AggregateFunctionTTest.h
Normal file
135
src/AggregateFunctions/AggregateFunctionTTest.h
Normal file
@ -0,0 +1,135 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <AggregateFunctions/IAggregateFunction.h>
|
||||||
|
#include <Columns/ColumnVector.h>
|
||||||
|
#include <Columns/ColumnTuple.h>
|
||||||
|
#include <Common/assert_cast.h>
|
||||||
|
#include <Core/Types.h>
|
||||||
|
#include <DataTypes/DataTypesNumber.h>
|
||||||
|
#include <DataTypes/DataTypeTuple.h>
|
||||||
|
#include <cmath>
|
||||||
|
|
||||||
|
|
||||||
|
/// This function is used in implementations of different T-Tests.
|
||||||
|
/// On Darwin it's unavailable in math.h but actually exists in the library (can be linked successfully).
|
||||||
|
#if defined(OS_DARWIN)
|
||||||
|
extern "C"
|
||||||
|
{
|
||||||
|
double lgamma_r(double x, int * signgamp);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
|
||||||
|
namespace DB
|
||||||
|
{
|
||||||
|
|
||||||
|
class ReadBuffer;
|
||||||
|
class WriteBuffer;
|
||||||
|
|
||||||
|
|
||||||
|
template <typename F>
|
||||||
|
static Float64 integrateSimpson(Float64 a, Float64 b, F && func)
|
||||||
|
{
|
||||||
|
const size_t iterations = std::max(1e6, 1e4 * std::abs(std::round(b)));
|
||||||
|
const long double h = (b - a) / iterations;
|
||||||
|
Float64 sum_odds = 0.0;
|
||||||
|
for (size_t i = 1; i < iterations; i += 2)
|
||||||
|
sum_odds += func(a + i * h);
|
||||||
|
Float64 sum_evens = 0.0;
|
||||||
|
for (size_t i = 2; i < iterations; i += 2)
|
||||||
|
sum_evens += func(a + i * h);
|
||||||
|
return (func(a) + func(b) + 2 * sum_evens + 4 * sum_odds) * h / 3;
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline Float64 getPValue(Float64 degrees_of_freedom, Float64 t_stat2)
|
||||||
|
{
|
||||||
|
Float64 numerator = integrateSimpson(0, degrees_of_freedom / (t_stat2 + degrees_of_freedom),
|
||||||
|
[degrees_of_freedom](double x) { return std::pow(x, degrees_of_freedom / 2 - 1) / std::sqrt(1 - x); });
|
||||||
|
|
||||||
|
int unused;
|
||||||
|
Float64 denominator = std::exp(
|
||||||
|
lgamma_r(degrees_of_freedom / 2, &unused)
|
||||||
|
+ lgamma_r(0.5, &unused)
|
||||||
|
- lgamma_r(degrees_of_freedom / 2 + 0.5, &unused));
|
||||||
|
|
||||||
|
return std::min(1.0, std::max(0.0, numerator / denominator));
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/// Returns tuple of (t-statistic, p-value)
|
||||||
|
/// https://cpb-us-w2.wpmucdn.com/voices.uchicago.edu/dist/9/1193/files/2016/01/05b-TandP.pdf
|
||||||
|
template <typename Data>
|
||||||
|
class AggregateFunctionTTest :
|
||||||
|
public IAggregateFunctionDataHelper<Data, AggregateFunctionTTest<Data>>
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
AggregateFunctionTTest(const DataTypes & arguments)
|
||||||
|
: IAggregateFunctionDataHelper<Data, AggregateFunctionTTest<Data>>({arguments}, {})
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
String getName() const override
|
||||||
|
{
|
||||||
|
return Data::name;
|
||||||
|
}
|
||||||
|
|
||||||
|
DataTypePtr getReturnType() const override
|
||||||
|
{
|
||||||
|
DataTypes types
|
||||||
|
{
|
||||||
|
std::make_shared<DataTypeNumber<Float64>>(),
|
||||||
|
std::make_shared<DataTypeNumber<Float64>>(),
|
||||||
|
};
|
||||||
|
|
||||||
|
Strings names
|
||||||
|
{
|
||||||
|
"t_statistic",
|
||||||
|
"p_value"
|
||||||
|
};
|
||||||
|
|
||||||
|
return std::make_shared<DataTypeTuple>(
|
||||||
|
std::move(types),
|
||||||
|
std::move(names)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena *) const override
|
||||||
|
{
|
||||||
|
Float64 x = columns[0]->getFloat64(row_num);
|
||||||
|
Float64 y = columns[1]->getFloat64(row_num);
|
||||||
|
|
||||||
|
this->data(place).add(x, y);
|
||||||
|
}
|
||||||
|
|
||||||
|
void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena *) const override
|
||||||
|
{
|
||||||
|
this->data(place).merge(this->data(rhs));
|
||||||
|
}
|
||||||
|
|
||||||
|
void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override
|
||||||
|
{
|
||||||
|
this->data(place).write(buf);
|
||||||
|
}
|
||||||
|
|
||||||
|
void deserialize(AggregateDataPtr place, ReadBuffer & buf, Arena *) const override
|
||||||
|
{
|
||||||
|
this->data(place).read(buf);
|
||||||
|
}
|
||||||
|
|
||||||
|
void insertResultInto(AggregateDataPtr place, IColumn & to, Arena *) const override
|
||||||
|
{
|
||||||
|
auto [t_statistic, p_value] = this->data(place).getResult();
|
||||||
|
|
||||||
|
/// Because p-value is a probability.
|
||||||
|
p_value = std::min(1.0, std::max(0.0, p_value));
|
||||||
|
|
||||||
|
auto & column_tuple = assert_cast<ColumnTuple &>(to);
|
||||||
|
auto & column_stat = assert_cast<ColumnVector<Float64> &>(column_tuple.getColumn(0));
|
||||||
|
auto & column_value = assert_cast<ColumnVector<Float64> &>(column_tuple.getColumn(1));
|
||||||
|
|
||||||
|
column_stat.getData().push_back(t_statistic);
|
||||||
|
column_value.getData().push_back(p_value);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
};
|
@ -1,49 +1,74 @@
|
|||||||
#include <AggregateFunctions/AggregateFunctionFactory.h>
|
#include <AggregateFunctions/AggregateFunctionFactory.h>
|
||||||
#include <AggregateFunctions/AggregateFunctionWelchTTest.h>
|
#include <AggregateFunctions/AggregateFunctionTTest.h>
|
||||||
#include <AggregateFunctions/FactoryHelpers.h>
|
#include <AggregateFunctions/FactoryHelpers.h>
|
||||||
|
#include <AggregateFunctions/Moments.h>
|
||||||
|
|
||||||
#include "registerAggregateFunctions.h"
|
#include "registerAggregateFunctions.h"
|
||||||
|
|
||||||
#include <AggregateFunctions/Helpers.h>
|
|
||||||
#include <DataTypes/DataTypeAggregateFunction.h>
|
|
||||||
|
|
||||||
namespace ErrorCodes
|
namespace ErrorCodes
|
||||||
{
|
{
|
||||||
extern const int NOT_IMPLEMENTED;
|
extern const int BAD_ARGUMENTS;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
namespace DB
|
namespace DB
|
||||||
{
|
{
|
||||||
|
|
||||||
namespace
|
namespace
|
||||||
{
|
{
|
||||||
|
|
||||||
|
struct WelchTTestData : public TTestMoments<Float64>
|
||||||
|
{
|
||||||
|
static constexpr auto name = "welchTTest";
|
||||||
|
|
||||||
|
std::pair<Float64, Float64> getResult() const
|
||||||
|
{
|
||||||
|
Float64 mean_x = x1 / m0;
|
||||||
|
Float64 mean_y = y1 / m0;
|
||||||
|
|
||||||
|
/// s_x^2, s_y^2
|
||||||
|
|
||||||
|
/// The original formulae looks like \frac{1}{size_x - 1} \sum_{i = 1}^{size_x}{(x_i - \bar{x}) ^ 2}
|
||||||
|
/// But we made some mathematical transformations not to store original sequences.
|
||||||
|
/// Also we dropped sqrt, because later it will be squared later.
|
||||||
|
|
||||||
|
Float64 sx2 = (x2 + m0 * mean_x * mean_x - 2 * mean_x * x1) / (m0 - 1);
|
||||||
|
Float64 sy2 = (y2 + m0 * mean_y * mean_y - 2 * mean_y * y1) / (m0 - 1);
|
||||||
|
|
||||||
|
/// t-statistic, squared
|
||||||
|
Float64 t_stat = (mean_x - mean_y) / sqrt(sx2 / m0 + sy2 / m0);
|
||||||
|
|
||||||
|
/// degrees of freedom
|
||||||
|
|
||||||
|
Float64 numerator_sqrt = sx2 / m0 + sy2 / m0;
|
||||||
|
Float64 numerator = numerator_sqrt * numerator_sqrt;
|
||||||
|
|
||||||
|
Float64 denominator_x = sx2 * sx2 / (m0 * m0 * (m0 - 1));
|
||||||
|
Float64 denominator_y = sy2 * sy2 / (m0 * m0 * (m0 - 1));
|
||||||
|
|
||||||
|
Float64 degrees_of_freedom = numerator / (denominator_x + denominator_y);
|
||||||
|
|
||||||
|
return {t_stat, getPValue(degrees_of_freedom, t_stat * t_stat)};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
AggregateFunctionPtr createAggregateFunctionWelchTTest(const std::string & name, const DataTypes & argument_types, const Array & parameters)
|
AggregateFunctionPtr createAggregateFunctionWelchTTest(const std::string & name, const DataTypes & argument_types, const Array & parameters)
|
||||||
{
|
{
|
||||||
assertBinary(name, argument_types);
|
assertBinary(name, argument_types);
|
||||||
assertNoParameters(name, parameters);
|
assertNoParameters(name, parameters);
|
||||||
|
|
||||||
AggregateFunctionPtr res;
|
if (!isNumber(argument_types[0]) || !isNumber(argument_types[1]))
|
||||||
|
throw Exception("Aggregate function " + name + " only supports numerical types", ErrorCodes::BAD_ARGUMENTS);
|
||||||
|
|
||||||
if (isDecimal(argument_types[0]) || isDecimal(argument_types[1]))
|
return std::make_shared<AggregateFunctionTTest<WelchTTestData>>(argument_types);
|
||||||
{
|
|
||||||
throw Exception("Aggregate function " + name + " only supports numerical types", ErrorCodes::NOT_IMPLEMENTED);
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
res.reset(createWithTwoNumericTypes<AggregateFunctionWelchTTest>(*argument_types[0], *argument_types[1], argument_types));
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!res)
|
|
||||||
{
|
|
||||||
throw Exception("Aggregate function " + name + " only supports numerical types", ErrorCodes::NOT_IMPLEMENTED);
|
|
||||||
}
|
|
||||||
|
|
||||||
return res;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void registerAggregateFunctionWelchTTest(AggregateFunctionFactory & factory)
|
void registerAggregateFunctionWelchTTest(AggregateFunctionFactory & factory)
|
||||||
{
|
{
|
||||||
factory.registerFunction("welchTTest", createAggregateFunctionWelchTTest, AggregateFunctionFactory::CaseInsensitive);
|
factory.registerFunction("welchTTest", createAggregateFunctionWelchTTest, AggregateFunctionFactory::CaseInsensitive);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user