From 779538bd8971be1b22354216834da14bbe7fdc51 Mon Sep 17 00:00:00 2001 From: achimbab <07c00h@gmail.com> Date: Thu, 20 Jan 2022 22:57:37 +0900 Subject: [PATCH] Implemented meanZTest (#33354) --- .../reference/meanztest.md | 70 +++++++++ .../AggregateFunctionMeanZTest.cpp | 64 ++++++++ .../AggregateFunctionMeanZTest.h | 139 ++++++++++++++++++ src/AggregateFunctions/Moments.h | 71 +++++++++ .../registerAggregateFunctions.cpp | 2 + .../queries/0_stateless/02158_ztest.reference | 1 + tests/queries/0_stateless/02158_ztest.sql | 6 + .../0_stateless/02158_ztest_cmp.python | 77 ++++++++++ .../0_stateless/02158_ztest_cmp.reference | 1 + tests/queries/0_stateless/02158_ztest_cmp.sh | 9 ++ 10 files changed, 440 insertions(+) create mode 100644 docs/en/sql-reference/aggregate-functions/reference/meanztest.md create mode 100644 src/AggregateFunctions/AggregateFunctionMeanZTest.cpp create mode 100644 src/AggregateFunctions/AggregateFunctionMeanZTest.h create mode 100644 tests/queries/0_stateless/02158_ztest.reference create mode 100644 tests/queries/0_stateless/02158_ztest.sql create mode 100644 tests/queries/0_stateless/02158_ztest_cmp.python create mode 100644 tests/queries/0_stateless/02158_ztest_cmp.reference create mode 100755 tests/queries/0_stateless/02158_ztest_cmp.sh diff --git a/docs/en/sql-reference/aggregate-functions/reference/meanztest.md b/docs/en/sql-reference/aggregate-functions/reference/meanztest.md new file mode 100644 index 00000000000..7d016f42819 --- /dev/null +++ b/docs/en/sql-reference/aggregate-functions/reference/meanztest.md @@ -0,0 +1,70 @@ +--- +toc_priority: 303 +toc_title: meanZTest +--- + +# meanZTest {#meanztest} + +Applies mean z-test to samples from two populations. + +**Syntax** + +``` sql +meanZTest(population_variance_x, population_variance_y, confidence_level)(sample_data, sample_index) +``` + +Values of both samples are in the `sample_data` column. If `sample_index` equals to 0 then the value in that row belongs to the sample from the first population. Otherwise it belongs to the sample from the second population. +The null hypothesis is that means of populations are equal. Normal distribution is assumed. Populations may have unequal variance and the variances are known. + +**Arguments** + +- `sample_data` — Sample data. [Integer](../../../sql-reference/data-types/int-uint.md), [Float](../../../sql-reference/data-types/float.md) or [Decimal](../../../sql-reference/data-types/decimal.md). +- `sample_index` — Sample index. [Integer](../../../sql-reference/data-types/int-uint.md). + +**Parameters** + +- `population_variance_x` — Variance for population x. [Float](../../../sql-reference/data-types/float.md). +- `population_variance_y` — Variance for population y. [Float](../../../sql-reference/data-types/float.md). +- `confidence_level` — Confidence level in order to calculate confidence intervals. [Float](../../../sql-reference/data-types/float.md). + +**Returned values** + +[Tuple](../../../sql-reference/data-types/tuple.md) with four elements: + +- calculated t-statistic. [Float64](../../../sql-reference/data-types/float.md). +- calculated p-value. [Float64](../../../sql-reference/data-types/float.md). +- calculated confidence-interval-low. [Float64](../../../sql-reference/data-types/float.md). +- calculated confidence-interval-high. [Float64](../../../sql-reference/data-types/float.md). + + +**Example** + +Input table: + +``` text +┌─sample_data─┬─sample_index─┐ +│ 20.3 │ 0 │ +│ 21.9 │ 0 │ +│ 22.1 │ 0 │ +│ 18.9 │ 1 │ +│ 19 │ 1 │ +│ 20.3 │ 1 │ +└─────────────┴──────────────┘ +``` + +Query: + +``` sql +SELECT meanZTest(0.7, 0.45, 0.95)(sample_data, sample_index) FROM mean_ztest +``` + +Result: + +``` text +┌─meanZTest(0.7, 0.45, 0.95)(sample_data, sample_index)────────────────────────────┐ +│ (3.2841296025548123,0.0010229786769086013,0.8198428246768334,3.2468238419898365) │ +└──────────────────────────────────────────────────────────────────────────────────┘ +``` + + +[Original article](https://clickhouse.com/docs/en/sql-reference/aggregate-functions/reference/meanZTest/) diff --git a/src/AggregateFunctions/AggregateFunctionMeanZTest.cpp b/src/AggregateFunctions/AggregateFunctionMeanZTest.cpp new file mode 100644 index 00000000000..edc4361bce3 --- /dev/null +++ b/src/AggregateFunctions/AggregateFunctionMeanZTest.cpp @@ -0,0 +1,64 @@ +#include +#include +#include +#include + + +namespace ErrorCodes +{ + extern const int BAD_ARGUMENTS; + extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; +} + + +namespace DB +{ +struct Settings; + +namespace +{ + +struct MeanZTestData : public ZTestMoments +{ + static constexpr auto name = "meanZTest"; + + std::pair getResult(Float64 pop_var_x, Float64 pop_var_y) const + { + Float64 mean_x = getMeanX(); + Float64 mean_y = getMeanY(); + + /// z = \frac{\bar{X_{1}} - \bar{X_{2}}}{\sqrt{\frac{\sigma_{1}^{2}}{n_{1}} + \frac{\sigma_{2}^{2}}{n_{2}}}} + Float64 zstat = (mean_x - mean_y) / getStandardError(pop_var_x, pop_var_y); + if (!std::isfinite(zstat)) + { + return {std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN()}; + } + + Float64 pvalue = 2.0 * boost::math::cdf(boost::math::normal(0.0, 1.0), -1.0 * std::abs(zstat)); + + return {zstat, pvalue}; + } +}; + +AggregateFunctionPtr createAggregateFunctionMeanZTest( + const std::string & name, const DataTypes & argument_types, const Array & parameters, const Settings *) +{ + assertBinary(name, argument_types); + + if (parameters.size() != 3) + throw Exception("Aggregate function " + name + " requires three parameter.", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); + + if (!isNumber(argument_types[0]) || !isNumber(argument_types[1])) + throw Exception("Aggregate function " + name + " only supports numerical types", ErrorCodes::BAD_ARGUMENTS); + + return std::make_shared>(argument_types, parameters); +} + +} + +void registerAggregateFunctionMeanZTest(AggregateFunctionFactory & factory) +{ + factory.registerFunction("meanZTest", createAggregateFunctionMeanZTest); +} + +} diff --git a/src/AggregateFunctions/AggregateFunctionMeanZTest.h b/src/AggregateFunctions/AggregateFunctionMeanZTest.h new file mode 100644 index 00000000000..e4be2503d87 --- /dev/null +++ b/src/AggregateFunctions/AggregateFunctionMeanZTest.h @@ -0,0 +1,139 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +namespace DB +{ +struct Settings; + +class ReadBuffer; +class WriteBuffer; + +namespace ErrorCodes +{ + extern const int BAD_ARGUMENTS; +} + + +/// Returns tuple of (z-statistic, p-value, confidence-interval-low, confidence-interval-high) +template +class AggregateFunctionMeanZTest : + public IAggregateFunctionDataHelper> +{ +private: + Float64 pop_var_x; + Float64 pop_var_y; + Float64 confidence_level; + +public: + AggregateFunctionMeanZTest(const DataTypes & arguments, const Array & params) + : IAggregateFunctionDataHelper>({arguments}, params) + { + pop_var_x = params.at(0).safeGet(); + pop_var_y = params.at(1).safeGet(); + confidence_level = params.at(2).safeGet(); + + if (!std::isfinite(pop_var_x) || !std::isfinite(pop_var_y) || !std::isfinite(confidence_level)) + { + throw Exception(ErrorCodes::BAD_ARGUMENTS, "Aggregate function {} requires finite parameter values.", Data::name); + } + + if (pop_var_x < 0.0 || pop_var_y < 0.0) + { + throw Exception(ErrorCodes::BAD_ARGUMENTS, "Population variance parameters must be larger than or equal to zero in aggregate function {}.", Data::name); + } + + if (confidence_level <= 0.0 || confidence_level >= 1.0) + { + throw Exception(ErrorCodes::BAD_ARGUMENTS, "Confidence level parameter must be between 0 and 1 in aggregate function {}.", Data::name); + } + } + + String getName() const override + { + return Data::name; + } + + DataTypePtr getReturnType() const override + { + DataTypes types + { + std::make_shared>(), + std::make_shared>(), + std::make_shared>(), + std::make_shared>(), + }; + + Strings names + { + "z_statistic", + "p_value", + "confidence_interval_low", + "confidence_interval_high" + }; + + return std::make_shared( + std::move(types), + std::move(names) + ); + } + + bool allocatesMemoryInArena() const override { return false; } + + void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override + { + Float64 value = columns[0]->getFloat64(row_num); + UInt8 is_second = columns[1]->getUInt(row_num); + + if (is_second) + this->data(place).addY(value); + else + this->data(place).addX(value); + } + + void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override + { + this->data(place).merge(this->data(rhs)); + } + + void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional /* version */) const override + { + this->data(place).write(buf); + } + + void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional /* version */, Arena *) const override + { + this->data(place).read(buf); + } + + void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override + { + auto [z_stat, p_value] = this->data(place).getResult(pop_var_x, pop_var_y); + auto [ci_low, ci_high] = this->data(place).getConfidenceIntervals(pop_var_x, pop_var_y, confidence_level); + + /// Because p-value is a probability. + p_value = std::min(1.0, std::max(0.0, p_value)); + + auto & column_tuple = assert_cast(to); + auto & column_stat = assert_cast &>(column_tuple.getColumn(0)); + auto & column_value = assert_cast &>(column_tuple.getColumn(1)); + auto & column_ci_low = assert_cast &>(column_tuple.getColumn(2)); + auto & column_ci_high = assert_cast &>(column_tuple.getColumn(3)); + + column_stat.getData().push_back(z_stat); + column_value.getData().push_back(p_value); + column_ci_low.getData().push_back(ci_low); + column_ci_high.getData().push_back(ci_high); + } +}; + +}; diff --git a/src/AggregateFunctions/Moments.h b/src/AggregateFunctions/Moments.h index 6f51e76607f..d2a6b0b5581 100644 --- a/src/AggregateFunctions/Moments.h +++ b/src/AggregateFunctions/Moments.h @@ -2,6 +2,7 @@ #include #include +#include namespace DB @@ -359,4 +360,74 @@ struct TTestMoments } }; +template +struct ZTestMoments +{ + T nx{}; + T ny{}; + T x1{}; + T y1{}; + + void addX(T value) + { + ++nx; + x1 += value; + } + + void addY(T value) + { + ++ny; + y1 += value; + } + + void merge(const ZTestMoments & rhs) + { + nx += rhs.nx; + ny += rhs.ny; + x1 += rhs.x1; + y1 += rhs.y1; + } + + void write(WriteBuffer & buf) const + { + writePODBinary(*this, buf); + } + + void read(ReadBuffer & buf) + { + readPODBinary(*this, buf); + } + + Float64 getMeanX() const + { + return x1 / nx; + } + + Float64 getMeanY() const + { + return y1 / ny; + } + + Float64 getStandardError(Float64 pop_var_x, Float64 pop_var_y) const + { + /// \sqrt{\frac{\sigma_{1}^{2}}{n_{1}} + \frac{\sigma_{2}^{2}}{n_{2}}} + return std::sqrt(pop_var_x / nx + pop_var_y / ny); + } + + std::pair getConfidenceIntervals(Float64 pop_var_x, Float64 pop_var_y, Float64 confidence_level) const + { + /// (\bar{x_{1}} - \bar{x_{2}}) \pm zscore \times \sqrt{\frac{\sigma_{1}^{2}}{n_{1}} + \frac{\sigma_{2}^{2}}{n_{2}}} + Float64 mean_x = getMeanX(); + Float64 mean_y = getMeanY(); + + Float64 z = boost::math::quantile(boost::math::complement( + boost::math::normal(0.0f, 1.0f), (1.0f - confidence_level) / 2.0f)); + Float64 se = getStandardError(pop_var_x, pop_var_y); + Float64 ci_low = (mean_x - mean_y) - z * se; + Float64 ci_high = (mean_x - mean_y) + z * se; + + return {ci_low, ci_high}; + } +}; + } diff --git a/src/AggregateFunctions/registerAggregateFunctions.cpp b/src/AggregateFunctions/registerAggregateFunctions.cpp index 33f6a532224..351adac31bb 100644 --- a/src/AggregateFunctions/registerAggregateFunctions.cpp +++ b/src/AggregateFunctions/registerAggregateFunctions.cpp @@ -48,6 +48,7 @@ void registerAggregateFunctionRankCorrelation(AggregateFunctionFactory &); void registerAggregateFunctionMannWhitney(AggregateFunctionFactory &); void registerAggregateFunctionWelchTTest(AggregateFunctionFactory &); void registerAggregateFunctionStudentTTest(AggregateFunctionFactory &); +void registerAggregateFunctionMeanZTest(AggregateFunctionFactory &); void registerAggregateFunctionCramersV(AggregateFunctionFactory &); void registerAggregateFunctionTheilsU(AggregateFunctionFactory &); void registerAggregateFunctionContingency(AggregateFunctionFactory &); @@ -123,6 +124,7 @@ void registerAggregateFunctions() registerAggregateFunctionSequenceNextNode(factory); registerAggregateFunctionWelchTTest(factory); registerAggregateFunctionStudentTTest(factory); + registerAggregateFunctionMeanZTest(factory); registerAggregateFunctionNothing(factory); registerAggregateFunctionSingleValueOrNull(factory); registerAggregateFunctionIntervalLengthSum(factory); diff --git a/tests/queries/0_stateless/02158_ztest.reference b/tests/queries/0_stateless/02158_ztest.reference new file mode 100644 index 00000000000..0b0dd7134b3 --- /dev/null +++ b/tests/queries/0_stateless/02158_ztest.reference @@ -0,0 +1 @@ +-0.1749814092128543 0.8610942415056733 -12.200984112294334 10.200984112294334 diff --git a/tests/queries/0_stateless/02158_ztest.sql b/tests/queries/0_stateless/02158_ztest.sql new file mode 100644 index 00000000000..1d3e55db9ca --- /dev/null +++ b/tests/queries/0_stateless/02158_ztest.sql @@ -0,0 +1,6 @@ +DROP TABLE IF EXISTS mean_ztest; +CREATE TABLE mean_ztest (v int, s UInt8) ENGINE = Memory; +INSERT INTO mean_ztest SELECT number, 0 FROM numbers(100) WHERE number % 2 = 0; +INSERT INTO mean_ztest SELECT number, 1 FROM numbers(100) WHERE number % 2 = 1; +SELECT roundBankers(meanZTest(833.0, 800.0, 0.95)(v, s).1, 16) as z_stat, roundBankers(meanZTest(833.0, 800.0, 0.95)(v, s).2, 16) as p_value, roundBankers(meanZTest(833.0, 800.0, 0.95)(v, s).3, 16) as ci_low, roundBankers(meanZTest(833.0, 800.0, 0.95)(v, s).4, 16) as ci_high FROM mean_ztest; +DROP TABLE IF EXISTS mean_ztest; diff --git a/tests/queries/0_stateless/02158_ztest_cmp.python b/tests/queries/0_stateless/02158_ztest_cmp.python new file mode 100644 index 00000000000..8fc22d78e74 --- /dev/null +++ b/tests/queries/0_stateless/02158_ztest_cmp.python @@ -0,0 +1,77 @@ +#!/usr/bin/env python3 +import os +import sys +from statistics import variance +from scipy import stats +import pandas as pd +import numpy as np + +CURDIR = os.path.dirname(os.path.realpath(__file__)) +sys.path.insert(0, os.path.join(CURDIR, 'helpers')) + +from pure_http_client import ClickHouseClient + + +# unpooled variance z-test for means of two samples +def twosample_mean_ztest(rvs1, rvs2, alpha=0.05): + mean_rvs1 = np.mean(rvs1) + mean_rvs2 = np.mean(rvs2) + var_pop_rvs1 = variance(rvs1) + var_pop_rvs2 = variance(rvs2) + se = np.sqrt(var_pop_rvs1 / len(rvs1) + var_pop_rvs2 / len(rvs2)) + z_stat = (mean_rvs1 - mean_rvs2) / se + p_val = 2 * stats.norm.cdf(-1 * abs(z_stat)) + z_a = stats.norm.ppf(1 - alpha / 2) + ci_low = (mean_rvs1 - mean_rvs2) - z_a * se + ci_high = (mean_rvs1 - mean_rvs2) + z_a * se + return z_stat, p_val, ci_low, ci_high + + +def test_and_check(name, a, b, t_stat, p_value, ci_low, ci_high, precision=1e-2): + client = ClickHouseClient() + client.query("DROP TABLE IF EXISTS ztest;") + client.query("CREATE TABLE ztest (left Float64, right UInt8) ENGINE = Memory;"); + client.query("INSERT INTO ztest VALUES {};".format(", ".join(['({},{})'.format(i, 0) for i in a]))) + client.query("INSERT INTO ztest VALUES {};".format(", ".join(['({},{})'.format(j, 1) for j in b]))) + real = client.query_return_df( + "SELECT roundBankers({}(left, right).1, 16) as t_stat, ".format(name) + + "roundBankers({}(left, right).2, 16) as p_value, ".format(name) + + "roundBankers({}(left, right).3, 16) as ci_low, ".format(name) + + "roundBankers({}(left, right).4, 16) as ci_high ".format(name) + + "FROM ztest FORMAT TabSeparatedWithNames;") + real_t_stat = real['t_stat'][0] + real_p_value = real['p_value'][0] + real_ci_low = real['ci_low'][0] + real_ci_high = real['ci_high'][0] + assert(abs(real_t_stat - np.float64(t_stat)) < precision), "clickhouse_t_stat {}, py_t_stat {}".format(real_t_stat, t_stat) + assert(abs(real_p_value - np.float64(p_value)) < precision), "clickhouse_p_value {}, py_p_value {}".format(real_p_value, p_value) + assert(abs(real_ci_low - np.float64(ci_low)) < precision), "clickhouse_ci_low {}, py_ci_low {}".format(real_ci_low, ci_low) + assert(abs(real_ci_high - np.float64(ci_high)) < precision), "clickhouse_ci_high {}, py_ci_high {}".format(real_ci_high, ci_high) + client.query("DROP TABLE IF EXISTS ztest;") + + +def test_mean_ztest(): + rvs1 = np.round(stats.norm.rvs(loc=1, scale=5,size=500), 2) + rvs2 = np.round(stats.norm.rvs(loc=10, scale=5,size=500), 2) + s, p, cl, ch = twosample_mean_ztest(rvs1, rvs2) + test_and_check("meanZTest(%f, %f, 0.95)" % (variance(rvs1), variance(rvs2)), rvs1, rvs2, s, p, cl, ch) + + rvs1 = np.round(stats.norm.rvs(loc=0, scale=5,size=500), 2) + rvs2 = np.round(stats.norm.rvs(loc=0, scale=5,size=500), 2) + s, p, cl, ch = twosample_mean_ztest(rvs1, rvs2) + test_and_check("meanZTest(%f, %f, 0.95)" % (variance(rvs1), variance(rvs2)), rvs1, rvs2, s, p, cl, ch) + + rvs1 = np.round(stats.norm.rvs(loc=2, scale=10,size=512), 2) + rvs2 = np.round(stats.norm.rvs(loc=5, scale=20,size=1024), 2) + s, p, cl, ch = twosample_mean_ztest(rvs1, rvs2) + test_and_check("meanZTest(%f, %f, 0.95)" % (variance(rvs1), variance(rvs2)), rvs1, rvs2, s, p, cl, ch) + + rvs1 = np.round(stats.norm.rvs(loc=0, scale=10,size=1024), 2) + rvs2 = np.round(stats.norm.rvs(loc=0, scale=10,size=512), 2) + s, p, cl, ch = twosample_mean_ztest(rvs1, rvs2) + test_and_check("meanZTest(%f, %f, 0.95)" % (variance(rvs1), variance(rvs2)), rvs1, rvs2, s, p, cl, ch) + + +if __name__ == "__main__": + test_mean_ztest() + print("Ok.") diff --git a/tests/queries/0_stateless/02158_ztest_cmp.reference b/tests/queries/0_stateless/02158_ztest_cmp.reference new file mode 100644 index 00000000000..587579af915 --- /dev/null +++ b/tests/queries/0_stateless/02158_ztest_cmp.reference @@ -0,0 +1 @@ +Ok. diff --git a/tests/queries/0_stateless/02158_ztest_cmp.sh b/tests/queries/0_stateless/02158_ztest_cmp.sh new file mode 100755 index 00000000000..4e6affbe11a --- /dev/null +++ b/tests/queries/0_stateless/02158_ztest_cmp.sh @@ -0,0 +1,9 @@ +#!/usr/bin/env bash + +CURDIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd) +# shellcheck source=../shell_config.sh +. "$CURDIR"/../shell_config.sh + +# We should have correct env vars from shell_config.sh to run this test + +python3 "$CURDIR"/02158_ztest_cmp.python