This commit is contained in:
nikitamikhaylov 2020-11-12 22:17:15 +03:00
parent de75c96a75
commit 02ce3ed4e7
7 changed files with 35 additions and 30 deletions

View File

@ -22,7 +22,7 @@ AggregateFunctionPtr createAggregateFunctionMannWhitneyUTest(const std::string &
if (!isNumber(argument_types[0]) || !isNumber(argument_types[1]))
throw Exception("Aggregate function " + name + " only supports numerical types", ErrorCodes::NOT_IMPLEMENTED);
return std::make_shared<AggregateFunctionMannWhitney<Float64>>(argument_types, parameters);
}

View File

@ -37,7 +37,7 @@ namespace ErrorCodes
template <typename T>
struct MannWhitneyData : public StatisticalSample<T, T>
{
enum class Alternative
enum class Alternative
{
TwoSided,
Less,
@ -58,7 +58,7 @@ struct MannWhitneyData : public StatisticalSample<T, T>
const Float64 n2 = this->size_y;
Float64 r1 = 0;
for (size_t i = 0; i < n1; ++i)
for (size_t i = 0; i < n1; ++i)
r1 += ranks[i];
const Float64 u1 = n1 * n2 + (n1 * (n1 + 1.)) / 2. - r1;
@ -74,7 +74,7 @@ struct MannWhitneyData : public StatisticalSample<T, T>
u = u1;
else if (alternative == Alternative::Greater)
u = u2;
const Float64 z = (u - meanrank) / sd;
const Float64 cdf = integrateSimpson(0, z, [] (Float64 t) { return std::pow(M_E, -0.5 * t * t) / std::sqrt(2 * M_PI);});
@ -89,14 +89,15 @@ struct MannWhitneyData : public StatisticalSample<T, T>
private:
/// We need to compute ranks according to all samples. Use this class to avoid extra copy and memory allocation.
class ConcatenatedSamples {
class ConcatenatedSamples
{
public:
ConcatenatedSamples(const Sample & first_, const Sample & second_)
ConcatenatedSamples(const Sample & first_, const Sample & second_)
: first(first_), second(second_) {}
const T & operator[](size_t ind) const
{
if (ind < first.size())
if (ind < first.size())
return first[ind];
return second[ind % first.size()];
}
@ -128,7 +129,8 @@ public:
if (params.size() > 2)
throw Exception("Aggregate function " + getName() + " require two parameter or less", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
if (params.empty()) {
if (params.empty())
{
alternative = Alternative::TwoSided;
return;
}
@ -143,17 +145,16 @@ public:
alternative = Alternative::Less;
else if (param == "greater")
alternative = Alternative::Greater;
else
throw Exception("Unknown parameter in aggregate function " + getName() +
else
throw Exception("Unknown parameter in aggregate function " + getName() +
". It must be one of: 'two sided', 'less', 'greater'", ErrorCodes::BAD_ARGUMENTS);
if (params.size() != 2) {
if (params.size() != 2)
return;
}
if (params[1].getType() != Field::Types::UInt64)
throw Exception("Aggregate function " + getName() + " require require second parameter to be a UInt64", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
continuity_correction = static_cast<bool>(params[1].get<UInt64>());
}
@ -215,7 +216,7 @@ public:
{
if (!this->data(place).size_x || !this->data(place).size_y)
throw Exception("Aggregate function " + getName() + " require both samples to be non empty", ErrorCodes::BAD_ARGUMENTS);
auto [u_statistic, p_value] = this->data(place).getResult(alternative, continuity_correction);
/// Because p-value is a probability.

View File

@ -23,7 +23,7 @@ AggregateFunctionPtr createAggregateFunctionRankCorrelation(const std::string &
if (!isNumber(argument_types[0]) || !isNumber(argument_types[1]))
throw Exception("Aggregate function " + name + " only supports numerical types", ErrorCodes::NOT_IMPLEMENTED);
return std::make_shared<AggregateFunctionRankCorrelation>(argument_types);
}

View File

@ -22,7 +22,8 @@ namespace DB
struct RankCorrelationData : public StatisticalSample<Float64, Float64>
{
Float64 getResult() {
Float64 getResult()
{
RanksArray ranks_x;
std::tie(ranks_x, std::ignore) = computeRanksAndTieCorrection(this->x);

View File

@ -8,7 +8,7 @@
#include <algorithm>
#include <utility>
namespace DB
namespace DB
{
template <typename F>
@ -29,7 +29,8 @@ static Float64 integrateSimpson(Float64 a, Float64 b, F && func)
using RanksArray = std::vector<Float64>;
template <typename Values>
std::pair<RanksArray, Float64> computeRanksAndTieCorrection(const Values & values) {
std::pair<RanksArray, Float64> computeRanksAndTieCorrection(const Values & values)
{
const size_t size = values.size();
/// Save initial positions, than sort indices according to the values.
std::vector<size_t> indexes(size);
@ -40,17 +41,16 @@ std::pair<RanksArray, Float64> computeRanksAndTieCorrection(const Values & value
size_t left = 0;
Float64 tie_numenator = 0;
RanksArray out(size);
while (left < size) {
while (left < size)
{
size_t right = left;
while (right < size && values[indexes[left]] == values[indexes[right]]) {
while (right < size && values[indexes[left]] == values[indexes[right]])
++right;
}
auto adjusted = (left + right + 1.) / 2.;
auto count_equal = right - left;
tie_numenator += std::pow(count_equal, 3) - count_equal;
for (size_t iter = left; iter < right; ++iter) {
for (size_t iter = left; iter < right; ++iter)
out[indexes[iter]] = adjusted;
}
left = right;
}
return {out, 1 - (tie_numenator / (std::pow(size, 3) - size))};
@ -71,12 +71,14 @@ struct StatisticalSample
size_t size_x{0};
size_t size_y{0};
void addX(X value, Arena * arena) {
void addX(X value, Arena * arena)
{
++size_x;
x.push_back(value, arena);
}
void addY(Y value, Arena * arena) {
void addY(Y value, Arena * arena)
{
++size_y;
y.push_back(value, arena);
}
@ -97,7 +99,7 @@ struct StatisticalSample
buf.write(reinterpret_cast<const char *>(y.data()), size_y * sizeof(y[0]));
}
void read(ReadBuffer & buf, Arena * arena)
void read(ReadBuffer & buf, Arena * arena)
{
readVarUInt(size_x, buf);
readVarUInt(size_y, buf);

View File

@ -20,8 +20,8 @@ TEST(Ranks, Simple)
ASSERT_EQ(ranks.size(), expected.size());
for (size_t i = 0; i < ranks.size(); ++i) {
for (size_t i = 0; i < ranks.size(); ++i)
ASSERT_DOUBLE_EQ(ranks[i], expected[i]);
}
ASSERT_DOUBLE_EQ(t, 0.9975296442687747);
}

View File

@ -29,6 +29,7 @@ SRCS(
AggregateFunctionHistogram.cpp
AggregateFunctionIf.cpp
AggregateFunctionMLMethod.cpp
AggregateFunctionMannWhitney.cpp
AggregateFunctionMaxIntersections.cpp
AggregateFunctionMerge.cpp
AggregateFunctionMinMaxAny.cpp