mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-11-21 15:12:02 +00:00
new interface for the function
This commit is contained in:
parent
516b152af3
commit
89547e77cf
@ -30,12 +30,12 @@ struct StudentTTestData : public TTestMoments<Float64>
|
|||||||
|
|
||||||
std::pair<Float64, Float64> getResult() const
|
std::pair<Float64, Float64> getResult() const
|
||||||
{
|
{
|
||||||
Float64 mean_x = x1 / m0;
|
Float64 mean_x = x1 / nx;
|
||||||
Float64 mean_y = y1 / m0;
|
Float64 mean_y = y1 / ny;
|
||||||
|
|
||||||
/// To estimate the variance we first estimate two means.
|
/// To estimate the variance we first estimate two means.
|
||||||
/// That's why the number of degrees of freedom is the total number of values of both samples minus 2.
|
/// That's why the number of degrees of freedom is the total number of values of both samples minus 2.
|
||||||
Float64 degrees_of_freedom = 2.0 * (m0 - 1);
|
Float64 degrees_of_freedom = nx + ny - 2;
|
||||||
|
|
||||||
/// Calculate s^2
|
/// Calculate s^2
|
||||||
/// The original formulae looks like
|
/// The original formulae looks like
|
||||||
@ -43,11 +43,11 @@ struct StudentTTestData : public TTestMoments<Float64>
|
|||||||
/// But we made some mathematical transformations not to store original sequences.
|
/// But we made some mathematical transformations not to store original sequences.
|
||||||
/// Also we dropped sqrt, because later it will be squared later.
|
/// Also we dropped sqrt, because later it will be squared later.
|
||||||
|
|
||||||
Float64 all_x = x2 + m0 * mean_x * mean_x - 2 * mean_x * m0;
|
Float64 all_x = x2 + nx * mean_x * mean_x - 2 * mean_x * x1;
|
||||||
Float64 all_y = y2 + m0 * mean_y * mean_y - 2 * mean_y * m0;
|
Float64 all_y = y2 + ny * mean_y * mean_y - 2 * mean_y * y1;
|
||||||
|
|
||||||
Float64 s2 = (all_x + all_y) / degrees_of_freedom;
|
Float64 s2 = (all_x + all_y) / degrees_of_freedom;
|
||||||
Float64 std_err2 = 2.0 * s2 / m0;
|
Float64 std_err2 = s2 * (1 / nx + 1 / ny);
|
||||||
|
|
||||||
/// t-statistic
|
/// t-statistic
|
||||||
Float64 t_stat = (mean_x - mean_y) / sqrt(std_err2);
|
Float64 t_stat = (mean_x - mean_y) / sqrt(std_err2);
|
||||||
@ -71,7 +71,7 @@ AggregateFunctionPtr createAggregateFunctionStudentTTest(const std::string & nam
|
|||||||
|
|
||||||
void registerAggregateFunctionStudentTTest(AggregateFunctionFactory & factory)
|
void registerAggregateFunctionStudentTTest(AggregateFunctionFactory & factory)
|
||||||
{
|
{
|
||||||
factory.registerFunction("studentTTest", createAggregateFunctionStudentTTest, AggregateFunctionFactory::CaseInsensitive);
|
factory.registerFunction("studentTTest", createAggregateFunctionStudentTTest);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -95,10 +95,10 @@ public:
|
|||||||
|
|
||||||
void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena *) const override
|
void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena *) const override
|
||||||
{
|
{
|
||||||
Float64 x = columns[0]->getFloat64(row_num);
|
Float64 value = columns[0]->getFloat64(row_num);
|
||||||
Float64 y = columns[1]->getFloat64(row_num);
|
UInt8 is_second = columns[1]->getUInt(row_num);
|
||||||
|
|
||||||
this->data(place).add(x, y);
|
this->data(place).add(value, static_cast<bool>(is_second));
|
||||||
}
|
}
|
||||||
|
|
||||||
void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena *) const override
|
void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena *) const override
|
||||||
|
@ -24,8 +24,8 @@ struct WelchTTestData : public TTestMoments<Float64>
|
|||||||
|
|
||||||
std::pair<Float64, Float64> getResult() const
|
std::pair<Float64, Float64> getResult() const
|
||||||
{
|
{
|
||||||
Float64 mean_x = x1 / m0;
|
Float64 mean_x = x1 / nx;
|
||||||
Float64 mean_y = y1 / m0;
|
Float64 mean_y = y1 / ny;
|
||||||
|
|
||||||
/// s_x^2, s_y^2
|
/// s_x^2, s_y^2
|
||||||
|
|
||||||
@ -33,19 +33,19 @@ struct WelchTTestData : public TTestMoments<Float64>
|
|||||||
/// But we made some mathematical transformations not to store original sequences.
|
/// But we made some mathematical transformations not to store original sequences.
|
||||||
/// Also we dropped sqrt, because later it will be squared later.
|
/// Also we dropped sqrt, because later it will be squared later.
|
||||||
|
|
||||||
Float64 sx2 = (x2 + m0 * mean_x * mean_x - 2 * mean_x * x1) / (m0 - 1);
|
Float64 sx2 = (x2 + nx * mean_x * mean_x - 2 * mean_x * x1) / (nx - 1);
|
||||||
Float64 sy2 = (y2 + m0 * mean_y * mean_y - 2 * mean_y * y1) / (m0 - 1);
|
Float64 sy2 = (y2 + ny * mean_y * mean_y - 2 * mean_y * y1) / (ny - 1);
|
||||||
|
|
||||||
/// t-statistic
|
/// t-statistic
|
||||||
Float64 t_stat = (mean_x - mean_y) / sqrt(sx2 / m0 + sy2 / m0);
|
Float64 t_stat = (mean_x - mean_y) / sqrt(sx2 / nx + sy2 / ny);
|
||||||
|
|
||||||
/// degrees of freedom
|
/// degrees of freedom
|
||||||
|
|
||||||
Float64 numerator_sqrt = sx2 / m0 + sy2 / m0;
|
Float64 numerator_sqrt = sx2 / nx + sy2 / ny;
|
||||||
Float64 numerator = numerator_sqrt * numerator_sqrt;
|
Float64 numerator = numerator_sqrt * numerator_sqrt;
|
||||||
|
|
||||||
Float64 denominator_x = sx2 * sx2 / (m0 * m0 * (m0 - 1));
|
Float64 denominator_x = sx2 * sx2 / (nx * nx * (nx - 1));
|
||||||
Float64 denominator_y = sy2 * sy2 / (m0 * m0 * (m0 - 1));
|
Float64 denominator_y = sy2 * sy2 / (ny * ny * (ny - 1));
|
||||||
|
|
||||||
Float64 degrees_of_freedom = numerator / (denominator_x + denominator_y);
|
Float64 degrees_of_freedom = numerator / (denominator_x + denominator_y);
|
||||||
|
|
||||||
@ -68,7 +68,7 @@ AggregateFunctionPtr createAggregateFunctionWelchTTest(const std::string & name,
|
|||||||
|
|
||||||
void registerAggregateFunctionWelchTTest(AggregateFunctionFactory & factory)
|
void registerAggregateFunctionWelchTTest(AggregateFunctionFactory & factory)
|
||||||
{
|
{
|
||||||
factory.registerFunction("welchTTest", createAggregateFunctionWelchTTest, AggregateFunctionFactory::CaseInsensitive);
|
factory.registerFunction("welchTTest", createAggregateFunctionWelchTTest);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -314,24 +314,30 @@ struct CorrMoments
|
|||||||
template <typename T>
|
template <typename T>
|
||||||
struct TTestMoments
|
struct TTestMoments
|
||||||
{
|
{
|
||||||
T m0{};
|
T nx{};
|
||||||
|
T ny{};
|
||||||
T x1{};
|
T x1{};
|
||||||
T y1{};
|
T y1{};
|
||||||
T x2{};
|
T x2{};
|
||||||
T y2{};
|
T y2{};
|
||||||
|
|
||||||
void add(T x, T y)
|
void add(T value, bool second_sample)
|
||||||
{
|
{
|
||||||
++m0;
|
if (second_sample) {
|
||||||
x1 += x;
|
++ny;
|
||||||
y1 += y;
|
y1 += value;
|
||||||
x2 += x * x;
|
y2 += value * value;
|
||||||
y2 += y * y;
|
} else {
|
||||||
|
++nx;
|
||||||
|
x1 += value;
|
||||||
|
x2 += value * value;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void merge(const TTestMoments & rhs)
|
void merge(const TTestMoments & rhs)
|
||||||
{
|
{
|
||||||
m0 += rhs.m0;
|
nx += rhs.nx;
|
||||||
|
ny += rhs.ny;
|
||||||
x1 += rhs.x1;
|
x1 += rhs.x1;
|
||||||
y1 += rhs.y1;
|
y1 += rhs.y1;
|
||||||
x2 += rhs.x2;
|
x2 += rhs.x2;
|
||||||
|
14
tests/queries/0_stateless/01558_ttest.reference
Normal file
14
tests/queries/0_stateless/01558_ttest.reference
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
0.021378001462867
|
||||||
|
0.0213780014628671
|
||||||
|
0.090773324285671
|
||||||
|
0.0907733242891952
|
||||||
|
0.00339907162713746
|
||||||
|
0.0033990715715539
|
||||||
|
-0.5028215369186904 0.6152361677168877
|
||||||
|
-0.5028215369187079 0.6152361677171103
|
||||||
|
14.971190998235835 5.898143508382202e-44
|
||||||
|
14.971190998235837 0
|
||||||
|
-2.610898982580138 0.00916587538237954
|
||||||
|
-2.610898982580134 0.0091658753823834
|
||||||
|
-28.740781574102936 7.667329672103986e-133
|
||||||
|
-28.74078157410298 0
|
55
tests/queries/0_stateless/01558_ttest.sql
Normal file
55
tests/queries/0_stateless/01558_ttest.sql
Normal file
File diff suppressed because one or more lines are too long
108
tests/queries/0_stateless/01558_ttest_scipy.python
Normal file
108
tests/queries/0_stateless/01558_ttest_scipy.python
Normal file
@ -0,0 +1,108 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
import os
|
||||||
|
import io
|
||||||
|
import sys
|
||||||
|
import requests
|
||||||
|
import time
|
||||||
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
|
from scipy import stats
|
||||||
|
|
||||||
|
CLICKHOUSE_HOST = os.environ.get('CLICKHOUSE_HOST', '127.0.0.1')
|
||||||
|
CLICKHOUSE_PORT_HTTP = os.environ.get('CLICKHOUSE_PORT_HTTP', '8123')
|
||||||
|
CLICKHOUSE_SERVER_URL_STR = 'http://' + ':'.join(str(s) for s in [CLICKHOUSE_HOST, CLICKHOUSE_PORT_HTTP]) + "/"
|
||||||
|
|
||||||
|
class ClickHouseClient:
|
||||||
|
def __init__(self, host = CLICKHOUSE_SERVER_URL_STR):
|
||||||
|
self.host = host
|
||||||
|
|
||||||
|
def query(self, query, connection_timeout = 1500):
|
||||||
|
NUMBER_OF_TRIES = 30
|
||||||
|
DELAY = 10
|
||||||
|
|
||||||
|
for i in range(NUMBER_OF_TRIES):
|
||||||
|
r = requests.post(
|
||||||
|
self.host,
|
||||||
|
params = {'timeout_before_checking_execution_speed': 120, 'max_execution_time': 6000},
|
||||||
|
timeout = connection_timeout,
|
||||||
|
data = query)
|
||||||
|
if r.status_code == 200:
|
||||||
|
return r.text
|
||||||
|
else:
|
||||||
|
print('ATTENTION: try #%d failed' % i)
|
||||||
|
if i != (NUMBER_OF_TRIES-1):
|
||||||
|
print(query)
|
||||||
|
print(r.text)
|
||||||
|
time.sleep(DELAY*(i+1))
|
||||||
|
else:
|
||||||
|
raise ValueError(r.text)
|
||||||
|
|
||||||
|
def query_return_df(self, query, connection_timeout = 1500):
|
||||||
|
data = self.query(query, connection_timeout)
|
||||||
|
df = pd.read_csv(io.StringIO(data), sep = '\t')
|
||||||
|
return df
|
||||||
|
|
||||||
|
def query_with_data(self, query, content):
|
||||||
|
content = content.encode('utf-8')
|
||||||
|
r = requests.post(self.host, data=content)
|
||||||
|
result = r.text
|
||||||
|
if r.status_code == 200:
|
||||||
|
return result
|
||||||
|
else:
|
||||||
|
raise ValueError(r.text)
|
||||||
|
|
||||||
|
def test_and_check(name, a, b, t_stat, p_value):
|
||||||
|
client = ClickHouseClient()
|
||||||
|
client.query("DROP TABLE IF EXISTS ttest;")
|
||||||
|
client.query("CREATE TABLE ttest (left Float64, right UInt8) ENGINE = Memory;");
|
||||||
|
client.query("INSERT INTO ttest VALUES {};".format(", ".join(['({},{}), ({},{})'.format(i, 0, j, 1) for i,j in zip(a, b)])))
|
||||||
|
|
||||||
|
real = client.query_return_df(
|
||||||
|
"SELECT roundBankers({}(left, right).1, 16) as t_stat, ".format(name) +
|
||||||
|
"roundBankers({}(left, right).2, 16) as p_value ".format(name) +
|
||||||
|
"FROM ttest FORMAT TabSeparatedWithNames;")
|
||||||
|
real_t_stat = real['t_stat'][0]
|
||||||
|
real_p_value = real['p_value'][0]
|
||||||
|
assert(abs(real_t_stat - np.float64(t_stat) < 1e-2)), "clickhouse_t_stat {}, scipy_t_stat {}".format(real_t_stat, t_stat)
|
||||||
|
assert(abs(real_p_value - np.float64(p_value)) < 1e-2), "clickhouse_p_value {}, scipy_p_value {}".format(real_p_value, p_value)
|
||||||
|
client.query("DROP TABLE IF EXISTS ttest;")
|
||||||
|
|
||||||
|
|
||||||
|
def test_student():
|
||||||
|
rvs1 = np.round(stats.norm.rvs(loc=1, scale=5,size=500), 5)
|
||||||
|
rvs2 = np.round(stats.norm.rvs(loc=10, scale=5,size=500), 5)
|
||||||
|
s, p = stats.ttest_ind(rvs1, rvs2, equal_var = True)
|
||||||
|
test_and_check("studentTTest", rvs1, rvs2, s, p)
|
||||||
|
|
||||||
|
rvs1 = np.round(stats.norm.rvs(loc=0, scale=5,size=500), 5)
|
||||||
|
rvs2 = np.round(stats.norm.rvs(loc=0, scale=5,size=500), 5)
|
||||||
|
s, p = stats.ttest_ind(rvs1, rvs2, equal_var = True)
|
||||||
|
test_and_check("studentTTest", rvs1, rvs2, s, p)
|
||||||
|
|
||||||
|
|
||||||
|
rvs1 = np.round(stats.norm.rvs(loc=0, scale=10,size=65536), 5)
|
||||||
|
rvs2 = np.round(stats.norm.rvs(loc=5, scale=1,size=65536), 5)
|
||||||
|
s, p = stats.ttest_ind(rvs1, rvs2, equal_var = True)
|
||||||
|
test_and_check("studentTTest", rvs1, rvs2, s, p)
|
||||||
|
|
||||||
|
def test_welch():
|
||||||
|
rvs1 = np.round(stats.norm.rvs(loc=1, scale=15,size=500), 5)
|
||||||
|
rvs2 = np.round(stats.norm.rvs(loc=10, scale=5,size=500), 5)
|
||||||
|
s, p = stats.ttest_ind(rvs1, rvs2, equal_var = True)
|
||||||
|
test_and_check("studentTTest", rvs1, rvs2, s, p)
|
||||||
|
|
||||||
|
rvs1 = np.round(stats.norm.rvs(loc=0, scale=7,size=500), 5)
|
||||||
|
rvs2 = np.round(stats.norm.rvs(loc=0, scale=3,size=500), 5)
|
||||||
|
s, p = stats.ttest_ind(rvs1, rvs2, equal_var = True)
|
||||||
|
test_and_check("studentTTest", rvs1, rvs2, s, p)
|
||||||
|
|
||||||
|
|
||||||
|
rvs1 = np.round(stats.norm.rvs(loc=0, scale=10,size=65536), 5)
|
||||||
|
rvs2 = np.round(stats.norm.rvs(loc=5, scale=1,size=65536), 5)
|
||||||
|
s, p = stats.ttest_ind(rvs1, rvs2, equal_var = True)
|
||||||
|
test_and_check("studentTTest", rvs1, rvs2, s, p)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_student()
|
||||||
|
test_welch()
|
||||||
|
print("Ok.")
|
1
tests/queries/0_stateless/01558_ttest_scipy.reference
Normal file
1
tests/queries/0_stateless/01558_ttest_scipy.reference
Normal file
@ -0,0 +1 @@
|
|||||||
|
Ok.
|
8
tests/queries/0_stateless/01558_ttest_scipy.sh
Executable file
8
tests/queries/0_stateless/01558_ttest_scipy.sh
Executable file
@ -0,0 +1,8 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
|
||||||
|
CURDIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)
|
||||||
|
. "$CURDIR"/../shell_config.sh
|
||||||
|
|
||||||
|
# We should have correct env vars from shell_config.sh to run this test
|
||||||
|
|
||||||
|
python3 "$CURDIR"/01558_ttest_scipy.python
|
Loading…
Reference in New Issue
Block a user