mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-11-23 16:12:01 +00:00
Add statistical aggregate function kolmogorovSmirnovTest (#48325)
This commit is contained in:
parent
db864891f8
commit
6e8f77ee9c
@ -0,0 +1,36 @@
|
||||
#include <AggregateFunctions/AggregateFunctionFactory.h>
|
||||
#include <AggregateFunctions/AggregateFunctionKolmogorovSmirnovTest.h>
|
||||
#include <AggregateFunctions/FactoryHelpers.h>
|
||||
|
||||
namespace ErrorCodes
|
||||
{
|
||||
extern const int NOT_IMPLEMENTED;
|
||||
}
|
||||
|
||||
namespace DB
|
||||
{
|
||||
struct Settings;
|
||||
|
||||
namespace
|
||||
{
|
||||
|
||||
AggregateFunctionPtr createAggregateFunctionKolmogorovSmirnovTest(
|
||||
const std::string & name, const DataTypes & argument_types, const Array & parameters, const Settings *)
|
||||
{
|
||||
assertBinary(name, argument_types);
|
||||
|
||||
if (!isNumber(argument_types[0]) || !isNumber(argument_types[1]))
|
||||
throw Exception(ErrorCodes::NOT_IMPLEMENTED, "Aggregate function {} only supports numerical types", name);
|
||||
|
||||
return std::make_shared<AggregateFunctionKolmogorovSmirnov>(argument_types, parameters);
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
||||
void registerAggregateFunctionKolmogorovSmirnovTest(AggregateFunctionFactory & factory)
|
||||
{
|
||||
factory.registerFunction("kolmogorovSmirnovTest", createAggregateFunctionKolmogorovSmirnovTest, AggregateFunctionFactory::CaseInsensitive);
|
||||
}
|
||||
|
||||
}
|
323
src/AggregateFunctions/AggregateFunctionKolmogorovSmirnovTest.h
Normal file
323
src/AggregateFunctions/AggregateFunctionKolmogorovSmirnovTest.h
Normal file
@ -0,0 +1,323 @@
|
||||
#pragma once
|
||||
#include <AggregateFunctions/IAggregateFunction.h>
|
||||
#include <AggregateFunctions/StatCommon.h>
|
||||
#include <Columns/ColumnVector.h>
|
||||
#include <Columns/ColumnTuple.h>
|
||||
#include <Common/Exception.h>
|
||||
#include <Common/assert_cast.h>
|
||||
#include <Common/ArenaAllocator.h>
|
||||
#include <Common/PODArray_fwd.h>
|
||||
#include <base/types.h>
|
||||
#include <DataTypes/DataTypeNullable.h>
|
||||
#include <DataTypes/DataTypesNumber.h>
|
||||
#include <DataTypes/DataTypeTuple.h>
|
||||
#include <IO/ReadHelpers.h>
|
||||
#include <IO/WriteHelpers.h>
|
||||
|
||||
namespace DB
|
||||
{
|
||||
struct Settings;
|
||||
|
||||
namespace ErrorCodes
|
||||
{
|
||||
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
|
||||
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
|
||||
extern const int BAD_ARGUMENTS;
|
||||
}
|
||||
|
||||
struct KolmogorovSmirnov : public StatisticalSample<Float64, Float64>
|
||||
{
|
||||
enum class Alternative
|
||||
{
|
||||
TwoSided,
|
||||
Less,
|
||||
Greater
|
||||
};
|
||||
|
||||
std::pair<Float64, Float64> getResult(Alternative alternative, String method)
|
||||
{
|
||||
::sort(x.begin(), x.end());
|
||||
::sort(y.begin(), y.end());
|
||||
|
||||
Float64 max_s = std::numeric_limits<Float64>::min();
|
||||
Float64 min_s = std::numeric_limits<Float64>::max();
|
||||
Float64 now_s = 0;
|
||||
UInt64 pos_x = 0;
|
||||
UInt64 pos_y = 0;
|
||||
UInt64 n1 = x.size();
|
||||
UInt64 n2 = y.size();
|
||||
|
||||
const Float64 n1_d = 1. / n1;
|
||||
const Float64 n2_d = 1. / n2;
|
||||
const Float64 tol = 1e-7;
|
||||
|
||||
// reference: https://en.wikipedia.org/wiki/Kolmogorov%E2%80%93Smirnov_test
|
||||
while (pos_x < x.size() && pos_y < y.size())
|
||||
{
|
||||
if (likely(fabs(x[pos_x] - y[pos_y]) >= tol))
|
||||
{
|
||||
if (x[pos_x] < y[pos_y])
|
||||
{
|
||||
now_s += n1_d;
|
||||
++pos_x;
|
||||
}
|
||||
else
|
||||
{
|
||||
now_s -= n2_d;
|
||||
++pos_y;
|
||||
}
|
||||
max_s = std::max(max_s, now_s);
|
||||
min_s = std::min(min_s, now_s);
|
||||
}
|
||||
else
|
||||
{
|
||||
now_s += n1_d;
|
||||
++pos_x;
|
||||
}
|
||||
}
|
||||
now_s += n1_d * (x.size() - pos_x) - n2_d * (y.size() - pos_y);
|
||||
min_s = std::min(min_s, now_s);
|
||||
max_s = std::max(max_s, now_s);
|
||||
|
||||
Float64 d = 0;
|
||||
if (alternative == Alternative::TwoSided)
|
||||
d = std::max(std::abs(max_s), std::abs(min_s));
|
||||
else if (alternative == Alternative::Less)
|
||||
d = -min_s;
|
||||
else if (alternative == Alternative::Greater)
|
||||
d = max_s;
|
||||
|
||||
UInt64 g = std::__gcd(n1, n2);
|
||||
UInt64 nx_g = n1 / g;
|
||||
UInt64 ny_g = n2 / g;
|
||||
|
||||
if (method == "auto")
|
||||
method = std::max(n1, n2) <= 10000 ? "exact" : "asymp";
|
||||
else if (method == "exact" && nx_g >= std::numeric_limits<Int32>::max() / ny_g)
|
||||
method = "asymp";
|
||||
|
||||
Float64 p_value = std::numeric_limits<Float64>::infinity();
|
||||
|
||||
if (method == "exact")
|
||||
{
|
||||
/* reference:
|
||||
* Gunar Schröer and Dietrich Trenkler
|
||||
* Exact and Randomization Distributions of Kolmogorov-Smirnov, Tests for Two or Three Samples
|
||||
*
|
||||
* and
|
||||
*
|
||||
* Thomas Viehmann
|
||||
* Numerically more stable computation of the p-values for the two-sample Kolmogorov-Smirnov test
|
||||
*/
|
||||
if (n2 > n1)
|
||||
std::swap(n1, n2);
|
||||
|
||||
const Float64 f_n1 = static_cast<Float64>(n1);
|
||||
const Float64 f_n2 = static_cast<Float64>(n2);
|
||||
const Float64 k_d = (0.5 + floor(d * f_n2 * f_n1 - tol)) / (f_n2 * f_n1);
|
||||
PaddedPODArray<Float64> c(n1 + 1);
|
||||
|
||||
auto check = alternative == Alternative::TwoSided ?
|
||||
[](const Float64 & q, const Float64 & r, const Float64 & s) { return fabs(r - s) >= q; }
|
||||
: [](const Float64 & q, const Float64 & r, const Float64 & s) { return r - s >= q; };
|
||||
|
||||
c[0] = 0;
|
||||
for (UInt64 j = 1; j <= n1; j++)
|
||||
if (check(k_d, 0., j / f_n1))
|
||||
c[j] = 1.;
|
||||
else
|
||||
c[j] = c[j - 1];
|
||||
|
||||
for (UInt64 i = 1; i <= n2; i++)
|
||||
{
|
||||
if (check(k_d, i / f_n2, 0.))
|
||||
c[0] = 1.;
|
||||
for (UInt64 j = 1; j <= n1; j++)
|
||||
if (check(k_d, i / f_n2, j / f_n1))
|
||||
c[j] = 1.;
|
||||
else
|
||||
{
|
||||
Float64 v = i / static_cast<Float64>(i + j);
|
||||
Float64 w = j / static_cast<Float64>(i + j);
|
||||
c[j] = v * c[j] + w * c[j - 1];
|
||||
}
|
||||
}
|
||||
p_value = c[n1];
|
||||
}
|
||||
else if (method == "asymp")
|
||||
{
|
||||
Float64 n = std::min(n1, n2);
|
||||
Float64 m = std::max(n1, n2);
|
||||
Float64 p = sqrt((n * m) / (n + m)) * d;
|
||||
|
||||
if (alternative == Alternative::TwoSided)
|
||||
{
|
||||
/* reference:
|
||||
* J.DURBIN
|
||||
* Distribution theory for tests based on the sample distribution function
|
||||
*/
|
||||
Float64 new_val, old_val, s, w, z;
|
||||
UInt64 k_max = static_cast<UInt64>(sqrt(2 - log(tol)));
|
||||
|
||||
if (p < 1)
|
||||
{
|
||||
z = - (M_PI_2 * M_PI_4) / (p * p);
|
||||
w = log(p);
|
||||
s = 0;
|
||||
for (UInt64 k = 1; k < k_max; k += 2)
|
||||
s += exp(k * k * z - w);
|
||||
p = s / 0.398942280401432677939946059934;
|
||||
}
|
||||
else
|
||||
{
|
||||
z = -2 * p * p;
|
||||
s = -1;
|
||||
UInt64 k = 1;
|
||||
old_val = 0;
|
||||
new_val = 1;
|
||||
while (fabs(old_val - new_val) > tol)
|
||||
{
|
||||
old_val = new_val;
|
||||
new_val += 2 * s * exp(z * k * k);
|
||||
s *= -1;
|
||||
k++;
|
||||
}
|
||||
p = new_val;
|
||||
}
|
||||
p_value = 1 - p;
|
||||
}
|
||||
else
|
||||
{
|
||||
/* reference:
|
||||
* J. L. HODGES, Jr
|
||||
* The significance probability of the Smirnov two-sample test
|
||||
*/
|
||||
|
||||
// Use Hodges' suggested approximation Eqn 5.3
|
||||
// Requires m to be the larger of (n1, n2)
|
||||
Float64 expt = -2 * p * p - 2 * p * (m + 2 * n) / sqrt(m * n * (m + n)) / 3.0;
|
||||
p_value = exp(expt);
|
||||
}
|
||||
}
|
||||
return {d, p_value};
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
class AggregateFunctionKolmogorovSmirnov final:
|
||||
public IAggregateFunctionDataHelper<KolmogorovSmirnov, AggregateFunctionKolmogorovSmirnov>
|
||||
{
|
||||
private:
|
||||
using Alternative = typename KolmogorovSmirnov::Alternative;
|
||||
Alternative alternative = Alternative::TwoSided;
|
||||
String method = "auto";
|
||||
|
||||
public:
|
||||
explicit AggregateFunctionKolmogorovSmirnov(const DataTypes & arguments, const Array & params)
|
||||
: IAggregateFunctionDataHelper<KolmogorovSmirnov, AggregateFunctionKolmogorovSmirnov> ({arguments}, {}, createResultType())
|
||||
{
|
||||
if (params.size() > 2)
|
||||
throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Aggregate function {} require two parameter or less", getName());
|
||||
|
||||
if (params.empty())
|
||||
return;
|
||||
|
||||
if (params[0].getType() != Field::Types::String)
|
||||
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Aggregate function {} require first parameter to be a String", getName());
|
||||
|
||||
const auto & param = params[0].get<String>();
|
||||
if (param == "two-sided")
|
||||
alternative = Alternative::TwoSided;
|
||||
else if (param == "less")
|
||||
alternative = Alternative::Less;
|
||||
else if (param == "greater")
|
||||
alternative = Alternative::Greater;
|
||||
else
|
||||
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Unknown parameter in aggregate function {}. "
|
||||
"It must be one of: 'two-sided', 'less', 'greater'", getName());
|
||||
|
||||
if (params.size() != 2)
|
||||
return;
|
||||
|
||||
if (params[1].getType() != Field::Types::String)
|
||||
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Aggregate function {} require second parameter to be a String", getName());
|
||||
|
||||
method = params[1].get<String>();
|
||||
if (method != "auto" && method != "exact" && method != "asymp")
|
||||
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Unknown method in aggregate function {}. "
|
||||
"It must be one of: 'auto', 'exact', 'asymp'", getName());
|
||||
}
|
||||
|
||||
String getName() const override
|
||||
{
|
||||
return "kolmogorovSmirnovTest";
|
||||
}
|
||||
|
||||
bool allocatesMemoryInArena() const override { return true; }
|
||||
|
||||
static DataTypePtr createResultType()
|
||||
{
|
||||
DataTypes types
|
||||
{
|
||||
std::make_shared<DataTypeNumber<Float64>>(),
|
||||
std::make_shared<DataTypeNumber<Float64>>(),
|
||||
};
|
||||
|
||||
Strings names
|
||||
{
|
||||
"d_statistic",
|
||||
"p_value"
|
||||
};
|
||||
|
||||
return std::make_shared<DataTypeTuple>(
|
||||
std::move(types),
|
||||
std::move(names)
|
||||
);
|
||||
}
|
||||
|
||||
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override
|
||||
{
|
||||
Float64 value = columns[0]->getFloat64(row_num);
|
||||
UInt8 is_second = columns[1]->getUInt(row_num);
|
||||
if (is_second)
|
||||
this->data(place).addY(value, arena);
|
||||
else
|
||||
this->data(place).addX(value, arena);
|
||||
}
|
||||
|
||||
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena * arena) const override
|
||||
{
|
||||
this->data(place).merge(this->data(rhs), arena);
|
||||
}
|
||||
|
||||
void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> /* version */) const override
|
||||
{
|
||||
this->data(place).write(buf);
|
||||
}
|
||||
|
||||
void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena * arena) const override
|
||||
{
|
||||
this->data(place).read(buf, arena);
|
||||
}
|
||||
|
||||
void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override
|
||||
{
|
||||
if (!this->data(place).size_x || !this->data(place).size_y)
|
||||
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Aggregate function {} require both samples to be non empty", getName());
|
||||
|
||||
auto [d_statistic, p_value] = this->data(place).getResult(alternative, method);
|
||||
|
||||
/// Because p-value is a probability.
|
||||
p_value = std::min(1.0, std::max(0.0, p_value));
|
||||
|
||||
auto & column_tuple = assert_cast<ColumnTuple &>(to);
|
||||
auto & column_stat = assert_cast<ColumnVector<Float64> &>(column_tuple.getColumn(0));
|
||||
auto & column_value = assert_cast<ColumnVector<Float64> &>(column_tuple.getColumn(1));
|
||||
|
||||
column_stat.getData().push_back(d_statistic);
|
||||
column_value.getData().push_back(p_value);
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
}
|
@ -79,6 +79,7 @@ void registerAggregateFunctionExponentialMovingAverage(AggregateFunctionFactory
|
||||
void registerAggregateFunctionSparkbar(AggregateFunctionFactory &);
|
||||
void registerAggregateFunctionIntervalLengthSum(AggregateFunctionFactory &);
|
||||
void registerAggregateFunctionAnalysisOfVariance(AggregateFunctionFactory &);
|
||||
void registerAggregateFunctionKolmogorovSmirnovTest(AggregateFunctionFactory & factory);
|
||||
|
||||
class AggregateFunctionCombinatorFactory;
|
||||
void registerAggregateFunctionCombinatorIf(AggregateFunctionCombinatorFactory &);
|
||||
@ -170,6 +171,7 @@ void registerAggregateFunctions()
|
||||
registerAggregateFunctionExponentialMovingAverage(factory);
|
||||
registerAggregateFunctionSparkbar(factory);
|
||||
registerAggregateFunctionAnalysisOfVariance(factory);
|
||||
registerAggregateFunctionKolmogorovSmirnovTest(factory);
|
||||
|
||||
registerWindowFunctions(factory);
|
||||
}
|
||||
|
@ -0,0 +1,3 @@
|
||||
0.1 0.1 1 0.05 0.1 1 0.05 0.1 1 0.05 0.099562 1 0.018316 1 1 -0 1 1 -0 1 1 -0 1 1 -0 1
|
||||
0.000007 0.000007 0.000004 0.000023 0.000007 0.000004 0.000023 0.000007 0.000004 0.000023 0.000008 0.000003 0.00002 0.158 0.158 0.158 0.146 0.158 0.158 0.146 0.158 0.158 0.146 0.158 0.158 0.146
|
||||
0 0 0 0.523357 0 0 0.523357 0 0 0.523357 0 0 0.504595 0.486 0.486 0.486 0.036 0.486 0.486 0.036 0.486 0.486 0.036 0.486 0.486 0.036
|
107
tests/queries/0_stateless/02706_kolmogorov_smirnov_test.sql
Normal file
107
tests/queries/0_stateless/02706_kolmogorov_smirnov_test.sql
Normal file
File diff suppressed because one or more lines are too long
@ -0,0 +1,84 @@
|
||||
#!/usr/bin/env python3
|
||||
import os
|
||||
import sys
|
||||
from scipy import stats
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
|
||||
CURDIR = os.path.dirname(os.path.realpath(__file__))
|
||||
sys.path.insert(0, os.path.join(CURDIR, "helpers"))
|
||||
|
||||
from pure_http_client import ClickHouseClient
|
||||
|
||||
|
||||
def test_and_check(name, a, b, t_stat, p_value, precision=1e-2):
|
||||
client = ClickHouseClient()
|
||||
client.query("DROP TABLE IF EXISTS ks_test;")
|
||||
client.query("CREATE TABLE ks_test (left Float64, right UInt8) ENGINE = Memory;")
|
||||
client.query(
|
||||
"INSERT INTO ks_test VALUES {};".format(
|
||||
", ".join(["({},{})".format(i, 0) for i in a])
|
||||
)
|
||||
)
|
||||
client.query(
|
||||
"INSERT INTO ks_test VALUES {};".format(
|
||||
", ".join(["({},{})".format(j, 1) for j in b])
|
||||
)
|
||||
)
|
||||
real = client.query_return_df(
|
||||
"SELECT roundBankers({}(left, right).1, 16) as t_stat, ".format(name)
|
||||
+ "roundBankers({}(left, right).2, 16) as p_value ".format(name)
|
||||
+ "FROM ks_test FORMAT TabSeparatedWithNames;"
|
||||
)
|
||||
real_t_stat = real["t_stat"][0]
|
||||
real_p_value = real["p_value"][0]
|
||||
assert (
|
||||
abs(real_t_stat - np.float64(t_stat)) < precision
|
||||
), "clickhouse_t_stat {}, scipy_t_stat {}".format(real_t_stat, t_stat)
|
||||
assert (
|
||||
abs(real_p_value - np.float64(p_value)) < precision
|
||||
), "clickhouse_p_value {}, scipy_p_value {}".format(real_p_value, p_value)
|
||||
client.query("DROP TABLE IF EXISTS ks_test;")
|
||||
|
||||
|
||||
def test_ks_all_alternatives(rvs1, rvs2):
|
||||
s, p = stats.ks_2samp(rvs1, rvs2)
|
||||
test_and_check("kolmogorovSmirnovTest", rvs1, rvs2, s, p)
|
||||
|
||||
s, p = stats.ks_2samp(rvs1, rvs2, alternative="two-sided")
|
||||
test_and_check("kolmogorovSmirnovTest('two-sided')", rvs1, rvs2, s, p)
|
||||
|
||||
s, p = stats.ks_2samp(rvs1, rvs2, alternative="greater", method="auto")
|
||||
test_and_check("kolmogorovSmirnovTest('greater', 'auto')", rvs1, rvs2, s, p)
|
||||
|
||||
s, p = stats.ks_2samp(rvs1, rvs2, alternative="less", method="exact")
|
||||
test_and_check("kolmogorovSmirnovTest('less', 'exact')", rvs1, rvs2, s, p)
|
||||
|
||||
if max(len(rvs1), len(rvs2)) > 10000:
|
||||
s, p = stats.ks_2samp(rvs1, rvs2, alternative="two-sided", method="asymp")
|
||||
test_and_check("kolmogorovSmirnovTest('two-sided', 'asymp')", rvs1, rvs2, s, p)
|
||||
s, p = stats.ks_2samp(rvs1, rvs2, alternative="greater", method="asymp")
|
||||
test_and_check("kolmogorovSmirnovTest('greater', 'asymp')", rvs1, rvs2, s, p)
|
||||
|
||||
|
||||
def test_kolmogorov_smirnov():
|
||||
rvs1 = np.round(stats.norm.rvs(loc=1, scale=5, size=10), 2)
|
||||
rvs2 = np.round(stats.norm.rvs(loc=1.5, scale=5, size=20), 2)
|
||||
test_ks_all_alternatives(rvs1, rvs2)
|
||||
|
||||
rvs1 = np.round(stats.norm.rvs(loc=13, scale=1, size=100), 2)
|
||||
rvs2 = np.round(stats.norm.rvs(loc=1.52, scale=9, size=100), 2)
|
||||
test_ks_all_alternatives(rvs1, rvs2)
|
||||
|
||||
rvs1 = np.round(stats.norm.rvs(loc=1, scale=5, size=100), 2)
|
||||
rvs2 = np.round(stats.norm.rvs(loc=11.5, scale=50, size=1000), 2)
|
||||
test_ks_all_alternatives(rvs1, rvs2)
|
||||
|
||||
rvs1 = np.round(stats.norm.rvs(loc=1, scale=5, size=11000), 2)
|
||||
rvs2 = np.round(stats.norm.rvs(loc=3.5, scale=5.5, size=11000), 2)
|
||||
test_ks_all_alternatives(rvs1, rvs2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_kolmogorov_smirnov()
|
||||
print("Ok.")
|
@ -0,0 +1 @@
|
||||
Ok.
|
9
tests/queries/0_stateless/02706_kolmogorov_smirnov_test_scipy.sh
Executable file
9
tests/queries/0_stateless/02706_kolmogorov_smirnov_test_scipy.sh
Executable file
@ -0,0 +1,9 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
CURDIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)
|
||||
# shellcheck source=../shell_config.sh
|
||||
. "$CURDIR"/../shell_config.sh
|
||||
|
||||
# We should have correct env vars from shell_config.sh to run this test
|
||||
|
||||
python3 "$CURDIR"/02706_kolmogorov_smirnov_test_scipy.python
|
Loading…
Reference in New Issue
Block a user