Implemented meanZTest (#33354)

This commit is contained in:
achimbab 2022-01-20 22:57:37 +09:00 committed by GitHub
parent 7156e64ee2
commit 779538bd89
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 440 additions and 0 deletions

View File

@ -0,0 +1,70 @@
---
toc_priority: 303
toc_title: meanZTest
---
# meanZTest {#meanztest}
Applies mean z-test to samples from two populations.
**Syntax**
``` sql
meanZTest(population_variance_x, population_variance_y, confidence_level)(sample_data, sample_index)
```
Values of both samples are in the `sample_data` column. If `sample_index` equals to 0 then the value in that row belongs to the sample from the first population. Otherwise it belongs to the sample from the second population.
The null hypothesis is that means of populations are equal. Normal distribution is assumed. Populations may have unequal variance and the variances are known.
**Arguments**
- `sample_data` — Sample data. [Integer](../../../sql-reference/data-types/int-uint.md), [Float](../../../sql-reference/data-types/float.md) or [Decimal](../../../sql-reference/data-types/decimal.md).
- `sample_index` — Sample index. [Integer](../../../sql-reference/data-types/int-uint.md).
**Parameters**
- `population_variance_x` — Variance for population x. [Float](../../../sql-reference/data-types/float.md).
- `population_variance_y` — Variance for population y. [Float](../../../sql-reference/data-types/float.md).
- `confidence_level` — Confidence level in order to calculate confidence intervals. [Float](../../../sql-reference/data-types/float.md).
**Returned values**
[Tuple](../../../sql-reference/data-types/tuple.md) with four elements:
- calculated t-statistic. [Float64](../../../sql-reference/data-types/float.md).
- calculated p-value. [Float64](../../../sql-reference/data-types/float.md).
- calculated confidence-interval-low. [Float64](../../../sql-reference/data-types/float.md).
- calculated confidence-interval-high. [Float64](../../../sql-reference/data-types/float.md).
**Example**
Input table:
``` text
┌─sample_data─┬─sample_index─┐
│ 20.3 │ 0 │
│ 21.9 │ 0 │
│ 22.1 │ 0 │
│ 18.9 │ 1 │
│ 19 │ 1 │
│ 20.3 │ 1 │
└─────────────┴──────────────┘
```
Query:
``` sql
SELECT meanZTest(0.7, 0.45, 0.95)(sample_data, sample_index) FROM mean_ztest
```
Result:
``` text
┌─meanZTest(0.7, 0.45, 0.95)(sample_data, sample_index)────────────────────────────┐
│ (3.2841296025548123,0.0010229786769086013,0.8198428246768334,3.2468238419898365) │
└──────────────────────────────────────────────────────────────────────────────────┘
```
[Original article](https://clickhouse.com/docs/en/sql-reference/aggregate-functions/reference/meanZTest/) <!--hide-->

View File

@ -0,0 +1,64 @@
#include <AggregateFunctions/AggregateFunctionFactory.h>
#include <AggregateFunctions/AggregateFunctionMeanZTest.h>
#include <AggregateFunctions/FactoryHelpers.h>
#include <AggregateFunctions/Moments.h>
namespace ErrorCodes
{
extern const int BAD_ARGUMENTS;
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
}
namespace DB
{
struct Settings;
namespace
{
struct MeanZTestData : public ZTestMoments<Float64>
{
static constexpr auto name = "meanZTest";
std::pair<Float64, Float64> getResult(Float64 pop_var_x, Float64 pop_var_y) const
{
Float64 mean_x = getMeanX();
Float64 mean_y = getMeanY();
/// z = \frac{\bar{X_{1}} - \bar{X_{2}}}{\sqrt{\frac{\sigma_{1}^{2}}{n_{1}} + \frac{\sigma_{2}^{2}}{n_{2}}}}
Float64 zstat = (mean_x - mean_y) / getStandardError(pop_var_x, pop_var_y);
if (!std::isfinite(zstat))
{
return {std::numeric_limits<Float64>::quiet_NaN(), std::numeric_limits<Float64>::quiet_NaN()};
}
Float64 pvalue = 2.0 * boost::math::cdf(boost::math::normal(0.0, 1.0), -1.0 * std::abs(zstat));
return {zstat, pvalue};
}
};
AggregateFunctionPtr createAggregateFunctionMeanZTest(
const std::string & name, const DataTypes & argument_types, const Array & parameters, const Settings *)
{
assertBinary(name, argument_types);
if (parameters.size() != 3)
throw Exception("Aggregate function " + name + " requires three parameter.", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
if (!isNumber(argument_types[0]) || !isNumber(argument_types[1]))
throw Exception("Aggregate function " + name + " only supports numerical types", ErrorCodes::BAD_ARGUMENTS);
return std::make_shared<AggregateFunctionMeanZTest<MeanZTestData>>(argument_types, parameters);
}
}
void registerAggregateFunctionMeanZTest(AggregateFunctionFactory & factory)
{
factory.registerFunction("meanZTest", createAggregateFunctionMeanZTest);
}
}

View File

@ -0,0 +1,139 @@
#pragma once
#include <AggregateFunctions/IAggregateFunction.h>
#include <AggregateFunctions/StatCommon.h>
#include <Columns/ColumnVector.h>
#include <Columns/ColumnTuple.h>
#include <Common/assert_cast.h>
#include <Core/Types.h>
#include <DataTypes/DataTypesNumber.h>
#include <DataTypes/DataTypeTuple.h>
#include <cmath>
namespace DB
{
struct Settings;
class ReadBuffer;
class WriteBuffer;
namespace ErrorCodes
{
extern const int BAD_ARGUMENTS;
}
/// Returns tuple of (z-statistic, p-value, confidence-interval-low, confidence-interval-high)
template <typename Data>
class AggregateFunctionMeanZTest :
public IAggregateFunctionDataHelper<Data, AggregateFunctionMeanZTest<Data>>
{
private:
Float64 pop_var_x;
Float64 pop_var_y;
Float64 confidence_level;
public:
AggregateFunctionMeanZTest(const DataTypes & arguments, const Array & params)
: IAggregateFunctionDataHelper<Data, AggregateFunctionMeanZTest<Data>>({arguments}, params)
{
pop_var_x = params.at(0).safeGet<Float64>();
pop_var_y = params.at(1).safeGet<Float64>();
confidence_level = params.at(2).safeGet<Float64>();
if (!std::isfinite(pop_var_x) || !std::isfinite(pop_var_y) || !std::isfinite(confidence_level))
{
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Aggregate function {} requires finite parameter values.", Data::name);
}
if (pop_var_x < 0.0 || pop_var_y < 0.0)
{
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Population variance parameters must be larger than or equal to zero in aggregate function {}.", Data::name);
}
if (confidence_level <= 0.0 || confidence_level >= 1.0)
{
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Confidence level parameter must be between 0 and 1 in aggregate function {}.", Data::name);
}
}
String getName() const override
{
return Data::name;
}
DataTypePtr getReturnType() const override
{
DataTypes types
{
std::make_shared<DataTypeNumber<Float64>>(),
std::make_shared<DataTypeNumber<Float64>>(),
std::make_shared<DataTypeNumber<Float64>>(),
std::make_shared<DataTypeNumber<Float64>>(),
};
Strings names
{
"z_statistic",
"p_value",
"confidence_interval_low",
"confidence_interval_high"
};
return std::make_shared<DataTypeTuple>(
std::move(types),
std::move(names)
);
}
bool allocatesMemoryInArena() const override { return false; }
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override
{
Float64 value = columns[0]->getFloat64(row_num);
UInt8 is_second = columns[1]->getUInt(row_num);
if (is_second)
this->data(place).addY(value);
else
this->data(place).addX(value);
}
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override
{
this->data(place).merge(this->data(rhs));
}
void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> /* version */) const override
{
this->data(place).write(buf);
}
void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena *) const override
{
this->data(place).read(buf);
}
void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override
{
auto [z_stat, p_value] = this->data(place).getResult(pop_var_x, pop_var_y);
auto [ci_low, ci_high] = this->data(place).getConfidenceIntervals(pop_var_x, pop_var_y, confidence_level);
/// Because p-value is a probability.
p_value = std::min(1.0, std::max(0.0, p_value));
auto & column_tuple = assert_cast<ColumnTuple &>(to);
auto & column_stat = assert_cast<ColumnVector<Float64> &>(column_tuple.getColumn(0));
auto & column_value = assert_cast<ColumnVector<Float64> &>(column_tuple.getColumn(1));
auto & column_ci_low = assert_cast<ColumnVector<Float64> &>(column_tuple.getColumn(2));
auto & column_ci_high = assert_cast<ColumnVector<Float64> &>(column_tuple.getColumn(3));
column_stat.getData().push_back(z_stat);
column_value.getData().push_back(p_value);
column_ci_low.getData().push_back(ci_low);
column_ci_high.getData().push_back(ci_high);
}
};
};

View File

@ -2,6 +2,7 @@
#include <IO/WriteHelpers.h> #include <IO/WriteHelpers.h>
#include <IO/ReadHelpers.h> #include <IO/ReadHelpers.h>
#include <boost/math/distributions/normal.hpp>
namespace DB namespace DB
@ -359,4 +360,74 @@ struct TTestMoments
} }
}; };
template <typename T>
struct ZTestMoments
{
T nx{};
T ny{};
T x1{};
T y1{};
void addX(T value)
{
++nx;
x1 += value;
}
void addY(T value)
{
++ny;
y1 += value;
}
void merge(const ZTestMoments & rhs)
{
nx += rhs.nx;
ny += rhs.ny;
x1 += rhs.x1;
y1 += rhs.y1;
}
void write(WriteBuffer & buf) const
{
writePODBinary(*this, buf);
}
void read(ReadBuffer & buf)
{
readPODBinary(*this, buf);
}
Float64 getMeanX() const
{
return x1 / nx;
}
Float64 getMeanY() const
{
return y1 / ny;
}
Float64 getStandardError(Float64 pop_var_x, Float64 pop_var_y) const
{
/// \sqrt{\frac{\sigma_{1}^{2}}{n_{1}} + \frac{\sigma_{2}^{2}}{n_{2}}}
return std::sqrt(pop_var_x / nx + pop_var_y / ny);
}
std::pair<Float64, Float64> getConfidenceIntervals(Float64 pop_var_x, Float64 pop_var_y, Float64 confidence_level) const
{
/// (\bar{x_{1}} - \bar{x_{2}}) \pm zscore \times \sqrt{\frac{\sigma_{1}^{2}}{n_{1}} + \frac{\sigma_{2}^{2}}{n_{2}}}
Float64 mean_x = getMeanX();
Float64 mean_y = getMeanY();
Float64 z = boost::math::quantile(boost::math::complement(
boost::math::normal(0.0f, 1.0f), (1.0f - confidence_level) / 2.0f));
Float64 se = getStandardError(pop_var_x, pop_var_y);
Float64 ci_low = (mean_x - mean_y) - z * se;
Float64 ci_high = (mean_x - mean_y) + z * se;
return {ci_low, ci_high};
}
};
} }

View File

@ -48,6 +48,7 @@ void registerAggregateFunctionRankCorrelation(AggregateFunctionFactory &);
void registerAggregateFunctionMannWhitney(AggregateFunctionFactory &); void registerAggregateFunctionMannWhitney(AggregateFunctionFactory &);
void registerAggregateFunctionWelchTTest(AggregateFunctionFactory &); void registerAggregateFunctionWelchTTest(AggregateFunctionFactory &);
void registerAggregateFunctionStudentTTest(AggregateFunctionFactory &); void registerAggregateFunctionStudentTTest(AggregateFunctionFactory &);
void registerAggregateFunctionMeanZTest(AggregateFunctionFactory &);
void registerAggregateFunctionCramersV(AggregateFunctionFactory &); void registerAggregateFunctionCramersV(AggregateFunctionFactory &);
void registerAggregateFunctionTheilsU(AggregateFunctionFactory &); void registerAggregateFunctionTheilsU(AggregateFunctionFactory &);
void registerAggregateFunctionContingency(AggregateFunctionFactory &); void registerAggregateFunctionContingency(AggregateFunctionFactory &);
@ -123,6 +124,7 @@ void registerAggregateFunctions()
registerAggregateFunctionSequenceNextNode(factory); registerAggregateFunctionSequenceNextNode(factory);
registerAggregateFunctionWelchTTest(factory); registerAggregateFunctionWelchTTest(factory);
registerAggregateFunctionStudentTTest(factory); registerAggregateFunctionStudentTTest(factory);
registerAggregateFunctionMeanZTest(factory);
registerAggregateFunctionNothing(factory); registerAggregateFunctionNothing(factory);
registerAggregateFunctionSingleValueOrNull(factory); registerAggregateFunctionSingleValueOrNull(factory);
registerAggregateFunctionIntervalLengthSum(factory); registerAggregateFunctionIntervalLengthSum(factory);

View File

@ -0,0 +1 @@
-0.1749814092128543 0.8610942415056733 -12.200984112294334 10.200984112294334

View File

@ -0,0 +1,6 @@
DROP TABLE IF EXISTS mean_ztest;
CREATE TABLE mean_ztest (v int, s UInt8) ENGINE = Memory;
INSERT INTO mean_ztest SELECT number, 0 FROM numbers(100) WHERE number % 2 = 0;
INSERT INTO mean_ztest SELECT number, 1 FROM numbers(100) WHERE number % 2 = 1;
SELECT roundBankers(meanZTest(833.0, 800.0, 0.95)(v, s).1, 16) as z_stat, roundBankers(meanZTest(833.0, 800.0, 0.95)(v, s).2, 16) as p_value, roundBankers(meanZTest(833.0, 800.0, 0.95)(v, s).3, 16) as ci_low, roundBankers(meanZTest(833.0, 800.0, 0.95)(v, s).4, 16) as ci_high FROM mean_ztest;
DROP TABLE IF EXISTS mean_ztest;

View File

@ -0,0 +1,77 @@
#!/usr/bin/env python3
import os
import sys
from statistics import variance
from scipy import stats
import pandas as pd
import numpy as np
CURDIR = os.path.dirname(os.path.realpath(__file__))
sys.path.insert(0, os.path.join(CURDIR, 'helpers'))
from pure_http_client import ClickHouseClient
# unpooled variance z-test for means of two samples
def twosample_mean_ztest(rvs1, rvs2, alpha=0.05):
mean_rvs1 = np.mean(rvs1)
mean_rvs2 = np.mean(rvs2)
var_pop_rvs1 = variance(rvs1)
var_pop_rvs2 = variance(rvs2)
se = np.sqrt(var_pop_rvs1 / len(rvs1) + var_pop_rvs2 / len(rvs2))
z_stat = (mean_rvs1 - mean_rvs2) / se
p_val = 2 * stats.norm.cdf(-1 * abs(z_stat))
z_a = stats.norm.ppf(1 - alpha / 2)
ci_low = (mean_rvs1 - mean_rvs2) - z_a * se
ci_high = (mean_rvs1 - mean_rvs2) + z_a * se
return z_stat, p_val, ci_low, ci_high
def test_and_check(name, a, b, t_stat, p_value, ci_low, ci_high, precision=1e-2):
client = ClickHouseClient()
client.query("DROP TABLE IF EXISTS ztest;")
client.query("CREATE TABLE ztest (left Float64, right UInt8) ENGINE = Memory;");
client.query("INSERT INTO ztest VALUES {};".format(", ".join(['({},{})'.format(i, 0) for i in a])))
client.query("INSERT INTO ztest VALUES {};".format(", ".join(['({},{})'.format(j, 1) for j in b])))
real = client.query_return_df(
"SELECT roundBankers({}(left, right).1, 16) as t_stat, ".format(name) +
"roundBankers({}(left, right).2, 16) as p_value, ".format(name) +
"roundBankers({}(left, right).3, 16) as ci_low, ".format(name) +
"roundBankers({}(left, right).4, 16) as ci_high ".format(name) +
"FROM ztest FORMAT TabSeparatedWithNames;")
real_t_stat = real['t_stat'][0]
real_p_value = real['p_value'][0]
real_ci_low = real['ci_low'][0]
real_ci_high = real['ci_high'][0]
assert(abs(real_t_stat - np.float64(t_stat)) < precision), "clickhouse_t_stat {}, py_t_stat {}".format(real_t_stat, t_stat)
assert(abs(real_p_value - np.float64(p_value)) < precision), "clickhouse_p_value {}, py_p_value {}".format(real_p_value, p_value)
assert(abs(real_ci_low - np.float64(ci_low)) < precision), "clickhouse_ci_low {}, py_ci_low {}".format(real_ci_low, ci_low)
assert(abs(real_ci_high - np.float64(ci_high)) < precision), "clickhouse_ci_high {}, py_ci_high {}".format(real_ci_high, ci_high)
client.query("DROP TABLE IF EXISTS ztest;")
def test_mean_ztest():
rvs1 = np.round(stats.norm.rvs(loc=1, scale=5,size=500), 2)
rvs2 = np.round(stats.norm.rvs(loc=10, scale=5,size=500), 2)
s, p, cl, ch = twosample_mean_ztest(rvs1, rvs2)
test_and_check("meanZTest(%f, %f, 0.95)" % (variance(rvs1), variance(rvs2)), rvs1, rvs2, s, p, cl, ch)
rvs1 = np.round(stats.norm.rvs(loc=0, scale=5,size=500), 2)
rvs2 = np.round(stats.norm.rvs(loc=0, scale=5,size=500), 2)
s, p, cl, ch = twosample_mean_ztest(rvs1, rvs2)
test_and_check("meanZTest(%f, %f, 0.95)" % (variance(rvs1), variance(rvs2)), rvs1, rvs2, s, p, cl, ch)
rvs1 = np.round(stats.norm.rvs(loc=2, scale=10,size=512), 2)
rvs2 = np.round(stats.norm.rvs(loc=5, scale=20,size=1024), 2)
s, p, cl, ch = twosample_mean_ztest(rvs1, rvs2)
test_and_check("meanZTest(%f, %f, 0.95)" % (variance(rvs1), variance(rvs2)), rvs1, rvs2, s, p, cl, ch)
rvs1 = np.round(stats.norm.rvs(loc=0, scale=10,size=1024), 2)
rvs2 = np.round(stats.norm.rvs(loc=0, scale=10,size=512), 2)
s, p, cl, ch = twosample_mean_ztest(rvs1, rvs2)
test_and_check("meanZTest(%f, %f, 0.95)" % (variance(rvs1), variance(rvs2)), rvs1, rvs2, s, p, cl, ch)
if __name__ == "__main__":
test_mean_ztest()
print("Ok.")

View File

@ -0,0 +1 @@
Ok.

View File

@ -0,0 +1,9 @@
#!/usr/bin/env bash
CURDIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)
# shellcheck source=../shell_config.sh
. "$CURDIR"/../shell_config.sh
# We should have correct env vars from shell_config.sh to run this test
python3 "$CURDIR"/02158_ztest_cmp.python