new interface for the function

This commit is contained in:
nikitamikhaylov 2020-11-06 20:48:58 +03:00
parent 516b152af3
commit 89547e77cf
9 changed files with 219 additions and 27 deletions

View File

@ -30,12 +30,12 @@ struct StudentTTestData : public TTestMoments<Float64>
std::pair<Float64, Float64> getResult() const std::pair<Float64, Float64> getResult() const
{ {
Float64 mean_x = x1 / m0; Float64 mean_x = x1 / nx;
Float64 mean_y = y1 / m0; Float64 mean_y = y1 / ny;
/// To estimate the variance we first estimate two means. /// To estimate the variance we first estimate two means.
/// That's why the number of degrees of freedom is the total number of values of both samples minus 2. /// That's why the number of degrees of freedom is the total number of values of both samples minus 2.
Float64 degrees_of_freedom = 2.0 * (m0 - 1); Float64 degrees_of_freedom = nx + ny - 2;
/// Calculate s^2 /// Calculate s^2
/// The original formulae looks like /// The original formulae looks like
@ -43,11 +43,11 @@ struct StudentTTestData : public TTestMoments<Float64>
/// But we made some mathematical transformations not to store original sequences. /// But we made some mathematical transformations not to store original sequences.
/// Also we dropped sqrt, because later it will be squared later. /// Also we dropped sqrt, because later it will be squared later.
Float64 all_x = x2 + m0 * mean_x * mean_x - 2 * mean_x * m0; Float64 all_x = x2 + nx * mean_x * mean_x - 2 * mean_x * x1;
Float64 all_y = y2 + m0 * mean_y * mean_y - 2 * mean_y * m0; Float64 all_y = y2 + ny * mean_y * mean_y - 2 * mean_y * y1;
Float64 s2 = (all_x + all_y) / degrees_of_freedom; Float64 s2 = (all_x + all_y) / degrees_of_freedom;
Float64 std_err2 = 2.0 * s2 / m0; Float64 std_err2 = s2 * (1 / nx + 1 / ny);
/// t-statistic /// t-statistic
Float64 t_stat = (mean_x - mean_y) / sqrt(std_err2); Float64 t_stat = (mean_x - mean_y) / sqrt(std_err2);
@ -71,7 +71,7 @@ AggregateFunctionPtr createAggregateFunctionStudentTTest(const std::string & nam
void registerAggregateFunctionStudentTTest(AggregateFunctionFactory & factory) void registerAggregateFunctionStudentTTest(AggregateFunctionFactory & factory)
{ {
factory.registerFunction("studentTTest", createAggregateFunctionStudentTTest, AggregateFunctionFactory::CaseInsensitive); factory.registerFunction("studentTTest", createAggregateFunctionStudentTTest);
} }
} }

View File

@ -95,10 +95,10 @@ public:
void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena *) const override void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena *) const override
{ {
Float64 x = columns[0]->getFloat64(row_num); Float64 value = columns[0]->getFloat64(row_num);
Float64 y = columns[1]->getFloat64(row_num); UInt8 is_second = columns[1]->getUInt(row_num);
this->data(place).add(x, y); this->data(place).add(value, static_cast<bool>(is_second));
} }
void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena *) const override void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena *) const override

View File

@ -24,8 +24,8 @@ struct WelchTTestData : public TTestMoments<Float64>
std::pair<Float64, Float64> getResult() const std::pair<Float64, Float64> getResult() const
{ {
Float64 mean_x = x1 / m0; Float64 mean_x = x1 / nx;
Float64 mean_y = y1 / m0; Float64 mean_y = y1 / ny;
/// s_x^2, s_y^2 /// s_x^2, s_y^2
@ -33,19 +33,19 @@ struct WelchTTestData : public TTestMoments<Float64>
/// But we made some mathematical transformations not to store original sequences. /// But we made some mathematical transformations not to store original sequences.
/// Also we dropped sqrt, because later it will be squared later. /// Also we dropped sqrt, because later it will be squared later.
Float64 sx2 = (x2 + m0 * mean_x * mean_x - 2 * mean_x * x1) / (m0 - 1); Float64 sx2 = (x2 + nx * mean_x * mean_x - 2 * mean_x * x1) / (nx - 1);
Float64 sy2 = (y2 + m0 * mean_y * mean_y - 2 * mean_y * y1) / (m0 - 1); Float64 sy2 = (y2 + ny * mean_y * mean_y - 2 * mean_y * y1) / (ny - 1);
/// t-statistic /// t-statistic
Float64 t_stat = (mean_x - mean_y) / sqrt(sx2 / m0 + sy2 / m0); Float64 t_stat = (mean_x - mean_y) / sqrt(sx2 / nx + sy2 / ny);
/// degrees of freedom /// degrees of freedom
Float64 numerator_sqrt = sx2 / m0 + sy2 / m0; Float64 numerator_sqrt = sx2 / nx + sy2 / ny;
Float64 numerator = numerator_sqrt * numerator_sqrt; Float64 numerator = numerator_sqrt * numerator_sqrt;
Float64 denominator_x = sx2 * sx2 / (m0 * m0 * (m0 - 1)); Float64 denominator_x = sx2 * sx2 / (nx * nx * (nx - 1));
Float64 denominator_y = sy2 * sy2 / (m0 * m0 * (m0 - 1)); Float64 denominator_y = sy2 * sy2 / (ny * ny * (ny - 1));
Float64 degrees_of_freedom = numerator / (denominator_x + denominator_y); Float64 degrees_of_freedom = numerator / (denominator_x + denominator_y);
@ -68,7 +68,7 @@ AggregateFunctionPtr createAggregateFunctionWelchTTest(const std::string & name,
void registerAggregateFunctionWelchTTest(AggregateFunctionFactory & factory) void registerAggregateFunctionWelchTTest(AggregateFunctionFactory & factory)
{ {
factory.registerFunction("welchTTest", createAggregateFunctionWelchTTest, AggregateFunctionFactory::CaseInsensitive); factory.registerFunction("welchTTest", createAggregateFunctionWelchTTest);
} }
} }

View File

@ -314,24 +314,30 @@ struct CorrMoments
template <typename T> template <typename T>
struct TTestMoments struct TTestMoments
{ {
T m0{}; T nx{};
T ny{};
T x1{}; T x1{};
T y1{}; T y1{};
T x2{}; T x2{};
T y2{}; T y2{};
void add(T x, T y) void add(T value, bool second_sample)
{ {
++m0; if (second_sample) {
x1 += x; ++ny;
y1 += y; y1 += value;
x2 += x * x; y2 += value * value;
y2 += y * y; } else {
++nx;
x1 += value;
x2 += value * value;
}
} }
void merge(const TTestMoments & rhs) void merge(const TTestMoments & rhs)
{ {
m0 += rhs.m0; nx += rhs.nx;
ny += rhs.ny;
x1 += rhs.x1; x1 += rhs.x1;
y1 += rhs.y1; y1 += rhs.y1;
x2 += rhs.x2; x2 += rhs.x2;

View File

@ -0,0 +1,14 @@
0.021378001462867
0.0213780014628671
0.090773324285671
0.0907733242891952
0.00339907162713746
0.0033990715715539
-0.5028215369186904 0.6152361677168877
-0.5028215369187079 0.6152361677171103
14.971190998235835 5.898143508382202e-44
14.971190998235837 0
-2.610898982580138 0.00916587538237954
-2.610898982580134 0.0091658753823834
-28.740781574102936 7.667329672103986e-133
-28.74078157410298 0

File diff suppressed because one or more lines are too long

View File

@ -0,0 +1,108 @@
#!/usr/bin/env python3
import os
import io
import sys
import requests
import time
import pandas as pd
import numpy as np
from scipy import stats
CLICKHOUSE_HOST = os.environ.get('CLICKHOUSE_HOST', '127.0.0.1')
CLICKHOUSE_PORT_HTTP = os.environ.get('CLICKHOUSE_PORT_HTTP', '8123')
CLICKHOUSE_SERVER_URL_STR = 'http://' + ':'.join(str(s) for s in [CLICKHOUSE_HOST, CLICKHOUSE_PORT_HTTP]) + "/"
class ClickHouseClient:
def __init__(self, host = CLICKHOUSE_SERVER_URL_STR):
self.host = host
def query(self, query, connection_timeout = 1500):
NUMBER_OF_TRIES = 30
DELAY = 10
for i in range(NUMBER_OF_TRIES):
r = requests.post(
self.host,
params = {'timeout_before_checking_execution_speed': 120, 'max_execution_time': 6000},
timeout = connection_timeout,
data = query)
if r.status_code == 200:
return r.text
else:
print('ATTENTION: try #%d failed' % i)
if i != (NUMBER_OF_TRIES-1):
print(query)
print(r.text)
time.sleep(DELAY*(i+1))
else:
raise ValueError(r.text)
def query_return_df(self, query, connection_timeout = 1500):
data = self.query(query, connection_timeout)
df = pd.read_csv(io.StringIO(data), sep = '\t')
return df
def query_with_data(self, query, content):
content = content.encode('utf-8')
r = requests.post(self.host, data=content)
result = r.text
if r.status_code == 200:
return result
else:
raise ValueError(r.text)
def test_and_check(name, a, b, t_stat, p_value):
client = ClickHouseClient()
client.query("DROP TABLE IF EXISTS ttest;")
client.query("CREATE TABLE ttest (left Float64, right UInt8) ENGINE = Memory;");
client.query("INSERT INTO ttest VALUES {};".format(", ".join(['({},{}), ({},{})'.format(i, 0, j, 1) for i,j in zip(a, b)])))
real = client.query_return_df(
"SELECT roundBankers({}(left, right).1, 16) as t_stat, ".format(name) +
"roundBankers({}(left, right).2, 16) as p_value ".format(name) +
"FROM ttest FORMAT TabSeparatedWithNames;")
real_t_stat = real['t_stat'][0]
real_p_value = real['p_value'][0]
assert(abs(real_t_stat - np.float64(t_stat) < 1e-2)), "clickhouse_t_stat {}, scipy_t_stat {}".format(real_t_stat, t_stat)
assert(abs(real_p_value - np.float64(p_value)) < 1e-2), "clickhouse_p_value {}, scipy_p_value {}".format(real_p_value, p_value)
client.query("DROP TABLE IF EXISTS ttest;")
def test_student():
rvs1 = np.round(stats.norm.rvs(loc=1, scale=5,size=500), 5)
rvs2 = np.round(stats.norm.rvs(loc=10, scale=5,size=500), 5)
s, p = stats.ttest_ind(rvs1, rvs2, equal_var = True)
test_and_check("studentTTest", rvs1, rvs2, s, p)
rvs1 = np.round(stats.norm.rvs(loc=0, scale=5,size=500), 5)
rvs2 = np.round(stats.norm.rvs(loc=0, scale=5,size=500), 5)
s, p = stats.ttest_ind(rvs1, rvs2, equal_var = True)
test_and_check("studentTTest", rvs1, rvs2, s, p)
rvs1 = np.round(stats.norm.rvs(loc=0, scale=10,size=65536), 5)
rvs2 = np.round(stats.norm.rvs(loc=5, scale=1,size=65536), 5)
s, p = stats.ttest_ind(rvs1, rvs2, equal_var = True)
test_and_check("studentTTest", rvs1, rvs2, s, p)
def test_welch():
rvs1 = np.round(stats.norm.rvs(loc=1, scale=15,size=500), 5)
rvs2 = np.round(stats.norm.rvs(loc=10, scale=5,size=500), 5)
s, p = stats.ttest_ind(rvs1, rvs2, equal_var = True)
test_and_check("studentTTest", rvs1, rvs2, s, p)
rvs1 = np.round(stats.norm.rvs(loc=0, scale=7,size=500), 5)
rvs2 = np.round(stats.norm.rvs(loc=0, scale=3,size=500), 5)
s, p = stats.ttest_ind(rvs1, rvs2, equal_var = True)
test_and_check("studentTTest", rvs1, rvs2, s, p)
rvs1 = np.round(stats.norm.rvs(loc=0, scale=10,size=65536), 5)
rvs2 = np.round(stats.norm.rvs(loc=5, scale=1,size=65536), 5)
s, p = stats.ttest_ind(rvs1, rvs2, equal_var = True)
test_and_check("studentTTest", rvs1, rvs2, s, p)
if __name__ == "__main__":
test_student()
test_welch()
print("Ok.")

View File

@ -0,0 +1 @@
Ok.

View File

@ -0,0 +1,8 @@
#!/usr/bin/env bash
CURDIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)
. "$CURDIR"/../shell_config.sh
# We should have correct env vars from shell_config.sh to run this test
python3 "$CURDIR"/01558_ttest_scipy.python