mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-11-10 01:25:21 +00:00
Implemented meanZTest (#33354)
This commit is contained in:
parent
7156e64ee2
commit
779538bd89
@ -0,0 +1,70 @@
|
|||||||
|
---
|
||||||
|
toc_priority: 303
|
||||||
|
toc_title: meanZTest
|
||||||
|
---
|
||||||
|
|
||||||
|
# meanZTest {#meanztest}
|
||||||
|
|
||||||
|
Applies mean z-test to samples from two populations.
|
||||||
|
|
||||||
|
**Syntax**
|
||||||
|
|
||||||
|
``` sql
|
||||||
|
meanZTest(population_variance_x, population_variance_y, confidence_level)(sample_data, sample_index)
|
||||||
|
```
|
||||||
|
|
||||||
|
Values of both samples are in the `sample_data` column. If `sample_index` equals to 0 then the value in that row belongs to the sample from the first population. Otherwise it belongs to the sample from the second population.
|
||||||
|
The null hypothesis is that means of populations are equal. Normal distribution is assumed. Populations may have unequal variance and the variances are known.
|
||||||
|
|
||||||
|
**Arguments**
|
||||||
|
|
||||||
|
- `sample_data` — Sample data. [Integer](../../../sql-reference/data-types/int-uint.md), [Float](../../../sql-reference/data-types/float.md) or [Decimal](../../../sql-reference/data-types/decimal.md).
|
||||||
|
- `sample_index` — Sample index. [Integer](../../../sql-reference/data-types/int-uint.md).
|
||||||
|
|
||||||
|
**Parameters**
|
||||||
|
|
||||||
|
- `population_variance_x` — Variance for population x. [Float](../../../sql-reference/data-types/float.md).
|
||||||
|
- `population_variance_y` — Variance for population y. [Float](../../../sql-reference/data-types/float.md).
|
||||||
|
- `confidence_level` — Confidence level in order to calculate confidence intervals. [Float](../../../sql-reference/data-types/float.md).
|
||||||
|
|
||||||
|
**Returned values**
|
||||||
|
|
||||||
|
[Tuple](../../../sql-reference/data-types/tuple.md) with four elements:
|
||||||
|
|
||||||
|
- calculated t-statistic. [Float64](../../../sql-reference/data-types/float.md).
|
||||||
|
- calculated p-value. [Float64](../../../sql-reference/data-types/float.md).
|
||||||
|
- calculated confidence-interval-low. [Float64](../../../sql-reference/data-types/float.md).
|
||||||
|
- calculated confidence-interval-high. [Float64](../../../sql-reference/data-types/float.md).
|
||||||
|
|
||||||
|
|
||||||
|
**Example**
|
||||||
|
|
||||||
|
Input table:
|
||||||
|
|
||||||
|
``` text
|
||||||
|
┌─sample_data─┬─sample_index─┐
|
||||||
|
│ 20.3 │ 0 │
|
||||||
|
│ 21.9 │ 0 │
|
||||||
|
│ 22.1 │ 0 │
|
||||||
|
│ 18.9 │ 1 │
|
||||||
|
│ 19 │ 1 │
|
||||||
|
│ 20.3 │ 1 │
|
||||||
|
└─────────────┴──────────────┘
|
||||||
|
```
|
||||||
|
|
||||||
|
Query:
|
||||||
|
|
||||||
|
``` sql
|
||||||
|
SELECT meanZTest(0.7, 0.45, 0.95)(sample_data, sample_index) FROM mean_ztest
|
||||||
|
```
|
||||||
|
|
||||||
|
Result:
|
||||||
|
|
||||||
|
``` text
|
||||||
|
┌─meanZTest(0.7, 0.45, 0.95)(sample_data, sample_index)────────────────────────────┐
|
||||||
|
│ (3.2841296025548123,0.0010229786769086013,0.8198428246768334,3.2468238419898365) │
|
||||||
|
└──────────────────────────────────────────────────────────────────────────────────┘
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
[Original article](https://clickhouse.com/docs/en/sql-reference/aggregate-functions/reference/meanZTest/) <!--hide-->
|
64
src/AggregateFunctions/AggregateFunctionMeanZTest.cpp
Normal file
64
src/AggregateFunctions/AggregateFunctionMeanZTest.cpp
Normal file
@ -0,0 +1,64 @@
|
|||||||
|
#include <AggregateFunctions/AggregateFunctionFactory.h>
|
||||||
|
#include <AggregateFunctions/AggregateFunctionMeanZTest.h>
|
||||||
|
#include <AggregateFunctions/FactoryHelpers.h>
|
||||||
|
#include <AggregateFunctions/Moments.h>
|
||||||
|
|
||||||
|
|
||||||
|
namespace ErrorCodes
|
||||||
|
{
|
||||||
|
extern const int BAD_ARGUMENTS;
|
||||||
|
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
namespace DB
|
||||||
|
{
|
||||||
|
struct Settings;
|
||||||
|
|
||||||
|
namespace
|
||||||
|
{
|
||||||
|
|
||||||
|
struct MeanZTestData : public ZTestMoments<Float64>
|
||||||
|
{
|
||||||
|
static constexpr auto name = "meanZTest";
|
||||||
|
|
||||||
|
std::pair<Float64, Float64> getResult(Float64 pop_var_x, Float64 pop_var_y) const
|
||||||
|
{
|
||||||
|
Float64 mean_x = getMeanX();
|
||||||
|
Float64 mean_y = getMeanY();
|
||||||
|
|
||||||
|
/// z = \frac{\bar{X_{1}} - \bar{X_{2}}}{\sqrt{\frac{\sigma_{1}^{2}}{n_{1}} + \frac{\sigma_{2}^{2}}{n_{2}}}}
|
||||||
|
Float64 zstat = (mean_x - mean_y) / getStandardError(pop_var_x, pop_var_y);
|
||||||
|
if (!std::isfinite(zstat))
|
||||||
|
{
|
||||||
|
return {std::numeric_limits<Float64>::quiet_NaN(), std::numeric_limits<Float64>::quiet_NaN()};
|
||||||
|
}
|
||||||
|
|
||||||
|
Float64 pvalue = 2.0 * boost::math::cdf(boost::math::normal(0.0, 1.0), -1.0 * std::abs(zstat));
|
||||||
|
|
||||||
|
return {zstat, pvalue};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
AggregateFunctionPtr createAggregateFunctionMeanZTest(
|
||||||
|
const std::string & name, const DataTypes & argument_types, const Array & parameters, const Settings *)
|
||||||
|
{
|
||||||
|
assertBinary(name, argument_types);
|
||||||
|
|
||||||
|
if (parameters.size() != 3)
|
||||||
|
throw Exception("Aggregate function " + name + " requires three parameter.", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
|
||||||
|
|
||||||
|
if (!isNumber(argument_types[0]) || !isNumber(argument_types[1]))
|
||||||
|
throw Exception("Aggregate function " + name + " only supports numerical types", ErrorCodes::BAD_ARGUMENTS);
|
||||||
|
|
||||||
|
return std::make_shared<AggregateFunctionMeanZTest<MeanZTestData>>(argument_types, parameters);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
void registerAggregateFunctionMeanZTest(AggregateFunctionFactory & factory)
|
||||||
|
{
|
||||||
|
factory.registerFunction("meanZTest", createAggregateFunctionMeanZTest);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
139
src/AggregateFunctions/AggregateFunctionMeanZTest.h
Normal file
139
src/AggregateFunctions/AggregateFunctionMeanZTest.h
Normal file
@ -0,0 +1,139 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <AggregateFunctions/IAggregateFunction.h>
|
||||||
|
#include <AggregateFunctions/StatCommon.h>
|
||||||
|
#include <Columns/ColumnVector.h>
|
||||||
|
#include <Columns/ColumnTuple.h>
|
||||||
|
#include <Common/assert_cast.h>
|
||||||
|
#include <Core/Types.h>
|
||||||
|
#include <DataTypes/DataTypesNumber.h>
|
||||||
|
#include <DataTypes/DataTypeTuple.h>
|
||||||
|
#include <cmath>
|
||||||
|
|
||||||
|
|
||||||
|
namespace DB
|
||||||
|
{
|
||||||
|
struct Settings;
|
||||||
|
|
||||||
|
class ReadBuffer;
|
||||||
|
class WriteBuffer;
|
||||||
|
|
||||||
|
namespace ErrorCodes
|
||||||
|
{
|
||||||
|
extern const int BAD_ARGUMENTS;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/// Returns tuple of (z-statistic, p-value, confidence-interval-low, confidence-interval-high)
|
||||||
|
template <typename Data>
|
||||||
|
class AggregateFunctionMeanZTest :
|
||||||
|
public IAggregateFunctionDataHelper<Data, AggregateFunctionMeanZTest<Data>>
|
||||||
|
{
|
||||||
|
private:
|
||||||
|
Float64 pop_var_x;
|
||||||
|
Float64 pop_var_y;
|
||||||
|
Float64 confidence_level;
|
||||||
|
|
||||||
|
public:
|
||||||
|
AggregateFunctionMeanZTest(const DataTypes & arguments, const Array & params)
|
||||||
|
: IAggregateFunctionDataHelper<Data, AggregateFunctionMeanZTest<Data>>({arguments}, params)
|
||||||
|
{
|
||||||
|
pop_var_x = params.at(0).safeGet<Float64>();
|
||||||
|
pop_var_y = params.at(1).safeGet<Float64>();
|
||||||
|
confidence_level = params.at(2).safeGet<Float64>();
|
||||||
|
|
||||||
|
if (!std::isfinite(pop_var_x) || !std::isfinite(pop_var_y) || !std::isfinite(confidence_level))
|
||||||
|
{
|
||||||
|
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Aggregate function {} requires finite parameter values.", Data::name);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (pop_var_x < 0.0 || pop_var_y < 0.0)
|
||||||
|
{
|
||||||
|
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Population variance parameters must be larger than or equal to zero in aggregate function {}.", Data::name);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (confidence_level <= 0.0 || confidence_level >= 1.0)
|
||||||
|
{
|
||||||
|
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Confidence level parameter must be between 0 and 1 in aggregate function {}.", Data::name);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
String getName() const override
|
||||||
|
{
|
||||||
|
return Data::name;
|
||||||
|
}
|
||||||
|
|
||||||
|
DataTypePtr getReturnType() const override
|
||||||
|
{
|
||||||
|
DataTypes types
|
||||||
|
{
|
||||||
|
std::make_shared<DataTypeNumber<Float64>>(),
|
||||||
|
std::make_shared<DataTypeNumber<Float64>>(),
|
||||||
|
std::make_shared<DataTypeNumber<Float64>>(),
|
||||||
|
std::make_shared<DataTypeNumber<Float64>>(),
|
||||||
|
};
|
||||||
|
|
||||||
|
Strings names
|
||||||
|
{
|
||||||
|
"z_statistic",
|
||||||
|
"p_value",
|
||||||
|
"confidence_interval_low",
|
||||||
|
"confidence_interval_high"
|
||||||
|
};
|
||||||
|
|
||||||
|
return std::make_shared<DataTypeTuple>(
|
||||||
|
std::move(types),
|
||||||
|
std::move(names)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool allocatesMemoryInArena() const override { return false; }
|
||||||
|
|
||||||
|
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override
|
||||||
|
{
|
||||||
|
Float64 value = columns[0]->getFloat64(row_num);
|
||||||
|
UInt8 is_second = columns[1]->getUInt(row_num);
|
||||||
|
|
||||||
|
if (is_second)
|
||||||
|
this->data(place).addY(value);
|
||||||
|
else
|
||||||
|
this->data(place).addX(value);
|
||||||
|
}
|
||||||
|
|
||||||
|
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override
|
||||||
|
{
|
||||||
|
this->data(place).merge(this->data(rhs));
|
||||||
|
}
|
||||||
|
|
||||||
|
void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> /* version */) const override
|
||||||
|
{
|
||||||
|
this->data(place).write(buf);
|
||||||
|
}
|
||||||
|
|
||||||
|
void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena *) const override
|
||||||
|
{
|
||||||
|
this->data(place).read(buf);
|
||||||
|
}
|
||||||
|
|
||||||
|
void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override
|
||||||
|
{
|
||||||
|
auto [z_stat, p_value] = this->data(place).getResult(pop_var_x, pop_var_y);
|
||||||
|
auto [ci_low, ci_high] = this->data(place).getConfidenceIntervals(pop_var_x, pop_var_y, confidence_level);
|
||||||
|
|
||||||
|
/// Because p-value is a probability.
|
||||||
|
p_value = std::min(1.0, std::max(0.0, p_value));
|
||||||
|
|
||||||
|
auto & column_tuple = assert_cast<ColumnTuple &>(to);
|
||||||
|
auto & column_stat = assert_cast<ColumnVector<Float64> &>(column_tuple.getColumn(0));
|
||||||
|
auto & column_value = assert_cast<ColumnVector<Float64> &>(column_tuple.getColumn(1));
|
||||||
|
auto & column_ci_low = assert_cast<ColumnVector<Float64> &>(column_tuple.getColumn(2));
|
||||||
|
auto & column_ci_high = assert_cast<ColumnVector<Float64> &>(column_tuple.getColumn(3));
|
||||||
|
|
||||||
|
column_stat.getData().push_back(z_stat);
|
||||||
|
column_value.getData().push_back(p_value);
|
||||||
|
column_ci_low.getData().push_back(ci_low);
|
||||||
|
column_ci_high.getData().push_back(ci_high);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
};
|
@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
#include <IO/WriteHelpers.h>
|
#include <IO/WriteHelpers.h>
|
||||||
#include <IO/ReadHelpers.h>
|
#include <IO/ReadHelpers.h>
|
||||||
|
#include <boost/math/distributions/normal.hpp>
|
||||||
|
|
||||||
|
|
||||||
namespace DB
|
namespace DB
|
||||||
@ -359,4 +360,74 @@ struct TTestMoments
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct ZTestMoments
|
||||||
|
{
|
||||||
|
T nx{};
|
||||||
|
T ny{};
|
||||||
|
T x1{};
|
||||||
|
T y1{};
|
||||||
|
|
||||||
|
void addX(T value)
|
||||||
|
{
|
||||||
|
++nx;
|
||||||
|
x1 += value;
|
||||||
|
}
|
||||||
|
|
||||||
|
void addY(T value)
|
||||||
|
{
|
||||||
|
++ny;
|
||||||
|
y1 += value;
|
||||||
|
}
|
||||||
|
|
||||||
|
void merge(const ZTestMoments & rhs)
|
||||||
|
{
|
||||||
|
nx += rhs.nx;
|
||||||
|
ny += rhs.ny;
|
||||||
|
x1 += rhs.x1;
|
||||||
|
y1 += rhs.y1;
|
||||||
|
}
|
||||||
|
|
||||||
|
void write(WriteBuffer & buf) const
|
||||||
|
{
|
||||||
|
writePODBinary(*this, buf);
|
||||||
|
}
|
||||||
|
|
||||||
|
void read(ReadBuffer & buf)
|
||||||
|
{
|
||||||
|
readPODBinary(*this, buf);
|
||||||
|
}
|
||||||
|
|
||||||
|
Float64 getMeanX() const
|
||||||
|
{
|
||||||
|
return x1 / nx;
|
||||||
|
}
|
||||||
|
|
||||||
|
Float64 getMeanY() const
|
||||||
|
{
|
||||||
|
return y1 / ny;
|
||||||
|
}
|
||||||
|
|
||||||
|
Float64 getStandardError(Float64 pop_var_x, Float64 pop_var_y) const
|
||||||
|
{
|
||||||
|
/// \sqrt{\frac{\sigma_{1}^{2}}{n_{1}} + \frac{\sigma_{2}^{2}}{n_{2}}}
|
||||||
|
return std::sqrt(pop_var_x / nx + pop_var_y / ny);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::pair<Float64, Float64> getConfidenceIntervals(Float64 pop_var_x, Float64 pop_var_y, Float64 confidence_level) const
|
||||||
|
{
|
||||||
|
/// (\bar{x_{1}} - \bar{x_{2}}) \pm zscore \times \sqrt{\frac{\sigma_{1}^{2}}{n_{1}} + \frac{\sigma_{2}^{2}}{n_{2}}}
|
||||||
|
Float64 mean_x = getMeanX();
|
||||||
|
Float64 mean_y = getMeanY();
|
||||||
|
|
||||||
|
Float64 z = boost::math::quantile(boost::math::complement(
|
||||||
|
boost::math::normal(0.0f, 1.0f), (1.0f - confidence_level) / 2.0f));
|
||||||
|
Float64 se = getStandardError(pop_var_x, pop_var_y);
|
||||||
|
Float64 ci_low = (mean_x - mean_y) - z * se;
|
||||||
|
Float64 ci_high = (mean_x - mean_y) + z * se;
|
||||||
|
|
||||||
|
return {ci_low, ci_high};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -48,6 +48,7 @@ void registerAggregateFunctionRankCorrelation(AggregateFunctionFactory &);
|
|||||||
void registerAggregateFunctionMannWhitney(AggregateFunctionFactory &);
|
void registerAggregateFunctionMannWhitney(AggregateFunctionFactory &);
|
||||||
void registerAggregateFunctionWelchTTest(AggregateFunctionFactory &);
|
void registerAggregateFunctionWelchTTest(AggregateFunctionFactory &);
|
||||||
void registerAggregateFunctionStudentTTest(AggregateFunctionFactory &);
|
void registerAggregateFunctionStudentTTest(AggregateFunctionFactory &);
|
||||||
|
void registerAggregateFunctionMeanZTest(AggregateFunctionFactory &);
|
||||||
void registerAggregateFunctionCramersV(AggregateFunctionFactory &);
|
void registerAggregateFunctionCramersV(AggregateFunctionFactory &);
|
||||||
void registerAggregateFunctionTheilsU(AggregateFunctionFactory &);
|
void registerAggregateFunctionTheilsU(AggregateFunctionFactory &);
|
||||||
void registerAggregateFunctionContingency(AggregateFunctionFactory &);
|
void registerAggregateFunctionContingency(AggregateFunctionFactory &);
|
||||||
@ -123,6 +124,7 @@ void registerAggregateFunctions()
|
|||||||
registerAggregateFunctionSequenceNextNode(factory);
|
registerAggregateFunctionSequenceNextNode(factory);
|
||||||
registerAggregateFunctionWelchTTest(factory);
|
registerAggregateFunctionWelchTTest(factory);
|
||||||
registerAggregateFunctionStudentTTest(factory);
|
registerAggregateFunctionStudentTTest(factory);
|
||||||
|
registerAggregateFunctionMeanZTest(factory);
|
||||||
registerAggregateFunctionNothing(factory);
|
registerAggregateFunctionNothing(factory);
|
||||||
registerAggregateFunctionSingleValueOrNull(factory);
|
registerAggregateFunctionSingleValueOrNull(factory);
|
||||||
registerAggregateFunctionIntervalLengthSum(factory);
|
registerAggregateFunctionIntervalLengthSum(factory);
|
||||||
|
1
tests/queries/0_stateless/02158_ztest.reference
Normal file
1
tests/queries/0_stateless/02158_ztest.reference
Normal file
@ -0,0 +1 @@
|
|||||||
|
-0.1749814092128543 0.8610942415056733 -12.200984112294334 10.200984112294334
|
6
tests/queries/0_stateless/02158_ztest.sql
Normal file
6
tests/queries/0_stateless/02158_ztest.sql
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
DROP TABLE IF EXISTS mean_ztest;
|
||||||
|
CREATE TABLE mean_ztest (v int, s UInt8) ENGINE = Memory;
|
||||||
|
INSERT INTO mean_ztest SELECT number, 0 FROM numbers(100) WHERE number % 2 = 0;
|
||||||
|
INSERT INTO mean_ztest SELECT number, 1 FROM numbers(100) WHERE number % 2 = 1;
|
||||||
|
SELECT roundBankers(meanZTest(833.0, 800.0, 0.95)(v, s).1, 16) as z_stat, roundBankers(meanZTest(833.0, 800.0, 0.95)(v, s).2, 16) as p_value, roundBankers(meanZTest(833.0, 800.0, 0.95)(v, s).3, 16) as ci_low, roundBankers(meanZTest(833.0, 800.0, 0.95)(v, s).4, 16) as ci_high FROM mean_ztest;
|
||||||
|
DROP TABLE IF EXISTS mean_ztest;
|
77
tests/queries/0_stateless/02158_ztest_cmp.python
Normal file
77
tests/queries/0_stateless/02158_ztest_cmp.python
Normal file
@ -0,0 +1,77 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from statistics import variance
|
||||||
|
from scipy import stats
|
||||||
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
CURDIR = os.path.dirname(os.path.realpath(__file__))
|
||||||
|
sys.path.insert(0, os.path.join(CURDIR, 'helpers'))
|
||||||
|
|
||||||
|
from pure_http_client import ClickHouseClient
|
||||||
|
|
||||||
|
|
||||||
|
# unpooled variance z-test for means of two samples
|
||||||
|
def twosample_mean_ztest(rvs1, rvs2, alpha=0.05):
|
||||||
|
mean_rvs1 = np.mean(rvs1)
|
||||||
|
mean_rvs2 = np.mean(rvs2)
|
||||||
|
var_pop_rvs1 = variance(rvs1)
|
||||||
|
var_pop_rvs2 = variance(rvs2)
|
||||||
|
se = np.sqrt(var_pop_rvs1 / len(rvs1) + var_pop_rvs2 / len(rvs2))
|
||||||
|
z_stat = (mean_rvs1 - mean_rvs2) / se
|
||||||
|
p_val = 2 * stats.norm.cdf(-1 * abs(z_stat))
|
||||||
|
z_a = stats.norm.ppf(1 - alpha / 2)
|
||||||
|
ci_low = (mean_rvs1 - mean_rvs2) - z_a * se
|
||||||
|
ci_high = (mean_rvs1 - mean_rvs2) + z_a * se
|
||||||
|
return z_stat, p_val, ci_low, ci_high
|
||||||
|
|
||||||
|
|
||||||
|
def test_and_check(name, a, b, t_stat, p_value, ci_low, ci_high, precision=1e-2):
|
||||||
|
client = ClickHouseClient()
|
||||||
|
client.query("DROP TABLE IF EXISTS ztest;")
|
||||||
|
client.query("CREATE TABLE ztest (left Float64, right UInt8) ENGINE = Memory;");
|
||||||
|
client.query("INSERT INTO ztest VALUES {};".format(", ".join(['({},{})'.format(i, 0) for i in a])))
|
||||||
|
client.query("INSERT INTO ztest VALUES {};".format(", ".join(['({},{})'.format(j, 1) for j in b])))
|
||||||
|
real = client.query_return_df(
|
||||||
|
"SELECT roundBankers({}(left, right).1, 16) as t_stat, ".format(name) +
|
||||||
|
"roundBankers({}(left, right).2, 16) as p_value, ".format(name) +
|
||||||
|
"roundBankers({}(left, right).3, 16) as ci_low, ".format(name) +
|
||||||
|
"roundBankers({}(left, right).4, 16) as ci_high ".format(name) +
|
||||||
|
"FROM ztest FORMAT TabSeparatedWithNames;")
|
||||||
|
real_t_stat = real['t_stat'][0]
|
||||||
|
real_p_value = real['p_value'][0]
|
||||||
|
real_ci_low = real['ci_low'][0]
|
||||||
|
real_ci_high = real['ci_high'][0]
|
||||||
|
assert(abs(real_t_stat - np.float64(t_stat)) < precision), "clickhouse_t_stat {}, py_t_stat {}".format(real_t_stat, t_stat)
|
||||||
|
assert(abs(real_p_value - np.float64(p_value)) < precision), "clickhouse_p_value {}, py_p_value {}".format(real_p_value, p_value)
|
||||||
|
assert(abs(real_ci_low - np.float64(ci_low)) < precision), "clickhouse_ci_low {}, py_ci_low {}".format(real_ci_low, ci_low)
|
||||||
|
assert(abs(real_ci_high - np.float64(ci_high)) < precision), "clickhouse_ci_high {}, py_ci_high {}".format(real_ci_high, ci_high)
|
||||||
|
client.query("DROP TABLE IF EXISTS ztest;")
|
||||||
|
|
||||||
|
|
||||||
|
def test_mean_ztest():
|
||||||
|
rvs1 = np.round(stats.norm.rvs(loc=1, scale=5,size=500), 2)
|
||||||
|
rvs2 = np.round(stats.norm.rvs(loc=10, scale=5,size=500), 2)
|
||||||
|
s, p, cl, ch = twosample_mean_ztest(rvs1, rvs2)
|
||||||
|
test_and_check("meanZTest(%f, %f, 0.95)" % (variance(rvs1), variance(rvs2)), rvs1, rvs2, s, p, cl, ch)
|
||||||
|
|
||||||
|
rvs1 = np.round(stats.norm.rvs(loc=0, scale=5,size=500), 2)
|
||||||
|
rvs2 = np.round(stats.norm.rvs(loc=0, scale=5,size=500), 2)
|
||||||
|
s, p, cl, ch = twosample_mean_ztest(rvs1, rvs2)
|
||||||
|
test_and_check("meanZTest(%f, %f, 0.95)" % (variance(rvs1), variance(rvs2)), rvs1, rvs2, s, p, cl, ch)
|
||||||
|
|
||||||
|
rvs1 = np.round(stats.norm.rvs(loc=2, scale=10,size=512), 2)
|
||||||
|
rvs2 = np.round(stats.norm.rvs(loc=5, scale=20,size=1024), 2)
|
||||||
|
s, p, cl, ch = twosample_mean_ztest(rvs1, rvs2)
|
||||||
|
test_and_check("meanZTest(%f, %f, 0.95)" % (variance(rvs1), variance(rvs2)), rvs1, rvs2, s, p, cl, ch)
|
||||||
|
|
||||||
|
rvs1 = np.round(stats.norm.rvs(loc=0, scale=10,size=1024), 2)
|
||||||
|
rvs2 = np.round(stats.norm.rvs(loc=0, scale=10,size=512), 2)
|
||||||
|
s, p, cl, ch = twosample_mean_ztest(rvs1, rvs2)
|
||||||
|
test_and_check("meanZTest(%f, %f, 0.95)" % (variance(rvs1), variance(rvs2)), rvs1, rvs2, s, p, cl, ch)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_mean_ztest()
|
||||||
|
print("Ok.")
|
1
tests/queries/0_stateless/02158_ztest_cmp.reference
Normal file
1
tests/queries/0_stateless/02158_ztest_cmp.reference
Normal file
@ -0,0 +1 @@
|
|||||||
|
Ok.
|
9
tests/queries/0_stateless/02158_ztest_cmp.sh
Executable file
9
tests/queries/0_stateless/02158_ztest_cmp.sh
Executable file
@ -0,0 +1,9 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
|
||||||
|
CURDIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)
|
||||||
|
# shellcheck source=../shell_config.sh
|
||||||
|
. "$CURDIR"/../shell_config.sh
|
||||||
|
|
||||||
|
# We should have correct env vars from shell_config.sh to run this test
|
||||||
|
|
||||||
|
python3 "$CURDIR"/02158_ztest_cmp.python
|
Loading…
Reference in New Issue
Block a user