From 9177ba3c02dcfa479246cafa7d8817427d734d9b Mon Sep 17 00:00:00 2001 From: nikitamikhaylov Date: Fri, 13 Nov 2020 01:45:19 +0300 Subject: [PATCH] test added --- .../AggregateFunctionMannWhitney.cpp | 2 +- .../AggregateFunctionMannWhitney.h | 21 ++++---- .../01561_mann_whitney_scipy.python | 54 +++++++++++++++++++ .../01561_mann_whitney_scipy.reference | 1 + .../0_stateless/01561_mann_whitney_scipy.sh | 8 +++ .../0_stateless/helpers/pure_http_client.py | 49 +++++++++++++++++ 6 files changed, 122 insertions(+), 13 deletions(-) create mode 100644 tests/queries/0_stateless/01561_mann_whitney_scipy.python create mode 100644 tests/queries/0_stateless/01561_mann_whitney_scipy.reference create mode 100755 tests/queries/0_stateless/01561_mann_whitney_scipy.sh create mode 100644 tests/queries/0_stateless/helpers/pure_http_client.py diff --git a/src/AggregateFunctions/AggregateFunctionMannWhitney.cpp b/src/AggregateFunctions/AggregateFunctionMannWhitney.cpp index b5fd39a451e..ceb0b930f73 100644 --- a/src/AggregateFunctions/AggregateFunctionMannWhitney.cpp +++ b/src/AggregateFunctions/AggregateFunctionMannWhitney.cpp @@ -23,7 +23,7 @@ AggregateFunctionPtr createAggregateFunctionMannWhitneyUTest(const std::string & if (!isNumber(argument_types[0]) || !isNumber(argument_types[1])) throw Exception("Aggregate function " + name + " only supports numerical types", ErrorCodes::NOT_IMPLEMENTED); - return std::make_shared>(argument_types, parameters); + return std::make_shared(argument_types, parameters); } } diff --git a/src/AggregateFunctions/AggregateFunctionMannWhitney.h b/src/AggregateFunctions/AggregateFunctionMannWhitney.h index 7dbc7722498..160d9b3e407 100644 --- a/src/AggregateFunctions/AggregateFunctionMannWhitney.h +++ b/src/AggregateFunctions/AggregateFunctionMannWhitney.h @@ -33,9 +33,7 @@ namespace ErrorCodes } -/// Required two samples be of the same type. Because we need to compute ranks of all observations from both samples. -template -struct MannWhitneyData : public StatisticalSample +struct MannWhitneyData : public StatisticalSample { enum class Alternative { @@ -44,8 +42,6 @@ struct MannWhitneyData : public StatisticalSample Greater }; - using Sample = typename StatisticalSample::SampleX; - std::pair getResult(Alternative alternative, bool continuity_correction) { ConcatenatedSamples both(this->x, this->y); @@ -88,6 +84,8 @@ struct MannWhitneyData : public StatisticalSample } private: + using Sample = typename StatisticalSample::SampleX; + /// We need to compute ranks according to all samples. Use this class to avoid extra copy and memory allocation. class ConcatenatedSamples { @@ -95,7 +93,7 @@ private: ConcatenatedSamples(const Sample & first_, const Sample & second_) : first(first_), second(second_) {} - const T & operator[](size_t ind) const + const Float64 & operator[](size_t ind) const { if (ind < first.size()) return first[ind]; @@ -113,18 +111,17 @@ private: }; }; -template -class AggregateFunctionMannWhitney : - public IAggregateFunctionDataHelper, AggregateFunctionMannWhitney> +class AggregateFunctionMannWhitney final: + public IAggregateFunctionDataHelper { private: - using Alternative = typename MannWhitneyData::Alternative; - typename MannWhitneyData::Alternative alternative; + using Alternative = typename MannWhitneyData::Alternative; + Alternative alternative; bool continuity_correction{true}; public: explicit AggregateFunctionMannWhitney(const DataTypes & arguments, const Array & params) - :IAggregateFunctionDataHelper, AggregateFunctionMannWhitney> ({arguments}, {}) + :IAggregateFunctionDataHelper ({arguments}, {}) { if (params.size() > 2) throw Exception("Aggregate function " + getName() + " require two parameter or less", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); diff --git a/tests/queries/0_stateless/01561_mann_whitney_scipy.python b/tests/queries/0_stateless/01561_mann_whitney_scipy.python new file mode 100644 index 00000000000..6905c758550 --- /dev/null +++ b/tests/queries/0_stateless/01561_mann_whitney_scipy.python @@ -0,0 +1,54 @@ +#!/usr/bin/env python3 +import os +import sys +from scipy import stats +import pandas as pd +import numpy as np + +CURDIR = os.path.dirname(os.path.realpath(__file__)) +sys.path.insert(0, os.path.join(CURDIR, 'helpers')) + +from pure_http_client import ClickHouseClient + + + +def test_and_check(name, a, b, t_stat, p_value): + client = ClickHouseClient() + client.query("DROP TABLE IF EXISTS mann_whitney;") + client.query("CREATE TABLE mann_whitney (left Float64, right UInt8) ENGINE = Memory;"); + client.query("INSERT INTO mann_whitney VALUES {};".format(", ".join(['({},{}), ({},{})'.format(i, 0, j, 1) for i,j in zip(a, b)]))) + + real = client.query_return_df( + "SELECT roundBankers({}(left, right).1, 16) as t_stat, ".format(name) + + "roundBankers({}(left, right).2, 16) as p_value ".format(name) + + "FROM mann_whitney FORMAT TabSeparatedWithNames;") + real_t_stat = real['t_stat'][0] + real_p_value = real['p_value'][0] + assert(abs(real_t_stat - np.float64(t_stat) < 1e-2)), "clickhouse_t_stat {}, scipy_t_stat {}".format(real_t_stat, t_stat) + assert(abs(real_p_value - np.float64(p_value)) < 1e-2), "clickhouse_p_value {}, scipy_p_value {}".format(real_p_value, p_value) + client.query("DROP TABLE IF EXISTS mann_whitney;") + + +def test_mann_whitney(): + rvs1 = np.round(stats.norm.rvs(loc=1, scale=5,size=500), 5) + rvs2 = np.round(stats.expon.rvs(scale=0.2,size=500), 5) + s, p = stats.mannwhitneyu(rvs1, rvs2, alternative='two-sided') + test_and_check("mannWhitneyUTest", rvs1, rvs2, s, p) + test_and_check("mannWhitneyUTest('two-sided')", rvs1, rvs2, s, p) + + equal = np.round(stats.cauchy.rvs(scale=5, size=500), 5) + s, p = stats.mannwhitneyu(equal, equal, alternative='two-sided') + test_and_check("mannWhitneyUTest('two-sided')", equal, equal, s, p) + + s, p = stats.mannwhitneyu(equal, equal, alternative='less', use_continuity=False) + test_and_check("mannWhitneyUTest('less', 0)", equal, equal, s, p) + + + rvs1 = np.round(stats.cauchy.rvs(scale=10,size=65536), 5) + rvs2 = np.round(stats.norm.rvs(loc=0, scale=10,size=65536), 5) + s, p = stats.mannwhitneyu(rvs1, rvs2, alternative='greater') + test_and_check("mannWhitneyUTest('greater')", rvs1, rvs2, s, p) + +if __name__ == "__main__": + test_mann_whitney() + print("Ok.") \ No newline at end of file diff --git a/tests/queries/0_stateless/01561_mann_whitney_scipy.reference b/tests/queries/0_stateless/01561_mann_whitney_scipy.reference new file mode 100644 index 00000000000..587579af915 --- /dev/null +++ b/tests/queries/0_stateless/01561_mann_whitney_scipy.reference @@ -0,0 +1 @@ +Ok. diff --git a/tests/queries/0_stateless/01561_mann_whitney_scipy.sh b/tests/queries/0_stateless/01561_mann_whitney_scipy.sh new file mode 100755 index 00000000000..e4e9152a97d --- /dev/null +++ b/tests/queries/0_stateless/01561_mann_whitney_scipy.sh @@ -0,0 +1,8 @@ +#!/usr/bin/env bash + +CURDIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd) +. "$CURDIR"/../shell_config.sh + +# We should have correct env vars from shell_config.sh to run this test + +python3 "$CURDIR"/01561_mann_whitney_scipy.python \ No newline at end of file diff --git a/tests/queries/0_stateless/helpers/pure_http_client.py b/tests/queries/0_stateless/helpers/pure_http_client.py new file mode 100644 index 00000000000..4e18ab3a0f4 --- /dev/null +++ b/tests/queries/0_stateless/helpers/pure_http_client.py @@ -0,0 +1,49 @@ +import os +import io +import sys +import requests +import time +import pandas as pd + +CLICKHOUSE_HOST = os.environ.get('CLICKHOUSE_HOST', '127.0.0.1') +CLICKHOUSE_PORT_HTTP = os.environ.get('CLICKHOUSE_PORT_HTTP', '8123') +CLICKHOUSE_SERVER_URL_STR = 'http://' + ':'.join(str(s) for s in [CLICKHOUSE_HOST, CLICKHOUSE_PORT_HTTP]) + "/" + +class ClickHouseClient: + def __init__(self, host = CLICKHOUSE_SERVER_URL_STR): + self.host = host + + def query(self, query, connection_timeout = 1500): + NUMBER_OF_TRIES = 30 + DELAY = 10 + + for i in range(NUMBER_OF_TRIES): + r = requests.post( + self.host, + params = {'timeout_before_checking_execution_speed': 120, 'max_execution_time': 6000}, + timeout = connection_timeout, + data = query) + if r.status_code == 200: + return r.text + else: + print('ATTENTION: try #%d failed' % i) + if i != (NUMBER_OF_TRIES-1): + print(query) + print(r.text) + time.sleep(DELAY*(i+1)) + else: + raise ValueError(r.text) + + def query_return_df(self, query, connection_timeout = 1500): + data = self.query(query, connection_timeout) + df = pd.read_csv(io.StringIO(data), sep = '\t') + return df + + def query_with_data(self, query, content): + content = content.encode('utf-8') + r = requests.post(self.host, data=content) + result = r.text + if r.status_code == 200: + return result + else: + raise ValueError(r.text) \ No newline at end of file