From 02ce3ed4e714c5224a529fbb81b5a52bd2625529 Mon Sep 17 00:00:00 2001 From: nikitamikhaylov Date: Thu, 12 Nov 2020 22:17:15 +0300 Subject: [PATCH] style --- .../AggregateFunctionMannWhitney.cpp | 2 +- .../AggregateFunctionMannWhitney.h | 31 ++++++++++--------- .../AggregateFunctionRankCorrelation.cpp | 2 +- .../AggregateFunctionRankCorrelation.h | 3 +- src/AggregateFunctions/StatCommon.h | 22 +++++++------ src/AggregateFunctions/tests/gtest_ranks.cpp | 4 +-- src/AggregateFunctions/ya.make | 1 + 7 files changed, 35 insertions(+), 30 deletions(-) diff --git a/src/AggregateFunctions/AggregateFunctionMannWhitney.cpp b/src/AggregateFunctions/AggregateFunctionMannWhitney.cpp index ef68ea5a9a1..b5fd39a451e 100644 --- a/src/AggregateFunctions/AggregateFunctionMannWhitney.cpp +++ b/src/AggregateFunctions/AggregateFunctionMannWhitney.cpp @@ -22,7 +22,7 @@ AggregateFunctionPtr createAggregateFunctionMannWhitneyUTest(const std::string & if (!isNumber(argument_types[0]) || !isNumber(argument_types[1])) throw Exception("Aggregate function " + name + " only supports numerical types", ErrorCodes::NOT_IMPLEMENTED); - + return std::make_shared>(argument_types, parameters); } diff --git a/src/AggregateFunctions/AggregateFunctionMannWhitney.h b/src/AggregateFunctions/AggregateFunctionMannWhitney.h index cbfe19ccf5d..7dbc7722498 100644 --- a/src/AggregateFunctions/AggregateFunctionMannWhitney.h +++ b/src/AggregateFunctions/AggregateFunctionMannWhitney.h @@ -37,7 +37,7 @@ namespace ErrorCodes template struct MannWhitneyData : public StatisticalSample { - enum class Alternative + enum class Alternative { TwoSided, Less, @@ -58,7 +58,7 @@ struct MannWhitneyData : public StatisticalSample const Float64 n2 = this->size_y; Float64 r1 = 0; - for (size_t i = 0; i < n1; ++i) + for (size_t i = 0; i < n1; ++i) r1 += ranks[i]; const Float64 u1 = n1 * n2 + (n1 * (n1 + 1.)) / 2. - r1; @@ -74,7 +74,7 @@ struct MannWhitneyData : public StatisticalSample u = u1; else if (alternative == Alternative::Greater) u = u2; - + const Float64 z = (u - meanrank) / sd; const Float64 cdf = integrateSimpson(0, z, [] (Float64 t) { return std::pow(M_E, -0.5 * t * t) / std::sqrt(2 * M_PI);}); @@ -89,14 +89,15 @@ struct MannWhitneyData : public StatisticalSample private: /// We need to compute ranks according to all samples. Use this class to avoid extra copy and memory allocation. - class ConcatenatedSamples { + class ConcatenatedSamples + { public: - ConcatenatedSamples(const Sample & first_, const Sample & second_) + ConcatenatedSamples(const Sample & first_, const Sample & second_) : first(first_), second(second_) {} const T & operator[](size_t ind) const { - if (ind < first.size()) + if (ind < first.size()) return first[ind]; return second[ind % first.size()]; } @@ -128,7 +129,8 @@ public: if (params.size() > 2) throw Exception("Aggregate function " + getName() + " require two parameter or less", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); - if (params.empty()) { + if (params.empty()) + { alternative = Alternative::TwoSided; return; } @@ -143,17 +145,16 @@ public: alternative = Alternative::Less; else if (param == "greater") alternative = Alternative::Greater; - else - throw Exception("Unknown parameter in aggregate function " + getName() + + else + throw Exception("Unknown parameter in aggregate function " + getName() + ". It must be one of: 'two sided', 'less', 'greater'", ErrorCodes::BAD_ARGUMENTS); - - if (params.size() != 2) { + + if (params.size() != 2) return; - } - + if (params[1].getType() != Field::Types::UInt64) throw Exception("Aggregate function " + getName() + " require require second parameter to be a UInt64", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); - + continuity_correction = static_cast(params[1].get()); } @@ -215,7 +216,7 @@ public: { if (!this->data(place).size_x || !this->data(place).size_y) throw Exception("Aggregate function " + getName() + " require both samples to be non empty", ErrorCodes::BAD_ARGUMENTS); - + auto [u_statistic, p_value] = this->data(place).getResult(alternative, continuity_correction); /// Because p-value is a probability. diff --git a/src/AggregateFunctions/AggregateFunctionRankCorrelation.cpp b/src/AggregateFunctions/AggregateFunctionRankCorrelation.cpp index f8a778fa002..87fc24f8f98 100644 --- a/src/AggregateFunctions/AggregateFunctionRankCorrelation.cpp +++ b/src/AggregateFunctions/AggregateFunctionRankCorrelation.cpp @@ -23,7 +23,7 @@ AggregateFunctionPtr createAggregateFunctionRankCorrelation(const std::string & if (!isNumber(argument_types[0]) || !isNumber(argument_types[1])) throw Exception("Aggregate function " + name + " only supports numerical types", ErrorCodes::NOT_IMPLEMENTED); - + return std::make_shared(argument_types); } diff --git a/src/AggregateFunctions/AggregateFunctionRankCorrelation.h b/src/AggregateFunctions/AggregateFunctionRankCorrelation.h index 99f5e6f141a..bdec03d5975 100644 --- a/src/AggregateFunctions/AggregateFunctionRankCorrelation.h +++ b/src/AggregateFunctions/AggregateFunctionRankCorrelation.h @@ -22,7 +22,8 @@ namespace DB struct RankCorrelationData : public StatisticalSample { - Float64 getResult() { + Float64 getResult() + { RanksArray ranks_x; std::tie(ranks_x, std::ignore) = computeRanksAndTieCorrection(this->x); diff --git a/src/AggregateFunctions/StatCommon.h b/src/AggregateFunctions/StatCommon.h index 9376cae15dd..437b3bbf4d2 100644 --- a/src/AggregateFunctions/StatCommon.h +++ b/src/AggregateFunctions/StatCommon.h @@ -8,7 +8,7 @@ #include #include -namespace DB +namespace DB { template @@ -29,7 +29,8 @@ static Float64 integrateSimpson(Float64 a, Float64 b, F && func) using RanksArray = std::vector; template -std::pair computeRanksAndTieCorrection(const Values & values) { +std::pair computeRanksAndTieCorrection(const Values & values) +{ const size_t size = values.size(); /// Save initial positions, than sort indices according to the values. std::vector indexes(size); @@ -40,17 +41,16 @@ std::pair computeRanksAndTieCorrection(const Values & value size_t left = 0; Float64 tie_numenator = 0; RanksArray out(size); - while (left < size) { + while (left < size) + { size_t right = left; - while (right < size && values[indexes[left]] == values[indexes[right]]) { + while (right < size && values[indexes[left]] == values[indexes[right]]) ++right; - } auto adjusted = (left + right + 1.) / 2.; auto count_equal = right - left; tie_numenator += std::pow(count_equal, 3) - count_equal; - for (size_t iter = left; iter < right; ++iter) { + for (size_t iter = left; iter < right; ++iter) out[indexes[iter]] = adjusted; - } left = right; } return {out, 1 - (tie_numenator / (std::pow(size, 3) - size))}; @@ -71,12 +71,14 @@ struct StatisticalSample size_t size_x{0}; size_t size_y{0}; - void addX(X value, Arena * arena) { + void addX(X value, Arena * arena) + { ++size_x; x.push_back(value, arena); } - void addY(Y value, Arena * arena) { + void addY(Y value, Arena * arena) + { ++size_y; y.push_back(value, arena); } @@ -97,7 +99,7 @@ struct StatisticalSample buf.write(reinterpret_cast(y.data()), size_y * sizeof(y[0])); } - void read(ReadBuffer & buf, Arena * arena) + void read(ReadBuffer & buf, Arena * arena) { readVarUInt(size_x, buf); readVarUInt(size_y, buf); diff --git a/src/AggregateFunctions/tests/gtest_ranks.cpp b/src/AggregateFunctions/tests/gtest_ranks.cpp index 4e289716705..b29271cbec7 100644 --- a/src/AggregateFunctions/tests/gtest_ranks.cpp +++ b/src/AggregateFunctions/tests/gtest_ranks.cpp @@ -20,8 +20,8 @@ TEST(Ranks, Simple) ASSERT_EQ(ranks.size(), expected.size()); - for (size_t i = 0; i < ranks.size(); ++i) { + for (size_t i = 0; i < ranks.size(); ++i) ASSERT_DOUBLE_EQ(ranks[i], expected[i]); - } + ASSERT_DOUBLE_EQ(t, 0.9975296442687747); } diff --git a/src/AggregateFunctions/ya.make b/src/AggregateFunctions/ya.make index f5e64f1471b..ea36a6acd91 100644 --- a/src/AggregateFunctions/ya.make +++ b/src/AggregateFunctions/ya.make @@ -29,6 +29,7 @@ SRCS( AggregateFunctionHistogram.cpp AggregateFunctionIf.cpp AggregateFunctionMLMethod.cpp + AggregateFunctionMannWhitney.cpp AggregateFunctionMaxIntersections.cpp AggregateFunctionMerge.cpp AggregateFunctionMinMaxAny.cpp