#include #include #include #include #include #include #include #include #include #include #include #include #include #include namespace DB { namespace ErrorCodes { extern const int ILLEGAL_TYPE_OF_ARGUMENT; extern const int BAD_ARGUMENTS; } class FunctionTwoSampleProportionsZTest : public IFunction { public: static constexpr auto POOLED = "pooled"; static constexpr auto UNPOOLED = "unpooled"; static constexpr auto name = "proportionsZTest"; static FunctionPtr create(ContextPtr) { return std::make_shared(); } String getName() const override { return name; } size_t getNumberOfArguments() const override { return 6; } ColumnNumbers getArgumentsThatAreAlwaysConstant() const override { return {5}; } bool useDefaultImplementationForNulls() const override { return false; } bool useDefaultImplementationForConstants() const override { return true; } bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo & /*arguments*/) const override { return false; } static DataTypePtr getReturnType() { auto float_data_type = std::make_shared>(); DataTypes types(4, float_data_type); Strings names{"z_statistic", "p_value", "confidence_interval_low", "confidence_interval_high"}; return std::make_shared(std::move(types), std::move(names)); } DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override { for (size_t i = 0; i < 4; ++i) { if (!isUInt(arguments[i].type)) { throw Exception( ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "The {}th Argument of function {} must be an unsigned integer.", i + 1, getName()); } } if (!isFloat(arguments[4].type)) { throw Exception{ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "The fifth argument {} of function {} should be a float,", arguments[4].type->getName(), getName()}; } /// There is an additional check for constancy in ExecuteImpl if (!isString(arguments[5].type) || !arguments[5].column) { throw Exception{ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "The sixth argument {} of function {} should be a constant string", arguments[5].type->getName(), getName()}; } return getReturnType(); } ColumnPtr executeImpl(const ColumnsWithTypeAndName & const_arguments, const DataTypePtr &, size_t input_rows_count) const override { auto arguments = const_arguments; /// Only last argument have to be constant for (size_t i = 0; i < 5; ++i) arguments[i].column = arguments[i].column->convertToFullColumnIfConst(); static const auto uint64_data_type = std::make_shared>(); auto column_successes_x = castColumnAccurate(arguments[0], uint64_data_type); const auto & data_successes_x = checkAndGetColumn>(column_successes_x.get())->getData(); auto column_successes_y = castColumnAccurate(arguments[1], uint64_data_type); const auto & data_successes_y = checkAndGetColumn>(column_successes_y.get())->getData(); auto column_trials_x = castColumnAccurate(arguments[2], uint64_data_type); const auto & data_trials_x = checkAndGetColumn>(column_trials_x.get())->getData(); auto column_trials_y = castColumnAccurate(arguments[3], uint64_data_type); const auto & data_trials_y = checkAndGetColumn>(column_trials_y.get())->getData(); static const auto float64_data_type = std::make_shared>(); auto column_confidence_level = castColumnAccurate(arguments[4], float64_data_type); const auto & data_confidence_level = checkAndGetColumn>(column_confidence_level.get())->getData(); String usevar = checkAndGetColumnConst(arguments[5].column.get())->getValue(); if (usevar != UNPOOLED && usevar != POOLED) throw Exception{ErrorCodes::BAD_ARGUMENTS, "The sixth argument {} of function {} must be equal to `pooled` or `unpooled`", arguments[5].type->getName(), getName()}; const bool is_unpooled = (usevar == UNPOOLED); auto res_z_statistic = ColumnFloat64::create(); auto & data_z_statistic = res_z_statistic->getData(); data_z_statistic.reserve(input_rows_count); auto res_p_value = ColumnFloat64::create(); auto & data_p_value = res_p_value->getData(); data_p_value.reserve(input_rows_count); auto res_ci_lower = ColumnFloat64::create(); auto & data_ci_lower = res_ci_lower->getData(); data_ci_lower.reserve(input_rows_count); auto res_ci_upper = ColumnFloat64::create(); auto & data_ci_upper = res_ci_upper->getData(); data_ci_upper.reserve(input_rows_count); auto insert_values_into_result = [&data_z_statistic, &data_p_value, &data_ci_lower, &data_ci_upper]( Float64 z_stat, Float64 p_value, Float64 lower, Float64 upper) { data_z_statistic.emplace_back(z_stat); data_p_value.emplace_back(p_value); data_ci_lower.emplace_back(lower); data_ci_upper.emplace_back(upper); }; static constexpr Float64 nan = std::numeric_limits::quiet_NaN(); boost::math::normal_distribution<> nd(0.0, 1.0); for (size_t row_num = 0; row_num < input_rows_count; ++row_num) { const UInt64 successes_x = data_successes_x[row_num]; const UInt64 successes_y = data_successes_y[row_num]; const UInt64 trials_x = data_trials_x[row_num]; const UInt64 trials_y = data_trials_y[row_num]; const Float64 confidence_level = data_confidence_level[row_num]; const Float64 props_x = static_cast(successes_x) / trials_x; const Float64 props_y = static_cast(successes_y) / trials_y; const Float64 diff = props_x - props_y; const UInt64 trials_total = trials_x + trials_y; if (successes_x == 0 || successes_y == 0 || successes_x > trials_x || successes_y > trials_y || trials_total == 0 || !std::isfinite(confidence_level) || confidence_level < 0.0 || confidence_level > 1.0) { insert_values_into_result(nan, nan, nan, nan); continue; } Float64 se = std::sqrt(props_x * (1.0 - props_x) / trials_x + props_y * (1.0 - props_y) / trials_y); /// z-statistics /// z = \frac{ \bar{p_{1}} - \bar{p_{2}} }{ \sqrt{ \frac{ \bar{p_{1}} \left ( 1 - \bar{p_{1}} \right ) }{ n_{1} } \frac{ \bar{p_{2}} \left ( 1 - \bar{p_{2}} \right ) }{ n_{2} } } } Float64 zstat; if (is_unpooled) { zstat = (props_x - props_y) / se; } else { UInt64 successes_total = successes_x + successes_y; Float64 p_pooled = static_cast(successes_total) / trials_total; Float64 trials_fact = 1.0 / trials_x + 1.0 / trials_y; zstat = diff / std::sqrt(p_pooled * (1.0 - p_pooled) * trials_fact); } if (unlikely(!std::isfinite(zstat))) { insert_values_into_result(nan, nan, nan, nan); continue; } // pvalue Float64 pvalue = 0; Float64 one_side = 1 - boost::math::cdf(nd, std::abs(zstat)); pvalue = one_side * 2; // Confidence intervals Float64 d = props_x - props_y; Float64 z = -boost::math::quantile(nd, (1.0 - confidence_level) / 2.0); Float64 dist = z * se; Float64 ci_low = d - dist; Float64 ci_high = d + dist; insert_values_into_result(zstat, pvalue, ci_low, ci_high); } return ColumnTuple::create( Columns{std::move(res_z_statistic), std::move(res_p_value), std::move(res_ci_lower), std::move(res_ci_upper)}); } }; REGISTER_FUNCTION(ZTest) { factory.registerFunction(); } }