mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-12-03 13:02:00 +00:00
test added
This commit is contained in:
parent
02ce3ed4e7
commit
9177ba3c02
@ -23,7 +23,7 @@ AggregateFunctionPtr createAggregateFunctionMannWhitneyUTest(const std::string &
|
|||||||
if (!isNumber(argument_types[0]) || !isNumber(argument_types[1]))
|
if (!isNumber(argument_types[0]) || !isNumber(argument_types[1]))
|
||||||
throw Exception("Aggregate function " + name + " only supports numerical types", ErrorCodes::NOT_IMPLEMENTED);
|
throw Exception("Aggregate function " + name + " only supports numerical types", ErrorCodes::NOT_IMPLEMENTED);
|
||||||
|
|
||||||
return std::make_shared<AggregateFunctionMannWhitney<Float64>>(argument_types, parameters);
|
return std::make_shared<AggregateFunctionMannWhitney>(argument_types, parameters);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -33,9 +33,7 @@ namespace ErrorCodes
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/// Required two samples be of the same type. Because we need to compute ranks of all observations from both samples.
|
struct MannWhitneyData : public StatisticalSample<Float64, Float64>
|
||||||
template <typename T>
|
|
||||||
struct MannWhitneyData : public StatisticalSample<T, T>
|
|
||||||
{
|
{
|
||||||
enum class Alternative
|
enum class Alternative
|
||||||
{
|
{
|
||||||
@ -44,8 +42,6 @@ struct MannWhitneyData : public StatisticalSample<T, T>
|
|||||||
Greater
|
Greater
|
||||||
};
|
};
|
||||||
|
|
||||||
using Sample = typename StatisticalSample<T, T>::SampleX;
|
|
||||||
|
|
||||||
std::pair<Float64, Float64> getResult(Alternative alternative, bool continuity_correction)
|
std::pair<Float64, Float64> getResult(Alternative alternative, bool continuity_correction)
|
||||||
{
|
{
|
||||||
ConcatenatedSamples both(this->x, this->y);
|
ConcatenatedSamples both(this->x, this->y);
|
||||||
@ -88,6 +84,8 @@ struct MannWhitneyData : public StatisticalSample<T, T>
|
|||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
using Sample = typename StatisticalSample<Float64, Float64>::SampleX;
|
||||||
|
|
||||||
/// We need to compute ranks according to all samples. Use this class to avoid extra copy and memory allocation.
|
/// We need to compute ranks according to all samples. Use this class to avoid extra copy and memory allocation.
|
||||||
class ConcatenatedSamples
|
class ConcatenatedSamples
|
||||||
{
|
{
|
||||||
@ -95,7 +93,7 @@ private:
|
|||||||
ConcatenatedSamples(const Sample & first_, const Sample & second_)
|
ConcatenatedSamples(const Sample & first_, const Sample & second_)
|
||||||
: first(first_), second(second_) {}
|
: first(first_), second(second_) {}
|
||||||
|
|
||||||
const T & operator[](size_t ind) const
|
const Float64 & operator[](size_t ind) const
|
||||||
{
|
{
|
||||||
if (ind < first.size())
|
if (ind < first.size())
|
||||||
return first[ind];
|
return first[ind];
|
||||||
@ -113,18 +111,17 @@ private:
|
|||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename T>
|
class AggregateFunctionMannWhitney final:
|
||||||
class AggregateFunctionMannWhitney :
|
public IAggregateFunctionDataHelper<MannWhitneyData, AggregateFunctionMannWhitney>
|
||||||
public IAggregateFunctionDataHelper<MannWhitneyData<T>, AggregateFunctionMannWhitney<T>>
|
|
||||||
{
|
{
|
||||||
private:
|
private:
|
||||||
using Alternative = typename MannWhitneyData<T>::Alternative;
|
using Alternative = typename MannWhitneyData::Alternative;
|
||||||
typename MannWhitneyData<T>::Alternative alternative;
|
Alternative alternative;
|
||||||
bool continuity_correction{true};
|
bool continuity_correction{true};
|
||||||
|
|
||||||
public:
|
public:
|
||||||
explicit AggregateFunctionMannWhitney(const DataTypes & arguments, const Array & params)
|
explicit AggregateFunctionMannWhitney(const DataTypes & arguments, const Array & params)
|
||||||
:IAggregateFunctionDataHelper<MannWhitneyData<T>, AggregateFunctionMannWhitney<T>> ({arguments}, {})
|
:IAggregateFunctionDataHelper<MannWhitneyData, AggregateFunctionMannWhitney> ({arguments}, {})
|
||||||
{
|
{
|
||||||
if (params.size() > 2)
|
if (params.size() > 2)
|
||||||
throw Exception("Aggregate function " + getName() + " require two parameter or less", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
|
throw Exception("Aggregate function " + getName() + " require two parameter or less", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
|
||||||
|
54
tests/queries/0_stateless/01561_mann_whitney_scipy.python
Normal file
54
tests/queries/0_stateless/01561_mann_whitney_scipy.python
Normal file
@ -0,0 +1,54 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from scipy import stats
|
||||||
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
CURDIR = os.path.dirname(os.path.realpath(__file__))
|
||||||
|
sys.path.insert(0, os.path.join(CURDIR, 'helpers'))
|
||||||
|
|
||||||
|
from pure_http_client import ClickHouseClient
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def test_and_check(name, a, b, t_stat, p_value):
|
||||||
|
client = ClickHouseClient()
|
||||||
|
client.query("DROP TABLE IF EXISTS mann_whitney;")
|
||||||
|
client.query("CREATE TABLE mann_whitney (left Float64, right UInt8) ENGINE = Memory;");
|
||||||
|
client.query("INSERT INTO mann_whitney VALUES {};".format(", ".join(['({},{}), ({},{})'.format(i, 0, j, 1) for i,j in zip(a, b)])))
|
||||||
|
|
||||||
|
real = client.query_return_df(
|
||||||
|
"SELECT roundBankers({}(left, right).1, 16) as t_stat, ".format(name) +
|
||||||
|
"roundBankers({}(left, right).2, 16) as p_value ".format(name) +
|
||||||
|
"FROM mann_whitney FORMAT TabSeparatedWithNames;")
|
||||||
|
real_t_stat = real['t_stat'][0]
|
||||||
|
real_p_value = real['p_value'][0]
|
||||||
|
assert(abs(real_t_stat - np.float64(t_stat) < 1e-2)), "clickhouse_t_stat {}, scipy_t_stat {}".format(real_t_stat, t_stat)
|
||||||
|
assert(abs(real_p_value - np.float64(p_value)) < 1e-2), "clickhouse_p_value {}, scipy_p_value {}".format(real_p_value, p_value)
|
||||||
|
client.query("DROP TABLE IF EXISTS mann_whitney;")
|
||||||
|
|
||||||
|
|
||||||
|
def test_mann_whitney():
|
||||||
|
rvs1 = np.round(stats.norm.rvs(loc=1, scale=5,size=500), 5)
|
||||||
|
rvs2 = np.round(stats.expon.rvs(scale=0.2,size=500), 5)
|
||||||
|
s, p = stats.mannwhitneyu(rvs1, rvs2, alternative='two-sided')
|
||||||
|
test_and_check("mannWhitneyUTest", rvs1, rvs2, s, p)
|
||||||
|
test_and_check("mannWhitneyUTest('two-sided')", rvs1, rvs2, s, p)
|
||||||
|
|
||||||
|
equal = np.round(stats.cauchy.rvs(scale=5, size=500), 5)
|
||||||
|
s, p = stats.mannwhitneyu(equal, equal, alternative='two-sided')
|
||||||
|
test_and_check("mannWhitneyUTest('two-sided')", equal, equal, s, p)
|
||||||
|
|
||||||
|
s, p = stats.mannwhitneyu(equal, equal, alternative='less', use_continuity=False)
|
||||||
|
test_and_check("mannWhitneyUTest('less', 0)", equal, equal, s, p)
|
||||||
|
|
||||||
|
|
||||||
|
rvs1 = np.round(stats.cauchy.rvs(scale=10,size=65536), 5)
|
||||||
|
rvs2 = np.round(stats.norm.rvs(loc=0, scale=10,size=65536), 5)
|
||||||
|
s, p = stats.mannwhitneyu(rvs1, rvs2, alternative='greater')
|
||||||
|
test_and_check("mannWhitneyUTest('greater')", rvs1, rvs2, s, p)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_mann_whitney()
|
||||||
|
print("Ok.")
|
@ -0,0 +1 @@
|
|||||||
|
Ok.
|
8
tests/queries/0_stateless/01561_mann_whitney_scipy.sh
Executable file
8
tests/queries/0_stateless/01561_mann_whitney_scipy.sh
Executable file
@ -0,0 +1,8 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
|
||||||
|
CURDIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)
|
||||||
|
. "$CURDIR"/../shell_config.sh
|
||||||
|
|
||||||
|
# We should have correct env vars from shell_config.sh to run this test
|
||||||
|
|
||||||
|
python3 "$CURDIR"/01561_mann_whitney_scipy.python
|
49
tests/queries/0_stateless/helpers/pure_http_client.py
Normal file
49
tests/queries/0_stateless/helpers/pure_http_client.py
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
import os
|
||||||
|
import io
|
||||||
|
import sys
|
||||||
|
import requests
|
||||||
|
import time
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
CLICKHOUSE_HOST = os.environ.get('CLICKHOUSE_HOST', '127.0.0.1')
|
||||||
|
CLICKHOUSE_PORT_HTTP = os.environ.get('CLICKHOUSE_PORT_HTTP', '8123')
|
||||||
|
CLICKHOUSE_SERVER_URL_STR = 'http://' + ':'.join(str(s) for s in [CLICKHOUSE_HOST, CLICKHOUSE_PORT_HTTP]) + "/"
|
||||||
|
|
||||||
|
class ClickHouseClient:
|
||||||
|
def __init__(self, host = CLICKHOUSE_SERVER_URL_STR):
|
||||||
|
self.host = host
|
||||||
|
|
||||||
|
def query(self, query, connection_timeout = 1500):
|
||||||
|
NUMBER_OF_TRIES = 30
|
||||||
|
DELAY = 10
|
||||||
|
|
||||||
|
for i in range(NUMBER_OF_TRIES):
|
||||||
|
r = requests.post(
|
||||||
|
self.host,
|
||||||
|
params = {'timeout_before_checking_execution_speed': 120, 'max_execution_time': 6000},
|
||||||
|
timeout = connection_timeout,
|
||||||
|
data = query)
|
||||||
|
if r.status_code == 200:
|
||||||
|
return r.text
|
||||||
|
else:
|
||||||
|
print('ATTENTION: try #%d failed' % i)
|
||||||
|
if i != (NUMBER_OF_TRIES-1):
|
||||||
|
print(query)
|
||||||
|
print(r.text)
|
||||||
|
time.sleep(DELAY*(i+1))
|
||||||
|
else:
|
||||||
|
raise ValueError(r.text)
|
||||||
|
|
||||||
|
def query_return_df(self, query, connection_timeout = 1500):
|
||||||
|
data = self.query(query, connection_timeout)
|
||||||
|
df = pd.read_csv(io.StringIO(data), sep = '\t')
|
||||||
|
return df
|
||||||
|
|
||||||
|
def query_with_data(self, query, content):
|
||||||
|
content = content.encode('utf-8')
|
||||||
|
r = requests.post(self.host, data=content)
|
||||||
|
result = r.text
|
||||||
|
if r.status_code == 200:
|
||||||
|
return result
|
||||||
|
else:
|
||||||
|
raise ValueError(r.text)
|
Loading…
Reference in New Issue
Block a user