#include #if !defined(ARCADIA_BUILD) && USE_STATS #include #include #include #include #include #include #include #include #include #include #define STATS_ENABLE_STDVEC_WRAPPERS #include namespace DB { namespace ErrorCodes { extern const int ILLEGAL_TYPE_OF_ARGUMENT; extern const int BAD_ARGUMENTS; } static const String BETA = "beta"; static const String GAMMA = "gamma"; template Variants bayesian_ab_test(String distribution, PODArray & xs, PODArray & ys) { const size_t r = 1000, c = 100; Variants variants(xs.size(), {0.0, 0.0, 0.0, 0.0}); std::vector> samples_matrix; for (size_t i = 0; i < xs.size(); ++i) { variants[i].x = xs[i]; variants[i].y = ys[i]; } if (distribution == BETA) { Float64 alpha, beta; for (size_t i = 0; i < xs.size(); ++i) if (xs[i] < ys[i]) throw Exception("Conversions cannot be larger than trials", ErrorCodes::BAD_ARGUMENTS); for (size_t i = 0; i < xs.size(); ++i) { alpha = 1.0 + ys[i]; beta = 1.0 + xs[i] - ys[i]; samples_matrix.emplace_back(stats::rbeta>(r, c, alpha, beta)); } } else if (distribution == GAMMA) { Float64 shape, scale; for (size_t i = 0; i < xs.size(); ++i) { shape = 1.0 + xs[i]; scale = 250.0 / (1 + 250.0 * ys[i]); std::vector samples = stats::rgamma>(r, c, shape, scale); for (auto & sample : samples) sample = 1 / sample; samples_matrix.emplace_back(std::move(samples)); } } PODArray means; for (auto & samples : samples_matrix) { Float64 total = 0.0; for (auto sample : samples) total += sample; means.push_back(total / samples.size()); } // Beats control for (size_t i = 1; i < xs.size(); ++i) { for (size_t n = 0; n < r * c; ++n) { if (higher_is_better) { if (samples_matrix[i][n] > samples_matrix[0][n]) ++variants[i].beats_control; } else { if (samples_matrix[i][n] < samples_matrix[0][n]) ++variants[i].beats_control; } } } for (auto & variant : variants) variant.beats_control = static_cast(variant.beats_control) / r / c; // To be best PODArray count_m(xs.size(), 0); PODArray row(xs.size(), 0); for (size_t n = 0; n < r * c; ++n) { for (size_t i = 0; i < xs.size(); ++i) row[i] = samples_matrix[i][n]; Float64 m; if (higher_is_better) m = *std::max_element(row.begin(), row.end()); else m = *std::min_element(row.begin(), row.end()); for (size_t i = 0; i < xs.size(); ++i) { if (m == samples_matrix[i][n]) { ++variants[i].best; break; } } } for (auto & variant : variants) variant.best = static_cast(variant.best) / r / c; return variants; } String convertToJson(const PODArray & variant_names, const Variants & variants) { FormatSettings settings; std::stringstream s; { WriteBufferFromOStream buf(s); writeCString("{\"data\":[", buf); for (size_t i = 0; i < variants.size(); ++i) { writeCString("{\"variant_name\":", buf); writeJSONString(variant_names[i], buf, settings); writeCString(",\"x\":", buf); writeText(variants[i].x, buf); writeCString(",\"y\":", buf); writeText(variants[i].y, buf); writeCString(",\"beats_control\":", buf); writeText(variants[i].beats_control, buf); writeCString(",\"to_be_best\":", buf); writeText(variants[i].best, buf); writeCString("}", buf); if (i != variant_names.size() -1) writeCString(",", buf); } writeCString("]}", buf); } return s.str(); } class FunctionBayesAB : public IFunction { public: static constexpr auto name = "bayesAB"; static FunctionPtr create(const Context &) { return std::make_shared(); } String getName() const override { return name; } bool isDeterministic() const override { return false; } bool isDeterministicInScopeOfQuery() const override { return false; } size_t getNumberOfArguments() const override { return 5; } DataTypePtr getReturnTypeImpl(const DataTypes &) const override { return std::make_shared(); } static bool toFloat64(const ColumnConst * col_const_arr, PODArray & output) { Array src_arr = col_const_arr->getValue(); for (size_t i = 0, size = src_arr.size(); i < size; ++i) { switch (src_arr[i].getType()) { case Field::Types::Int64: output.push_back(static_cast(src_arr[i].get())); break; case Field::Types::UInt64: output.push_back(static_cast(src_arr[i].get())); break; case Field::Types::Float64: output.push_back(src_arr[i].get()); break; default: return false; } } return true; } void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result, size_t input_rows_count) const override { if (input_rows_count == 0) { block.getByPosition(result).column = ColumnString::create(); return; } PODArray xs, ys; PODArray variant_names; String dist; bool higher_is_better; if (const ColumnConst * col_dist = checkAndGetColumnConst(block.getByPosition(arguments[0]).column.get())) { dist = col_dist->getDataAt(0).data; dist = Poco::toLower(dist); if (dist != BETA && dist != GAMMA) throw Exception("First argument for function " + getName() + " cannot be " + dist, ErrorCodes::BAD_ARGUMENTS); } else throw Exception("First argument for function " + getName() + " must be Constant string", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); if (const ColumnConst * col_higher_is_better = checkAndGetColumnConst(block.getByPosition(arguments[1]).column.get())) higher_is_better = col_higher_is_better->getBool(0); else throw Exception("Second argument for function " + getName() + " must be Constatnt boolean", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); if (const ColumnConst * col_const_arr = checkAndGetColumnConst(block.getByPosition(arguments[2]).column.get())) { if (!col_const_arr) throw Exception("Third argument for function " + getName() + " must be Array of constant strings", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); Array src_arr = col_const_arr->getValue(); for (size_t i = 0; i < src_arr.size(); ++i) { if (src_arr[i].getType() != Field::Types::String) throw Exception("Third argument for function " + getName() + " must be Array of constant strings", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); variant_names.push_back(src_arr[i].get()); } } if (const ColumnConst * col_const_arr = checkAndGetColumnConst(block.getByPosition(arguments[3]).column.get())) { if (!col_const_arr) throw Exception("Forth argument for function " + getName() + " must be Array of constant numbers", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); if (!toFloat64(col_const_arr, xs)) throw Exception("Forth and fifth Argument for function " + getName() + " must be Array of constant Numbers", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); } if (const ColumnConst * col_const_arr = checkAndGetColumnConst(block.getByPosition(arguments[4]).column.get())) { if (!col_const_arr) throw Exception("Fifth argument for function " + getName() + " must be Array of constant numbers", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); if (!toFloat64(col_const_arr, ys)) throw Exception("Fifth Argument for function " + getName() + " must be Array of constant Numbers", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); } if (variant_names.size() != xs.size() || xs.size() != ys.size()) throw Exception(ErrorCodes::BAD_ARGUMENTS, "Sizes of arguments doesn't match: variant_names: {}, xs: {}, ys: {}", variant_names.size(), xs.size(), ys.size()); if (variant_names.size() < 2) throw Exception(ErrorCodes::BAD_ARGUMENTS, "Sizes of arguments must be larger than 1. variant_names: {}, xs: {}, ys: {}", variant_names.size(), xs.size(), ys.size()); if (std::count_if(xs.begin(), xs.end(), [](Float64 v) { return v < 0; }) > 0 || std::count_if(ys.begin(), ys.end(), [](Float64 v) { return v < 0; }) > 0) throw Exception("Negative values don't allowed", ErrorCodes::BAD_ARGUMENTS); Variants variants; if (higher_is_better) variants = bayesian_ab_test(dist, xs, ys); else variants = bayesian_ab_test(dist, xs, ys); auto dst = ColumnString::create(); std::string result_str = convertToJson(variant_names, variants); dst->insertData(result_str.c_str(), result_str.length()); block.getByPosition(result).column = std::move(dst); } }; void registerFunctionBayesAB(FunctionFactory & factory) { factory.registerFunction(); } } #else namespace DB { class FunctionFactory; void registerFunctionBayesAB(FunctionFactory & /* factory */) { } } #endif