mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-12-14 18:32:29 +00:00
226 lines
8.8 KiB
C++
226 lines
8.8 KiB
C++
#include <Columns/ColumnNullable.h>
|
|
#include <Columns/ColumnString.h>
|
|
#include <Columns/ColumnTuple.h>
|
|
#include <Columns/ColumnsNumber.h>
|
|
#include <Columns/IColumn.h>
|
|
#include <DataTypes/DataTypeTuple.h>
|
|
#include <DataTypes/DataTypesNumber.h>
|
|
#include <Functions/FunctionFactory.h>
|
|
#include <Functions/FunctionHelpers.h>
|
|
#include <Functions/IFunction.h>
|
|
#include <Functions/castTypeToEither.h>
|
|
#include <Interpreters/castColumn.h>
|
|
#include <boost/math/distributions/normal.hpp>
|
|
#include <Common/typeid_cast.h>
|
|
|
|
|
|
namespace DB
|
|
{
|
|
|
|
namespace ErrorCodes
|
|
{
|
|
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
|
|
extern const int BAD_ARGUMENTS;
|
|
}
|
|
|
|
|
|
class FunctionTwoSampleProportionsZTest : public IFunction
|
|
{
|
|
public:
|
|
static constexpr auto POOLED = "pooled";
|
|
static constexpr auto UNPOOLED = "unpooled";
|
|
|
|
static constexpr auto name = "proportionsZTest";
|
|
|
|
static FunctionPtr create(ContextPtr) { return std::make_shared<FunctionTwoSampleProportionsZTest>(); }
|
|
|
|
String getName() const override { return name; }
|
|
|
|
size_t getNumberOfArguments() const override { return 6; }
|
|
ColumnNumbers getArgumentsThatAreAlwaysConstant() const override { return {5}; }
|
|
|
|
bool useDefaultImplementationForNulls() const override { return false; }
|
|
bool useDefaultImplementationForConstants() const override { return true; }
|
|
bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo & /*arguments*/) const override { return false; }
|
|
|
|
static DataTypePtr getReturnType()
|
|
{
|
|
auto float_data_type = std::make_shared<DataTypeNumber<Float64>>();
|
|
DataTypes types(4, float_data_type);
|
|
|
|
Strings names{"z_statistic", "p_value", "confidence_interval_low", "confidence_interval_high"};
|
|
|
|
return std::make_shared<DataTypeTuple>(std::move(types), std::move(names));
|
|
}
|
|
|
|
DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override
|
|
{
|
|
for (size_t i = 0; i < 4; ++i)
|
|
{
|
|
if (!isUInt(arguments[i].type))
|
|
{
|
|
throw Exception(
|
|
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
|
|
"The {}th Argument of function {} must be an unsigned integer.",
|
|
i + 1,
|
|
getName());
|
|
}
|
|
}
|
|
|
|
if (!isFloat(arguments[4].type))
|
|
{
|
|
throw Exception{ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
|
|
"The fifth argument {} of function {} should be a float,",
|
|
arguments[4].type->getName(),
|
|
getName()};
|
|
}
|
|
|
|
/// There is an additional check for constancy in ExecuteImpl
|
|
if (!isString(arguments[5].type) || !arguments[5].column)
|
|
{
|
|
throw Exception{ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
|
|
"The sixth argument {} of function {} should be a constant string",
|
|
arguments[5].type->getName(),
|
|
getName()};
|
|
}
|
|
|
|
return getReturnType();
|
|
}
|
|
|
|
|
|
ColumnPtr executeImpl(const ColumnsWithTypeAndName & const_arguments, const DataTypePtr &, size_t input_rows_count) const override
|
|
{
|
|
auto arguments = const_arguments;
|
|
/// Only last argument have to be constant
|
|
for (size_t i = 0; i < 5; ++i)
|
|
arguments[i].column = arguments[i].column->convertToFullColumnIfConst();
|
|
|
|
static const auto uint64_data_type = std::make_shared<DataTypeNumber<UInt64>>();
|
|
|
|
auto column_successes_x = castColumnAccurate(arguments[0], uint64_data_type);
|
|
const auto & data_successes_x = checkAndGetColumn<ColumnVector<UInt64>>(column_successes_x.get())->getData();
|
|
|
|
auto column_successes_y = castColumnAccurate(arguments[1], uint64_data_type);
|
|
const auto & data_successes_y = checkAndGetColumn<ColumnVector<UInt64>>(column_successes_y.get())->getData();
|
|
|
|
auto column_trials_x = castColumnAccurate(arguments[2], uint64_data_type);
|
|
const auto & data_trials_x = checkAndGetColumn<ColumnVector<UInt64>>(column_trials_x.get())->getData();
|
|
|
|
auto column_trials_y = castColumnAccurate(arguments[3], uint64_data_type);
|
|
const auto & data_trials_y = checkAndGetColumn<ColumnVector<UInt64>>(column_trials_y.get())->getData();
|
|
|
|
static const auto float64_data_type = std::make_shared<DataTypeNumber<Float64>>();
|
|
|
|
auto column_confidence_level = castColumnAccurate(arguments[4], float64_data_type);
|
|
const auto & data_confidence_level = checkAndGetColumn<ColumnVector<Float64>>(column_confidence_level.get())->getData();
|
|
|
|
String usevar = checkAndGetColumnConst<ColumnString>(arguments[5].column.get())->getValue<String>();
|
|
|
|
if (usevar != UNPOOLED && usevar != POOLED)
|
|
throw Exception{ErrorCodes::BAD_ARGUMENTS,
|
|
"The sixth argument {} of function {} must be equal to `pooled` or `unpooled`",
|
|
arguments[5].type->getName(),
|
|
getName()};
|
|
|
|
const bool is_unpooled = (usevar == UNPOOLED);
|
|
|
|
auto res_z_statistic = ColumnFloat64::create();
|
|
auto & data_z_statistic = res_z_statistic->getData();
|
|
data_z_statistic.reserve(input_rows_count);
|
|
|
|
auto res_p_value = ColumnFloat64::create();
|
|
auto & data_p_value = res_p_value->getData();
|
|
data_p_value.reserve(input_rows_count);
|
|
|
|
auto res_ci_lower = ColumnFloat64::create();
|
|
auto & data_ci_lower = res_ci_lower->getData();
|
|
data_ci_lower.reserve(input_rows_count);
|
|
|
|
auto res_ci_upper = ColumnFloat64::create();
|
|
auto & data_ci_upper = res_ci_upper->getData();
|
|
data_ci_upper.reserve(input_rows_count);
|
|
|
|
auto insert_values_into_result = [&data_z_statistic, &data_p_value, &data_ci_lower, &data_ci_upper](
|
|
Float64 z_stat, Float64 p_value, Float64 lower, Float64 upper)
|
|
{
|
|
data_z_statistic.emplace_back(z_stat);
|
|
data_p_value.emplace_back(p_value);
|
|
data_ci_lower.emplace_back(lower);
|
|
data_ci_upper.emplace_back(upper);
|
|
};
|
|
|
|
static constexpr Float64 nan = std::numeric_limits<Float64>::quiet_NaN();
|
|
|
|
boost::math::normal_distribution<> nd(0.0, 1.0);
|
|
|
|
for (size_t row_num = 0; row_num < input_rows_count; ++row_num)
|
|
{
|
|
const UInt64 successes_x = data_successes_x[row_num];
|
|
const UInt64 successes_y = data_successes_y[row_num];
|
|
const UInt64 trials_x = data_trials_x[row_num];
|
|
const UInt64 trials_y = data_trials_y[row_num];
|
|
const Float64 confidence_level = data_confidence_level[row_num];
|
|
|
|
const Float64 props_x = static_cast<Float64>(successes_x) / trials_x;
|
|
const Float64 props_y = static_cast<Float64>(successes_y) / trials_y;
|
|
const Float64 diff = props_x - props_y;
|
|
const UInt64 trials_total = trials_x + trials_y;
|
|
|
|
if (successes_x == 0 || successes_y == 0 || successes_x > trials_x || successes_y > trials_y || trials_total == 0
|
|
|| !std::isfinite(confidence_level) || confidence_level < 0.0 || confidence_level > 1.0)
|
|
{
|
|
insert_values_into_result(nan, nan, nan, nan);
|
|
continue;
|
|
}
|
|
|
|
Float64 se = std::sqrt(props_x * (1.0 - props_x) / trials_x + props_y * (1.0 - props_y) / trials_y);
|
|
|
|
/// z-statistics
|
|
/// z = \frac{ \bar{p_{1}} - \bar{p_{2}} }{ \sqrt{ \frac{ \bar{p_{1}} \left ( 1 - \bar{p_{1}} \right ) }{ n_{1} } \frac{ \bar{p_{2}} \left ( 1 - \bar{p_{2}} \right ) }{ n_{2} } } }
|
|
Float64 zstat;
|
|
if (is_unpooled)
|
|
{
|
|
zstat = (props_x - props_y) / se;
|
|
}
|
|
else
|
|
{
|
|
UInt64 successes_total = successes_x + successes_y;
|
|
Float64 p_pooled = static_cast<Float64>(successes_total) / trials_total;
|
|
Float64 trials_fact = 1.0 / trials_x + 1.0 / trials_y;
|
|
zstat = diff / std::sqrt(p_pooled * (1.0 - p_pooled) * trials_fact);
|
|
}
|
|
|
|
if (unlikely(!std::isfinite(zstat)))
|
|
{
|
|
insert_values_into_result(nan, nan, nan, nan);
|
|
continue;
|
|
}
|
|
|
|
// pvalue
|
|
Float64 pvalue = 0;
|
|
Float64 one_side = 1 - boost::math::cdf(nd, std::abs(zstat));
|
|
pvalue = one_side * 2;
|
|
|
|
// Confidence intervals
|
|
Float64 d = props_x - props_y;
|
|
Float64 z = -boost::math::quantile(nd, (1.0 - confidence_level) / 2.0);
|
|
Float64 dist = z * se;
|
|
Float64 ci_low = d - dist;
|
|
Float64 ci_high = d + dist;
|
|
|
|
insert_values_into_result(zstat, pvalue, ci_low, ci_high);
|
|
}
|
|
|
|
return ColumnTuple::create(
|
|
Columns{std::move(res_z_statistic), std::move(res_p_value), std::move(res_ci_lower), std::move(res_ci_upper)});
|
|
}
|
|
};
|
|
|
|
|
|
REGISTER_FUNCTION(ZTest)
|
|
{
|
|
factory.registerFunction<FunctionTwoSampleProportionsZTest>();
|
|
}
|
|
|
|
}
|